Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现

news2025/1/6 16:17:14

什么是GAN

GAN是生成对抗网络,将会根据一个随机向量,实现数据的生成(如生成手写数字、生成文本等)。

GAN的训练过程中,需要有一个生成器G和一个鉴别器D.

生成器用于生成数据,鉴定器用于鉴定数据的准确性,其实就是在鉴别数据是人生成的还是机器生成的,因为生成器需要以假乱真。

鉴别器将会与生成器一起训练。鉴别器将会先训练,这样才有适当的能力去鉴定生成器生成数据的准确性。

鉴别器的训练过程中,需要先给它准确的数据,和通过随机向量传入生成器产生的数据(一律视为负样本),并通过损失函数对其进行训练;生成器训练过程中,会先给它一个随机向量进行前向传播,然后让鉴别器判断其正确性,并通过损失函数(不正确的数据意味着有损失)进行训练:

生成器训练过程中,需要先通过随机向量获取其结果,然后让鉴别器进行鉴别,在通过鉴别器的鉴别结果计算损失(如果鉴别器认为这是生成器生成的,则产生损失),最后更新梯度和参数:

训练过程直到生成器拟合训练集(收敛),判别器的输出总是0.5(均方误差损失函数应为0.25)为止.

形象的GAN的例子

想象一场由一位“名画伪造者”和一位“艺术鉴定家”参与的猫捉老鼠游戏。

在这个场景中,名画伪造者(即GAN中的生成器)的目标是创造出一幅足以欺骗艺术鉴定家(即GAN中的判别器)的假画。开始时,伪造者的技艺并不精湛,他制作的假画充满了破绽,很容易被鉴定家一眼识破。

然而,随着伪造者不断尝试和失败,他逐渐从每一次的失败中学习,逐渐提升了自己的技艺。他开始注意到真画的每一个细节,从笔触、色彩到构图,都尽量模仿得惟妙惟肖。每一次的失败都让他更接近成功,他制作的假画也越来越难以辨别真伪。

而艺术鉴定家也不甘示弱。他开始时能够轻易地识别出伪造者的假画,但随着伪造者技艺的提升,他也需要不断提升自己的鉴定能力。他开始深入研究真画的每一个特点,以便更准确地识别出伪造者的假画。

这个过程就像GAN中的训练过程一样。生成器不断尝试生成新的数据(在这里是假画),而判别器则不断尝试区分这些数据是真实的还是生成的。两者在相互竞争的过程中不断提升自己的能力,最终达到了一个平衡状态。

在这个例子中,名画伪造者就是GAN中的生成器,他负责生成新的数据;而艺术鉴定家则是GAN中的判别器,他负责区分数据的真伪。两者在相互竞争的过程中共同进步,使得生成的数据越来越接近真实的数据。

代码实现

本文将以基于MNIST(手写数据集)为数据集,实现一个生成手写数字的GAN模型:

首先创建models.py,用于定义判别器和生成器:

import paddle


# Generator Code
class Generator(paddle.nn.Layer):
    def __init__(self, ):
        super(Generator, self).__init__()
        self.gen = paddle.nn.Sequential(
            paddle.nn.Linear(in_features=100, out_features=256),
            paddle.nn.ReLU(True),
            paddle.nn.Linear(in_features=256, out_features=512),
            paddle.nn.ReLU(True),
            paddle.nn.Linear(in_features=512, out_features=1024),
            paddle.nn.Tanh(),
        )

    def forward(self, x):
        x = self.gen(x)
        out = paddle.reshape(x,[-1,1,32,32])
        return out


# Discriminator Code
class Discriminator(paddle.nn.Layer):
    def __init__(self, ):
        super(Discriminator, self).__init__()
        self.dis = paddle.nn.Sequential(
            paddle.nn.Linear(in_features=1024, out_features=512),
            paddle.nn.LeakyReLU(0.2),
            paddle.nn.Linear(in_features=512, out_features=256),
            paddle.nn.LeakyReLU(0.2),
            paddle.nn.Linear(in_features=256, out_features=1),
            paddle.nn.Sigmoid()
        )

    def forward(self, x):
        x = paddle.reshape(x, [-1, 1024])
        out = self.dis(x)
        return out

其中,生成器将接收一个长度为100的张量(随机向量),输出一个长度为1024的张量(生成的图片);鉴别器将接收一个长度为1024的张量(图片) ,输出长度为1的张量(鉴别结果)

然后创建main.py,用于训练:

import paddle
import matplotlib.pyplot as plt
from models import Generator, Discriminator
import numpy as np

dataset = paddle.vision.datasets.MNIST(mode='train',
                                       transform=paddle.vision.transforms.Compose([
                                           paddle.vision.transforms.Resize((32, 32)),
                                           paddle.vision.transforms.Normalize([0], [255])
                                       ]))

dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)

netG = Generator()
netD = Discriminator()

if 1:
    try:
        mydict = paddle.load('generator.params')
        netG.set_dict(mydict)
        mydict = paddle.load('discriminator.params')
        netD.set_dict(mydict)
    except:
        print('fail to load model')

optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)

# 最大迭代epoch
max_epoch = 10

for epoch in range(max_epoch):
    now_step = 0
    for step, (data, label) in enumerate(dataloader):
        ############################
        # (1) 更新鉴别器
        ###########################

        # 清除D的梯度
        optimizerD.clear_grad()

        # 传入正样本,并更新梯度
        pos_img = data
        label = paddle.full([pos_img.shape[0], 1], 1, dtype='float32')
        pre = netD(pos_img)
        loss_D_1 = paddle.nn.functional.mse_loss(pre, label)
        loss_D_1.backward()

        # 通过randn构造随机数,制造负样本,并传入D,更新梯度
        noise = paddle.randn([pos_img.shape[0], 100], 'float32')
        neg_img = netG(noise)
        label = paddle.full([pos_img.shape[0], 1], 0, dtype='float32')
        pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算
        loss_D_2 = paddle.nn.functional.mse_loss(pre, label)
        loss_D_2.backward()

        # 更新D网络参数
        optimizerD.step()
        optimizerD.clear_grad()

        loss_D = loss_D_1 + loss_D_2

        ############################
        # (2) 更新生成器
        ###########################

        # 清除D的梯度
        optimizerG.clear_grad()

        noise = paddle.randn([pos_img.shape[0], 100], 'float32')
        fake = netG(noise)
        label = paddle.full((pos_img.shape[0], 1), 1, dtype=np.float32, )
        output = netD(fake)
        # 这个写法没有问题,因为这个mse_loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度
        loss_G = paddle.nn.functional.mse_loss(output, label)
        loss_G.backward()

        # 更新G网络参数
        optimizerG.step()
        optimizerG.clear_grad()

        now_step += 1

        ###########################
        # 输出日志
        ###########################
        if now_step % 100 == 0:
            print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')

paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

如果是第一次训练或不使用原有训练参数,可以将if 1改成if 0.

接下来创建use.py,用于生成图片:

import paddle
from models import Generator
import matplotlib.pyplot as plt

import paddle
from models import Generator
import matplotlib.pyplot as plt
import numpy as np

# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)

# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格

# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):
    noise = paddle.randn([1, 100], 'float32')
    img = netG(noise)
    img = img.numpy()[0][0]  # img.numpy():张量转np数组
    img[img < 0] = 0  # 将img中所有小于0的元素赋值为0
    img = np.clip(img, 0, 1)  # 将img中所有小于0的元素设为0,大于1的设为1(如果需要)

    # 显示图片
    ax.imshow(img)
    ax.axis('off')  # 不显示坐标轴

# 显示图像
plt.show()

进行多轮训练后,生成结果:

 

可以看到,它很好的生成了我们想要的图片。

GANs

但是,我们这个模型只能随机产生数字,还不能生成指定的数字(如让机器生成一个1).为了解决这个问题,我们可以针对每一个数字生成一个对应的GAN,所有这样的GAN组合起来,就是GANs. 这里不展开讲解。

参考

MNIST数据集下用Paddle框架的动态图模式玩耍经典对抗生成网络(GAN)-使用文档-PaddlePaddle深度学习平台

【飞桨PaddlePaddle】四天搞懂生成对抗网络(一)——通俗理解经典GAN_四天搞懂生成对抗网络(一)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1657788.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2-6 任务 猜数小游戏(单次版)

本任务要求编写一个猜数小游戏&#xff08;单次版&#xff09;&#xff0c;游戏规则是计算机产生一个0到100之间的随机整数&#xff0c;用户通过输入猜测的数字进行猜测&#xff0c;根据猜测情况给出提示&#xff0c;直到猜对为止。编程思路是利用while循环和多分支结构实现永真…

Linux 第二十四章

&#x1f436;博主主页&#xff1a;ᰔᩚ. 一怀明月ꦿ ❤️‍&#x1f525;专栏系列&#xff1a;线性代数&#xff0c;C初学者入门训练&#xff0c;题解C&#xff0c;C的使用文章&#xff0c;「初学」C&#xff0c;linux &#x1f525;座右铭&#xff1a;“不要等到什么都没有了…

python面向函数

组织好的&#xff0c;可重复利用的&#xff0c;用来实现单一&#xff0c;或相关联功能的代码段&#xff0c;避免重复造轮子&#xff0c;增加程序复用性。 定义方法为def 函数名 (参数) 参数可动态传参&#xff0c;即使用*args代表元组形式**kwargs代表字典形式&#xff0c;代替…

探索智能编程新境界:我与Baidu Comate的独特体验之旅

文章目录 一、认识Baidu Comate二、VS Code安装Baidu Comate教程三、Baidu Comate功能体验功能概览具体功能1.根据注释自动生成代码2.函数注释3.行间注释4.代码解释5.生成单元测试6.代码优化7.答疑解惑 四、交互体验五、总结 一、认识Baidu Comate ✨Baidu Comate插件是一款基…

如何在PPT中插入网页?这样操作,免费还高效!

融合课、跨学科课&#xff0c;已经是近两年来教育界的热门词。 在公开课、微课比赛中&#xff0c;不添融合一些较为先进的信息技术&#xff0c;都不好意思拿出手了。 最近&#xff0c;由不坑老师开发制作的Office插件——不坑盒子&#xff0c;实现了在PPT中插入网页&#xff…

鸿蒙开发接口Ability框架:【(StaticSubscriberExtensionAbility)】

StaticSubscriberExtensionAbility StaticSubscriberExtensionAbility模块提供静态订阅者扩展能力的类别的能力。 说明&#xff1a; 本模块首批接口从API version 9 开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。 本模块接口仅可在Stage模型下…

987: 输出用先序遍历创建的二叉树是否为完全二叉树的判定结果

解法&#xff1a; 一棵二叉树是完全二叉树的条件是&#xff1a; 对于任意一个结点&#xff0c;如果它有右子树而没有左子树&#xff0c;则这棵树不是完全二叉树。 如果一个结点有左子树但是没有右子树&#xff0c;则这个结点之后的所有结点都必须是叶子结点。 如果满足以上条…

ADOP带你了解:温度如何影响您的室外以太网电缆?

温度&#xff1a;室外以太网电缆的隐形敌人 在构建和维护室外以太网网络时&#xff0c;我们通常会考虑到许多物理因素&#xff0c;如电缆的长度、宽带容量和连接质量。然而&#xff0c;有一个不那么显眼但同样重要的因素常常被忽视&#xff0c;那就是温度。温度的波动不仅影响…

物联网实战--平台篇之(四)账户后台交互

目录 一、交互逻辑 二、请求验证码 三、帐号注册 四、帐号/验证码登录 五、重置密码 本项目的交流QQ群:701889554 物联网实战--入门篇https://blog.csdn.net/ypp240124016/category_12609773.html 物联网实战--驱动篇https://blog.csdn.net/ypp240124016/category_12631…

《21天学通C++》(第二十章)STL映射类(map和multimap)

为什么需要map和multimap&#xff1a; 1.查找高效&#xff1a; 映射类允许通过键快速查找对应的值&#xff0c;这对于需要频繁查找特定元素的场景非常适合。 2.自动排序&#xff1a; 会自动根据键的顺序对元素进行排序 3.多级映射&#xff1a; 映射类可以嵌套使用&#xff0c;创…

java.net.SocketInputStream.socketRead0 卡死导致 tomcat 线程池打满的问题

0 TL;DR; 问题与原因&#xff1a;某些特定条件下 java.net.SocketInputStream.socketRead0 方法会卡死&#xff0c;导致运行线程一直被占用导致泄露采用的方案&#xff1a;使用监控线程异步监控卡死事件&#xff0c;如果发生直接关闭网络连接释放链接以及对应的线程 1. 问题 …

贪心算法--将数组和减半的最小操作数

本题是力扣2208---点击跳转题目 思路&#xff1a; 要尽快的把数组和减小&#xff0c;那么每次挑出数组中最大的元素减半即可&#xff0c;由于每次都是找出最值元素&#xff0c;可以用优先队列来存储这些数组元素 每次取出最值&#xff0c;减半后再放入优先队列中&#xff0c;操…

最新:Lodash 严重安全漏洞背后你不得不知道的 JavaScript 知识

可能有信息敏感的同学已经了解到&#xff1a;Lodash 库爆出严重安全漏洞&#xff0c;波及 400万 项目。这个漏洞使得 lodash “连夜”发版以解决潜在问题&#xff0c;并强烈建议开发者升级版本。 我们在忙着“看热闹”或者“”升级版本”的同时&#xff0c;静下心来想&#xf…

如何通过代理IP实现搜索引擎优化

目录 前言 一、代理IP的基本概念 二、通过代理IP访问其他地区的搜索引擎 三、对比不同地区搜索结果 结论 前言 搜索引擎优化&#xff08;Search Engine Optimization&#xff0c;SEO&#xff09;是指通过优化网站的结构、内容和关键词等因素&#xff0c;提高网站在搜索引…

ubuntu挂载固态硬盘

ubuntu挂载固态硬盘 两种情况 包装盒拆出来的新硬盘用过的需要后处理的硬盘 新硬盘 一、确认硬盘设备 插上主机后输入 lsblk检查是否识别到你插入的硬盘 可以看到上图的nvme0n1是我挂载的硬盘&#xff08;目前已经挂载完成并映射到 ~/ssd目录&#xff09;&#xff0c;nvm…

如果你这样使用电路仿真软件,你就无敌了!

在电子设计领域&#xff0c;电路仿真软件如同一把锋利的宝剑&#xff0c;掌握它&#xff0c;你就能在复杂的电子世界中游刃有余。今天&#xff0c;就让我们一起探讨如何高效利用电路仿真软件&#xff0c;让你在电子设计领域所向披靡&#xff01; 一、熟悉软件界面与基础操作 …

点击短信链接唤起Android App实战

一.概述 在很多业务场景中,需要点击短信链接跳转到App的指定页面。在Android系统中,想要实现这个功能,可以通过DeepLink或AppLink实现。二.方案 2.1 DeepLink 2.1.1 方案效果 DeepLink是Android系统最基础、最普遍、最广泛的外部唤起App的方式,不受系统版本限制。当用户…

基于Vue3与ElementUI Plus的酷企秀场景可视化DIY设计器:前端技术引领下的数字化展示新篇章

一、引言 在当今信息化高速发展的时代&#xff0c;企业对于展示自身形象、提升用户体验以及增强品牌知名度的需求日益迫切。针对这一市场需求&#xff0c;我们推出了基于Vue3与ElementUI Plus的酷企秀场景可视化DIY设计器。该产品不仅具备电子画册、VR全景、地图秀三大核心功能…

2024年自动驾驶、车辆工程与智能交通国际会议(ICADVEIT2024)

2024年自动驾驶、车辆工程与智能交通国际会议&#xff08;ICADVEIT2024&#xff09; 会议简介 2024年自动驾驶、车辆工程和智能交通国际会议&#xff08;ICADVEIT 2024&#xff09;将在中国深圳举行。会议主要聚焦自动驾驶、车辆工程和智能交通等研究领域&#xff0c;旨在为从…

pytest教程-42-钩子函数-pytest_runtest_makereport

领取资料&#xff0c;咨询答疑&#xff0c;请➕wei: June__Go 上一小节我们学习了pytest_runtest_teardown钩子函数的使用方法&#xff0c;本小节我们讲解一下pytest_runtest_makereport钩子函数的使用方法。 pytest_runtest_makereport 钩子函数在 pytest 为每个测试生成报…