生成对抗网络的 TensorFlow 初探
source link: https://mp.weixin.qq.com/s/LXINXIs2t0O2XbUbD2Agkg?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
之前介绍过生成对抗网络的初步原理,参见( 生成对抗网络浅析(GAN) )。
今天结合最近很火的TensorFlow,看看原理背后的实现。
01
模型
上一篇,参见( 生成对抗网络浅析(GAN) )定义了GAN模型的Model,
使用TFGAN我们组要定义4个重要属性
a. Generator, 在噪声的干扰下,生成Fake image;
b. Discriminator, 判定输入Training set,是Real,还是Fake;
c. 真实图片,Real Images;
d. Random noise;
Generator
def generator_fn(noise, weight_decay=2.5e-5, is_training=True):
"""G 生成MNIST图片的G网络.
Args:
noise: Tensor表征的噪音。
weight_decay: L2正则化 -- light weight decay。
is_training: 如果为“True”,批量规范使用批量统计。如果是'False`,批量规范使用从人口中收集的指数移动平均线统计.
Returns:
生成图像范围[-1, 1].
"""
with framework.arg_scope(
[layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
weights_regularizer=layers.l2_regularizer(weight_decay)), \
framework.arg_scope([layers.batch_norm], is_training=is_training,
zero_debias_moving_mean=True):
net = layers.fully_connected(noise, 1024)
net = layers.fully_connected(net, 7 * 7 * 256)
net = tf.reshape(net, [-1, 7, 7, 256])
net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
# Make sure that generator output is in the same range as `inputs`
# ie [-1, 1].
net = layers.conv2d(net, 1, 4, normalizer_fn=None, activation_fn=tf.tanh)
return net
Discriminator
def discriminator_fn(img, unused_conditioning, weight_decay=2.5e-5,
is_training=True):
"""D 使用MNIST数字的D网络.
Args:
img: 真实或生成的图片,范围 [-1, 1].
unused_conditioning: TFGAN API可以帮助处理条件GAN,这需要向生成器和鉴别器提供额外的“条件”信息。由于此示例不是有条件的,因此我们不使用此参数。
weight_decay: L2 正则化 weight decay。
is_training: 同G网络。
Returns:
记录图像真实概率。
"""
with framework.arg_scope(
[layers.conv2d, layers.fully_connected],
activation_fn=leaky_relu, normalizer_fn=None,
weights_regularizer=layers.l2_regularizer(weight_decay),
biases_regularizer=layers.l2_regularizer(weight_decay)):
net = layers.conv2d(img, 64, [4, 4], stride=2)
net = layers.conv2d(net, 128, [4, 4], stride=2)
net = layers.flatten(net)
with framework.arg_scope([layers.batch_norm], is_training=is_training):
net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)
return layers.linear(net, 1)
Real images , 使用mnist数据源作为real images输入。
with tf.device('/cpu:0'):
real_images, _, _ = data_provider.provide_data(
'train', batch_size, MNIST_DATA_DIR)
GANModel Tuple
gan_model = tfgan.gan_model(
generator_fn,
discriminator_fn,
real_data=real_images,
generator_inputs=tf.random_normal([batch_size, noise_dims]))
02
损失函数
损失函数(loss function)是用来估量模型的预测值f(x)与真实值Y的不一致程度。
其中,前面的均值函数表示的是经验风险函数,L代表的是损失函数,后面的Φ是正则化项(regularizer)或者叫惩罚项(penalty term),它可以是L1,也可以是L2,或者其他的正则函数。整个式子表示的意思是找到使目标函数最小时的θ值。
对于GAN, 论文中的的损失函数就是二元极大极小 -- minmax,
使用TF中的minmax损失函数
# 使用原始论问中的minmax损失函数。
vanilla_gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.minimax_generator_loss,
discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss)
同样也可以使用Wasserstein、Improved Wasserstein, 可参见论文https://arxiv.org/pdf/1701.07875.pdf
# 使用 Wasserstein loss , 参考(https://arxiv.org/abs/1701.07875)
# (https://arxiv.org/abs/1704.00028).
improved_wgan_loss = tfgan.gan_loss(
gan_model,
# We make the loss explicit for demonstration, even though the default is
# Wasserstein loss.
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
gradient_penalty_weight=1.0)
参考TF的实现
# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).
def wasserstein_generator_loss(
discriminator_gen_outputs,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Wasserstein generator loss for GANs.
See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.
Args:
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`discriminator_gen_outputs`, and must be broadcastable to
`discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or
the same as the corresponding dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add detailed summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with ops.name_scope(scope, 'generator_wasserstein_loss', (
discriminator_gen_outputs, weights)) as scope:
discriminator_gen_outputs = _to_float(discriminator_gen_outputs)
loss = - discriminator_gen_outputs
loss = losses.compute_weighted_loss(
loss, weights, scope, loss_collection, reduction)
if add_summaries:
summary.scalar('generator_wass_loss', loss)
return loss
自定义损失函数。
def silly_custom_generator_loss(gan_model, add_summaries=False):
return tf.reduce_mean(gan_model.discriminator_gen_outputs)
def silly_custom_discriminator_loss(gan_model, add_summaries=False):
return (tf.reduce_mean(gan_model.discriminator_gen_outputs) -
tf.reduce_mean(gan_model.discriminator_real_outputs))
03
训练&评估
训练
GAN的训练过程中,需要交替训练Generator和Discriminator网络,让Generator和Discriminator处于不断的优化和对抗中,正如论文算法的过程
过程相对比较简单,首先定义GANTrainOps的元组,然后设置优化参数
generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)
discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)
gan_train_ops = tfgan.gan_train_ops(
gan_model,
improved_wgan_loss,
generator_optimizer,
discriminator_optimizer)
评估
使用‘Inception Score’和’Frechet Inception distance‘, 来衡量生成image的分布和真实image的分布的近似情况。
num_images_to_eval = 500
MNIST_CLASSIFIER_FROZEN_GRAPH = './mnist/data/classify_mnist_graph_def.pb'
# 要加载变量,请使用与训练job相同的变量范围。
with tf.variable_scope('Generator', reuse=True):
eval_images = gan_model.generator_fn(
tf.random_normal([num_images_to_eval, noise_dims]),
is_training=False)
# 计算 Inception score.
eval_score = util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)
# 计算 Frechet Inception distance.
with tf.device('/cpu:0'):
real_images, _, _ = data_provider.provide_data(
'train', num_images_to_eval, MNIST_DATA_DIR)
frechet_distance = util.mnist_frechet_distance(
real_images, eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)
# 重绘eval图片
generated_data_to_visualize = tfgan.eval.image_reshaper(
eval_images[:20,...], num_cols=10)
训练过程和结果
TFGAN使用源于GAN minmax博弈的交替训练思路,可以更改G和D的更新比率。
train_step_fn = tfgan.get_sequential_train_steps()
global_step = tf.train.get_or_create_global_step()
loss_values, mnist_scores, frechet_distances = [], [], []
with tf.train.SingularMonitoredSession() as sess:
start_time = time.time()
for i in xrange(1601):
cur_loss, _ = train_step_fn(
sess, gan_train_ops, global_step, train_step_kwargs={})
loss_values.append((i, cur_loss))
if i % 200 == 0:
mnist_score, f_distance, digits_np = sess.run(
[eval_score, frechet_distance, generated_data_to_visualize])
mnist_scores.append((i, mnist_score))
frechet_distances.append((i, f_distance))
print('Current loss: %f' % cur_loss)
print('Current MNIST score: %f' % mnist_scores[-1][1])
print('Current Frechet distance: %f' % frechet_distances[-1][1])
visualize_training_generator(i, start_time, digits_np)
可以看到如论文中的演进曲线的变化( 生成对抗网络浅析(GAN) )
时间维度的评估指标变化如下
扩展阅读
参考:
https://github.com/tensorflow/models/tree/master/research/gan
https://arxiv.org/pdf/1701.07875.pdf
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/losses/python/losses_impl.py
https://blog.csdn.net/stalbo/article/details/79356739
https://zhuanlan.zhihu.com/p/44407513
http://www.csuldw.com/2016/03/26/2016-03-26-loss-function/
https://arxiv.org/pdf/1606.03498.pdf
THE END
- 晚安 -
图片长按2秒,识别图中二维码,关注订阅号
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK