-
WGAN-GP 리뷰(Wasserstein GAN with gradient penalty)논문 리뷰/Generative Model 2024. 7. 22. 19:14
2017년 NeuIPS에 게재된 WGAN-GP 논문을 리뷰하고자 한다. 이 논문은 WGAN의 발전된 버전으로 WGAN과 주저자가 동일해 WGAN에 대한 설명은 많이 생략했다. 따라서 이 논문을 읽기 전에 WGAN에 대한 이해가 필요하다.
(틈새 홍보) WGAN 논문 리뷰 ↓ ↓ ↓ ↓ ↓
https://jjo-mathstory.tistory.com/entry/WGAN-%EB%A6%AC%EB%B7%B0Wasserstein-GAN
WGAN-GP에서 GP는 gradient policy를 의미한다. 이 논문의 가장 중요한 핵심인 GP는 기존 WGAN이 weight clippling을 통해 Lipschitz constraint를 만족시키고자 했던 한계를 극복하는데 사용된다.
We propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input.
Abstract에 있는 이 문장이 이 논문의 핵심문장이자 전부라고 봐도 무방하다.
그럼 이제 WGAN-GP에 대해 알아보자.
Introduction
WGAN에서는 critic(GAN에서의 discriminator)가 1-Lipschitz functions space에 있어야 하며, 이를 위해 저자는 weight clipping을 사용하였다. ( critic의 parameter에 제약조건을 추가하여 일정 범위 안에서만 존재하고록 강제)
이 논문에서는
1. 왜 weight clippling이 WGAN의 학습을 방해하는지를 설명하고
2. gradient policy를 제안한다.(WGAN-GP)
* GAN과 WGAN 소개는 생략합니다. GAN 포스팅과 WGAN 포스팅을 참고하세요!
Proposition
더보기증명)
Proposition 1과 그 결과를 잘 살펴보자. 이 증명이 곧 Gradient policy를 어떤 식으로 정의할 지 정하는 Proposition이다.
Difficulties with weight constraints
1. capacity underuse
clipping biase를 통해 k-Lipschitz constraint를 구현하면 모델이 지나치게 심플해진다. 아래 그림에서 보이듯이 weight clipping을 통해 학습된 WGAN은 복잡한 형태의 모델을 잘 학습하지 못하는 경향을 보인다.(weight clipping ignores higher moments of the data distribution)
2. Exploding and vanishing gradients
weight clipping에서는 clipping threshold c에 값에 따라 vanishing gradient 또는 exploding gradient 문제가 발생할 수 있다.
Gradient Penalty
Lipschitz constraint를 강제하기 위해 우리는 더이상 weight clipping 기법을 사용하지 않고, 대신 gradient penalty를 사용할 것이다. 미분가능한 함수가 1-Lipschitz일 필요충분조건이 이 함수의 gradient norm이 1이라는 점을 감안해 우리는 이제 gradient norm이 1이 될 수 있도록 강제할 것이다. 아래는 우리의 새로운 목적함수다.
(1) Sampling distribution : $P_{\hat{x}}$는 $P_r$과 $P_g$의 내분하는 지점에 있는 분포로 생각하자. Proposition 1에 따라 최적화된 critic은 $P_r$과 $P_g$로부터 각각 추출된 샘플의 내분점들이자 gradient norm이 1인 샘플들을 포함하고 있어야 한다. 사실 우리는 unit gradient norm을 가진 모든 점들에 대해 $P_{\hat{x}}$를 고려해보아야 하지만, $P_r$과 $P_g$로부터 각각 추출된 샘플의 내분점에 대해서만 gradient norm을 1로 강제해도 학습이 잘 된다는 점을 알아냈다.
(2) Penalty coefficient : 이 논문에서는 $lambda = 10$으로 두었다.
(3) No critic batch normalization : 앞선 GAN 계열 논문들에서는 batch normalization을 generator와 discriminator에 모두 사용함으로서 성능을 향상시켰다. 하지만 WGAN-GP에서는 각각의 input에 대해 critic의 gradient norm에 penalty를 주고 있으므로 batch 단위로 업데이트하는 batch normalization은 적절하지 않다. 다만 각각의 input에 대해 normalization을 하는 것은 좋은 방법으로 보여진다.(=layer normalization)
(4) Two-sided penalty : WGAN-GP에서는 critic의 gradient norm이 1이 되도록 강제하므로 1보다 작거나 같게 만드는 one-sided penalty가 아닌 two-sided penalty를 부여한다.(1보다 작거나 같게 하면서 동시에 1보다 크거가 같게 강제한다)
실험 결과
* 실험 결과 자체는 생략하고 시사점만 정리
5.3 (Improved performance over weight clipping) WGAN-GP는 WGAN 모델에 비해 training speed와 sample quality 측면에서 뛰어나다.
5.5 (Modeling discrete data with a continuous generator) degeneratr distribution에 대한 WGAN-GP 성능을 알아보기 위해 character-level의 GAN language model을 고안해보았다. 이 논문에서는 굉장히 심플한 generator를 사용해 간단한 실험을 해보았는데, 기존 GAN 모델로는 만들기 어려웠던 유의미한 샘플을 만들었다. 언어 모델로서의 WGAN-GP가 성공할 수 있었던 이유는 아래와 같다.
5.6 (Meaningful loss curves and detecting overfitting) WGAN-GP's loss는 sample quality와 연관이 있다. 이 실험에서는 일부로 critic이 과적합된 GAN 모델을 구현해보았는데 왼쪽에 있는 WGAN-GP의 경우 훈련 손실은 증가하지만 검증 손실은 감소하는 양성을 보인다. 이는 critic이 generator보다 빠르게 과적합되면서 아래와 같은 양상을 띄게 된다.
코드
WGAN-GP 코드 출처 :
https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py
알아보기 쉽게 코드로 구현해놓은 깃허브 페이지를 찾아서 소개한다.
여기서 눈여겨 보아야 할 부분은 gradient에 penalty를 부여하는 부분이다!
위 코드에서는 compute_gradient_penalty를 계산하고 이를 discriminator 학습 시 loss에 더해준다.
compute_gradient_penalty 구현은 아래와 같다.
def compute_gradient_penalty(D, real_samples, fake_samples): alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates = D(interpolates) fake = Variable(torch.Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad = False) # 주어진 출력에 대한 입력의 그래디언트를 계산 gradients = autograd.grad( outputs = d_interpolates, inputs = interpolates, grad_outputs = fake, create_graph = True, retain_graph = True, only_inputs = True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim = 1) -1) ** 2).mean() return gradient_penalty
나머지 부분은 레퍼런스에 달아놓은 깃허브 홈페이지를 참고하면 된다.
'논문 리뷰 > Generative Model' 카테고리의 다른 글