基础GAN生成式对抗网络(pytorch实验)

news2024/9/29 23:26:49

(Generative Adversarial Network)

一、理论

https://zhuanlan.zhihu.com/p/307527293?utm_campaign=shareopn&utm_medium=social&utm_psn=1815884330188283904&utm_source=wechat_session
大佬的文章中的“GEN的本质”部分
在这里插入图片描述

二、实验

1、数据集介绍

采用MNIST数据集,如下是训练集中的一张图片
在这里插入图片描述

2、代码

引入包

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

定义生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()  # 输出范围在 -1 到 1 之间
        )

    def forward(self, x):
        return self.net(x)

定义判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出范围在 0 到 1 之间
        )

    def forward(self, x):
        return self.net(x)

训练
先训练判别器,后训练生成器
训练时先训练判别器:将训练集数据(Training Set)打上真标签(1)和生成器(Generator)生成的假图片(Fake image)打上假标签(0)一同组成batch送入判别器(Discriminator),对判别器进行训练。计算loss时使判别器对真数据(Training Set)输入的判别趋近于真(1),对生成器(Generator)生成的假图片(Fake image)的判别趋近于假(0)。此过程中只更新判别器(Discriminator)的参数,不更新生成器(Generator)的参数。

然后再训练生成器:将高斯分布的噪声z(Random noise)送入生成器(Generator),然后将生成器(Generator)生成的假图片(Fake image)打上真标签(1)送入判别器(Discriminator)。计算loss时使判别器对生成器(Generator)生成的假图片(Fake image)的判别趋近于真(1)。此过程中只更新生成器(Generator)的参数,不更新判别器(Discriminator)的参数。

transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


def train_gan(generator, discriminator, dataloader, num_epochs=25):
   criterion = nn.BCELoss()
   optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
   optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

   for epoch in range(num_epochs):
       for i, (imgs, _) in enumerate(dataloader):
           batch_size = imgs.size(0)

           real_imgs = imgs.view(batch_size, -1)  # 将图像展平成一维

           # 标签
           real_labels = torch.ones(batch_size, 1)
           fake_labels = torch.zeros(batch_size, 1)

           # 训练判别器
           outputs = discriminator(real_imgs)
           d_loss_real = criterion(outputs, real_labels)
           real_score = outputs

           z = torch.randn(batch_size, 100)
           fake_imgs = generator(z)
           outputs = discriminator(fake_imgs.detach())
           d_loss_fake = criterion(outputs, fake_labels)
           fake_score = outputs

           d_loss = d_loss_real + d_loss_fake
           optimizer_d.zero_grad()
           d_loss.backward()
           optimizer_d.step()

           # 训练生成器
           outputs = discriminator(fake_imgs)
           g_loss = criterion(outputs, real_labels)

           optimizer_g.zero_grad()
           g_loss.backward()
           optimizer_g.step()

       print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

       if (epoch+1) % 10 == 0:
           with torch.no_grad():
               fake_imgs = generator(torch.randn(64, 100)).view(-1, 1, 28, 28)
               grid = torchvision.utils.make_grid(fake_imgs, nrow=8, normalize=True)
               plt.imshow(grid.permute(1, 2, 0).cpu())
               plt.title(f'Epoch {epoch+1}')
               plt.show()




generator = Generator()
discriminator = Discriminator()

train_gan(generator, discriminator, dataloader)

3、结果

输入一个随机噪声图像,由生成器能得到如下的图片(训练1step的结果)
在这里插入图片描述

输入一个随机噪声图像,由生成器能得到如下的图片(训练10step的结果)
在这里插入图片描述

拓展

可以看看VAE、CGAN等模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2136806.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SIP ACK method

SIP ACK同样在RFC3261中定义。 ACK仅仅用于对INVITE request的response的回复,例如在通话结束时,MO要断开连接,此时就会生成一条BYE 消息。BYE不会经过代理,而是直接路由到MT。MT通过200 (OK) 响应确认收到 BYE,然后就…

vagrant+virtualbox+ubuntu22.04无法上网问题

一、过程 vagrantfile配置私有网络 config.vm.network "private_network", ip: "192.168.56.10"启动虚拟机,可以ping通百度的实际IP,ping不通域名修改/etc/netplan/50-vagrant.yaml,配置DNS network:renderer: Networ…

计算机毕业设计Python知识图谱美团美食推荐系统 美团餐厅推荐系统 美团推荐系统 美食价格预测 美团爬虫 美食数据分析 美食可视化大屏

《Python知识图谱美团美食推荐系统》开题报告 一、研究背景与意义 随着信息技术的飞速发展和互联网应用的普及,人们的消费习惯逐渐从线下转移到线上,外卖行业迎来了前所未有的发展机遇。美团作为国内领先的生活服务电子商务平台,拥有庞大的…

基于SpringBoot+Vue的考务报名平台(带1w+文档)

基于SpringBootVue的考务报名平台(带1w文档) 基于SpringBootVue的考务报名平台(带1w文档) 当前社会各行业领域竞争压力非常大,随着当前时代的信息化,科学化发展,让社会各行业领域都争相使用新的信息技术,对行业内的各种相关数据进…

026.(娱乐)魔改浏览器-任务栏图标右上角加提示徽章

一、目标: windows中,打开chromium,任务栏中会出现一个chromium的图标。我们的目标是给这个图标的右上角,加上"有1条新消息"的小提示图标,也叫徽章(badge)注意:本章节纯属娱乐,有需要…

【DS】AVL树

目录 AVL树的介绍AVL树节点的定义认识AVL树的抽象图AVL树的插入BSTree规则插入更新平衡因子平衡因子的判断 AVL树的旋转左单旋右单旋左右双旋右左双旋 AVL树的验证AVL树的查找AVL树的性能 在上篇搜索二叉树末尾提到过,得益于搜索二叉树的性质:大于根往右…

计算机毕业设计Python深度学习垃圾邮件分类检测系统 朴素贝叶斯算法 机器学习 人工智能 数据可视化 大数据毕业设计 Python爬虫 知识图谱 文本分类

基于朴素贝叶斯的邮件分类系统设计 摘要:为了解决垃圾邮件导致邮件通信质量被污染、占用邮箱存储空间、伪装正常邮件进行钓鱼或诈骗以及邮件分类问题。应用Python、Sklearn、Echarts技术和Flask、Lay-UI框架,使用MySQL作为系统数据库,设计并实…

java: 程序包org.junit.jupiter.api不存在

明明idea没有报错,引用包也没问题,为啥提示java: 程序包org.junit.jupiter.api不存在? 配置!还TMD是配置! 如果是引用包的版本不对或者其他,直接就是引用报错或者pom里面飘红了。 这个应该是把generat…

微服务-- Sentinel的使用

目录 Sentinel:微服务的哨兵 生态系统景观 sentinel与spring cloud Hystrix 对比 Sentinel 主要分为两部分 Sentinel安装与使用 Sentinel的控制规则 流控规则 流控规则的属性说明 新增流控规则 关联流控模式 SentinelResource注解的使用 SentinelResou…

如何将自己的项目发布到Maven中央仓库

1.背景 本教程为2024年9月最新版 我有一个java项目想发布到maven中央仓库&#xff0c;然后任何人都可以在pom文件中引用我写的代码 引用格式如下&#xff1a; <!-- 这是引用rocketmq的坐标 --> <dependency><groupId>org.apache.rocketmq</groupId>&l…

炫酷HTML蜘蛛侠登录页面

全篇使用HTML、CSS、JavaScript&#xff0c;建议有过基础的进行阅读。 一、预览图 二、HTML代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-w…

ESP8266做httpServer提示Header fields are too long for server to interpret

CONFIG_HTTP_BUF_SIZE512 CONFIG_HTTPD_MAX_REQ_HDR_LEN1024 CONFIG_HTTPD_MAX_URI_LEN512CONFIG_HTTPD_MAX_REQ_HDR_LEN由512改为1024

C++ | Leetcode C++题解之第404题左叶子之和

题目&#xff1a; 题解&#xff1a; class Solution { public:bool isLeafNode(TreeNode* node) {return !node->left && !node->right;}int sumOfLeftLeaves(TreeNode* root) {if (!root) {return 0;}queue<TreeNode*> q;q.push(root);int ans 0;while …

氢能源多旋翼无人机技术详解

1. 技术背景与优势 随着全球对低碳、环保和高效能源解决方案的需求日益增长&#xff0c;氢能源作为一种清洁、高效的能源形式&#xff0c;在多个领域展现出巨大的应用潜力。在无人机领域&#xff0c;氢能源多旋翼无人机因其独特的优势逐渐受到关注。相比传统锂电池无人机&…

Linux sh命令

目录 一. 基本语法二. 选项2.1 -c 字符串中读取内容&#xff0c;并执行2.1.1 基本用法2.1.2 获取当前目录下失效的超链接 2.2 -x 每个命令执行之前&#xff0c;将其打印出来2.3 结合Here文档使用 一. 基本语法 ⏹Linux 和 Unix 系统中用于执行 shell 脚本 或 运行命令 的命令。…

【我的 PWN 学习手札】Fastbin Double Free

前言 Fastbin的Double Free实际上还是利用其特性产生UAF的效果&#xff0c;使得可以进行Fastbin Attack 一、Double Free double free&#xff0c;顾名思义&#xff0c;free两次。对于fastbin这种单链表的组织结构&#xff0c;会形成这样一个效果&#xff1a; 如果我们mallo…

卡西莫多的手信

通过网盘分享的文件&#xff1a;卡西莫多的手信2022-2024.9.15-A5.pdf 链接: 百度网盘 请输入提取码 提取码: gig1

oracle数据库安装和配置详细讲解

​ 大家好&#xff0c;我是程序员小羊&#xff01; 前言&#xff1a; Oracle 数据库是全球广泛使用的关系型数据库管理系统 (RDBMS)&#xff0c;提供高性能、可靠性、安全性和可扩展性&#xff0c;广泛应用于企业关键任务系统。下面详细介绍如何在 CentOS 系统上安装和配置 Or…

非金属失效与典型案例分析培训

随着生产和科学技术的发展&#xff0c;人们不断对高分子材料提出各种各样的新要求。因为技术的全新要求和产品的高要求化&#xff0c;而客户对产品的高要求及工艺理解不一&#xff0c;于是高分子材料断裂、开裂、腐蚀、变色等之类失效频繁出现&#xff0c;常引起供应商与用户间…

O2O营销,中小企业数字化转型的加速器

嘿&#xff0c;小伙伴们&#xff0c;今天咱们要聊的&#xff0c;可是那让中小企业焕发新生的O2O营销魔法&#xff01;它就像是一位时空穿梭者&#xff0c;轻松跨越线上与线下的鸿沟&#xff0c;带着商家们开启了一场数字化转型的奇妙之旅。 O2O营销&#xff1a;不只是连接&…