PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

news2024/10/2 10:27:18

PyTorch深度学习实战(31)——生成对抗网络

    • 0. 前言
    • 1. GAN
    • 2. GAN 模型分析
    • 3. 利用 GAN 模型生成手写数字
    • 小结
    • 系列链接

0. 前言

生成对抗网络 (Generative Adversarial Networks, GAN) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用广泛,且 GAN 的相关研究正在迅速发展,以生成与真实图像难以区分的逼真图像。在本节中,我们将学习 GAN 网络的原理并使用 PyTorch 实现 GAN

1. GAN

生成对抗网络 (Generative Adversarial Networks, GAN) 包含两个网络:生成网络( Generator,也称生成器)和判别网络( discriminator,也称判别器)。在 GAN 网络训练过程中,需要有一个合理的图像样本数据集,生成网络从图像样本中学习图像表示,然后生成与图像样本相似的图像。判别网络接收(由生成网络)生成的图像和原始图像样本作为输入,并将图像分类为原始(真实)图像或生成(伪造)图像。
生成网络的目标是生成逼真的伪造图像骗过判别网络,判别网络的目标是将生成的图像分类为伪造图像,将原始图像样本分类为真实图像。本质上,GAN 中的对抗表示两个网络的相反性质,生成网络生成图像来欺骗判别网络,判别网络通过判别图像是生成图像还是原始图像来对输入图像进行分类:

GAN原理

在上图中,生成网络根据输入随机噪声生成图像,判别网络接收生成网络生成的图像,并将它们与真实图像样本进行比较,以判断生成的图像是真实的还是伪造的。生成网络尝试生成尽可能逼真的图像,而判别网络尝试判定生成网络生成图像的真实性,从而学习生成尽可能逼真的图像。
GAN 的关键思想是生成网络和判别网络之间的竞争和动态平衡,通过不断的训练和迭代,生成网络和判别网络会逐渐提高性能,生成网络能够生成更加逼真的样本,而判别网络则能够更准确地区分真实和伪造的样本。
通常,生成网络和判别网络交替训练,将生成网络和判别网络视为博弈双方,并通过两者之间的对抗来推动模型性能的提升,直到生成网络生成的样本能够以假乱真,判别网络无法分辨真实样本和生成样本之间的差异:

  • 生成网络的训练过程:冻结判别网络权重,生成网络以噪声 z 作为输入,通过最小化生成网络与真实数据之间的差异来学习如何生成更好的样本,以便判别网络将图像分类为真实图像
  • 判别网络的训练过程:冻结生成网络权重,判别网络通过最小化真实样本和假样本之间的分类误差来更新判别网络,区分真实样本和生成样本,将生成网络生成的图像分类为伪造图像

重复训练生成网络与判别网络,直到达到平衡,当判别网络能够很好地检测到生成的图像时,生成网络对应的损失比判别网络对应的损失要高得多。通过不断训练生成网络和判别网络,直到生成网络可以生成逼真图像,而判别网络无法区分真实图像和生成图像。

2. GAN 模型分析

为了生成手写数字的图像,我们采取以下策略:

  • 导入 MNIST 数据
  • 初始化随机噪声
  • 定义生成网络模型
  • 定义判别网络模型
  • 使用生成网络生成伪造图像,生成网络在最初只能生成噪声图像,噪声图像是通过将一组噪声值通过权重随机的神经网络得到的图像
  • 交替训练两个模型
    • 将生成的图像与原始图像串联起来,判别网络预测每个图像是伪造图像还是真实图像,对判别网络进行训练,判别网络的损失是图像的预测值和实际值(标签)的二进制交叉熵,生成的伪造图像的实际值(标签)为 0,原始数据集中真实图像的实际值(标签)为 1
    • 训练生成网络利用输入噪声生成伪造图像,使其看起来更接近真实图像,从而使生成图像有可能欺骗判别网络
    • 输入噪声通过生成网络传递输出伪造图像,将生成网络生成的图像输入到判别网络中,此时,判别网络权重被冻结,因为生成网络的目标是欺骗判别网络,因此,假设生成的伪造图像实际值(标签)为 1,生成网络的损失是判别网络对输入图像的预测值和实际值 (1) 的二进制交叉熵

了解了 GAN 的基本原理后,在下一小节,我们实现 GAN 生成 MNIST 手写数字图像。

3. 利用 GAN 模型生成手写数字

(1) 导入相关库并定义设备:

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
from torchvision.utils import make_grid
device = "cuda" if torch.cuda.is_available() else "cpu"

from torchvision.datasets import MNIST
from torchvision import transforms

(2) 导入 MNIST 数据,定义具有内置数据转换功能的数据加载器,以便缩放输入数据:

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))
])

data_loader = torch.utils.data.DataLoader(MNIST('MNIST/', train=True, download=True, transform=transform),batch_size=128, shuffle=True, drop_last=True)

(3) 定义判别网络模型类:

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential( 
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

在以上代码中,使用 LeakyReLU 激活函数替换 ReLU。打印判别网络的简要信息:

from torchsummary import summary
discriminator = Discriminator().to(device)
print(summary(discriminator, (1,784)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1              [-1, 1, 1024]         803,840
         LeakyReLU-2              [-1, 1, 1024]               0
           Dropout-3              [-1, 1, 1024]               0
            Linear-4               [-1, 1, 512]         524,800
         LeakyReLU-5               [-1, 1, 512]               0
           Dropout-6               [-1, 1, 512]               0
            Linear-7               [-1, 1, 256]         131,328
         LeakyReLU-8               [-1, 1, 256]               0
           Dropout-9               [-1, 1, 256]               0
           Linear-10                 [-1, 1, 1]             257
          Sigmoid-11                 [-1, 1, 1]               0
================================================================
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.57
Estimated Total Size (MB): 5.61
----------------------------------------------------------------

(4) 定义生成网络模型类 Generator

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

生成网络根据 100 维随机噪声输入生成图像。打印生成网络模型的简要信息:

generator = Generator().to(device)
print(summary(generator, (1,100)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1               [-1, 1, 256]          25,856
         LeakyReLU-2               [-1, 1, 256]               0
            Linear-3               [-1, 1, 512]         131,584
         LeakyReLU-4               [-1, 1, 512]               0
            Linear-5              [-1, 1, 1024]         525,312
         LeakyReLU-6              [-1, 1, 1024]               0
            Linear-7               [-1, 1, 784]         803,600
              Tanh-8               [-1, 1, 784]               0
================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.67
Estimated Total Size (MB): 5.71
----------------------------------------------------------------

(5) 定义函数生成随机噪声并将其注册到设备中:

def noise(size):
    n = torch.randn(size, 100)
    return n.to(device)

(6) 定义函数来训练判别网络。

判别网络训练函数 (discriminator_train_step) 将真实数据 (real_data) 和伪造数据 (fake_data) 作为输入:

def discriminator_train_step(real_data, fake_data, loss, d_optimizer):

重置优化器梯度:

    d_optimizer.zero_grad()

在对损失值执行反向传播之前,预测真实数据 (real_data) 并计算损失 (error_real):

    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()

在真实数据上计算判别网络损失时,我们期望判别网络预测输出为 1。因此,在判别网络的训练过程中,使用 torch.ones 作为标签,期望判别网络在真实数据上的输出为 1,从而计算判别网络在真实数据上的损失。

在对损失值执行反向传播之前,预测伪造数据 (fake_data) 并计算损失 (error_fake):

    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()

在伪造数据上计算判别网络损失时,我们期望判别网络预测输出为 0。因此,在判别网络的训练过程中,使用 torch.zeros 作为标签,期望判别网络在伪造数据上的输出为 0,从而计算判别网络在伪造数据上的损失。

更新权重并返回整体损失(将模型在 real_dataerror_realfake_dataerror_fake 的损失值相加):

    d_optimizer.step()
    return error_real + error_fake

(7) 训练生成网络模型。

定义生成网络训练函数 generator_train_step 并传入伪造数据 fake_data 作为参数:

def generator_train_step(real_data, fake_data, loss, g_optimizer):

重置优化器梯度:

    g_optimizer.zero_grad()

预测判别网络对伪造数据 (fake_data) 的输出:

    prediction = discriminator(fake_data)

在计算生成网络的损失时,使用 torch.ones 作为标签,期望判别网络在伪造数据上的输出为 1,以在训练生成网络时欺骗判别网络输出值 1,以此来鼓励生成网络生成更加逼真的数据,并让判别网络无法区分其真伪:

    error = loss(prediction, torch.ones(len(real_data), 1).to(device))

执行反向传播,更新权重,并返回损失:

    error.backward()
    g_optimizer.step()
    return error

(8) 定义模型对象、生成网络和判别网络的优化器,以及损失函数:

discriminator = Discriminator().to(device)
generator = Generator().to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()

(9) 训练模型。

循环训练模型 200epochs (num_epochs):

num_epochs = 200

d_loss_epoch = []
g_loss_epoch = []
for epoch in range(num_epochs):
    N = len(data_loader)
    d_loss_items = []
    g_loss_items = []
    for i, (images, _) in enumerate(data_loader):

加载真实数据 (real_data) 和伪造数据,其中,伪造数据是通过将大小与真实数据样本数相同的噪声数据 (batch_size = len(real_data)) 传入生成网络网络获得的。需要注意的是,必须调用 fake_data.detach(),否则训练无法正常进行。通过 detach() 函数分离出来一个新的张量,这样在 discriminator_train_step() 中调用 error.backward() 时,与生成网络相关的张量(生成 fake_data )不会受到影响。使用 discriminator_train_step 函数训练判别网络:

        real_data = images.view(len(images), -1).to(device)
        fake_data = generator(noise(len(real_data))).to(device)
        fake_data = fake_data.detach()

训练判别网络后,继续训练生成网络。从噪声数据生成一组新的伪造图像 (fake_data) 并使用 generator_train_step 函数训练生成网络:

        fake_data = generator(noise(len(real_data))).to(device)
        g_loss = generator_train_step(real_data, fake_data, loss, g_optimizer)

记录损失变化:

        d_loss_items.append(d_loss.item())
        g_loss_items.append(g_loss.item())
    d_loss_epoch.append(np.average(d_loss_items))
    g_loss_epoch.append(np.average(g_loss_items))

绘制判别网络和生成网络的损失随训练的变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, d_loss_epoch, 'bo', label='Discriminator Training loss')
plt.plot(epochs, g_loss_epoch, 'r-', label='Generator Training loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

模型性能检测

(10) 可视化模型训练后生成的伪造数据:

z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)
plt.imshow(grid.cpu().detach().permute(1,2,0), cmap='gray')
plt.show()

生成结果

在上图中,可以看到利用 GAN 生成逼真的图像,但仍有一定的改进空间,在之后的学习中,我们将介绍更多 GAN 的改进模型生成更逼真的图像。

小结

生成对抗网络是一种强大的深度学习模型,由生成器网络和判别器网络组成,通过彼此之间的竞争来提高性能,已经在图像生成、图像修复、图像转换和自然语言处理等领域取得了巨大的成功。其核心思想是通过生成器和判别器之间的博弈过程来实现真实样本的生成。生成器负责生成逼真的样本,而判别器则负责判断样本是真实还是伪造。通过不断的训练和迭代,生成器和判别器会相互竞争并逐渐提高性能。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)
PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(27)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(28)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(29)——神经风格迁移
PyTorch深度学习实战(30)——Deepfakes

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1402314.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于taro搭建小程序多项目框架,记重点了!!!

前言 为什么需要这样一个框架,以及这个框架带来的好处是什么? 从字面意思上理解:该框架可以用来同时管理多个小程序,并且可以抽离公用组件或业务逻辑供各个小程序使用。当你工作中面临这种同时维护多个小程序的业务场景时&#…

react umi/max 页签(react-activation)

思路:通过react-activation实现页面缓存,通过umi-plugin-keep-alive将react-activation注入umi框架,封装页签组件最后通过路由的wrappers属性引入页面。 浏览本博客之前先看一下我的博客实现的功能是否满足需求,实现功能&#xf…

日常常见应用组件升级记录

一、前言 因近期安全扫描,发现java后端应用涉及多个引用组件版本过低,涉及潜在漏洞利用风险,特记录相关处理升级处理过程,以备后续确认; 二、升级处理过程 2.1、Java类应用内置Spring Boot版本升级 Spring Boot是一…

【数据结构与算法】之字符串系列-20240121

这里写目录标题 一、344. 反转字符串二、125. 验证回文串三、205. 同构字符串四、242. 有效的字母异位词五、290. 单词规律 一、344. 反转字符串 简单 编写一个函数,其作用是将输入的字符串反转过来。输入字符串以字符数组 s 的形式给出。 不要给另外的数组分配额…

java SSM政府采购管理系统myeclipse开发mysql数据库springMVC模式java编程计算机网页设计

一、源码特点 java SSM政府采购管理系统是一套完善的web设计系统(系统采用SSM框架进行设计开发,springspringMVCmybatis),对理解JSP java编程开发语言有帮助,系统具有完整的源代 码和数据库,系统主要采…

【UE5】第一次尝试项目转插件(Plugin)的时候,无法编译

VS显示100条左右的错误,UE热编译也不能通过。原因可能是[名字.Build.cs]文件的错误,缺少一些内容,比如说如果要写UserWidget类,那么就要在 ]名字.Build.cs] 中加入如下内容: public class beibaoxitong : ModuleRules …

UML序列图入门

使用工具:processOn可免费建9个文件 UML:Unified Modeling Language 统一建模语言 ProcessOn - 我的文件 一、介绍 序列图是基于对象而非基于类; 从上到下:展示了处理过程中,每个对象的生命周期; 从左到…

从白子画到东方青苍,你选择谁来守护你的修仙之旅?

从白子画到东方青苍,你选择谁来守护你的修仙之旅? 在繁花似锦的修仙世界中,每一位追梦者都渴望有那么一位守护者,与你共患难,共成长。热血与浪漫交织的《花千骨》与《苍兰诀》,给我们带来了两位风华绝代的守护者:白子…

ST-GCN 人体姿态估计模型 代码实战

构建一个ST-GCN(Spatio-Temporal Graph Convolutional Network)模型需要结合图卷积网络(GCN)的思想,以处理时空数据。以下是一个简单的例子,演示如何使用PyTorch构建ST-GCN模型: import torch …

偷偷浏览小网站时,原来有这么多人已经知道

最近看到一篇挺有意思文章,偷偷浏览小网站时,都有谁会知道你看了啥。思量之下,从更广泛的技术角度看,仍有大量补充的空间,于是就有了这样一篇文章。本文的目的在于增强大家的网络安全意识,非必要不要浏览不…

护眼灯品牌哪个好?护眼灯最好的品牌排行

很多家长在帮助孩子选购台灯时可能会遇到一些困扰,而市面上存在许多质量不合格的台灯。曾经有央视节目曝光了许多LED台灯存在严重不合格的情况。尽管目前对于台灯执行强制性产品认证,并不定期进行抽样检查,但仍然存在一些未强制认证的产品在市…

EF Core实操,数据库生成实体,迁移

EF Core实操,数据库生成实体,迁移 大家好,我是行不更名,坐不改姓的宋晓刚,下面将带领大家进入C#编程EF Core数据库基础入门知识,如何连接数据库,如何编写代码,跟上我的步伐进入EF Core数据库下的世界。 家人们,如果有什么不懂,可以留言,或者加我联系方式,一起进入…

python系列-顺序/条件/循环语句

🌈个人主页: 会编程的果子君 ​💫个人格言:“成为自己未来的主人~” 目录 顺序语句 条件语句 什么是条件语句 语法格式 缩进和代码块 空语句pass 循环语句 while循环 for循环 continue break 顺序语句 默认情况下,Python的代码执行…

Elasticsearch+Kibana 学习记录

文章目录 安装Elasticsearch 安装Kibana 安装 Rest风格API操作索引基本概念示例创建索引查看索引删除索引映射配置(不配置好像也行、智能判断)新增数据随机生成ID自定义ID 修改数据删除数据 查询基本查询查询所有(match_all)匹配查…

为什么 HTTPS 协议能保障数据传输的安全性?

HTTP 协议 在谈论 HTTPS 协议之前,先来回顾一下 HTTP 协议的概念。 HTTP 协议介绍 HTTP 协议是一种基于文本的传输协议,它位于 OSI 网络模型中的应用层。 HTTP 协议是通过客户端和服务器的请求应答来进行通讯,目前协议由之前的 RFC 2616 拆…

C++ | 五、哈希表 Hash Table(数组、集合、映射)、迭代器

哈希表基础 哈希表是一类数据结构(哈希表包含数组、集合和映射,和前两篇文章叙述的字符串、链表平级)哈希表概念:类似于Python里的字典类型,哈希表把关键码key值通过哈希函数来和哈希表上的索引对应起来,之…

BGP AS-Path 选路试验

一、拓朴图: 实验要求:R5 通过改变对端传入的 AS-Path 路由策略,使原来较长的 AS-Path 成为最优 二、步骤: 1、配置IP 2、R1、R2、R3 配置 IGP 3、R1、R2、R3 配置 BGP,将 R1 创建的 Loop 100.1.1.1 的环回口在 IBGP …

uni-app小程序:文件下载打开文件方法苹果安卓都适用

api: const filetype e.substr(e.lastIndexOf(.)1)//获取文件地址的类型 console.log(文档,filetype) uni.downloadFile({url: e,//e是图片地址success(res) {console.log(res)if (res.statusCode 200) {console.log(下载成功,);var filePath encodeURI(res.tempFilePath);…

【第三课课后作业】基于 InternLM 和 LangChain 搭建你的知识库

基于 InternLM 和 LangChain 搭建你的知识库 1. 基础作业: 环境配置 1.1 InternLM 模型部署 创建开发机 进入 conda 环境之后,使用以下命令从本地一个已有的 pytorch 2.0.1 的环境,激活环境,在环境中安装运行 demo 所需要的依…

【赠书第17期】Excel高效办公:文秘与行政办公(AI版)

文章目录 前言 1 了解Excel的强大功能和工具 2 提升Excel技能的方法 3 结合AI技术提升Excel应用 4 注意事项 5 推荐图书 6 粉丝福利 前言 随着人工智能(AI)技术的快速发展,我们的工作方式也在发生深刻变革。其中,Excel 作…