본문 바로가기
Pytorch/Tips

Pytorch Weighted cross entropy 구현

by khslab 2024. 4. 26.

cross entropy는 일반적으로 torch.nn 라이브러리를 통해

사용할 수도 있지만, 이를 변형하기 위해서는 cross_entropy를 구현할 수도 있어야 한다.

 

다음은 일반적인 cross entropy다.

output = model(input)
output = F.log_softmax(output, dim=-1)

target = F.one_hot(label, num_classes = N)

loss = -(target*output).sum(dim=-1)
loss = loss.sum()/output.shape[0]

 

이때 weighted 옵션을 주면 다음과 같다.

output = model(input)
output = F.log_softmax(output, dim=-1)

target = F.one_hot(label, num_classes = N)

weight_list = [0,2,5,7] #class 0,2,5,7에 가중치를 주겠다는 뜻.
class_weight = torch.ones(self.nclass)
class_weight[weight_list] = 50 #가중치 50을 줌.

loss = -(target*output).sum(dim=-1) * class_weight[label]
loss = loss.sum()/class_weight[label].sum()

 

여기서 우리가 주목해야 할 점은 cross entropy의 weight는

가중치를 줄 class를 가진 input이 들어올 겨우

가중치 수 만큼 똑같은 샘플이 더 있다고 생각한다는 것이다.

 

위 코드를 바탕으로 예를 들면,

label이 [2]인 input이 들어올 경우

이 input과 똑같은 것이 50개 더 있다고 간주한다는 것이다.

 

보통 cross entropy까지는 구현이 있는데

weighted cross entropy는 없어서 한번 구현해봤다.