728x90
File "/-/.conda/envs/torch38/lib/python3.8/site packages/torch/utils/data/dataloader.py", line 444, in __iter__ return self._get_iterator()
이 부분 부터 에러가 시작됐다.
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=self.batchsize,
shuffle=True,
num_workers=self.num_workers,
# 이 부분 추가
generator=torch.Generator(device='cuda')
)
generator=torch.Generator(device='cuda')
위의 코드를 추가하여 수동으로 Generator를 cuda에 넣어서 해결했다
728x90