cross entropy는 일반적으로 torch.nn 라이브러리를 통해
사용할 수도 있지만, 이를 변형하기 위해서는 cross_entropy를 구현할 수도 있어야 한다.
다음은 일반적인 cross entropy다.
output = model(input)
output = F.log_softmax(output, dim=-1)
target = F.one_hot(label, num_classes = N)
loss = -(target*output).sum(dim=-1)
loss = loss.sum()/output.shape[0]
이때 weighted 옵션을 주면 다음과 같다.
output = model(input)
output = F.log_softmax(output, dim=-1)
target = F.one_hot(label, num_classes = N)
weight_list = [0,2,5,7] #class 0,2,5,7에 가중치를 주겠다는 뜻.
class_weight = torch.ones(self.nclass)
class_weight[weight_list] = 50 #가중치 50을 줌.
loss = -(target*output).sum(dim=-1) * class_weight[label]
loss = loss.sum()/class_weight[label].sum()
여기서 우리가 주목해야 할 점은 cross entropy의 weight는
가중치를 줄 class를 가진 input이 들어올 겨우
가중치 수 만큼 똑같은 샘플이 더 있다고 생각한다는 것이다.
위 코드를 바탕으로 예를 들면,
label이 [2]인 input이 들어올 경우
이 input과 똑같은 것이 50개 더 있다고 간주한다는 것이다.
보통 cross entropy까지는 구현이 있는데
weighted cross entropy는 없어서 한번 구현해봤다.
'Pytorch > Tips' 카테고리의 다른 글
터미널마다 다른 CUDA버전 적용하는 법 (0) | 2025.03.27 |
---|---|
랜덤시드 고정 안될때 (0) | 2025.03.25 |
Pytorch Learnable parameter, 학습가능한 파라미터 만들기 (2) | 2024.03.14 |
우분투 Ubuntu Nvidia 드라이버, CUDA, cuDNN 설치 (1) | 2023.06.08 |
Pytorch dataset에서 특정 class만을 load하는 Dataloader 만드는법 (0) | 2023.05.12 |