@(TOC)[CycleGAN图像风格迁移呼唤]
模型介绍
模型简介
CycleGAN(Cycle Generative Adversaial Network)即循环对抗生成网络,来自论文Link:Unpaired lmage-to-mage Translation using Cycle-Consistent AdvesairalNetworks该模型实现了—种在没有配对示例的情况下学习将图像从源域×转换到目标域Y的方法。
该模型一个重要应用领城是域迁移(Dom in Adaptation),可以通俗地理解为图像风格迁移。其实在CycieGAV之前,就已经有了域迁移模型,比以D Pi2Pk,但是Pi2Fik要求训练数据必须是成对的,而现实生活中,要找到两个城(画风)中成对出现的图片是相当困难的,因此 CyclCGAN诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。
模型结构
CycleGAN网络本质上是由两个镜像对称的GAN网络组成,其结构如下图所示(图片来源于原论文)∶
为了方便理解,这里以苹果和橘子为例介绍。上图中
X
X
X可以理解为苹果,
Y
Y
Y为橘子;
G
G
G为将苹果生成橘子风格的生成器,
F
F
F为将橘子生成的苹果风格的生成器,
D
x
D_x
Dx和
D
x
D_x
Dx为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。
该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle ConsistencyLoss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文)︰
图中苹果图片
x
x
x经过生成器
G
G
G得到伪橘子
Y
ˆ
\^Y
Yˆ,然后将伪橘子
Y
ˆ
\^Y
Yˆ结果送进生成器
F
F
F又产生苹果风格的结果
x
ˆ
\^x
xˆ,最后将生成的苹果风格结果
x
ˆ
\^x
xˆ与原苹果图片
x
x
x一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。
1 数据集
本案例使用的数据集里面的图片来源于Link:ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统─缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。
这里对数据进行了随机裁剪、水平随机翻转和归—化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为MindRecord格式的数据,以省略大部分数据预处理的代码。
1.1 数据集下载
使用download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用pip install download安装download 包。
1.2 数据集加载
使用MindSpore的MindDataset接读取和解析数据集。
1.3 可视化
通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib 模块可视化部分训练数据。
2 构建生成器
本案例生成器的模型结构参考的ResNet模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数n_layers参数控制残差块数。
生成器的结构如下所示:
具体的模型结构请参照下文代码:
3 构建判别器
判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是Patch大小为70x70的PatchGANs模型。通过一系列的Conv2d 、 BatchNorm2d和LeakyReLu层对其进行处理,最后通过Sigmoid 激活函数得到最终概率。
4 优化器和损失函数
根据不同模型需要单独的设置优化器,这是训练过程决定的。
对生成器
G
G
G及其判别器
D
y
Dy
Dy ,目标损失函数定义为:
其中
G
G
G试图生成看起来与
Y
Y
Y中的图像相似的图像
G
(
x
)
G(x)
G(x),而
D
y
D_y
Dy的目标是区分翻译样本
G
(
x
)
G(x)
G(x)和真实样本
y
y
y,生成器的目标是最小化这个损失函数以此来对抗判别器。即
单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于X的每个图像x,图像转换周期应能够将x带回原始图像,可以称之为正向循环—致性,即
对于
Y
Y
Y,类似的
可以理解采用了一个循环一致性损失来激励这种行为。
循环一致损失函数定义如下:
5 前向计算
搭建模型前向计算损失的过程,过程如下代码。
为了减少模型振荡[1],这里遵循Shrivastava等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。
6 计算梯度和反向传播
其中梯度计算也是分开不同的模型来进行的,详情见如下代码:
7 模型训练
8 模型推理
下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第—行为原图,第二行为对应生成的结果图。
9 参考
[1] I.Goodfellow.NIPS 2016 tutorial: Generative ad-versarial networks. arXiv preprint arXiv:1701.00160,2016.2,4,5
[2]A.Shwivastava T.Pister,O. Tuzel, J.Susskind W.Wang, R.Webb.Learning from simulated and unsupervised images through adversarial training. In CVPR,2017.3,5,6,7