生成对抗网络(GAN,Generative Adversarial Network)

news2024/11/24 5:32:14

生成对抗网络(GAN,Generative Adversarial Network)是一种深度学习模型,由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的目标是生成逼真的样本,而判别器的目标是区分真实样本与生成样本。它们通过对抗过程相互训练,最终使生成器能够生成高度逼真的数据。

基本概念

  1. 生成器:从随机噪声(通常是高斯噪声)生成数据,表示为 G ( z ) G(z) G(z),其中 z z z 是潜在变量(噪声)。

  2. 判别器:判断输入数据是否真实,表示为 D ( x ) D(x) D(x),其中 x x x 是输入数据。判别器输出一个值,表示其对输入数据为真实的概率。

目标函数

GAN 的目标是通过最小化以下对抗损失来训练生成器和判别器:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

  • p d a t a ( x ) p_{data}(x) pdata(x):真实数据分布。
  • p z ( z ) p_z(z) pz(z):潜在变量分布(通常为高斯分布)。
  • D ( x ) D(x) D(x):判别器对真实样本的判别概率。
  • G ( z ) G(z) G(z):生成器生成的样本。

训练过程

  1. 判别器训练:通过真实样本和生成样本的损失来优化判别器。
  2. 生成器训练:通过判别器的反馈,优化生成器,使得生成的样本更逼近真实样本。

通过不断的对抗训练,生成器最终能够生成接近真实数据的样本,判别器则不断提高其区分能力。

以下是一个使用 PyTorch 实现的简单 GAN 案例,目标是生成手写数字(MNIST 数据集)。代码包括生成器和判别器的定义,以及训练过程。

GAN 案例代码

如果你要使用自定义的 MNISTDataset 类来加载数据,可以将它集成到之前的 GAN 示例中。以下是完整的代码示例,结合你的 MNISTDataset 实现。

完整的 GAN 示例代码

epoch设置为20,以作示例。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

# 自定义 MNIST 数据集类
class MNISTDataset(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.images = self.load_images(images_path)
        self.labels = self.load_labels(labels_path)
        self.transform = transform

    def load_images(self, path):
        with open(path, 'rb') as f:
            f.read(16)  # 跳过前16个字节
            images = np.frombuffer(f.read(), np.uint8).reshape(-1, 1, 28, 28)
        return torch.tensor(images, dtype=torch.float32) / 255.0  # 归一化到 [0, 1]

    def load_labels(self, path):
        with open(path, 'rb') as f:
            f.read(8)  # 跳过前8个字节
            labels = np.frombuffer(f.read(), np.uint8)
        return labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# 超参数
latent_size = 100
num_epochs = 20
batch_size = 64
learning_rate = 0.0002

# 数据准备
data_root = r'./MNIST'
train_dataset = MNISTDataset(
    images_path=os.path.join(data_root, 'train-images-idx3-ubyte'),
    labels_path=os.path.join(data_root, 'train-labels-idx1-ubyte')
)

data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# 生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_size, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()  # 输出在[-1, 1]范围内
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# 判别器模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出在[0, 1]范围内
        )

    def forward(self, x):
        return self.model(x.view(-1, 28 * 28))

# 初始化模型和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 检查是否有 GPU
generator = Generator().to(device)  # 移动到 GPU
discriminator = Discriminator().to(device)  # 移动到 GPU
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 训练过程
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.to(device)  # 移动输入数据到 GPU

        # 标签
        real_labels = torch.ones(images.size(0), 1).to(device)  # 移动到 GPU
        fake_labels = torch.zeros(images.size(0), 1).to(device)  # 移动到 GPU

        # 判别器训练
        optimizer_D.zero_grad()
        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)

        z = torch.randn(images.size(0), latent_size).to(device)  # 随机噪声,移动到 GPU
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # 生成器训练
        optimizer_G.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

# 可视化生成的图像
with torch.no_grad():
    z = torch.randn(64, latent_size).to(device)  # 随机噪声,移动到 GPU
    fake_images = generator(z)

# 显示生成的图像
grid_img = fake_images.cpu().numpy()  # 移动到 CPU 以便绘图
grid_img = grid_img.reshape(-1, 28, 28)
plt.figure(figsize=(8, 8))
for i in range(64):
    plt.subplot(8, 8, i + 1)
    plt.imshow(grid_img[i], cmap='gray')
    plt.axis('off')
plt.show()

在这里插入图片描述

代码说明

  1. 自定义数据集MNISTDataset 类用于从指定的路径加载 MNIST 数据。
  2. 数据归一化:在 load_images 方法中,将图像数据归一化到 [0, 1] 范围。
  3. 数据加载:使用 DataLoader 创建训练数据集的加载器。
  4. GAN 模型:包含生成器和判别器的定义。
  5. 训练过程:判别器和生成器交替更新。
  6. 可视化生成图像:训练结束后生成并显示手写数字图像。

你可以运行这个代码,并观察生成的手写数字。确保 MNIST 数据集文件在指定的路径下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2203332.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Xinstall品牌揭秘:如何成为App拉新的行业翘楚?

在移动互联网时代,App作为连接用户与服务的桥梁,其重要性不言而喻。然而,随着市场竞争的加剧,App拉新(即吸引新用户下载并使用App)的难度也在逐渐增大。传统的营销方式往往面临着成本高、效率低、用户留存差…

理解PID(零)——什么是PID

PID控制器是一种广泛用于各种工业控制场合的控制器,它结构简单,可以根据工程经验整定参数Kp,Ki,Kd. 虽然现在控制专家提出了很多智能的控制算法,比如神经网络,模糊控制等,但是PID仍然被广泛使用。常见的PID控制器有位置…

视频怎么转gif动图?5个简单转换方法快来学(详细教程)

相信大家在社交平台上会经常看到一些有趣的gif动图表情包,有些小伙伴就会问:这些GIF动图是如何制作的呢?一般GIF动图表情包可以用视频来制作,今天小编就来给大家分享几个视频转成GIF动图的方法,相信通过以下的几个方法…

文献阅读CONCH模型--相关知识点罗列

文章链接:A visual-language foundation model for computational pathology | Nature MedicineThe accelerated adoption of digital pathology and advances in deep learning have enabled the development of robust models for various pathology tasks across…

【可答疑】基于51单片机的智能家居系统(含仿真、代码、报告、演示视频等)

✨哈喽大家好,这里是每天一杯冰美式oh,985电子本硕,大厂嵌入式在职0.3年,业余时间做做单片机小项目,有需要也可以提供就业指导(免费)~ 🐱‍🐉这是51单片机毕业设计100篇…

ceph基础

ceph基础搭建 存储基础 传统的存储类型: DAS设备: SAS,SATA,SCSI,IDW,USB 无论是那种接口,都是存储设备驱动下的磁盘设备,而磁盘设备其实就是一种存储是直接接入到主板总线上去的。直连存储。 NAS设备: NFS CIFS FTP 几乎所有的…

商标恶意维权形式及应对策略

在商业领域,商标恶意维权的现象时有出现,给正常的市场秩序和企业经营带来了不良影响。以下将介绍其常见形式及应对方法。 一、商标恶意维权的形式1、囤积商标后恶意诉讼。一些人或企业大量注册与知名品牌相似或具有一定通用性的商标,并非用于…

留学生毕业论文设计问卷questionnaire的基本步骤

在上一期内容中,小编介绍了留学毕业论文的定量研究和相关的问卷设计。然而在一些研究中,定量研究和问卷数据并不能满足我们的研究需求。这种情况下,我们可以采取其他的数据收集方式,例如observation,case study和inter…

软件设计之SSM(11)

软件设计之SSM(11) 路线图推荐: 【Java学习路线-极速版】【Java架构师技术图谱】 尚硅谷新版SSM框架全套视频教程,Spring6SpringBoot3最新SSM企业级开发 资料可以去尚硅谷官网免费领取 学习内容: Springboot 配置文件整合SpringMVC整合Dr…

【学术会议征稿】第十届能源资源与环境工程研究进展国际学术会议(ICAESEE 2024)

第十届能源资源与环境工程研究进展国际学术会议(ICAESEE 2024) 2024 10th International Conference on Advances in Energy Resources and Environment 第十届能源资源与环境工程研究进展国际学术会议(ICAESEE 2024)定于2024年…

拓扑排序与入度为0的结点算法解析及实现

拓扑排序与入度为0的结点算法解析及实现 算法思想时间复杂度分析伪代码C语言实现环路检测结论拓扑排序是一种用于有向无环图(DAG, Directed Acyclic Graph)的重要操作,它可以对图中的结点进行排序,使得对于每一条有向边 (u, v),顶点 u 在排序中都出现在顶点 v 之前。本文介…

Qt和c++面试集合

目录 Qt面试 什么是信号(Signal)和槽(Slot)? 什么是Meta-Object系统? 什么是Qt的MVC模式? 1. QT中connect函数的第五个参数是什么?有什么作用? 3. 在QT中&#xff…

ROS2官方文档(2024-10-10最新版)

ROS 2 Documentation — ROS 2 Documentation: Jazzy documentation (armfun.cn) ROS 2 文档 — ROS 2 文档:Humble 文档 (armfun.cn) 翻译中文方法:使用windows11自带Edge浏览器打开,右上角点击翻译为中文

pytest框架之fixture测试夹具详解

前言 大家下午好呀,今天呢来和大家唠唠pytest中的fixtures夹具的详解,废话就不多说了咱们直接进入主题哈。 一、fixture的优势 ​ pytest框架的fixture测试夹具就相当于unittest框架的setup、teardown,但相对之下它的功能更加强大和灵活。 …

DBMS-3.3 SQL(3)——DML的INSERT、UPDATE、DELETE空值的处理DCL

本文章的素材与知识来自李国良老师和王珊老师。 DML——INSERT、UPDATE、DELETE 一. INSERT 1.语法 (1)INTO子句 (2)VALUES子句 (3)示例 2.插入子查询 若插入的是子查询则不需要VALUES子句 二. UPDATE …

大数据法律监督模型平台实现常态化法律监督

大数据法律监督模型平台充分挖掘大数据价值,利用大数据关联、碰撞、比对,从海量数据中自动筛查出法律监督线索,推送给检察官,有利于提升法律监督质效。 大数据法律监督模型平台建设目标 1、提升监察机关主动监督、精准…

基于DCGM+Prometheus+Grafana的GPU监控方案

目录 前言一、指标导出器1、DCGM:获取远程节点的信息 2、 DCGM-Exporter收集多节点信息更改收集指标 二、 Prometheus - From metrics to insight修改配置文件查看收集结果 三、Grafana仪表板展示导入数据源创建仪表板更多仪表板 前言 基于DCGM(NVIDIA …

[SAP ABAP] LIKE TABLE OF

LIKE TABLE OF语句是用来参照结构体(工作区)对象定义内表数据类型的语句 在SAP ABAP中有标准表&#xff0c;排序表和哈希表三种内表数据类型 *定义标准表 DATA: <ty_tab_standard_name> LIKE [STANDARD] TABLE OF <dtype> [WITH NON-UNIQUE KEY <k1 k2 ... kn…

Python自动给课本文字标注拼音

环境&#xff1a; Ubuntu20.04&#xff0c;ubuntu20.04自带python版本 3.8.10&#xff0c;pip的版本是 20.0.2 pip install pypinyin # 安装失败&#xff0c;检查更新pip确保pip是最新版本&#xff1a; pip install --upgrade pip 检查是否安装成功 pip show pypinyin pinyin…

【电路笔记】-求和运算放大器

求和运算放大器 文章目录 求和运算放大器1、概述2、反相求和放大器3、同相求和放大器4、减法放大器5、应用5.1 音频混合器5.2 数模转换器 (DAC)6、总结1、概述 在我们之前有关运算放大器的大部分文章中,仅将一个输入应用于反相或非反相运算放大器的输入。在本文中,将讨论一种…