5

Easily Improve Any GAN with Metropolis-Hastings

 3 years ago
source link: https://mc.ai/easily-improve-any-gan-with-metropolis-hastings/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Easily Improve Any GAN with Metropolis-Hastings

Create higher fidelity samples with no additional training in TensorFlow 2

How can we improve our GAN with as little work as possible? (Image from Pixabay )

Ever had trouble getting believable outputs from your GAN? In this article we are going to review a recent improvement to GANs that boosts output quality with an easy-to-follow Tensorflow 2 implementation to get your own Metropolis-Hastings GAN off the ground!

About the Paper

In late 2018, Uber Engineering released a paper entitled “Metropolis- Hastings Generative Adversarial Networks” (2018, Turner et al.), in which the authors introduce a simple trick that improves the output of GANs with no additional train time of the GAN. If you are not very familiar with GANs, I suggest reading this article before proceeding to get a better understand of the models we are dealing with.

A visual from “Metropolis-Hastings Generative Adversarial Networks” (2018, Turner et al.) showing how MH can improve the outputs of Progressive GANs

The basic premise of Uber’s work is that the discriminator, which is normally thrown out after training, contains valuable information in determining which samples resemble the true data distribution. The authors propose a novel sampling scheme in which samples are drawn from the latent space of the generator using Metropolis-Hastings with the discriminator choosing which generated samples to accept.

Figure 1: The Metropolis-Hastings GAN pseudocode from the Uber paper

What’s ingenious about this method is that it works with any GAN architecture or loss function , as MH-GAN is really just a post-training sampling methodology. Before we show how to apply this sampling method, let’s first train a GAN to use it on!

Training our GAN in TensorFlow 2

To start on our journey of training a MH-GAN, we are first going to a train a Least-Squares GAN due to its stability and ease of implementation. I included a link at the bottom of the article if you want to learn more about LS-GANs. For the data, we are going to be working with the Fashion MNIST dataset included in TensorFlow so these results are easy to replicate at home!

Please note that some of the training code below is adapted from the official TensorFlow DC-GAN tutorial located here :

Note that we are also using the tensorflow-addons package to implement some more advanced functionality such as the Ranger optimizer and Mish activation layers to smooth the training process.

Defining Our Generator and Discriminator

The next step in training our LS-GAN is to defining our generator and discriminator objects to feed into our LS-GAN class:

We have a base class to train a Least-Squares GAN as well as a generator and discriminator to feed into it, now its time to dive into implementing the MH-GAN!

The Metropolis-Hastings Algorithm

The first thing we need to understand when implementing a MH-GAN is the Metropolis-Hastings algorithm. MH is a Markov-Chain Monte-Carlo method for sampling from a distribution that is difficult to directly sample from. In this case, we are trying to sample from the true data distribution that is implied by the discriminator.

We won’t dive too deep into the theory behind Metropolis-Hastings, but I have included a link at the bottom of this article if you want to learn more.

Once we have a trained generator and discriminator, we use the following steps to draw samples that follow the true data distribution (feel free to refer to the pseudocode above):

  1. Seed the chain with a real sample of data (the authors do this to avoid the lengthy burn-in period that can be entailed with Metropolis-Hastings)
  2. For K number of iterations, we draw a random sample from the generator and calculate the score from the discriminator
  3. At each iteration, if the following expression is greater than a random number in the range [0,1] then the new sample is accepted and compared to the next iteration:
Figure 2: Condition for accepting a new sample in an MH-GAN (from the Uber MH-GAN paper)

Essentially what this rule is doing is comparing the ratio of the proposed sample’s discriminator score to the current sample’s score to determine if we should “accept” the new sample. With enough iterations our independent chain will approximate the true data distribution implied by our discriminator’s output!

If we fail to accept any samples after a certain number of iterations, the authors recommend restarting the chain from a synthetic sample to guarantee convergence, albeit at a slower pace. However, this is not the end of the story , as there is an outstanding issue with the output of our discriminator, which we will address in the next section.

Calibrating Our Discriminator

One of the assumptions made in applying Metropolis-Hastings is that the probabilities of our discriminator are well-calibrated. This means that the discriminator output must be interpretable as a confidence level. Unfortunately, the raw outputs of a neural network classifier are almost never well-calibrated, meaning we have to apply an additional step to calibrate the outputs. This is true even if we use a sigmoid or softmax activation on the final layer!

The Scikit-Learn documentation has a great section on calibration here , but what the authors of the paper recommend is using a technique called isotonic regression to convert the network outputs to calibrated probabilities as a post-processing step. Unfortunately, this needs to be done on a held-out set, meaning we have to divide our Fashion-MNIST into a training and validation set (fortunately Tensorflow does this for us!)

Below is a calibration curve (a diagnostic tool that illustrates the calibration of a model, the 45 degree line represents perfect calibration ) on held-out data showing how our isotonic predictions satisfy the condition of Metropolis-Hastings. Our raw outputs, on the other hand, are clearly poorly-calibrated.

Figure 3: Calibration curve showing improvement after applying isotonic regression to discriminator outputs

Once we have our discriminator calibrated, we can then use the calibrated probabilities in determining the acceptance probabilities of new samples in the independent MCMC chain we outlined above!

Coding our MH-GAN

Now its finally time to put the calibration and Metropolis-Hastings sampling into practice with a subclass of the LS-GAN we coded above. Note that we could easily apply this subclass to any other sort of GAN that follows the API we outlined above, or even adapt it to existing GAN packages.

Note that there are a couple of things going on here, first we need to implement a method to train our calibrator as well as a calibrated scoring method.

For training the calibrator, we simply take a held-out set and generate the same number of samples from the generator. Next, we train the IsotonicRegression class from Scikit-Learn to translate the discriminator scores into true probabilities.

Finally, we have our generate_mh_samples() method, which generates a specified number of accepted samples given a true example to seed the MCMC chain. Can you see a common theme here? The calibrated scoring and MH sample generation methods are simply the MH-GAN counterparts to the score_samples() and generate_samples() methods from the LS-GAN!

Time to wet our feet with this new class using the Fashion-MNIST data!

Comparing the Results

To start using our shiny new MH-GAN class, the first thing we have to do is load our Fashion-MNIST data and separate it into our training and validation sets, as well as normalizing the pixels. Fortunately, tf.keras makes this super easy!

Now that we have our training and test sets, all we have to do is initialize our MHGAN object with the generator and discriminator we declared above, as well as the dimensionality of our noise for the generator and the optimizer learning rate. We are going to then train for 50 epochs (this was rather arbitrary, usually this needs to be more thought-out for GANs), and finally calibrate our discriminator. Our object-oriented approach allows us to do this in 3 lines of code!

After the training is completed, we can draw samples from our trained MH-GAN by using either our normal naive GAN sampling, or the Metropolis-Hastings algorithm from the Uber paper.

Let’s take a look at 15 images generated using LS-GAN, MH-GAN, and some true images to compare:

Figure 4: Comparison of LS-GAN and MH-GAN outputs with true samples from the Fashion-MNIST dataset

Wow!We can definitely see that the MH-GAN nearly entirely avoids the issue of clearly fake samples! This finding is in line with the finds of the Uber team, where they reported on CelebA dataset their method drastically reduces the number of low-fidelity images. We can see instances of samples from the LS-GAN that just look like ill-defined blobs, whereas this occurs much less frequently in our MH-GAN.

The team also argues that it address the mode collapse problem, but in my trials of their method I didn’t really see this occurring, but MNIST and Fashion-MNIST aren’t always the best indicators of algorithm performance.

Conclusions

If you are interested in applying the quick improvement to GANs we demonstrated in this article, I would highly recommend reading the original MH-GAN paper which is linked below! Furthermore, we have shown that TensorFlow 2 makes it extremely easy to apply technique with very little code!

There are quite a few other intriguing papers regarding discriminator-based GAN sampling, such as “Discriminator Rejection Sampling” by Azadi et al. and “Your GAN is Secretly an Energy-based Model and You Should use Discriminator Driven Latent Sampling” by Che et al., showing the promise of this field. No doubt it will be interesting to see where this goes over the next few years.

If you have any questions about this article, please feel free to reach out!

References

Link to the original MH-GAN paper by Uber:

Original website article by Uber Engineering on MH-GANS:

Helpful article on Least-Squares GAN

A quick tutorial to Metropolis-Hastings sampling:


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK