pytorch 生成手写数字图像

news2024/11/26 16:46:14

生成对抗网络的概念

最基本的GAN模型由一个生成器 G 和判别器 D 组成。生成器用于生成假样本,判别器用于判断样本是真实的还是假的。

  1. 生成器(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器
  2. 判别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器做的“假数据”

首先,固定判别器D,训练生成器G。让生成器不断生成假数据,然后让判别器D去判断,一开始生成器G生成的结果很容易被判别器D识别,然而随着不断的训练,生成器G效果不断提升,直到判别器无法分辩出数据的真假,也就是说这时判别器判断真假数据的概率为0.5.

然后固定生成器G,训练判别器D。当判别器无法分辩生成器生成的数据的时候,这时继续训练生成器是没有意义的。这时,可以训练判别器D,提升判别器D的性能。

数据集的显示

import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

transforms=transforms.Compose(
    [
    transforms.Resize(28),
    transforms.ToTensor(),
    # transforms.Normalize([0.5],[0.5]) ##均值,标准差
    ]
)
train_datasets=datasets.MNIST(root='./',train=True,download=True,transform=transforms)
test_datasets=datasets.MNIST(root='./',train=False,download=True,transform=transforms)

print('训练集的数量',len(train_datasets))
print('测试集的数量',len(test_datasets))

train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)
test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)

print('训练集可视化')
fig=plt.figure()
for i in range(12):
    plt.subplot(3,4,i+1)
    img=train_datasets.train_data[i]
    label=train_datasets.train_labels[i]
    plt.imshow(img,cmap='gray')
    plt.title(label)
    plt.xticks([])
    plt.yticks([])
plt.show()

print('测试集可视化')
fig=plt.figure()
for i in range(12):
    plt.subplot(3,4,i+1)
    img=test_datasets.test_data[i]
    label=test_datasets.test_labels[i]
    plt.imshow(img,cmap='gray')
    plt.title(label)
    plt.xticks([])
    plt.yticks([])
plt.show()

 

网络结构

import numpy as np
import argparse
import torch.nn as nn
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
opt = parser.parse_args()
class Generator(nn.Module):
    """
    生成器,根据一组随机的向量生成一组图像
    """
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))  #BatchNorm:在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布,momentum=0.8
            layers.append(nn.LeakyReLU(0.2, inplace=True))    #inplace = True ,直接覆盖原输入数据的值
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),     #opt.latent_dim,100维的随机噪声
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod((opt.channels, opt.img_size, opt.img_size)))),         #np.prod(img_shape),返回1*28*28
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        # img = img.view(img.size()[0], *(opt.channels, opt.img_size, opt.img_size))
        img=img.view(img.size()[0],opt.channels,opt.img_size,opt.img_size)
        return img

class Discriminator(nn.Module):
    """
    判别器是用来判断生成器生成图片的真假,判别器效果越真越好,直到最后判别无法判别生成器的输出(即输出概率为0.5的时候)
    """
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod((opt.channels, opt.img_size, opt.img_size))), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# model=Discriminator()
# input=torch.rand(10,int(np.prod((opt.channels, opt.img_size, opt.img_size))))
# output=model(input)
# print('判别器的输出',output.shape)
#
# model=Generator()
# input=torch.rand(10,100)
# output=model(input)
# print('生成器的输出',output.shape)

  

模型的训练

先训练判别器,在训练生成器

import argparse
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
from model import Generator,Discriminator

parser = argparse.ArgumentParser()   #创建一个参数对象
#调用 add_argument() 方法给 ArgumentParser对象添加程序所需的参数信息
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
opt = parser.parse_args() # parse_args()返回我们定义的参数字典
print(opt)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device',device)
transforms=transforms.Compose(
    [
    transforms.Resize(opt.img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]) ##均值,标准差
    ]
)
train_datasets=datasets.MNIST(root='./',train=True,download=True,transform=transforms)
lenth = 10000
train_datasets, _ = torch.utils.data.random_split(train_datasets, [lenth, len(train_datasets) - lenth])
# test_datasets=datasets.MNIST(root='./',train=False,download=True,transform=transforms)

print('训练集的数量',len(train_datasets))
# print('测试集的数量',len(test_datasets))

train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)
# test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)

# 损失函数
adversarial_loss = torch.nn.BCELoss().to(device)

# 定义网络结构
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 优化器的设置
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))  #Betas是动量梯度的下降


for epoch in range(opt.n_epochs):
    total_d_loss=0
    total_g_loss=0
    #开始训练
    for i, (img, _) in enumerate(train_loader):

        ##将图片变为1维数据
        real_img=img.view(img.size()[0],-1)

        #定义真实的图片label为1
        real_label=torch.ones(img.size()[0],1)
        #定义假的图片label为0
        fake_label=torch.zeros(img.size()[0],1)

        #判别器训练

        #将真实图片输入到判别器中
        real_out=discriminator(real_img)

        #得到真实图片的loss
        d_loss_real=adversarial_loss(real_out,real_label)
        #得到真实图片的判别值,real_out输出的值越接近1越好
        real_scores=real_out

        #计算假图片的损失
        noise=torch.randn(img.size()[0],opt.latent_dim) ##随机生成一些噪声,

        ##将随机噪声放入生成网络中,生成一张假的图片
        #避免梯度传到生成器,这里生成器不用更新,detach分离
        fake_img=generator(noise).detach()
        #判别器判断假的图片
        fake_out=discriminator(fake_img)
        #得到假图片的loss
        d_loss_fake=adversarial_loss(fake_out,fake_label)
        #得到假图片的判别值,对于判别器来讲,假图片的d_loss_fake越接近越好
        d_loss=d_loss_real+d_loss_fake  ##损失包含判真损失和判假损失

        total_d_loss+=d_loss.data.item()

        optimizer_D.zero_grad()  #反向传播之前,将梯度归0
        d_loss.backward()  #将误差反向传播
        optimizer_D.step() #更新参数

        #训练生成器
        #原理:目的是希望生成的假图片可以被判别器判断为真的图片
        #在此过程中,将判别器固定,将假的图片传入判别器的结果real_label对应
        #使得生成的图片让判别器以为是真的。这样就达到了对抗的目的

        #计算假图片的损失
        noise=torch.randn(img.size()[0],opt.latent_dim) #随机生成一些噪声
        fake_img=generator(noise)  ##随机噪声输入到生成器中,得到一幅假的图片
        output=discriminator(fake_img)  ##经过判别器得到的结果
        g_loss=adversarial_loss(output,real_label)

        total_g_loss+=g_loss.data.item()

        #反向传播  更新参数
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    #打印每个epoch 的损失
    print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f}'.format(epoch,opt.n_epochs,total_d_loss/len(train_loader),total_g_loss/len(train_loader)))

    torch.save(generator,'./gen.pth')
    torch.save(discriminator,'./dis.pth')

先训练生成器,在训练判别器

import argparse
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
from model import Generator,Discriminator

parser = argparse.ArgumentParser()   #创建一个参数对象
#调用 add_argument() 方法给 ArgumentParser对象添加程序所需的参数信息
parser.add_argument("--n_epochs", type=int, default=10, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
opt = parser.parse_args() # parse_args()返回我们定义的参数字典
print(opt)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device',device)
transforms=transforms.Compose(
    [
    transforms.Resize(opt.img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]) ##均值,标准差
    ]
)
train_datasets=datasets.MNIST(root='./',train=True,download=True,transform=transforms)
# lenth = 60000
# train_datasets, _ = torch.utils.data.random_split(train_datasets, [lenth, len(train_datasets) - lenth])
# test_datasets=datasets.MNIST(root='./',train=False,download=True,transform=transforms)

print('训练集的数量',len(train_datasets))
# print('测试集的数量',len(test_datasets))

train_loader = DataLoader(train_datasets, batch_size=opt.batch_size, shuffle=True)
# test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)

# 损失函数
adversarial_loss = torch.nn.BCELoss().to(device)

# 定义网络结构
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 优化器的设置
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))  #Betas是动量梯度的下降


for epoch in range(opt.n_epochs):
    total_d_loss=0
    total_g_loss=0
    #开始训练
    for i, (img, _) in enumerate(train_loader):

        ##将图片变为1维数据
        real_img=img.view(img.size()[0],-1)

        #定义真实的图片label为1
        real_label=torch.ones(img.size()[0],1)
        #定义假的图片label为0
        fake_label=torch.zeros(img.size()[0],1)

        #训练生成器
        #原理:目的是希望生成的假图片可以被判别器判断为真的图片
        #在此过程中,将判别器固定,将假的图片传入判别器的结果real_label对应
        #使得生成的图片让判别器以为是真的。这样就达到了对抗的目的

        #计算假图片的损失
        noise=torch.randn(img.size()[0],opt.latent_dim) #随机生成一些噪声
        fake_img=generator(noise).detach()  ##随机噪声输入到生成器中,得到一幅假的图片
        output=discriminator(fake_img)  ##经过判别器得到的结果
        g_loss=adversarial_loss(output,real_label)

        total_g_loss+=g_loss.data.item()
        #反向传播  更新参数
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        #判别器训练
        #将真实图片输入到判别器中
        real_out=discriminator(real_img)

        #得到真实图片的loss
        d_loss_real=adversarial_loss(real_out,real_label)
        #计算假图片的损失
        noise=torch.randn(img.size()[0],opt.latent_dim) ##随机生成一些噪声,

        ##将随机噪声放入生成网络中,生成一张假的图片
        #避免梯度传到生成器,这里生成器不用更新,detach分离
        fake_img=generator(noise).detach()
        #判别器判断假的图片
        fake_out=discriminator(fake_img)
        #得到假图片的loss
        d_loss_fake=adversarial_loss(fake_out,fake_label)
        #得到假图片的判别值,对于判别器来讲,假图片的d_loss_fake越接近越好
        d_loss=d_loss_real+d_loss_fake  ##损失包含判真损失和判假损失

        total_d_loss+=d_loss.data.item()

        optimizer_D.zero_grad()  #反向传播之前,将梯度归0
        d_loss.backward()  #将误差反向传播
        optimizer_D.step() #更新参数

    #打印每个epoch 的损失
    print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f}'.format(epoch,opt.n_epochs,total_d_loss/len(train_loader),total_g_loss/len(train_loader)))

    torch.save(generator,'./gen.pth')
    torch.save(discriminator,'./dis.pth')

模型的测试

import argparse
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
import numpy as np
parser = argparse.ArgumentParser()   #创建一个参数对象
#调用 add_argument() 方法给 ArgumentParser对象添加程序所需的参数信息
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
opt = parser.parse_args() # parse_args()返回我们定义的参数字典
print(opt)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device',device)
transforms=transforms.Compose(
    [
    transforms.Resize(opt.img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]) ##均值,标准差
    ]
)
train_datasets=datasets.MNIST(root='./',train=True,download=True,transform=transforms)
lenth = 10000
train_datasets, _ = torch.utils.data.random_split(train_datasets, [lenth, len(train_datasets) - lenth])
test_datasets=datasets.MNIST(root='./',train=False,download=True,transform=transforms)

print('训练集的数量',len(train_datasets))
# print('测试集的数量',len(test_datasets))

train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)
# test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)

# 损失函数
adversarial_loss = torch.nn.BCELoss().to(device)

# 定义网络结构
generator = torch.load('./gen.pth',map_location=lambda storage, loc: storage)
discriminator = torch.load('./dis.pth',map_location=lambda storage, loc: storage)

noise=torch.randn(12,opt.latent_dim)
print('生成随机噪声',noise.shape)
image=generator(noise)
print('生成的图片',image.shape)

##判别器进行判断
output=discriminator(image)
#判别器判定大于0.5为真,小于0.5为假。所以判别器最好的结果是为0.5,即分不清楚真假
print('判别器的输出',output)
print('判别器输出的平均值',torch.mean(output))

for i in range(12):
    plt.subplot(3,4,i+1)
    img=image[i].reshape(28,28)
    plt.imshow(np.array(img.data),cmap='gray')
    plt.xticks([])
    plt.yticks([])
plt.show()

 参考文献:

基于Pytorch用GAN生成手写数字实例(附代码)_使者大牙的博客-CSDN博客_gan生成手写数字 pytorch

GAN学习总结三-Pytorch实现利用GAN进行MNIST手写数字生成_DaneAI的博客-CSDN博客_plt.rcparams['figure.figsize'] = (10.0, 8.0) # 设置画

【pytorch】基于mnist数据集的cgan手写数字生成实现_Xavier Jiezou的博客-CSDN博客_cgan mnist pytorch 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/88854.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jQuery - AJAX 简介

什么是 AJAX? AJAX 异步 JavaScript 和 XML(Asynchronous JavaScript and XML)。 简短地说,在不重载整个网页的情况下,AJAX 通过后台加载数据,并在网页上进行显示。 使用 AJAX 的应用程序案例&#xff…

个人简介网页设计作业 静态HTML个人介绍网页作业 DW个人网站模板下载 WEB静态大学生简单网页 个人网页作品代码 个人网页制作 学生个人网页

🎉精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

[附源码]Nodejs计算机毕业设计基于百度AI平台的财税报销系统Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置: Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分…

Redis中BIO、NIO、IO多路复用

1 BIO(阻塞IO) 阻塞IO就是两个阶段都必须阻塞等待 通常IO操作都是阻塞I/O的,也就是说当你调用read时,如果没有数据收到,那么线程或者进程就会被挂起,直到收到数据。 read直到数据复制到应用进程的缓冲区或者发生错误才会返回&am…

【数据结构与算法】第十六篇:图论(基础篇)

知识导航图形结构的引进图(Grapth)1.图的概念与应用2.有向图入度,出度3.无向图4.完全图无向完全图有向完全图5.连通图6.连通分量强连通分量图的实现方案1.邻接矩阵实现法2.邻接表实现法3.两种方法对比分析图形结构的引进 🌎 数据…

Linux基础-目录操作

该文章主要为完成实训任务及总结,详细实现过程及结果见【参考文章】 参考文章:https://howard2005.blog.csdn.net/article/details/126962205 文章目录一、常用权限操作1.1 常用权限操作1. chgrp命令2. chown命令3. chmod命令1.2 权限操作实战任务1 创建…

14、Redis_主从复制

文章目录14、Redis_主从复制14.1 是什么14.2. 能干嘛14.3 怎么玩:主从复制14.3.1 新建redis6379.conf,填写以下内容14.3.2 新建redis6380.conf,填写以下内容14.3.3 新建redis6381.conf,填写以下内容14.3.4 启动三台redis服务器14.…

java项目_第173期ssm高校二手交易平台_计算机毕业设计

java项目_第173期ssm高校二手交易平台_计算机毕业设计 【源码请到下载专栏下载】 今天分享的项目是《ssm高校二手交易平台》 该项目分为2个角色,管理员和用户。 用户可以浏览前台商品,并且进行购买商品,并在 个人后台查看自己的订单、查看商品…

DPDK源码分析之DPDK技术简介

Cache和内存技术 1. Cache一致性 多核处理器同时访问同一段cacheline时,会出现写回冲突的情况,操作系统解决这个问题会消耗一部分性能,DPDK采用了两个技术来解决这个问题: 对于共享的数据,每个核都定义自己的备份lc…

区块链学习2-合约开发

概述 智能合约本质上是运行在某种环境(例如虚拟机)中的一段代码逻辑。 长安链的智能合约是运行在长安链上的一组“动态代码”,类似于Fabric的chaincode,Fabric的智能合约称为链码(chaincode),…

对氯间二甲苯酚在活性污泥发酵过程中重塑ARGs的机制类别

2022年8月,凌恩生物客户河海大学罗景阳教授团队在《Science of the Total Environment》期刊上发表研究论文“Para-chloro-meta-xylenol reshaped the fates of antibiotic resistance genes during sludge fermentation: Insights of cell membrane permeability, …

ChatGPT技术解构

ChatGPT的训练主要分为三个步骤,如图所示: Step1: 使用有监督学习方式,基于GPT3.5微调训练一个初始模型;训练数据约为2w~3w量级(根据InstructGPT的训练数据量级估算,参照https://arxiv.org/pdf…

【内网安全-防火墙】防火墙、协议、策略

目录 一、基础知识 1、防火墙五个域 2、协议模型 二、出入站策略 1、单个机器防火墙 2、域控的防火墙 3、安全策略 一、基础知识 1、防火墙五个域 1、Untrust(不信任域,低级安全区域): 用来定义Internet等不安全的网络,用于网络入口线的接入 ——…

沁恒 CH32V003J4M6 开发测试

一、概述 具体看图,SOP8价格在0.6R,TSSOP20价格在0.7R,优势太大了 二、开发准备 通过原厂可以拿到样片,目前我拿到这颗是SOP8,另外官方淘宝可以买到TSSOP20的测试板,也带样片购买WCHLINK,TB…

Mysql 进阶(面向面试篇)锁

全局锁、表级锁(表锁、元数据锁、意向锁(意向共享锁、意向排它锁))、行级锁(行锁(共享锁、排它锁)、间隙锁、临键锁) 表级锁(表锁(表锁分为:表共…

仅差一步!如何缩短加入购物车与成单的距离?

不知不觉,2022年已接近尾声,经历了卡塔尔世界杯、黑色星期五等跨境电商狂欢节后,不少跨境电商卖家都在开展复盘行动,为接下来的圣诞节运营计划打下扎实基础。时常关注跨境电商行业的人都知道,衡量跨境电商广告效率的关…

Python函数、类和对象、流程控制语句if-else while的讲解及演示(图文解释 附源码)

一、函数 函数是完成某个功能的代码段,可被其他代码调用,调用的代码可以将数据传递给函数,函数可将对数据的处理结果返回给调用代码。 def mysubt( a, b 0 ): # 定义一个自己的减法函数,第二个参数为默认值为0的默认参数c a -…

2023年湖北报考施工员要多少钱?甘建二告诉您

2023年湖北报考施工员要多少钱?甘建二告诉您 2023年湖北报考施工员要多少钱,甘建二告诉您 2023年武汉报考施工员要多少钱,甘建二告诉您 2023年黄冈报考施工员要多少钱,甘建二告诉您 2023年黄石报考施工员要多少钱,甘…

HBase Java API 开发:批量操作 第3关:批量导入数据至HBase

每一次只添加一个数据显然不像是大数据开发,在开发项目的时候也肯定会涉及到大量的数据操作。 使用Java进行批量数据操作,其实就是循环的在Put对象中添加数据最后在通过Table对象提交。 如何进行批量操作呢,讲到批量操作,相信大…

秋招必备!阿里产出的高并发+JVM豪华套餐送给你,绝对硬核干货

**3、设计了方案,但细节掌握不透彻:**讲不出方案要关注的技术点和可能带来的消极影响。比如读性能有瓶颈会引入缓存,但是忽视了缓存命中率、数据一致性、热点key等问题。 面对马上就要到来的双十一的秒杀环节,你是否已经有备无患…