GAN(Generative Adversarial Network)의 간단한 이해
GAN에 대해 이해하기 위해 간단한 이야기를 읽어봅시다.
대장장이 마을에서 두 명장이 있었습니다.
한명은 무엇이든 뚫어버리는 칼을 만드는 명장이고, 다른 한명은 무엇이든 막는 방패를 만드는 명장이었습니다. 두 대장장이는 서로의 작품이 최고라며 싸웠습니다.
어떤날에 장인이 만든 칼이 방패를 뚫어버린 적이 있었습니다.
방패를 만드는 장인은 자존심이 상했고, 더욱 강력한 방패를 만들기 위해 약점을 보완하여 방패를 발전시켰습니다. 결국 다음 결투때는 칼을 막아냈지요.
그러자 칼을 만드는 대장장이 또한 자존심이 강해 더욱 강한 칼을 만들기 위해 노력했습니다.
그결과 두 대장장이의 작품은 그 무엇보다도 강력한 무기가 될 수 있었습니다.
위의 이야기는 GAN의 원리와 유사합니다.
GAN은 두가지 모델(Generator (생성자) 모델, Discriminator (판별자) 모델)을 가지고 있습니다.
이미지 GAN을 생각하면 이해가 쉽습니다.
Generator network(생성자;G) 모델에서는 사람의 얼굴을 만들어낸다고 가정합시다. 하지만 어색한 점이 많습니다. 눈코입도 이상하고, 누가봐도 사람이라고 하기 힘든 사진입니다.
Discriminator network(판별자;D) 모델에서는 실제 사람얼굴 데이터와 Generator Network(생성자 ;G) 모델의 결과를 입력값으로 받아와 사람인지 아닌지를 판별합니다. 처음에는 생성자(G)가 바보라서 쉽게 구분합니다. 그래서 판별자(D)는 생성자(G)에게 피드백을 줍니다.
"너 좀 잘 만들어봐.. 너무 못만들어서 딱봐도 가짜인걸 알아보겠다!!!"
피드백을 받은 생성자(G)는 더욱 정교하게 실제 데이터와 비슷하도록 만들어갑니다. 판별자(G)는 생성자(G)가 만들어낸 Fake image와 실제 존재하는 real image의 차이를 점점 판별하기 어려워합니다. 그래서 생성자(G)도 계속해서 차이를 발견해내기 위해 발전합니다.
대장장이 예시와 유사하게 서로를 적대적으로 발전시켜나가는 모델을 GAN이라고 부릅니다.
모델 학습
자 그럼 구체적으로 어떻게 모델을 학습시키는지 확인해 봅시다.
$$\min_{G}\max_{D} V(D,G) = \mathbb{E}_{x} [logD(x)] +\mathbb{E}_{x^*} [log(1-D(x^*))]$$
생성기 : $G$
판별기 : $D$
실제 데이터 : $x$
$G$에 의해 생성된 데이터 : $x^*=G(z)$
실제 데이터에 대한 예측값 : $\mathbb{E}_{x} $
생성된 데이터에 대한 예측값 : $\mathbb{E}_{x^*}$
real data를 real 이라고 판단할 확률 : $D(x)$
generated(fake) data를 real 이라고 판단할 확률 : $ D(x^*)$
판별자 (D)
판별자(D)는 만들어진 데이터와 실제데이터를 잘 구분하는 것이 가장 중요합니다.
판별자(D)의 입장에서는 $D(x)$가 높을수록, $D(x^*)$가 낮을 수록($1-D(x^*)$가 높을 수록) 판단을 잘하는 것입니다.
따라서 $D(x)$, $1-D(x^*)$를 maximize(최대화) 하는 것을 목표로 update 합니다.
$$ \theta_d = \theta_d + \frac{1}{m} \nabla_{\theta_d} \sum\limits_{i=1}^{m}(logD(x_i)+log(1-D(x_i^*)))$$
$\theta_d$ : $D$의 매개변수
$m$ : 미니 배치의 크기
$i$ : 미니 배치 샘플의 인덱스
생성자 (G)
생성자(G)입장에서는 반대겠지요.
판별자(D)가 진짜인지 가짜인지 판별을 잘 못하는 경우, 데이터 생성을 잘 한다고 할 수 있습니다.
따라서 $1-D(x^*)$를 minimize(최소화) 하는 것을 폭표로 update 합니다.
판별자(D)와 달리 생성자(G)는 $D(x)$에 직접적인 영향을 주지 못함으로 update과정에서는 $1-D(x_i^*)$만 최소화 하는 방향으로 update합니다. $D(x_i^*)$를 maximize하는 것과 같습니다.
$$\theta_g = \theta_g - \frac{1}{m} \nabla_{\theta_g} \sum\limits_{i=1}^{m}log(1-D(x_i^*))$$
[참고자료]
1. <PyTorch를 이용한 GAN 실제>,John Hany 지음, 차정원 옮김.
'Artificial Intelligence > Neural Networks' 카테고리의 다른 글
[ CNN ] 3. Point-wise convolution - PyTorch Code (0) | 2021.08.12 |
---|---|
[ CNN ] 2. Grouped convolution - PyTorch Code (0) | 2021.08.11 |
[ CNN ] 1. 바닐라 합성곱(vanilla convolution) - PyTorch Code (0) | 2021.08.11 |
[CNN] Padding 무엇인가? (0) | 2020.03.16 |
[ CNN ] pooling이란? (tf.keras.layers.MaxPool2D) (0) | 2020.03.16 |