Tech Corner - 8. April 2021
header_image

What’s wrong with you GANs? Or how to avoid problems when working with GANs

Hey, my fellow readers, if you read this blog post, it means that you are already familiar with the concept of GAN. That's good, because I'm not going to explain some details here, like architecture or math. I want to share my experience with the thorny training process of these neural networks instead.

GAN meme

(Created by me on imgflip.com)

So where should I start? …Let me tell you a bit about the project I'm working on currently. It's my diploma thesis that focuses on different use cases where GAN and 5G networks can work together to create a better place to live (in theory). I don't want to write about the details of the work here, but I can tell you that I used the pix2pix GAN architecture (see this link for more info), which belongs to the category of conditional GANs.

The first step was, of course, coding the overall architecture (I used pytorch for that) and preparing the data for training. Everything went pretty well until I started doing experiments, which also means training the networks.

Here are 3 issues I struggled with:

  • Mode collapse
  • Slow convergence
  • Deceptive loss function

You're probably quite curious about them. So let's bring more light into it.

Mode collapse

Probably the most common issue with GAN training is mode collapse, and believe me, it can be really difficult to solve. You can observe this when your generator produces a small amount of really similar outputs, which seems OK for a discriminator, but it's actually trash. The core of this issue can be explained by exhaustive math, but to keep it simple, I will use my own words.

The main point is that the generator somehow finds a local optimum in the sample space, which looks good for the discriminator at some point. This often happens when the discriminator is not yet properly trained to recognize what is correct and what is incorrect (the whole architecture collapses to a point where there is no going back).

As a result of this collapse, the generator always produces a very similar small subset of samples to reduce its loss and the discriminator does not improve at all, so there is no way to punish the generator and force it to produce something new and better.

There is no obvious and perfect solution for this issue and it may vary from project to project, but there are some of the most common tricks (the last one worked for me):

  • Adding layers to your generator/removing layers from discriminator
  • Try to train your GAN longer
  • Perform some parameter tuning on both networks
  • Add some dropout and batch normalization layers to the generator

Deceptive loss function

Despite solving the previous problem, I still felt that something was wrong. We all know that the loss function should be something very useful when it comes to neural network training, but in my case and in the case of GANs in general, you need to be aware of what the loss function is trying to tell you. I'll explain it further.

After several epochs, the program shows me plots of loss functions for the generator and discriminator. At first glance, it looked quite solid, both curves were declining and there was no sign of dominance between the nets. Do you still remember what I said about loss functions? OK, because here it comes...the plot twists.

As a next step, I checked the generated images and after about 10,000 epochs, the results were really bad and far from my expectations. I realized that the loss function might not be the best metric to measure training progress, and after some Google magic 🙂, I found that it was recommended to use additional (and custom) metrics to track GAN performance. Blessed by this knowledge, I used well-known metrics in computer vision tasks (SSIM) as well as my own metrics. ...and guess what? It really helps me spot mistakes in the early stages of training.

Slow convergence

I can imagine that now you have to be like, "Wow, so much pain must have a happy ending." Well, my friend, there is, but first I will introduce you the term slow convergence.

(Created by me on imgflip.com)

In the previous section, I mentioned that my program ran 10,000 epochs. It took my GPU (Nvidia RTX 2070 Super) about 3 hours, which is quite a lot of time. But after a few experiments, I found that to generate images of the desired quality, it was necessary to run a program for almost 50,000 epochs. To reach this number of epochs, my GPU was forced to run for 21 hours.

My recommendation after this observation is that even if your model has poor performance from the beginning, you should always try to run it a little longer. Another solution for slow convergence and poor results may be to remove some of the inner layers from the generator. You can also try this process on a discriminator.

My second recommendation is that if you really want to play with GANs, you should buy some GPU with a decent number of CUDA cores, otherwise it will take forever to get some results.

That's all, this time. You are now ready to fight these obstacles in your projects.

Dávid Hreško

READ MORE