2

The story about WGAN

 2 years ago
source link: https://medium.com/@sunnerli/the-story-about-wgan-784be5acd84c
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

The story about WGAN

Abstraction

Generative adversarial network (GAN) is a popular issue in these years in deep learning area. Unlike to the action of discriminative model, the generative model can give the joint probability between the data and parameters. However, the instability and rigid architecture make it hard to generate the image as our expectation.

In this article, I’ll try to explain the progress of EM distance. At first, the failure and awkward theory behind original GAN will be discussed. Next, I’ll introduce the concept of WGAN. At last, the improved version of WGAN will be shown.

To be notice, the concept of process is related to the huge knowledge such as topology, information theory and probability distance measure. The containing is too hard so that I will skip most part of inference. You can just refer the original paper to check the proving behind the architecture.

GAN and its loss function

Figure 1. The basic concept of generative adversarial network

In 2014, Ian Goodfellow purposed the structure called generative adversarial network which contains the two components: generator and discriminator. The training process is just like zero-sum game, and it can be simply shown in Figure 1.

For generator, it should generate the image which is just like the real one. On the contrary, the discriminator should distinguish the image is fake or not. During the training, the generator should make itself have more capability to generate image which is more and more like the actual one, and the discriminator should make itself realize the difference with more and more accuracy.

Figure 2. The formula of relative entropy

In fact, the concept of generative model comes from the vanilla auto-encoder (VAE) which is also a popular generative model before GAN was invested. The goal of this model is minimize the related entropy (also called KL-divergence) between the generative distribution and real data distribution. You can check the formula in Figure 2. If the KL-divergence is at minimum position, the probability of real data is equal to the probability of generative data.

Figure 3. the mode dropping circumstance

In the training process, You should make sure both of the components should learn with the same speed. Take a look at the formula of KL-divergence,
if Pr(x) > Pg(x), it means that the image is just like the real one under the knowledge of discriminator. In this case, the collapse problem (also called mode dropping) might occurs. The Figure 3 shows the result after the mode dropping happens. The model will tend to generate the real image while most of them are almost the same. However, if Pr(x) < Pg(x), it means the image is just like a fake one. On the contrary, the KL-divergence of this case is small! So the model will give very little penalty toward this situation.

Figure 4. The inference about optimal discriminator

What’s the formulation of the optimal discriminator? The inference is shown in Figure 4. We can extend the loss function with the weight of the corresponding probability. Next, We take the partial derivation toward the loss function, and we can get the third line result. If the image is a real one, then it should give the equal probability which value is 0.5. On the other hand, if the image is a fake one, the result is 0.

The main problem: optimal discriminator

Figure 5. The objective of usual GAN

However, Martin et al. claimed[1] that the optimal discriminator is the main problem, and it leads the instable in usual GAN. The Figure 5 shows the loss function of usual GAN. We can just put the optimal discriminator into the loss function, and see what’s the result under this situation. To be notice, the JS-divergence is another metric which measures the distance between the both distributions and the average distribution.

Figure 6. The inference of JS-divergence in usual GAN

So we can get the new message: unlike shrinking the KL-divergence in VAE case, the GAN reduces the JS-divergence respectively. As the result, if we keep training the discriminator, the value of loss will become 0, and the JS-divergence will keep in log 2.

However, this great result is based on the assumption: If the distribution is continuous, then we can get the gradient respectively. Is this assumption built anywhere?

Figure 7. The expected situation of two distribution in union space

In the general case, the both distributions have full dimension toward the union space which means the data may spread averagely in the data space. Take Figure 7 as example, if the generative data and real data are both gaussian-like distributions, then reducing the JS-divergence is to reduce the distance between these two gaussian. If the two distributions are close, then the generator can generate the data which are just like getting from real data distribution.

However, the fact is that the data don’t distribute correctly in the union space. Simplify to say, the two distributions may not have intersection. There’s one guessing which was written in paper:

One possible cause for the distributions not to be continuous is if their supports lie on low dimensional manifolds.

Figure 8. The first case of manifolds distribute in the union space

The paper splits as two cases to discuss. The first one is: the two distributions are perfectly-aligned. Take the Figure 8 as example, we assume that the vertical axis is called z. In the example, the Pr set distributes in the plane which is over than the z=0. However, the Pg set is laid on z=0. By Urysohn’s smooth lemma, we definitely can find the optimal discriminator that can distinguish the two distributions.

Figure 9. The second case of manifolds distribute in the union space

In fact, the probability that the two distributions are perfectly-aligned is small. As the result, the author discussed about another case: If the two distributions have sub-manifold which has transversallity toward each other, then the dimension of this perturbation is much lower than the dimension of the both data distributions. Moreover, we still can find the optimal discriminator!

The Figure 9 illustrates this case. For example, if the distribution of real data is just like the yellow plane, and the generative distribution is just like the green plane. The orange region is overlapping if we project our sight to the white background. Actually, they didn’t overlap in this area. The only region of the transversallity is a single line!

Figure 10. The relation between dimension projection and manifold

General to say, the function of the generator is to project the random vector to the image feature space. However, the dimension of manifold isn’t change. Take the Figure 10 for example, we assume the dimension of noise vector is 2, and the dimension of feature space is 3. By generator, the noise distribution can be projected to the position of orange line. However, although the data point are located in 3-dimensional space, the dimension of orange line (manifold) still remains 2.

By this simple example, the fact I want to show is: even though the generator can project the data point to the high dimensional space, the dimension of actual manifold isn’t change. As the result, it doesn’t increase the probability that the two distributions can have large transversallity.

Alternative method to improve

In the previous section, we know that the manifold of two distributions don’t have the intersection which cannot be ignored. The author also gave a recommendation to improve this phenomenon: add the random noise to force them to generate the transversallity!

Figure 11. The change of transversallity

Figure 11 demonstrates the result. For example, if the two manifolds are laid in 3-dimensional space just like the above image shows. Add the random noise can indeed dilate the manifolds. It may have more probability to have the intersection just like the yellow region.

However, adding noise isn’t the best solution. After the experiment, the proportion of noise is up to 0.1 while the model converge! Before we start to introduce the WGAN, the theory related to distance should be clarify first.

Different distance

Figure 12. The definition of different distance metric

For the distance measure of probability distribution, there are a lot of metric can be the choice which are shown in Figure 12. The most left one is total variation distance (TV-divergence); the second one is KL-divergence which has been well known in VAE; the third one is JS-divergence which is the main loss function in previous GAN model. Moreover, the last one is Wasserstein-1 distance which will be introduce in the next section.

Figure 13. The definition of distance if the two distribution are separated

The Figure 13 shows the origin description of critical cons about the divergence, and the Figure 14 illustrates this example. The green region is the data distribution of P0, and the orange region is the data distribution of . In the general case, the two distribution are separated.

However, if the actual dimension of two distribution is lower than the union, the KL-divergance will up to unlimited! On the contrary, the JS-divergance will keep as log2. This is the main reason why the generator cannot learn better since the distance measure is awkward! This property of discontinuous makes the model cannot get the continuous gradient.

Figure 14. Visualization toward the Figure 13

Simplify to make the conclusion about the reason why GAN is hard to train: In fact, The two distributions are separated in the union space. Under this situation, the divergence is a awkward number which cannot help the generator learn better. There is only one case which the GAN can learn successfully is that the two distribution has the intersection which cannot be ignored.

Neither KL-divergence nor JS-divergence can give the right direction to learn the capability, Martin et al. changed another metric — EM distance (or called Wasserstein-1 distance) in this paper[2]. The physical idea of EM distance is: how much work you should spend to transport the distribution to another one. As the result, the value is positive and the shape is symmetric. There are two properties that the EM-distance has:

  1. The function is continuous anywhere
  2. The gradient of the function is almost everywhere

In this article, the proving of the two rules will not be discussed. But by these two theory, we can avoid getting stuck in the saturate of training, and keep updating till it converges.

Figure 15. The duality relation of EM-distance

However, During finding the infimum, it’s hard to exhaust the whole possible sample in the joint distribution. By Kantorovich-Rubinstein duality method, we can approximate the problem into the dual format, and just find the supremum. The relation between the two form is shown in Figure 15. The only constraint is that the function should be the Lipschitz-1 continuous function.

Figure 16. The objective of WGAN

In the usual GAN, we want to maximize the score of classification. If the image is fake, the discriminator should give it as 0 score; if the image is real one, the 1 score should be gotten. In WGAN, it changes the task of discriminator as regression problem, and Martin renamed it as critics. The critics should measure the EM-distance that how many work should spend, and find the maximum case. You can verify it in Figure 16.

Figure 17. The algorithm of WGAN

The training process of WGAN is shown in Figure 17 which is very similar like usual GAN. There are only 4 difference:

  1. The critics will update for multiple times
  2. We don’t need to take logarithm (don’t use cross entropy) while computing the loss
  3. We should do weight clipping to satisfy the constraint of Lipschitz continuity
  4. Don’t use momentum-based optimizer (for example, Adam optimizer)
Figure 18. The convergence of gradient in different structure

After the experiment by Martin, the WGAN can avoid the problem of gradient vanishment. As you can see in the Figure 18, the gradient of usual GAN drops to zero and becomes saturate phenomenon. However, EM-distance provides meaningful loss and the model can still learn gradually.

Merge the weight clipping — WGAN-GP

After the WGAN was releasing, lots of developers found that there are some problems during generating the image. The speed of convergence is slow and the quality of image isn’t very well. Ishaan et al. pointed out some problem[3]:

  1. By the definition of Kantorovich-Rubinstein duality, there must be a point Xt exists such that the gradient can be compute directly, and each point has unique gradient. The following Figure 19 shows this concept. Between the distribution of real data and the distribution of generative data, there is a point that can help to compute the gradient.
  2. The weight clipping mechanism limits the capability of whole model.
Figure 19. The unique gradient assumption under antorovich-Rubinstein duality theory

The goal of weight clipping is to satisfy the constraint of Lipschitz continuity, and it can be implemented easily. However, this simplification causes the problem. As we know, the neural network has flexibility to fit much complex distribution. But if you just clip the learned weight to the rigid range, the model cannot adapt this constraint directly.

As the result, the model will tend to dilate the values of each component to the margin of limitation. The author also did the analysis to give the evidence to this phenomenon, and you can see the result in Figure 20.

Figure 20. The distribution of weight in WGAN

After the experiments, the author purposed another method to reach the limit of Lipschitz continuity: merge the limit term to the loss function. The idea of this changing is similar to add the constraint term in the mechanism of SVM. The only difference is that the lagrange multiplier is the optimal parameter that you should find by quadratic programming, but we just need to set it as constant in WGAN. This term is called gradient penalty.

Figure 21. The revised objective in WGAN-GP

The Figure 21 shows the revised loss function. However, in the definition of Lipschitz continuity, we should exhaust the whole possible sample in the joint distribution. Can it work after the transformation of dual format?

The answer is yes, but the author claimed that we don’t need to consider whole sample at all. In fact, Ishaan purposed that we just need to generate the combination between the two distribution, and only do the penalty toward these middle sample. This idea is like that generating the Xt points in Figure 19.

Figure 22. The algorithm of WGAN-GP

The revised training process is shown in Figure 22, and Ishaan gave it another fancy name: WGAN-GP. In this version, you can use momentum-based optimizer to update the model, and it doesn’t cause the loss explosion error. In the updating of critics, the gradient penalty term should be added into the loss function.

There is only one thing you should remember: Don’t use batch normalization! In the theory of Kantorovich-Rubinstein duality, the gradient of each pair is unique. However, batch normalization will shuffle this circumstance, and make the mapping disturbed. As the result, the layer normalization (or other approach) was recommended to be used in the structure.

Experiments by myself

At last, I want to show my “trick” experiments which I try to enhance the WGAN-GP idea. Toward WGAN-GP, there are also three directions we can try to improve:

  1. The speed of convergence is slow
  2. The shape of EM-distance is not a full continuous function. Is there any improvement can change it as continuous function?
  3. In the implementation phase, we should generate the artificial sample, and compute the loss in the same iteration. However, it should spend extra time to do this calculation. Is there any method to simplify this step?
Figure 23. The gaussian-like revision toward EM-distance

The Figure 23 shows the function I purposed to examine. You can see the shape of the function in the left side of the figure 24. As you can see, the format of this function is just like the Gaussian curve while it’s reflected vertically. The most well property of Gaussian curve includes:

  1. It is continuous anywhere
  2. It can also be derivative everywhere
  3. It has the close form solution
Figure 24. The shape of both function

In experiment 1, I simplify the gradient penalty term as performing the l2 normalization only. Since weight clipping is just like to restrict the capability of the whole model and doesn’t let the value of weight to be over than the limit, I guess that the usual normalization may solve this constraint. However, the improvement is fail.

Figure 25. The loss of experiment 1

In experiment 2, I replace the original EM-distance to the gaussian-like revision. However, it also cannot converge.

Figure 26. The loss of experiment 2

In experiment 3, I try to do the linear combination toward existed gradient result. By this idea, we don’t need to compute the gradient with extra time. Under my expectation, this gradient penalty should be stronger than the original one by the proving of Jensen’s inequality. However, the functional shape of neural network isn’t convex directly. After the experiment, both theory and physical result are fail.

Figure 27. The loss of experiment 3

At last, I show the result that each experiments generates. The Figure 28 illustrates the complete failure. As you can see, each case cannot converge. What’s worse, there are some trend of collapse problem we can observe easily.

Figure 28. The generated image in three experiments

Conclusion

In this article, I try to discuss the concept behind GAN, and introduce the structure of WGAN. Next, to reach Lipschitz continuous in the formal way, the gradient penalty idea is adopted. At last the three simple experiments are shown. Although my improvements are whole fail, but WGAN is still a very creative idea to make the generative model be more practical!

Reference

[1] M. Arjovsky, and L. Bottou, “Towards Principled Methods for Training Generative Adversarial Networks,” ArXiv:1701.04862 [stat.ML], Jan. 2017.

[2] M. Arjovsky, S. Chintala, and L. Bottou, “Wasserstein GAN,” ArXiv:1701.07875 [stat.ML], March 2017.

[3] I. Gulrajani , F. Ahmed, M. Arjovsky, V. Dumoulin, and Aaron Courville, “Improved Training of Wasserstein GANs,” ArXiv:1704.00028 [cs.LG], May 2017.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK