(深度学习快速入门)第五章第一节2:GAN经典案例之MNIST手写数字生成

news2025/1/11 22:46:24
  • 获取pdf:密码7281

文章目录

  • 一:数据集介绍
  • 二:GAN简介
    • (1)简介
    • (2)损失函数
  • 三:代码编写
    • (1)参数及数据预处理
    • (2)生成器与判别器模型
    • (3)优化器和损失函数
    • (4)训练
  • 三:效果查看
    • (1)tensorboard
    • (2)生成图片效果

一:数据集介绍

MNIST数据集:MNIST是个手写数字图片集,每张图片都做了归一化处理,大小是28x28,并且是灰度图像,所以每张图像格式为1x28x28

  • 数据集下载地址

包括如下四个文件

在这里插入图片描述

含义如下

类别文件名描述
训练集图片train-images-idx3-ubyte.gz(9.9M)包含60000个样本
训练集标签train-labels-idx1-ubyte.gz(29KB)包含60000个标签
测试集图片t10k-images-idx3-ubyte.gz(1.6M)包含10000个样本
测试集标签t10k-labels-idx1-ubyte.gz(5KB)包含10000个样本

当然torchvision.datasets中也内置了这个数据集,可以通过如下代码从网络上下载

train_data = dataset.MNIST(root='./mnist/',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)
test_data = dataset.MNIST(root='./mnist/',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=False)
  • root:表示数据集待存放的目录
  • train:如果为true将会使用训练集的数据集(training.pt),如果为false将会使用测试集数据集(test.pt
  • download:如果为true将会从网络上下载并放入root中,如果数据集已下载则不会再次下载
  • transform:接受PIL图片并返回转换后的图片,常用的就是转换为tensor(这里便会调用torchvision.transform

数据集加载成功后,文件布局如下

在这里插入图片描述

二:GAN简介

(1)简介

GAN(Generative Adversial Nets,生成式对抗网络):这是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型有两个模型:生成模型(Generative Model)辨别模型(Discriminative Model)的互相博弈学习产生相当好的输出。实际使用时一般会选择DNN作为G和D

如下图,以论文中所述的制作假钞的例子为例进行说明

  • 生成模型G的目的是尽量能够生成足以以假乱真的假钞去欺骗判别模型D,让它以为这是真钞
  • 判别模型D的目的是尽量能够鉴别出生成模型G生成的假钞是假的

在这里插入图片描述

(2)损失函数

GAN损失函数如下
在这里插入图片描述

其中参数含义如下

  • x x x:真实的数据样本
  • z z z:噪声,从随机分布采集的样本
  • G G G:生成模型
  • D D D:判别模型
  • G ( z ) G(z) G(z):输入噪声生成一条样本
  • D ( x ) D(x) D(x):判别真实样本是否来自真实数据(如果是则为1,如果不是则为0)
  • D ( G ( z ) ) D(G(z)) D(G(z)):判别生成样本是否来自真实数据(如果是则为1,如果不是则为0)

该损失函数整体分为两个部分

第一部分:给定 G G G找到使 V V V最大化的 D D D,因为使 V V V最大化的 D D D会使判别器效果最好

  • 对于①:判别器的输入为真实数据 x x x E x ∼ p d a t a [ l o g D ( x ) ] E_{x}\sim p_{data}[logD(x)] Expdata[logD(x)]值越大表示判别器认为输入 x x x为真实数据的概率越大,也即表示判别器的能力越强,所以这一项输出越大对判别器越有利
  • 对于②:判别器的输入伪造数据 G ( z ) G(z) G(z),此时 D ( G ( z ) ) D(G(z)) D(G(z))越小那么就表示判别器将此伪造数据鉴别为真实数据的概率也越小,也即表示判别器的能力越强。注意此时第二项是 l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))的期望 E x ∼ p d a t a [ l o g ( 1 − D ( G ( z ) ) ) ] E_{x}\sim p_{data}[log(1-D(G(z)))] Expdata[log(1D(G(z)))]。所以当判别器能力越强时, D ( G ( z ) ) D(G(z)) D(G(z))越小同时 E x ∼ p d a t a [ l o g ( 1 − D ( G ( z ) ) ) ] E_{x}\sim p_{data}[log(1-D(G(z)))] Expdata[log(1D(G(z)))]也就越大

在这里插入图片描述

第二部分:给定 D D D找到使 V V V最小化的 G G G,因为使 V V V最小化的 G G G会使生成器效果最好

  • 对于①:由于固定了 D D D,而这一部分只和 D D D有关,因此这一部分是常量,所以可以舍去
  • 对于②:判别器的输入伪造数据 G ( z ) G(z) G(z),与上面不同的是,我们期望生成器的效果要好,尽可能骗过辨别器,所以 D ( G ( z ) ) D(G(z)) D(G(z))要尽可能大( D ( G ( z ) ) D(G(z)) D(G(z))越大表示辨别器鉴定此数据为真实数据的概率越大), E x ∼ p d a t a [ l o g ( 1 − D ( G ( z ) ) ) ] E_{x}\sim p_{data}[log(1-D(G(z)))] Expdata[log(1D(G(z)))]也就越小

三:代码编写

(1)参数及数据预处理

# 设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    print("GPU上运行")
else:
    print("CPU上运行")
# 图片格式
img_size = [1, 28, 28]

# batchsize
batchsize = 64

# latent_dim
latent_dim = 100

# 数据集及变化
data_transforms = transforms.Compose(
    [
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]
)
dataset = torchvision.datasets.MNIST(root='~/autodl-tmp/dataset', train=True, download=False, transform=data_transforms)

(2)生成器与判别器模型

# 生成器模型
"""
根据输入生成图像
"""

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))

            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, np.prod(img_size, dtype=np.int32)),

            nn.Tanh()
        )
    def forward(self, x):
        # [batchsize, latent_dim]
        output = self.model(x)
        image = output.reshape(x.shape[0], *img_size)
        return image

# 判别器模型
"""
判别图像真假
"""
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear( np.prod(img_size, dtype=np.int32), 512),
            nn.ReLU(inplace=True),

            nn.Linear(512, 256),
            nn.ReLU(inplace=True),

            nn.Linear(256, 128),
            nn.ReLU(inplace=True),

            nn.Linear(128, 1),
            nn.ReLU(inplace=True),

            nn.Sigmoid(),
        )

    def forward(self, x):
        # [batch_size, 1, 28, 28]
        x = x.reshape(x.shape[0], -1)
        output = self.model(x)

        return output

(3)优化器和损失函数

# 优化器和损失函数
generator = Generator()
generator = generator.to(device)
discriminator = Discriminator()
discriminator = discriminator.to(device)

g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
loss_func = nn.BCELoss()

(4)训练

def train():
    step = 0
    dataloader = DataLoader(dataset=dataset, batch_size=batchsize, shuffle=True, drop_last=True, num_workers=8)
    for epoch in range(1, 100):
        print("-----------当前epoch:{}-----------".format(epoch))
        for i, batch in enumerate(dataloader):
            print("-----------当前batch:{}/{}-----------".format(i, (len(dataloader))))
            # 拿到真实图片
            X, _ = batch
            X = X.to(device)
            # 采用标准正态分布得到的batchsize × latent_dim的向量
            z = torch.randn(batchsize, latent_dim)
            z = z.to(device)
            # 送入生成器生成假图片
            pred_X = generator(z)

            g_optimizer.zero_grad()
            """
            生成器损失:
            让生成的图像与通过辨别器与torch.ones(batchsize, 1)越接近越好
            
            """
            g_loss = loss_func(discriminator(pred_X), torch.ones(batchsize, 1).to(device))
            g_loss.backward()
            g_optimizer.step()

            d_optimizer.zero_grad()
            """
            辨别器损失:
            一方面让真实图片通过辨别器与torch.ones(batchsize, 1)越接近越好
            另一方面让生成图片通过辨别器与torch.zeros(batchsize, 0)越接近越好
            """

            d_loss = 0.5 * (loss_func(discriminator(X), torch.ones(batchsize, 1).to(device)) + loss_func(discriminator(pred_X.detach()), torch.zeros(batchsize, 1).to(device)))

            d_loss.backward()
            d_optimizer.step()

            print("生成器损失{}".format(g_loss), "辨别器损失{}".format(d_loss))

            logger.add_scalar('g_loss', g_loss, step)
            logger.add_scalar('d_loss', d_loss, step)
            
            step = step+1
            if step % 1000 == 0:
                save_image(pred_X.data[:25], "./image_save/image_{}.png".format(step), nrow=5)

三:效果查看

(1)tensorboard

在这里插入图片描述

(2)生成图片效果

每1000个step保存一次照片,最后生成了92张图片,每张图片由每个batch的前25张图片构成

在这里插入图片描述


1000-step
在这里插入图片描述

5000-step
在这里插入图片描述

10000-step
在这里插入图片描述

20000-step
在这里插入图片描述

30000-step
在这里插入图片描述

50000-step
在这里插入图片描述

70000-step

在这里插入图片描述

80000-step

在这里插入图片描述

90000-step
在这里插入图片描述

920000-step(final)
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/331855.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Acwing---112.雷达设备

雷达设备1.题目2.基本思想3.代码实现1.题目 假设海岸是一条无限长的直线,陆地位于海岸的一侧,海洋位于另外一侧。 每个小岛都位于海洋一侧的某个点上。 雷达装置均位于海岸线上,且雷达的监测范围为 d,当小岛与某雷达的距离不超…

2023-02-09 mysql/innodb存储

前言 大家都知道 MySQL 的数据都是保存在磁盘的,那具体是保存在哪个文件呢?MySQL 存储的行为是由存储引擎实现的,MySQL 支持多种存储引擎,不同的存储引擎保存的文件自然也不同。InnoDB 是我们常用的存储引擎,也是 MySQL 默认的存储引擎。本文主要以 InnoDB 存储引擎展开讨…

centos7.6 设置防火墙

1、查看系统版本 cat /etc/redhat-release2、查看防火墙运行状态 systemctl status firewalld这里可以看到当前是未运行状态(inactive)。 3、关闭开机自启动防火墙 systemctl disable firewalld.service4、启动防火墙并查看状态,系统默认 22 端口是开启的。 sy…

云原生微服务应用平台 EDAS 2022 年度报告

作者:孤戈 最近一年来,随着我们的客户对于云技术的诉求从资源快速交付的服务,转变为对资源精益运用的服务。EDAS 团队结合公共云上所服务的企业类客户的几万个应用,选取了 8 个最具代表性的指标,进行了一次系统性的分…

进度条实时显示request下载文件的解决方案

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,科大讯飞比赛第三名,CCF比赛第四名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

Linux -文件打包和解压缩

1、概念 在 Windows 上最常见的不外乎这两种 .zip,.7z 后缀的压缩文件。而在 Linux 上面常见的格式除了以上两种外,还有 .rar,.gz,.xz,.bz2,.tar,.tar.gz,.tar.xz,*.tar…

python--turtle

前言 就随便练练,学习一下turtle库的使用 正文 1.语法学习 import turtle #导入库 turtle.showturtle() #画笔显示箭头 turtle.write("我是大帅逼") #写下字符串 turtle.forward(300) …

拦截器interceptor总结

拦截器一. 概念拦截器和AOP的区别:拦截器和过滤器的区别:二. 入门案例2.1 定义拦截器bean2.2 定义配置类2.3 执行流程2.4 简化配置类到SpringMvcConfig中一. 概念 引入: 消息从浏览器发送到后端,请求会先到达Tocmat服务器&#x…

C语言一维数组篇【下】——每日刷题经验分享

一维数组篇——每日刷题经验分享~😎前言🙌有序序列插入一个整数 😊序列中删除指定数字 😊序列中整数去重小乐乐查找数字筛选法求素数总结撒花💞😎博客昵称:博客小梦~ 😊最喜欢的座右…

【FFMPEG源码分析】从ffplay源码摸清ffmpeg框架(一)

ffplay入口 ffmpeg\fftools\ffplay.c int main(int argc, char **argv) {/*******************start 动态库加载/网络初始化等**************/int flags;VideoState *is;init_dynload();av_log_set_flags(AV_LOG_SKIP_REPEATED);parse_loglevel(argc, argv, options);/* regis…

SSJ-21A AC220V静态【时间继电器】

系列型号: SSJ-11B静态时间继电器;SSJ-21B静态时间继电器 SSJ-21A静态时间继电器;SSJ-22A静态时间继电器 SSJ-22B静态时间继电器SSJ-42B静态时间继电器 SSJ-42A静态时间继电器SSJ-41A静态时间继电器 SSJ-41B静态时间继电器SSJ-32B静态时间继电…

dev-c++解决中文输出乱码问题

之前写c程序老是出现编译输出乱码的问题,就去博客,百度查阅了一番找了找了办法。废话不多说直接上操作。 首先打开dev-c-》工具-》编译选项 在第一行输入 -fexec-charsetgbk//生成gbk格式 -fexec-charsetUFT-8//生成UFT-8格式 也可以输入UFT-8格式&…

获取主机RDP连接凭据

为了避免每次连接服务器都进行身份验证,经常使用RDP的用户可能勾选保存连接凭据,以便进行快速的身份验证。这些凭据都使用数据保护API以加密形式存储在windows的凭据管理器中,路径为“%USERPROFILE%\AppData\Local\Microsoft\Credentials”执…

C语言---宏

专栏:C语言 个人主页:HaiFan. 专栏简介:本专栏主要更新一些C语言的基础知识,也会实现一些小游戏和通讯录,学时管理系统之类的,有兴趣的朋友可以关注一下。 #define预处理预定义符号define#define定义标识符…

云片验证码分析思路

本文仅供学习参考,又不懂的可以联系博主 目标链接: aHR0cHM6Ly93d3cueXVucGlhbi5jb20vcHJvZHVjdC9jYXB0Y2hh接口分析 captcha/get 验证码图片获取接口,GET请求,包含四个参数cb、i、k、captchaId 接口返回,如果是滑动验证码&…

「敏捷建模」敏捷设计理念的纪律

本文概述了敏捷软件开发团队的设计策略。这些策略对于扩展敏捷软件开发以满足现代IT组织的实际需求至关重要。敏捷的设计方法与传统方法截然不同,显然也更有效。重要的是要了解: 敏捷设计实践 敏捷设计理念 整个敏捷生命周期的设计 1.敏捷设计实践 从高…

大数据概述

一、大数据时代 大数据时代 三次信息化浪潮:个人计算机80年-互联网95年-物联网、云计算和大数据(2010年) 发展时间较短,大数据人才缺失大数据人才 培训出来的:Java-》大数据 优点:对于大数据技术的细节会比较清楚 缺点&#xff1…

为什么Redis集群的最大槽数是16384个?

对于客户端请求的key,根据公式HASH_SLOTCRC16(key) mod 16384,计算出映射到哪个分片上,然后Redis会去相应的节点进行操作! 为什么有16384个槽? Redis集群并没有使用一致性hash而是引入了哈希槽的概念。Redis 集群有16…

金仓数据库事务日志与检查点

事务日志与检查点 WAL文件,在金仓数据库中,事务日志文件称为Write Ahead Log(预写式日志,简称WAL)。 WAL存储了数据库系统中所有更改和操作的历史,相当于Oracle的REDO。 WAL机制是在这个写数据的过程中加…

[Android开发基础4] 意图与意图过滤器

文章目录 意图(Intent) 简介 显式意图 隐式意图 意图过滤器(IntentFiler) action data category 意图(Intent) 简介 Intent被称为意图,是程序中各组件进行交互的一种重要方式&#xff0c…