본문 바로가기
Artificial Intelligence/Basic

[ Loss ] Cross-Entropy, Negative Log-Likelihood 내용 정리! ( + Pytorch Code )

by SuperMemi 2022. 8. 21.
반응형

[ Loss ] Cross-Entropy, Negative Log-Likelihood 내용 정리! ( + Pytorch Code )

 

https://pytorch.org/

 


 

[ Pytorch ] 파이토치 설치하기

[ Pytorch ] 파이토치 설치하기 머신러닝에서 tensorflow와 pytorch는 양대 산맥이죠 pytorch를 설치해봅시다. https://pytorch.org/get-started/locally/ PyTorch An open source machine learning framewor..

supermemi.tistory.com


이 글[1]을 바탕으로 작성되었습니다. 영어가 더 편하신 분은 원문 글로!!

 

Cross-Entropy, Negative Log-Likelihood, and All That Jazz

Two closely related mathematical formulations widely used in data science, and notes on their implementations in PyTorch

towardsdatascience.com


Cross-Entropy == Negative Log-Likelihood?

 

처음 머신러닝 또는 딥러닝을 공부할때 Cross-EntropyNegative Log-Likelihood에 대해서 많이 들어보셨을 겁니다. 그리고 Pytorch의 Loss 함수(CrossEntropyLoss,  NLLLoss)로도 많이 사용합니다. 

 

그런데 Pytorch의 CrossEntropyLoss 설명에 다음과 같이 적혀 있습니다.

 

 

Cross-Entropy 구현에 Negative Log-Likelihood 와 동일한 식이 포함되어 있다는데 정말일까요?

그렇다면 Cross-EntropyNegative Log-Likelihood는 같은 식이라고 봐도 될까요?

 

이번 글에서는 이와 관련하여 천천히 하나씩 설명드리겠습니다.

 


Maximum Likelihood Estimation

 

우선 쉬운 예시로 binary classification(정답이 0 또는 1)을 가정해 봅시다.

 

모델 \(f\)는 파라메타 \(\theta\)로 구성되어 있습니다. 주된 목적은 주어진 데이터의 likelihood를 최대화 하는 파라메타 \(\theta\)를 찾는 것입니다. 

 

$$\hat{y}_{\theta ,i}=\sigma (f_\theta (x_i))$$

 

\(\hat{y}_{\theta ,i}\)은 모델이 예측한 positivie일 확률이며, \(\sigma\)는 어떠한 비선형 활성 함수(non-linear activation function)입니다. 일반적으로 [0,1] 구간으로 매핑(mapping) 해주는 시그모이드(sigmoid)함수를 많이 사용합니다. 

 

$$\sigma (z) = \frac{1}{1+exp(-z)}$$

 

그리고, likelihood 함수는 아래와 같이 정의됩니다.

 

$$\mathbb{P}(\mathcal{D}|\theta) = \prod_{i=1}^n \hat{y}_{\theta ,i}^{y_i}(1-\hat{y}_{\theta ,i})^{(1-y_{\theta ,i})}$$

 

참고사항

  • \(\prod\) 는 곱셈 기호 입니다.
  • sigmoid 함수를 적용했기 때문에 예측값\((\hat{y}_{\theta ,i})\)은 0~1 의 값을 가집니다. 
  • 즉, 우리의 모델은 positive(class=1)이라고 생각하면 1에 가까운 답을 내고, negative(class=0)이라고 생각하면 0에 가까운 답을 내게 됩니다.
  • 모델의 예측과 실제 정답간의 차이가 적을 수록 likelihood의 값은 높아지고, 차이가 커질 수록 likelihood의 값은 낮아집니다. 

 

위의 likelihood 함수를 자세히 봅시다.

\(y_i\)와 \((1-y_i)\)가 지수 값으로 제곱되고 있습니다. 이는 어떤 의미일까요?

 

binary classification 에서는 정답(\(y\))가 0 또는 1 입니다. 따라서, 어떠한 \(i\)번째 데이터의 정답(\(y_i\))은 0 또는 1 입니다. 

 

  • \(y_i = 0\) 일때, 첫번째 항 \(\hat{y}_{\theta ,i}^{y_i}\)은 지수(\(y_i\))가 0이 되므로 첫번째 항이 1이 되고, 뒤쪽 항\(((1-\hat{y}_{\theta ,i})^{(1-y_{\theta ,i})})\)은 지수\((1-y_{\theta ,i})\)가 1이 되어 \((1-\hat{y}_{\theta ,i})\)만 남게 됩니다. 그래서 모델의 예측 값\((\hat{y}_{\theta ,i})\)이 정답\((y_i = 0)\)과 유사한 0이 될수록 likelihood가 1에 가까워집니다. 반대로 정답과 예측이 다른 경우 likelihood가 0에 가까워 집니다. 
  • 반대로 \(y_i = 1\) 일때, 첫번째 항 \(\hat{y}_{\theta ,i}^{y_i}\)은 지수(\(y_i\))가 1이 되므로 \(\hat{y}_{\theta ,i}^{y_i}\) 만 남게 되고, 뒤쪽 항\(((1-\hat{y}_{\theta ,i})^{(1-y_{\theta ,i})})\)은 지수\((1-y_{\theta ,i})\)가 0이 되어 1이 됩니다. 그래서 모델의 예측 값\((\hat{y}_{\theta ,i})\)이 정답\((y_i = 1)\)과 유사한 1이 될수록 likelihood가 1에 가까워집니다. 반대로 정답과 예측이 다른 경우 likelihood가 0에 가까워 집니다.

 

즉, Likelihood 는 정답 라벨에 대한 예측된 확률값을 비교함으로써 예측의 정확도를 측정한다고 볼 수 있습니다. 

 

 

Log-likelihood

 

위의 likelihood 식에 log 함수를 취하면 Log-likelihood 식입니다.

 

$\textrm{log}\mathbb{P}(\mathcal{D}|\theta)$

$\quad = \textrm{log}(\prod_{i=1}^n \hat{y}_{\theta ,i}^{y_i}(1-\hat{y}_{\theta ,i})^{(1-y_{\theta ,i})})$

$\quad = \sum_{i=1}^n(y_i \textrm{log}(\hat{y}_{\theta ,i})+(1-y_{\theta ,i})\textrm{log}(1-\hat{y}_{\theta ,i}))$

 

log 함수를 적용했을때 장점 

우선 곱셈으로 이루어진 likelihood 식이 덧셈으로 변합니다. 또한 지수값이 곱셈으로 변화합니다. 이는 연산이 용이하고 직관적이라는 장점이 있습니다. 

 

Summing up the correct entries (binary case)

 

아래의 동영상을 보시면 직관적인 이해가 가능하실 겁니다!

  1. 모델 예측. 만약 raw 값이라면 sigmoid 적용 (0~1 값으로 매핑)
  2. log 함수 적용
  3. log-likelihood 계산 (Masking principle : label 에 맞춘 값 선택)
  4. 각 데이터 log-likelihood 값 합산

The computation of binary negative log-likelihood, image by [1] (produced with Manim)

 

Minimizing the Negative Log-Likelihood

 

Log-Likelihood 의 값은 커질 수록 좋습니다. 그러나 우리가 경사하강법(Gradient Descent)라 불리는 최적화 방법을 사용하기 위해서는 loss 함수 값이 작아질 수록 좋은 것으로 정의해야 합니다. Pytorch 또한 loss 값을 줄여나가는 방향으로 학습을 진행하게 되죠.

 

$$l(\theta)= -\sum_{i=1}^n(y_i \textrm{log}(\hat{y}_{\theta ,i})+(1-y_{\theta ,i})\textrm{log}(1-\hat{y}_{\theta ,i}))$$

 

그래서 이에 맞추어 Log-Likelihood 에 마이너스(Negative)를 적용합니다.

(참고로, log 함수는 단조 함수(항상 증가)하기 때문에 이와 같은 Negative 적용이 가능합니다.)

 

  • Likelihood최대화(Maximizing) 하는 것은 Log-likelihood최대화(Maximizing) 하는 것과 동일합니다. 
  • 이를 반대로 뒤집으면 Likelihood 를 최대화 하는 것은 Negative likelihood 를 최소화 하는 것과 동일합니다.
  • 그리고 Negative likelihood 를 최소화 하는 것은 Negative Log-Likelihood 를 최소화 하는 것과 동일합니다. 
  • 즉, Negative Log-Likelihood를 최소화 하는 것은 Likelihood를 최대화 하는 것과 동일한 효과를 가집니다.

 

import torch
import torch.nn as nn

torch.manual_seed(77)

### Binary Setting ###

print(f"{'Setting up binary case':-^80}") 

z = torch.randn(5) # z = f(x)
y_hat = torch.sigmoid(z) # y_hat = sigmoid(f(x))
y = torch.tensor([0.,1.,1.,0.,1.]) # ground truth (float type)

print(f"z = {z}\ny_hat = {y_hat}\ny = {y}\n{'':-^80}")

# Negative Log-likelihoods
loss_NLL_scratch = -(y * y_hat.log() + (1 - y) * (1 - y_hat).log())
print(f"Negative Log-likelihoods\n    {loss_NLL_scratch}")
print(f"Loss Summantion : {loss_NLL_scratch.sum()}")

 

 

 

Generalizing to Multiclass

 

위의 binary Log-likelihood 를 Multiclass 로 확장시켜 일반화해봅시다. 여기서 주의하실 점은 binary 와 multiclass 의 차이입니다.

 

Binary의 경우 하나의 데이터가 input으로 들어가면 모델이 단 하나의 값(0~1)을 output으로 만들어내는 구조를 가지고 있습니다. sigmoid function 을 사용합니다. 따라서 binary likelihood 에서는 "not positive means negative"라는 개념을 직접적인 식으로 표현하였습니다. 

 

Multiclass의 경우 하나의 데이터가 input으로 들어갔을때 output이 하나의 값이 아니라 전체 클래스의 길이를 가진 벡터(C 차원)로 나타냅니다. 이러한 환경에서는 output 벡터 자체를 확률 분포로 표현하기 위해 sigmoid가 아닌 softmax function 을 사용합니다. (사실 binary setting 또한 entry 가 2인 multiclass 라고 생각하셔도 됩니다.)

 

$$\textrm{softmax}(z)_i = \frac{\textrm{exp}(z_i)}{\sum_{j=o}^{C-1} \textrm{exp}(z_j)}$$

 

  • $z_i$ 는 C 차원의 벡터 입니다. 
  • softmax의 결과 : 각 class에 해당하는 element는 0~1의 값을 가지며, 모든 class element 를 합하면 1이 됩니다. 
  • 만약 C=2 일경우, Binary 와 동일한 식(sigmoid)이 됩니다!! 직접 계산해 보세요!!

 

Log-likelihood 함수를 일반화 시켜서 다시쓰면 아래와 같습니다. 

 

$$\textrm{log}\mathbb{P}(\mathcal{D}|\theta)= \sum_{i=1}^n \textrm{log}(\hat{y}_{\theta ,i})^{y_i}$$

 

여기서 주의할 점은 $y_i$ 이 지수함수의 제곱 역할이 아니라 upper index 의 개념으로 바뀝니다. 즉, True label 과 관련되는 확률만 보겠다는 의미입니다. 

 

Summing up the correct entries (multiclass case)

  1. Predicted Values($z_i$)는 아직 확률값이 아닙니다.
  2. softmax 함수를 적용하여 이 값들을 확률값으로 변환해 주고, log 함수를 적용합니다. (pytorch 에서는 LogSoftmax 를 적용하면 간편합니다.)
  3. class 에 해당하는 true label 부분의 log probability 부분을 다 더해 줍니다. 

The computation of multiclass negative log-likelihood, image by the author (produced with  Manim )

 

import torch
import torch.nn as nn

torch.manual_seed(77)

### Multi class Setting ###

print(f"{'Setting up multiclass case':-^80}") 

z = torch.randn(5,3) # z = f(x)
y_hat = torch.softmax(z, dim=1) # y_hat = softmax(f(x))

# ground truth : each element in target has to have 0 <= value < C
y = torch.tensor([0,1,0,2,1]) 

print(f"z = {z}\ny_hat = {y_hat}\ny = {y}\n{'':-^80}")

# Negative Log-likelihoods 
#     - masking the correct entries
loss_NLL_scratch = -y_hat.log()[torch.arange(5),y.long()] 
print(f"Negative Log-likelihoods\n    {loss_NLL_scratch}")
print(f"Loss Summantion : {loss_NLL_scratch.sum()}")


Cross-Entropy

 

이산 환경(Discrete setting)에서 두 확률 분포 $p$와 $q$가 주어졌을때, cross-entropy 는 다음과 같이 정의 됩니다.

 

$$H(p,q)=-\sum_{x\in \chi}p(x)\textrm{log}(q)$$

 

사실 Cross-Entropy의 $p$를 정답 분포, $q$를 모델의 예측 분포라고 생각하면 Negative Log-Likelihood 식과 매우 유사해 보입니다. 

 

Cross-Entropy Vs. Negative Log-likelihood

 

Pytorch의 구현된 함수에서 약간의 차이가 존재합니다. Pytorch의 CrossEntropyLoss 설명에 다음과 같이 적혀 있습니다.

 

 

CrossEntropyLoss 에서는 LogSoftmax 를 적용한 후 Negative Log-likelihood(NLL) Loss를 적용한 식과 동일하다는 내용인데요. (여기서 LogSoftmax는Softmax 함수를 적용한 후 Log 함수를 적용하는 것과 동일한 함수입니다.)

 

정리하자면..

  • CrossEntropyLoss 안에서 LogSoftmax와 Negative Log-likelihood 가 진행되기 때문에 softmax나 log 함수가 적용되지 않은 모델 output(raw data)을 input으로 주어야 합니다.
  • 이와 달리 NLLLoss 안에서는 softmax나 log함수가 이뤄지지 않습니다. 그래서 모델 output(raw data)을 input으로 그대로 사용하는 것이 아니라 LogSoftmax 함수를 적용한 후 input으로 사용해야합니다.

 

Pytorch Code 

nn.BCELoss, nn.BCEWithLogitsLoss, nn.CrossEntropyLoss, nn.NLLLoss 총정리

 

[ Pytorch ] nn.BCELoss, nn.BCEWithLogitsLoss, nn.CrossEntropyLoss, nn.NLLLoss 총정리

[ Pytorch ] nn.BCELoss, nn.BCEWithLogitsLoss, nn.CrossEntropyLoss, nn.NLLLoss 총정리 이 글은 아래 링크된 글에 이어지는 글입니다. 내용이 궁금하시다면 먼저 보고 오시길 바랍니다! [ Loss ] Cross..

supermemi.tistory.com


 

반응형