(GAN) Generative Adversarial Nets, NeurIPS 2014

2025. 1. 20. 17:57·논문 리뷰
반응형

논문: Generative Adversarial Nets

code: goodfeli/adversarial: Code and hyperparameters for the paper "Generative Adversarial Networks"

 

0. Summary

 본 논문은 적대적인 과정(adversarial process)을 통해 생성 모델을 추정하기 위한 새로운 프레임워크를 제안한다. 데이터 분포를 학습하는 생성 모델 G와, 샘플이 실제 훈련 데이터에서 나온 것인지 G에서 생성된 것인지를 구분할 확률을 추정하는 판별 모델 D를 동시에 학습한다.

 G와 D가 다층 퍼셉트론(MLP)로 구성되어 있을 때, 전체 시스템은 역전파를 통해 학습할 수 있다. 또한, 훈련 과정이나 샘플을 생성하는 동안 Markov Chain 및 unrolled approximate inference network를 사용할 필요가 없다.


1. Introduction

 Deep generative model은 최대우도추정(maximum likelihood estimation)이나 관련 기법에서 발생하는 계산 불가능에 가까운 많은 확률적 연산을 근사해야 하는 어려움, 선형 활성화 함수의 이점을 가져오는 것 등의 어려움이 있기 때문에 크게 임팩트 있지 않았다. 

 이에 본 논문은 이러한 어려움을 우회할 수 있는 새로운 생성 모델 추정 방법을 제안한다. 적대적 신경망 프레임워크에서, 생성 모델은 판별 모델과 경쟁 관계에 놓인다. 판별 모델은 모델 분포에서 생성된 샘플인지, 실제 데이터 분포에서 나온 샘플인지를 구분하도록 학습하는 모델이다.

 본 논문에서는 생성 모델이 무작위 noise를 다층 퍼셉트론에 통과시켜 샘플을 생성하고, 판별 모델 또한 다층 퍼셉트론으로 구성되는 특수한 경우를 살펴본다. 이 경우, 역전파와 dropout 같은 이미 널리 성공을 거둔 알고리즘만으로 두 모델을 학습할 수 있으며, 생성 모델에서 샘플을 얻을 때도 순방향 전파(forward propagation)만 수행하면 된다. 근사 추론(approximate inference)이나 마르코프 체인(Markov chains)은 필요하지 않다.


2. Related Work

2.1 Parametric specification of a probability distribution function

 기존에 존재하는 대부분의 generative model 연구는 데이터의 분포를 명시적으로 parametric function으로 나타나는데 초점이 맞춰져 있었다. 이러한 모델은 로그 우도를 최대화함으로써 학습할 수 있다. 하지만, 이러한 모델은 일반적으로 intractable한 우도 함수를 가지며, likelihood gradient에 대해 많은 근사치를 필요로 한다. 따라서, 데이터 분포를 명시적으로 나타내지 않으면서도 원하는 분포에서 데이터를 샘플링하는 방법을 제시한다.

 


3. Adversarial Nets

적대적 모델링 프레임워크는 두 모델이 모두 다층 퍼셉트론으로 구성되어 있다. 우선, 입력 노이즈 변수 $z$에 대한 prior $p_z(z)$를 정의한다. 그 다음, 이를 데이터 공간으로 나타내는 $G(z;\theta_g)$를 정의한다. 또한, $x$가 실제 데이터에서 온 것인지, 생성기 $G$에 의해 만들어진 $p_g$로부터 온 것인지를 판별하는 확률 값을 출력하는 $D(x;\theta_d)$를 정의한다.

$G$는 noise 분포$p_z(z)$로 부터 샘플링한 $z$를 입력으로 받아 생성한 데이터의 확률 분포 $p_g$가 실제 데이터 분포인 $p_{data}$와 유사해지는 방향으로 $\theta_G$를 업데이트 하며 학습된다. $D$는 $G$가 만들어낸 데이터 분포인 $p_g$에서 온 데이터 인지, 실제 데이터 분포인 $p_{data}$에서 온 데이터인지를 구분하는 방향으로 $\theta_D$를 업데이트 하며 학습된다. 즉, D와 G는 다음 functiono $V(D,G)$를 통해 two-player minimax game을 수행한다.

 

$$min_G max_D (D,G) = E_{x\sim p_{data}(x)}[logD(x)] + E_{z\sim p_z(z)}[log(1-D(G(z)))]$$

 - $G$: Generator

 - $D$: Discriminator

 - $p_g$: generator's distribution

 - $p_{data}$: real data's distribution

 - $p_z(z)$: prior on input noise variables

 

 이러한 학습 방법은 D를 k번 업데이트하고 G를 한 번 업데이트하는 과정을 번갈아 수행한다. 또한, 학습 초기 

$log(1-D(G(z)))$를 최소화하기보다 $log(D(G(z)))$를 최대화 하도록 G를 학습한다. 

 

학습 초기, D는 real data(1) 과 fake data(0)을 구분하기 쉽다고 한다. 따라서, 기존 목적함수(파란 선)을 기준으로, G는 학습 초기 D에 의해 0으로 구분될 확률이 높고, $log(1)=0$ 즉, low gradient를 얻어, 학습 속도가 매우 느리다. 반면, 수정된 목적함수(초록 선)을 기준으로는 high gradient를 얻어, 학습 속도가 매우 빠르다. 이는 학습이 많이 필요한 초기 상태에 매우 적합한 목적 함수이다.


4. Theoretical Results

G는 노이즈 $z \sim p_z$를 입력받아 G(z)를 출력함으로써 암묵적(implicit)으로 확률분포 $p_g$를 정의한다. 알고리즘 1이 충분한 모델 용량과 학습 시간이 주어졌을 때, $p_{data}$를 올바르게 추정하도록 수렴하기를 바란다. 

Fig 1

GAN은 판별 분포(파란 점선)을 업데이트하며 훈련된다. D는 $p_x$(실제 데이터가 따르는 확률 분포, 검은색 점선)에서 생성한 샘플과 $p_g$ (G, 녹색 실선)에서 생성한 샘플을 구분하도록 학습된다. 

 

(a) 적대적 쌍이 수렴에 가까워진 상황은 다음과 같다. $p_g$는 $p_{data}$와 유사하며, D는 정확한 분류기이다.

(b) D는 $D^*(x)=\frac{p_{data}(x)}{p_{data}(x)+p_g(x)}$로 수렴한다.

(c) D의 gradient는 G(z)가 실제 데이터로 분류될 가능성이 더 높은 영역으로 이동하도록 유도한다.

(d) 여러 학습 단계를 거친 후, G와 D가 충분한 model capacity를 가지고 있다면, $p_g=p_{data}$에 도달해 서로 개선할 수 없는 지점에 이르게 된다. 이 때, 판별자는 두 분포를 구분할 수 없게 되어 $D(x)=\frac{1}{2}$가 된다.

 

 

 

4.1 Global loptimality of $p_g=p_{data}$

우선, 임의의 생성기 G에 대해 최적의 판별기 D를 고려한다.

 

Proposition 1. G가 고정되었을 때, 최적의 판별기 D는 다음과 같다.

$$D^*_G(x)=\frac{p_{data}(x)}{p_{data}(x)+p_g(x)}$$

 

Proof) 임의의 생성기 G에 대해 판별기 D의 학습 기준은 다음 값을 최대화 하는 것이다.

$$ max_D V(G,D) = E_{x\sim p_{data}(x)}[logD(x)] + E_{z\sim p_z(z)}[log(1-D(G(z)))] $$

연속 확률 변수에서 기대값은 적분($\int$)를 사용하여 다음과 같이 나타낸다.

 

$$ V(G, D) = \int_x p_{data}(x) log(D(x))dx + \int_z p_z(z) log(1-D(G(z)))dz \\ = \int_x p_{data}(x)log(D(x)) + p_g log(1-D(x))dx $$

 

일반적으로, 어떤 실수 값을 갖는 (a,b)에 대한 function $y-> alog(y)+blog(1-y)$ 은 [0,1] 구간에서 $\frac{a}{a+b}$에서 최대값을 갖는다. 따라서, 목적함수는 다음과 같이 정의될 수 있다.

 

$$C(G) = max_D V(G,D) \\ = E_{x\sim p_{data}}[logD^*_G(x)] + E_{z\sim p_z}[log(1-D^*_G(G(z)))] \\ = E_{x\sim p_{data}}[logD^*_G(x)] + E_{x\sim p_g}[log(1-D^*_G(x))] \\ = E_{x\sim p_{data}} [log \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}] + E_{x\sim p_g} [log \frac{p_g(x)}{p_{data}(x) + p_g(x)}]$$

 

$p_g = p_{data}$ 인 경우, $D^*_G(x)=\frac{1}{2}$ 이며, $C(G) = log \frac{1}{2} + log \frac{1}{2} = -log(4)$ 이다. 이를 증명하기 위해, C(G)에 log(4)를 뺀 결과를 살펴본다.

 

$$ C(G) = E_{x\sim p_{data}} [log \frac{2 * p_{data}(x)}{p_{data}(x) + p_g(x)}] + E_{x\sim p_g} [log \frac{2 * p_g(x)}{p_{data}(x) + p_g(x)}]  - log(4) \\ = -log(4) + KL(p_{data}||\frac{p_{data}+p_g}{2}) + KL(p_g||\frac{p_{data}+p_g}{2})$$

 

KL divergence는 항상  0 이상이기 때문에, $C*=-log(4)$는 $C(G)$의 global minimum 이다.

4.2 Convergence of Algorithm 1

-


5. Experiments

Fig. 2 a) MNIST b) TFD c) CIFAR-10 (fully connected model) d) CIFAR-10 (convolutional discriminator and "deconvolutional" generator

 

그림 2는 MNIST, TFD, CIFAR-10을 포함한 데이터셋에서 적대적 신경망을 훈련시킨 결과이다. 이 샘플이 기존 방법들로 생성된 샘플보다 우수하다고 주장하지는 않지만, 기존 생성 모델과 경쟁할 만한 수준이라고 소개한다. 또한, 정량적인 평가 지표는 아래 Table 1에 제시한다.

 

Table 1 : Parzen window-based log-likelihood estimates

 

$GAN$은 $p_g(x)$를 명시적으로 정의하지 않기 때문에, $p_g$와 $p_{data}$가 얼마나 유사한지 계산할 수 없다. 따라서, Gaussian Parzen Window 방법을 사용하여, GAN이 생성한 샘플이 실제 테스트 데이터와 얼마나 일치한지 로그 우도로 추정하였다. 높은 로그 우도 값은 GAN이 생성한 샘플의 분포 $p_g(x)$가 실제 데이터 분포 $p_{data}(x)$ 와 더 유사하다는 것을 의미한다.


6. Advantages and disadvantages

제안된 프레임워크의 장점으로는 마르코프 체인이 전혀 필요하지 않고, 그래디언트를 얻기 위해 역전파만 사용되며, 학습 중에 추론이 필요하지 않다는 점, 모델에 다양한 함수들(dropout 등)을 통합할 수 있다는 점이 있다.

 

단점으로는 주로 $p_g(x)$에 대한 명시적 표현이 없다는 점과, 훈련 과정에서 $D$와 $G$가 잘 동기화되어야 한다는 점이 있다. 특히, $D$를 업데이트하지 않고 $G$를 너무 많이 훈련하면, $G$가 $z$의 너무 많은 값을 하나의 $x$ 값으로 수렴시켜 $p_{data}$를 표현하기에 충분한 다양성을 잃게 되는 "the Helvetica scenario" (?) 가 발생할 수 있다.


7. Conclusions and future work

1. conditional generative model (CGAN)으로 발전할 수 있음

2. $x$가 주어졌을 때, $z$를 예측하는 보조 네트워크를 훈련하여 근사 추론을 수행할 수 있음

3. 제한된 라벨 데이터가 주어진 경우, classifier의 성능을 향상시킬 수 있음

4. G, D를 조정하는 더 나은 방법을 고안하거나, 학습 중 z를 샘플링하기 위한 더 나은 분포를 결정함으로써 훈련 속도를 향상시킬 수 있음

반응형

'논문 리뷰' 카테고리의 다른 글

(AdaIN) Arbitrary Style Transfer in Real-Time With Adaptive Instance Normalization, ICCV 2017  (2) 2025.01.21
(CycleGAN) Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, ICCV 2017  (0) 2025.01.21
(Grad-CAM) Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, ICCV 2017  (0) 2025.01.14
(CAM) Learning Deep Features Discriminative Localization, CVPR 2016  (0) 2025.01.13
'논문 리뷰' 카테고리의 다른 글
  • (AdaIN) Arbitrary Style Transfer in Real-Time With Adaptive Instance Normalization, ICCV 2017
  • (CycleGAN) Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, ICCV 2017
  • (Grad-CAM) Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, ICCV 2017
  • (CAM) Learning Deep Features Discriminative Localization, CVPR 2016
hangyuwon
hangyuwon
  • hangyuwon
    191
    hangyuwon
  • 전체
    오늘
    어제
  • 글쓰기 관리
    • 분류 전체보기 (38)
      • 기타 (1)
      • Stanford CS231n (19)
      • 논문 리뷰 (5)
      • Error (4)
      • 알고리즘 (2)
      • Linux (1)
      • 잡동사니 (2)
      • 딥러닝 (4)
  • 인기 글

  • 태그

    error
    알고리즘
    논문 리뷰
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
hangyuwon
(GAN) Generative Adversarial Nets, NeurIPS 2014
상단으로

티스토리툴바