什么是GAN?

news2025/1/12 10:51:59

一、基本概念

        生成对抗网络(Generative Adversarial Network,GAN)是一种由两个神经网络共同组成深度学习模型:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗的方式进行训练,生成器尝试伪造逼真的样本数据,而判别器则负责判断输入的数据是真实数据还是生成器伪造出来的数据。理想情况下,判别器对真实样本和生成样本的判断概率都是1/2,意味着判别器已经无法判断生成器生成的数据真假。

二、模型原理

        GAN的模型原理并不复杂。首先,GAN由以下两个子模型组成:

  • 生成器(Generator)从随机噪声中生成数据,目标是欺骗判别器,使其认为生成的数据是真实的。
  • 判别器(Discriminator):判断输入数据是来自真实数据分布还是生成器,目标是正确区分真实数据和生成数据。

        然后,GAN的损失函数是训练的核心,我们需要构建一个合适的损失函数用于衡量生成器和判别器的表现:

  • 生成器损失(G_loss):通常表示为最大化判别器对其生成样本的错误分类概率,也就是判别器判定所有生成数据均为真。
  • 判别器损失(D_loss):由两部分组成,一部分是真实样本的损失(标签为1),另一部分是生成样本的损失(标签为0)。

        最后,我们通过算法设计来交替训练生成器和判别器,例如生成器每训练5个Epoch,我们就训练一次判别器:

  • 训练判别器:提高其区分真实样本和生成样本的能力。
  • 训练生成器:提高其生成真实样本的能力,目标是最大化判别器将其生成样本识别为真实样本的概率。

三、python实现

1、导库

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA

2、数据处理

        这里我们的目标是训练一个生成对抗网络来生成iris数据,使用sklearn的iris数据集训练。这意味着,我们输入给生成器的信息中需要包含类别信息,这样生成器才能生成对应类别的数据样本。当然,这一步不是必要的,在类别不敏感的任务中,只需要生成符合要求的数据即可。

# 加载Iris数据集
iris = load_iris()
data = iris.data
labels = iris.target

# 标准化数据
scaler = StandardScaler()
data = scaler.fit_transform(data)

# One-hot编码标签
encoder = OneHotEncoder(sparse=False)
# torch.Size([100, 3])
labels = encoder.fit_transform(labels.reshape(-1, 1))

# 转换为PyTorch张量
data = torch.FloatTensor(data)
labels = torch.FloatTensor(labels)

# 创建数据加载器
batch_size = 32
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

3、构建生成器

        这里,我们构建一个全连接神经网络。生成器的输入包括随机初始化的x,以及x对应的期望类别,期望类别是可以真实标签,表示生成对应类别下的数据样本。

# 生成器网络
class Generator(nn.Module):
    def __init__(self, input_dim, label_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.model(x)

4、构建判别器

        这里,我们的判别器实际上是一个二分类模型。判别器的输入维度跟生成器一直,都需要考虑类别信息。

# 判别器网络
class Discriminator(nn.Module):
    def __init__(self, input_dim, label_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.model(x)

5、超参数设置

        值得注意的是,我们分别为生成器和判别器构造一个优化器,从而便于分开训练两个子模型。

# 设置超参数
latent_dim = 100
data_dim = data.shape[1]
label_dim = labels.shape[1]
lr = 0.0002
num_epochs = 200

# 初始化生成器和判别器
generator = Generator(latent_dim, label_dim, data_dim)
discriminator = Discriminator(data_dim, label_dim)

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 损失函数
criterion = nn.BCELoss()

6、模型训练

        这里,我们选择了分开训练生成器和判别器,在一个epoch中,先训练3次生成器,再训练一次判别器。这样的目的是增加生成器的学习时间,从而使得生成的样本更为真实。

# 训练GAN
for epoch in range(num_epochs):
    for i, (real_data, real_labels) in enumerate(dataloader):
        batch_size = real_data.size(0)
        
        # 当前仅训练生成器
        generator.train()
        discriminator.eval()

        # 迭代训练生成器,这里是每个epoch训练3次
        for _ in range(3):
            z = torch.randn(batch_size, latent_dim)
            # 直接使用真实标签即可,这里的标签代表的是样本类别,目的是让模型学习到类别差异
            # 生成器生成的是各对应类别的数据
            fake_data = generator(z, real_labels)
            # 使用判别器对生成的假数据进行分类
            outputs = discriminator(fake_data, real_labels)
            # 基于判别器的结果计算生成器的损失,目标是让判别器认为生成的数据是真实的(标签为1)
            # 如果这里使用的是torch.zeros则生成器的结果将会非常差,几乎无法生成真实数据
            # 这是由于我们的目标是让outputs逼近全1向量,也就是让判别器认为所有生成的数据都是真实的,这样才能让生成样本越来越真实
            g_loss = criterion(outputs, torch.ones(batch_size, 1))

            # 反向传播生成器的梯度
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

        # 当前仅训练判别器
        generator.eval()
        discriminator.train()

        # 训练判别器,真实样本标签为1,生成样本标签为0
        real_targets = torch.ones(batch_size, 1)
        fake_targets = torch.zeros(batch_size, 1)

        # 真实数据损失
        outputs = discriminator(real_data, real_labels)
        d_loss_real = criterion(outputs, real_targets)
        real_score = outputs

        # 生成假数据,计算损失
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z, real_labels)
        outputs = discriminator(fake_data.detach(), real_labels)
        # 这里的目标与上面生成器部分相反,我们是要让outputs逼近全0向量,也就是全部预测出假数据
        # 所以fake_targets是一个全0向量
        d_loss_fake = criterion(outputs, fake_targets)
        fake_score = outputs

        # 总的判别器损失
        d_loss = d_loss_real + d_loss_fake

        # 反向传播判别器的梯度
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

    if epoch%10==0:
        # 打印损失
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '
              f'D(x): {real_score.mean().item():.4f}, D(G(z)): {fake_score.mean().item():.4f}')

7、生成新数据

        最后,我们使用训练好的GAN中的生成器来生成一批新数据。可以看到,效果不错。

# 生成新数据
num_samples = 100
z = torch.randn(num_samples, latent_dim)
labels = np.array([0, 1, 2] * (num_samples // 3) + [0] * (num_samples % 3))
labels = encoder.transform(labels.reshape(-1, 1))
labels = torch.FloatTensor(labels)
generated_data = generator(z, labels).detach().numpy()

# 降维
pca = PCA(n_components=2)
data_2d = pca.fit_transform(data)
generated_data_2d = pca.transform(generated_data)

# 可视化生成的数据
plt.figure(figsize=(10, 5))
for i in range(3):
    real_class_data = data_2d[iris.target == i]
    generated_class_data = generated_data_2d[np.argmax(labels.numpy(), axis=1) == i]
    plt.scatter(real_class_data[:, 0], real_class_data[:, 1], label=f'Real Class {i}')
    plt.scatter(generated_class_data[:, 0], generated_class_data[:, 1], label=f'Generated Class {i}')
plt.legend()
plt.show()

7aab63ce61704b35afa41a14652ac1e7.png

四、完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA


# 加载Iris数据集
iris = load_iris()
data = iris.data
labels = iris.target

# 标准化数据
scaler = StandardScaler()
data = scaler.fit_transform(data)

# One-hot编码标签
encoder = OneHotEncoder(sparse=False)
labels = encoder.fit_transform(labels.reshape(-1, 1))

# 转换为PyTorch张量
data = torch.FloatTensor(data)
labels = torch.FloatTensor(labels)

# 创建数据加载器
batch_size = 32
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 生成器网络
class Generator(nn.Module):
    def __init__(self, input_dim, label_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.model(x)

# 判别器网络
class Discriminator(nn.Module):
    def __init__(self, input_dim, label_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.model(x)

# 设置超参数
latent_dim = 100
data_dim = data.shape[1]
label_dim = labels.shape[1]
lr = 0.0002
num_epochs = 200

# 初始化生成器和判别器
generator = Generator(latent_dim, label_dim, data_dim)
discriminator = Discriminator(data_dim, label_dim)

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 损失函数
criterion = nn.BCELoss()

# 训练GAN
for epoch in range(num_epochs):
    for i, (real_data, real_labels) in enumerate(dataloader):
        batch_size = real_data.size(0)
        
        generator.train()
        discriminator.eval()
        # 迭代训练生成器,这里是每个epoch训练3次
        for _ in range(3):
            z = torch.randn(batch_size, latent_dim)
            # 直接使用真实标签即可,这里的标签代表的是样本类别,目的是让模型学习到类别差异
            # 生成器生成的是各对应类别的数据
            fake_data = generator(z, real_labels)
            # 使用判别器对生成的假数据进行分类
            outputs = discriminator(fake_data, real_labels)
            # 基于判别器的结果计算生成器的损失,目标是让判别器认为生成的数据是真实的(标签为1)
            # 如果这里使用的是torch.zeros则生成器的结果将会非常差,几乎无法生成真实数据
            # 这是由于我们的目标是让outputs逼近全1向量,也就是让判别器认为所有生成的数据都是真实的,这样才能让生成样本越来越真实
            g_loss = criterion(outputs, torch.ones(batch_size, 1))

            # 反向传播生成器的梯度
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

        generator.eval()
        discriminator.train()
        # 训练判别器,真实样本标签为1,生成样本标签为0
        real_targets = torch.ones(batch_size, 1)
        fake_targets = torch.zeros(batch_size, 1)

        # 真实数据损失
        outputs = discriminator(real_data, real_labels)
        d_loss_real = criterion(outputs, real_targets)
        real_score = outputs

        # 生成假数据,计算损失
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z, real_labels)
        outputs = discriminator(fake_data.detach(), real_labels)
        # 这里的目标与上面生成器部分相反,我们是要让outputs逼近全0向量,也就是全部预测出假数据
        # 所以fake_targets是一个全0向量
        d_loss_fake = criterion(outputs, fake_targets)
        fake_score = outputs

        # 总的判别器损失
        d_loss = d_loss_real + d_loss_fake

        # 反向传播判别器的梯度
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

    if epoch%10==0:
        # 打印损失
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, '
              f'D(x): {real_score.mean().item():.4f}, D(G(z)): {fake_score.mean().item():.4f}')

# 生成新数据
num_samples = 100
z = torch.randn(num_samples, latent_dim)
labels = np.array([0, 1, 2] * (num_samples // 3) + [0] * (num_samples % 3))
labels = encoder.transform(labels.reshape(-1, 1))
labels = torch.FloatTensor(labels)
generated_data = generator(z, labels).detach().numpy()

# 降维
pca = PCA(n_components=2)
data_2d = pca.fit_transform(data)
generated_data_2d = pca.transform(generated_data)

# 可视化生成的数据
plt.figure(figsize=(10, 5))
for i in range(3):
    real_class_data = data_2d[iris.target == i]
    generated_class_data = generated_data_2d[np.argmax(labels.numpy(), axis=1) == i]
    plt.scatter(real_class_data[:, 0], real_class_data[:, 1], label=f'Real Class {i}')
    plt.scatter(generated_class_data[:, 0], generated_class_data[:, 1], label=f'Generated Class {i}')
plt.legend()
plt.show()

五、总结

        生成对抗网络是一个很经典的深度学习模型,它在诸多领域中发挥着重要作用。除了超参数调整之外,训练GAN的另一个关键步骤是构造一个合适的训练策略。例如,可以同时训练生成器和判别器,也可以交替训练二者,或者先训练生成器再训练判别器等等。但是,这两个网络是相互博弈的,由于生成器参数是随机初始化的,一开始生成的数据质量往往较差。我们的策略一般是先让生成器变强(通过构造更复杂的网络结构或者更多的训练次数),让生成的数据质量先提升。这样随着训练的迭代,生成的样本越来越逼真,判别器也不得不为了最小化D_loss而提升自身的能力。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring |(八)AOP配置管理

文章目录 📚AOP切点表达式🐇语法格式🐇通配符 📚AOP通知类型🐇环境准备🐇通知类型的使用 📚AOP通知获取数据🐇环境准备🐇获取参数🐇获取返回值🐇获…

Flink 从入门到实战

Flink中的批和流 批处理的特点是有界、持久、大量,非常适合需要访问全部记录才能完成的计算工作,一般用于离线统计。 流处理的特点是无界、实时, 无需针对整个数据集执行操作,而是对通过系统 传输的每个数据项执行操作,一般用于实…

Ubuntu20.04运行LARVIO

文章目录 1.运行 Toyish 示例程序2.运行 ROS Nodelet参考 1.运行 Toyish 示例程序 LARVIO 提供了一个简化的toyish示例程序,适合快速验证和测试。 编译项目 进入 build 文件夹并通过 CMake 编译项目: mkdir build cd build cmake -D CMAKE_BUILD_TYPER…

小程序-基于java+SpringBoot+Vue的戏曲文化苑小程序设计与实现

项目运行 1.运行环境:最好是java jdk 1.8,我们在这个平台上运行的。其他版本理论上也可以。 2.IDE环境:IDEA,Eclipse,Myeclipse都可以。推荐IDEA; 3.tomcat环境:Tomcat 7.x,8.x,9.x版本均可 4.硬件环境&#xff1a…

mybatis plus如何使用mybatis xml拼接sql

在 MyBatis Plus 中,如果你想使用 MyBatis 的 XML 文件来拼接 SQL,可以结合使用 MyBatis 和 MyBatis Plus 的功能。MyBatis Plus 是一个增强 MyBatis 的工具,它提供了很多便捷的操作,但有时你可能需要使用 XML 文件来定义更复杂的…

【uniapp】轮播图

前言 Uniapp的swiper组件是一个滑块视图容器组件&#xff0c;可以在其中放置多个轮播图或滑动卡片。它是基于微信小程序的swiper组件进行封装&#xff0c;可以在不同的平台上使用&#xff0c;如微信小程序、H5、App等。 效果图 前端代码 swiper组件 <template><vi…

Python爬虫爬取数据报错

报错&#xff1a; Error fetching the URL: (Connection aborted., ConnectionResetError(10054, 远程主机强迫关闭了一个现有的连接。, None, 10054, None)) 报错原因&#xff1a; 目标服务器限制&#xff1a; 目标网站可能已经检测到你的请求来自自动化工具&#xff08;如爬虫…

人工智能与传统控制系统的融合发展

在这个科技快速迭代的时代&#xff0c;人工智能技术正以前所未有的速度改变着我们的生活。在控制系统领域&#xff0c;AI技术的引入为传统控制带来了新的发展机遇和挑战。然而&#xff0c;这并不意味着传统控制将被完全取代&#xff0c;相反&#xff0c;AI与传统控制的深度融合…

shell综合

声明&#xff01; 学习视频来自B站up主 泷羽sec 有兴趣的师傅可以关注一下&#xff0c;如涉及侵权马上删除文章&#xff0c;笔记只是方便各位师傅的学习和探讨&#xff0c;文章所提到的网站以及内容&#xff0c;只做学习交流&#xff0c;其他均与本人以及泷羽sec团队无关&#…

什么是串联谐振

比如有一个由电阻、电容和电感的串联电路中&#xff0c;存在一个频率能使这个电路的电流最大&#xff0c;这个现象就叫谐振。 那么这个频率是多少呢&#xff1f; 交流电频率与电路固有频率一致时&#xff0c;它就能发生谐振&#xff0c;此时这个电路的电流是最大的 这个固有频…

韦东山stm32hal库--定时器喂狗模型按键消抖原理+实操详细步骤

一.定时器按键消抖的原理: 按键消抖的原因: 当我们按下按键的后, 端口从高电平变成低电平, 理想的情况是, 按下, 只发生一次中断, 中断程序只记录一个数据. 但是我们使用的是金属弹片, 实际的情况就是如上图所示, 可能会发生多次中断,难道我们要记录3/4次数据吗? 答:按键按下…

雨云服务器搭建docker且用docker部署kali服务器教程

雨云 - 新一代云服务提供商 介绍 大家好今天教大家如何使用雨云的服务器安装docker并且用docker搭建kali服务器&#xff0c;实现大家做黑客的梦。 性价比比较高的云服务器提供参考&#xff1a;雨云 - 新一代云服务提供商 优惠码&#xff1a;MzkxODI4 什么是kali Kali L…

SQL进阶——JOIN操作详解

在数据库设计中&#xff0c;数据通常存储在多个表中。为了从这些表中获取相关的信息&#xff0c;我们需要使用JOIN操作。JOIN操作允许我们通过某种关系&#xff08;如相同的列&#xff09;将多张表的数据结合起来。它是SQL中非常重要的操作&#xff0c;广泛应用于实际开发中。本…

分析JHTDB数据库的Channel5200数据集的数据(SciServer服务器)

代码来自https://github.com/idies/pyJHTDB/blob/master/examples/channel.ipynb %matplotlib inline import numpy as np import math import random import pyJHTDB import matplotlib.pyplot as plt import time as ttN 3 T pyJHTDB.dbinfo.channel5200[time][-1] time …

数据分析:彩票中奖号码分析与预测

预测双色球彩票的中奖号码是一个典型的随机事件&#xff0c;因为每个号码的出现概率是独立的&#xff0c;且历史数据并不能直接用于预测未来的开奖结果。然而&#xff0c;我们可以通过统计分析来了解号码的分布规律&#xff0c;从而提供一些可能的参考。 样例数据【点击下载】…

详细分析 npm run build 基本知识 | 不同环境不同命令

目录 前言1. 基本知识2. 构建逻辑 前言 关于部署服务器的知识推荐阅读&#xff1a;npm run build部署到云服务器中的Nginx&#xff08;图文配置&#xff09; 1. 基本知识 npm run 是 npm 的一个命令&#xff0c;用于运行 package.json 中定义的脚本&#xff0c;可以通过 “s…

Jpype调用jar包

需求描述 ​   公司要求使用python对接口做自动化测试&#xff0c;接口的实现是Java&#xff0c;部分接口需要做加解密&#xff0c;因此需要使用python来调用jar包来将明文加密成密文&#xff0c;然后通过http请求访问接口进行测试。 如何实现 1.安装Jpype ​   首先我…

Realtek网卡MAC刷新工具PG8168.exe Version:2.34.0.4使用说明

本刷新工具虽然文件名叫PG8168.EXE&#xff0c;但不是只有RTL8168可用&#xff0c;是这一个系列的产品都可以使用。实验证明RTL8111也可以使用。 用法&#xff1a; PG8168 [/h][/?][/b][/c HexOffsetHexValue][/d NICNumber][/l][/r][/w][/v] [/# NICNumber] [/nodeidHexNOD…

【Unity】Unity编辑器扩展,替代预制体上重复拖拽赋值

今天做游戏时有个需求&#xff0c;游戏中需要给不同年份不同月份的奖牌制定不一样的非规则形状&#xff0c;其中形状为100个像素组成的不同图形&#xff0c;并且按照从1-100路径一个个解锁&#xff0c;所以需要全部手动放置。但是手动放置好后&#xff0c;发现再一个个挂到脚本…

c语言的qsort函数理解与使用

介绍&#xff1a;qsort 函数是 C 标准库中用于排序的快速排序算法函数。它的用法非常灵活&#xff0c;可以对任意类型的元素进行排序&#xff0c;只要提供了比较函数即可。 qsort 函数原型及参数解释&#xff1a; void qsort ( void* base, //指向要排序的数组的首元素…