Pytorch从零开始实战17

news2024/9/27 23:28:37

Pytorch从零开始实战——生成对抗网络入门

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——生成对抗网络入门
    • 环境准备
    • 模型定义
    • 开始训练
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch1.8+cpu,本次实验目的是了解生成对抗网络。

生成对抗网络(GAN)是一种深度学习模型。GAN由两个主要组成部分组成:生成器和判别器。这两个部分通过对抗的方式共同学习,使得生成器能够生成逼真的数据,而判别器能够区分真实数据和生成的数据。

生成器的任务是生成与真实数据相似的样本。它接收一个随机噪声向量,然后通过深度神经网络将这个随机噪声转换为具体的数据样本。在图像生成的场景中,生成器通常将随机噪声映射为图像。生成器的目标是欺骗判别器,使其无法区分生成的样本和真实的样本。生成器的训练目标是最小化生成的样本与真实样本之间的差异。

判别器的任务是对给定的样本进行分类,判断它是来自真实数据集还是由生成器生成的。它接收真实样本和生成样本,然后通过深度神经网络输出一个概率,表示输入样本是真实样本的概率。判别器的目标是准确地分类样本,使其能够正确地区分真实数据和生成的数据。判别器的训练目标是最大化正确分类的概率。

导入相关包。

import torch
import torch.nn as nn
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

创建文件夹,分别保存训练过程中的图像、模型参数和数据集。

os.makedirs("./images/", exist_ok=True) # 训练过程中图片效果
os.makedirs("./save/", exist_ok=True) # 训练完成时模型保存位置
os.makedirs("./datasets/", exist_ok=True) # 数据集位置

设置超参数。
b1、b2为Adam优化算法的参数,其中b1是梯度的一阶矩估计的衰减系数,b2是梯度的二阶矩估计的衰减系数。
latent_dim表示生成器输入的随机噪声向量的维度。这个噪声向量用于生成器产生新样本。
sample_interval表示在训练过程中每隔多少个batch保存一次生成器生成的样本图像,以便观察生成效果。

epochs = 20
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim=100
img_size=28
channels=1
sample_interval=500

设定图像尺寸并检查cuda,本次使用的设备没有cuda。

img_shape = (channels, img_size, img_size) # (1, 28, 28)
img_area = np.prod(img_shape) # 784
 
## 设置cuda
cuda = True if torch.cuda.is_available() else False
print(cuda) # False

本次使用GAN来生成手写数字,首先下载mnist数据集。

mnist = datasets.MNIST(root='./datasets/', 
                       train=True, 
                       download=True,
                       transform=transforms.Compose([transforms.Resize(img_size),
                                                     transforms.ToTensor(), 
                                                     transforms.Normalize([0.5], [0.5])]))

使用dataloader划分批次与打乱。

dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)

len(dataloader) # 938

模型定义

首先定义鉴别器模型,代码中LeakyReLU是ReLU激活函数的变体,它引入了一个小的负斜率,在负输入值范围内,而不是将它们直接置零。这个斜率通常是一个小的正数,例如0.01。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),        
            nn.LeakyReLU(0.2, inplace=True),  
            nn.Linear(512, 256),             
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Linear(256, 1),              
            nn.Sigmoid(),                    
        )
 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1) 
        validity = self.model(img_flat)     
        return validity         # 返回的是一个[0, 1]间的概率

定义生成器模型,用于输出图像。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(in_feat, out_feat, normalize=True):      
            layers = [nn.Linear(in_feat, out_feat)]          
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8)) 
            layers.append(nn.LeakyReLU(0.2, inplace=True))   
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False), 
            *block(128, 256),                         
            *block(256, 512),                         
            *block(512, 1024),                       
            nn.Linear(1024, img_area),                
            nn.Tanh()                                
        )
    def forward(self, z):                           
        imgs = self.model(z)                       
        imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)
        return imgs                                 # 输出为64张大小为(1, 28, 28)的图像

开始训练

创建生成器、判别器对象。

generator = Generator()
discriminator = Discriminator()

定义损失函数。这个其实就是二分类的交叉熵损失。

criterion = torch.nn.BCELoss()

定义优化函数。

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

开始训练,实现GAN训练过程,其中生成器和判别器交替训练,通过对抗过程使得生成器生成逼真的图像,而判别器不断提高对真实和生成图像的判别能力。

for epoch in range(epochs):                   # epoch:50
    for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)
 
        imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs)     # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1))    ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1))    ## 定义假的图片的label为0
 
 
        real_out = discriminator(real_img)            # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
        real_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 随机生成一些噪声, 大小为(128, 100)
        fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。
        fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的loss
        fake_scores = fake_out
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0
        loss_D.backward()                   # 将误差反向传播
        optimizer_D.step()                  # 更新参数
 
 
        z = Variable(torch.randn(imgs.size(0), latent_dim))     ## 得到随机噪声
        fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)                                    ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()                                             ## 梯度归0
        loss_G.backward()                                                   ## 进行反向传播
        optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数
 
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
            

在这里插入图片描述
保存模型。

torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')

查看最初的噪声图像。
在这里插入图片描述
查看后面生成的图像。
在这里插入图片描述

总结

对于GAN,生成器的任务是从随机噪声生成逼真的数据样本,判别器的任务是对给定的数据样本进行分类,判断其是真实数据还是由生成器生成的。生成器和判别器通过对抗的方式进行训练。在每个训练迭代中,生成器试图生成逼真的样本以欺骗判别器,而判别器努力提高自己的能力,以正确地区分真实和生成的样本。这种对抗训练的动态平衡最终导致生成器生成高质量、逼真的样本。

总之,GAN实现了在无监督情况下学习数据分布的能力,被广泛用于生成逼真图像、视频等数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1396732.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MIdjourney】一些材质相关的关键词

1.多维剪纸(Multidimensional papercut) "Multidimensional papercut"(多维剪纸)是一种剪纸艺术形式,通过多层次的剪纸技巧和设计来创造出立体感和深度感。这种艺术形式通常涉及在不同的纸层上剪裁不同的图案,并将它们…

Node.js基础知识点(四)

本节介绍一下最简单的http服务 一.http 可以使用Node 非常轻松的构建一个web服务器,在 Node 中专门提供了一个核心模块:http http 这个模块的就可以帮你创建编写服务器。 1. 加载 http 核心模块 var http require(http) 2. 使用 http.createServe…

Java学习(二十一)--JDBC/数据库连接池

为什么需要 传统JDBC数据库连接,使用DriverManager来获取; 每次向数据库建立连接时都要将Connection加载到内存中,再验证IP地址、用户名和密码(0.05s~1s)时间。 需要数据库连接时候,就向数据库要求一个&#xf…

卷积神经网络简介-AI快速进阶系列

1. 概述 在本教程中,我们将研究卷积神经网络背后的理论及其架构。 我们将首先讨论通常使用卷积神经网络 (CNN) 执行的任务和特征提取问题。然后,我们将讨论为什么需要CNN,以及为什么传统的前馈神经网络是不够的。 然…

Redis实战之-分布式锁

一、基本原理和实现方式对比 分布式锁:满足分布式系统或集群模式下多进程可见并且互斥的锁。 分布式锁的核心思想就是让大家都使用同一把锁,只要大家使用的是同一把锁,那么我们就能锁住线程,不让线程进行,让程序串行…

20230117-yolov5训练环境搭建

文章目录 1.参考资料2.服务器环境3.安装过程4.问题与解决5.补充6.其它技巧 1.参考资料 https://blog.csdn.net/qq_43573527/article/details/132963466 long错误解决方案 https://pytorch.org/get-started/previous-versions/ pytorch下载的位置 2.服务器环境 conda环境&…

RabbitMQ入门精讲

1. 什么是消息队列 消息指的是两个应用间传递的数据。数据的类型有很多种形式,可能只包含文本字符串,也可能包含嵌入对象。 “消息队列(Message Queue)”是在消息的传输过程中保存消息的容器。在消息队列中,通常有生产者和消费者两个角色。…

NAT实验

一:实验要求 二:实验分析 拓扑图 三:实验配置 1:路由器配置 R1配置IP R2配置IP 2:缺省路由 查看路由表 3:端口映射 4:pc、HTTP配置 5:DNS、client配置 四:实验结果 pc可以ping…

响应式Web开发项目教程(HTML5+CSS3+Bootstrap)第2版 例4-5 select

代码 <!doctype html> <html> <head> <meta charset"utf-8"> <title>select</title> </head><body> <!--单选下拉菜单可设置默认选中项--> 所在城市&#xff08;单选&#xff09;:<br> <select>…

如何使用Portainer部署web站点并实现无公网ip远程访问

文章目录 前言1. 安装Portainer1.1 访问Portainer Web界面 2. 使用Portainer创建Nginx容器3. 将Web静态站点实现公网访问4. 配置Web站点公网访问地址4.1公网访问Web站点 5. 固定Web静态站点公网地址6. 固定公网地址访问Web静态站点 前言 Portainer是一个开源的Docker轻量级可视…

第十一章 请求响应

第十一章 请求响应 1.概述2.请求-postman工具3.请求-简单参数&实体参数4.请求-数组集合参数5.请求-日期参数&JSON参数6.请求-路径参数7.响应-ResponseBody&统一响应结果8.响应-案例 1.概述 将前端发送的请求封装为HttpServletRequest对象 在通过HttpServletRespo…

JVM工作原理与实战(十六):运行时数据区-Java虚拟机栈

专栏导航 JVM工作原理与实战 RabbitMQ入门指南 从零开始了解大数据 目录 专栏导航 前言 一、运行时数据区 二、Java虚拟机栈 1.栈帧的组成 2.局部变量表 3.操作数栈 4.帧数据 总结 前言 JVM作为Java程序的运行环境&#xff0c;其负责解释和执行字节码&#xff0c;管理…

深入解析 Java 方法引用:Lambda 表达式的进化之路

前言 方法引用是 Java 8 提供的一种新特性&#xff0c;它允许我们更简洁地传递现有方法作为参数。这项特性实际上是对 Lambda 表达式的一种补充&#xff0c;通过方法引用&#xff0c;我们可以直接引用现有方法&#xff0c;而无需编写完整的Lambda表达式。最近在使用方法引用的…

ElasticSearch扫盲概念篇[ES系列] - 第500篇

历史文章&#xff08;文章累计500&#xff09; 《国内最全的Spring Boot系列之一》 《国内最全的Spring Boot系列之二》 《国内最全的Spring Boot系列之三》 《国内最全的Spring Boot系列之四》 《国内最全的Spring Boot系列之五》 《国内最全的Spring Boot系列之六》 E…

每日一题——LeetCode1252.奇数值单元格的数目

进阶&#xff1a;你可以设计一个时间复杂度为 O(n m indices.length) 且仅用 O(n m) 额外空间的算法来解决此问题吗&#xff1f; 方法一 直接模拟&#xff1a; 创建一个n x m的矩阵&#xff0c;初始化所有元素为0&#xff0c;对于indices中的每一对[ri,ci]&#xff0c;将矩…

5W紫外激光打标机优势特点

紫外激光打标机在当今市场上备受关注&#xff0c;而5W紫外激光打标机更是其中的佼佼者。作为一种高精度、高效率的激光加工设备&#xff0c;5W紫外激光打标机在各个领域都有着广泛的应用。 首先&#xff0c;让我们来了解一下5W紫外激光打标机的基本原理。紫外激光打标机利用高能…

Springboot日志框架logback与log4j2

目录 Springboot日志使用 Logback日志 日志格式 自定义日志格式 日志文件输出 Springboot启用log4j2日志框架 Springboot日志使用 Springboot底层是使用slf4jlogback的方式进行日志记录 Logback日志 trace&#xff1a;级别最低 debug&#xff1a;调试级别的&#xff0c…

Google play 应用批量下架的可能原因及应对指南

想必大多数上架马甲包或矩阵式上架的开发者们&#xff0c;都遭遇过应用包批量被下架、账号被封的情况。这很令人苦恼&#xff0c;那造成这种情况的可能原因有哪些呢&#xff1f;以及如何降低这种情况发生&#xff1f; 1、代码问题 通常上架成功后被下架的应用&#xff0c;很可…

基于局部信息提取的人脸标志检测算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1 人脸检测 4.2 局部区域选择 4.3 特征提取 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 .........................................…

2024最新 8 款电脑数据恢复软件推荐分享

数据恢复是一个涉及从设备硬盘驱动器检索已删除文件的过程。这可能需要存储在工作站、笔记本电脑、移动设备、服务器、相机、闪存驱动器上的数据——任何在独立或镜像/阵列驱动器上存储数据的东西&#xff0c;无论是内部还是外部。 在某些情况下&#xff0c;文件可能被意外或故…