基本介绍
今日要实践的模型是Pix2Pix模型,用于图像转换。使用官方的指定数据集,该数据集是已经经过处理的外墙(facades)数据,可以直接使用mindspore.dataset的方法读取。由于Pix2Pix模型是基于cGAN(条件生成对抗网络),本文会先简单介绍cGAN图像生成的原理,再简单介绍一下Pix2Pix模型,然后展示自己的运行结果,不作代码展示,最后进行总结。
cGAN基本原理
该部分内容来自官方文档,非原创
cGAN的生成器与传统GAN的生成器在原理上有一些区别,cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成,这是cGAN和GAN的在图像翻译任务中的差异。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。
首先定义一些在整个过程中需要用到的符号:
- 𝑥:代表观测图像的数据。
- 𝑧:代表随机噪声的数据。
- 𝑦=𝐺(𝑥,𝑧):生成器网络,给出由观测图像𝑥与随机噪声𝑧生成的“假”图片,其中𝑥来自于训练数据而非生成器。
- 𝐷(𝑥,𝐺(𝑥,𝑧)):判别器网络,给出图像判定为真实图像的概率,其中𝑥来自于训练数据,𝐺(𝑥,𝑧)来自于生成器。
cGAN的目标可以表示为:
该公式是cGAN的损失函数,D
想要尽最大努力去正确分类真实图像与“假”图像,也就是使参数𝑙𝑜𝑔𝐷(𝑥,𝑦)最大化;而G
则尽最大努力用生成的“假”图像𝑦欺骗D
,避免被识破,也就是使参数𝑙𝑜𝑔(1−𝐷(𝐺(𝑥,𝑧)))最小化。cGAN的目标可简化为:
为了对比cGAN和GAN的不同,我们将GAN的目标也进行了说明:
从公式可以看出,GAN直接由随机噪声𝑧生成“假”图像,不借助观测图像𝑥的任何信息。过去的经验告诉我们,GAN与传统损失混合使用是有好处的,判别器的任务不变,依旧是区分真实图像与“假”图像,但是生成器的任务不仅要欺骗判别器,还要在传统损失的基础上接近训练数据。假设cGAN与L1正则化混合使用,那么有:
进而得到最终目标:
图像转换问题本质上其实就是像素到像素的映射问题,Pix2Pix使用完全一样的网络结构和目标函数,仅更换不同的训练数据集就能分别实现以上的任务。
Pix2Pix模型简介
Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器和判别器
- 生成器
生成器G用到的是U-Net结构,输入的轮廓图𝑥编码再解码成真是图片。U-Net是德国Freiburg大学模式识别和图像处理组提出的一种全卷积结构。它分为两个部分,其中左侧是由卷积和降采样操作组成的压缩路径,右侧是由卷积和上采样组成的扩张路径,扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成。网络模型整体是一个U形的结构,因此被叫做U-Net。和常见的先降采样到低维度,再升采样到原始分辨率的编解码结构的网络相比,U-Net的区别是加入skip-connection,对应的feature maps和decode之后的同样大小的feature maps按通道拼一起,用来保留不同分辨率下像素级的细节信息。
- 判别器
判别器D用到的是作者自己提出来的条件判别器PatchGAN,判别器D的作用是在轮廓图𝑥的条件下,对于生成的图片𝐺(𝑥)判断为假,对于真实判断为真。
Pix2Pix代码实践
官方给的代码实践是经典的深度学习流程。即数据集预处理,模型搭建,模型训练,模型评估,模型推理。这个流程中重点是模型搭建中的生成器和判别器,这二者是cGAN的核心,与GAN有所不同,最好结合代码和原理进行学习理解。详细的可直接参考官方的代码实践,这里给出我自己的运行结果和部分代码
- 数据集部分可视化结果
- 模型训练结果:只训练了3轮
- 模型推理结果:3轮训练的推理结果,感觉一般般,最好多训练几轮
总结
cGAN与GAN有所不同,他们之间的差异也比较明显,个人感觉cGAN是优于GAN的。Pix2Pix的思想不难,但是数学公式挺难的,要看懂数学公式,理解数学公式是需要一定的数学基础和时间投入的。如果有时间我会认真去钻研这些公式,今天就简单理解,跑跑代码,体验一下这个模型。