【Pytorch】生成对抗网络实战

news2024/9/28 7:26:51

GAN框架基于两个模型的竞争,Generator生成器和Discriminator鉴别器。生成器生成假图像,鉴别器则尝试从假图像中识别真实的图像。作为这种竞争的结果,生成器将生成更好看的假图像,而鉴别器将更好地识别它们。

目录

创建数据集

定义生成器

定义鉴别器

初始化模型权重

定义损失函数

定义优化器

训练模型

部署生成器


创建数据集

使用 PyTorch torchvision 包中提供的 STL-10 数据集,数据集中有 10 个类:飞机、鸟、车、猫、鹿、狗、马、猴、船、卡车。图像为96*96像素的RGB图像。数据集包含 5,000 张训练图像和 8,000 张测试图像。在训练数据集和测试数据集中,每个类分别有 500 和 800 张图像。

 STL-10数据集详细参考http://t.csdnimg.cn/ojBn6中数据加载和处理部分 

from torchvision import datasets
import torchvision.transforms as transforms
import os

# 定义数据集路径
path2data="./data"
# 创建数据集路径
os.makedirs(path2data, exist_ok= True)

# 定义图像尺寸
h, w = 64, 64
# 定义均值
mean = (0.5, 0.5, 0.5)
# 定义标准差
std = (0.5, 0.5, 0.5)
# 定义数据预处理
transform= transforms.Compose([
           transforms.Resize((h,w)),  # 调整图像尺寸
           transforms.CenterCrop((h,w)),  # 中心裁剪
           transforms.ToTensor(),  # 转换为张量
           transforms.Normalize(mean, std)])  # 归一化
    
# 加载训练集
train_ds=datasets.STL10(path2data, split='train', 
                        download=False,
                        transform=transform)

 展示示例图像张量形状、最小值和最大值

import torch
for x, _ in train_ds:
    print(x.shape, torch.min(x), torch.max(x))
    break

 展示示例图像

from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline
plt.imshow(to_pil_image(0.5*x+0.5))

 

创建数据加载器 

import torch
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, 
                                       batch_size=batch_size, 
                                       shuffle=True)

 示例

for x,y in train_dl:
    print(x.shape, y.shape)
    break

定义生成器

GAN框架是基于两个模型的竞争,generator生成器和discriminator鉴别器。生成器生成假图像,鉴别器尝试从假图像中识别真实的图像。

作为这种竞争的结果,生成器将生成更好看的假图像,而鉴别器将更好地识别它们。

定义生成器模型 

from torch import nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, params):
        super(Generator, self).__init__()
        # 获取参数
        nz = params["nz"]
        ngf = params["ngf"]
        noc = params["noc"]
        # 定义反卷积层1
        self.dconv1 = nn.ConvTranspose2d( nz, ngf * 8, kernel_size=4,
                                         stride=1, padding=0, bias=False)
        # 定义批归一化层1
        self.bn1 = nn.BatchNorm2d(ngf * 8)
        # 定义反卷积层2
        self.dconv2 = nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, 
                                         stride=2, padding=1, bias=False)
        # 定义批归一化层2
        self.bn2 = nn.BatchNorm2d(ngf * 4)
        # 定义反卷积层3
        self.dconv3 = nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size=4, 
                                         stride=2, padding=1, bias=False)
        # 定义批归一化层3
        self.bn3 = nn.BatchNorm2d(ngf * 2)
        # 定义反卷积层4
        self.dconv4 = nn.ConvTranspose2d( ngf * 2, ngf, kernel_size=4, 
                                         stride=2, padding=1, bias=False)
        # 定义批归一化层4
        self.bn4 = nn.BatchNorm2d(ngf)
        # 定义反卷积层5
        self.dconv5 = nn.ConvTranspose2d( ngf, noc, kernel_size=4, 
                                         stride=2, padding=1, bias=False)

# 前向传播
    def forward(self, x):
        # 反卷积层1
        x = F.relu(self.bn1(self.dconv1(x)))
        # 反卷积层2
        x = F.relu(self.bn2(self.dconv2(x)))            
        # 反卷积层3
        x = F.relu(self.bn3(self.dconv3(x)))        
        # 反卷积层4
        x = F.relu(self.bn4(self.dconv4(x)))    
        # 反卷积层5
        out = torch.tanh(self.dconv5(x))
        return out

设定生成器模型参数、移动模型到cuda设备并打印模型结构 

params_gen = {
        "nz": 100,
        "ngf": 64,
        "noc": 3,
        }
model_gen = Generator(params_gen)
device = torch.device("cuda:0")
model_gen.to(device)
print(model_gen)

定义鉴别器

定义鉴别器模型, 用于鉴别真实图像

class Discriminator(nn.Module):
    def __init__(self, params):
        super(Discriminator, self).__init__()
        # 获取参数
        nic= params["nic"]
        ndf = params["ndf"]
        # 定义卷积层1
        self.conv1 = nn.Conv2d(nic, ndf, kernel_size=4, stride=2, padding=1, bias=False)
        # 定义卷积层2
        self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)
        # 定义批归一化层2
        self.bn2 = nn.BatchNorm2d(ndf * 2)            
        # 定义卷积层3
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)
        # 定义批归一化层3
        self.bn3 = nn.BatchNorm2d(ndf * 4)
        # 定义卷积层4
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)
        # 定义批归一化层4
        self.bn4 = nn.BatchNorm2d(ndf * 8)
        # 定义卷积层5
        self.conv5 = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False)

    def forward(self, x):
        # 使用leaky_relu激活函数对卷积层1的输出进行激活
        x = F.leaky_relu(self.conv1(x), 0.2, True)
        # 使用leaky_relu激活函数对卷积层2的输出进行激活,并使用批归一化层2进行批归一化
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, inplace = True)
        # 使用leaky_relu激活函数对卷积层3的输出进行激活,并使用批归一化层3进行批归一化
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, inplace = True)
        # 使用leaky_relu激活函数对卷积层4的输出进行激活,并使用批归一化层4进行批归一化
        x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, inplace = True)        
        
        # 使用sigmoid激活函数对卷积层5的输出进行激活,并返回结果
        # Sigmoid激活函数是一种常用的非线性激活函数,它将输入值压缩到0和1之间,[ \sigma(x) = \frac{1}{1 + e^{-x}} ]
        out = torch.sigmoid(self.conv5(x))
        return out.view(-1)

设置模型参数,移动模型到cuda设备,打印模型结构 


params_dis = {
    "nic": 3,
    "ndf": 64}
model_dis = Discriminator(params_dis)
model_dis.to(device)
print(model_dis)

初始化模型权重

定义函数,初始化模型权重 

def initialize_weights(model):
    # 获取模型类的名称
    classname = model.__class__.__name__
    # 如果模型类名称中包含'Conv',则初始化权重为均值为0,标准差为0.02的正态分布
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    # 如果模型类名称中包含'BatchNorm',则初始化权重为均值为1,标准差为0.02的正态分布,偏置为0
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

初始化生成器模型和鉴别器模型的权重 

# 对生成器模型应用初始化权重函数
model_gen.apply(initialize_weights);
# 对判别器模型应用初始化权重函数
model_dis.apply(initialize_weights);

定义损失函数

定义二元交叉熵(BCE)损失函数 

loss_func = nn.BCELoss()

定义优化器

定义Adam优化器

from torch import optim
# 学习率
lr = 2e-4 
# Adam优化器的beta1参数
beta1 = 0.5
# 定义鉴别器模型的优化器,学习率为lr,beta1参数为beta1,beta2参数为0.999
opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1, 0.999))
# 定义生成器模型的优化器
opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1, 0.999))

训练模型

 示例训练1000个epochs

# 定义真实标签和虚假标签
real_label = 1
fake_label = 0
# 获取生成器的噪声维度
nz = params_gen["nz"]
# 设置训练轮数
num_epochs = 1000
# 定义损失历史记录
loss_history={"gen": [],
              "dis": []}
# 定义批次数
batch_count = 0
# 遍历训练轮数
for epoch in range(num_epochs):
    # 遍历训练数据
    for xb, yb in train_dl:
        # 获取批大小
        ba_si = xb.size(0)
        # 将判别器梯度置零
        model_dis.zero_grad()
        # 将输入数据移动到指定设备
        xb = xb.to(device)
        # 将标签数据转换为指定设备
        yb = torch.full((ba_si,), real_label, device=device)
        # 判别器输出
        out_dis = model_dis(xb)
        # 将输出和标签转换为浮点数
        out_dis = out_dis.float()
        yb = yb.float()
        # 计算真实样本的损失
        loss_r = loss_func(out_dis, yb)
        # 反向传播
        loss_r.backward()

        # 生成噪声
        noise = torch.randn(ba_si, nz, 1, 1, device=device)
        # 生成器输出
        out_gen = model_gen(noise)
        # 判别器输出
        out_dis = model_dis(out_gen.detach())
        # 将标签数据填充为虚假标签
        yb.fill_(fake_label)    
        # 计算虚假样本的损失
        loss_f = loss_func(out_dis, yb)
        # 反向传播
        loss_f.backward()
        # 计算判别器的总损失
        loss_dis = loss_r + loss_f  
        # 更新判别器的参数
        opt_dis.step()   

        # 将生成器梯度置零
        model_gen.zero_grad()
        # 将标签数据填充为真实标签
        yb.fill_(real_label)  
        # 判别器输出
        out_dis = model_dis(out_gen)
        # 计算生成器的损失
        loss_gen = loss_func(out_dis, yb)
        # 反向传播
        loss_gen.backward()
        # 更新生成器的参数
        opt_gen.step()

        # 记录生成器和判别器的损失
        loss_history["gen"].append(loss_gen.item())
        loss_history["dis"].append(loss_dis.item())
        # 更新批次数
        batch_count += 1
        # 每100个批打印一次损失
        if batch_count % 100 == 0:
            print(epoch, loss_gen.item(),loss_dis.item())

 绘制损失图像

plt.figure(figsize=(10,5))
plt.title("Loss Progress")
plt.plot(loss_history["gen"],label="Gen. Loss")
plt.plot(loss_history["dis"],label="Dis. Loss")
plt.xlabel("batch count")
plt.ylabel("Loss")
plt.legend()
plt.show()

存储模型权重 

import os
path2models = "./models/"
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, "weights_gen_128.pt")
path2weights_dis = os.path.join(path2models, "weights_dis_128.pt")
torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)

部署生成器

通常情况下,训练完成后放弃鉴别器模型而保留生成器模型,部署经过训练的生成器来生成新的图像。为部署生成器模型,将训练好的权重加载到模型中,然后给模型提供随机噪声。

# 加载生成器模型的权重
weights = torch.load(path2weights_gen)
# 将权重加载到生成器模型中
model_gen.load_state_dict(weights)
# 将生成器模型设置为评估模式
model_gen.eval()

 生成图像

import numpy as np
with torch.no_grad():
    # 生成固定噪声
    fixed_noise = torch.randn(16, nz, 1, 1, device=device)
    # 打印噪声形状
    print(fixed_noise.shape)
    # 生成假图像
    img_fake = model_gen(fixed_noise).detach().cpu()    
# 打印假图像形状
print(img_fake.shape)
# 创建画布
plt.figure(figsize=(10,10))
# 遍历假图像
for ii in range(16):
    # 在画布上绘制图像
    plt.subplot(4,4,ii+1)
    # 将图像转换为PIL图像
    plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5))
    # 关闭坐标轴
    plt.axis("off")
    

其中一些可能看起来扭曲,而另一些看起来相对真实。为改进结果,可以在单个数据类上训练模型,而不是在多个类上一起训练。GAN在使用单个类进行训练时表现更好。此外,可以尝试更长时间地训练模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2088488.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

改变潜意识,让梦想照进现实

你是否经常感到困惑,为何努力不得其果?今天我们将一起探索如何通过重新规划潜意识,让你的内心世界和外在行动达到和谐统一,让潜意识成为你坚不可摧的后盾。 想象一下,如果你的潜意识全天候无休止地为你的梦想努力&…

故障电弧探测器在工业与民用建筑电气线路中的设计与应用

安科瑞徐赟杰 【摘要】:电气设备是建筑中不可缺少的一部分,具有较为重要的作用和意义,在应用过程中不仅能够提升建筑本身实用性能,而且可为消费者提供更加优良的生活环境。但设备一旦在运行过程中出现故障,不仅会影响…

合宙低功耗4G模组Air780EQ——硬件设计手册02

Air780EQ是一款基于移芯EC716E平台设计的LTECat1无线通信模组。 支持FDD-LTE/TDD-LTE的4G远距离无线 传输技术。 另外,模组提供了USB/UART/I2C等通用接口满足IoT行业的各种应用诉求。 本文将继续介绍合宙Air780EQ的硬件设计中的 应用接口,射频接口&am…

一分钟学会万用表

目录: 1、电池的安装 1)指针万用表 2)数字万用表 3)高精度表 2、表笔的分类 3、表笔安装 5、常用测量方法 1)二极管测量 2)电阻与通断测量 3)电压测量 4)电流测量 …

面对孩子自闭症,我们该怎么办?

当得知孩子被诊断为自闭症时,家长们往往会感到震惊、无助甚至绝望。然而,面对这一挑战,我们需要做的是保持冷静,积极寻找应对策略,为孩子创造一个充满爱与希望的环境。 深入了解自闭症是关键。自闭症是一种复杂的神经发…

八款精品图纸加密软件强力推荐2024年图纸加密软件最佳选择!

在数字化时代,设计图纸的安全问题越来越受到企业的重视。为了保障企业的知识产权和核心竞争力,选择一款合适的图纸加密软件显得尤为重要。以下是2024年八款精品图纸加密软件的强力推荐,它们各具特色,能够满足不同企业的需求。 1.…

视频美颜SDK的核心技术:打造智能化主播美颜工具详解

视频美颜SDK不仅提升了视频质量,还为主播们提供了智能化、个性化的美颜功能。那么,视频美颜SDK的核心技术究竟是什么?又是如何为主播打造智能化美颜工具的呢? 1.人脸检测与特征点识别 视频美颜SDK技术通过深度学习算法&#xff…

H5带建站时长可自定义背景官网/引导页源码

源码名称:带建站时长可自定义背景官网/引导页源码 源码介绍:一款带动态时间显示建站时长的引导页源码,可用于引导页、工作室官网、个人主页等。源码为H5自适应手机端、电脑端。 需求环境:H5 下载地址: https://www.…

nefu暑假集训2 ST表 个人模板+例题汇总

前言: 比较简单的一个算法了,原理相当于是用二进制优化的区间dp了,用于求一个区间的最大或最小值。其实这类问题一般用线段树就可以直接解决,但如果查询次数过多的话可能会超时,这时就是ST表出场的时候了,因…

遗产系统 legacy system 的定义和演化策略

原始英语叫做legacy system,被国内翻译成了“遗产系统”。实际上,legacy system,可以翻译为遗留系统、旧系统、老系统。 下文部分摘自《遗产系统及其解决方案的综述》一论文。 遗产系统的定义 遗产系统是 “一个已经运行了很长时间的&…

【初学人工智能原理】【13】LSTM网络:自然语言处理实践

前言 本文教程均来自b站【小白也能听懂的人工智能原理】,感兴趣的可自行到b站观看。 代码及工具箱 本专栏的代码和工具函数已经上传到GitHub:1571859588/xiaobai_AI: 零基础入门人工智能 (github.com),可以找到对应课程的代码 正文 上节…

虹科技术|全新Linux环境PCAN驱动程序发布!CAN/CAN FD通信体验全面升级!

全新8.17.0版本的PCAN-Linux驱动程序正式发布,专为CAN和CAN FD接口量身打造。无论是CAN 2.0 a/b还是CAN FD的PCAN硬件产品,都能在我们的新驱动下“驰骋自如”。想要体验字符模式设备驱动接口(chardev)的便捷,还是Socke…

Navicat Lite导入为SQL,然后到服务器的SQLServer Management 里执行时,报各种错误,是文件的Encoding不一致导致的解决

1、好多时候,本地的操作系统与服务器的操作系统不一致,有的时候也是历史原因,我们不得不用老旧的版本的数据库,比如 SQLServer 2008R2的数据库系统。 2、然后本地因为操作系统是win11的,导致这个SQLServer 2008R2根本…

【自动化测试】处理页面加载元素过慢以及页面中存在frame框架页问题

在自动化测试中,处理页面加载元素过慢以及页面中存在frame框架页等问题,需要采用一些特定的策略和技术来确保测试的顺利进行。下面我将分别针对这两个问题给出一些解决方案: 1. 处理页面加载元素过慢的问题 1.1 等待机制 显式等待&#xf…

如何在Mac上使用VMware配置Windows虚拟机

作者:CSDN-PleaSure乐事 欢迎大家阅读我的博客 希望大家喜欢 使用环境:VMware Fusion 目录​​​​​​​ 1.下载windows虚拟机arm文件 2.打开VMware并拖入刚刚下载完成的iso文件 3.导入完成 4.固件类型 5.选择加密 6.完成 7.默认安装 8.现在安装…

NC 反转字符串

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 描述 写出一个程序…

《React Hooks:让你的组件更灵活》

前端开发中非常流行的React框架。React是一个用于构建用户界面的JavaScript库,尤其适用于构建复杂的单页应用。 React Hooks:让你的组件更灵活 React 是当今最受欢迎的前端 JavaScript 库之一,用于构建用户界面。自从 React 16.8 版本开始&a…

图表操作——图表保存为图片+多个图表批量保存为压缩包——js技能提升

使用场景: echarts图表:生成的柱状图/折线图/饼图等可以实现图表的导出,导出格式为一个图片。也可以支持多个图表同时导出为图片,以压缩包的形式下载下来。 下面介绍单个导出批量导出的具体用法: 1.单个导出功能——…

使用seamless-scroll-v3 实现无缝滚动,自动轮播平滑的滚动效果

安装&#xff1a;npm地址&#xff1a;https://www.npmjs.com/package/seamless-scroll-v3 yarn add seamless-scroll-v3# 或者使用 npm npm install seamless-scroll-v3# 或者使用 pnpm pnpm add seamless-scroll-v3 实现效果&#xff1a; template中的代码&#xff1a; <…

pmp证书为何会被骂?他真的就是个垃圾证书?

说是垃圾到不至于。 毕竟PMP证书今年被北京市列入急需紧缺专业人才人员名单&#xff01;同时可以在创新创业、社会保障、评价激励等方面得到优先支。 其次&#xff0c;证书&#xff0c;其内容可以夯实基础&#xff0c;理清一个项目从启动、执行到最后的收尾做ppt结案的整个流…