基于PyTorch的MNIST手写数字GAN生成器

news2024/11/15 17:00:28

文章目录

    • 前言小笔记
      • 关键特性
      • 技术栈
      • 使用场景
      • 贡献者:
    • 完整代码
    • 代码解析
      • 1. 导入必要的库
      • 2. 设备配置
      • 3. 超参数设置
      • 4. 创建样本目录
      • 5. 图像处理
      • 6. 加载MNIST数据集
      • 7. 创建数据加载器
      • 8. 定义判别器(Discriminator)D
      • 9. 定义生成器(Generator)G
      • 10. 设备设置
      • 11. 损失函数和优化器
      • 12. 辅助函数
      • 13. 训练循环
      • 14. 打印训练进度
      • 15. 保存图像
      • 16. 保存模型检查点
    • 运行效果

前言小笔记

这份代码是利用深度学习技术,通过生成对抗网络(GAN)模型,实现了对手写数字图像的生成。MNIST数据集是一个广泛使用的数据库,包含了大量的手写数字灰度图像,是机器学习和计算机视觉领域的标准测试集。

关键特性

  • 设备适配性:自动检测并使用可用的GPU资源,提高计算效率。
  • 超参数配置:提供了灵活的超参数设置,包括潜在空间大小、隐藏层大小、图像尺寸等,以适应不同的训练需求。
  • 图像预处理:实现了图像的归一化处理,将像素值标准化到[-1, 1]区间,以利于神经网络的训练。
  • 数据加载:使用DataLoader高效地加载和批处理数据,同时支持数据打乱,提高模型泛化能力。
  • 判别器与生成器网络:定义了两个神经网络模型,判别器用于区分真实图像与生成图像,生成器用于生成逼真的数字图像。
  • 损失函数与优化器:选用了二元交叉熵损失函数和Adam优化器,确保了模型的有效训练。
  • 训练循环:实现了完整的训练逻辑,包括判别器和生成器的交替训练,以及梯度的更新。
  • 进度监控:在训练过程中提供了详细的进度输出,方便用户监控训练状态。
  • 图像保存:训练过程中会生成并保存真实图像和假图像的样本,用于可视化训练效果。
  • 模型保存:训练完成后,模型参数会被保存,方便后续的模型加载和使用。

技术栈

  • PyTorch:主要的深度学习框架,用于构建和训练神经网络。
  • torchvision:PyTorch的扩展包,提供图像处理和数据加载工具。

使用场景

本项目适用于深度学习研究、教育、数据科学竞赛等场景,特别是在需要生成图像数据或理解GAN工作原理的场合。

贡献者:

本代码来源于
https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/generative_adversarial_network/main.py#L41-L57适用于希望快速入门GAN或在MNIST数据集上实践GAN模型的用户。

完整代码

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5],   # 1 for greyscale channels
                                     std=[0.5])])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='../../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

# Generator 
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

# Device setting
D = D.to(device)
G = G.to(device)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

代码解析

1. 导入必要的库

代码开始处导入了多个Python库,这些库提供了后续操作所需的功能。

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

2. 设备配置

设置设备为GPU(如果可用),否则使用CPU。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3. 超参数设置

定义了网络训练过程中将使用的超参数,如潜在空间大小、隐藏层大小等。

latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

4. 创建样本目录

如果不存在,创建一个目录来保存生成的样本图像。

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

5. 图像处理

定义图像预处理步骤,包括转换为张量和归一化。

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

6. 加载MNIST数据集

加载MNIST数据集,并应用上面定义的转换。

mnist = torchvision.datasets.MNIST(root='../../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

7. 创建数据加载器

创建一个DataLoader对象,用于批量加载数据,并在训练时打乱数据顺序。

data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

8. 定义判别器(Discriminator)D

使用nn.Sequential定义了一个简单的神经网络结构作为判别器。

D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    # ... 其他层 ...
    nn.Sigmoid()
)

9. 定义生成器(Generator)G

同样使用nn.Sequential定义生成器网络结构。

G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    # ... 其他层 ...
    nn.Tanh()
)

10. 设备设置

将判别器和生成器移动到之前设置的设备上。

D = D.to(device)
G = G.to(device)

11. 损失函数和优化器

定义了二元交叉熵损失函数和两个Adam优化器,分别用于判别器和生成器。

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

12. 辅助函数

定义了两个辅助函数,denorm用于将归一化的图像反归一化,reset_grad用于清除梯度。

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

13. 训练循环

实现了GAN的训练过程,包括训练判别器和生成器的逻辑。

for epoch in range(num_epochs):
    # ... 训练逻辑 ...

14. 打印训练进度

在训练过程中,每隔一定步数打印当前的训练状态。

15. 保存图像

在训练过程中,保存真实图像和生成的假图像。

# Save real images
# Save sampled images

16. 保存模型检查点

训练结束后,保存生成器和判别器的模型参数。

torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

这段代码实现了一个典型的GAN训练流程,包括数据预处理、模型定义、训练循环、图像保存和模型保存等步骤。

运行效果

源代码直接运行:在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2054411.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言(15)——顺序表的应用

目录 1.基于动态顺序表实现通讯录项⽬ 1.1功能要求 1.2代码实现 2. 顺序表经典算法 1.基于动态顺序表实现通讯录项⽬ 1.1功能要求 1)⾄少能够存储100个⼈的通讯信息 2)能够保存⽤⼾信息:名字、性别、年龄、电话、地址等 3)…

生活垃圾填埋场污染监测:新标准下的技术革新与环境保护

随着城市化进程的加速,生活垃圾产生量急剧增加,如何有效处理并控制其带来的环境污染成为亟待解决的问题。近日,生态环境部发布了新修订的《生活垃圾填埋场污染控制标准》(GB 16889-2024),将自2024年9月1日起…

【Redis】哈希类型详解及缓存方式对比:从命令操作到实际应用场景

目录 Hash 哈希命令命令⼩结内部编码使⽤场景缓存方式对比 Hash 哈希 ⼏乎所有的主流编程语⾔都提供了哈希(hash)类型,它们的叫法可能是哈希、字典、关联数组、映射。在 Redis 中,哈希类型是指值本身又是⼀个键值对结构&#xff…

万维网与HTTP协议:基础知识简明指南

引言 在当今的数字时代,了解万维网(World Wide Web, WWW)和HTTP协议(Hyper Text Transfer Protocol)是至关重要的。本文将为基础小白们简明扼要地介绍万维网及其核心协议HTTP,并通过简单的例子和清晰的段落…

三级_网络技术_34_网络管理技术

1.在某主机上用浏览器无法访问到域名为www.tipu.edu.cn的网站,并且在该主机上执行tracert命令时有如下信息 分析以上信息,会造成这种现象的原因是 相关路由器上进行了访问控制 服务器 wwww.tjipu.edu.cn工作不正常 该计算机设置的DNS服务器工作不正常…

知行科技半年报显示商业化进展提速,下一个亮点在出海?

中国智驾落地竞速比拼愈演愈烈,让智驾公司陷入颇为紧张的竞争氛围。然而烈火出真金,这场角逐也成为领先企业脱颖而出的机会。 8月16日晚,智驾解决方案提供商知行科技(HK:01274)发布2024年上半年财报。数据显示,知行科技维持了营收…

SPI驱动学习一(协议原理)

目录 一、SPI协议介绍1. SPI 协议概述2. SPI 总线的主要组成部分3. SPI 协议的工作原理3. SPI 通信模式 二、SPI 协议的优点与缺点三、应用实例与常见问题1. 常用外设设备2. 常见问题3. 同时上电问题详细分析可能的原因解决方案 一、SPI协议介绍 1. SPI 协议概述 SPI&#xff…

【DP动态规划】学习笔记大全

-------------------------------------------------------本篇文章尚未完结,大家可以先看已有部分------------------------------------------------------- 【DP动态规划】学习笔记大全 Part 1 背包DP1.1 01背包1.1.1 题意解释1.1.2 为什么不使用贪心1.1.3 该如何…

【机器学习西瓜书学习笔记——规则学习】

机器学习西瓜书学习笔记【第十五章】 第十五章 规则学习15.1 基本概念15.2 序贯覆盖最简单的做法两种产生规则的策略 15.3 剪枝优化预剪枝后剪枝 15.4 一阶规则学习**FOIL算法** 15.5 归纳逻辑程序设计( I L P ILP ILP)最小一般泛化逆归结 第十五章 规则学习 15.1 基本概念 规…

干货|嵌入式分析产品选型指南

在当今数据驱动的商业环境中,业务系统的嵌入式分析能力正成为企业决策的关键能力。将数据分析能力嵌入到企业的核心业务流程中,能够帮助企业快速洞察业务趋势,做出更加明智的业务决策。随着市场对数据分析工具的需求日益增长,选择…

本地生活服务平台源码在哪里?2大获取渠道源码质量解析!

当前,本地生活赛道的发展潜力和收益前景已经日渐显化,本地生活服务商的数量也随之不断增长。不过,由于官方平台对于其本地生活服务商的申请条件并未放宽,因此,新增本地生活服务商中的绝大多数都会选择部署本地生活服务…

letcode 分类练习 654. 最大二叉树 617.合并二叉树 700.二叉搜索树中的搜索 98.验证二叉搜索树

letcode 分类练习 654. 最大二叉树 617.合并二叉树 700.二叉搜索树中的搜索 98.验证二叉搜索树 654. 最大二叉树617.合并二叉树700.二叉搜索树中的搜索98.验证二叉搜索树 654. 最大二叉树 class Solution { public:TreeNode* build(vector<int>& nums, int left, int…

Spring MVC中获取请求参数的方式

在Spring MVC中获取请求方式参数的主要方式有RequestParam&#xff0c;PathVariable&#xff0c;RequestBody&#xff0c;HttpServletRequest&#xff0c;RequestHeader等方式&#xff0c;接下来我们分别对其请求获取参数的方式进行相关介绍和使用。 RequestParam 用于获取请…

AMR 机器人底盘分析(补充中)

AMR 机器人底盘分析 1 介绍2 不同轮系底盘类型单舵轮双舵轮底盘四舵轮底盘麦克纳姆轮底盘两驱差速底盘四驱差速底盘单差速总成四差速总成底盘 3 行业专利分析CN220701198U -- 某柔CN110758038A谋星翼*菲谋工 参考 1 介绍 AGV 广泛应用于物流、制造业、安防巡检等领域&#xff…

C语言部分内存函数详解

C语言部分内存函数详解 前言1.memcpy1.1基本用法1.2注意事项**目标空间与原空间不能重叠****目标空间原数据会被覆盖****目标空间要够大****拷贝字节数需小于原空间大小** 1.3模拟实现 2.memmove2.1基本用法2.2注意事项2.3模拟实现 3.memset3.1基本用法 4.memcmp4.1基本用法4.2…

C#使用onnxruntime加载模型,部署到别人的PC上报错

C#使用onnxruntime加载模型&#xff0c;部署到别人的PC上报错 C#使用onnxruntime加载模型&#xff0c;部署到别人的PC上报错解决方案 C#使用onnxruntime加载模型&#xff0c;部署到别人的PC上报错 C#使用onnxruntime加载模型&#xff0c;部署到别人的PC上报错&#xff1a; Sys…

Python Web 应用和数据处理任务库之Redis Queue (RQ) 使用详解

概要 在现代 Web 应用和数据处理任务中,后台任务处理是一个非常重要的部分。Redis Queue (RQ) 是一个使用 Redis 作为消息队列的简单 Python 库,专注于处理异步任务。RQ 易于设置和使用,适用于需要后台处理的 Web 应用或数据处理项目。本文将详细介绍 RQ 库,包括其安装方法…

火狐如何离线继承配置

陪伴自己6年的电脑&#xff0c;因为CPU烧了&#xff0c;导致一些配置没导出来&#xff0c;其中包括浏览器的收藏记录、网站密码。 火狐浏览器离线继承配置 把老电脑的C盘取出&#xff0c;插入硬盘盒中&#xff0c;找到C:\Users\用户名\AppData\Roaming\Mozilla\Firefox\Profile…

【MySQL】JDBC的基础使用

系列文章目录 第一章 数据库基础 第二章 数据库基本操作 第三章数据库约束 第四章表的设计 第五章查询进阶 第六章索引和事务 文章目录 系列文章目录前言一、JDBC基本概念二、JDBC的准备工作三、JDBC-Demo小结 四、JDBC进阶写法总结 前言 在前面对MySQL已经有了基本的认知&am…

分类预测|基于白鲸优化混合核极限学习机结合Adaboost的数据分类预测Matlab程序BWO-HKELM-Adaboost

分类预测|基于白鲸优化混合核极限学习机结合Adaboost的数据分类预测Matlab程序BWO-HKELM-Adaboost 文章目录 前言分类预测|基于白鲸优化混合核极限学习机结合Adaboost的数据分类预测Matlab程序BWO-HKELM-Adaboost 一、BWO-HKELM-Adaboost模型1. 模型组成1.1 白鲸优化算法&#…