昇思25天学习打卡营第25天|GAN图像生成

news2024/9/21 0:50:39

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

GAN图像生成

模型简介

生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

GAN论文逐段精读【论文精读】_哔哩哔哩_bilibili

最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):

  • 生成器的任务是生成看起来像训练图像的“假”图像;
  • 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 𝐺 和估计样本是否来自训练数据的判别模型 𝐷 。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

用 𝑥 代表图像数据,用 𝐷(𝑥)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,𝐷(𝑥)需要处理作为二进制文件的大小为 1×28×28的图像数据。当 𝑥 来自训练数据时,𝐷(𝑥) 数值应该趋近于 1 ;而当 𝑥 来自生成器时,𝐷(𝑥)数值应该趋近于 0 。因此 𝐷(𝑥)也可以被认为是传统的二分类器。

用 𝑧 代表标准正态分布中提取出的隐码(隐向量),用 𝐺(𝑧):表示将隐码(隐向量) 𝑧 映射到数据空间的生成器函数。函数 𝐺(𝑧)的目标是将服从高斯分布的随机噪声 𝑧 通过生成网络变换为近似于真实分布 的数据分布,我们希望找到 θ 使得尽可能的接近,其中 𝜃代表网络参数。

符号p_G(x; \theta) 表示的是生成模型 G 生成数据x的概率分布,参数为\theta。具体来说:

x:表示数据样本,例如图像或其他形式的数据。
\theta:表示生成器G的参数,这些参数是通过训练过程来优化的。
p_G(x; \theta):表示通过生成器G在参数\theta下生成数据x的概率分布。

表示生成器 𝐺 生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述,𝐷 和 𝐺 在进行一场博弈,𝐷 想要最大程度的正确分类真图像与假图像,也就是参数;而 𝐺 试图欺骗 𝐷 来最小化假图像被识别到的概率,也就是参数。因此GAN的损失函数为:

\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

- 判别器的目标函数:

\max_D \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

- 生成器的目标函数:
\min_G \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

在GAN(生成式对抗网络)中,损失函数反映了生成器和判别器之间的对抗关系。有两个模型:生成器G和判别器D。判别器D的目标是最大化其正确分类真实图像和生成图像的概率,而生成器 G的目标是最小化被判别器正确识别为生成图像的概率。

判别器D的目标是最大化以下期望值:

\mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

其中,x表示真实图像,z是从标准正态分布中采样的隐变量,G(z)表示生成的假的图像。第一项期望值表示真实图像被判别为真实的概率,第二项期望值表示生成的图像被判别为假的概率。

另一方面,生成器G的目标是最小化判别器正确识别为生成图像的概率,这相当于最小化以下期望值:

\mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

综合两个目标, 可以给出生成式对抗网络的损失函数(目标函数)为:

\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

这是一个双重优化问题,两个模型在同一时间被训练:判别器D尽可能区分真实和生成图像,生成器G尽可能生成能够欺骗判别器的图像。

以下是解释关键项的详细含义:

1.\mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)]:期望真实数据x被判别为真实的概率的对数,这部分使得判别器更好地区分真实数据。

2. \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]:期望生成的假数据G(z)被判别为假的概率的对数,这部分反映了生成器通过尝试生成能够欺骗判别器的假数据来提高生成效果。

通过这样的对抗训练过程,生成器会逐渐提高其生成质量,生成的图像会愈发接近真实图像,使得判别器愈加难以区分真假图像。

从理论上讲,此博弈游戏的平衡点是,此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:·

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  2. 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据。
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。

在理想的均衡点,生成器生成的数据分布 p_g和真实的数据分布 p_{data} 相同,即 p_g = p_{data}。此时,对于任意一个输入数据,判别器无法区分它是生成器生成的还是实际数据,这就表示判别器的输出应该在0和1之间摇摆不定,最可能的值是0.5。

gan

在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数据分布,𝑧 表示隐码,𝑥 表示生成的虚假图像 𝐺(𝑧)。该图片来源于Generative Adversarial Nets。详细的训练方法介绍见原论文。

在这四张图片中,黑色虚线代表了真实数据分布的概率密度。我们可以看到左侧通常表示真实数据分布较高的区域。因此,判别器在这些区域应该输出更高的概率值。从图中蓝色虚线的形状可以看出,随着训练的进行,判别器在左侧区域的输出值是大于0.5的,因为这个区域内的样本被判别器认为更可能是真实数据。第四张图片在理想情况下,生成器生成的数据分布(绿色实线)与真实数据分布(黑色虚线)几乎完全重合。判别器无法区分生成数据和真实数据,只能随机猜测,所以在这个阶段的判别器输出会接近1/2。

数据集

数据集简介

MNIST手写数字数据集是NIST数据集的子集,共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道。图片已经预先进行了尺寸归一化和中心化处理。

本案例将使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图片。

数据集下载

使用download接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用pip install download安装download包。

下载解压后的数据集目录结构如下:

./MNIST_Data/
├─ train
│ ├─ train-images-idx3-ubyte
│ └─ train-labels-idx1-ubyte
└─ test
   ├─ t10k-images-idx3-ubyte
   └─ t10k-labels-idx1-ubyte

数据下载的代码如下:

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 
# 数据下载
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)

file_sizes: 100%|███████████████████████████| 10.8M/10.8M [00:00<00:00, 116MB/s]
Extracting zip file...
Successfully downloaded / unzipped to .

[3]:

'.'

数据加载

使用MindSpore自己的MnistDatase接口,读取和解析MNIST数据集的源文件构建数据集。然后对数据进行一些前处理。

import numpy as np
import mindspore.dataset as ds

batch_size = 64
latent_size = 100  # 隐码的长度

# 载入数据集
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')

# 定义数据加载函数
def data_load(dataset):
    dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False,num_samples=10000)
    # 数据增强
    mnist_ds = dataset1.map(
        operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
        output_columns=["image", "latent_code"])
    mnist_ds = mnist_ds.project(["image", "latent_code"])

    # 批量操作
    mnist_ds = mnist_ds.batch(batch_size, True)

    return mnist_ds

mnist_ds = data_load(train_dataset)

iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
Iter size: 156

dataset1 是一个 GeneratorDataset,它从原始数据集中生成一个新的数据集并对其进行打乱(shuffle),限制样本数量为10000。operations 部分是一个匿名函数(lambda 函数),它接收输入的数据 x,然后返回两个新的值:将图像数据转换为 float32 类型、生成一个服从标准正态分布的随机向量,大小为 latent_size,并将其转化为 float32 类型。output_columns=["image", "latent_code"] 定义了输出数据集的列名。project 方法用于选择数据集中指定的列。["image", "latent_code"] 表示只保留图像和生成的潜在编码(latent code)列,丢弃其他不需要的列如 label

数据集可视化

通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。

import matplotlib.pyplot as plt

data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):
    image = data_iter['image'][idx]
    figure.add_subplot(rows, cols, idx)
    plt.axis("off")
    plt.imshow(image.squeeze(), cmap="gray")
plt.show()

隐码构造

为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。

import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype

# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)

生成隐码test_noise,用来生成图像。

模型构建

本案例实现中所搭建的 GAN 模型结构与原论文中提出的 GAN 结构大致相同,但由于所用数据集 MNIST 为单通道小尺寸图片,可识别参数少,便于训练,我们在判别器和生成器中采用全连接网络架构和 ReLU 激活函数即可达到令人满意的效果,且省略了原论文中用于减少参数的 Dropout 策略和可学习激活函数 Maxout

生成器

生成器 Generator 的功能是将隐码映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的灰度图像(或 RGB 彩色图像)。在本案例演示中,该功能通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内。注意实例化生成器之后需要修改参数的名称,不然静态图模式下会报错。

from mindspore import nn
import mindspore.ops as ops

img_size = 28  # 训练图像长(宽)

class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, 28, 28))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

定义了一个生成器 (Generator) 模型,用于生成与真实图像大小相同的灰度图像 (或RGB彩色图像),该生成器使用了五层 Dense 全连接层,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,最后一层使用 Tanh 激活函数。
使用五层 Dense 全连接层(全连接神经网络层)主要是为了逐步提升输入的维度以生成高质量的图像。到具体每层参数的选择,比如128、256、512、1024,是通过经验和实验选择的,这些数值通常能够在保持计算高效的同时提升生成图像的质量。
- 输入100维到128维:从一个100维的噪声输入开始,初步扩大到128维。
- 128维到256维:进一步扩大,增加特征复杂度。
- 256维到512维:继续递增维度,捕捉更多特征信息。
- 512维到1024维:最大层,提供足够的容量以生成复杂图像。
- 1024维到784维(28×28):最终输出尺寸为28x28(784个像素),适合常见的小图像数据集如MNIST。
BatchNorm1d 批归一化层
批归一化层(Batch Normalization)在神经网络训练中有如下作用:
1. 稳定训练过程:它通过标准化每个小批次的输入,使其均值为0方差为1,来稳定和加速神经网络的训练。
2. 减轻内部协变量偏移:每一层输入的分布保持相对稳定,使网络各层能够更容易学习。
3. 减轻过拟合:有轻微的正则化效果,从而减少过拟合的现象。
BatchNorm1d 对高维向量进行归一化。其参数(如256,512)与前后层输出的维度相匹配。
代码中使用了两个激活函数:
1. ReLU (Rectified Linear Unit):
用于隐藏层。ReLU激活函数具有稀疏激活和梯度消失问题的解决能力。
2. Tanh (Hyperbolic Tangent):
用于生成器的最后一层。
Tanh激活函数将输出限制在[-1, 1]之间。

在实例化生成器之后需要修改参数的名称,这是因为在静态图模式(如有些深度学习框架中的图计算模式)下,参数名称需要唯一且一致,不然可能会发生名称冲突或不一致的问题,导致内部图计算无法正确构建和运行。new_parameters_prefix 设为 generator,调用 net_g.update_parameters_name('generator') 后,net_g 内部所有参数的名称都会被自动地加上generator前缀,确保参数名称的唯一性。

判别器

如前所述,判别器 Discriminator 是一个二分类网络模型,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。注意实例化判别器之后需要修改参数的名称,不然静态图模式下会报错。

 # 判别器
class Discriminator(nn.Cell):
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 784] -> [N, 512]
        self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
        self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
        # [N, 512] -> [N, 256]
        self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
        self.model.append(nn.LeakyReLU())
        # [N, 256] -> [N, 1]
        self.model.append(nn.Dense(256, 1))
        self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')

Dense层的参数(神经元数)主要是从输入图像大小逐层缩减到单一输出(表示图像为真实图的概率)。
输入层(784 -> 512):输入图像的每个像素点展平成一个一维向量,长度为28*28=784。使用512个神经元来处理这个输入,在信息传递过程中做到信息浓缩而不过多丢失细节。
隐藏层1(512 -> 256将512维的输入映射到256维。继续收缩特征空间,减少特征数量,同时保持重要的表示。
输出层(256 -> 1):最后一层只输出一个值,通过 Sigmoid 激活函数,将其映射到[0,1]区间,表示判别结果的概率分数。
两个激活函数:
LeakyReLU:在隐藏层(512和256)后面使用,它是一种改进的ReLU函数,通过线性保持负数部分而不是直接切断它们(默认斜率为0.2),这有助于缓解神经元的"死亡"问题并允许在负值区域有一定的梯度传播。
Sigmoid:在最后的输出层后面使用,将线性输出值压缩到[0,1]区间,使其成为一个概率值,适合二分类任务。

损失函数和优化器

定义了 Generator 和 Discriminator 后,损失函数使用MindSpore中二进制交叉熵损失函数BCELoss ;这里生成器和判别器都是使用Adam优化器,但是需要构建两个不同名称的优化器,分别用于更新两个模型的参数,详情见下文代码。注意优化器的参数名称也需要修改。

lr = 0.0002  # 学习率

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

BCELoss(Binary Cross Entropy Loss)是二分类问题常用的一种损失函数。

Adam(Adaptive Moment Estimation)优化器是一种能很好地适应训练过程中参数调整的优化算法。

模型训练

训练分为两个主要部分。

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 的值。

第二部分是训练生成器。如论文所述,最小化来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将隐码批量推送到生成器中,以直观地跟踪生成器 Generator 的训练效果。

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint

total_epoch = 12  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小

# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'

checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径
# 生成器计算损失过程
def generator_forward(test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
    return loss_g


# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    real_out = net_d(real_data)
    real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
    fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
    loss_d = real_loss + fake_loss
    return loss_d

# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())

def train_step(real_data, latent_code):
    # 计算判别器损失和梯度
    loss_d, grads_d = grad_d(real_data, latent_code)
    optimizer_d(grads_d)
    loss_g, grads_g = grad_g(latent_code)
    optimizer_g(grads_g)

    return loss_d, loss_g

# 保存生成的test图像
def save_imgs(gen_imgs1, idx):
    for i3 in range(gen_imgs1.shape[0]):
        plt.subplot(5, 5, i3 + 1)
        plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
        plt.axis("off")
    plt.savefig(image_path + "/test_{}.png".format(idx))

# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)

net_g.set_train()
net_d.set_train()

# 储存生成器和判别器loss
losses_g, losses_d = [], []

for epoch in range(total_epoch):
    start = time.time()
    for (iter, data) in enumerate(mnist_ds):
        start1 = time.time()
        image, latent_code = data
        image = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]
        image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
        d_loss, g_loss = train_step(image, latent_code)
        end1 = time.time()
        if iter % 10 == 10:
            print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
                  f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
                  f"loss_d:{d_loss.asnumpy():>4f} , "
                  f"loss_g:{g_loss.asnumpy():>4f} , "
                  f"time:{(end1 - start1):>3f}s, "
                  f"lr:{lr:>6f}")

    end = time.time()
    print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))

    losses_d.append(d_loss.asnumpy())
    losses_g.append(g_loss.asnumpy())

    # 每个epoch结束后,使用生成器生成一组图片
    gen_imgs = net_g(test_noise)
    save_imgs(gen_imgs.asnumpy(), epoch)

    # 根据epoch保存模型权重文件
    if epoch % 1 == 0:
        save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
        save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))
time of epoch 1 is 84.03s
time of epoch 2 is 7.17s
time of epoch 3 is 7.07s
time of epoch 4 is 6.86s
time of epoch 5 is 7.00s
time of epoch 6 is 6.96s
time of epoch 7 is 7.02s
time of epoch 8 is 7.02s
time of epoch 9 is 7.00s
time of epoch 10 is 6.88s
time of epoch 11 is 6.91s
time of epoch 12 is 6.93s

效果展示

运行下面代码,描绘DG损失与训练迭代的关系图:

plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

可视化训练过程中通过隐向量生成的图像。

import cv2
import matplotlib.animation as animation

# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):
    image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):
    plt.axis("off")
    show_list.append([plt.imshow(image_list[epoch], cmap='gray')])

ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)

epoch为100时:

epoch为200时:

训练过程测试动态图

从上面的图像可以看出,随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当 epoch 达到100以上时,生成的手写数字图片与数据集中的较为相似。下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:

模型推理

下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:

import mindspore as ms

# test_ckpt = './result/checkpoints/Generator199.ckpt'

# parameter = ms.load_checkpoint(test_ckpt)
# ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):
    fig.add_subplot(5, 5, i + 1)
    plt.axis("off")
    plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()

epoch为12时效果:

epoch为200时效果:

GAN的基本原理:GAN由生成器(Generator)和判别器(Discriminator)组成。生成器的目标是生成看起来像真实数据的假数据,而判别器的目标是区分真实数据和生成器生成的假数据。通过这种对抗训练过程,生成器和判别器相互竞争,最终生成器能够生成高质量的假数据。
数据集处理:使用了MNIST手写数字数据集进行训练。数据集通过MindSpore的MnistDataset接口进行加载和预处理。
模型构建:生成器和判别器都使用了全连接层(Dense)和激活函数(ReLU、LeakyReLU、Sigmoid)。生成器通过多层全连接层将随机噪声映射到图像空间,而判别器则通过多层全连接层将图像映射到概率空间。
损失函数和优化器:使用了二进制交叉熵损失函数(BCELoss)来计算生成器和判别器的损失。
使用了Adam优化器来更新生成器和判别器的参数。
训练过程:训练过程分为两个主要部分:训练判别器和训练生成器。在每个epoch结束时,生成一组固定的随机噪声输入到生成器中,以评估生成器的性能。训练过程中,生成器和判别器的损失被记录下来,并在训练结束后进行可视化。
结果展示:通过加载预训练的生成器模型参数文件,生成新的手写数字图像。生成的图像通过Matplotlib进行展示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1930821.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iOS——MRC与ARC以及自动释放池深入底层学习

MRC与ARC再回顾 在前面&#xff0c;我们简单学了MRC与ARC。MRC指手动内存管理&#xff0c;需要开发者使用retain、release等手动管理对象的引用计数&#xff0c;确保对象在必要时被释放。ARC指自动内存管理&#xff0c;由编译器自动管理对象的引用计数&#xff0c;开发者不需要…

SQL注入问题

一、什么是sql注入 public class TestSql {public static void main(String[] args) {Scanner inScanner new Scanner(System.in);System.out.println("请输入用户名");String username inScanner.nextLine();System.out.println("请输入密码");String …

php基础: 三角形

包含&#xff1a;左三角、左上三角、右三角、右上三角、等腰三角、倒等腰三角。注意空格的数量&#xff0c;因为*号后面加了空格 /*** * 左三角形* param $n* return void*/ function triangleLeft($n){echo <pre>;for ($i 1; $i < $n; $i) {for ($j 1; $j < $i…

MongoDB常用命令大全

文章目录 一、MongoDB简介二、服务启动停止备份三、数据库相关四、集合操作五、文档操作六、其他常用命令 一、MongoDB简介 MongoDB是一款流行的NoSQL数据库&#xff0c;以其灵活的文档模型、高可用性、易于扩展等特性而受到广泛关注。 MongoDB 是由C语言编写的&#xff0c;是…

C# modbus 图表

控件&#xff1a;chart1(图表)&#xff0c;cartesianChart1(第三方添加图表)&#xff0c;timer(时间) 添加第三方&#xff1a; 效果&#xff1a;图标会根据连接的温度&#xff0c;湿度用timer时间进行改变 Chart1控件样式&#xff1a;Series添加线条&#xff0c;颜色&#xf…

劳易测应用案例 汽车零部件装配线光电传感器解决方案

汽车零部件种类繁多&#xff0c;形状、尺寸、功能各异&#xff0c;生产线的规划与布局必须紧密贴合产品的独特工艺、精细装配流程及高效生产需求。随着电动汽车时代的到来&#xff0c;生产标准愈加严格&#xff0c;对生产线的设计和装配周期提出了更高要求。市场要求生产线不仅…

EE trade:强平和爆仓的区别

在金融交易市场中&#xff0c;杠杆交易的引入&#xff0c;让投资者可以用少量的资金撬动更大的头寸&#xff0c;获取更大的收益。然而&#xff0c;杠杆交易也带来了更大的风险&#xff0c;一旦市场波动&#xff0c;投资者可能会面临强平或爆仓的风险。了解强平和爆仓的区别&…

MySQL-对数据库和表的DDL命令

文章目录 一、什么是DDL操作二、数据库编码集和数据库校验集三、使用步骤对数据库的增删查改1.创建数据库2.进入数据库3.显示数据库4.修改数据库mysqldump 5.删除数据库 对表的增删查改1.添加/创建表2.插入表内容3.查看表查看所有表查看表结构查看表内容 4.修改表修改表的名字修…

保障低压设备安全!中国星坤连接器精密工艺解析!

在现代电子设备中&#xff0c;连接器扮演着至关重要的角色&#xff0c;它们是电子系统之间沟通的桥梁。随着技术的发展&#xff0c;对连接器的需求也在不断提升&#xff0c;特别是在低电压应用领域。中国星坤最新推出的低压连接器&#xff0c;以其精密性和安全性&#xff0c;为…

Msql数据库之DDL(数据定义语言)的相关操作

数据定义语言(DDL)&#xff1a;用于创建、修改和删除数据库对象&#xff0c;如数据库、表、视图、索引等 一、数据库的相关操作&#xff1a; 1、创建数据库 语法&#xff1a;create database [if not exists ] 数据库名; 例&#xff1a;create database if not exists test…

2024-07-16 Unity插件 Odin Inspector6 —— Group Attributes

文章目录 1 说明2 Group 特性2.1 BoxGroup2.2 ButtonGroup2.3 FoldoutGroup2.4 ShowIfGroup / HideIfGroup2.5 HorizontalGroup2.6 ResponsiveButtonGroup2.7 TabGroup2.8 ToggleGroup2.9 VerticalGroup 1 说明 ​ 本文介绍 Odin Inspector 插件中有关 Group 特性的使用方法。…

【Apache POI】Java解析Excel文件并处理合并单元格-粘贴即用

同为牛马&#xff0c;点个赞吧&#xff01; 一、Excel文件样例 二、工具类源码 import org.apache.poi.ss.usermodel.*; import org.apache.poi.ss.util.CellRangeAddress; import org.apache.poi.xssf.usermodel.XSSFWorkbookFactory; import org.springframework.web.multip…

【B树、B-树、B+、B*树】

目录 一、B-树&#xff08;即B树&#xff09;的定义及操作1.1、定义1.2、操作1.2.1、查找1.2.2、插入1.2.3、删除 二、B树的定义及操作2.1、定义2.2、操作2.2.1、查找2.2.2、插入2.2.3、删除 三、B*树 一、B-树&#xff08;即B树&#xff09;的定义及操作 1.1、定义 B-tree即…

解决vue3中el-input在form表单按下回车刷新页面

问题&#xff1a;在input框中点击回车之后不是调用我写的回车事件&#xff0c;而是刷新页面 原因&#xff1a; 如果表单中只有一个input 框则按下回车会直接关闭表单 所以导致刷新页面 解决方法 &#xff1a; 再写一个input 表单 &#xff0c;并设置style"display:none&…

【对顶堆 优先队列】2102. 序列顺序查询

本文涉及知识点 对顶堆 优先队列 LeetCode 2102. 序列顺序查询 一个观光景点由它的名字 name 和景点评分 score 组成&#xff0c;其中 name 是所有观光景点中 唯一 的字符串&#xff0c;score 是一个整数。景点按照最好到最坏排序。景点评分 越高 &#xff0c;这个景点越好。…

再谈有关JVM中的四种引用

1.强引用 强引用就是我们平时使用最多的那种引用&#xff0c;就比如以下的代码 //创建一个对象 Object obj new Object();//强引用 这个例子就是创建了一个对象并建立了强引用&#xff0c;强引用一般就是默认支持的当内存不足的时候&#xff0c;JVM开始垃圾回收&#xff0c…

【Java--数据结构】二叉树oj题(上)

前言 欢迎关注个人主页&#xff1a;逸狼 创造不易&#xff0c;可以点点赞吗~ 如有错误&#xff0c;欢迎指出~ 判断是否是相同的树 oj链接 要判断树是否一样&#xff0c;要满足3个条件 根的 结构 和 值 一样左子树的结构和值一样右子树的结构和值一样 所以就可以总结以下思路…

js补环境系列之剖析:原型、原型对象、实例对象三者互相转化(不讲废话、全是干货)

【作者主页】&#xff1a;小鱼神1024 【擅长领域】&#xff1a;JS逆向、小程序逆向、AST还原、验证码突防、Python开发、浏览器插件开发、React前端开发、NestJS后端开发等等 思考下&#xff1a;js补环境中&#xff0c;什么场景会用到原型、原型对象、实例对象&#xff1f; 举…

最大文件句柄数

优质博文&#xff1a;IT-BLOG-CN 灵感来源 一、什么是文件句柄 文件句柄File Handle是操作系统中用于访问文件的一种数据结构&#xff0c;通常是一个整数或指针。文件句柄用于标识打开的文件&#xff0c;每个打开的文件都有一个唯一的文件句柄。 它们是对文件、网络套接字或…

商业数据分析思维的培训PTT制作大纲分享

商业数据分析思维的培训PTT制作大纲: 基本步骤: 明确PPT的目的和主题 收集并整理相关内容资料 构思并确定PPT的框架大纲 编写PPT的内容文字 插入图片、图表等视觉元素 设计PPT的版式和模板 排练并修改PPT 输出并备份最终版本 目的:数据思维培养; 主题:商业数据分…