ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • GAN을 어떻게 학습할 것인가(How to Train a GAN?)
    논문 리뷰/Generative Model 2024. 7. 18. 17:42

    GitHub - soumith/ganhacks: starter from "How to Train a GAN?" at NIPS2016

     

    GitHub - soumith/ganhacks: starter from "How to Train a GAN?" at NIPS2016

    starter from "How to Train a GAN?" at NIPS2016. Contribute to soumith/ganhacks development by creating an account on GitHub.

    github.com

     

    cs236을 듣다가 알게된 페이지인데 흥미로워서 한국어로 정리해보고자 한다.

    정확한 내용은 위에 페이지를 참고하면 된다!

     

    1. Normalize the inputs : 입력값을 정규화하자

    -1에서 1사이의 값으로 정규화하고 Tanh를 generator의 마지막 output layer로 사용하자

     

    2. A modified loss funcion : loss function을 수정해서 사용하자

    GAN 논문에서는 G를 최적화하기 위한 손실함수를 $min \log{1-D}$로 정의하고 있지만, 실제로는 $max \logD$를 사용하자.(즉, fake과 real의 label을 서로 바꾸자는 말이다) 이렇게 하면 vanishing gradient 문제를 좀 더 손쉽게 다룰 수 있다.

     

    이유) 처음에 Generator가 만들어 낸 샘플은 fake로 판단될 확률이 높다. 즉 D(G(z))는 0에 가까운 값을 갖을 것이며, 이는 gradient vanishing problem을 겪을 가능성을 높인다.

    생성자의 수정된 손실 함수

    또한 실전에서는 label flip기법을 사용하기도 한다. label flip기법은 생성자를 훈련할 때 일부 실제 데이터를 가짜로, 가짜 데이터를 진짜로 라벨링하여 Discriminator를 혼랍스럽게 하여 Generator를 정교하게 학습시키는 방법이다.

     

    3. Use a spherical Z : uniform distribution이 아닌 gaussian distribution에서 샘플링하자.

    spherical sampling을 하면 잠재 공간의 모든 방향이 동일하게 취급되어 생성되는 데이터의 다양성을 높이고, 모델이 더욱 일반화된 패턴을 학습할 수 있도록 돕는다. 또한 직선 보간이 아닌 대원(Great Cricle) 보간을 하면 보간된 값이 구의 표면을 따라 움직이므로 보다 자연스러운 보간값을 얻을 수 있다.

     

    4. BatchNorm : real과 fake을 섞어 batchnormalization하지 말자.

    batchnorm 사용시 real data와 fake data를 mini batch에 섞어 사용하지 말자. 하나의 미니 배치에는 모두 real data만, 다른 미니배치에는 모두 fake data만 정규화하여 훈련하는 것이 더 안정적이다. 

    reference : ganhacks/images/batchmix.png at master · soumith/ganhacks · GitHub 마

    만약 batch normalization을 사용하기 여의치 않다면 instance normalization을 사용하자.

    ✨Instance Normalization : 배치 정규화의 대안으로 각 샘플(이미지)마다 별도의 평균과 분산을 계산해 정규화

     

    5. Avoid Sparse Gradients (ReLU, MaxPool) : gradient가 너무 작아지지 않도록 하자.

    LeakyReLU는 Generator와 Discriminator에게 모두 좋다. 

    Downsampling : use Average Pooling, Conv2d + stride

    Upsampling : use PixelShuffle, ConvTranspose2d + stride

     

    6. Use Soft and Noisy Labels : label 부여시 smooth하고 noisy하게 부여하자.

    Label Smoothing : 만약 real = 1, fake = 0으로 라벨링하고 싶다면 real에 대해서는 (0.7, 1.2) 사이 랜덤 숫자로, fake은 (0.0, 0.3)사이 랜덤 숫자를 부여하자. 이렇게 label을 약간의 불확실성과 함께 discriminator에게 제공하므로써 고정된 레이블에 discriminator가 지나치게 확신하게 되는 과적합(overfitting)현상을 방지할 수 있다.

     

    또한 Discriminator 훈련시 라벨을 가끔씩 뒤집는 Label Noise를 활용하여 Discriminator가 real과 fake를 쉽게 구분하지 못하도록 하고, Generator가 더 강력한 이미지를 생성하도록 유도한다. 

     

    7. DCGAN / Hybrid Models : 할 수 있다면 DCGAN을 사용하자!

    만약 DCGAN을 사용할 수 없다면 다양한 모델을 섞어서 사용하자.

    KL + GAN 또는 VAE + GAN의 조합을 추천한다.

     

    8. Use stability tricks from RL(Reinforce learning) : 강화학습에서 얻은 여러 팁을 사용하자

    - 강화학습에서 유래된 Experience Replay기법을 활용해 GAN을 좀 더 안정적으로 학습하자.

    - Deep deterministic policy gradient를 사용해보자.

    Pfau & Vinyals (2016)의  'Connecting generative adversarial networks and actor-critic methods'를 참고하자.

    (링크 : 1610.01945 (arxiv.org))

     

    Experience Replay : 과거의 경험을 저장하고, 이를 학습 과정에서 재사용하는 기법

    - 과거에 생성된 이미지를 replay buffer에 저장한 후, 주기적으로 이 버퍼에서 이미지를 샘플링하여 판별자에게 다시 보여준다. 이는 판별기가 과거의 이미지를 기억하고, 더 다양한 데이터 분포를 학습하도록 돕는다.

    - 생성자와 판별자의 이전 상태를 저장하는 checkpoints를 유지하고, 주기적으로 이 checkpoint를 교체하여 몇 번의 iterations 동안 훈련한다. 이는 모델이 과거의 학습 상태를 반영하고, 새로운 학습을 촉진하도록 한다.

     

    9. Use the ADAM Optimizer : ADAM 옵티마이저를 사용하자. 

    판별자에는 SGD를, 생성자에는 ADAM을 사용하자.

     

    10. Track failures early : 판별자와 생성자의 손실, gradient norm 등을 통해 학습이 잘 되는지를 확인하자

    → 판별자의 손실이 0으로 수렴하는 경우 : 생성자의 가짜 이미지를 완벽하게 구분하고 있다는 의미

    → gradient norm > 100이라면 문제가 발생하고 있다는 신호

    → 판별자 D의 손실이 작지만 일정하게 유지되는 것이 이상적

    → 생성자 G의 손실이 꾸준히 감소하지만 이미지 품질이 낮은 경우, 판별자 D를 쓰레기(garbage) 샘플로 속이는 중😏

     

    11. Don't balance loss via statistics(unless you have a good reason to) : D와 G의 훈련스케쥴 조정은 어렵지..🥲

    생성자와 판별자의 훈련 스케쥴을 조정하는 것은 매우 어렵다!! 그러나 만약 훈련 스케쥴을 조정하고 싶다면 직관에 의지하기 보다는 원칙에 기반한 접근법을 추천한다.

    예시) 
    while lossD > A:
      train D
    while lossG > B:
      train G

     

    ✨원칙에 기반한 접근법

    - 학습률 조정 : D와 G의 학습률을 각기 다르게 설정하여 학습의 불안정성을 줄인다. 일반적으로 판별자 D의 학습률을 더 낮게 설정한다. 학습 중 동적으로 학습률을 조정하기도 한다.

    - Gradient Clipping 

    - 혼합 손실함수 사용 : 여러 손실함수를 결합하여 사용한다. 예를 들어 L2 손실이나 direction 손실 등을 추가하여 손실함수를 설정하면 생성된 이미지의 품질을 향상시킬 수 있다.

    - 정규화 기법 : Batch Normalization이나 Spectral Normalization을 사용하자

    - Momentum 기반 Opimizer 사용 : ADAM, RMSProp등을 사용하자.

    - Target Network 사용 : RL에 사용되는 Target Network 기법을 사용하자.

     

    12. If you have labels, use them : Auxiliary GNA을 사용해보자

    라벨이 있는 데이터를 사용해 판별기가 진짜와 가짜를 구분하는 것뿐 아니라 샘플을 분류하도록 훈련할 수 있다. 이러한 접근 방식을 Auxiliary GAN(또는 Conditional GAN)이라고 부르며, 이를 통해 GAN의 성능을 높일 수 있다.

     

    13. Add noise to inputs, decay over time : Noise를 추가하면 학습이 잘 될수도 있다

    (1) 판별자의 입력에 노이즈를 추가하거나 (2) 생성자의 모든 층에 노이즈를 추가하고 이를 감소시키는 방향으로 GAN을 학습하면 mode collapse와 unstability를 완화하는데 도움이 된다.

     

    14. [notsure] Train distriminator more(sometimes)

    만약 데이터에 noise가 있거나 판별자와 생성자의 훈련 스케쥴을 알기 어려울 때는 Discriminator를 좀 더 학습해보자

     

    15. [notsure] Batch Discrimination : Mixed results

     

    16. Discrete variables in Conditional GANS : embedding layer를 사용하자.

    라벨이 있는 데이터를 효과적으로 활용하기 위해 embedding layer를 사용하고, 이를 이미지의 추가 채널로 제공한다. embedding layer의 차원은 작게 유지하며, 이미지 채널의 크기에 맞게 upsampling하여 GAN 모델에 통합해 사용한다.

     

    17. Use Dropouts in G in both train and test phase : Dropout을 생성자의 train뿐만 아니라 test에서도 사용하자

    Dropout은 일반적으로 train시에만 사용되는데, GAN 훈련에서는 테스트 시에도 Dropout을 적용하여 모델이 더 견고하게 학습되도록 한다. 

     

     

    댓글

Designed by Tistory.