深度学习(四):pytorch搭建GAN(对抗网络)

news2025/2/27 1:50:48

1.GAN

生成对抗网络(GAN)是一种深度学习模型,由两个网络组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成假数据,而判别器则负责判断数据是真实的还是 fake的。这两个网络互相竞争,生成器试图生成更真实的数据以欺骗判别器,而判别器则试图更好地识别生成的数据。
在这里插入图片描述

GAN 的基本思想是:通过训练生成器和判别器,使得生成器能够生成与真实数据非常相似的数据,同时使得判别器能够更有效地识别这些数据。

1.1 概念

  1. 生成器(Generator):生成器是一个神经网络,其目的是生成假的数据,看起来像是真实的。生成器通常包含一些神经网络层,如卷积层、全连接层等。生成器接受随机噪声作为输入,并生成看起来像是真实数据的输出。
  2. 判别器(Discriminator):判别器也是一个神经网络,其目的是识别数据是真实的还是 fake的。判别器通常也包含一些神经网络层,如卷积层、全连接层等。判别器接受输入数据,并输出一个分数,表示输入数据是真实的还是 fake的。
  3. 生成对抗训练:生成对抗训练是指同时训练生成器和判别器。生成器试图生成更真实的数据,以欺骗判别器。判别器则试图更好地识别生成的数据,以避免被欺骗。生成器和判别器之间的竞争导致它们不断改进,以提高生成数据的真实性。
  4. 生成器损失和判别器损失:生成器损失是指生成器试图生成更真实数据的损失。生成器损失通常使用生成器的对抗损失和生成损失之和来计算。判别器损失是指判别器试图更好地识别真实数据和假数据的损失。判别器损失通常使用判别器识别真实数据和假数据的损失之和来计算。
  5. 对抗性训练:对抗性训练是指在训练过程中,使用生成器生成的假数据来训练判别器,以提高判别器的识别能力。同时,使用判别器识别的反馈来训练生成器,以提高生成器生成更真实数据的能力。

1.2 优势

GAN(Generative Adversarial Network)是一种生成对抗网络,主要由生成器和判别器组成。生成器负责生成假数据,而判别器负责判断数据是真实的还是 fake的。GAN 的训练过程相对复杂,但是它可以生成非常真实的数据,并且可以用来进行数据增强、图像生成、视频生成等应用。

GAN 的优势主要体现在以下几个方面:

  1. 生成数据非常真实:GAN 可以生成非常真实的数据,可以用来进行数据增强、图像生成、视频生成等应用。
  2. 可以生成大量数据:GAN 可以生成大量的数据,可以用来进行机器学习、深度学习等应用。
  3. 可以生成不同类型的数据:GAN 可以生成不同类型的数据,可以用来进行图像生成、视频生成等应用。
  4. 可以进行对抗训练:GAN 可以进行对抗训练,可以提高模型的鲁棒性和泛化能力。

虽然 GAN 具有优势,但是也存在一些挑战,例如训练过程复杂、生成器容易过拟合、对抗训练难以实现等。因此,在实际应用中,需要根据具体情况进行优化和调整。

1.3 训练技巧

  1. 使用批归一化(Batch Normalization):批归一化是一种在卷积神经网络中常用的加速训练和提高模型性能的方法。在 GAN 的生成器和判别器中可以使用批归一化来提高性能。
  2. 使用 Leaky ReLU 激活函数:Leaky ReLU 激活函数是一种在 ReLU 激活函数中加入一个小于 1 的常数,以避免神经元死亡的方法。在 GAN 的生成器和判别器中可以使用 Leaky ReLU 激活函数来提高性能。
  3. 使用 U-Net 结构:U-Net 是一种用于图像分割的网络结构,其结构可以同时实现编码器和解码器。在 GAN 的生成器中可以使用 U-Net 结构来提高生成图像的质量。
  4. 使用对抗性损失(Adversarial Loss):对抗性损失是一种可以增加生成器损失的方法,通过在损失函数中加入一个与真实数据接近的噪声来增加生成器的难度。在 GAN 的训练过程中可以使用对抗性损失来提高性能。
  5. 使用预训练模型:预训练模型是一种在已有数据集上训练好的模型,可以用于迁移学习和提高性能。在 GAN 的生成器和判别器中可以使用预训练模型来提高性能。
  6. 使用注意力机制(Attention):注意力机制是一种可以提高模型性能和泛化能力的方法,可以在 GAN 的生成器和判别器中使用注意力机制来提高性能。

总结起来,GAN 的训练过程需要综合考虑多个方面,包括数据预处理、损失函数选择、正则化、梯度裁剪、对抗性训练、数据增强和 early stopping 等技巧。同时,还可以使用一些额外的技巧,如批归一化、Leaky ReLU 激活函数、U-Net 结构、对抗性损失、预训练模型和注意力机制等来进一步提高 GAN 的性能。

2 代码实现

步骤:

  1. 导入所需的库和模块。
  2. 定义生成器的网络结构,包括全连接层和激活函数。
  3. 定义判别器的网络结构,也包括全连接层和激活函数。
  4. 定义训练函数,包括将模型移动到设备、定义损失函数和优化器、开始训练的循环等。
  5. 设置随机种子。
  6. 设置设备,如果有可用的GPU则使用GPU,否则使用CPU。
  7. 加载MNIST数据集,并进行数据预处理。
  8. 初始化生成器和判别器。
  9. 设置训练的参数,如训练轮数、生成器的输入维度等。
  10. 调用训练函数进行训练。
# 导入torch模块
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 定义生成器的网络结构
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),  # 全连接层,输入latent_dim维,输出256维
            nn.LeakyReLU(0.2),  # LeakyReLU激活函数
            nn.Linear(256, 512),  # 全连接层,输入256维,输出512维
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),  # 全连接层,输入512维,输出1024维
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),  # 全连接层,输入1024维,输出784维
            nn.Tanh()  # Tanh激活函数
        )

    def forward(self, x):
        return self.model(x)

# 定义判别器的网络结构
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),  # 全连接层,输入784维,输出512维
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),  # 全连接层,输入512维,输出256维
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),  # 全连接层,输入256维,输出1维
            nn.Sigmoid()  # Sigmoid激活函数
        )

    def forward(self, x):
        return self.model(x)

# 定义训练函数
def train(generator, discriminator, dataloader, num_epochs, latent_dim, device):
    # 将模型移动到设备
    generator.to(device)
    discriminator.to(device)

    # 定义损失函数和优化器
    criterion = nn.BCELoss()  # 二分类交叉熵损失函数
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 生成器的优化器
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 判别器的优化器

    # 开始训练
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            # 将图像转换为向量
            real_images = real_images.view(-1, 784).to(device)
            # 获取图像的batch_size
            batch_size = real_images.size(0)
            # 定义真实标签和 fake标签
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # 训练判别器
            optimizer_D.zero_grad()
            # 计算真实图像的输出
            real_outputs = discriminator(real_images)
            # 计算真实图像的损失
            real_loss = criterion(real_outputs, real_labels)

            # 生成假图像
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_images = generator(z)
            # 计算假图像的输出
            fake_outputs = discriminator(fake_images.detach())
            # 计算假图像的损失
            fake_loss = criterion(fake_outputs, fake_labels)

            # 计算判别器的损失
            d_loss = real_loss + fake_loss
            # 反向传播
            d_loss.backward()
            # 更新参数
            optimizer_D.step()

            # 训练生成器
            optimizer_G.zero_grad()
            # 计算假图像的输出
            fake_outputs = discriminator(fake_images)
            # 计算生成器的损失
            g_loss = criterion(fake_outputs, real_labels)

            # 反向传播
            g_loss.backward()
            # 更新参数
            optimizer_G.step()

            # 每200步打印一次损失
            if (i+1) % 200 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "
                      f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
        # 每1步打印一次图像
        if (epoch+1) % 1 == 0:
            # 生成图像
            with torch.no_grad():
                z = torch.randn(10, 100).to(device)
                generated_images = generator(z).cpu().view(-1, 28, 28)

            # 展示原始数据和生成数据的图像
            fig, axes = plt.subplots(2, 5, figsize=(10, 4))
            for i, ax in enumerate(axes.flat):
                if i < 5:
                    ax.imshow(real_images[i].view(28, 28), cmap='gray')
                    ax.set_title('Real')
                else:
                    ax.imshow(generated_images[i-5], cmap='gray')
                    ax.set_title('Generated')
                ax.axis('off')
            plt.tight_layout()
            plt.show()

# 设置随机种子
torch.manual_seed(42)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 初始化生成器和判别器
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()

# 训练GAN模型
num_epochs = 50
train(generator, discriminator, train_dataloader, num_epochs, latent_dim, device)

2.1结果

第一轮:

在这里插入图片描述
训练之后:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1278550.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

cnpm 安装后无法使用怎么办?

问题的原因 cnpm 安装成功&#xff0c;但是却无法使用&#xff0c;一般分为两种情况&#xff0c;一种是提示无法执行命令&#xff0c;另一种是可以执行但是执行时报错&#xff0c;下面分别说明遇到这两种情况的解决方案。 解决方案 问题一&#xff1a;无法执行相关命令 首先…

Docker快速入门(docker加速,镜像,容器,数据卷常见命令操作整理)

Docker本质是将代码所需的环境依赖进行打包运行,而在Docker中最重要的是镜像和容器 镜像:可以简单地理解为每启动一个docker镜像就会占用计算机一个进程,这个进程和另外起的docker镜像的进程是相互独立的,以数据库为例,每个镜像都会copy一份数据库,在他所在的进程中.别的镜像在…

根文件系统构建-对busybox进行配置

一. 简介 本文来学习 根文件系统的制作中&#xff0c;关于 busybox的配置。 本文继上一篇 busybox中文支持的设置&#xff0c;地址如下&#xff1a; 根文件系统构建-busybox中文支持-CSDN博客 二. 根文件系统构建-busybox配置 1. 配置 busybox 与我们编译 Uboot 、 Lin…

DBeaver 社区版(免费版)下载、安装、解决驱动更新出错问题

DBeaver 社区版&#xff08;免费版&#xff09; DBeaver有简洁版&#xff0c;企业版&#xff0c;旗舰版&#xff0c;社区版&#xff08;免费版&#xff09;。除了社区版&#xff0c;其他几个版本都是需要付费的&#xff0c;当然相对来说&#xff0c;功能也要更完善些&#xff…

HashMap源码全面解析

注&#xff1a;本篇文章是在JDK1.8版本源码进行分析。 一、概述 HashMap 是基于哈希表的 Map接口的实现&#xff0c;是以 key-value 存储形式存在&#xff0c;即主要用来存储键值对。 HashMap的类图&#xff1a; HashMap继承抽象类AbstractMap&#xff0c;实现了Map、Clonea…

select选择框里填充图片,下拉选项带图片

遇到一个需求&#xff0c;选择下拉框选取图标&#xff0c;填充到框里 1、效果展示 2、代码 <el-form-item label"工种图标" class"Form_icon Form_label"><el-select ref"select" :value"formLabelAlign.icon" placeholder&…

2023年第十二届数学建模国际赛小美赛B题工业表面缺陷检测求解分析

2023年第十二届数学建模国际赛小美赛 B题 工业表面缺陷检测 原题再现&#xff1a; 金属或塑料制品的表面缺陷不仅影响产品的外观&#xff0c;还可能对产品的性能或耐久性造成严重损害。自动表面异常检测已经成为一个有趣而有前景的研究领域&#xff0c;对视觉检测的应用领域有…

PyQt6 QRadioButton单选按钮控件

​锋哥原创的PyQt6视频教程&#xff1a; 2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计33条视频&#xff0c;包括&#xff1a;2024版 PyQt6 Python桌面开发 视频教程(无废话…

Opencv框选黑色字体进行替换(涉及知识点:selectROI,在控制台输入字体大小,颜色,内容替换所选择的区域)

import cv2 from PIL import Image,ImageDraw,ImageFont import numpy as npimg_path ../img/ img_clean_path ../img_clean/ name xiao_ben suf .pngimg cv2.imread(img_pathnamesuf) cv2.imshow(original, img)# 选择ROI roi cv2.selectROI(windowName"original&q…

Linux多核飞控

Linux多核飞控是一种基于多核处理器构建的飞控系统&#xff0c;用于控制飞行器的飞行。这种飞控系统使用Linux操作系统作为主要的控制平台&#xff0c;可以支持多个处理器核心同时工作&#xff0c;以实现更高的性能和更快的响应速度。 Linux通常用于具有较高计算量和较大内存需…

Python读取json数据导出到Excel

一、JSON字符串转换为Python对象 导入Python的json模块。该模块包含两个重要的功能-loads和load,读取JSON文件&#xff0c;并将JSON数据解析为Python数据&#xff0c;除了JSON&#xff0c;我们还需要Python的原生函数open()。一般loads用于读取JSON字符串&#xff0c;而load()用…

【数据中台】开源项目(4)-BitSail

介绍 BitSail是字节跳动开源的基于分布式架构的高性能数据集成引擎, 支持多种异构数据源间的数据同步&#xff0c;并提供离线、实时、全量、增量场景下的全域数据集成解决方案&#xff0c;目前服务于字节内部几乎所有业务线&#xff0c;包括抖音、今日头条等&#xff0c;每天同…

CleanMyMac X2024Macos强大的系统优化工具

都说苹果的闪存是金子做的&#xff0c;这句话并非空穴来风&#xff0c;普遍都是256G起步&#xff0c;闪存没升级一个等级&#xff0c;价格都要增加上千元。昂贵的价格让多数消费者都只能选择低容量版本的mac。而低容量的mac是很难满足用户的需求的&#xff0c;伴随着时间的推移…

初始数据结构(加深对旋转的理解)

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/rotate-array/submissions/ 与字…

《堆》的模拟实现

目录 前言&#xff1a; 模拟实现《堆》&#xff1a; 1.自定义数据类型 2.初始化“堆” 3.销毁“堆” 4.进“堆” 关于AdjustUp() 5.删除堆顶元素 关于AdjustDown() 6.判断“堆”是否为空 7.求“堆”中的数据个数 8.求“堆”顶元素 总结&#xff1a; 前言&#xf…

锐捷RG-UAC应用网关 前台RCE漏洞复现

0x01 产品简介 锐捷RG-UAC系列应用管理网关是锐捷自主研发的应用管理产品。 0x02 漏洞概述 锐捷RG-UAC应用管理网关 nmc_sync.php 接口处存在命令执行漏洞&#xff0c;未经身份认证的攻击者可执行任意命令控制服务器权限。 0x03 复现环境 FOFA&#xff1a;app"Ruijie-R…

软著项目推荐 深度学习手势识别算法实现 - opencv python

文章目录 1 前言2 项目背景3 任务描述4 环境搭配5 项目实现5.1 准备数据5.2 构建网络5.3 开始训练5.4 模型评估 6 识别效果7 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 深度学习手势识别算法实现 - opencv python 该项目较为新颖…

宝塔面板:轻松玩转linux系统,实现服务器状态监控和运维部署!

. linux安装 安装命令概述基本设置软件安装设置安全设置文件管理日志模块终端模块计划任务卸载命令windows服务器安装 下载卸载遗留user.ini文件删除报错 宝塔面板是一款服务器管理软件&#xff0c;旨在提升运维效率。它支持一键安装LAMP/LNMP/集群/监控/网站/FTP/数据库/JAVA等…

【代码】计及碳捕集电厂低碳特性及需求响应的综合能源系统多时间尺度调度模型matlab/yalmip代码

程序名称&#xff1a;计及碳捕集电厂低碳特性及需求响应的综合能源系统多时间尺度调度模型 实现平台&#xff1a;matlab-yalmip-cplex/gurobi 代码简介&#xff1a;代码主要做的是一个虚拟电厂/微网多时间尺度电热综合能源系统低碳经济调度模型&#xff0c;源侧在碳捕集电厂中…

2024年美国大学生数学建模竞赛(MCM/ICM)论文写作方法指导

一、前言 谈笑有鸿儒&#xff0c;往来无白丁。鸟宿池边树&#xff0c;僧敲月下门。士为知己者死&#xff0c;女为悦己者容。吴楚东南坼&#xff0c;乾坤日夜浮。剪不断&#xff0c;理还乱&#xff0c;是离愁&#xff0c;别是一番滋味在心头。 重要提示&#xff1a;优秀论文的解…