使用pytorch构建一个初级的无监督的GAN网络模型

news2025/1/14 19:12:24

在这个系列中将系统的构建GAN及其相关的一些变种模型,来了解GAN的基本原理。本片为此系列的第一篇,实现起来很简单,所以不要期待有很好的效果出来。

第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN,使用MINIST数据集,包含0-9的60000张手写数字图像,如图:
在这里插入图片描述

原理

首先简单讲一下GAN的工作原理,如下为前向传播的过程:
在这里插入图片描述
GAN网络有两个模型,分别是生成器generator和判别器discriminator。generator的作用是生成图片的,也就是我们想要的结果,通过输入随机噪声来生成图片;而discriminator是判断输入的图片是真实数据还是生成的假数据,输入生成的假数据或真实数据,输出真与假的概率值。

而反向传播过程其实是分开的,即generator和discriminator是分别进行梯度更新的。且交替进行训练的,一个模型训练,另一个模型就要保持不变,保持两个模型的能力要相当才能一起进步,否则如果判别器的性能要比生成器要好的话就很容易陷入模式崩溃mdoel collapse或梯度消失等。
下图为discriminator的反向传播的过程:
在这里插入图片描述
discriminator的工作是为了将生成的假数据判别为0,将真实的数据判别为1,即公正判别非黑即白,所以loss的计算为:
在这里插入图片描述

下图为generator的反向传播的过程:
在这里插入图片描述
而generator的工作是为了将生成的假数据让discriminator判别为1,即骗过discriminator颠倒黑白,所以loss的计算为:
在这里插入图片描述

代码

下面开始直接上代码,我在网上学习别人代码的习惯是先把所有代码跑起来再来仔细看每个代码模块,我在这也就先放上所有代码再分析各个模块。
model.py:

from torch import nn
import torch

def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True),
    )

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
    def forward(self, noise):
        return self.gen(noise)
    def get_gen(self):
        return self.gen

def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
         nn.Linear(input_dim, output_dim), #Layer 1
         nn.LeakyReLU(0.2, inplace=True)
    )

class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, image):
        return self.disc(image)
    def get_disc(self):
        return self.disc

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import Discriminator, Generator
torch.manual_seed(0) # Set for testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples,z_dim,device=device)

criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'

dataloader = DataLoader(
    MNIST('./', download=True, transform=transforms.ToTensor()),  # 已经下载过的可以改为False跳过下载
    batch_size=batch_size,
    shuffle=True)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    return disc_loss

def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss
    
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()

        ### Update generator ###
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(
                f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

运行结果

运行后每隔500个epoch画出fake和real,刚开始的fake和real是这样的:
在这里插入图片描述
在这里插入图片描述
到后面的fake逐渐变成这样:
在这里插入图片描述

代码解释

网络模型

model.py里面存放了generator和discriminator的网络模型,神经元使用的是简单的全连接层,后面的文章再使用卷积。
在这里插入图片描述
生成器输出为784 = 28 * 28,因为使用的是MINIST手写字体数据集,每张图的shape是28 * 28 * 1(黑白图单通道),所以输出的假数据要与真实数据的shape一致,这样输入鉴别器才不会出错。
在这里插入图片描述
生成的图片(或真实数据)直接输入鉴别器,所以鉴别器的输入也是28*28,而输出为1,即输出判别结果为真或假。
在这里插入图片描述
每个优化器仅优化一个模型的参数,所以一个模型构建一个优化器。

图像显示

在这里插入图片描述
首先将图像的tensor转到cpu上,因为PyTorch中的大部分图像处理和显示函数都是在CPU上执行的,包括我们使用的imshow。
detach() 方法将张量从计算图中分离出来,但是仍指向原变量的存放位置,不同之处只是requirse_grad为false,得到的这个tensor永远不需要计算器梯度,不具有grad,这样做的目的是避免梯度计算的影响,因为在展示图像时通常不需要计算梯度。
Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系,类似这样:
在这里插入图片描述
一个网络模型就是一个计算图,在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,求导过程就如上图这样。
make_grid 函数用于将多个图像组成一个网格,方便显示。
在这里插入图片描述

在这里插入图片描述
然后每500个batch显示一次当前模型性能所能生成的图片以及当前batch的真实图片(虽然一个batch设置了128张,但是我们只展示25张),以及print出生成器和鉴别器的loss。

损失函数

在这里插入图片描述
损失函数的原理在上面的“原理”中有讲解,这里不再赘述。
在计算鉴别器的loss里,disc_fake_pred = disc(fake.detach())是对生成图片的判别,这里也使用 .detach() 的目的是将生成器产生的假数据与生成器的参数分离,使得在计算 disc_fake_pred 时不会对生成器的梯度进行传播。这是因为在训练鉴别器的阶段,我们只希望更新鉴别器的参数,而不希望更新生成器的参数(就如上面说的生成器的训练和鉴别器的应该要隔开分别训练、交替训练)。

反向传播

在这里插入图片描述
retain_graph=True参数是用来指示 PyTorch 在反向传播时保留计算图。这个参数的作用是为了在一次反向传播之后保留计算图的状态,以便后续再次调用 backward() 函数时能够继续使用这个计算图进行梯度计算。
Pytoch构建的计算图是动态图,为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放,所以当你想要多次backward时候就会报如下错:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

而在GAN中一次的迭代需要先更新鉴别器的参数,然后再更新生成器的参数;在更新生成器的参数时,我们仍然需要使用鉴别器来鉴别real or fake,只要使用到鉴别器就需要他的计算图。因此,我们需要在调用 disc_loss.backward() 后保留计算图,以便后续调用 gen_loss.backward() 时能够继续使用相同的计算图进行梯度计算。而对于生成器的梯度更新 gen_loss.backward(),不需要显式指定 retain_graph=True。
所以,在同一个计算图上多次调用 backward() 函数时才需要使用它。

下一篇构建DCGAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1553287.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3.28作业

#include <iostream> using namespace std; // 构造函数示例 class MyClass { private: int data; public: // 默认构造函数 MyClass() { data 0; } // 带参数的构造函数 MyClass(int value) { data value; } …

【vue核心技术实战精讲】1.9 Vue指令之v-model双向数据绑定

文章目录 前言本节内容1、v-model2、总结v-model 双向的数据绑定双向数据流&#xff08;绑定&#xff09; v-bind 和 v-model 的区别? 3-1、实战 <input>A、 输入框 type"text"示例效果 B、 单选按钮 type"radio"示例效果 C、 复选框 type"che…

|行业洞察·医药|《医药行业年终总结报告:政策篇-143页》

报告各部分的详细解读&#xff1a; 1. 政策概览 政策导读&#xff1a;2023年作为“十四五”发展时期的第三年&#xff0c;国家发布了许多关键性文件&#xff0c;对医药行业的采购、医保、医疗、医药等方面提出了明确的目标和规划。政策发布情况&#xff1a;截至12月19日&…

[LeetCode]516. 最长回文子序列[记忆化搜索解法详解]

最长回文子序列 LeetCode 原题链接 题目 给你一个字符串 s &#xff0c;找出其中最长的回文子序列&#xff0c;并返回该序列的长度。 子序列定义为&#xff1a;不改变剩余字符顺序的情况下&#xff0c;删除某些字符或者不删除任何字符形成的一个序列。 示例 1&#xff1a…

苹果应用商店上架工具的最新趋势与未来发展展望

摘要 移动应用app上架是开发者关注的重要环节&#xff0c;但常常会面临审核不通过等问题。为帮助开发者顺利完成上架工作&#xff0c;各种辅助工具应运而生。本文探讨移动应用app上架原理、常见辅助工具功能及其作用&#xff0c;最终指出合理使用工具的重要性。 引言 移动应…

YonBuilder移动开发基础——友开发App与自定义Loader

概述 在使用 YonBuilder移动开发 技术进行 App 项目开发时&#xff0c;我们需要使用YonStuido开发工具的 WIFI同步 功能进行项目代码的真机调试&#xff0c;友开发App 与 自定义Loader 都支持 WIFI同步 功能&#xff0c;那么两款 App 软件到底有什么区别&#xff0c;在开发过程…

C语言 | qsort()函数使用

目录&#xff1a; 1.qsort介绍 2.使⽤qsort函数 排序 整型数据 3.使⽤qsort函数 排序 结构体数据 4. qsort函数的模拟实现冒泡排序 qsort()函数 是一个 C语言编译器函数库自带的排序函数&#xff0c; 它可以对指定数组&#xff08;包括字符串&#xff0c;二维数组&#x…

STM32CubeMX学习笔记28---FreeRTOS软件定时器

一、软件定时器简介 1 、基本概念 定时器&#xff0c;是指从指定的时刻开始&#xff0c;经过一个指定时间&#xff0c;然后触发一个超时事件&#xff0c;用户 可以自定义定时器的周期与频率。类似生活中的闹钟&#xff0c;我们可以设置闹钟每天什么时候响&#xff0c; 还能设置…

【车体坐标系与世界坐标系的互相转换】能够一眼看懂的知识点!!!

本文讲解车体坐标系与世界坐标系互相转换的数学推导&#xff0c;如下图所示 将waypoint坐标从车体坐标系转换到世界坐标系&#xff1a; [ x ′ y ′ z ′ ] [ x y z ] [ cos ⁡ θ sin ⁡ θ 0 − sin ⁡ θ cos ⁡ θ 0 0 0 1 ] \left[\begin{array}{lll} x^{\prime} & …

产品经理的自我修养

点击下载《产品经理的自我修养》 1. 前言 在产品领域取得成功的关键在于持续的激情。只有保持热情不减,我们才能克服各种困难,打造出卓越的产品。 如果你真心渴望追求产品之路,我强烈建议你立即行动起来,亲自参与实际的产品创作。无论是建立一个网站、创建一个社群,还是…

axios发送get请求但参数中有数组导致请求路径多出了“[]“的处理办法

一、情况 使用axios发送get请求携带了数组参数时&#xff0c;请求路径中就会多出[]字符&#xff0c;而在后端也会报错 二、解决办法 1、安装qs 当前项目的命令行中安装 npm install qs2、引入qs库(使用qs库来将参数对象转换为字符串) // 全局 import qs from qs Vue.proto…

WPF中获取TreeView以及ListView获取其本身滚动条进行滚动

实现自行调节scoll滚动的位置(可相应获取任何控件中的内部滚动条) TreeView:TreeViewAutomationPeer lvap new TreeViewAutomationPeer(treeView); var svap lvap.GetPattern(PatternInterface.Scroll) as ScrollViewerAutomationPeer; var scroll svap.Owner as ScrollVie…

免费翻译pdf格式论文

进入谷歌翻译网址https://translate.google.com/?slauto&tlzh-CN&opdocs 将需要全文翻译的pdf放进去 选择英文到中文&#xff0c;然后点击翻译 可以选择打开译文或者下载译文&#xff0c;下载译文会下载到电脑上&#xff0c;打开译文会在浏览器打开。

华为数通方向HCIP-DataCom H12-821题库(多选题:221-240)

第221题 下面哪些路由协议支持通过命令配置发布缺省路由? A、OSPF B、IGMP C、ISIS D、BGP 【正确答案】ACD 【答案解析】 第222题 在route-policy中,能够用于apply子句的BGP属性有哪些? A、Local-Preference B. AS_Path C、Tag D、MED 【正确答案】ABD 【答案解析】 第22…

机器学习之决策树现成的模型使用

目录 须知 DecisionTreeClassifier sklearn.tree.plot_tree cost_complexity_pruning_path(X_train, y_train) CART分类树算法 基尼指数 分类树的构建思想 对于离散的数据 对于连续值 剪枝策略 剪枝是什么 剪枝的分类 预剪枝 后剪枝 后剪枝策略体现之威斯康辛州乳…

Redis 特性,为什么要用Redis,Redis到底是多线程还是单线程

一、Redis介绍 Redis&#xff08;Remote Dictionary Server )&#xff0c;即远程字典服务&#xff0c;是一个开源的&#xff0c;使用C语言编写、支持网络、可基于内存亦可持久化的日志型、Key-Value数据库&#xff0c;并提供多种语言的API。 二、特性(为什么要用Redis&#x…

Docker 夺命连环 15 问

目录 什么是Docker&#xff1f; Docker的应用场景有哪些&#xff1f; Docker的优点有哪些&#xff1f; Docker与虚拟机的区别是什么&#xff1f; Docker的三大核心是什么&#xff1f; 如何快速安装Docker&#xff1f; 如何修改Docker的存储位置&#xff1f; Docker镜像常…

护眼大路灯智商税吗?五款最佳护眼落地灯分享!

大路灯能够提供更加舒适健康的光线&#xff0c;而且大路灯还提供许多能够提高使用便捷度的大路灯&#xff0c;尤其是对于学生党以及上班族来说&#xff0c;大路灯是一款很好的用眼照明小帮手。然而&#xff0c;对于护眼大路灯智商税吗这个问题&#xff0c;很冥想不是&#xff0…

四川易点慧电子商务抖音小店:前景无忧,创新引领未来零售风潮

在数字经济高速发展的今天&#xff0c;电子商务已成为推动经济增长的重要引擎。四川易点慧电子商务有限公司紧跟时代步伐&#xff0c;积极布局抖音小店&#xff0c;展现出强劲的发展势头和广阔的前景。 抖音小店作为抖音平台上的重要商业生态&#xff0c;凭借其庞大的用户群体和…

jira安装与配置

1. 环境准备 环境要求 1) JDK1.8以上环境配置 2) Mysql数据库5.7.13 3) Jira版本7及破解包 1.1 JDK1.8安装配置 1) 首先下载 JDK1.8&#xff0c; - 网址&#xff1a;https://www.oracle.com/cn/java/technologies/javase/javase-jdk8-downloads.html - windows64 版&am…