73

实战生成对抗网络(二):生成手写数字

 5 years ago
source link: https://mp.weixin.qq.com/s/An1Mz10B6Sxh6t2z7FFFQQ?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

在开始本文之前,让我们先看看一则报道:

人民网讯 据英国广播电视公司10月25日报道,由人工智能创作的艺术作品以432000美元(约合300万人民币)的高价成功拍卖。

看起来一则不起眼的新闻,其实意义深远,它意味着人们开始认可计算机创作的艺术价值,那些沾沾自喜认为不会被人工智能取代的艺术家也要瑟瑟发抖了。

这幅由人工智能创作的作品长啥样,有啥过人之处?

na6rYji.jpg!web

嗯,以我这种外行人士看来,实在不怎么样,但这不意味着人工智能不行。要知道,AlphaGo初出道时,也只敢挑战一下樊麾这样的二流棋手,接下来挑战顶级棋手李世石,人类还能勉力一战,等进化到AlphaGo Master,零封人类棋手。然而这还没有完,AlphaGo Zero不再学习人类棋譜,完全通过自学,碾压AlphaGo Master,对付人类棋手,更如我们捏死一只蚂蚁那么容易。

所以说,尽管人工智能创作的第一副作品如同鬼画桃符,但其潜力无可限量。

那么,接下来我们会探讨如何创作出一幅名画?No. No.

创作一副画并不是那么容易。这幅名为《埃德蒙·贝拉米肖像》的画作是由巴黎一个名为“显而易见”(Obvious)的艺术团体创作利用人工智能技术创作而成,这幅作品是用算法和15000幅从14世纪到20世纪的肖像画数据制作而成。

我们还没有那个条件去创作一副人工智能的画作,但我们可以先从基本的着手,生成手写数字。手写数字对于机器学习的同学来说,太熟悉不过了。既然是老朋友了,那让我们开始吧!

首先回顾一下《 实战生成对抗网络[1]:简介 》这篇文章的内容,GAN由生成器和判别器组成。简单起见,我们选择简单的二层神经网络来实现生成器和判别器。

生成器

实现生成器并不难,我们采取的全连接网络拓扑结构为:100 → 128 → 784,最后的输出为784是因为MNIST数据集就是由28 x 28像素的灰度图像组成。代码如下:

G_W1 = tf.Variable(initializer([100, 128]), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(initializer([128, 784]), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]def generator(z):
  G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
  G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
  G_prob = tf.nn.sigmoid(G_log_prob)  return G_prob

判别器

判别器正好相反,以MNIST图像作为输入并返回一个代表真实图像的概率的标量,代码如下:

D_W1 = tf.Variable(initializer(shape=[784, 128]), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')
D_W2 = tf.Variable(initializer(shape=[128, 1]), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name="D_W2")
theta_D = [D_W1, D_W2, D_b1, D_b2]def discriminator(x):
  D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
  D_logit = tf.matmul(D_h1, D_W2) + D_b2
  D_prob = tf.nn.sigmoid(D_logit)  return D_prob, D_logit

训练算法

在论文arXiv: 1406.2661, 2014中给出了训练算法的伪代码:

NreYBnn.jpg!web

TensorFlow中的优化器只能做最小化,因为为了最大化损失函数,我们在伪代码给出的损失函数前加上一个负号。

D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

接下来定义优化器:

# 仅更新D(X)的参数, var_list=theta_DD_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)# 仅更新G(X)的参数, var_list=theta_GG_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

最后进行迭代,更新参数:

for it in range(60000):
  X_mb, _ = mnist.train.next_batch(mb_size)

  _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
  _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})

整个流程下来,其实和之前的深度学习算法差不多,非常容易理解。算法是不是有效果呢?我们可以将迭代过程中生成的手写数字显示出来:

jmeY32U.gif

嗯,结果虽然有点差强人意,但差不多是手写数字的字形,而且随着迭代,越来越接近手写数字,可以说GAN算法还是有效的。

小结

一个简单的GAN网络就这么几行代码就能搞定,看样子生成一副画也没有什么难的。先不要这么乐观,其实,GAN网络中的坑还是不少,比如在迭代过程中,就出现过如下提示:

Iter: 9000
D loss: nan
G_loss: nan

从代码中我们可以看出,GAN网络依然采用的梯度下降法来迭代求解参数。梯度下降的启动会选择一个减小所定义问题损失的方向,但是我们并没有一个办法来确保利用GAN网络可以进入纳什均衡的状态,这是一个高维度的非凸优化目标。网络试图在接下来的步骤中最小化非凸优化目标,最终有可能导致进入振荡而不是收敛到底层正式目标。

另外还有模型坍塌、计数、角度以及全局结构方面的问题,要解决这些问题,需要使用一些特殊的技巧和方法,后面我们深入各种GAN模型时将会探讨。

本文完整的代码请参考: https://github.com/mogoweb/aiexamples

参考

  1. 首幅人工智能画作拍卖43.2万美元 远超预估价

  2. 实战生成对抗网络[1]:简介


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK