torchvision 라이브러리는 웬만한 dataset을 모두 지원해주기 때문에
매우 편한 라이브러리다.
그러나 연구를 하다보면 가끔 dataset class 중 특정 class만이 필요할 때가 있다.
이때 다음과 같은 방법으로 손쉽게 특정 class를 추출할 수 있다.
먼저 원하는 데이터셋을 torchvision 코드로 설정한다.
(STL 10을 예시로 들겠다.)
training_data = torchvision.datasets.STL10(
root=".",
split='train',
download=True,
transform=transform
)
위는 일반적인 dataset을 불러오는 방법이다.
이렇게 만든 dataset은 모든 class가 포함되어 있다.
이제 Subset 모듈을 사용하면 된다.
from torch.utils.data import Subset
먼저 라이브러리를 import해주고 다음과 같이 작성한다.
subsets = Subset(training_data, [i for i, (x, y) in enumerate(training_data) if y in [0,1,2]])
위 예시는 class 0,1,2를 불러오는 코드다 이제 [0,1,2]자리에 원하는 class index를 입력하면 특정 class만 포함하는 dataset이 만들어진다.
이제 위 subsets을 Dataloader에 넣으면 끝이다.
loaders = DataLoader(subsets, batch_size= 64, num_workers=6)
잘 됐는지 print 해보면 다음과 같이 잘 설정된 것을 확인할 수 있다.
전체 코드는 다음과 같다.
import torchvision
from torch.utils.data import DataLoader, Subset
training_data = torchvision.datasets.STL10(
root=".",
split='train',
download=True,
transform=transform
)
subsets = Subset(training_data, [i for i, (x, y) in enumerate(training_data) if y in [0,1,2]])
loaders = DataLoader(subsets, batch_size= 64, num_workers=6)
for batch_idx, (inputs, targets) in enumerate(loaders):
print(targets)
'Pytorch > Tips' 카테고리의 다른 글
Pytorch Weighted cross entropy 구현 (0) | 2024.04.26 |
---|---|
Pytorch Learnable parameter, 학습가능한 파라미터 만들기 (2) | 2024.03.14 |
우분투 Ubuntu Nvidia 드라이버, CUDA, cuDNN 설치 (1) | 2023.06.08 |
우분투 16.04 cmake 설치 방법 (0) | 2023.05.03 |
한 모델의 loss계산에 2개 이상의 dataloader를 병렬로 사용할 경우 (1) | 2023.02.21 |