G1 GAN生成MNIST手写数字图像

news2024/11/26 22:38:45
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

G1 GAN生成MNIST手写数字图像

1. 生成对抗网络 (GAN) 简介

生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成图像、视频等数据。GAN 由两个网络组成:

  • 生成器 (Generator):用于生成假的数据样本,试图让判别器无法分辨其为假的。
  • 判别器 (Discriminator):用于区分输入的数据是真实的还是生成器生成的。

GAN 的核心思想是,生成器和判别器通过相互对抗学习,生成器逐渐提高生成逼真数据的能力,而判别器逐渐提高区分真假数据的能力。最后,生成器生成的样本与真实样本之间的差异会越来越小。

GAN 的基本流程

  1. 判别器输入真实数据,判别器输出一个接近1的值,表示为真;
  2. 生成器生成假的数据,并试图欺骗判别器;
  3. 判别器输出接近0的值,表示为假;
  4. 生成器通过更新自身的参数,试图让判别器认为生成的数据是真实的。

GAN 的目标是使得生成器生成的假数据,能骗过判别器。

GAN 的损失函数

GAN 的训练目标是让生成器和判别器进行对抗训练,其损失函数分为两个部分:生成器损失和判别器损失。生成器的目标是最大化判别器判断生成数据为真的概率,判别器的目标是最大化正确判断真实数据和生成数据的概率。

判别器的损失函数定义为:

L D = − [ E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ] \mathcal{L}_D = - \left[ \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \right] LD=[Expdata[logD(x)]+Ezpz[log(1D(G(z)))]]

生成器的损失函数定义为:

L G = − E z ∼ p z [ log ⁡ D ( G ( z ) ) ] \mathcal{L}_G = - \mathbb{E}_{z \sim p_z} \left[ \log D(G(z)) \right] LG=Ezpz[logD(G(z))]

其中:

  • ( D(x) ) 表示判别器对真实数据 ( x ) 判别为真的概率;
  • ( G(z) ) 是生成器通过噪声 ( z ) 生成的假数据;
  • ( D(G(z)) ) 表示判别器对生成器生成数据的输出(希望趋向于1)。

2. PyTorch 实现

下面使用 PyTorch 实现 GAN 生成 MNIST 手写数字图像。

2.1 导入库与超参数设置

import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)

# 超参数设置
n_epochs = 50
batch_size = 64
lr = 0.0002
latent_dim = 100
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)

cuda = True if torch.cuda.is_available() else False

2.2 数据预处理

使用 torchvision.datasets.MNIST 下载并处理 MNIST 数据集。数据会被标准化到 [-1, 1] 区间,并通过 DataLoader 转化为可迭代数据集。

# 下载MNIST数据集并进行预处理
mnist = datasets.MNIST(root='./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Resize(img_size),
                           transforms.ToTensor(),
                           transforms.Normalize([0.5], [0.5])
                       ]))

dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

2.3 定义生成器模型

生成器接受一个随机噪声向量 ( z ),通过多层线性变换和激活函数逐步生成一个 28x28 的图像。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_area),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(img.size(0), *img_shape)

2.4 定义判别器模型

判别器是一个二分类网络,输入一个 28x28 的图像,输出一个表示真假概率的值。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

2.5 定义优化器与损失函数

generator = Generator()
discriminator = Discriminator()

# 定义损失函数
criterion = nn.BCELoss()

# 定义生成器和判别器的优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

if cuda:
    generator.cuda()
    discriminator.cuda()
    criterion.cuda()

2.6 训练过程

2.6.1 训练判别器

判别器需要区分真实图像和生成的假图像,通过两个损失值相加,更新判别器的参数。

real_img = Variable(imgs.type(torch.cuda.FloatTensor))
real_label = Variable(torch.ones(imgs.size(0), 1).cuda())
fake_label = Variable(torch.zeros(imgs.size(0), 1).cuda())

real_out = discriminator(real_img)
loss_real = criterion(real_out, real_label)

z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake = criterion(fake_out, fake_label)

loss_D = loss_real + loss_fake
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
2.6.2 训练生成器

生成器的目标是让判别器认为生成的数据是真实的,因此生成器的损失是判别器对假图像的输出。

z = Variable(torch.randn(imgs.size(0), latent_dim).cuda())
fake_img = generator(z)
output = discriminator(fake_img)

loss_G = criterion(output, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()

在这里插入图片描述

2.7 保存与可视化生成图像

if batches_done % sample_interval == 0:
    save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)

在这里插入图片描述

4. 总结

这周学习了如何使用 PyTorch 实现生成对抗网络 (GAN) 来生成 MNIST 手写数字图像。GAN 通过生成器与判别器之间的对抗学习,不断提升生成图像的质量,是一种非常强大的生成模型。可以在论文中将其作为数据增强的一种方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2220870.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何调试浏览器中的内存泄漏?

聚沙成塔每天进步一点点 本文回顾 ⭐ 专栏简介⭐ 如何调试浏览器中的内存泄漏?1. 什么是内存泄漏?2. 调试内存泄漏的工具3. 如何使用 Memory 面板进行内存调试3.1 获取内存快照(Heap Snapshot)获取内存快照的步骤:快照…

即时通讯增加Redis渠道

情况说明 在本地和服务器分别启动im服务,当本地发送消息时,会发现服务器上并没有收到消息 初版im只支持单机版,不支持分布式的情况。此次针对该情况对项目进行优化,文档中贴出的代码非完整代码,可自行查看参考资料[2] 代码结构调…

C Primer Plus 第9章——第一篇

你该逆袭了 文章目录 一、复习函数1、定义带形式参数的函数2、声明带形式参数函数的原型3、使用 return 从函数中返回值(1)、返回值不仅可以赋给变量,也可以被用作表达式的一部分。(2)、返回值不一定是变量的值&#x…

springboot redisTemplate hash 序列化探讨

前提提要:这个是个人小白总结,写完博客后开始厌蠢。 redisTemplate 有两种插入hash的方式 redisTemplate.opsForHash().putAll(key, map);redisTemplate.opsForHash().put(key, field, value);在使用的过程中,难免会疑问为什么 key field v…

Windows下部署autMan

一、安装autMan 下载autMan压缩包 https://github.com/hdbjlizhe/fanli/releases 解压安装包 二、运行(注意,无论是交互运行还是静默运行,终端均不可关闭) 基本运行 双击autMan.exe运行。 高级运行 在autMan文件夹&#xff0…

Sigrity Power SI Model Extraction模式如何提取电源网络的S参数和阻抗操作指导(一)

Sigrity Power SI Model Extraction模式如何提取电源网络的S参数和阻抗操作指导(一) Sigrity PowerSI是频域电磁场仿真工具,以下图为例介绍如果用它观测电源的网络的S参数以及阻抗的频域曲线. 观测IC端电源网络的自阻抗 1. 用powerSi.exe打开该SPD文件

工业相机详解及选型

工业相机相对于传统的民用相机而言,具有搞图像稳定性,传输能力和高抗干扰能力等,目前市面上的工业相机大多数是基于CCD(Charge Coupled Device)或CMOS(Complementary Metal Oxide Semiconductor)芯片的相机。 一,工业相机的分类 …

sentinel原理源码分析系列(六)-统计指标

调用链和统计节点构建完成,进入统计指标插槽,统计指标在最后执行的,等后面的插槽执行完,资源调用完成了,根据资源调用情况累计。指标统计是最重要的插槽,所有的功能都依靠指标数据,指标的正确与…

你知道什么叫数控加工中心吗?

加工中心是一种高度机电一体化的数控机床,具有刀库,自动换刀功能,对工件一次装夹后进行多工序加工的数控机床。通过计算的控制系统和稳定的机械结构,加工中心能够实现高精度的加工,确保工件的尺寸精度和表面质量。通过…

实用好助手

在现代职场中,拥有高效且适用的工具能够显著提升我们的工作效率与质量。除了常见的办公软件,还有许多小众但非常实用的工具可以大幅度优化工作流程。以下是另外五个推荐的工作软件,它们各自具备独特的功能与优势,值得一试。 1 …

【Docker系列】在 Docker 容器中打印和配置环境变量

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

双十一有哪些值得买的东西?2024年最全双十一好物推荐榜单来了!

双十一能够入手的好东西那肯定是非常多的,不过要想买到性价比高、实用性强的好物,就必须得做些功课了。作为一个智能家居和数码领域的博主,自然知道每年双十一买什么是最划算的。如果有朋友正在为双十一不知道买什么而发愁,那就快…

python+大数据+基于热门视频的数据分析研究【内含源码+文档+部署教程】

博主介绍:✌全网粉丝10W,前互联网大厂软件研发、集结硕博英豪成立工作室。专注于计算机相关专业毕业设计项目实战6年之久,选择我们就是选择放心、选择安心毕业✌ 🍅由于篇幅限制,想要获取完整文章或者源码,或者代做&am…

登录后端笔记(一):注册、登录;基于MD5加密

一、注册 一、参数:lombok pom.xml里引入依赖; 二、响应数据:Result 原视频 两个注解对应有参无参生成构造方法; data类型是泛型T,即data在使用时可对应object可对应string字符串可对应bean对象可对应map等&#x…

微信碰一碰支付系统有哪些好的?教程详解抢先看!

支付宝“碰一碰支付”的风刚刚刮起来,它的老对手微信便紧随其后,推出了自己的碰一碰支付设备,再次印证了这个项目市场前景广阔的同时,也让与碰一碰支付系统相关问题的热度又上了一层楼,尤其是微信碰一碰支付系统有哪些…

炒股VS炒游戏装备,哪个更好做

这个项目,赚个10%都是要被嫌弃的 虽然天天都在抒发自己对股市的看法,但自己自始至终也没有买进任何一支股票。之所以对这个话题感兴趣,着实是因为手上的游戏搬砖项目也是国际性买卖,跟国际形势,国际汇率挂钩&#xff0…

RAG总结及前沿之Meta-Chunking切分思路及VisRAG多模态实现机制解读

今天我们来看两个工作,一个是关于RAG的切分策略,Meta-Chunking,里面基于数学常识提到的边际采样分块(Margin Sampling Chunking)通过LLMs对连续句子是否需要分割进行二元分类,基于边际采样得到的概率差异来…

基于SSM+微信小程序的房屋租赁管理系统(房屋2)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1、项目介绍 基于SSM微信小程序的房屋租赁管理系统实现了有管理员、中介和用户。 1、管理员功能有,个人中心,用户管理,中介管理,房屋信息管理&#xff…

Nest.js 实战 (十五):前后端分离项目部署的最佳实践

☘️ 前言 本项目是一个采用现代前端框架 Vue3 与后端 Node.js 框架 Nest.js 实现的前后端分离架构的应用。Vue3 提供了高性能的前端组件化解决方案,而 Nest.js 则利用 TypeScript 带来的类型安全和模块化优势构建了一个健壮的服务端应用。通过这种技术栈组合&…

信雅纳Chimera 100G网络损伤仪助力Parallel Wireless开展5G RAN无线前传网络的损伤模拟

背景介绍 Parallel Wireless 为移动运营商提供唯一全覆盖的(5G/4G/3G/2G)软件支持的本地 OpenRAN (ORAN) 解决方案。该公司与全球 50 多家领先运营商合作,并被 Telefonica 和 Vodafone 评为表现最佳的供应商。Parallel Wireless 在多技术、开放式虚拟化…