学习笔记:Pytorch利用MNIST数据集训练生成对抗网络(GAN)

news2024/9/27 7:17:57

2023.8.27

       在进行深度学习的进阶的时候,我发了生成对抗网络是一个很神奇的东西,为什么它可以“将一堆随机噪声经过生成器变成一张图片”,特此记录一下学习心得。

一、生成对抗网络百科

        2014年,还在蒙特利尔读博士的Ian Goodfellow发表了论 文《Generative Adversarial Networks》(网址: https://arxiv.org/abs/1406.2661),将生成对抗网络引入 深度学习领域。2016年,GAN热潮席卷AI领域顶级会议, 从ICLR到NIPS,大量高质量论文被发表和探讨。Yann LeCun曾评价GAN是“20年来机器学习领域最酷的想法”。

机器学习的模型可大体分为两类,生成模型( Generative Model)和判别模型(Discriminative Model)。判别模型需要输入变量 ,通过某种模型来 预测 。生成模型是给定某种隐含信息,来随机产生观 测数据。

GAN百科:

GAN(生成对抗网络)的系统全面介绍(醍醐灌顶)_打灰人的博客-CSDN博客

二、GAN代码

训练代码:

                epoch=1000时的效果就不错啦

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


class Generator(nn.Module):  # 生成器
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


class Discriminator(nn.Module):  # 判别器
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img = img.view(img.size(0), -1)
        validity = self.model(img)
        return validity


def gen_img_plot(model, test_input):
    pred = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((pred[i] + 1) / 2)
        plt.axis('off')
    plt.show(block=False)
    plt.pause(3)  # 停留0.5s
    plt.close()


# 调用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 超参数设置
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 1000

# 数据集载入和数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 训练数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 测试数据 torch.randn()函数的作用是生成一组均值为0,方差为1(即标准正态分布)的随机数
# test_data = torch.randn(batch_size, latent_dim).to(device)
test_data = torch.FloatTensor(batch_size, latent_dim).to(device)

# 实例化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 开始训练模型
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        batch_size = imgs.shape[0]
        real_imgs = imgs.to(device)

        # 训练判别器
        z = torch.FloatTensor(batch_size, latent_dim).to(device)
        z.data.normal_(0, 1)
        fake_imgs = generator(z)  # 生成器生成假的图片

        real_labels = torch.full((batch_size, 1), 1.0).to(device)
        fake_labels = torch.full((batch_size, 1), 0.0).to(device)

        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        z.data.normal_(0, 1)
        fake_imgs = generator(z)

        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        torch.save(generator.state_dict(), "Generator_mnist.pth")

    print(f"Epoch [{epoch}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")

# gen_img_plot(Generator, test_data)
gen_img_plot(generator, test_data)

测试代码:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')


class Generator(nn.Module):  # 生成器
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


# test_data = torch.FloatTensor(128, 100).to(device)
test_data = torch.randn(128, 100).to(device)  # 随机噪声

model = Generator(100).to(device)
model.load_state_dict(torch.load('Generator_mnist.pth'))
model.eval()

pred = np.squeeze(model(test_data).detach().cpu().numpy())

for i in range(64):
    plt.subplot(8, 8, i + 1)
    plt.imshow((pred[i] + 1) / 2)
    plt.axis('off')
plt.savefig(fname='image.png', figsize=[5, 5])
plt.show()

三、结果

       在超参数设置 epoch=1000,batch_size=128,lr=0.0001,latent_dim = 100 时,gan生成的权重测的结果如图所示

四,GAN的损失函数曲线

                一开始训练时,我的gan的损失函数的曲线是类似这样的,就是知乎这文章里一样,生成器损失函数的曲线一直发散。首先,这个loss的曲线一看就是网络崩了,一般正常的情况,d_loss的值会一直下降然后收敛,而g_loss的曲线会先增大后减少,最后同样也会收敛。其次,网络拿到手以后先不要训练太多次,容易出现过拟合的情况。

生成对抗网络的损失函数图像如下合理吗? - 知乎

这是训练了10轮的生成器和鉴别器的损失函数值变化吧:

效果如图所示: 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/934648.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习笔记230827--vue项目中,子组件拿不到父组件异步获取数据的问题

问题描述 父组件的数据是请求后台所得&#xff0c;因为是异步数据&#xff0c;就会出现&#xff0c;父组件的值传递过去了&#xff0c;子组件加载不到&#xff0c;拿不到值的问题。 下面从同步数据传递和异步数据传递开始论述问题 1. 父组件传递的是同步数据 父组件 <…

【Spring】什么是 AOP(面向切面编程) ? 为什么要有 AOP ? 如何实现 Spring AOP ?

文章目录 前言一、什么是 AOP ?二、为什么要使用 AOP ?三、 AOP 的组成四、Spring AOP 的实现1, 添加依赖2, 定义切面3, 定义切点4, 定义通知5, 创建连接点 总结 前言 各位读者好, 我是小陈, 这是我的个人主页, 希望我的专栏能够帮助到你: &#x1f4d5; JavaSE基础: 基础语法…

计算机视觉 – Computer Vision | CV

计算机视觉为什么重要&#xff1f; 人的大脑皮层&#xff0c; 有差不多 70% 都是在处理视觉信息。 是人类获取信息最主要的渠道&#xff0c;没有之一。 在网络世界&#xff0c;照片和视频&#xff08;图像的集合&#xff09;也正在发生爆炸式的增长&#xff01; 下图是网络上…

Linux操作系统--shell编程(helloworld初体验)

1.shell概述 shell是一个命令行解释器,它接受应用程序/用户命令,然后调用操作系统的内核,以完成所谓的功能指令。 Linux中常用的解析器 CentOS7使用的解析器是bash,这里的sh是指向bash 2.Shell脚本入门 下面我们开始学习编写shell脚本,我们从HelloWorld开始。

怎么找到真实可用的淘宝拼多多京东API?(商品数据订单数据销量价格接口)

要找到真实可用的淘宝、拼多多、京东API&#xff0c;可以采取以下步骤&#xff1a; 打开相应电商平台的开放平台网站&#xff0c;例如淘宝开放平台、拼多多开放平台、京东开放平台等。在网站中注册并登录&#xff0c;找到API文档或开发者文档等页面。在文档中搜索与所需功能相…

代码随想录算法训练营之JAVA|第三十八天|494. 目标和

今天是第38天刷leetcode&#xff0c;立个flag&#xff0c;打卡60天。 算法挑战链接 494. 目标和https://leetcode.cn/problems/target-sum/ 第一想法 题目理解&#xff1a;题目给出一个数组&#xff0c;使用 或 - 算术符号&#xff0c;有多少种组合可以得到target的值。 拿…

URL中传递JSON字符串

今天遇见了一个需求&#xff0c;从post请求中在url里传递json字符串&#xff0c; 就是路径?参数11那种情况 最后怎么解决的呢&#xff1f; 需要使用前端方法&#xff0c;先用JSON.stringify格式化成字符串&#xff0c;再用encodeURIComponent把JSON里面的符号转转为url支持的…

1.2 Kali Linux的网络配置

前言 最新文章请见此处&#xff0c;持续更新&#xff0c;敬请订阅&#xff01;https://blog.csdn.net/algorithmyyds/category_12418682.html 网络在如今的社会已是十分重要的媒介&#xff0c;如果没有网络&#xff0c;很多事情将难以办成。渗透测试也是一样——毕竟在攻击机…

新生报到:无压力的数字自我介绍

&#x1f338; 新生报到&#xff1a;无压力的数字自我介绍 &#x1f338; 开学季又来临&#xff0c;每个学校、每个班级都迎来了一批新鲜面孔。作为新生&#xff0c;面对陌生的环境和同学&#xff0c;首次的自我介绍无疑是一个让许多人感到紧张和迷茫的挑战。你是否曾因为害羞…

理解底层— —Golang的log库,二开实现自定义Logger

理解底层— —Golang的log库&#xff0c;实现自定义Logger 1 分析实现思路 基于golang中自带的log库实现&#xff1a;对日志实现设置日志级别&#xff0c;每天生成一个文件&#xff0c;同时添加上前缀以及展示文件名等 日志级别&#xff0c;通过添加prefix&#xff1a;[INFO]、…

C 字符串处理

字符数组 输入输出 输入函数 scanf(%s, s)读入字符串&#xff0c;在第一个空白符( 、\n 、\t )处停止&#xff0c;不读入空白符&#xff0c;在串尾自动添加\0’ 。gets(s)读入一行字符&#xff0c;直到遇到\n &#xff0c;读入换行符并将其舍弃&#xff0c;在串尾自动添加\…

H5如何做性能测试?

说起H5性能测试&#xff0c;可能许多同学有所耳闻&#xff0c;但是不知道该如何去做性能测试&#xff0c;或者不知道H5应该关注哪些性能指标。今天我们就来看下。希望阅读本文后&#xff0c;能够有所了解。 常用指标 1、H5性能相关参数介绍 白屏时间&#xff1a;用户首次看到…

[LitCTF 2023]PHP是世界上最好的语言!!

进入环境看起来还是挺牛逼的&#xff0c;但是在右边输入框下有一个执行代码&#xff0c;有点牛 真的可以直接执行&#xff0c;那么 根据题目提示&#xff0c;我们得知flag&#xff0c;在根目录&#xff0c;所以我们可以直接利用 查看到flag位置 得到flag

IDEA插件反编译jar包

安装插件Java Decompiler 安装插件Java Decompiler成功之后重启idea 找到已安装插件的jar包 执行反编译 反编译 在已安装插件Java Decompiler的jar包位置下cmd命令执行反编译 java -cp "插件路径" org.jetbrains.java.decompiler.main.decompiler.ConsoleDec…

在线SM4(国密)加密解密工具

在线SM4(国密)加密解密工具

基于安卓的考研助手系统app 微信小程序

&#xff0c;设计并开发实用、方便的应用程序具有重要的意义和良好的市场前景。HBuilder技术作为当前最流行的操作平台&#xff0c;自然也存在着大量的应用服务需求。 本课题研究的是基于HBuilder技术平台的安卓的考研助手APP&#xff0c;开发这款安卓的考研助手APP主要是为了…

【管理运筹学】第 6 章 | 运输问题(2,表上作业法 | 初始可行解的确定)

文章目录 引言二、表上作业法2.1 初始基可行解的确定2.1.1 最小元素法2.1.2 伏格尔法 写在最后 引言 承接前文&#xff0c;在对运输问题有了基本的了解后&#xff0c;我们开始深入学习表上作业的具体内容。 二、表上作业法 2.1 初始基可行解的确定 2.1.1 最小元素法 基本思…

攻防世界-倒立屋

原题 解题思路 用StegSolve打开文件&#xff0c;调通道没用&#xff0c;wp说用RGB信道打开可以找到&#xff0c;但说实话用大括号也没找到在哪&#xff0c;得是预先知道答案才找得到。

Linux常用命令_文件处理命令:su root

文章目录 1. 命令格式与目录处理命令ls1.1 命令格式1.2 目录处理命令&#xff1a;ls 2. 目录处理命令2.1 目录处理命令&#xff1a;mkdir2.2 目录处理命令&#xff1a;cd2.3 目录处理命令&#xff1a;pwd2.4 目录处理命令&#xff1a;rmdir2.5 目录处理命令&#xff1a;cp2.6 目…

C语言文件操作收尾【随机读写 + 结束判定 + 文件缓冲区】

全文目录 前言fseek 重定位位置指示器函数ftell 获取当前文件指示器的位置rewind 重置位置指示器文本文件和二进制文件文件读取结束的判定feof 和 ferror 文件缓冲区总结 前言 有了文件的顺序读写基础&#xff0c;那么肯定会好奇文件的随机读写&#xff0c;毕竟顺序读写对于有…