学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)
GAN图像生成
模型简介
生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。
GAN论文逐段精读【论文精读】_哔哩哔哩_bilibili
最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):
- 生成器的任务是生成看起来像训练图像的“假”图像;
- 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。
GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。
GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 𝐺 和估计样本是否来自训练数据的判别模型 𝐷 。
在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。
用 𝑥 代表图像数据,用 𝐷(𝑥)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,𝐷(𝑥)需要处理作为二进制文件的大小为 1×28×28的图像数据。当 𝑥 来自训练数据时,𝐷(𝑥) 数值应该趋近于 1 ;而当 𝑥 来自生成器时,𝐷(𝑥)数值应该趋近于 0 。因此 𝐷(𝑥)也可以被认为是传统的二分类器。
用 𝑧 代表标准正态分布中提取出的隐码(隐向量),用 𝐺(𝑧):表示将隐码(隐向量) 𝑧 映射到数据空间的生成器函数。函数 𝐺(𝑧)的目标是将服从高斯分布的随机噪声 𝑧 通过生成网络变换为近似于真实分布 的数据分布,我们希望找到 θ 使得尽可能的接近,其中 𝜃代表网络参数。
符号 表示的是生成模型 生成数据的概率分布,参数为。具体来说:
:表示数据样本,例如图像或其他形式的数据。
:表示生成器的参数,这些参数是通过训练过程来优化的。
:表示通过生成器在参数下生成数据的概率分布。
表示生成器 𝐺 生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述,𝐷 和 𝐺 在进行一场博弈,𝐷 想要最大程度的正确分类真图像与假图像,也就是参数;而 𝐺 试图欺骗 𝐷 来最小化假图像被识别到的概率,也就是参数。因此GAN的损失函数为:
- 判别器的目标函数:
- 生成器的目标函数:
在GAN(生成式对抗网络)中,损失函数反映了生成器和判别器之间的对抗关系。有两个模型:生成器和判别器。判别器的目标是最大化其正确分类真实图像和生成图像的概率,而生成器 的目标是最小化被判别器正确识别为生成图像的概率。
判别器的目标是最大化以下期望值:
其中,表示真实图像,是从标准正态分布中采样的隐变量,表示生成的假的图像。第一项期望值表示真实图像被判别为真实的概率,第二项期望值表示生成的图像被判别为假的概率。
另一方面,生成器的目标是最小化判别器正确识别为生成图像的概率,这相当于最小化以下期望值:
综合两个目标, 可以给出生成式对抗网络的损失函数(目标函数)为:
这是一个双重优化问题,两个模型在同一时间被训练:判别器尽可能区分真实和生成图像,生成器尽可能生成能够欺骗判别器的图像。
以下是解释关键项的详细含义:
1.:期望真实数据被判别为真实的概率的对数,这部分使得判别器更好地区分真实数据。
2. :期望生成的假数据被判别为假的概率的对数,这部分反映了生成器通过尝试生成能够欺骗判别器的假数据来提高生成效果。
通过这样的对抗训练过程,生成器会逐渐提高其生成质量,生成的图像会愈发接近真实图像,使得判别器愈加难以区分真假图像。
从理论上讲,此博弈游戏的平衡点是,此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:·
- 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
- 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
- 生成器通过优化,生成出更加贴近真实数据分布的数据。
- 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。
在理想的均衡点,生成器生成的数据分布 和真实的数据分布 相同,即 。此时,对于任意一个输入数据,判别器无法区分它是生成器生成的还是实际数据,这就表示判别器的输出应该在0和1之间摇摆不定,最可能的值是0.5。
在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数据分布,𝑧 表示隐码,𝑥 表示生成的虚假图像 𝐺(𝑧)。该图片来源于Generative Adversarial Nets。详细的训练方法介绍见原论文。
在这四张图片中,黑色虚线代表了真实数据分布的概率密度。我们可以看到左侧通常表示真实数据分布较高的区域。因此,判别器在这些区域应该输出更高的概率值。从图中蓝色虚线的形状可以看出,随着训练的进行,判别器在左侧区域的输出值是大于0.5的,因为这个区域内的样本被判别器认为更可能是真实数据。第四张图片在理想情况下,生成器生成的数据分布(绿色实线)与真实数据分布(黑色虚线)几乎完全重合。判别器无法区分生成数据和真实数据,只能随机猜测,所以在这个阶段的判别器输出会接近1/2。
数据集
数据集简介
MNIST手写数字数据集是NIST数据集的子集,共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道。图片已经预先进行了尺寸归一化和中心化处理。
本案例将使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图片。
数据集下载
使用download
接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用pip install download
安装download
包。
下载解压后的数据集目录结构如下:
./MNIST_Data/
├─ train
│ ├─ train-images-idx3-ubyte
│ └─ train-labels-idx1-ubyte
└─ test
├─ t10k-images-idx3-ubyte
└─ t10k-labels-idx1-ubyte
数据下载的代码如下:
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore Version: 2.2.14 Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios. Home-page: https://www.mindspore.cn Author: The MindSpore Authors Author-email: contact@mindspore.cn License: Apache 2.0 Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy Required-by:
# 数据下载
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB) file_sizes: 100%|███████████████████████████| 10.8M/10.8M [00:00<00:00, 116MB/s] Extracting zip file... Successfully downloaded / unzipped to .[3]:
'.'
数据加载
使用MindSpore自己的MnistDatase
接口,读取和解析MNIST数据集的源文件构建数据集。然后对数据进行一些前处理。
import numpy as np
import mindspore.dataset as ds
batch_size = 64
latent_size = 100 # 隐码的长度
# 载入数据集
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')
# 定义数据加载函数
def data_load(dataset):
dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False,num_samples=10000)
# 数据增强
mnist_ds = dataset1.map(
operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
output_columns=["image", "latent_code"])
mnist_ds = mnist_ds.project(["image", "latent_code"])
# 批量操作
mnist_ds = mnist_ds.batch(batch_size, True)
return mnist_ds
mnist_ds = data_load(train_dataset)
iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
Iter size: 156
dataset1
是一个 GeneratorDataset
,它从原始数据集中生成一个新的数据集并对其进行打乱(shuffle),限制样本数量为10000。operations
部分是一个匿名函数(lambda 函数),它接收输入的数据 x
,然后返回两个新的值:将图像数据转换为 float32
类型、生成一个服从标准正态分布的随机向量,大小为 latent_size
,并将其转化为 float32
类型。output_columns=["image", "latent_code"]
定义了输出数据集的列名。project
方法用于选择数据集中指定的列。["image", "latent_code"]
表示只保留图像和生成的潜在编码(latent code)列,丢弃其他不需要的列如 label
。
数据集可视化
通过create_dict_iterator
函数将数据转换成字典迭代器,然后使用matplotlib
模块可视化部分训练数据。
import matplotlib.pyplot as plt
data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):
image = data_iter['image'][idx]
figure.add_subplot(rows, cols, idx)
plt.axis("off")
plt.imshow(image.squeeze(), cmap="gray")
plt.show()
隐码构造
为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise
输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。
import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype
# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)
生成隐码test_noise,用来生成
图像。
模型构建
本案例实现中所搭建的 GAN 模型结构与原论文中提出的 GAN 结构大致相同,但由于所用数据集 MNIST 为单通道小尺寸图片,可识别参数少,便于训练,我们在判别器和生成器中采用全连接网络架构和 ReLU
激活函数即可达到令人满意的效果,且省略了原论文中用于减少参数的 Dropout
策略和可学习激活函数 Maxout
。
生成器
生成器 Generator
的功能是将隐码映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的灰度图像(或 RGB 彩色图像)。在本案例演示中,该功能通过五层 Dense
全连接层来完成的,每层都与 BatchNorm1d
批归一化层和 ReLU
激活层配对,输出数据会经过 Tanh
函数,使其返回 [-1,1] 的数据范围内。注意实例化生成器之后需要修改参数的名称,不然静态图模式下会报错。
from mindspore import nn
import mindspore.ops as ops
img_size = 28 # 训练图像长(宽)
class Generator(nn.Cell):
def __init__(self, latent_size, auto_prefix=True):
super(Generator, self).__init__(auto_prefix=auto_prefix)
self.model = nn.SequentialCell()
# [N, 100] -> [N, 128]
# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维
self.model.append(nn.Dense(latent_size, 128))
self.model.append(nn.ReLU())
# [N, 128] -> [N, 256]
self.model.append(nn.Dense(128, 256))
self.model.append(nn.BatchNorm1d(256))
self.model.append(nn.ReLU())
# [N, 256] -> [N, 512]
self.model.append(nn.Dense(256, 512))
self.model.append(nn.BatchNorm1d(512))
self.model.append(nn.ReLU())
# [N, 512] -> [N, 1024]
self.model.append(nn.Dense(512, 1024))
self.model.append(nn.BatchNorm1d(1024))
self.model.append(nn.ReLU())
# [N, 1024] -> [N, 784]
# 经过线性变换将其变成784维
self.model.append(nn.Dense(1024, img_size * img_size))
# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
self.model.append(nn.Tanh())
def construct(self, x):
img = self.model(x)
return ops.reshape(img, (-1, 1, 28, 28))
net_g = Generator(latent_size)
net_g.update_parameters_name('generator')
定义了一个生成器 (Generator) 模型,用于生成与真实图像大小相同的灰度图像 (或RGB彩色图像),该生成器使用了五层 Dense 全连接层,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,最后一层使用 Tanh 激活函数。
使用五层 Dense 全连接层(全连接神经网络层)主要是为了逐步提升输入的维度以生成高质量的图像。到具体每层参数的选择,比如128、256、512、1024,是通过经验和实验选择的,这些数值通常能够在保持计算高效的同时提升生成图像的质量。
- 输入100维到128维:从一个100维的噪声输入开始,初步扩大到128维。
- 128维到256维:进一步扩大,增加特征复杂度。
- 256维到512维:继续递增维度,捕捉更多特征信息。
- 512维到1024维:最大层,提供足够的容量以生成复杂图像。
- 1024维到784维(28×28):最终输出尺寸为28x28(784个像素),适合常见的小图像数据集如MNIST。
BatchNorm1d 批归一化层
批归一化层(Batch Normalization)在神经网络训练中有如下作用:
1. 稳定训练过程:它通过标准化每个小批次的输入,使其均值为0方差为1,来稳定和加速神经网络的训练。
2. 减轻内部协变量偏移:每一层输入的分布保持相对稳定,使网络各层能够更容易学习。
3. 减轻过拟合:有轻微的正则化效果,从而减少过拟合的现象。
BatchNorm1d 对高维向量进行归一化。其参数(如256,512)与前后层输出的维度相匹配。
代码中使用了两个激活函数:
1. ReLU (Rectified Linear Unit):
用于隐藏层。ReLU激活函数具有稀疏激活和梯度消失问题的解决能力。
2. Tanh (Hyperbolic Tangent):
用于生成器的最后一层。
Tanh激活函数将输出限制在[-1, 1]之间。
在实例化生成器之后需要修改参数的名称,这是因为在静态图模式(如有些深度学习框架中的图计算模式)下,参数名称需要唯一且一致,不然可能会发生名称冲突或不一致的问题,导致内部图计算无法正确构建和运行。new_parameters_prefix 设为 generator
,调用 net_g.update_parameters_name('generator')
后,net_g 内部所有参数的名称都会被自动地加上generator
前缀,确保参数名称的唯一性。
判别器
如前所述,判别器 Discriminator
是一个二分类网络模型,输出判定该图像为真实图的概率。主要通过一系列的 Dense
层和 LeakyReLU
层对其进行处理,最后通过 Sigmoid
激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。注意实例化判别器之后需要修改参数的名称,不然静态图模式下会报错。
# 判别器
class Discriminator(nn.Cell):
def __init__(self, auto_prefix=True):
super().__init__(auto_prefix=auto_prefix)
self.model = nn.SequentialCell()
# [N, 784] -> [N, 512]
self.model.append(nn.Dense(img_size * img_size, 512)) # 输入特征数为784,输出为512
self.model.append(nn.LeakyReLU()) # 默认斜率为0.2的非线性映射激活函数
# [N, 512] -> [N, 256]
self.model.append(nn.Dense(512, 256)) # 进行一个线性映射
self.model.append(nn.LeakyReLU())
# [N, 256] -> [N, 1]
self.model.append(nn.Dense(256, 1))
self.model.append(nn.Sigmoid()) # 二分类激活函数,将实数映射到[0,1]
def construct(self, x):
x_flat = ops.reshape(x, (-1, img_size * img_size))
return self.model(x_flat)
net_d = Discriminator()
net_d.update_parameters_name('discriminator')
Dense层的参数(神经元数)主要是从输入图像大小逐层缩减到单一输出(表示图像为真实图的概率)。
输入层(784 -> 512):输入图像的每个像素点展平成一个一维向量,长度为28*28=784。使用512个神经元来处理这个输入,在信息传递过程中做到信息浓缩而不过多丢失细节。
隐藏层1(512 -> 256将512维的输入映射到256维。继续收缩特征空间,减少特征数量,同时保持重要的表示。
输出层(256 -> 1):最后一层只输出一个值,通过 Sigmoid 激活函数,将其映射到[0,1]区间,表示判别结果的概率分数。
两个激活函数:
LeakyReLU:在隐藏层(512和256)后面使用,它是一种改进的ReLU函数,通过线性保持负数部分而不是直接切断它们(默认斜率为0.2),这有助于缓解神经元的"死亡"问题并允许在负值区域有一定的梯度传播。
Sigmoid:在最后的输出层后面使用,将线性输出值压缩到[0,1]区间,使其成为一个概率值,适合二分类任务。
损失函数和优化器
定义了 Generator
和 Discriminator
后,损失函数使用MindSpore中二进制交叉熵损失函数BCELoss
;这里生成器和判别器都是使用Adam
优化器,但是需要构建两个不同名称的优化器,分别用于更新两个模型的参数,详情见下文代码。注意优化器的参数名称也需要修改。
lr = 0.0002 # 学习率
# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')
# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')
BCELoss
(Binary Cross Entropy Loss)是二分类问题常用的一种损失函数。
Adam
(Adaptive Moment Estimation)优化器是一种能很好地适应训练过程中参数调整的优化算法。
模型训练
训练分为两个主要部分。
第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 的值。
第二部分是训练生成器。如论文所述,最小化来训练生成器,以产生更好的虚假图像。
在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将隐码批量推送到生成器中,以直观地跟踪生成器 Generator
的训练效果。
import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint
total_epoch = 12 # 训练周期数
batch_size = 64 # 用于训练的训练集批量大小
# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'
checkpoints_path = "./result/checkpoints" # 结果保存路径
image_path = "./result/images" # 测试结果保存路径
# 生成器计算损失过程
def generator_forward(test_noises):
fake_data = net_g(test_noises)
fake_out = net_d(fake_data)
loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
return loss_g
# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):
fake_data = net_g(test_noises)
fake_out = net_d(fake_data)
real_out = net_d(real_data)
real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
loss_d = real_loss + fake_loss
return loss_d
# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())
def train_step(real_data, latent_code):
# 计算判别器损失和梯度
loss_d, grads_d = grad_d(real_data, latent_code)
optimizer_d(grads_d)
loss_g, grads_g = grad_g(latent_code)
optimizer_g(grads_g)
return loss_d, loss_g
# 保存生成的test图像
def save_imgs(gen_imgs1, idx):
for i3 in range(gen_imgs1.shape[0]):
plt.subplot(5, 5, i3 + 1)
plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
plt.axis("off")
plt.savefig(image_path + "/test_{}.png".format(idx))
# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)
net_g.set_train()
net_d.set_train()
# 储存生成器和判别器loss
losses_g, losses_d = [], []
for epoch in range(total_epoch):
start = time.time()
for (iter, data) in enumerate(mnist_ds):
start1 = time.time()
image, latent_code = data
image = (image - 127.5) / 127.5 # [0, 255] -> [-1, 1]
image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
d_loss, g_loss = train_step(image, latent_code)
end1 = time.time()
if iter % 10 == 10:
print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
f"loss_d:{d_loss.asnumpy():>4f} , "
f"loss_g:{g_loss.asnumpy():>4f} , "
f"time:{(end1 - start1):>3f}s, "
f"lr:{lr:>6f}")
end = time.time()
print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))
losses_d.append(d_loss.asnumpy())
losses_g.append(g_loss.asnumpy())
# 每个epoch结束后,使用生成器生成一组图片
gen_imgs = net_g(test_noise)
save_imgs(gen_imgs.asnumpy(), epoch)
# 根据epoch保存模型权重文件
if epoch % 1 == 0:
save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))
time of epoch 1 is 84.03s time of epoch 2 is 7.17s time of epoch 3 is 7.07s time of epoch 4 is 6.86s time of epoch 5 is 7.00s time of epoch 6 is 6.96s time of epoch 7 is 7.02s time of epoch 8 is 7.02s time of epoch 9 is 7.00s time of epoch 10 is 6.88s time of epoch 11 is 6.91s time of epoch 12 is 6.93s
效果展示
运行下面代码,描绘D
和G
损失与训练迭代的关系图:
plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
可视化训练过程中通过隐向量生成的图像。
import cv2
import matplotlib.animation as animation
# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):
image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):
plt.axis("off")
show_list.append([plt.imshow(image_list[epoch], cmap='gray')])
ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)
epoch为100时:
epoch为200时:
从上面的图像可以看出,随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当 epoch
达到100以上时,生成的手写数字图片与数据集中的较为相似。下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:
模型推理
下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:
import mindspore as ms
# test_ckpt = './result/checkpoints/Generator199.ckpt'
# parameter = ms.load_checkpoint(test_ckpt)
# ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):
fig.add_subplot(5, 5, i + 1)
plt.axis("off")
plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()
epoch为12时效果:
epoch为200时效果:
GAN的基本原理:GAN由生成器(Generator)和判别器(Discriminator)组成。生成器的目标是生成看起来像真实数据的假数据,而判别器的目标是区分真实数据和生成器生成的假数据。通过这种对抗训练过程,生成器和判别器相互竞争,最终生成器能够生成高质量的假数据。
数据集处理:使用了MNIST手写数字数据集进行训练。数据集通过MindSpore的MnistDataset接口进行加载和预处理。
模型构建:生成器和判别器都使用了全连接层(Dense)和激活函数(ReLU、LeakyReLU、Sigmoid)。生成器通过多层全连接层将随机噪声映射到图像空间,而判别器则通过多层全连接层将图像映射到概率空间。
损失函数和优化器:使用了二进制交叉熵损失函数(BCELoss)来计算生成器和判别器的损失。
使用了Adam优化器来更新生成器和判别器的参数。
训练过程:训练过程分为两个主要部分:训练判别器和训练生成器。在每个epoch结束时,生成一组固定的随机噪声输入到生成器中,以评估生成器的性能。训练过程中,生成器和判别器的损失被记录下来,并在训练结束后进行可视化。
结果展示:通过加载预训练的生成器模型参数文件,生成新的手写数字图像。生成的图像通过Matplotlib进行展示。