65

生成对抗网络的 TensorFlow 初探

 5 years ago
source link: https://mp.weixin.qq.com/s/LXINXIs2t0O2XbUbD2Agkg?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

之前介绍过生成对抗网络的初步原理,参见( 生成对抗网络浅析(GAN) )。

今天结合最近很火的TensorFlow,看看原理背后的实现。

01

模型

上一篇,参见( 生成对抗网络浅析(GAN) )定义了GAN模型的Model,

auaABbQ.jpg!web

使用TFGAN我们组要定义4个重要属性

a. Generator,  在噪声的干扰下,生成Fake image;

b. Discriminator, 判定输入Training set,是Real,还是Fake;

c. 真实图片,Real Images;

d. Random noise;

Generator

def generator_fn(noise, weight_decay=2.5e-5, is_training=True):

"""G 生成MNIST图片的G网络.


Args:

noise: Tensor表征的噪音。

weight_decay: L2正则化 -- light weight decay。

is_training: 如果为“True”,批量规范使用批量统计。如果是'False`,批量规范使用从人口中收集的指数移动平均线统计.


Returns:

生成图像范围[-1, 1].

"""

with framework.arg_scope(

[layers.fully_connected, layers.conv2d_transpose],

activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,

weights_regularizer=layers.l2_regularizer(weight_decay)), \

framework.arg_scope([layers.batch_norm], is_training=is_training,

zero_debias_moving_mean=True):

net = layers.fully_connected(noise, 1024)

net = layers.fully_connected(net, 7 * 7 * 256)

net = tf.reshape(net, [-1, 7, 7, 256])

net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)

net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)

# Make sure that generator output is in the same range as `inputs`

# ie [-1, 1].

net = layers.conv2d(net, 1, 4, normalizer_fn=None, activation_fn=tf.tanh)


return net

Discriminator

def discriminator_fn(img, unused_conditioning, weight_decay=2.5e-5,

is_training=True):

"""D 使用MNIST数字的D网络.


Args:

img: 真实或生成的图片,范围 [-1, 1].

unused_conditioning: TFGAN API可以帮助处理条件GAN,这需要向生成器和鉴别器提供额外的“条件”信息。由于此示例不是有条件的,因此我们不使用此参数。

weight_decay: L2 正则化 weight decay。

is_training: 同G网络。


Returns:

记录图像真实概率。

"""

with framework.arg_scope(

[layers.conv2d, layers.fully_connected],

activation_fn=leaky_relu, normalizer_fn=None,

weights_regularizer=layers.l2_regularizer(weight_decay),

biases_regularizer=layers.l2_regularizer(weight_decay)):

net = layers.conv2d(img, 64, [4, 4], stride=2)

net = layers.conv2d(net, 128, [4, 4], stride=2)

net = layers.flatten(net)

with framework.arg_scope([layers.batch_norm], is_training=is_training):

net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)

return layers.linear(net, 1)

Real images , 使用mnist数据源作为real images输入。

with tf.device('/cpu:0'):

real_images, _, _ = data_provider.provide_data(

'train', batch_size, MNIST_DATA_DIR)

GANModel Tuple

gan_model = tfgan.gan_model(

generator_fn,

discriminator_fn,

real_data=real_images,

generator_inputs=tf.random_normal([batch_size, noise_dims]))

02

损失函数

损失函数(loss function)是用来估量模型的预测值f(x)与真实值Y的不一致程度。

bQ3YRvF.png!web

其中,前面的均值函数表示的是经验风险函数,L代表的是损失函数,后面的Φ是正则化项(regularizer)或者叫惩罚项(penalty term),它可以是L1,也可以是L2,或者其他的正则函数。整个式子表示的意思是找到使目标函数最小时的θ值。

对于GAN, 论文中的的损失函数就是二元极大极小 -- minmax,

3A7JzaA.jpg!web

使用TF中的minmax损失函数

# 使用原始论问中的minmax损失函数。

vanilla_gan_loss = tfgan.gan_loss(

gan_model,

generator_loss_fn=tfgan.losses.minimax_generator_loss,

discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss)

同样也可以使用Wasserstein、Improved Wasserstein, 可参见论文https://arxiv.org/pdf/1701.07875.pdf

# 使用 Wasserstein loss , 参考(https://arxiv.org/abs/1701.07875)

# (https://arxiv.org/abs/1704.00028).

improved_wgan_loss = tfgan.gan_loss(

gan_model,

# We make the loss explicit for demonstration, even though the default is

# Wasserstein loss.

generator_loss_fn=tfgan.losses.wasserstein_generator_loss,

discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,

gradient_penalty_weight=1.0)

参考TF的实现

# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).

def wasserstein_generator_loss(

discriminator_gen_outputs,

weights=1.0,

scope=None,

loss_collection=ops.GraphKeys.LOSSES,

reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,

add_summaries=False):

"""Wasserstein generator loss for GANs.

See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.

Args:

discriminator_gen_outputs: Discriminator output on generated data. Expected

to be in the range of (-inf, inf).

weights: Optional `Tensor` whose rank is either 0, or the same rank as

`discriminator_gen_outputs`, and must be broadcastable to

`discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or

the same as the corresponding dimension).

scope: The scope for the operations performed in computing the loss.

loss_collection: collection to which this loss will be added.

reduction: A `tf.losses.Reduction` to apply to loss.

add_summaries: Whether or not to add detailed summaries for the loss.

Returns:

A loss Tensor. The shape depends on `reduction`.

"""

with ops.name_scope(scope, 'generator_wasserstein_loss', (

discriminator_gen_outputs, weights)) as scope:

discriminator_gen_outputs = _to_float(discriminator_gen_outputs)


loss = - discriminator_gen_outputs

loss = losses.compute_weighted_loss(

loss, weights, scope, loss_collection, reduction)


if add_summaries:

summary.scalar('generator_wass_loss', loss)


return loss

自定义损失函数。

def silly_custom_generator_loss(gan_model, add_summaries=False):

return tf.reduce_mean(gan_model.discriminator_gen_outputs)

def silly_custom_discriminator_loss(gan_model, add_summaries=False):

return (tf.reduce_mean(gan_model.discriminator_gen_outputs) -

tf.reduce_mean(gan_model.discriminator_real_outputs))

03

训练&评估

训练

GAN的训练过程中,需要交替训练Generator和Discriminator网络,让Generator和Discriminator处于不断的优化和对抗中,正如论文算法的过程

6zAfqem.jpg!web

过程相对比较简单,首先定义GANTrainOps的元组,然后设置优化参数

generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)

discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)

gan_train_ops = tfgan.gan_train_ops(

gan_model,

improved_wgan_loss,

generator_optimizer,

discriminator_optimizer)

评估

使用‘Inception Score’和’Frechet Inception distance‘, 来衡量生成image的分布和真实image的分布的近似情况。

num_images_to_eval = 500

MNIST_CLASSIFIER_FROZEN_GRAPH = './mnist/data/classify_mnist_graph_def.pb'


# 要加载变量,请使用与训练job相同的变量范围。

with tf.variable_scope('Generator', reuse=True):

eval_images = gan_model.generator_fn(

tf.random_normal([num_images_to_eval, noise_dims]),

is_training=False)


# 计算 Inception score.

eval_score = util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)


# 计算 Frechet Inception distance.

with tf.device('/cpu:0'):

real_images, _, _ = data_provider.provide_data(

'train', num_images_to_eval, MNIST_DATA_DIR)

frechet_distance = util.mnist_frechet_distance(

real_images, eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)


# 重绘eval图片

generated_data_to_visualize = tfgan.eval.image_reshaper(

eval_images[:20,...], num_cols=10)

训练过程和结果

TFGAN使用源于GAN minmax博弈的交替训练思路,可以更改G和D的更新比率。

train_step_fn = tfgan.get_sequential_train_steps()


global_step = tf.train.get_or_create_global_step()

loss_values, mnist_scores, frechet_distances = [], [], []


with tf.train.SingularMonitoredSession() as sess:

start_time = time.time()

for i in xrange(1601):

cur_loss, _ = train_step_fn(

sess, gan_train_ops, global_step, train_step_kwargs={})

loss_values.append((i, cur_loss))

if i % 200 == 0:

mnist_score, f_distance, digits_np = sess.run(

[eval_score, frechet_distance, generated_data_to_visualize])

mnist_scores.append((i, mnist_score))

frechet_distances.append((i, f_distance))

print('Current loss: %f' % cur_loss)

print('Current MNIST score: %f' % mnist_scores[-1][1])

print('Current Frechet distance: %f' % frechet_distances[-1][1])

visualize_training_generator(i, start_time, digits_np)

YfIJvei.jpg!web

UfyEniU.jpg!web

mUBniyy.jpg!web

r6V7Jre.jpg!web

EvUnMr6.jpg!web

jANR3aQ.jpg!web

Nbmuiqm.jpg!web

RfmEZzM.jpg!web

vm26FjY.jpg!web

uYjQvmM.jpg!web

bqqeAr6.jpg!web

7NZv6je.jpg!web

7nI7va3.jpg!web

yQbauq6.jpg!web

QnIbQnU.jpg!web

eUnAfuB.jpg!web

YJVzM3b.jpg!web

BZrM73m.jpg!web

VjAzE3y.jpg!web

fua2imn.jpg!web

可以看到如论文中的演进曲线的变化( 生成对抗网络浅析(GAN)

UfyEniU.jpg!web

r6V7Jre.jpg!web

jANR3aQ.jpg!web

RfmEZzM.jpg!web

uYjQvmM.jpg!web

7NZv6je.jpg!web

yQbauq6.jpg!web

eUnAfuB.jpg!web

BZrM73m.jpg!web

fua2imn.jpg!web

时间维度的评估指标变化如下

22IvEf6.jpg!web

EneAr2b.jpg!web

2eiuEfA.jpg!web

扩展阅读

生成对抗网络浅析(GAN)

参考:

https://github.com/tensorflow/models/tree/master/research/gan

https://arxiv.org/pdf/1701.07875.pdf

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/losses/python/losses_impl.py

https://blog.csdn.net/stalbo/article/details/79356739

https://zhuanlan.zhihu.com/p/44407513

http://www.csuldw.com/2016/03/26/2016-03-26-loss-function/

https://arxiv.org/pdf/1606.03498.pdf

THE END

- 晚安 -

图片长按2秒,识别图中二维码,关注订阅号

7BJbEfb.jpg!web


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK