11

PyTorch 零基础入门 GAN 模型之基础篇

 2 years ago
source link: https://bbs.cvmart.net/articles/5283
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
PyTorch 零基础入门 GAN 模型之基础篇
PyTorch 零基础入门 GAN 模型之基础篇精选
技术讨论
三井 · 发表于 2021-08-06 15:07:05 文章来源: 三井的专栏

来源 | 知乎

近年来,各种生成模型及其应用广泛地出现在大家的视野范围内,像最近非常火爆的 Alias-Free GAN 更是从一个全新的视角,为生成模型领域中新的发展方向打下了坚实的理论基础。但是现在来看,无论是之前的 StyleGAN2 还是现在的 Alias-Free GAN,模型细节还有训练过程都是非常繁杂的。

同时,如果再结合 PyTorch 的话,又需要考虑各种分布式训练的问题。在这种内卷和快速迭代的时代,如何快速上手,把握住机会呢?在这一系列的教程里面呢,我们将以 MMGeneration 为基础,来帮助大家快速入门 GAN 这个庞大的领域。选择 MMGeneration 是因为,在前期,你可以不需要任何 PyTorch 的基础,到后期熟练之后,你也只需要在对应模型上进行一些修改就可以轻松地上手做一些实验啦。这么好用的工具包,还不来 Star 一波?
MMGeneration 的链接:https://github.com/open-mmlab/mmgeneration

入门第一步:从 GAN 到 DCGAN

GAN 的全称是 Generative Adversarial Networks,翻译过来就是生成对抗网络。这里我们需要重点理解的是对抗的含义:GAN 的基本想法是两个网络,生成器(G)和判别器(D),在训练过程中相互对抗。看起来就像是两个人练武,虽然一开始大家都很菜,但是判别器进步一点,然后生成器就迎头赶上,然后慢慢地两个人携手成为一代宗师。这是一种很抽象的理解,具体到数学理论上,GAN 的思想是由 Ian J. Goodfellow 在 Generative Adversarial Nets 中完整提出并证明的。

v2-09bff2c2171c94db192f129ace16289e_b.png

v2-09bff2c2171c94db192f129ace16289e_720w.png

整个 GAN 的对抗思想就体现在上图所示的损失函数公式当中,下面我们一点一点解析这个公式。在公式中 G(z) 描述了生成器的工作方式:输入一个噪声信号,然后输出一个尽可能逼真的图片(样本)。D(x) 和 D(G(z)) 表示判别器输入的分别是真实数据集中的图片和生成的图片,判别器的任务就是需要尽可能将两者区分开(在这里可以看作一个简单的二分类问题),这也就是让 D 去最大化这个损失函数的意义。而 G 就需要尽可能让 D 无法分辨出真实图片和虚假图片的差异,换句话说期望 G 生成的图片可以以假乱真。

v2-bebe9f8053b21c400ab93f1bbc667966_b.jpg

v2-bebe9f8053b21c400ab93f1bbc667966_720w.jpg

上图展示具体的训练流程,可以看到,在训练过程中 G 和 D 的优化是交替进行的,而且一般情况下我们往往希望 D 学习的稍微快一点,这样能够带动 G 更好地朝着全局最优解方向优化。具体的理论推导,大家可以去参考 Generative Adversarial Networks 一文,在这里通过 KL 散度可以很非常直观地看到我们的优化目标最优解就是生成器能够完全拟合真实样本的分布。
在 Ian J. Goodfellow 的文章中,主要是提出了 GAN 的思想。可是面对图像,我们常用的算子是卷积层。在 DCGAN 一文中,作者将 transposed convolution 用到了生成器中,成为了一个里程碑式的工作。从此之后,以卷积神经网络为主体的 GAN 模型就不断地涌现了出来。那下面我们将从这个基础模型入手,来看一下 GAN 网络的构造。

模型结构和代码分析

v2-72d98432707f5bde603298d45c00b7d1_b.jpg

v2-72d98432707f5bde603298d45c00b7d1_720w.jpg

上图就展示了一个典型的生成器和判别器的结构。在生成器当中,首先我们需要一个将噪声向量转换成为二维特征的模块,也就是 noise2feat block。那接下来需要连续经过几个上采样块将低分辨率的特征转换成高分辨的特征,在 DCGAN 中,我们使用的是 transposed convolution 来实现。最后,需要一个 to_rgb 块来将特征图的通道数映射为3通道,从而生成图片。那判别器其实就是生成器的一个反转,我们需要通过 img2feat 和大量的下采样块将特征图不断降低分辨率,最后输送给 decision head,来对当前的输入图片进行评判。
对应到 MMGeneration 当中,模型的具体代码都存放在 models/architectures/dcgan 文件夹下面,下图展示了对应的生成器和判别器的代码逻辑,在 mmgen 中我们也是严格按照这样的设计来构建代码,相信大家能够更容易上手。如果关心具体实现的同学,可以到文件中查看,如果你暂时还是 PyTorch 初学者,那你大可不必关心具体的实现,我们接下来告诉大家怎么用 mmgen 训练一个 DCGAN。

v2-3a2d8d2cbf71f8042609e928220bd31c_b.jpg

v2-3a2d8d2cbf71f8042609e928220bd31c_720w.jpg

DCGAN CelebA 实验

接下来,将通过 MMGeneration 来一步一步详细地教大家如何训练第一个 DCGAN 模型。这里不需要大家有太多的 PyTorch 基础知识,只需要跟着我们一步一步来就可以了。
Step 1 安装:使用 MMGeneration,你只需要克隆一下 github 上面的仓库到本地,然后按照安装手册配置一下环境即可,如果安装遇到什么问题,可以给 MMGeneration 提 issue,我们会尽快为小伙伴们解答。下面是具体的安装步骤:

# we assume that you have installed pytorch and mmcv-full in your env. 
# clone repo 
git clone https://github.com/open-mmlab/mmgeneration mmgen 
cd mmgen 
# install mmgen 
pip install -e . 

Step 2 数据:假设大家已经安装好了 MMGeneration,回到训练上来,首先我们要做的是准备训练数据,CelebA 的数据可以通过其官方网站下载,我们选用其中的 Align&Cropped 数据来进行训练。下载解压完了之后,我们需要回到 MMGeneration 仓库的文件夹,通过软链的方式将数据链接到仓库的 data 目录下面:

mkdir data 
ln -s absolute_path_to_CelebA ./data/celeba 

这样我们的数据准备工作就基本完成了。不过我们需要再更新一下我们的 config 文件的中的 img_root 字段,将我们现在的数据路径更新上去。具体要做的就是修改 dcgan-celeba config 文件的第 11 行:

# define dataset 
# you must set `samples_per_gpu` and `imgs_root` 
data = dict( 
    samples_per_gpu=128, 
    train=dict(imgs_root='data/celeba'))  # set img_root 

这个 config 文件其实就能帮我们定义整个训练的过程,包括数据集的构造,模型的定义以及训练流程的定义等等,详细的介绍后续会带给大家。大家现在可以先通过我们提供的 config 来实现快速的上手训练和采样生成图片。
Step 3 训练:训练的指令其实非常简单,通过我们之前修改的 config 文件,我们就可以通过如下命令进行训练了:

bash tools/dist_train.sh ./configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py 1 --work-dir ./work_dirs/dcgan-celeba 

在训练过程中,我们会自动保存不同阶段模型生成的样本到 work_dirs/dcgan-celeba 文件夹下面:

v2-e58a15b490f3914661a356992abd6325_b.gif

v2-e58a15b490f3914661a356992abd6325_b.jpg

这样我们就可以随时观测到模型的收敛情况了,当然后续的教程里面,我们还会介绍如何通过一些客观的评价指标来检测我们的训练过程。训练完成之后,我们就可以通过随机采样来看看模型能带给我们什么样的样本啦。在 MMGeneration 当中,可以轻松得通过 demo/unconditional_demo.py 来实现:

python demo/unconditional_demo.py ./configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py work_dirs/dcgan-celeba/ckpt/iter_290000.pth 

其实在 MMGeneration 当中已经支持了非常多模型的采样,并且提供了公共的 checkpoint 供大家把玩,在我们的快速上手教程中有更详细地介绍,欢迎大家来试用并且提出你们宝贵的意见。


相关推荐:

基于 GAN 的极限图像压缩框架
涵盖 18+ SOTA GAN 实现,这个开源工程 PyTorch 库火了
GAN 万字长文综述

  • unlike.svg 1
  • comment.svg 0
  • view.svg 1126
like_white.png有用

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK