4

Generative Adversarial Networks in Python

 4 years ago
source link: https://mc.ai/generative-adversarial-networks-in-python-2/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Generative adversarial networks (GANs) are a set of deep neural network models use to produce synthetic data. The method was developed by Ian Goodfellow in 2014 and is outlined in the paper Generative Adversarial Networks . The goal of a GAN is to train a discriminator to be able to distinguish between real and fake data while simultaneously training a generator to produce synthetic instances of data that can reliably trick the discriminator.

A popular application of GANs was in the ‘GANgough’ project where synthetic paintings were generated by GANs trained on paintings from wikiart.org. The independent researcher, Kenny Jones and Derrick Bonafilia, were able to generate synthetic religious, landscape, flower and portrait images with impressive performance. The article GANGough: Creating Art with GANs details the method. In this post, we will walk through the process of building a basic GAN in python which we will use to generate synthetic images of handwritten digits.

Let’s get started!

First, let’s import the necessary packages. Let’s start by importing ‘matplotlib’, ‘tensorflow.keras’ layers, and the ‘tensorflow’ library. Let’s also define a variable that we can use to store and clear our sessions:

import matplotlib.pyplot as plt
from tensorflow.keras import layers
import tensorflow as tf
from tensorflow.python.keras import backend as K
K.clear_session()

Next let’s load the ‘MNIST’ data set, which is available in the ‘tensorflow’ library. The data contains images of hand written digits and labels corresponding to the digits:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

Let’s take a look at the first image in the training data:

plt.imshow(train_images[0], cmap='gray')

We can see that this is a hand written ‘5’. Next, let’s reshape the data, convert the image pixels to floating point values, and normalize the pixel values to be between -1 and 1:

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5

Now let’s define our generator model:

def generator_model():
 model = tf.keras.Sequential()
 model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
 model.add(layers.BatchNormalization())
 model.add(layers.LeakyReLU())

 model.add(layers.Reshape((7, 7, 256)))
 assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

 model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
 assert model.output_shape == (None, 7, 7, 128)
 model.add(layers.BatchNormalization())
 model.add(layers.LeakyReLU())

 model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
 assert model.output_shape == (None, 14, 14, 64)
 model.add(layers.BatchNormalization())
 model.add(layers.LeakyReLU())

 model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
 assert model.output_shape == (None, 28, 28, 1)

 return model

We first initialize a sequential model object. We then add the first layer, which is an ordinary dense neural network layer. There are also a series of transposed convolution layers, which are convolutional layers with padding. For those unfamiliar, a convolutional layer learns matrices (kernels) of weights which are then combined to form filters used for feature extraction. Through learning the filter weights, convolutional layers learn convolved features that represent high level information about an image. Through the learned filters, these layers can perform operations like edge detection, image sharpening and image blurring. These are some examples of kernel matrices in computer vision:

If you are interested, you can learn more about convolutional neural networks here . There are also a series of leaky ‘ReLu’ layers:

These are modified ‘ReLu’ activations which help to alleviate the dying neuron issue, by increasing the range of the ‘ReLu’ function. There are also batch normalization layers which fix the mean and variances of each layer’s inputs. This helps to improve the speed, performance ,and stability of the neural network.

The generator and discriminator networks are trained in a similar fashion to ordinary neural networks. Namely, weights are randomly initialized, a loss function and its gradients with respect to the weights are evaluated, and the weights are iteratively updated through back propagation.

The training process will help the generator model produce real looking images from noise and the discriminator do a better job at detecting seemingly authentic fake images. Let’s see an example of input for our generator model. First let’s define our generator and initialize some noise ‘pixel’ data:

generator = generator_model()
noise = tf.random.normal([1, 100])

Next, let’s pass in our noise data into our ‘generator_model’ function and plot the image using ‘matplotlib’:

your_session = K.get_session()
generated_image = generator(noise, training=False)
array = generated_image[0, :, :, 0].eval(session=your_session)
plt.imshow(array, cmap='gray')

We see that this is just a noisy black and white image. The goal is for our generator to learn how to produce real looking images of digits, like the one we plotted earlier, by iteratively training on this noisy data. Upon sufficient training, our generator should be able to generate authentic looking hand written digits from noisy input like what is shown above.

Now let’s define our discriminator function. This will be an ordinary convolutional neural network used for classification:

def discriminator_model():
 model = tf.keras.Sequential()
 model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
 input_shape=[28, 28, 1]))
 model.add(layers.LeakyReLU())
 model.add(layers.Dropout(0.3))

 model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
 model.add(layers.LeakyReLU())
 model.add(layers.Dropout(0.3))

 model.add(layers.Flatten())
 model.add(layers.Dense(1))

 return model

Next, let’s define our loss function and our discriminator object:

discriminator = discriminator_model()
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Next we define the loss function specific to the discriminator. This function measures how well the discriminator is able to distinguish real images from fake images. It compares the binary predictions of the discriminator to the labels on the real images and fake images, where ‘

def discriminator_loss(real_output, fake_output):
 real_loss = cross_entropy(tf.ones_like(real_output), real_output)
 fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
 total_loss = real_loss + fake_loss
 return total_loss

The generator loss function measure how well the generator was able to trick the discriminator:

def generator_loss(fake_output):
 return cross_entropy(tf.ones_like(fake_output), fake_output)

Since the generator and discriminator are separate neural networks they each have their own optimizers. We will use the ‘Adam’ optimizer to train our discriminator and generator:

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Next, let’s define the number of epochs (which is the number of full passes over the training data), the dimension size of our noise data, and the number of samples to generate:

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

We then define our function for our training loop. The ‘@tf.function’ decorator compiles the function. The ‘train_step()’ function starts by generating an image from a random noise:

@tf.function
def train_step(images):
 noise = tf.random.normal([BATCH_SIZE, noise_dim])
 with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
 generated_images = generator(noise, training=True) #random seed images

The discriminator is then used to classify real and fake images:

@tf.function
def train_step(images):
 ...
 real_output = discriminator(images, training=True)
 fake_output = discriminator(generated_images, training=True)

We then calculate the generator and discriminator loss:

@tf.function
def train_step(images):
 ... gen_loss = generator_loss(fake_output)
 disc_loss = discriminator_loss(real_output, fake_output)

We then calculate the gradients of the loss functions:

@tf.function
def train_step(images):
 ...
 gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
 gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

We then apply the optimizer to find the weights that minimize loss and we update the generator and discriminator:

@tf.function
def train_step(images):
 ...
 generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
 discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

Next, we define a method that will allow us to generate fake images, after training is complete, and save them:

def generate_and_save_images(model, epoch, test_input):

 predictions = model(test_input, training=False)

 fig = plt.figure(figsize=(4,4))

 for i in range(predictions.shape[0]):
 plt.subplot(4, 4, i+1)
 plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
 plt.axis('off')

 plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
 plt.show()

Next, we define the training method that will allow us to train the generator and discriminator simultaneously. We start by iterating over the number of epochs:

def train(dataset, epochs):
 for epoch in range(epochs):

Within the loop over epochs we produce images:

def train(dataset, epochs):
 ...
 display.clear_output(wait=True)
 generate_and_save_images(generator,
 epoch + 1,
 seed)
 display.clear_output(wait=True)
 generate_and_save_images(generator,
 epochs,
 seed)

Finally, we can call the ‘train()’ method on the training data with the epochs parameter:

train(train_dataset, EPOCHS)

If we run our code with two epochs we should get the following output of fake images:

We see that the out put is still very noisy. After 50 epochs we should generate the following plot (Note that this takes several hours to run on a MacBook Pro with 16 G of memory):

As we can see, some of the digits are recognizable while others need a bit more training to improve. Presumable, with more epochs the digits will look more authentic. I’ll stop here but feel free to play around with the data and code yourself. There are many other data sets that you can use to train GANs including the Intel Image Classification dataset, CIFAR dataset, and the Cats & Dogs dataset. Other interesting applications include deep fake videos and deep fake audio.

To get started on training a GAN on videos you can check out the paper Adversarial Video Generation of Complex Datasets . In this paper, the authors train a GAN on the UCF-101 Action Recognition Dataset , which contains videos from YouTube within 101 action categories. To get started on training a GAN on audio check out the paper Adversarial Audio Synthesis . In this paper, the authors train a GAN on the Speech Commands One Through Nine , which contains audio of drums, bird vocalizations, and much more.

CONCLUSIONS

To summarize, in this post we discussed the generative adversarial network (GAN) and how to implement it in python. We showed that GANs simultaneously train two neural networks, one used for data generation and the other for data discrimination. The layers of the discriminator and generator most notably contain transposed convolution and ordinary convolution layers respectively which learn high level feature representations of images. I encourage you to try training a GAN on some other interesting data such as the speech or video data sets I mentioned above. I hope you found this post useful/interesting. The code from this post is available on GitHub . Thank you for reading!


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK