w = torch.empty(3, 5)
nn.init.uniform_(w)
[ CNN ] 가중치 초기화 (Weight Initialization)
CNN이든 머신러닝이든 결국 우리는 목적함수의 값을 최적화하는 방향으로 학습을 시켜나가죠. 그래서 보통 역전파를 이용해서 가중치(weight) 값을 조금씩 변화시켜나가며 정답을 찾아갑니다. 결국 우리가 찾고 싶은건 가중치(weight) 값이 되는 건데요. 그 가중치(weight)의 초기값에 따라서 학습의 진행방향이 달라집니다. 즉, 가중치 초기화(Weight initialization)은 학습과정에서 중요한 역할을 합니다.
1. torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
균일 분포($u(a,b)$)로 텐서를 초기화 합니다.
확률에 기반해서 각자리마다 0~1 값으로 초기화 해줍니다.(figure 1 참고)
import torch
import matplotlib.pyplot as plt
import numpy as np
tensor = torch.empty(1000)
torch.nn.init.uniform_(tensor)
# figure 1(left)
plt.plot(range(len(tensor)),tensor)
plt.show()
# figure 2(right)
plt.hist(np.sort(tensor))
plt.show()
2. torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
정규 분포($N(mean,std^2)$) 로 초기화 시킴.
import torch
import matplotlib.pyplot as plt
import numpy as np
tensor = torch.empty(1000)
torch.nn.init.uniform_(tensor)
# figure 2(left)
plt.plot(range(len(tensor)),tensor)
plt.show()
# figure 2(right)
plt.hist(np.sort(tensor))
plt.show()
3. torch.nn.init.xavier_uniform_(tensor, gain=1.0)
Xavier-uniform : $u(-a,a)$
Glorot initialization 이라고도 함.
Xavier 초기화는 고정된 표준편차를 사용하지 않는다는 특징이 있다.
이전 은닉층의 노드수(fan_in)과 현재 은닉층의 노드(fan_out) 을 고려하여 만들어 진다.
활성값이 고르게 분포한다.
w = torch.empty(3, 5)
nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
4. torch.nn.init.xavier_normal_(tensor, gain=1.0)
Xavier-normal : $N(0,std^2)$
Glorot initialization 이라고도 함.
w = torch.empty(3, 5)
nn.init.xavier_normal_(w, gain=nn.init.calculate_gain('relu'))
5. torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
Kaiming_uniform : $u(-bound, bound)$
He-initialization이라고도 불림.
relu 나 leaky_relu를 activation function으로 사용하는 경우 많이 사용함.
w = torch.empty(3, 5)
nn.init.kaiming_uniform_(w, mode='fan_out', nonlinearity='relu')
6. torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
kaiming_normal : $N(0,std^2)$
He-initialization이라고도 불림.
relu 나 leaky_relu를 activation function으로 사용하는 경우 많이 사용함.
w = torch.empty(3, 5)
nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
'Artificial Intelligence > Neural Networks' 카테고리의 다른 글
[ CVPR2022 / GML4VC ] 2. Open Remarks (0) | 2022.08.22 |
---|---|
[ CVPR2022 / GML4VC ] 1. 개요 (Graph Machine Learning, GNNs) (0) | 2022.08.22 |
[ CNN ] 6. Dilated convolution - PyTorch Code (0) | 2021.08.12 |
[ CNN ] 5. Depth-wise Separable convolution - PyTorch Code (0) | 2021.08.12 |
[ CNN ] 4. Depth-wise convolution - PyTorch Code (0) | 2021.08.12 |