第G1周:生成对抗网络(GAN)入门

news2025/1/11 11:19:04

目录

  • 一、课题背景和开发环境
  • 二、理论基础
    • 1.生成器
    • 2. 判别器
    • 3. 基本原理
  • 三、前期准备工作
    • 1. 定义超参数
    • 2.下载数据
    • 3. 配置数据
  • 四、定义模型
    • 1. 定义鉴别器
    • 2. 定义生成器
  • 五、训练模型
    • 1. 创建实例
    • 2. 训练模型
    • 3. 保存模型

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊|接辅导、项目定制

一、课题背景和开发环境

📌第G1周:生成对抗网络(GAN)入门📌
Python 3.8.12
numpy==1.21.5 -> 1.24.3
pytorch==1.8.1+cu111
📌 基础任务:
了解什么是生成对抗网络(GAN)
学习本文代码,并跑通代码
🎈进阶任务:
调用训练好的模型生成新图像

二、理论基础

生成对抗网络(Generative Adversarial Networks, GAN)是近年来深度学习领域的一个热点方向。
GAN并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。
GAN由两个分别被称为生成器(Generator)和判别器(Discriminator)的神经网络组成。其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。生成器和判别器交替运行,相互博弈,各自的能力都得到升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本都输出50%真,50%假的判断。此时,生成器输出的人工样本已经逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具有“伪造”真实样本能力的生成器。

1.生成器

GANs中,生成器 G 选取随机噪声 z 作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本 G ( z ) G(z)G(z) 。生成器的本质是一个使用生成式方法的模型,它对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。
从数学上来说,生成式方法对于给定的真实数据,首先需要对数据的显式变量或隐含变量做分布假设;然后再将真实数据输入到模型中对变量、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。从机器学习的角度来说,模型不会去做分布假设,而是通过不断地学习真实数据,对模型进行修正,最后也可以得到一个学习后的模型来做样本生成任务。这种方法不同于数学方法,学习的过程对人类理解较不直观。

2. 判别器

GANs中,判别器 D 对于输入的样本 x,输出一个 [ 0 , 1 ] [0,1][0,1] 之间的概率数值 D ( x ) D(x)D(x)。x 可能是来自于原始数据集中的真实样本 x,也可能是来自于生成器 G 的人工样本 G ( z ) G(z)G(z)。通常约定,概率值 D ( x ) D(x)D(x) 越接近于1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明 GAN 是一个无监督的学习过程。

3. 基本原理

GAN是博弈论和机器学习相结合的产物,于2014年Ian Goodfellow的论文中问世,一经问世即火爆足以看出人们对于这种算法的认可和狂热的研究热忱。想要更详细的了解GAN,就要知道它是怎么来的,以及这种算法出现的意义是什么。研究者最初想要通过计算机完成自动生成数据的功能,例如通过训练某种算法模型,让某模型学习过一些苹果的图片后能自动生成苹果的图片,具备些功能的算法即认为具有生成功能。但是GAN不是第一个生成算法,而是以往的生成算法在衡量生成图片和真实图片的差距时采用均方误差作为损失函数,但是研究者发现有时均方误差一样的两张生成图片效果却截然不同,鉴于此不足Ian Goodfellow提出了GAN。
在这里插入图片描述
那么GAN是如何完成生成图片这项功能的呢,如上图所示,GAN是由两个模型组成的:生成模型G和判别模型D。首先第一代生成模型1G的输入是随机噪声z,然后生成模型会生成一张初级照片,训练一代判别模型1D另其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了欺瞒一代鉴别器,于是一代生成模型开始优化,然后它进阶成了二代,当它生成的数据成功欺瞒1D时,鉴别模型也会优化更新,进而升级为2D,按照同样的过程也会不断更新出N代的G和D。

三、前期准备工作

1. 定义超参数

import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image
import torchvision.transforms as transforms

# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)
os.makedirs('./output/', exist_ok=True)
os.makedirs('./data/MNIST/', exist_ok=True)

# 超参数配置
n_epochs = 50
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 2
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500

# 图像的尺寸:(1, 28, 28),和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)

# 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

2.下载数据

## mnist数据集下载
mnist = datasets.MNIST(root='./data/', 
                       train=True, 
                       download=False,
                       transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))

3. 配置数据

# 配置数据到加载器
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

四、定义模型

1. 定义鉴别器

'''
定义判别器 Discriminator
将图片28x28展开成784,然后通过多层感知器,
中间经过斜率设置为0.2的LeakyReLU激活函数,
最后接sigmoid激活函数得到一个0到1之间的概率进行二分类
'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
                     nn.Linear(img_area, 512),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Linear(512, 256),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Linear(256, 1),
                     nn.Sigmoid())
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)  # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
        validity = self.model(img_flat)       # 通过鉴别器网络
        return validity                       # 鉴别器返回的是一个[0, 1]间的概率

2. 定义生成器


'''
定义生成器 Generator
输入一个100维的0~1之间的高斯分布,
然后通过第一层线性变换将其映射到256维,
然后通过LeakyReLU激活函数,
接着进行一个线性变换,
再经过一个LeakyReLU激活函数,
然后经过陷先变换将其变成784维,
最后经过Tanh激活函数,
是希望生成的假的图片数据分布,能够再-1~1之间。
'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 模型中间块
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        # prod():返回给定轴上的数组元素的乘积:1*28*28=784
        self.model = nn.Sequential(
                     *block(latent_dim, 128, normalize=False),
                     *block(128, 256),
                     *block(256, 512),
                     *block(512, 1024),
                     nn.Linear(1024, img_area),
                     nn.Tanh())
    
    # view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1,28, 28)
    def forward(self, z):                          # 输入的是(64, 100)的噪声数据
        imgs = self.model(z)                       # 噪声数据通过生成器模型
        imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1,28, 28)
        return imgs                                # 输出为64张大小为(1, 28, 28)的图像

五、训练模型

1. 创建实例

# 创建生成器、判别器对象
generator = Generator()
discriminator = Discriminator()
# 定义loss的度量方式(二分类的交叉熵)
criterion = torch.nn.BCELoss()
# 定义又换函数,学习率为0.0003
# betas: 用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
# 若有显卡,在cuda模式中运行
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()

2. 训练模型

# 进行多个epoch的训练
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        ''' 训练判别器 Train Discriminator '''
        # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        imgs = imgs.view(imgs.size(0), -1)
        real_img = Variable(imgs).cuda()
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()
        # 计算真实图片的损失
        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out, real_label)
        real_scores = real_out
        # 计算假的图片的损失
        # detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out, fake_label)
        fake_scores = fake_out
        # 损失函数和优化
        loss_D = loss_real_D + loss_fake_D
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        
        ''' 训练生成器 Train Generator '''
        # 原理:目的是希望生成的假的图片被判别器判断为真的图片,
        # 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
        # 反向传播更新的参数是生成网络里面的参数,
        # 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z)
        output = discriminator(fake_img)
        # 损失函数和优化
        loss_G = criterion(output, real_label)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        
        ''' 打印训练过程中的日志 '''
        # item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 300 == 0:
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
            % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))
        # 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)

3. 保存模型

# 保存模型
torch.save(generator.state_dict(), './output/generator.pth')
torch.save(discriminator.state_dict(), './output/discriminator.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/753807.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

检测到目标Content-Security-Policy响应头缺失

详细描述 HTTP 响应头Content-Security-Policy允许站点管理者控制用户代理能够为指定的页面加载哪些资源。除了少数例外情况,设置的政策主要涉及指定服务器的源和脚本结束点。 Content-Security-Policy响应头的缺失使得目标URL更易遭受跨站脚本攻击。 解决办法 …

浅谈炼钢厂能源计量管理系统的设计与应用

安科瑞 华楠 摘要: 从能源计量和管理的角度,论述了炼钢厂的能源计量管理系统的基本组成及功能。该系统的建立,将使炼钢厂能源介质的计量管理工作实现自动采集、瞬时监测、故障报警、能流监视;完成报表统计、离线输入、成本分析、预测参考等功…

【正点原子STM32连载】 第五十五章 录音机实验摘自【正点原子】STM32F103 战舰开发指南V1.2

1)实验平台:正点原子stm32f103战舰开发板V4 2)平台购买地址:https://detail.tmall.com/item.htm?id609294757420 3)全套实验源码手册视频下载地址: http://www.openedv.com/thread-340252-1-1.html# 第五…

大学生用一周时间给麦当劳做了个App(Flutter版)

背景 有个大学生粉丝最近私信联系我,说基于我之前开源的多语言项目做了个仿麦当劳的项目,虽然只是个样子货,但是收获颇多,希望把自己写的代码开源出来供大家一起学习进度。这个小伙伴确实是非常积极上进,很多大学生&a…

Django admin管理工具TabularInline表格内联

详解 TabularInline 是 Django Admin 中的一个内联模型选项,用于在父模型的编辑页面中以表格形式显示关联的子模型对象。下面是对 TabularInline 的一些详解: 显示方式:TabularInline 以表格的形式显示子模型对象。每个子模型对象将以一行的…

12.0、Java_IO流 - 字节数组输入输出流

12.0、Java_IO流 - 字节数组输入输出流 字节数组流: ByteArrayInputStream 和 byteArrayOutputStream 经常用在需要流和数组之间转化的情况; 字节数组输入流: 说白了,FileInputStream 是把文件当做数据源;ByteArrayInp…

变动率ROC指标详解及改进版选股公式

ROC指标(变动率指标)是一种基于动量的技术指标,衡量当前价格与一定天数前价格之间变化的百分比。ROC指标围绕零轴上下波动,如果价格变化向上,指标会移动到零轴之上;如果价格变动向下,则指标会移…

NestJS 编写 SSE 接口推送数据

做项目的时候遇到了顺便就记一下相关的内容。 SSE Server-Sent Events(SSE)技术,它是一种用于实现服务器向客户端实时推送数据的Web技术。SSE基于HTTP协议,允许服务器将数据以事件流(Event Stream)的形式…

深度学习(29)—— DETR

深度学习(29)—— DETR DETR代码欢迎光临Jane的GitHub:在这里等你 看完YOLO 之后,紧接着看了DETR。作为Transformer在物体检测上的开山之作,虽然他的性能或许不及其他的模型,但是想法是OK的。里面还有一些…

数据结构day1(2023.7.13)

一、Xmind整理: 二、课上练习: 练习1:static(全局变量、局部变量作用域) int a0;//全局变量 生命周期和作用于都是从定义开始到整个文件结束 void fun() { int b0;//局部变量 static int c0;//局部变量 作用于&#x…

智头条|第25届中国建博会(广州)成功举行,马斯克组建xAI公司

行业动态: 第25届中国建博会(广州)成功举行 7月8日至11日期间,2023中国建博会(广州)暨首届广州卫博会在广州如火如荼地进行。本届展会以“冠军企业首秀平台”为定位,以“建装理想家,服务新格局”为主题&a…

我的创作纪念日——创作的第2048天

创作机缘 今天收到私信,在CSDN已经7年码龄,创作2048天了,刚开始写作的时候似乎还是在大二,那个懵懂无知的年纪,也是在那个时候开始接触开发,接触编程。 之后便是无尽的探索与尝试,没有明确的发…

DuiLib的消息传递机制

前言 学会了怎么写XML文件,但是我还是不知道怎么实现各个控件之间的消息传递。于是我对源代码好好研究了一下,发现duilib作为一个界面库有自己独立的封装的窗口类,也就是WindowsImplBase。 在这个类中,实现对windows窗口传过来的消息的处理…

【每日一题】2673. 使二叉树所有路径值相等的最小代价

【每日一题】2673. 使二叉树所有路径值相等的最小代价 2673. 使二叉树所有路径值相等的最小代价题目描述解题思路 2673. 使二叉树所有路径值相等的最小代价 题目描述 给你一个整数 n 表示一棵 满二叉树 里面节点的数目,节点编号从 1 到 n 。根节点编号为 1 &#…

前端day06笔记

数组遍历 for let i in 数组名 函数 没有return返回的是underfined var是全局作用域 匿名函数 具名函数 值传递 引用传递 arguments:接实参 箭头函数 递归 闭包 对象的增删改查 对象的遍历 数组对象 获取数组对象 内置对象 从2-10

保险企业如何做好数据安全合规与敏感数据保护

监管部门多次“重拳出击”,保险企业如何做好敏感数据保护工作? 继《个人信息保护法》、《中华人民共和国消费者权益保护法》、《中国银保监会办公厅关于印发银行保险机构信息科技外包风险监管办法的通知》、《互联网保险业务监管办法》等相关法规之后&am…

CNN从搭建到部署实战(pytorch+libtorch)

模型搭建 下面的代码搭建了CNN的开山之作LeNet的网络结构。 import torchclass LeNet(torch.nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv torch.nn.Sequential(torch.nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_sizetorch.nn.Sig…

mysql 第三章

目录 1.索引 2.事务 3.总结 1.索引 2.事务 3.总结 事务是一种机制,一个操作序列。包含了一组数据库操作命令。

静态数码管——FPGA

文章目录 前言一、数码管1、数码管简介2、共阴极数码管or共阳极数码管3、共阴极与共阳极的真值表 二、系统设计1、模块框图2、RTL视图 三、源码1、seg_led_static模块2、time_count模块3、top_seg_led_static(顶层文件) 四、效果五、总结六、参考资料 前言 环境: 1、…

大学生用一周时间给麦当劳做了个App(微信小程序版)

背景 有个大学生粉丝最近私信联系我,说基于我之前开源的多语言项目做了个仿麦当劳的项目,虽然只是个样子货,但是收获颇多,希望把自己写的代码开源出来供大家一起学习进度。这个小伙伴确实是非常积极上进,很多大学生&a…