1 GAN基本概念
1.1 GAN介绍
GAN的英文全称是Generative Adversarial Network,中文名是生成对抗网络。它由两个部分组成,生成器和鉴别器(又称判别器),生成网络(Generator)负责生成模拟数据;判别网络Discriminator)负责判断输入的数据是真实的还是生成的。生成网络要不断优化自己生成的数据让判别网络判断不出来,判别网络也要优化自己让自己判断得更准确。它们之间的关系可以用竞争或敌对关系来描述。
在GAN的原作中,作者将生成器比喻为印假钞票的犯罪分子,判别器则类比为警察。犯罪分子努力让钞票看起来逼真,警察则不断提升对于假钞的辨识能力。二者互相博弈,随着时间的进行,都会越来越强。那么类比于图像生成任务,生成器不断生成尽可能逼真的假图像。判别器则判断图像是否是真实的图像,还是生成的图像,二者不断博弈优化。最终生成器生成的图像使得判别器完全无法判别真假。
1.2 GAN的基本架构图
生成器负责依据随机向量产生内容,这些内容可以是图片、文字,也可以是音乐,具体什么取决于你想要创造什么;判别器负责判别接收的内容是否是真实的,通常他会给出一个概率,代表内容的真实度。其中,真实数据分布中的数据与生成数据可以认为是相同形状的。
对抗是指GAN的交替训练的过程,对于原始的GAN,以图片生成为例子,由高斯分布随机采样得到的噪声通过生成器得到了生成的假图片,将假图片和真图片一起随机抽取送入判别器判别,让它学习区分两者,给真的高分,给假的低分,当判别器能够熟练判断现有的数据后,再让生成器以从判别器处获得高分为目标,不断生成更好的假图片,直到能骗过判别器,重复这一过程,直到判别器对任何图片的预测概率都接近0.5,也就是无法判别图片的真假,就可以停止训练了。
对抗式过程的最终目标是尽可能逼真地模拟数据集的分布。
我们训练一个GAN的最终目标就是获得一个足够好的生成器,也就生成一个足够已经乱真的内容
1.3 GAN原理简述
生成器和判别器都是神经网络,它们在训练阶段都相互竞争。重复这些步骤,在这个过程中,生成器和鉴别器在每次重复后在各自的工作中变得越来越好。
Generator
GAN中的Generator是一种神经网络,给定一组随机的值,通过一系列非线性计算产生真实的图像。该生成器产生假图像 Xfake,其中随机向量Z,服从多元高斯分布采样。 生成器的输入是服从多元正态分布或高斯分布采样,并生成一个等于原始图像Xreal大小的输出。如下是一个以随机矢量为输入,生成假数字图像的生成器的流程图。
生成器的作用是: 欺骗的判别器、产生逼真的图像、训练完成后能实现高性能生成效果。
Discriminator
在GAN中判别器基于判别建模的概念,它试图用特定的标签对数据集中的不同类进行分类。因此,在本质上,它类似于一个监督分类问题。此外,判别器对观察结果的分类能力不仅限于图像,还包括视频、文本和许多其他领域(多模态)。以下是判别器将生成器生成的图像,分类判断真假的流程图。
判别器的作用是解决一个二值分类问题,学习区分真假图像:预测观察结果是由生成器(假的)生成,还是来自原始数据分布(真实的),在此过程中,它学习一组参数或权重。随着训练的进行,权重也在不断更新。
2 GAN的样本生成过程
GAN模型不是一上来就能实现具体功能的,需要经历一个训练的过程。我将其训练前后状态称为“原始的GAN模型”和“成熟的GAN模型”,原始的GAN模型要经过一个训练的过程成为一个成熟的GAN模型,而这个“成熟的GAN模型”才是我们实际应用的GAN模型。这个训练过程具体是训练生成网络(Generator)和判别网络(Discriminator)。
生成器G是一个生成图片的网络,它接收一个随机的噪声 z ,通过这个噪声生成图片,生成的图片记做G(z)。判别器D判别一张图片是不是“真实的”。它的输入是 x , x 代表一张图片(其中,x 包含生成图片和真实图片,对于生成图片有 x=G(z) ,输出D(x)代表 x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表图片0%是真的(或者说100%是假的)
2.1 训练过程
生成器和鉴别器一对博弈关系:鉴别器惩罚生成器,鉴别器收益,生成器损失;生成器进化,使鉴别器对自己惩罚小,生成器收益,鉴别器损失。
具体过程:生成器生成假数据,然后将生成的假数据和真数据都输入判别器,判别器要判断出哪些是真的哪些是假的。判别器第一次判别出来的肯定有很大的误差,然后我们根据误差来优化判别器。现在判别器水平提高了,生成器生成的数据很难再骗过判别器了,所以我们得反过来优化生成器,之后生成器水平提高了,然后反过来继续训练判别器,判别器水平又提高了,再反过来训练生成器,就这样循环往复,直到达到纳什均衡。
第一阶段只有判别模型D参与。将训练集中的样本x作为D的输入,输出0-1之间的某个值,数值越大意味着样本x为真实数据的可能性越大。在这个过程中,我们希望D尽可能使输出的值逼近1。
第二阶段中,判别模型D和生成模型G都参与,首先将噪声z输入G,G从真实数据集里学习概率分布并产生假的样本,然后将假的样本输入判别模型D,这一次D将尽可能输入数值0。所以在这个过程中,判别模型D相当于一个监督情况下的二分类器,数据要么归为1,要么归为0。
2.2 GAN的目标函数
生成模型捕获数据的分布,并以尝试最大化判别器出错的概率的方式进行训练。另一方面,判别器基于一个模型,该模型估计它获得的样本是从训练数据而不是从生成器接收的概率。GAN 被表述为一个极小极大游戏,其中判别器试图最大化其奖励V(D, G),而生成器试图最小化判别器的奖励,或者换句话说,最大化其损失。
GANs定义了一个噪声Pz(x) 作为先验,用于学习生成模型G在训练数据x上的概率分布Pg,G(z)表示将输入的噪声z映射成数据(例如生成图片)。D(x)代表x 来自于真实数据分布Pdata而不是Pg的概率。据此,优化的目标函数定义如下minmax的形式(具体见:http://t.csdn.cn/29Btw):
上式中的minmax可理解为当更新D时,需要最大化上式,而当更新G时,需要最小化上式,详细解释如下:
需要注意的是: 生成器不是最小化判别器的目标函数,而是最小化判别器目标函数的最大值,判别器目标函数的最大值代表的是真实数据分布与生成数据分布的JS散度,JS散度可以度量分布的相似性,两个分布越接近,JS散度越小。即判别器的目标是最小化交叉熵损失,生成器的目标是最小化生成数据分布和真实数据分布的JS散度。
(1)在对判别模型D的参数进行更新时:
- 对于来自真实分布Pdata的样本x而言,我们希望D(x)的输出越接近于1越好,即logD(x)越大越好;
- 对于通过噪声z生成的数据G(z)而言,我们希望D(G(z))尽量接近于0(即D能够区分出真假数据),因此log(1−D(G(z)))也是越大越好,所以需要maxD。
(2)在对生成模型G的参数进行更新时:
我们希望G(z)尽可能和真实数据一样,即Pg=Pdata。因此我们希望D(G(z))尽量接近于1,即log(1-D(G(z)))越小越好,所以需要minG。需要说明的是,logD(x)是与无关的项,在求导时直接为0。
D的最佳情况为:
2.3 两者分布之间的差异性
对于生成网络G,其输入 z ∼ N ( 0 , I ) ,表示 z 服从正态分布的数据,通过训练出来的参数 θ , 生成网络生成的图片为 G ( z , θ )。
对于判别网络,可以认为是二分类问题,一类是生成网络的输出,即 xgenerative=G(z,θ);另一类是真实数据 xreal,(其中, xreal∼Dreal ,表示 xreal 服从一种真实的分布distribution)。将 x (其中, x=xgenerative∪xreal) 数据输入到判别网络中,输出结果分别为:
D(xreal,ϕ)
D(xgenerative,ϕ)=D(G(z,θ),ϕ)
它可以用以下公式在数学上描述。
生成网络的损失函数:
上式中,G 代表生成网络,D 代表判别网络,H 代表交叉熵,z 是输入随机数据。D(G(z)) 是对生成数据的判断概率,1代表数据绝对真实,0代表数据绝对虚假。H(1,D(G(z))) 代表判断结果与1的距离。显然生成网络想取得良好的效果,那就要做到,让判别器将生成数据判别为真数据(即D(G(z))与1的距离越小越好)。
判别网络的损失函数:
上式中,是真实数据,这里要注意的是,代表真实数据与1的距离,代表生成数据与0的距离。显然,识别网络要想取得良好的效果,那么就要做到,在它眼里,真实数据就是真实数据,生成数据就是虚假数据(即真实数据与1的距离小,生成数据与0的距离小)。
将其看成二分类问题,二分类问题的损失函数可以使用交叉熵损失函数来表示,对于二分类,只有正样本(label=1)与负样本(label=0)。并且两者概率之和为1。对于一个输入 x,经过模型输出为 p(x)。y是真实的标签。采用Binary Cross-Entropy(BCE)损失函数对判别器进行训练。于是单个样本的损失函数就是:
LOSS = -y * log(p(x)) + (1-y)log(1-p(x))如果是计算 N 个样本的平均损失函数,只要将 N 个 Loss 叠加起来再除以N就行:
优化原理:生成网络和判别网络有了损失函数,就可以基于各自的损失函数,利用误差反向传播(Backpropagation)(BP)反向传播算法和最优化方法(如梯度下降法)来实现参数的调整),不断提高生成网络和判别网络的性能(最终生成网络和判别网络的成熟状态就是学习到了合理的映射函数)。
2.4 GAN的损失函数难下降
生成器和判别器的目的相反,也就是说两个生成器网络和判别器网络互为对抗,此消彼长。不可能Loss一直降到一个收敛的状态。
- 对于生成器,其Loss下降快,很有可能是判别器太弱,导致生成器很轻易的就"愚弄"了判别器。
- 对于判别器,其Loss下降快,意味着判别器很强,判别器很强则说明生成器生成的图像不够逼真,才使得判别器轻易判别,导致Loss下降很快。
也就是说,无论是判别器,还是生成器。loss的高低不能代表生成器的好坏。一个好的GAN网络,其GAN Loss往往是不断波动的。
似乎判断模型是否收敛就只能看生成的图像质量了。实际上,后文探讨的WGAN,提出了一种新的loss度量方式,让我们可以通过一定的手段来判断模型是否收敛。
3 GAN的网络架构
生成网络代码
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.shape[0], *img_shape)
return img
对抗网络代码:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
)
4 GAN的缺陷
4.1 生成式模型、判别式模型的区别?
对于机器学习模型,我们可以根据模型对数据的建模方式将模型分为两大类,生成式模型和判别式模型。如果我们要训练一个关于猫狗分类的模型,对于判别式模型,只需要学习二者差异即可。比如说猫的体型会比狗小一点。而生成式模型则不一样,需要学习猫长什么样,狗长什么样。有了二者的长相以后,再根据长相去区分。具体而言:
-
生成式模型:由数据学习联合概率分布P(X,Y), 然后由P(Y|X)=P(X,Y)/P(X)求出概率分布P(Y|X)作为预测的模型。该方法表示了给定输入X与产生输出Y的生成关系
-
判别式模型:由数据直接学习决策函数Y=f(X)或条件概率分布P(Y|X)作为预测模型,即判别模型。判别方法关心的是对于给定的输入X,应该预测什么样的输出Y。
4.2 GAN的Loss难下降
作为一个训练良好的GAN,其Loss就是降不下去的。衡量GAN是否训练好了,只能由人肉眼去看生成的图片质量是否好。后文提及的WGAN就提出了一种新的Loss设计方式,较好的解决了难以判断收敛性的问题。
对于判别器而言,从GAN的Loss可以看出,生成器和判别器的目的相反,也就是说两个生成器网络和判别器网络互为对抗,此消彼长。不可能Loss一直降到一个收敛的状态。
- 对于生成器,其Loss下降快,很有可能是判别器太弱,导致生成器很轻易的就"愚弄"了判别器。
- 对于判别器,其Loss下降快,意味着判别器很强,判别器很强则说明生成器生成的图像不够逼真,才使得判别器轻易判别,导致Loss下降很快。
也就是说,无论是判别器,还是生成器。loss的高低不能代表生成器的好坏。一个好的GAN网络,其GAN Loss往往是不断波动的。
4.3 解决mode collapsing
所谓GAN的训练崩溃,指的是训练过程中,生成器和判别器存在一方压倒另一方的情况。 GAN原始判别器的Loss在判别器达到最优的时候,等价于最小化生成分布与真实分布之间的JS散度,由于随机生成分布很难与真实分布有不可忽略的重叠以及JS散度的突变特性,使得生成器面临梯度消失的问题;可是如果不把判别器训练到最优,那么生成器优化的目标就失去了意义。因此需要我们小心的平衡二者,要把判别器训练的不好也不坏才行。否则就会出现训练崩溃,得不到想要的结果
方法一:针对目标函数的改进方法
为了避免前面提到的由于优化maxmin导致mode跳来跳去的问题,UnrolledGAN采用修改生成器loss来解决。具体而言,UnrolledGAN在更新生成器时更新k次生成器,参考的Loss不是某一次的loss,是判别器后面k次迭代的loss。注意,判别器后面k次迭代不更新自己的参数,只计算loss用于更新生成器。这种方式使得生成器考虑到了后面k次判别器的变化情况,避免在不同mode之间切换导致的模式崩溃问题。此处务必和迭代k次生成器,然后迭代1次判别器区分开。DRAGAN则引入博弈论中的无后悔算法,改造其loss以解决mode collapse问题。
方法二:针对网络结构的改进方法
Multi agent diverse GAN(MAD-GAN)采用多个生成器,一个判别器以保障样本生成的多样性。相比于普通GAN,多了几个生成器,且在loss设计的时候,加入一个正则项。正则项使用余弦距离惩罚三个生成器生成样本的一致性。
MRGAN则添加了一个判别器来惩罚生成样本的mode collapse问题。这个判别器主要用于判断生成的样本是否具有多样性,即是否出现mode collapse。
方法三:Mini-batch Discrimination
Mini-batch discrimination在判别器的中间层建立一个mini-batch layer用于计算基于L1距离的样本统计量,通过建立该统计量,实现了一个batch内某个样本与其他样本有多接近。这个信息可以被判别器利用到,从而甄别出哪些缺乏多样性的样本。对生成器而言,则要试图生成具有多样性的样本。
4.4 避免GAN的训练崩溃
- 归一化图像输入到(-1,1)之间;Generator最后一层使用tanh激活函数
- 生成器的Loss采用:min (log 1-D)。因为原始的生成器Loss存在梯度消失问题;训练生成器的时候,考虑反转标签,real=fake, fake=real
- 不要在均匀分布上采样,应该在高斯分布上采样
- 一个Mini-batch里面必须只有正样本,或者负样本。不要混在一起;如果用不了Batch Norm,可以用Instance Norm
- 避免稀疏梯度,即少用ReLU,MaxPool。可以用LeakyReLU替代ReLU,下采样可以用Average Pooling或者Convolution + stride替代。上采样可以用PixelShuffle, ConvTranspose2d + stride
- 平滑标签或者给标签加噪声;平滑标签,即对于正样本,可以使用0.7-1.2的随机数替代;对于负样本,可以使用0-0.3的随机数替代。 给标签加噪声:即训练判别器的时候,随机翻转部分样本的标签。
- 如果可以,请用DCGAN或者混合模型:KL+GAN,VAE+GAN。
- 使用LSGAN,WGAN-GP
- Generator使用Adam,Discriminator使用SGD
- 尽快发现错误;比如:判别器Loss为0,说明训练失败了;如果生成器Loss稳步下降,说明判别器没发挥作用
- 不要试着通过比较生成器,判别器Loss的大小来解决训练过程中的模型坍塌问题。比如: While Loss D > Loss A: Train D While Loss A > Loss D: Train A
- 如果有标签,请尽量利用标签信息来训练
- 给判别器的输入加一些噪声,给G的每一层加一些人工噪声。
- 多训练判别器,尤其是加了噪声的时候
- 对于生成器,在训练,测试的时候使用Dropout
5 GANs的改进
5.1 CGAN
对抗网络的一个主要缺点是训练过程不稳定,为了提高训练的稳定性,Conditional Generative Adversarial Nets (CGAN)其一定程度上解决了GAN生成结果的不确定性。如果在Mnist数据集上训练原始GAN,GAN生成的图像是完全不确定的,具体生成的是数字1,还是2,还是几,根本不可控。为了让生成的数字可控,我们可以把数据集做一个切分,把数字0~9的数据集分别拆分开训练9个模型,不过这样太麻烦了,也不现实。因为数据集拆分不仅仅是分类麻烦,更主要在于,每一个类别的样本少,拿去训练GAN很有可能导致欠拟合。因此,CGAN就应运而生了。我们先看一下CGAN的网络结构:
从网络结构图可以看到,对于生成器Generator,其输入不仅仅是随机噪声的采样z,还有欲生成图像的标签信息。比如对于mnist数据生成,就是一个one-hot向量,某一维度为1则表示生成某个数字的图片。同样地,判别器的输入也包括样本的标签。这样就使得判别器和生成器可以学习到样本和标签之间的联系。 Loss设计和原始GAN基本一致,只不过生成器,判别器的输入数据是一个条件分布。在具体编程实现时只需要对随机噪声采样z和输入条件y做一个级联即可。
即通过把无监督的 GAN 变成半监督或者有监督的模型,从而为 GAN 的训练加上一点目标,其优化的目标函数为:
CGAN在生成模型G和判别模型D的建模中均引入了条件变量 y,这里y可以是label或者其他数据形态,将y和GAN原有的输入合并成一个向量作为CGAN的输入。这个简单直接的改进被证明很有效,并广泛用于后续的相关工作中。CGAN模型的示意图如下所示:
5.2 DCGAN
前面我们聊的GAN都是基于简单的神经网络�构建的。可是对于视觉问题,如果使用原始的基于DNN的GAN,则会出现许多问题。如果输入GAN的随机噪声为100维的随机噪声,输出图像为256x256大小。也就是说,要将100维的信息映射为65536维。如果单纯用DNN来实现,�那么整个模型参数会非常巨大,而且学习难度很大(低维度映射到高维度需要添加许多信息)。因此,深度卷积神将网络DCGAN就出现了。
生成网络:卷积神经网络+反卷积神经网络(前者负责提取图像特征,后者负责根据输入的特征重新生成图像(即假数据))。
判别网络:卷积神经网络+全连接层处理(传统神经网络)(前者负责提取图像特征,后者负责判别真假。)
具体而言,DCGAN将传统GAN的生成器,判别器均采用GAN实现,且使用了一下tricks:
- 将pooling层convolutions替代,其中,在discriminator上用strided convolutions替代,在generator上用fractional-strided convolutions替代。
- 在generator和discriminator上都使用batchnorm。
- 移除全连接层,global pooling增加了模型的稳定性,但伤害了收敛速度。
- 在generator的除了输出层外的所有层使用ReLU,输出层采用tanh。
- 在discriminator的所有层上使用LeakyReLU。
对抗网络与卷积神经网络相结合进行图片生成,DCGAN模型的结构如下:
DCGANs的基本架构就是使用几层“反卷积”(Deconvolution)。传统的CNN是将图像的尺寸压缩,变得越来越小,而反卷积是将初始输入的小数据(噪声)变得越来越大(但反卷积并不是CNN的逆向操作),例如在上面这张图中,从输入层的100维noise,到最后输出层64x64x3的图片,从小维度产生出大的维度。反卷积的示意图如下所示,2x2的输入图片,经过3x3 的卷积核,可产生4x4的feature map:
由于反卷积存在于卷积的反向传播中。其中反向传播的卷积核矩阵是前向传播的转置,所以其又可称为transport convolution。只不过我们把反向传播的操作拿到了前向传播来做,就产生了所谓的反卷积一说。但是transport convolution只能还原信号的大小,不能还原其值,因此不是真正的逆操作。
DCGAN的另一个改进是对生成模型中池化层的处理,传统CNN使用池化层(max-pooling或mean-pooling)来压缩数据的尺寸。在反卷积过程中,数据的尺寸会变得越来越大,而max-pooling的过程并不可逆,所以DCGAN的论文里并没有采用池化的逆向操作,而只是让反卷积的滑动步长设定为2或更大值,从而让尺寸按我们的需求增大。另外,DCGAN模型在G和D上均使用了batch normalization,这使得训练过程更加稳定和可控。
该文献将GANs应用于文本转图像(Text to Image),从而可根据特定输入文本所描述的内容来产生特定图像。因此,生成模型里除了输入随机噪声之外,还有一些特定的自然语言信息。所以判别模型不仅要区分样本是否是真实的,还要判定其是否与输入的语句信息相符。网络结构如下图所示:
5.3 WGAN
在生成对抗网络中,当判断网络为最优时,生成网络的优化目标是最小化真实分布 pr(x) 和模型分布 pθ(x) 之间的JS散度。当两个分布相同时,JS散度为0,最优生成网络对应的损失为−2log2。但是使用JS散度来训练生成对抗网络的一个问题是当两个分布没有重叠时,它们之间的JS散度恒等于常数log2。对生成网络来说,目标函数关于参数的梯度为0。
在GAN的基础上加入了Wasserstein距离,以解决GAN网络训练过程难以判断收敛性的问题。Wasserstein距离用于衡量两个分布之间的距离。相比KL散度和JS散度的优势在于即使两个分布没有重叠或者重叠非常少,Wasserstein距离仍然能反映两个分布的远近。其数学公式如下:
从公式上GAN似乎总是让人摸不着头脑,在代码实现上来说,其实就以下几点:
- 判别器最后一层去掉sigmoid
- 生成器和判别器的loss不取log
- 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
实际实验过程发现,WGAN没有那么好用,主要原因在于WAGN进行梯度截断。梯度截断将导致判别网络趋向于一个二值网络,造成模型容量的下降。 于是作者提出使用梯度惩罚来替代梯度裁剪。如果需要的话,可以选择Layer Normalization。
5.4 LSGAN
LSGAN(Least Squares GAN)这篇文章主要针对标准GAN的稳定性和图片生成质量不高做了一个改进。作者将原始GAN的交叉熵损失采用最小二乘损失替代。
实际实现的时候非常简单,最后一层去掉sigmoid,并且计算Loss的时候用平方误差即可。
5.5 seqGAN
seqGAN在GAN的框架下,结合强化学习来做文本生成。
在文本生成任务,seqGAN相比较于普通GAN区别在以下几点:
- 生成器不取argmax。
- 每生成一个单词,则根据当前的词语序列进行蒙特卡洛采样生成完成的句子。然后将句子送入判别器计算reward。
- 根据得到的reward进行策略梯度下降优化模型。
6 GAN的应用
如果你的训练数据不充分,没问题。GANs可以根据已知的数据并生成合成图像来扩充您的数据集。
从描述生成图像(从文本到图像合成)。
提高视频的分辨率,以捕捉更精细的细节(从低分辨率到高分辨率)。
在音频领域,GAN也可以用于合成高保真音频或执行语音翻译。
6.1 图像生成
GAN最常使用的地方就是图像生成,如超分辨率任务,语义分割等等。
6.2 数据增强
用GAN生成的图像来做数据增强,