文章目录
- 一、原理讲述
- 1.1 概念讲解
- 1.2 生成模型和判别模型
- 二、训练过程
- 2.1 训练原理
- 2.2 损失函数
- 三、应用
一、原理讲述
1.1 概念讲解
1. 生成式对抗网络(Generative Adversarial Network,GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。它启发自博弈论中的二人零和博弈(two-player game),两位博弈方分别由生成模型(generative model)和判别模型(discriminative model)充当。
2. 判别模型用于判断一个给定的图片是不是真实的图片(从数据集中获取的图片),生成模型的任务是去创造一个看起来像真的图片一样的图片。两个模型一起对抗训练,生成模型产生一张图片去欺骗判别模型,判别模型判断这张图片是真的还是假的。在这两个模型训练的过程中,两个模型的能力越来越强,最终达到稳态。
latent random variable:潜在的随机变量
generated fake samples:生成的假样本
fine tune training:微调训练
generator:生成器
discriminator:判别器
1.2 生成模型和判别模型
1. 生成模型:生产模型的输入是二维高斯模型中的一个随机向量,生成模型的输出是一张伪造的假图片(fake image),同时获取数据集中的真实图片,然后将假图片和真实图片传给判别模型,由判别模型给出是真实图片还是假图片的判别结果。
2. 判别模型:根据输入的图片类型是假图片或真实图片,将输入数据的lable(标签)标记为0或者1。经过判别模型后输出值为一个0到1之间的数,用于表示输入图片为真实图片的概率,1表示真实图片,0表示假图片。
二、训练过程
2.1 训练原理
1. GAN的训练在同一轮梯度反转的过程中可以细分为2步:(1)先训练D。(2)再训练G。注意:不是等所有的D训练好了才开始训练G,因为D的训练也需要上一轮梯度反转中的G的输出值作为输入。
梯度反转(Gradient Reversal)是一种无监督学习方法,通过将梯度乘上一个负数来反转梯度方向,以达到欺骗判别器的效果,使得源域和目标域之间的特征分布可以互相“融合”,从而实现域自适应的目的。
2. 当训练D的时候:上一轮G产生的图片和真实图片,直接拼接在一起作为x。然后按顺序摆放成0和1,假图对应0,真图对应1。然后就可以通过D、x输入生成一个score(从0到1之间的数),通过score和y组成的损失函数,就可以进行梯度反转了。
3. 当训练G的时候:需要把G和D当作一个整体,这里取名叫做’D_on_G’。这个整体(简称DG系统)的输出仍然是score。输入一组随机向量z,就可以在G生成一张图,通过D对生成的这张图进行打分得到score,这就是DG系统的前向过程。score=1就是DG系统需要优化的目标,score和y=1之间的差异可以组成损失函数,然后可以采用反向传播梯度。注意,这里的D的参数是不可训练的。这样就能保证G的训练是符合D的打分标准的。
2.2 损失函数
1. 判别模型
D
D
D 的损失函数为如下所示。其中:
x
x
x 表示真实图像;
z
z
z 表示输入网络中的噪声;
G
(
z
)
G(z)
G(z) 表示生成器生成的假图像;
D
(
x
)
D(x)
D(x) 表示判别模型判断真实图像是否为真的概率(由于是真实图像,我们当然希望概率越接近
1
1
1 越好);
D
(
G
(
z
)
)
D(G(z))
D(G(z)) 为判别模型
D
D
D 去判断生成模型
G
G
G 生成的假图像是否为真图像的概率(由于是生成模型生成的假图像,我们希望概率越接近
0
0
0 越好)。
我们总是期望
D
(
x
)
D(x)
D(x) 越大,
D
(
G
(
z
)
)
D(G(z))
D(G(z)) 越小,因此要最大化下式,用
l
o
g
log
log 函数约束它们之间的关系,通过训练不断调整网络的权值,以达到我们的期望。
2. 生成模型
G
G
G 的损失函数为如下所示。生成模型的主要作用就是从随机信号生成一张图像,来尽可能地拟合真实图像,使得判别模型
D
D
D 无法判断生成图像的真伪。由
l
o
g
log
log 函数的性质可知,只有当
D
(
G
(
z
)
)
D(G(z))
D(G(z)) 的值接近
1
1
1 的时候,下式才能有最小值。
D
(
G
(
z
)
)
=
1
D(G(z))=1
D(G(z))=1 表示判别模型
D
D
D 将生成模型
G
G
G 生成的图像判断为真实图像,所以最小化这个函数就可以使生成模型
G
G
G 通过不断训练生成接近真实图像分布的图像。
3. 结合上面的叙述,总的优化函数为:
三、应用
GAN最常使用的地方图像生成,如超分辨率任务,语义分割等。用GAN生成的图像也可以来做数据增强。