昇思25天学习打卡营第17天|GAN图像生成

news2024/11/18 10:38:39

模型简介

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型G和估计样本是否来自训练数据的判别模型D 。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

gan

数据集

使用MNIST手写数字数据集

数据集下载和加载

# 数据下载
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)

import numpy as np
import mindspore.dataset as ds

batch_size = 64
latent_size = 100  # 隐码的长度

train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')

def data_load(dataset):
    dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False,num_samples=10000)
    # 数据增强
    mnist_ds = dataset1.map(
        operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
        output_columns=["image", "latent_code"])
    mnist_ds = mnist_ds.project(["image", "latent_code"])

    # 批量操作
    mnist_ds = mnist_ds.batch(batch_size, True)

    return mnist_ds

mnist_ds = data_load(train_dataset)

iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)

隐码构造

为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。

import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype

# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)

模型构建

生成器

from mindspore import nn
import mindspore.ops as ops

img_size = 28  # 训练图像长(宽)

class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, 28, 28))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器

# 判别器
class Discriminator(nn.Cell):
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 784] -> [N, 512]
        self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
        self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
        # [N, 512] -> [N, 256]
        self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
        self.model.append(nn.LeakyReLU())
        # [N, 256] -> [N, 1]
        self.model.append(nn.Dense(256, 1))
        self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')

损失函数和优化器

lr = 0.0002  # 学习率

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

模型训练

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。

第二部分是训练生成器。以产生更好的虚拟图像。

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint

total_epoch = 12  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小

# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'

checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径
# 生成器计算损失过程
def generator_forward(test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
    return loss_g


# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    real_out = net_d(real_data)
    real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
    fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
    loss_d = real_loss + fake_loss
    return loss_d

# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())

def train_step(real_data, latent_code):
    # 计算判别器损失和梯度
    loss_d, grads_d = grad_d(real_data, latent_code)
    optimizer_d(grads_d)
    loss_g, grads_g = grad_g(latent_code)
    optimizer_g(grads_g)

    return loss_d, loss_g

# 保存生成的test图像
def save_imgs(gen_imgs1, idx):
    for i3 in range(gen_imgs1.shape[0]):
        plt.subplot(5, 5, i3 + 1)
        plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
        plt.axis("off")
    plt.savefig(image_path + "/test_{}.png".format(idx))

# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)

net_g.set_train()
net_d.set_train()

# 储存生成器和判别器loss
losses_g, losses_d = [], []

for epoch in range(total_epoch):
    start = time.time()
    for (iter, data) in enumerate(mnist_ds):
        start1 = time.time()
        image, latent_code = data
        image = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]
        image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
        d_loss, g_loss = train_step(image, latent_code)
        end1 = time.time()
        if iter % 10 == 10:
            print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
                  f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
                  f"loss_d:{d_loss.asnumpy():>4f} , "
                  f"loss_g:{g_loss.asnumpy():>4f} , "
                  f"time:{(end1 - start1):>3f}s, "
                  f"lr:{lr:>6f}")

    end = time.time()
    print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))

    losses_d.append(d_loss.asnumpy())
    losses_g.append(g_loss.asnumpy())

    # 每个epoch结束后,使用生成器生成一组图片
    gen_imgs = net_g(test_noise)
    save_imgs(gen_imgs.asnumpy(), epoch)

    # 根据epoch保存模型权重文件
    if epoch % 1 == 0:
        save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
        save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

生成图像:

效果展示

描绘D和G损失与训练迭代的关系图:

plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

可视化训练过程中通过隐向量生成的图像。

import cv2
import matplotlib.animation as animation

# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):
    image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):
    plt.axis("off")
    show_list.append([plt.imshow(image_list[epoch], cmap='gray')])

ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)

训练过程测试动态图

可见随着训练次数的增多,图像质量也越来越好。

模型推理

通过加载生成器网络模型参数文件来生成图像,代码如下:

import mindspore as ms

# test_ckpt = './result/checkpoints/Generator199.ckpt'

# parameter = ms.load_checkpoint(test_ckpt)
# ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):
    fig.add_subplot(5, 5, i + 1)
    plt.axis("off")
    plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()

总结

生成对抗网络GAN通过生成器和判别器之间的对抗训练,从而获得高质量的数据扩充。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1896797.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

怎么在表格后添加文字行行

Ctrl Shift 回车 解决方案 在表格末尾添加一行(表格行)。 (此时光标应该默认在这个新行中,如果没有,自己手动点一下) 按 Ctrl Shift 回车 将此行与前面的表格拆分开,中间会出现一个空文本行…

离散傅里叶变换(Discrete Fourier Transform,DFT)

离散傅里叶变换(Discrete Fourier Transform,DFT)是信号分析中的一种基本方法,它将离散时序信号从时间域变换到频率域,是傅里叶变换在时域和频域都呈现离散的形式。以下是关于离散傅里叶变换的详细介绍: 一…

尽量不写一行if...elseif...写出高质量可持续迭代的项目代码

背景 无论是前端代码还是后端代码,都存在着定位困难,不好抽离,改造困难的问题,造成代码开发越来越慢,此外因为代码耦合较高,总是出现改了一处地方,然后影响其他地方,要么就是要修改…

文本超长省略的几种方式(vue)

第一种&#xff0c;纯css 在给容器设置宽度后&#xff0c;使用css来省略文本超长部分&#xff0c;但是这样就看不到全部的内容 <template><div class"content"><div class"text">{{ text }}</div></div> </template>&…

ubuntu 安装说明

最近准备学习Linux&#xff0c;所以下载了最新的ubuntu server版本24.04&#xff0c;将安装步骤记录下来供参考。 1.安装 挂载光驱和iso文件&#xff0c;启动虚拟机。启动后&#xff0c;你会看到 GRUB 菜单上有两个选项&#xff1a; Try or Install Ubuntu Server 和 Test mem…

防水M7/8“航空法兰插座端子

防水M7/8"航空法兰插座广泛应用于传感器与执行器、电机马达、包装与传送系统、户外LED模块、轨道交通、船舶雷达与导航&#xff0c;以及现场总线DeviceNet与NMEA 2000开放型网络系统等应用领域。M7/8"插座作为一种常见的电气连接器件&#xff0c;在传感器领域中扮演着…

快手矩阵系统源码:技术优势解析

在短视频和直播行业迅猛发展的今天&#xff0c;快手凭借其强大的矩阵系统源码&#xff0c;为用户提供了多端管理、多账号管理、素材管理、视频批量上传、AI视频制作和定时发布等一系列高效功能。本文将深入探讨快手矩阵系统源码的多项优势&#xff0c;以及这些功能如何助力内容…

如何改善提示词,让 GPT-4 更高效准确地把视频内容整体转换成文章?

&#xff08;注&#xff1a;本文为小报童精选文章。已订阅小报童或加入知识星球「玉树芝兰」用户请勿重复付费&#xff09; 让我们来讨论一下大语言模型应用中的一个重要原则 ——「欲速则不达」。 作为一个自认为懒惰的人&#xff0c;我一直有一个愿望&#xff1a;完成视频制作…

气象观测站:观测和记录各种气象要素

在广袤无垠的蓝天下&#xff0c;气象观测站如同一个个静默的守护者&#xff0c;默默记录着风云变幻&#xff0c;守护着大地的安宁。 一、气象观测站&#xff1a;守护天空的“千里眼” 气象观测站&#xff0c;顾名思义&#xff0c;就是专门用于观测和记录各种气象要素的站点。它…

UVa1265/LA4848 Tour Belt

UVa1265/LA4848 Tour Belt 题目链接题意分析AC 代码 题目链接 本题是2010年icpc亚洲区域赛大田赛区的F题 题意 给出一个有n个结点m条边的加权无向图G&#xff08;2≤n≤5000&#xff0c;1≤m≤n(n-1)/2&#xff09;&#xff0c;满足如下条件的结点集B&#xff08;2≤|B|≤n&am…

2025深圳国际消费电子展览会

2025深圳国际消费电子展览会 时间&#xff1a;2025年06月25-27日 地点&#xff1a;深圳国际会展中心(新馆) 详询主办方陆先生 I38&#xff08;前三位&#xff09; I82I&#xff08;中间四位&#xff09; 9I72&#xff08;后面四位&#xff09; 展会介绍&#xff1a; 20…

SAP-SD同一物料下单价格确不同

业务说明&#xff1a; 业务部门反馈&#xff0c;同一物料下销售订单时&#xff0c;价格确不同。 那么这个价格是怎么取到的呢&#xff1f; 逻辑说明&#xff1a; 1、首先查看销售订单 可以看到相同物料价格是不同的&#xff0c;条件类型都是ZPR5&#xff0c;但是客户是不同…

【新能源时代!看大模型(LLMs)如何助力汽车自动驾驶!】

文末有福利&#xff01; 引言 本文主要介绍大模型(LLMs)如何助力汽车自动驾驶&#xff0c;简单来说&#xff0c;作者首先带大家了解大模型的工作模式&#xff0c;然后介绍了自动驾驶大模型的3大应用场景&#xff0c;最后指出自动驾驶大模型将会是未来的发展趋势&#xff0c;只…

View->裁剪框View的绘制,手势处理

XML文件 <?xml version"1.0" encoding"utf-8"?> <RelativeLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_parent"android…

CFS三层内网渗透——第二层内网打点并拿下第三层内网(三)

目录 八哥cms的后台历史漏洞 配置socks代理 ​以我的kali为例,手动添加 socks配置好了&#xff0c;直接sqlmap跑 ​登录进后台 蚁剑配置socks代理 ​ 测试连接 ​编辑 成功上线 上传正向后门 生成正向后门 上传后门 ​内网信息收集 ​进入目标二内网机器&#xf…

【linux进程】进程地址空间(什么是进程地址空间?为什么要有进程地址空间?)

目录 一、前言 二、 程序的地址空间是真实的 --- 物理空间吗&#xff1f; 三、进程地址空间 &#x1f525; 操作系统是如何建立起进程与物理内存之间的联系的呢&#xff1f; &#x1f525;什么是进程地址空间&#xff1f; &#x1f525;为什么不能直接去访问物理内存&a…

Protobuf(三):理论学习,简单总结

1. Protocol Buffers概述 Protocol Buffers&#xff08;简称protobuf&#xff09;&#xff0c;是谷歌用于序列化结构化数据的一种语言独立、平台独立且可扩展的机制&#xff0c;类似XML&#xff0c;但比XML更小、更快、更简单protobuf的工作流程如图所示 1.1 protobuf的优点…

Python酷库之旅-第三方库Pandas(003)

目录 一、用法精讲 4、pandas.read_csv函数 4-1、语法 4-2、参数 4-3、功能 4-4、返回值 4-5、说明 4-6、用法 4-6-1、创建csv文件 4-6-2、代码示例 4-6-3、结果输出 二、推荐阅读 1、Python筑基之旅 2、Python函数之旅 3、Python算法之旅 4、Python魔法之旅 …

工业废水中镍超标怎么办?含镍废水处理方法有哪些?

镍是一种存在于自然界中的过渡金属。镍在土壤和岩石中的存量丰富&#xff0c;大部分镍已被氧化&#xff0c;或与其他元素结合成化合物。   含镍废水主要来源于电镀、合金制造、金属表面处理、电子等行业。这些行业在生产过程中&#xff0c;通常会使用含有镍离子的化学试剂&a…

kali改回官方源后更新失败

官方源&#xff1a; deb http://http.kali.org/kali kali-rolling main non-free contrib deb-src http://http.kali.org/kali kali-rolling main non-free contrib在文件 /etc/cat/sources.list中将官方源修改为&#xff1a; deb http://http.kali.org/kali kali-rolling ma…