【Week-G1】调用官方GAN实现MNIST数字识别,Pytorch框架

news2024/11/18 1:31:11

文章目录

  • 1. 准备数据
    • 1.1 配置超参数
    • 1.2 下载数据
    • 1.3 配置数据
  • 2. 创建模型
    • 2.1 定义鉴别器
    • 2.2 定义生成器
  • 3. 训练模型
    • 3.1 创建实例
    • 3.2 开始训练
    • 3.3 保存模型
  • 4. 什么是GAN(对抗生成网络)?

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

说明:
(1)使用CPU时,屏蔽.cuda(),否则报错:
在这里插入图片描述

1. 准备数据

系统环境:
语言:Python3.7.8
编译器:VSCode
深度学习框架:torch 1.13.1

1.1 配置超参数

print("***********1.1 配置超参数*****************")
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch
## 创建文件夹
# 程序在路径 D:\jupyter notebook\DL-100-days\下运行,也就是下方的 ./
os.makedirs("./GAN/G1/images/", exist_ok=True)  # 记录训练过程的图片效果
os.makedirs("./GAN/G1/save/", exist_ok=True)    # 训练完成时,模型的保存位置
os.makedirs("./GAN/G1/mnist/", exist_ok=True)   # 下载数据集存放的位置
## 超参数配置
n_epochs  = 50
batch_size = 64
lr  = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 2
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500
#图像的尺寸(1, 28, 28),和图像的像素面积(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)
#设置cuda: (cuda:0)
cuda = True if torch.cuda.is_available() else False
print("CUDA: ", cuda)
print("\n")

文件路径如下图:
在这里插入图片描述
使用CPU版本,所以打印的CUDA结果为FALSE;
在这里插入图片描述

1.2 下载数据

print("***********2. 下载数据*****************")
mnist = datasets.MNIST(
    root='./datasets/', train=True, download=True, 
    transform=transforms.Compose(
         [transforms.Resize(img_size), 
          transforms.ToTensor(), 
          transforms.Normalize([0.5], [0.5])]),
)
print("\n")

在这里插入图片描述

1.3 配置数据

print("***********1.3 配置数据*****************")
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True
)
print("\n")

2. 创建模型

2.1 定义鉴别器

print("***********2. 创建模型********************")
print("***********2.1 定义鉴别器*****************")
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity
print("\n")

2.2 定义生成器

print("***********2.2 定义生成器*****************")
class Generate(nn.Module):
    def __init__(self):
        super(Generate, self).__init__()
 
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
 
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_area),
            nn.Tanh()
        )
 
    def forward(self, z):
        imgs = self.model(z)
        imgs = imgs.view(imgs.size(0), *img_shape)
        return imgs
print("\n")

3. 训练模型

3.1 创建实例

print("***********3. 训练模型*****************")
print("***********3.1 创建实例****************")
generator = Generate()
discriminator = Discriminator()
 
criterion = torch.nn.BCELoss()
 
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
 
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator#.cuda()
    criterion = criterion.cuda()
print("\n")

3.2 开始训练

print("***********3.2 开始训练*****************")
for epoch in range(n_epochs):  # epoch:50
    for i, (imgs, _) in enumerate(dataloader):
 
        ## =============================训练判别器==================
        ## view(): 相当于numpy中的reshape,重新定义矩阵的形状, 相当于reshape(128,784)  原来是(128, 1, 28, 28)
        imgs = imgs.view(imgs.size(0), -1)  # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs).cuda()  # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1))#.cuda()  ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1))#.cuda()  ## 定义假的图片的label为0
 
        ## ---------------------
        ##  Train Discriminator
        ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        ## ---------------------
        ## 计算真实图片的损失
        real_out = discriminator(real_img)  # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label)  # 得到真实图片的loss
        real_scores = real_out  # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim))#.cuda()  ## 随机生成一些噪声, 大小为(128, 100)
        fake_img = generator(z).detach()  ## 随机噪声放入生成网络中,生成一张假的图片。
        fake_out = discriminator(fake_img)  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)  ## 得到假的图片的loss
        fake_scores = fake_out  ## 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()  # 在反向传播之前,先将梯度归0
        loss_D.backward()  # 将误差反向传播
        optimizer_D.step()  # 更新参数
 
        ## -----------------
        ##  Train Generator
        ## 原理:目的是希望生成的假的图片被判别器判断为真的图片,
        ## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
        ## 反向传播更新的参数是生成网络里面的参数,
        ## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的
        ## -----------------
        z = Variable(torch.randn(imgs.size(0), latent_dim))#.cuda()  ## 得到随机噪声
        fake_img = generator(z)  ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)  ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)  ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()  ## 梯度归0
        loss_G.backward()  ## 进行反向传播
        optimizer_G.step()  ## step()一般用在反向传播后面,用于更新生成网络的参数
 
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 300 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(),
                   fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./GAN/G1/images/%d.png" % batches_done, nrow=5, normalize=True)
print("\n")

训练结果:

***********3.2 开始训练*****************
[Epoch 0/50] [Batch 299/938] [D loss: 1.363185] [G loss: 1.811396] [D real: 0.922554] [D fake: 0.692913]
[Epoch 0/50] [Batch 599/938] [D loss: 0.753042] [G loss: 2.229975] [D real: 0.815112] [D fake: 0.410212]
[Epoch 0/50] [Batch 899/938] [D loss: 1.049122] [G loss: 1.940738] [D real: 0.789812] [D fake: 0.548839]
... ... ...
[Epoch 48/50] [Batch 299/938] [D loss: 0.956054] [G loss: 1.398938] [D real: 0.661061] [D fake: 0.327662]
[Epoch 48/50] [Batch 599/938] [D loss: 1.070262] [G loss: 0.950201] [D real: 0.538358] [D fake: 0.234096]
[Epoch 48/50] [Batch 899/938] [D loss: 1.012980] [G loss: 1.247620] [D real: 0.650552] [D fake: 0.319423]
[Epoch 49/50] [Batch 299/938] [D loss: 1.254801] [G loss: 1.048441] [D real: 0.522079] [D fake: 0.313869]
[Epoch 49/50] [Batch 599/938] [D loss: 0.884523] [G loss: 1.709880] [D real: 0.767361] [D fake: 0.402201]
[Epoch 49/50] [Batch 899/938] [D loss: 1.019181] [G loss: 1.608823] [D real: 0.739194] [D fake: 0.421154]

./GAN/G1/images/下的缩略图如下:
在这里插入图片描述
部分详细图如下:
在这里插入图片描述

3.3 保存模型

print("***********3.3 保存模型*****************")
torch.save(generator.state_dict(), "./GAN/G1/save/generator.pth")
torch.save(discriminator.state_dict(). "./GAN/G1/save/discriminator.pth")
print("\n")

保存的模型文件如下:
在这里插入图片描述

4. 什么是GAN(对抗生成网络)?

【详解1】
【详解2】

机器学习的模型大体分为两类:判别模型(Discriminative Model)和生成模型(Generative Model)。

  • 判别模型:输入变量,使用模型进行预测
  • 生成模型:给出目标的隐含信息,随机产生观测数据。比如:给出一系列猫的图片,来生成一张新的猫的图片。重要点在于“生成”二字。

GAN:适用于无监督学习,该网络的框架由(至少)两个模块构成,即判别模型(Discriminative Model)和生成模型(Generative Model),通过二者的互相博弈学习来产生相当好的输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1867839.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

追求准确,还是追求举一反三,聊天机器人智能程度的困境 | Chatopera

在为企业客户上线聊天机器人客服的过程中,总会遇到一个问题,这让用户和我们都感到纠结。 到底是追求让机器人能准确的回答问题,还是让机器人可以举一反三的回答问题。 准确的回答问题,就是不容许回答错了,但是这样机…

windows中使用anaconda管理python版本

anaconda下载 python的版本问题实在是很大,版本低了高了都会影响脚本的执行,anaconda工具为此而生,不管是在windows下还是linux下,Anaconda的命令跟操作逻辑都是相同的,窥一斑而知全豹,本文在windows下示例如何使用anaconda anaconda的逻辑就是 他是一个全局的管理者,能创建工…

在Ubuntu中使用ROS搭建PX4 Gazebo 模拟飞行 四旋翼 固定翼

综合了网上很多教程以及踩了很多坑总结下来的教程 Ubuntu安装 此处不在详细说明,网上可随处搜到 ROS安装 感谢鱼香ROS大佬提供一键安装脚本 wget http://fishros.com/install -O fishros && sudo bash fishros 接下来按顺序按 1 1 2 3 1 再次运行 w…

红酒哲学:品味流转时光,探寻生活之深邃奥秘

在繁华的都市中,我们时常被各种声音和色彩所包围,追求着速度与激情。然而,在这喧嚣之中,总有那么一刻,我们渴望静下心来,品味一份不同的宁静与深度。这时,一杯雷盛红酒便成了我们与内心对话的桥…

太赞了!SD AI绘画,热门青衫映雪写真制作,一键出片,轻松复刻!【内含相关模型及ComfyUI工作流】

hello,大家好我是安琪! 今天安琪给大家带来了一篇关于写真制作,我通过SD WebUI进行本次青衫映雪主题的写真制作。(相关内容文末可自行扫描获取) 准备工作: 1.大模型准备真人写实大模型,我这里使用了TQing v3.4 2.…

Radxa 学习摘录

文章目录 一、参考资料二、硬件知识 一、参考资料 技术论坛(推荐) 官方资料下载 wiki资料 u-boot 文档 u-boot 源码 内核文档 内核源码 原理图 二、硬件知识 Radxa 3B 主板概览 MIPI接口 MIPI CSI(Camera Serial Interface)…

【前端】HTML+CSS复习记录【2】

文章目录 前言一、img(图片标签)二、a(链接标签)三、ul(无序列表)四、ol(有序列表)系列文章目录 前言 长时间未使用HTML编程,前端知识感觉忘得差不多了。通过梳理知识点…

智慧园区大数据云平台建设方案(Word原件)

第一章 项目建设背景及现状 第二章 园区创新发展趋势 第三章 工业园区大数据存在的问题 第四章 智慧工业园区大数据建设目的 第五章 智慧园区总体构架 第六章 系统核心组件 第七章 智慧工业园区大数据平台规划设计 获取方式:本文末个人名片直接获取。 软件资料清单…

文本生成sql模型(PipableAI/pip-sql-1.3b)

安装环境 pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install transformers 代码 question "What are the email address, town and county of the customers who are of the least common gender?"sc…

three.js - MeshStandardMaterial(标准网格材质)- 金属贴图、粗糙贴图

金属贴图、粗糙贴图 金属贴图:metalnessMap 和 粗糙贴图:roughnessMap,是用于模拟物体表面属性的两种重要贴图技术,这两种贴图,通常与基于物理的渲染(PBR)材质(如:MeshSt…

linux进程是什么?

进程概念 进程Process是指计算机中已运行的程序,是系统进行资源分配和调度的基本单位,是操作系统结构的基础。 在早期面向进程设计的计算机结构中,进程是程序的基本执行实体。在当代面向线程设计的计算机结构中,进程是线程的容器…

K210视觉识别模块学习笔记6: 识别苹果_图形化操作函数_

今日开始学习K210视觉识别模块: 图形化操作函数 亚博智能 K210视觉识别模块...... 固件库: canmv_yahboom_v2.1.1.bin 训练网站: 嘉楠开发者社区 今日学习如何在识别到目标的时候添加图形化操作:(获取坐标、框出目标等) 在识别苹果的基础上 学习与添加 这些操…

docker配置国内镜像加速器

1、搜索阿里云 2、搜索容器镜像服务 点击管理控制台 配置镜像加速器

鸿蒙NEXT开发:工具常用命令—install

安装三方库。 命令格式 ohpm install [options] [[<group>/]<pkg>[<version> | tag:<tag>]] ... ohpm install [options] <folder> ohpm install [options] <har file> alias: i 说明 group&#xff1a;三方库的命名空间&#xff0c;可…

count(*) over (partition by ……)用法详解

select id,count(*) over(partition by pro_id) from sal; 以pro_id分组&#xff0c;统计分组后每个pro_id的记录总数及对应的id&#xff1b; 类似还有count(*) over(order by ……)、sum(amount) over(partition by ……)等&#xff0c;略有区别

在Linux Ubuntu系统中使用Pascal语言

Pascal是一种结构化编程语言&#xff0c;而Free Pascal作为其现代编译器&#xff0c;不仅支持跨多种操作系统和处理器架构&#xff0c;还提供了高效的内存使用和函数重载等先进功能。Free Pascal继承了Pascal语言的核心特性&#xff0c;同时进行了扩展和优化&#xff0c;使其成…

Apache Flink类型及序列化研读生产应用|得物技术

一、背景 序列化是指将数据从内存中的对象序列化为字节流&#xff0c;以便在网络中传输或持久化存储。序列化在Apache Flink中非常重要&#xff0c;因为它涉及到数据传输和状态管理等关键部分。Apache Flink以其独特的方式来处理数据类型以及序列化&#xff0c;这种方式包括它…

彩虹PLM系统:引领汽车行业的数字化转型

彩虹PLM系统&#xff1a;引领汽车行业的数字化转型 彩虹PLM系统作为汽车行业数字化转型的引领者&#xff0c;凭借其卓越的技术实力和丰富的行业经验&#xff0c;为汽车行业带来了全面的解决方案。以下是彩虹PLM系统如何引领汽车行业数字化转型的详细分析&#xff1a; 一、整合全…

Redis 7.x 系列【7】数据类型之列表(List)

有道无术&#xff0c;术尚可求&#xff0c;有术无道&#xff0c;止于术。 本系列Redis 版本 7.2.5 源码地址&#xff1a;https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 概述2. 常用命令2.1 RPUSH2.2 LPUSH2.3 LRANGE2.4 LINDEX2.6 LREM2.7 LLEN2.8 LPOP…

8.计算机视觉—增广和迁移

目录 1.数据增广数据增强数据增强的操作代码实现2.微调 迁移学习 Transfer learning(重要的技术)网络结构微调:当目标数据集比源数据集小得多时,微调有助于提高模型的泛化能力。训练固定一些层总结代码实现1.数据增广 CES上的真实故事 有一家做智能售货机的公司,发现他们…