昇思25天学习打卡营第5天|GAN图像生成

news2025/1/12 16:16:26

文章目录

      • 昇思MindSpore应用实践
        • 基于MindSpore的生成对抗网络图像生成
          • 1、生成对抗网络简介
            • 零和博弈 vs 极大极小博弈
            • GAN的生成对抗损失:
          • 2、基于MindSpore的 Vanilla GAN
          • 3、基于MindSpore的手写数字图像生成
            • 导入数据
            • 数据可视化
            • 模型训练
      • Reference

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

基于MindSpore的生成对抗网络图像生成
1、生成对抗网络简介
零和博弈 vs 极大极小博弈

生成对抗网络Generative adversarial networks (GANs)主要包括生成器网络(Generator)和判别器网络(Discriminator)
这两个网络在GAN的训练过程中相互竞争,形成了一种博弈论中的极大极小博弈(MinMax game)

零和博弈(Zero-sum game)是博弈论中的一个重要概念,指的是参与者的利益完全相反,即一方的利益的增加意味着另一方的利益的减少,总利益为零。在零和博弈中,参与者之间的利益是完全对立的,因此一个参与者的利益的增加必然导致其他参与者的利益减少。在非合作博弈中,纳什均衡是一种重要的解,纳什均衡代表每个玩家选择的策略都是其在对方策略给定的情况下的最优策略。在零和博弈中,寻找纳什均衡通常涉及找到使每个玩家的预期收益最大化的策略组合。

极大极小博弈(MinMax game)是一种博弈论中的解决方法,用于确定参与者的最佳决策策略,此外为人所熟知用于决策的方法还有强化学习。在极大极小博弈中,每个参与者都试图最大化自己的最小收益。也就是说,每个参与者都采取行动,以确保在对手选择其最优策略时自己的收益最大化。

假设GAN网络训练达到了纳什平衡状态,那么判别器无法准确地判断出输入样本是真样本还是假样本,此时判别器失效,生成器达到了巅峰状态,我们就无需使用判别器并终止训练了,得到的生成器就是我们用来生成数据的预训练模型。

在这里插入图片描述
从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  2. 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据。
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2,如上图中的(d)所示。
GAN的生成对抗损失:

min ⁡ G max ⁡ D V ( G , D ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(G,D)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

GAN网络本身就是在训练一个能达到平衡状态的损失函数,生成对抗损失是GANs中最基本的损失函数。

近十年来著名的GAN网络结构:
在这里插入图片描述

2、基于MindSpore的 Vanilla GAN

生成器部分:生成器 Generator 的功能是将隐码映射到数据空间。通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内,并返回一张28x28的图像作为生成结果。

from mindspore import nn
import mindspore.ops as ops

img_size = 28  # 训练图像长(宽)28x28

class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,通过第一层线性变换将其映射到128维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # 通过第二层线性变换将其映射到256维
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, 28, 28))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器部分:判别器 Discriminator 是一个二分类网络模型,在训练时,判别器接收生成器的生成图像与对应的真实数据相对比,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。

 # 判别器
class Discriminator(nn.Cell):
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 784] -> [N, 512]
        self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
        self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
        # [N, 512] -> [N, 256]
        self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
        self.model.append(nn.LeakyReLU())
        # [N, 256] -> [N, 1]
        self.model.append(nn.Dense(256, 1))
        self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')
3、基于MindSpore的手写数字图像生成
导入数据
import numpy as np
import mindspore.dataset as ds

batch_size = 128
latent_size = 100  # 潜在编码的长度

train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')

def data_load(dataset):
    dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False)
    
    # 数据增强
    mnist_ds = dataset1.map(  # 通过map方法给每张图像映射一个潜在编码
    # 将图像数据转换为 float32 类型
    # 生成一个长度为 latent_size 的服从正态分布的随机向量,并将其转换为 float32 类型
        operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
        output_columns=["image", "latent_code"])
    mnist_ds = mnist_ds.project(["image", "latent_code"])

    # 批量操作
    mnist_ds = mnist_ds.batch(batch_size, True)

    return mnist_ds

mnist_ds = data_load(train_dataset)

iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
数据可视化
import matplotlib.pyplot as plt

data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):
    image = data_iter['image'][idx]
    figure.add_subplot(rows, cols, idx)
    plt.axis("off")
    plt.imshow(image.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述
潜在编码(latent code)的构造:
为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise输入到生成器中,通过这组固定的潜在编码(也叫隐码)所生成的图像效果来评估生成器的生成质量。

import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype

# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)
模型训练

定义损失函数和优化器:

lr = 0.0002  # 学习率

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

训练分为两个主要部分,也就是要训练两个网络:生成与对抗网络。

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 l o g D ( x ) + l o g ( 1 − D ( G ( z ) ) log D(x) + log(1 - D(G(z)) logD(x)+log(1D(G(z)) 的值。

第二部分是训练生成器。如论文所述,最小化 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1D(G(z))) 来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将固定隐码批量推送到生成器中,以直观地跟踪生成器 Generator 的训练效果。

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint

total_epoch = 24  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小

# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'

checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径

# 生成器计算损失过程
def generator_forward(test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
    return loss_g

# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    real_out = net_d(real_data)
    real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
    fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
    loss_d = real_loss + fake_loss
    return loss_d

# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())

def train_step(real_data, latent_code):
    # 计算判别器损失和梯度
    loss_d, grads_d = grad_d(real_data, latent_code)
    optimizer_d(grads_d)
    loss_g, grads_g = grad_g(latent_code)
    optimizer_g(grads_g)

    return loss_d, loss_g

# 保存生成的test图像
def save_imgs(gen_imgs1, idx):
    for i3 in range(gen_imgs1.shape[0]):
        plt.subplot(5, 5, i3 + 1)
        plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
        plt.axis("off")
    plt.savefig(image_path + "/test_{}.png".format(idx))

# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)

net_g.set_train()
net_d.set_train()

# 储存生成器和判别器loss
losses_g, losses_d = [], []

for epoch in range(total_epoch):
    start = time.time()
    for (iter, data) in enumerate(mnist_ds):
        start1 = time.time()
        image, latent_code = data
        image = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]
        image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
        d_loss, g_loss = train_step(image, latent_code)
        end1 = time.time()
        if iter % 10 == 0:
            print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
                  f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
                  f"loss_d:{d_loss.asnumpy():>4f} , "
                  f"loss_g:{g_loss.asnumpy():>4f} , "
                  f"time:{(end1 - start1):>3f}s, "
                  f"lr:{lr:>6f}")

    end = time.time()
    print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))

    losses_d.append(d_loss.asnumpy())
    losses_g.append(g_loss.asnumpy())

    # 每个epoch结束后,使用生成器生成一组图片
    gen_imgs = net_g(test_noise)
    save_imgs(gen_imgs.asnumpy(), epoch)

    # 根据epoch保存模型权重文件
    if epoch % 1 == 0:
        save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
        save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'Wayn_Fan-sail')

使用cpu进行12个epoch的生成效果如下:
在这里插入图片描述

Reference

昇思官方文档-GAN图像生成
昇思大模型平台

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1879345.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【单片机毕业设计选题24034】-基于STM32的手机智能充电系统

系统功能: 系统可以设置充电时长,启动充电后按设置的充电时长充电,充电时间到后自动 停止充电,中途检测到温度过高也会结束充电并开启风扇和蜂鸣器报警。 系统上电后,OLED显示“欢迎使用智能充电系统请稍后”,两秒钟…

《强化学习的数学原理》(2024春)_西湖大学赵世钰 Ch8 值函数拟合 【基于近似函数的 TD 算法:Sarsa、Q-leaning、DQN】

PPT 截取有用信息。 课程网站做习题。总体 MOOC 过一遍 1、学堂在线 视频 习题 2、相应章节 过电子书 复习 【下载: 本章 PDF GitHub 页面链接】 3、 MOOC 习题 跳过的 PDF 内容 学堂在线 课程页面链接 中国大学MOOC 课程页面链接 B 站 视频链接 PPT和书籍下载网址…

大厂面试官问我:哨兵怎么实现的,Redis 主节点挂了,获取锁时会有什么问题?【后端八股文六:Redis集群八股文合集】

往期内容: 大厂面试官问我:Redis处理点赞,如果瞬时涌入大量用户点赞(千万级),应当如何进行处理?【后端八股文一:Redis点赞八股文合集】-CSDN博客 大厂面试官问我:布隆过滤…

Shell 脚本编程保姆级教程(下)

七、Shell 流程控制 7.1 if #!/bin/bash num1100 if test $[num1] 100 thenecho num1 是 100 fi 7.2 if else #!/bin/bash num1100 num2100 if test $[num1] -eq $[num2] thenecho 两个数相等! elseecho 两个数不相等! fi 7.3 if else-if else #!/…

深度学习 - Transformer 组成详解

整体结构 1. 嵌入层(Embedding Layer) 生活中的例子:字典查找 想象你在读一本书,你不认识某个单词,于是你查阅字典。字典为每个单词提供了一个解释,帮助你理解这个单词的意思。嵌入层就像这个字典&#xf…

代码随想录-二叉搜索树(1)

目录 二叉搜索树的定义 700. 二叉搜索树中的搜索 题目描述: 输入输出示例: 思路和想法: 98. 验证二叉搜索树 题目描述: 输入输出示例: 思路和想法: 530. 二叉搜索树的最小绝对差 题目描述&#x…

IOS Swift 从入门到精通:ios 连接数据库 安装 Firebase 和 Firestore

创建 Firebase 项目 导航到Firebase 控制台并创建一个新项目。为项目指定任意名称。 在这里插入图片描述 下一步,启用 Google Analytics,因为我们稍后会用到它来发送推送通知。 在这里插入图片描述 在下一个屏幕上,选择您的 Google Analytics 帐户(如果已创建)。如果没…

java第三十课 —— 面向对象练习题

面向对象编程练习题 第一题 定义一个 Person 类 {name, age, job},初始化 Person 对象数组,有 3 个 person 对象,并按照 age 从大到小进行排序,提示,使用冒泡排序。 package com.hspedu.homework;import java.util.…

使用slenium对不同元素进行定位实战篇~

单选框Radio定位: 单选框只能点击一个,并且点击之后并不会被取消,而多选框,能够点击多个,并且点击之后可以取消 import org.junit.Test; import org.openqa.selenium.By; import org.openqa.selenium.WebDriver; imp…

基于python和opencv实现边缘检测程序

引言 图像处理是计算机视觉中的一个重要领域,它在许多应用中扮演着关键角色,如自动驾驶、医疗图像分析和人脸识别等。边缘检测是图像处理中的基本任务之一,它用于识别图像中的显著边界。本文将通过一个基于 Python 和 OpenCV 的示例程序&…

intellij idea安装R包ggplot2报错问题求解

1、intellij idea安装R包ggplot2问题 在我上次解决图形显示问题后,发现安装ggplot2包时出现了问题,这在之前高版本中并没有出现问题, install.packages(ggplot2) ERROR: lazy loading failed for package lifecycle * removing C:/Users/V…

Android 10.0 关于定制自适应AdaptiveIconDrawable类型的动态时钟图标的功能实现系列二(拖动到文件夹部分功能实现)

1.前言 在10.0的系统rom定制化开发中,在关于定制动态时钟图标中,原系统是不支持动态时钟图标的功能,所以就需要从新 定制动态时钟图标关于自适应AdaptiveIconDrawable类型的样式,就是可以支持当改变系统图标样式变化时,动态时钟 图标的背景图形也跟着改变,本篇实现在拖…

HBuilder X 小白日记02-布局和网页背景颜色

html&#xff1a; 例子1&#xff1a; 整个&#xff1a; css案例&#xff1a; 1.首先右键&#xff0c;创建css文件 2.在html文件的头部分&#xff0c;引用css&#xff0c;快捷方式&#xff1a;linkTab键 <link rel"stylesheet" href" "> 3.先在css…

操作系统精选题(二)(综合模拟题一)

&#x1f308; 个人主页&#xff1a;十二月的猫-CSDN博客 &#x1f525; 系列专栏&#xff1a; &#x1f3c0;操作系统 &#x1f4aa;&#x1f3fb; 十二月的寒冬阻挡不了春天的脚步&#xff0c;十二点的黑夜遮蔽不住黎明的曙光 目录 前言 简答题 一、进程由计算和IO操作组…

论文阅读之旋转目标检测ARC:《Adaptive Rotated Convolution for Rotated Object Detection》

论文link&#xff1a;link code&#xff1a;code ARC是一个改进的backbone&#xff0c;相比于ResNet&#xff0c;最后的几层有一些改变。 Introduction ARC自适应地旋转以调整每个输入的条件参数&#xff0c;其中旋转角度由路由函数以数据相关的方式预测。此外&#xff0c;还采…

【Unity】Timeline的倒播和修改速度(无需协程)

unity timeline倒播 一、核心&#xff1a; 通过playableDirector.playableGraph.GetRootPlayable(i).SetSpeed(speed)接口&#xff0c;设置PlayableDirector的速度。 二、playableGraph报空 若playableDirector不勾选Play On Awake&#xff0c;则默认没有PlayableGraph&…

Redis基础教程(三):redis命令

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; &#x1f49d;&#x1f49…

oj E : 投资项目的方案

Description 有n种基础的投资项目&#xff0c;每一种的单位收益率为profitn&#xff0c;存在m种投资组合&#xff0c;限制每一种的投资总额不能超过invest_summ 每种投资组合中项目所需的单位投入是不同的&#xff0c;为costmn 求&#xff1a;使得收益率之和最高的每种项目投…

Meven

目录 1.简介2.Maven项目目录结构2.1 约定目录结构的意义2.2 约定大于配置 3. POM.XML介绍3.2 依赖引用3.3 属性管理 4 Maven生命周期4.1 经常遇到的生命周期4.1 全部生命周期 5.依赖范围&#xff08;Scope&#xff09;6. 依赖传递6.1 依赖冲突6.2 解决依赖冲突6.2.1 最近依赖者…

1、线性回归模型

1、主要解决问题类型 1.1 预测分析(Prediction) 线性回归可以用来预测一个变量(通常称为因变量或响应变量)的值,基于一个或多个输入变量(自变量或预测变量)。例如,根据房屋的面积、位置等因素预测房价。 1.2 异常检测(Outlier Detection) 线性回归可以帮助识别数…