A brief introduction to Generative Adversarial Networks
Why should we care about Generative Adversarial Networks (GANs for short) in the first place?
Well… take a look at the below images.
What do you think about them?
If you would see one of these images somewhere on the internet, would you suspect something?
Let’s say, for example, that someone follows you on Twitter and has one of the above images as a profile picture. Would you know how to tell if it is a fake image or a real person?
The people in the images above do not exist. These images are generated by a GAN. And they are almost indistinguishable from real people.
If you zoom in enough the images above you can see some weird artifacts that makes you realize they are fake. But, if they don’t have a big enough resolution, you would not be able to tell the difference.
GANs are not limited only to generating fake images of people, but they can be used for a large variety of other applications as well.
Now, that we have an idea of what GANs are capable of, let’s try to understand what exactly a GAN is, how it learns, and where it fits in this big bag of machine learning tricks.
Discriminative and Generative Models
There are 2 main types of machine learning models: discriminative and generative.
Discriminative models are the ones that learn the conditional probability P(y|x). Where x is the input, and y is the output of the model. What this conditional probability means is that for a given input x, our model is able to tell us what y may be, what are some likely values for y.
A discriminative model learns a mapping from inputs to outputs, which can be useful for performing classification or regression tasks. These are the most common models that are taught in most beginner-level machine learning courses/books.
Generative models, on the other hand, learn the joint probability distribution P(x, y) which is much richer than the conditional. If we know the joint probability P(x, y) we can also find the marginals P(x), P(y), and the conditionals P(x|y), P(y|x).
So, a generative model can be used to draw examples from the distribution P(x, y) that it learned.
For example, if we use a dataset with images of cats and dogs, a generative model, once it learned enough from those examples, it can produce new images of cats and dogs. While a discriminative model can only distinguish between cats and dogs; it cannot produce images, no matter how good it is at classification.
As you probably guessed, GANs are part of the generative models category, but they are not the only such models. Other types of generative models are Variational Autoencoders, Boltzmann Machines, Gaussian Mixture Models, and many others…
What is a GAN made of and how it learns
GANs are actually made of 2 models: a discriminative and a generative model.
“Wait. You didn’t say that GANs are generative models? What’s up with that discriminative model?”, you may think.
Yes, GANs are generative models in the sense that we are mostly interested in the generator. But we need a discriminative model to train the generative one.
This is what makes GANs more special than other machine learning algorithms: the adversarial training.
These 2 models inside a GAN act like 2 enemies that fight each other: generator vs discriminator; hence the name adversarial. The generator’s job is to produce realistic fakes to fool the discriminator and the discriminator’s job is to tell the difference between what’s fake and what’s real; to catch the generator, basically.
The way GANs learn is similar to a Minimax game about which you can read more in this article:
In most other deep learning algorithms there is only one loss function that has to be minimized. So the learning process consists of finding the network’s parameters that minimize that loss function.
But now, in a GAN, both the generator and the discriminator have their own objective function. The discriminator D has as objective to maximize it’s chances of making a good classification of fakes vs reals. The generator G has as objective to minimize D’s chances of making a good classification.
At the start of the algorithm, both G and D are initialized with random parameters, so both of them will do a horrible job. But gradually, they will start getting better. Each one will improve itself iteration by iteration. The discriminator will get better at recognizing fakes and the generator will get better at fooling the discriminator.
At some moment, though, the generator should be producing very realistic fakes, and the discriminator will not be able to differentiate between them, getting stuck around 1/2 accuracy in detecting fakes.
The generator and the discriminator in a GAN are implemented as feed-forward neural networks. But, feed-forward neural networks are just function approximators, they learn mappings from inputs to outputs, that’s all they do. Therefore, in a GAN, the so-called generator is not actually able to “generate” images from scratch. We will use a random number generator to generate values according to some probability distribution (typically a Multivariate Normal Distribution), and then the generator neural network has to learn to map Multivariate Normal Random Variables to “Fake Image Random Variables”‘. This generator neural network learns to transforms a Multivariate Normal Distribution to a probability distribution over all the possible images for the given resolution. And it does so in such a way so that when we draw something from this output distribution, we are likely to get an authentic-looking image. In the generator’s output distribution, the combinations of pixel values that don’t look like a real image should have a probability close to 0.
Each iteration in a GAN’s training has 2 phases: one for the discriminator and one for the generator.
When we train the discriminator, we should keep the generator constant and use it only for producing fakes that can be fed to the discriminator along with real images.
Then we use the Binary Cross Entropy (BCE) to judge how good the discriminator’s predictions are compared to the true labels.
The true labels are: 1 for real images and 0 for fakes.
We take the gradient of the BCE loss with respect to (w.r.t.) the discriminator’s parameters and use those for the update.
When we train the generator the process is almost the same as for discriminator, but with 2 big differences: we use only fake images, and we use misleading labels. What this means is that we want the generator to get better at fooling the discriminator. And to do so, we give to our network the opposite labels of what should predict the discriminator. These are the misleading labels: 0 for real images and 1 for fakes.
A few closing thoughts…
What I found the most interesting about GANs is not the generative part, but the adversarial one. And that’s because it is more related to how humans learn. We, humans, don’t have an explicit loss function. How we learn and how we set goals is pretty much determined by our social interactions that we take part in.
A GAN simulates this social interaction through the adversarial game it plays. This idea would be interesting to be applied to discriminative models as well.
A very interesting thing is that the generator doesn’t actually “see” any images. It is trained only based on the feedback received from the discriminator, which can have a regularization effect.
An disadvantage of this adversarial training is that these networks are harder to train than other types of networks. The discriminator and the generator need to be in good balance, otherwise they will not be able to learn from each other effectively.
Here is the original paper on GANs in case you want to read it.
I hope you found this information useful and thanks for reading!
This article is also posted on Medium here. Feel free to have a look!