Mindspore框架DCGAN模型实现漫画头像生成
- Mindspore框架DCGAN模型实现漫画头像生成|(一)漫画头像数据集准备
- Mindspore框架DCGAN模型实现漫画头像生成|(二)DCGAN模型构建
- Mindspore框架DCGAN模型实现漫画头像生成|(三)DCGAN模型训练和推理
- Mindspore框架DCGAN模型实现漫画头像生成|(四)应用程序生成实践
Mindspore框架DCGAN模型实现漫画头像生成|(二)DCGAN模型构建
DCGAN,全称是 Deep Convolution Generative Adversarial Networks,深度卷积生成对抗网络
。
1. DCGAN模型特点
- make GAN + CNN more stable and deeper,能够产生更高分辨率的图像;
- 全卷积网络(all convolutional net):用步幅卷积(strided convolutions)替代确定性空间池化函数(deterministic spatial pooling functions)(比如最大池化),让网络自己学习downsampling方式。作者对 generator 和 discriminator 都采用了这种方法。
- 取消全连接层:使用 全局平均池化(global average pooling)替代 fully connected layer。global average pooling会降低收敛速度,但是可以提高模型的稳定性。GAN的输入采用均匀分布初始化,可能会使用全连接层(矩阵相乘),然后得到的结果可以reshape成一个4 dimension的tensor,然后后面堆叠卷积层即可;对于鉴别器,最后的卷积层可以先flatten,然后送入一个sigmoid分类器。
- 批归一化(Batch Normalization):BN 被证明是深度学习中非常重要的 加速收敛 和 减缓过拟合 的手段。这样有助于解决 poor initialization 问题并帮助梯度流向更深的网络。防止G把所有rand input都折叠到一个点,同时防止样本震荡和模型的不稳定,只对生成器(G)的输出层和鉴别器(D)的输入层使用BN。
- Leaky Relu 激活函数: 生成器(G),输出层使用tanh 激活函数,其余层使用relu 激活函数。鉴别器(D),都采用leaky rectified activation。
- DCGAN生成器G的结构如下:
2. 构造网络:生成器G
生成器G的功能是将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。
import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normal
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
# 通过输入部分中设置的nz、ngf和nc来影响代码中的生成器结构。
class Generator(nn.Cell):
"""DCGAN网络生成器"""
def __init__(self):
super(Generator, self).__init__()
self.generator = nn.SequentialCell(
nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),
nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
nn.ReLU(),
nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
nn.ReLU(),
nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
nn.ReLU(),
nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),
nn.BatchNorm2d(ngf, gamma_init=gamma_init),
nn.ReLU(),
nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),
nn.Tanh()
)
def construct(self, x):
return self.generator(x)
generator = Generator()
注意:nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数
。
2. 构造网络:判别器D
判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。形如:
通过一系列的Conv2d、BatchNorm2d和LeakyReLU层对其进行处理,最后通过Sigmoid激活函数得到最终概率。
class Discriminator(nn.Cell):
"""DCGAN网络判别器"""
def __init__(self):
super(Discriminator, self).__init__()
self.discriminator = nn.SequentialCell(
nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),
nn.LeakyReLU(0.2),
nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
nn.LeakyReLU(0.2),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
nn.LeakyReLU(0.2),
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),
nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
nn.LeakyReLU(0.2),
nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),
)
self.adv_layer = nn.Sigmoid()
def construct(self, x):
out = self.discriminator(x)
out = out.reshape(out.shape[0], -1)
return self.adv_layer(out)
discriminator = Discriminator()
模型结构输出: