8

生成对抗网络(Generative Adversarial Networks, GAN)

 3 years ago
source link: http://www.cnblogs.com/chenhuabin/p/14195437.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

生成对抗网络(Generative Adversarial Networks, GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的学习方法之一。

GAN 主要包括了两个部分,即生成器 generator 与判别器 discriminator。生成器主要用来学习真实图像分布从而让自身生成的图像更加真实,以骗过判别器。判别器则需要对接收的图片进行真假判别。在整个过程中,生成器努力地让生成的图像更加真实,而判别器则努力地去识别出图像的真假,这个过程相当于一个二人博弈,随着时间的推移,生成器和判别器在不断地进行对抗,最终两个网络达到了一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近 0.5(相当于随机猜测类别)。

3qYf6jU.png!mobile

这就是GAN的基本思想,其实并不难理解。但是,回归到神经网络本身,怎么去实现这种思想才是关键,我认为,进一步地,我认为如何定义损失函数才是关键。下图为GAN原论文中的损失函数公式:

Zvyaauu.png!mobile

我们来说说这个公式:

  • 整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。

  • D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片的是否真实的概率。

  • G的目的:上面提到过,D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”。也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。因此我们看到式子的最前面的记号是min_G。

  • D的目的:D的能力越强,D(x)应该越大,D(G(x))应该越小。这时V(D,G)会变大。因此式子对于D来说是求最大(max_D)

接下来,我们通过代码来实际感受生成式对抗网络GAN。代码使用Python3.8+tensorflow2.3.1实现,数据集为mnist手写数字识别数据集。

In [1]:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np 
import glob
import os

In [2]:

tf.__version__

Out[2]:

'2.3.1'

加载数据,我们只使用训练集即可,忽略测试集:

In [3]:

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

In [4]:

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') # 将数据转换为图像模式——单通道,然后转换数据类型为float32。

当数据值在0周围时,激活函数效果更加,所以,我们最好进行数据归一化:

In [5]:

train_images = (train_images- 127.5) / 127.5  # 数据归一化

In [6]:

batch_size = 256
buffer_size = 60000

In [7]:

datasets = tf.data.Dataset.from_tensor_slices(train_images) # tensorflow原生方法存储数据

In [8]:

datasets = datasets.shuffle(buffer_size).batch(batch_size) # 打乱数据顺序,分批成簇

In [10]:

train_images.shape

Out[10]:

(60000, 28, 28, 1)

现在,我们先定义一个生成器模型,模型网络中,我们使用最原始的全连接网络:

In [11]:

def generator_model():
    model =  keras.Sequential()
    model.add(layers.Dense(256, input_shape=(100,), use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(28*28*1,  use_bias=False, activation='tanh'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Reshape((28, 28, 1)))
    
    return model

判别器模型好理解,就是一个简单的全连接判别式网络:

In [12]:

def discriminator_model():
    model = keras.Sequential()
    model.add(layers.Flatten())
    
    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(256, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(1))
    
    return model

定义损失计算方式,在GAN网络中使用的是交叉熵损失函数:

In [13]:

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

损失函数得分为两个部分,一个是计算判别器的损失,一个是计算生成器的损失。其中,判别器损失也分为两个部分,一个是计算对真实图片的损失计算,在这一部分,我们期望模型能判别为真实图片,也就是越靠近1越好,一个是计算对判别器的损失计算,在这一部分,我么希望判别器能将图像判别为假,也就是结果越靠近0越好:

In [14]:

# 生成器损失计算。
def discriminator_loss(real_out, fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out), real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out), fake_out)
    return real_loss + fake_loss

In [15]:

# 判别器损失计算
def generator_loss(fake_out):
    fake_loss = cross_entropy(tf.ones_like(fake_out), fake_out)
    return fake_loss

定义优化器:

In [16]:

generator_opt = tf.keras.optimizers.Adam(1e-4)
discriminator_opt = tf.keras.optimizers.Adam(1e-4)

In [17]:

epochs = 160
noise_dim = 100  # 每个噪声100维度
generate_image_num = 16  # 生成16个随机噪声
seed = tf.random.normal([generate_image_num, noise_dim])  # 16个随机噪声,用于可视化输出训练过程中的效果展示

In [18]:

generator = generator_model()
discriminator = discriminator_model()

训练过程中的每一轮迭代,计算梯度,反向传播:

In [19]:

def train_step(images):
    noise = tf.random.normal([batch_size, noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        real_out = discriminator(images, training=True)
        gen_image = generator(noise, training=True)
        fake_out = discriminator(gen_image, training=True)
        
        gen_loss = generator_loss(fake_out)
        disc_loss = discriminator_loss(real_out, fake_out)
    gradient_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradient_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_opt.apply_gradients(zip(gradient_gen, generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_disc, discriminator.trainable_variables))

我们再定义可视化函数:

In [20]:

def genetate_plot_images(gen_model, test_noise):
    pre_images = gen_model(test_noise, training=False)
    fig = plt.figure(figsize=(32, 128))
    for i in range(pre_images.shape[0]):
        plt.subplot(1, 16, i+1)
        plt.imshow((pre_images[i, :, :, 0] + 1)/2, cmap='gray')
        plt.axis('off')
    plt.show()

真实开始训练:

In [21]:

def train(dataset, epochs):
    for epoch in range(epochs):
        
        for image_batch in dataset:
            train_step(image_batch)
        if epoch % 20 == 0:
            print('---------------------------------------------------------------------------epoch:%s-----------------------------------------------------------------------------'%(epoch+1))
            genetate_plot_images(generator, seed)
    print('---------------------------------------------------------------------------epoch:%s-----------------------------------------------------------------------------'%(epoch+1))
    genetate_plot_images(generator, seed)

In [22]:

train(datasets, epochs)
---------------------------------------------------------------------------epoch:1-----------------------------------------------------------------------------

7Ffyimf.png!mobile

---------------------------------------------------------------------------epoch:21-----------------------------------------------------------------------------

neU36zz.png!mobile

---------------------------------------------------------------------------epoch:41-----------------------------------------------------------------------------

neI73ye.png!mobile

---------------------------------------------------------------------------epoch:61-----------------------------------------------------------------------------

zA3IBzR.png!mobile

---------------------------------------------------------------------------epoch:81-----------------------------------------------------------------------------

RJvE7vN.png!mobile

---------------------------------------------------------------------------epoch:101-----------------------------------------------------------------------------

aANBnmI.png!mobile

---------------------------------------------------------------------------epoch:121-----------------------------------------------------------------------------

myIVFnr.png!mobile

---------------------------------------------------------------------------epoch:141-----------------------------------------------------------------------------

Vbeyqqm.png!mobile

---------------------------------------------------------------------------epoch:160-----------------------------------------------------------------------------

VjM3MvZ.png!mobile


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK