PyTorch训练简单的生成对抗网络GAN

news2024/9/25 11:19:36

文章目录

    • 原理
    • 代码
    • 结果
    • 参考

原理

同时训练两个网络:辨别器Discriminator 和 生成器Generator
Generator是 造假者,用来生成假数据。
Discriminator 是警察,尽可能的分辨出来哪些是造假的,哪些是真实的数据。

目的:使得判别模型尽量犯错,无法判断数据是来自真实数据还是生成出来的数据。

GAN的梯度下降训练过程:

在这里插入图片描述
上图来源:https://arxiv.org/abs/1406.2661

Train 辨别器: m a x max max l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x)) + log(1 - D(G(z))) log(D(x))+log(1D(G(z)))

Train 生成器: m i n min min l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))

我们可以使用BCEloss来计算上述两个损失函数

BCEloss的表达式: m i n − [ y l n x + ( 1 − y ) l n ( 1 − x ) ] min -[ylnx + (1-y)ln(1-x)] min[ylnx+(1y)ln(1x)]
具体过程参加代码中注释

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim): # z_dim 噪声的维度
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim), # 28x28 -> 784
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.gen(x)
    
# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4 # 3e-4是Adam最好的学习率
z_dim = 64 # 噪声维度
img_dim = 784 # 28x28x1
batch_size = 32
num_epochs = 50

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose( # MNIST标准化系数:(0.1307,), (0.3081,)
    [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))] # 不同数据集就有不同的标准化系数
)

dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
# BCE 损失
criterion = nn.BCELoss()

# 打开tensorboard:在该目录下,使用 tensorboard --logdir=runs
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device) # view相当于reshape
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise) # G(z)
        disc_real = disc(real).view(-1) # flatten
        # BCEloss的表达式:min -[ylnx + (1-y)ln(1-x)]

        # max log(D(real)) 相当于 min -log(D(real))
        # ones_like:1填充得到y=1, 即可忽略  min -[ylnx + (1-y)ln(1-x)]中的后一项
        # 得到 min -lnx,这里的x就是我们的real图片
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        # max log(1 - D(G(z))) 相当于 min -log(1 - D(G(z)))
        # zeros_like用0填充,得到y=0,即可忽略  min -[ylnx + (1-y)ln(1-x)]中的前一项
        # 得到 min -ln(1-x),这里的x就是我们的fake噪声
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1-D(G(z))) <--> max log(D(G(z))) <--> min - log(D(G(z)))
        # 依然可使用BCEloss来做
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] \ "
                f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )

                step += 1

结果

训练50轮的的损失

Epoch [0/50] \ Loss D: 0.7366, Loss G: 0.7051
Epoch [1/50] \ Loss D: 0.2483, Loss G: 1.6877
Epoch [2/50] \ Loss D: 0.1049, Loss G: 2.4980
Epoch [3/50] \ Loss D: 0.1159, Loss G: 3.4923
Epoch [4/50] \ Loss D: 0.0400, Loss G: 3.8776
Epoch [5/50] \ Loss D: 0.0450, Loss G: 4.1703
...
Epoch [43/50] \ Loss D: 0.0022, Loss G: 7.7446
Epoch [44/50] \ Loss D: 0.0007, Loss G: 9.1281
Epoch [45/50] \ Loss D: 0.0138, Loss G: 6.2177
Epoch [46/50] \ Loss D: 0.0008, Loss G: 9.1188
Epoch [47/50] \ Loss D: 0.0025, Loss G: 8.9419
Epoch [48/50] \ Loss D: 0.0010, Loss G: 8.3315
Epoch [49/50] \ Loss D: 0.0007, Loss G: 7.8302

使用

tensorboard --logdir=runs

打开tensorboard:

在这里插入图片描述
可以看到效果并不好,这是由于我们只是采用了简单的线性网络来做辨别器和生成器。后面的博文我们会使用更复杂的网络来训练GAN。

参考

[1] Building our first simple GAN
[2] https://arxiv.org/abs/1406.2661

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/890049.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FreeRTOS qemu mps2-an385 bsp 移植制作 :系统运行篇

相关文章 FreeRTOS qemu mps2-an385 bsp 移植制作 &#xff1a;环境搭建篇 FreeRTOS qemu mps2-an385 bsp 移植制作 &#xff1a;系统启动篇 开发环境 Win10 64位 VS Code&#xff0c;ssh 远程连接 ubuntu VMware Workstation Pro 16 Ubuntu 20.04 FreeRTOSv202212.01&a…

DataFrame.rename()函数--Pandas

1. 函数作用 修改DataFrame的行名、列名 2. 函数语法 DataFrame.rename(mapperNone, *, indexNone, columnsNone, axisNone, copyNone, inplaceFalse, levelNone, errorsignore)3. 函数参数 参数含义mapper与axis结合使用&#xff0c;表示运用到axis上的值&#xff1a;类字…

SRS流媒体服务(四)WebRTC实现实时视频通话和低延时互动直播

CentOS版本号&#xff1a;7.9 SRS版本号&#xff1a;4.0.215 服务器IP&#xff1a;192.168.5.104 注意需要开启端口号&#xff1a;1935、1985、8000&#xff08;UDP端口&#xff09;、8080。 注意需要开启服务&#xff1a;http 文章目录 webRTC介绍getUserMediaRTCPeerCon…

编译老版本c++程序 报错 msvcrt.dll 以及 0x000000 内存 不能为 “read“ 问题已解决

一般 win10 编译 xp对应老版本软件 调试采用 虚拟机形式进行测试&#xff0c;但是虚拟机中&#xff0c;无独立显卡&#xff0c;运行程序提示有&#xff0c;无法调用动态库&#xff0c;或者 内存无法读取&#xff0c;炸一看以为 winxp32位 内存识别只能3.7G.其实是显存无法使用…

【网工常用的CMD窗口命令行,你一定得记的】

作为网络工程师&#xff0c;工作中经常是容易被作为甩锅的对象。凡是业务系统出现问题或其他各种问题&#xff0c;人们总是会认为是网络问题&#xff0c;最初的怀疑都指向网工。然而最终结果往往是自己服务器出了问题。 现在我将列举几种我在日常工作中经常使用的查看网络状态…

七月 NFT 行业解读:游戏和音乐 NFT 引领增长,Opepen 掀起热潮

作者&#xff1a;lesleyfootprint.network 2023 年 7 月&#xff0c;NFT 市场的波动性持续存在&#xff0c;交易量呈下降趋势。然而&#xff0c;游戏和音乐 NFT 等领域的增长引人注目。参与这些细分领域的独立用户数量不断增加&#xff0c;反映了这些领域的复苏。 本综合报告…

实时会话简易版

1、数据存储 Redis缓存、pgsql数据库 2、存储使用 2.1、Redis缓存 1&#xff09;无序集合set&#xff1a;存储未读会话id 2&#xff09;list&#xff08;左进右出&#xff09;&#xff1a;存储会话未读消息 2.2、pgsql数据库 存储用户信息&#xff0c;存储会话id&#…

Maven之mirrorof范围

mirrorOf 是 central 还是 * 的问题 在配置阿里对官方中央仓库的镜像服务器时&#xff0c;我们使用到了 <mirror> 元素。 <mirror><id>aliyunmaven</id><mirrorOf>central</mirrorOf><name>阿里云公共仓库</name><url>…

【计算机视觉|生成对抗】非配对图像到图像的翻译:使用循环一致对抗网络(CycleGAN)

本系列博文为深度学习/计算机视觉论文笔记&#xff0c;转载请注明出处 标题&#xff1a;Unpaired Image-to-Image Translation Using Cycle-Consistent Adversarial Networks 链接&#xff1a;[1703.10593] Unpaired Image-to-Image Translation using Cycle-Consistent Adver…

基于metrics-server弹性伸缩

目录 Kubernetes部署方式 基于kubeadm部署K8S集群 一、环境准备 1.1、主机初始化配置 1.2、部署docker环境 二、部署kubernetes集群 2.1、组件介绍 2.2、配置阿里云yum源 2.3、安装kubelet kubeadm kubectl 2.4、配置init-config.yaml 2.5、安装master节点 2.6、安装…

2022年09月 C/C++(二级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题&#xff1a;统计误差范围内的数 统计一个整数序列中与指定数字m误差范围小于等于X的数的个数。 时间限制&#xff1a;5000 内存限制&#xff1a;65536 输入 输入包含三行&#xff1a; 第一行为N&#xff0c;表示整数序列的长度(N < 100); 第二行为N个整数&#xff0c;…

在C中使用Socket实现多线程异步TCP消息发送

目录 基础知识开始实现主要函数说明结束语 在本篇文章中&#xff0c;我们会探讨如何在C语言中使用socket来实现多线程&#xff0c;异步发送TCP消息的系统。虽然C标准库并没有原生支持异步和多线程编程&#xff0c;但是我们可以结合使用POSIX线程&#xff08;pthread&#xff09…

React2023电商项目实战 - 1.项目搭建

古人学问无遗力&#xff0c;少壮工夫老始成。 纸上得来终觉浅&#xff0c;绝知此事要躬行。 —— 陆游《《冬夜读书示子聿》》 系列文章目录 项目搭建App登录及网关App文章自媒体平台&#xff08;博主后台&#xff09;内容审核(自动) 文章目录 系列文章目录一、项目介绍1.页面…

二维码网络钓鱼攻击泛滥!美国著名能源企业成主要攻击目标

近日&#xff0c;Cofense发现了一次专门针对美国能源公司的网络钓鱼攻击活动&#xff0c;攻击者利用二维码将恶意电子邮件塞进收件箱并绕过安全系统。 Cofense 方面表示&#xff0c;这是首次发现网络钓鱼行为者如此大规模的使用二维码进行钓鱼攻击&#xff0c;这表明他们可能正…

云计算——ACA学习 云计算核心技术

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a;网络豆的主页​​​​​ 写在前面 本系列将会持续更新云计算阿里云ACA的学习&#xff0c;了解云计算及网络安全相关…

使用 AI 将绘画和照片转换为动画

推荐&#xff1a;使用 NSDT场景编辑器 助你快速搭建可二次编辑器的3D应用场景 华盛顿大学和Facebook的研究人员最近发表了一篇论文&#xff0c;展示了一种基于深度学习的系统&#xff0c;可以将静止图像和绘画转换为动画。称为照片唤醒的算法使用卷积神经网络从单个静止图像以 …

性能分析之MySQL慢查询日志分析(慢查询日志)

一、背景 MySQL的慢查询日志是MySQL提供的一种日志记录,他用来记录在MySQL中响应的时间超过阈值的语句,具体指运行时间超过long_query_time(默认是10秒)值的SQL,会被记录到慢查询日志中。 慢查询日志一般用于性能分析时开启,收集慢SQL然后通过explain进行全面分析,一…

【STM32】FreeRTOS事件组学习

事件组&#xff08;Event Group&#xff09; 一个任务执行之前需要经过多个条件进行判断&#xff0c;当条件全部满足或多个条件中的某一个条件满足才执行。 实验&#xff1a;创建两个任务&#xff0c;一个事件组&#xff0c;当按键一二三都按过一遍才打印。 实现&#xff1a…

【论文阅读】SHADEWATCHER:使用系统审计记录的推荐引导网络威胁分析(SP-2022)

SHADEWATCHER: Recommendation-guided CyberThreat Analysis using System Audit Records S&P-2022 新加坡国立大学、中国科学技术大学 Zengy J, Wang X, Liu J, et al. Shadewatcher: Recommendation-guided cyber threat analysis using system audit records[C]//2022 I…

python3 0学习笔记之基本知识

0基础学习笔记之基础知识 &#x1f4da; 基础内容1. 条件语句 if - elif - else2. 错误铺捉try - except(一种保险策略&#xff09;3. 四种开发模式4. 函数&#xff1a;def用来定义函数的5. 最大值最小值函数&#xff0c;max &#xff0c;min6. is 严格的相等&#xff0c;is no…