Python/numpy & Pytorch

[Data] torchvision datasets으로 받은 데이터 나누기

파송송 2023. 3. 27. 10:52
728x90

데이터 받기

dataset을 datasets으로 받아준다.

    train_dataset = datasets.MNIST(config.data_path,
                                    train=True,
                                    download=True,
                                    transform=config.augmentation
                                  )

그러고 data Loader에 넣어주면 다음과 같이 data정보가 나온다.

60000장의 데이터를 500장으로 줄이는 작업을 할 것이다.


데이터 나누기

먼저 dataloader를 list로 만들어준다.

이렇게 하면 슬라이스 작업을 할 수 있다.

a=list(train_loader)
(20, 1, 64, 64)
(2,)
3000

a는 64x64x1 이미지가 20(batch) 개 있고 

그것에 대한 label 값을 합쳐서 3000개의 세트를 가지고 있음


train_loader = a[:25]

다음과 같이 25*20(batch) 만큼 데이터를 소분해서 사용할 수 있음

728x90