Pytorch/Tips

Pytorch dataset에서 특정 class만을 load하는 Dataloader 만드는법

khslab 2023. 5. 12. 16:34

torchvision 라이브러리는 웬만한 dataset을 모두 지원해주기 때문에

매우 편한 라이브러리다.

 

그러나 연구를 하다보면 가끔 dataset class 중 특정 class만이 필요할 때가 있다.

이때 다음과 같은 방법으로 손쉽게 특정 class를 추출할 수 있다.

 

먼저 원하는 데이터셋을 torchvision 코드로 설정한다.

(STL 10을 예시로 들겠다.)

 

training_data = torchvision.datasets.STL10(
     root=".",
     split='train',
     download=True,
     transform=transform
)

 

위는 일반적인 dataset을 불러오는 방법이다.

이렇게 만든 dataset은 모든 class가 포함되어 있다.

 

이제 Subset 모듈을 사용하면 된다.

from torch.utils.data import Subset

먼저 라이브러리를 import해주고 다음과 같이 작성한다.

subsets = Subset(training_data, [i for i, (x, y) in enumerate(training_data) if y in [0,1,2]])

위 예시는 class 0,1,2를 불러오는 코드다 이제 [0,1,2]자리에 원하는 class index를 입력하면 특정 class만 포함하는 dataset이 만들어진다.

이제 위 subsets을 Dataloader에 넣으면 끝이다.

loaders = DataLoader(subsets, batch_size= 64, num_workers=6)

잘 됐는지 print 해보면 다음과 같이 잘 설정된 것을 확인할 수 있다.


전체 코드는 다음과 같다.

import torchvision
from torch.utils.data import DataLoader, Subset
training_data = torchvision.datasets.STL10(
root=".",
split='train',
download=True,
transform=transform
)
subsets = Subset(training_data, [i for i, (x, y) in enumerate(training_data) if y in [0,1,2]])
loaders = DataLoader(subsets, batch_size= 64, num_workers=6)
for batch_idx, (inputs, targets) in enumerate(loaders):
print(targets)