Pytorch/Tips
Pytorch dataset에서 특정 class만을 load하는 Dataloader 만드는법
khslab
2023. 5. 12. 16:34
torchvision 라이브러리는 웬만한 dataset을 모두 지원해주기 때문에
매우 편한 라이브러리다.
그러나 연구를 하다보면 가끔 dataset class 중 특정 class만이 필요할 때가 있다.
이때 다음과 같은 방법으로 손쉽게 특정 class를 추출할 수 있다.
먼저 원하는 데이터셋을 torchvision 코드로 설정한다.
(STL 10을 예시로 들겠다.)
training_data = torchvision.datasets.STL10(
root=".",
split='train',
download=True,
transform=transform
)
위는 일반적인 dataset을 불러오는 방법이다.
이렇게 만든 dataset은 모든 class가 포함되어 있다.
이제 Subset 모듈을 사용하면 된다.
from torch.utils.data import Subset
먼저 라이브러리를 import해주고 다음과 같이 작성한다.
subsets = Subset(training_data, [i for i, (x, y) in enumerate(training_data) if y in [0,1,2]])
위 예시는 class 0,1,2를 불러오는 코드다 이제 [0,1,2]자리에 원하는 class index를 입력하면 특정 class만 포함하는 dataset이 만들어진다.
이제 위 subsets을 Dataloader에 넣으면 끝이다.
loaders = DataLoader(subsets, batch_size= 64, num_workers=6)
잘 됐는지 print 해보면 다음과 같이 잘 설정된 것을 확인할 수 있다.
전체 코드는 다음과 같다.
import torchvision
from torch.utils.data import DataLoader, Subset
training_data = torchvision.datasets.STL10(
root=".",
split='train',
download=True,
transform=transform
)
subsets = Subset(training_data, [i for i, (x, y) in enumerate(training_data) if y in [0,1,2]])
loaders = DataLoader(subsets, batch_size= 64, num_workers=6)
for batch_idx, (inputs, targets) in enumerate(loaders):
print(targets)