深度学习 GAN生成对抗网络-1010格式数据生成简单案例

news2024/9/27 19:26:38

一、前言

本文不花费大量的篇幅来推导数学公式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络。

二、GAN概念

生成对抗网络(Generative Adversarial Networks,GAN)包含生成器(Generator)和鉴别器(Discriminator)两个神经网络。生成器用于生成虚假的数据,经过训练后能够生成以假乱真的数据;鉴别器使用真实数据和虚假数据训练后,能够辨别数据的真假;生成器和鉴别器相互博弈,最终达到鉴别器难以区分生成数据真假的状态。

三、案例实战

我们会创建一个GAN,生成器通过学习训练,来创建符合1010格式规律的值。这个任务比生成图像要简单。通过这个任务,我们可以了解GAN的基本代码框架,观察训练进程,进而帮助我们为接下来生成图像的任务做好准备。

我们先引入依赖库:

import matplotlib.pyplot as plt
import pandas
import torch
import torch.nn as nn

2.1 构造真实数据源

真实数据源可以是一个返回1010格式数据的函数,如下所示:

def generate_real():
    real_data = torch.FloatTensor([1,0,1,0])
    return real_data

执行:

generate_real()

结果:

tensor([1., 0., 1., 0.])

但是,在实际生活中,数据往往不是那么精准,我们让其有一定随机性:

def generate_real():
    real_data = torch.FloatTensor(
        [random.uniform(0.8, 1.0),
         random.uniform(0.0, 0.2),
         random.uniform(0.8, 1.0),
         random.uniform(0.0, 0.2)])
    return real_data

random.uniform(0.8, 1.0)产生0.8-1.0之间的随机小数。
执行:

generate_real()

结果:

tensor([0.9782, 0.0673, 0.8500, 0.1788])

2.2 构造随机数据

产生4个随机数,可能满足1010格式,也可能不满足,函数如下:

def generate_random(size):
    random_data = torch.rand(size)
    return random_data

执行:

generate_random(4)

结果:

tensor([0.4241, 0.0611, 0.7684, 0.2931])

2.3 构造鉴别器

鉴别器是一个神经网络,我们的目的是训练出一个能区分真实数据与随机噪声数据的鉴别器。下面代码定义了一个非常简单的神经网络:输入层有4个节点,用于接受输入的4个值;隐藏层有3个节点;输出层输出0~1的单个值,表示真或假。

class Discriminator(nn.Module):
    
    def __init__(self):
        # 初始化Pytorch父类
        super().__init__()
        
        # 定义神经网络层
        self.model = nn.Sequential(
            nn.Linear(4, 3),
            nn.Sigmoid(),
            nn.Linear(3, 1),
            nn.Sigmoid()
        )
        
        # 创建损失函数,使用均方误差
        self.loss_function = nn.MSELoss()

        # 创建优化器,使用随机梯度下降
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # 训练次数计数器
        self.counter = 0
        # 训练过程中损失值记录
        self.progress = []
    
    # 前向传播函数
    def forward(self, inputs):
        return self.model(inputs)
    
    # 训练函数
    def train(self, inputs, targets):
        # 前向传播,计算网络输出
        outputs = self.forward(inputs)
        
        # 计算损失值
        loss = self.loss_function(outputs, targets)

        # 累加训练次数
        self.counter += 1

        # 每10次训练记录损失值
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        # 每10000次输出训练次数   
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)
            

        # 梯度清零, 反向传播, 更新权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
    
    # 绘制损失变化图
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))

2.4 测试鉴别器

由于还没有创建生成器,所以无法测试能够与其竞争的鉴别器,目前能做的是,检验鉴别器是否能将真实数据与随机数据区分开。

训练

D = Discriminator()
for i in range(10000):
    # 真实数据
    D.train(generate_real(), torch.FloatTensor([1.0]))
    # 随机数据
    D.train(generate_random(4), torch.FloatTensor([0.0]))

结果:

counter =  10000
counter =  20000

上述代码虽然迭代了10000次,但是在每次迭代中分别对真实数据和随机数据进行了训练,累计训练20000次。

损失值变化

我们来看看训练过程中的损失值变化:

D.plot_progress()

在这里插入图片描述
如上图所示,损失值一开始接近0.25,随着训练次数增加,损失值逐渐接近0。

鉴别效果

我们再来测试一下鉴定器的效果,现在分别输入1010格式数据与随机数据,代码和运行结果如下:

print(D.forward(generate_real()).item())
print(D.forward(generate_random(4)).item())

结果:

0.8134430050849915
0.05087679252028465

得出的结果分别接近1和0,这说明鉴别器能够区分真实数据与随机噪声。

2.5 构造生成器

生成器也是一个神经网络,目的是尽量生成满足1010格式的4个值。为了使生成器与鉴别器不相伯仲地相互竞争与提高,生成器与鉴别器的结构正好相反:输入层只有1个节点;隐藏层有3个节点;输出层有4个节点,输出4个值。
代码如下,注意训练函数稍有不同,引入了鉴别器的损失函数进行反向传播,进而更新生成器权重

class Generator(nn.Module):
    
    def __init__(self):
        # 初始化Pytorch父类
        super().__init__()
        
        # 定义神经网络层
        self.model = nn.Sequential(
            nn.Linear(1, 3),
            nn.Sigmoid(),
            nn.Linear(3, 4),
            nn.Sigmoid()
        )

        # 注意这里没有损失函数,在训练时使用鉴别器的损失函数。

        # 创建优化器,使用随机梯度下降
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # 训练次数计数器
        self.counter = 0
        # 训练过程中损失值记录
        self.progress = []
        
    # 前向传播函数
    def forward(self, inputs):
        return self.model(inputs)
    
    # 训练函数
    def train(self, D, inputs, targets):
        # 前向传播,计算网络输出
        g_output = self.forward(inputs)
        
        # 将生成器输出,传入鉴别器,输出分类结果
        d_output = D.forward(g_output)
        
        # 计算鉴别误差
        loss = D.loss_function(d_output, targets)

        # 累加训练次数
        self.counter += 1

        # 每10次训练记录损失值
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        # 梯度清零, 反向传播, 更新权重。注意这里是对鉴别器的误差进行反向传播,但只更新生成器的权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

    # 绘制损失变化图
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))

2.6 检查生成器输出

同样地,我们也可以单独对生成器进行测试,以检查是否正常工作:

G = Generator()
G.forward(torch.FloatTensor([0.5]))

结果:

tensor([0.6172, 0.5979, 0.5700, 0.6622], grad_fn=<SigmoidBackward0>)

可以看到输出了4个值,但不符合1010格式,因为我们还没有对其进行训练。

2.7 训练GAN

训练

先看代码:

D = Discriminator()
G = Generator()

for i in range(10000):
    
    # 用真实样本数据训练鉴别器
    D.train(generate_real(), torch.FloatTensor([1.0]))
    
    # 用生成数据训练鉴别器
    # 此处训练是为了更新鉴别器权重,不需要更新生成器权重,使用detach()以避免计算生成器中的梯度
    D.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))
    
    # 训练生成器,更新生成器权重
    G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))

在迭代过程中,每次循环都会重复训练GAN的3个步骤:

  1. 用真实样本数据训练鉴别器,更新鉴别器权重
  2. 用生成的数据训练鉴别器,更新鉴别器权重。此处不需要更新生成器权重,detach()的作用是将其从计算图中分离出来
  3. 训练生成器,更新生成器权重

损失值变化

训练完成后,我们来看看鉴别器损失值的变化:

D.plot_progress()

在这里插入图片描述
这是一个非常有意思的结果,损失值最终保持在0.25附近。这说明鉴别器无法判断数据是真实的还是伪造的,于是输出0.5,由于我们损失函数使用的是均方误差,所以损失值是0.5的平方,即0.25。

下图是生成器的损失图,与鉴别器损失是互补的:

G.plot_progress()

在这里插入图片描述

生成数据

现在我们用训练好的生成器来生成数据:

G.forward(torch.FloatTensor([0.5]))

结果:

tensor([0.9537, 0.0367, 0.9493, 0.0507], grad_fn=<SigmoidBackward0>)

可以看到生成的数据符合1010格式。效果相当不错!

通过上面的训练,相信你已经熟悉GAN的结构了,后面我们将使用GAN来实现手写数字生成等更加酷炫的任务 😃

参考资料

《PyTorch生成对抗网络编程》(PS:写得太好了,强烈推荐。)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/170244.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HyperLogLog和Set比较 !!!

HyperLogLog和Set比较 HyperLogLog HyperLogLog常用于大数据量的统计&#xff0c; 比如页面访问量统计或者用户访问量统计&#xff0c;作为一种概率数据结构&#xff0c;HyperLogLog 以完美的精度换取高效的空间利用率。Redis HyperLogLog 实现最多使用 12 KB&#xff0c;并提…

docker推送镜像至阿里私有镜像仓库

文章目录一、注册阿里私有镜像仓库二、将公共镜像推送至私有镜像仓库1、首先拉取到mysql镜像2、登录阿里云Docker Registry&#xff08;这里的信息要更换成自己的&#xff09;3、将mysql镜像推送至Registry4、查看5、拉取镜像三、将正在启动的容器导出并推送至私有仓库1、将启动…

二分查找的最多比较次数

答案 对于二分搜索次数最多的问题&#xff0c;计算公式为&#xff0c;其中a , b , n 均为整数 当顺序表有n个关键字时候&#xff0c;查找失败&#xff0c;至少需要比较a次关键字 查找成功&#xff0c;至少需要b次 举例 已有从小到大排序的10000个数据&#xff0c;用二分查…

密码框限制xxs注入字符处理

<template><a-form-model-item ref"password" prop"password"><a-input-passwordplaceholder"请输入登录密码"v-model"cusForm.password"/></a-form-model-item> </template><script> export def…

「自控原理」3.2 二阶系统时域分析

本节介绍二阶系统的时域分析&#xff0c;主要介绍欠阻尼情况下的时间响应与动态性能指标 文章目录概述极点的表示方法无阻尼响应临界阻尼响应过阻尼响应欠阻尼响应欠阻尼系统的单位阶跃响应动态性能与极点分布的关系例题改善二阶系统动态性能的措施概述 二阶系统时间响应比较重…

elementUI如何设置input不可编辑

打开一个vue文件&#xff0c;添加一个input标签。如图&#xff1a; 添加disabled设置不可编辑。如图&#xff1a; 保存vue文件后使用浏览器打开&#xff0c;页面上显示的input已经实现不可编辑效果。如图&#xff1a; 参考&#xff1a;elementUI如何设置input不可编辑-百度…

出现死锁的场景分析及解决方法

在上一篇互斥锁的时候最后使用Account.class作为互斥锁&#xff0c;来解决转载问题&#xff0c;所有的账户转账操作都是串行的&#xff0c;性能太差。 我们可以考虑缩小锁定的范围&#xff0c;使用细粒度的锁&#xff0c;来提高并行度。例如用两把锁&#xff0c;转出账本一把&…

Python - 数据容器set(集合)

目录 集合的定义 集合的常用操作 添加新元素 add 移除元素 remove 从集合中随机取出元素 pop 清空集合 clear 取出2个集合的差集 difference 消除2个集合的交集 difference_update 2个集合合并 union for循环遍历 set的实用应用 集合的定义 不支持元素的重复&#…

软件设计师学习笔记-程序设计语言基础知识

前言 备战2023年5月份的软件设计师考试&#xff0c;在此记录学习之路。 知识点总结&#xff0c;具体内容请查看对应的模块。 提示&#xff1a;这里有软件设计师资料&#xff0c;包含软件设计师考试大纲、软件设计师第五版官方教程、历年考试真题。 通过百度网盘分享的文件&am…

好好学习,天天向上——“C”

各位uu们我又来啦&#xff0c;今天小雅兰来给大家分享一个有意思的东西&#xff0c;是为&#xff1a;天天向上的力量 基本问题&#xff1a;持续的价值 一年365天&#xff0c;每天进步1%&#xff0c;累积进步多少呢&#xff1f; 1.01^365 一年365天&#xff0c;每天退步1%&#…

python(运算符,顺序,选择,循环语句)

专栏&#xff1a;python 个人主页&#xff1a;HaiFan. 专栏简介&#xff1a;本专栏主要更新一些python的基础知识&#xff0c;也会实现一些小游戏和通讯录&#xff0c;学时管理系统之类的&#xff0c;有兴趣的朋友可以关注一下。 python基础语法2前言一、输入输出1.通过控制台输…

CSS 伪类

CSS 伪类 CSS 伪类是添加到选择器的关键字&#xff0c;用于指定所选元素的特殊状态。例如&#xff0c;伪类 :hover 可以用于选择一个按钮&#xff0c;当用户的指针悬停在按钮上时&#xff0c;设置此按钮的样式。 举例说明: button:hover {color: blue; }伪类由冒号&#xff…

【应用】SpringBoot -- Webflux + R2DBC 操作 MySQL

SpringBoot -- Webflux R2DBC 操作 MySQLWebflux 概述Webflux 基本使用Webflux R2DBC 操作 MySQLController ServiceRoute Handler参考文章Webflux 概述 简单来说&#xff0c;Webflux 是响应式编程的框架&#xff0c;与其对等的概念是 SpringMVC。两者的不同之处在于 Webf…

贪心策略(四)合并区间合集

目录 435. 无重叠区间移除元素的最小个数 无重叠区间 区间组合结果 合并区间_牛客题霸_牛客网 435. 无重叠区间 移除元素的最小个数 给定一个区间的集合 intervals &#xff0c;其中 intervals[i] [starti, endi] 。返回 需要移除区间的最小数量&#xff0c;使剩余区间互不重…

【Linux调试器-gdb使用】

目录 1. 背景 2. 使用 3 命令总结 1. 背景 通过c语言的学习我们知道程序的发布方式有两种&#xff0c;debug模式和release模式&#xff0c;debug模式就是我们所说的调试模式。我们已经熟悉了在Windows平台下VS系列的调试&#xff0c;接下来我们一起在无图形化界面的Linux下来…

2023-01-18 flink 11.6 时间水印 和 窗口周期的关系计算方法

forBoundedOutOfOrderness 和 TumblingEventTimeWindowsforBoundedOutOfOrderness&#xff08;M&#xff09;TumblingEventTimeWindows&#xff08;N&#xff09;第一条数据的时间TS1第一个窗口期公式&#xff1a;窗口开始时间&#xff1a;win_start ((TS1-M)/N) * N窗口结束时…

Docker上部署goweb项目

文章目录一、安装go语言环境①下载go语言环境②解压go语言环境到指定目录③验证是否成功④配置镜像加速二、go语言项目配置第一种&#xff1a;先编译后打包&#xff08;分步部署&#xff0c;靠谱&#xff09;第二种&#xff1a;直接打包法三、成功运行一、安装go语言环境 ①下…

Zabbix 监控 Linux操作系统的监控指标

一、Zabbix 监控 Linux操作系统的监控指标 (仅供参考) Zabbi x默认使用Zabbix agent监控操作系统,其内置的监控项可以满足系统大部分的指标监控,因此,在完成Zabbix agent的安装后,只需在前端页面配置并关联相应的系统监控模板就可以了。 如果内置监控项不能满足监控需求…

走向开放世界强化学习、IJCAI2022论文精选、机器人 RL 工具、强化学习招聘、《强化学习周刊》第73期...

No.73智源社区强化学习组强化学习周刊订阅《强化学习周刊》已经开启“订阅功能”&#xff0c;扫描下面二维码&#xff0c;进入主页&#xff0c;选择“关注TA”&#xff0c;我们会向您自动推送最新版的《强化学习周刊》。本期贡献者&#xff1a;&#xff08;李明&#xff0c;刘青…

【Kotlin】类的继承 ① ( 使用 open 关键字开启类的继承 | 使用 open 关键字开启方法重写 )

文章目录一、使用 open 关键字开启类的继承二、使用 open 关键字开启方法重写一、使用 open 关键字开启类的继承 Kotlin 中的类 默认都是 封闭的 , 无法被继承 , 如果要想类被继承 , 需要在定义类时 使用 open 关键字 ; 定义一个普通的 Kotlin 类 : class Person(val name: S…