【论文笔记】生成对抗网络 GAN

news2025/3/26 0:25:11

GAN

2014 年,Ian Goodfellow 等人提出生成对抗网络(Generative Adversarial Networks),GAN 的出现是划时代的,虽然目前主流的图像/视频生成模型是扩散模型(Diffusion Models)的天下,但是我们仍然有必要了解 GAN 的思想。

GAN 的核心思想是训练两个模型,分别为生成器(Generator)和辨别器(Discriminator),生成器的目标是生成虚假的数据,尽可能混淆辨别器,使其无法判别真实数据和虚假数据,而辨别器的目标则是尽可能将真实数据和虚假数据区分开来。这个过程如下图所示:

gan-example

生成器和辨别器处于一个对抗的过程,它们的能力不断地提升。GAN 的一个缺点在于它的训练过程不稳定,因此在 GAN 出来后,跟 GAN 相关的论文层出不穷,包括改进 GAN 的损失函数、训练方式,或者采用更先进的模型结构,使 GAN 的生成能力更强,同时使其训练过程更加稳定,但是 GAN 的核心思想是不变的。

模型结构

GAN 的结构如下图所示:

gan

GAN 的生成器和辨别器是两个独立的模型,在原始 GAN 中采用的生成器和辨别器都是多层感知机(Multi Layer Perceptron),后来出现了许多模型结构的改进,例如 DCGAN 将 MLP 替换为卷积神经网络。

辨别器

辨别器本质上是一个分类器,用于区分真实数据和由生成器生成的虚假数据,输出是一个 0-1 范围的标量,表示为真实数据的概率值。辨别器有两个数据来源:真实和虚假数据,训练辨别器的过程中,保持生成器的参数不变,利用二分类损失计算梯度,执行反向传播更新辨别器的参数,过程如下。

gan-disc

生成器

生成器用于生成虚假数据,尽可能混淆辨别器,生成器接受一个随机噪声(Random Noise),随机噪声的采样可以来自于均匀分布、正态分布等等,甚至可以是一张图片。生成器的作用就是将随机噪声分布转换为真实数据的分布,在生成器训练的过程中,保持辨别器的参数不变,利用辨别器的梯度来更新生成器。

在这里插入图片描述

损失函数

GAN 采用了 minimax 损失,其数学表达式如下:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D,G)=E_{x\sim p_{data}(x)}[\log D(x)]+E_{z\sim p_z(z)}[\log(1-D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中, V ( D , G ) V(D,G) V(D,G) 表示价值函数, x x x 为真实数据采样的样本, z z z 为生成器生成的样本。

minimax 损失本质上是一个二分类损失(Binary Cross Entropy),可以拆解为辨别器损失和生成器损失。

在训练辨别器的过程中,生成器参数保持不变,因此对于辨别器而言,\(G(z)\) 可以视为常数,其损失函数为:

L D = − E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] − E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_D=-E_{x\sim p_{data}(x)}[\log D(x)]-E_{z\sim p_z(z)}[\log(1-D(G(z)))] LD=Expdata(x)[logD(x)]Ezpz(z)[log(1D(G(z)))]

在训练生成器的过程中,辨别器参数保持不变,因此对于辨别器而言,价值函数的第一项为常数,在求导时忽略不计,因此生成器的损失函数为:

L G = − E z ∼ p z ( z ) [ log ⁡ ( D ( G ( z ) ) ) ] L_G=-E_{z\sim p_z(z)}[\log(D(G(z)))] LG=Ezpz(z)[log(D(G(z)))]

对于上述两个损失函数一个直观的理解是,对于 L G L_G LG 而言,我们希望生成器生成的假数据使判别器无法区分,即希望判别器输出的概率接近于 1,取对数后即接近于 0,由于判别器的输出在于 0 - 1 之间,因此取 log 后为负数,即转变为最大化对数概率,或最小化负对数概率,由于优化的过程通常是梯度下降的过程,因此选择后者。

在 GAN 的论文中,给出了一张用于阐述 GAN 的训练过程的图。假设随机噪声 z z z 采样自一维均匀分布,真实数据分布为标准正态分布。图中的黑色点线表示真实数据分布,蓝色虚线表示辨别器输出的概率分布,绿色实线表示生成器输出的概率分布。随着 GAN 的不断训练,生成器生成的数据分布逐渐接近于真实数据分布,辨别器越来越难以区分真实数据和假数据,因此在理想情况下,生成器完全学习到了真实数据分布,辨别器再也无法进行区分,因此输出的概率都为 50%,也就是图(d) 所示的直线。

gan-process

GAN 的训练过程以及 PyTorch 实现

以下是原始 GAN 论文中的训练算法:

train

注意:这里生成器的损失函数并不是前面重写的形式,但是它们两个是等价的,在实际中,作者采用前面重写的形式,因为他们认为这样训练更加稳定。实际的情况是都不那么稳定:)。

下面是一个 GAN 的 PyTorch 实现例子,生成器和辨别器均采用 MLP,在数据集 MNIST 上进行训练的代码,具体代码可见:vanilla-gan。

import os
from argparse import Namespace, ArgumentParser
import torch
from torch import nn, Tensor
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader


class Discriminator(nn.Module):
    """
    Disrcminator in GAN.
    
    Model Architecture: [affine - leaky relu - dropout] x 3 - affine - sigmoid
    """
    def __init__(self, image_shape: tuple[int, int, int]) -> None:
        super(Discriminator, self).__init__()
        C, H, W = image_shape
        image_size = C * H * W
        self.model = nn.Sequential(
            nn.Linear(image_size, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(256, 128), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(128, 1), nn.Sigmoid())

    def forward(self, images: Tensor) -> Tensor:
        images = images.view(images.size(0), -1)
        return self.model(images)


class Generator(nn.Module):
    """
    Generator in GAN.

    Model Architecture: [affine - batchnorm - relu] x 4 - affine - tanh
    """
    def __init__(self, image_shape: tuple[int, int, int], latent_dim: int) -> None:
        super(Generator, self).__init__()
        C, H, W = image_shape
        image_size = C * H * W
        self.image_shape = image_shape
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
            nn.Linear(1024, image_size), nn.Tanh())
    
    def forward(self, z: Tensor) -> Tensor:
        images: Tensor = self.model(z)
        return images.view(-1, *self.image_shape)

# Image processing.
transform_mnist = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

transform_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

# Device configuration.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def denormalize(x: Tensor) -> Tensor:
    out = (x + 1) / 2
    return out.clamp(0, 1)


def get_args() -> Namespace:
    """Get commandline arguments."""
    parser = ArgumentParser()
    parser.add_argument('--lr', type=float, default=0.0002, help='learning rate for Adam optimizer')
    parser.add_argument('--beta1', type=float, default=0.5, help='first momentum term for Adam')
    parser.add_argument('--beta2', type=float, default=0.999, help='second momentum term for Adam')
    parser.add_argument('--batch_size', type=int, default=64, help='size of a mini-batch')
    parser.add_argument('--num_epochs', type=int, default=100, help='training epochs')
    parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
    parser.add_argument('--dataset', type=str, default='MNIST', help='training dataset(MNIST | FashionMNIST | CIFAR10)')
    parser.add_argument('--sample_dir', type=str, default='samples', help='directory of image samples')
    parser.add_argument('--interval', type=int, default=1, help='epoch interval between image samples')
    parser.add_argument('--logdir', type=str, default='runs', help='directory of running log')
    parser.add_argument('--ckpt_dir', type=str, default='checkpoints', help='directory for saving model checkpoints')
    parser.add_argument('--seed', type=str, default=10213, help='random seed')
    return parser.parse_args()


def setup(args: Namespace) -> None:
    torch.manual_seed(args.seed)
    # Create directory if not exists.
    if not os.path.exists(os.path.join(args.sample_dir, args.dataset)):
        os.makedirs(os.path.join(args.sample_dir, args.dataset))
    if not os.path.exists(os.path.join(args.ckpt_dir, args.dataset)):
        os.makedirs(os.path.join(args.ckpt_dir, args.dataset))


def get_data_loader(args: Namespace) -> DataLoader:
    """Get data loader."""
    if args.dataset == 'MNIST':
        data = datasets.MNIST(root='../data', train=True, download=True, transform=transform_mnist)
    elif args.dataset == 'FashionMNIST':
        data = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform_mnist)
    elif args.dataset == 'CIFAR10':
        data = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_cifar)
    else:
        raise ValueError(f'Unkown dataset: {args.dataset}, support dataset: MNIST | FashionMNIST | CIFAR10')
    return DataLoader(dataset=data, batch_size=args.batch_size, num_workers=4, shuffle=True)


def train(args: Namespace, 
          G: Generator, D: Discriminator, 
          data_loader: DataLoader) -> None:
    """Train Generator and Discriminator.

    Args:
        args(Namespace): arguments.
        G(Generator): Generator in GAN.
        D(Discriminator): Discriminator in GAN.
    """
    writer = SummaryWriter(os.path.join(args.logdir, args.dataset))

    # generate fixed noise for sampling.
    fixed_noise = torch.rand(64, args.latent_dim).to(device)

    # Loss and optimizer.
    criterion = nn.BCELoss().to(device)
    optimizer_G = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optimizer_D = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    # Start training.
    for epoch in range(args.num_epochs):
        total_d_loss = total_g_loss = 0
        for images, _ in data_loader:
            m = images.size(0)
            images: Tensor = images.to(device)
            images = images.view(m, -1)
            # Create real and fake labels.
            real_labels = torch.ones(m, 1).to(device)
            fake_labels = torch.zeros(m, 1).to(device)
            # ================================================================== #
            #                      Train the discriminator                       #
            # ================================================================== #

            # Forward pass
            outputs = D(images)
            d_loss_real: Tensor = criterion(outputs, real_labels)
            
            z = torch.rand(m, args.latent_dim).to(device)
            fake_images: Tensor = G(z).detach()
            outputs = D(fake_images)
            d_loss_fake: Tensor = criterion(outputs, fake_labels)

            # Backward pass
            d_loss: Tensor = d_loss_real + d_loss_fake
            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
            total_d_loss += d_loss

            # ================================================================== #
            #                        Train the generator                         #
            # ================================================================== #
            
            # Forward pass
            z = torch.rand(images.size(0), args.latent_dim).to(device)
            fake_images: Tensor = G(z)
            outputs = D(fake_images)
            
            # Backward pass
            g_loss: Tensor = criterion(outputs, real_labels)
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()
            total_g_loss += g_loss
        print(f'''
=====================================
Epoch: [{epoch + 1}/{args.num_epochs}]
Discriminator Loss: {total_d_loss / len(data_loader):.4f}
Generator Loss: {total_g_loss / len(data_loader):.4f}
=====================================''')
        # Log Discriminator and Generator loss.
        writer.add_scalar('Discriminator Loss', total_d_loss / len(data_loader), epoch + 1)
        writer.add_scalar('Generator Loss', total_g_loss / len(data_loader), epoch + 1)
        fake_images: Tensor = G(fixed_noise)
        img_grid = make_grid(denormalize(fake_images), nrow=8, padding=2)
        writer.add_image('Fake Images', img_grid, epoch + 1)
        if (epoch + 1) % args.interval == 0:
            save_image(img_grid, os.path.join(args.sample_dir, args.dataset, f'fake_images_{epoch + 1}.png'))
    # Save the model checkpoints.
    torch.save(G.state_dict(), os.path.join(args.ckpt_dir, args.dataset, 'G.ckpt'))
    torch.save(D.state_dict(), os.path.join(args.ckpt_dir, args.dataset, 'D.ckpt'))


def main() -> None:
    args = get_args()
    setup(args)
    image_shape = (1, 28, 28) if args.dataset in ('MNIST', 'FashionMNIST') else (3, 32, 32)
    data_loader = get_data_loader(args)
    # Generator and Discrminator.
    G = Generator(image_shape=image_shape, latent_dim=args.latent_dim).to(device)
    D = Discriminator(image_shape=image_shape).to(device)
    train(args, G, D, data_loader)


if __name__ == '__main__':
    main()

参考

[1] I. Goodfellow et al., “Generative Adversarial Nets,” in Advances in Neural Information Processing Systems, Curran Associates, Inc., 2014. Accessed: Sep. 12, 2024. [Online]. Available: https://papers.nips.cc/paper_files/paper/2014/hash/5ca3e9b122f61f8f06494c97b1afccf3-Abstract.html

[2] eriklindernoren. “PyTorch-GAN”. Github 2018. [Online]. Available: https://github.com/eriklindernoren/PyTorch-GAN

[3] 李沐. “GAN论文逐段精读【论文精读】”. Bilibili 2021. [Online]. Available: https://www.bilibili.com/video/BV1rb4y187vD/?spm_id_from=333.1387.collection.video_card.click&vd_source=c8a32a5a667964d5f1068d38d6182813
n. “PyTorch-GAN”. Github 2018. [Online]. Available: https://github.com/eriklindernoren/PyTorch-GAN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2321625.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Agent】Dify Docker 安装问题 INTERNAL SERVER ERROR

总结:建议大家选择稳定版本的分支,直接拉取 master 分支,可能出现一下后面更新代码导致缺失一些环境内容。 启动报错 一直停留在 INSTALL 界面 我是通过 Docker 进行安装的,由于项目开发者不严谨导致,遇到一个奇怪的…

【Excel使用技巧】某列保留固定字段或内容

目录 ✅ 方法一:使用 Excel 公式提取 body 部分 🔍 解释: ✅ 方法二:批量处理整列数据 🚨 注意事项 🚨 处理效果 我想保留Excel某一列的固定内容,比如原内容是: thread entry i…

vue3,element-plus 表格搜索过滤数据

1、表格数据 // 表格数据 import type { User } from "/interface"; const tableData ref<User[]>([]); 2、 表格搜索过滤数据 // 搜索内容 const search ref(""); // 表格过滤数据 const tableFilterData computed(() >tableData.value.fi…

vue中上传接口file表单提交二进制文件流

1.使用elementui上传组件 要做一个选择文件后&#xff0c;先不上传&#xff0c;等最后点击确定后&#xff0c;把file二进制流及附加参数一起提交上去。 首先使用elementui中的上传组件&#xff0c;设置auto-uploadfalse&#xff0c;也就是选择文件后不立刻上传。 <el-uplo…

【学习笔记】卷积网络简介及原理探析

作者选择了由 Ian Goodfellow、Yoshua Bengio 和 Aaron Courville 三位大佬撰写的《Deep Learning》(人工智能领域的经典教程&#xff0c;深度学习领域研究生必读教材),开始深度学习领域学习&#xff0c;深入全面的理解深度学习的理论知识。 之前的文章参考下面的链接&#xf…

element-plus中Cascader级联选择器组件的使用

目录 一.基本使用 二.进阶使用 1.如何获取最后一级选项的值&#xff1f; 2.如何让级联选择器的输入框只展示最后一级&#xff1f; 三.实战 1.场景描述 2.实现步骤 ①设计后端返回值Vo ②编写controller ③编写service ④编写mapper层 ⑤在前端&#xff0c;通过发送…

【华为Pura先锋盛典】华为Pura X“阔折叠”手机发布:首次全面搭载HarmonyOS 5

文章目录 前言一、阔感体验&#xff0c;大有不同二、鸿蒙AI&#xff0c;大有智慧三、便携出行&#xff0c;大有不同四、首款全面搭载 HarmonyOS 5 的手机五、卓越性能&#xff0c;可靠安心六、红枫影像&#xff0c;大放光彩预热&#xff1a;鸿蒙电脑HarmonyOS 5 升级计划小结 前…

MQ,RabbitMQ,MQ的好处,RabbitMQ的原理和核心组件,工作模式

1.MQ MQ全称 Message Queue&#xff08;消息队列&#xff09;&#xff0c;是在消息的传输过程中 保存消息的容器。它是应用程序和应用程序之间的通信方法 1.1 为什么使用MQ 在项目中&#xff0c;可将一些无需即时返回且耗时的操作提取出来&#xff0c;进行异步处理&#xff0…

ETL:数据清洗、规范化和聚合的重要性

在当今这个数据呈爆炸式增长的时代&#xff0c;数据已成为企业最为宝贵的资产之一。然而&#xff0c;数据的海量增长也伴随着诸多问题&#xff0c;如数据来源多样、结构复杂以及质量问题等&#xff0c;这些问题严重阻碍了数据的有效处理与深度分析。在此背景下&#xff0c;ETL&…

电机控制常见面试问题(十八)

文章目录 一.电机控制高级拓扑结构1.LLC 二.谈谈电压器饱和后果三.电压器绕组连接方式的影响四.有源逆变的条件 一.电机控制高级拓扑结构 1.LLC LLC是什么&#xff1f;—— 一个会"变魔术"的电源盒子 想象你有一个魔法盒子&#xff0c;能把电池的电压变大或变小&…

stable diffusion本地安装

1. 基本环境准备 安装conda 环境 pytorch基础学习-CSDN博客 创建虚拟环境&#xff1a; conda create -n sd python3.10 一定要指定用3.10&#xff0c;过高的版本会提示错误&#xff1a; 激活启用环境&#xff1a; conda activate sd 设置pip国内镜像源&#xff1a; pip conf…

【内网穿透】Linux部署FRP0.61.2实现rk3566 Wechat iPad协议内网穿透教程

写在前面 FRP&#xff08;Fast Reverse Proxy&#xff09;是一个由Go语言编写的开源项目&#xff0c;用于内网穿透&#xff0c;即通过公网服务器将内网服务暴露给外部访问。这对于需要在内网环境中部署但又希望外部用户能够访问这些服务的场景非常有用 Github&#xff1a;htt…

VM虚拟机安装Ubuntu系统

前言 我现在装的Ubuntu总是死机&#xff0c;经常黑屏&#xff0c;所以我决定换个版本&#xff0c;顺便写一下笔记&#xff0c;给大家分享如何安装虚拟机 下载 这里我选择的是Ubuntu 22.04.5 LTS&#xff0c;下载链接&#xff1a;Ubuntu 22.04.5 LTS 如果访问不了网站的话&…

从JVM底层揭开Java方法重载与重写的面纱:原理、区别与高频面试题突破

&#x1f31f;引言&#xff1a;一场由方法调用引发的"血案" 2018年&#xff0c;某电商平台在"双十一"大促期间遭遇严重系统故障。 技术团队排查发现&#xff0c;问题根源竟是一个继承体系中的方法重写未被正确处理&#xff0c;导致订单金额计算出现指数级…

芋道 Spring Cloud Alibaba 消息队列 RocketMQ 入门

1. 概述 RocketMQ 是一款开源的分布式消息系统&#xff0c;基于高可用分布式集群技术&#xff0c;提供低延时的、高可靠的消息发布与订阅服务。同时&#xff0c;广泛应用于多个领域&#xff0c;包括异步通信解耦、企业解决方案、金融支付、电信、电子商务、快递物流、广告营销…

html css js网页制作成品——HTML+CSS+js迪奥口红网站网页设计(4页)附源码

目录 一、&#x1f468;‍&#x1f393;网站题目 二、✍️网站描述 三、&#x1f4da;网站介绍 四、&#x1f310;网站效果 五、&#x1fa93; 代码实现 &#x1f9f1;HTML 六、&#x1f947; 如何让学习不再盲目 七、&#x1f381;更多干货 一、&#x1f468;‍&#x1f…

PPT 转高精度图片 API 接口

PPT 转高精度图片 API 接口 文件处理 / 图片处理&#xff0c;将 PPT 文件转换为图片序列。 1. 产品功能 支持将 PPT 文件转换为高质量图片序列&#xff1b;支持 .ppt 和 .pptx 格式&#xff1b;保持原始 PPT 的布局和样式&#xff1b;转换后的图片支持永久访问&#xff1b;全…

python学习笔记--实现简单的爬虫(二)

任务&#xff1a;爬取B站上最爱欢迎的编程课程 网址&#xff1a;编程-哔哩哔哩_bilibili 打开网页的代码模块&#xff0c;如下图&#xff1a; 标题均位于class_"bili-video-card__info--tit"的h3标签中&#xff0c;下面通过代码来实现&#xff0c;需要说明的是URL中…

【颠覆性缓存架构】Caffeine双引擎缓存实战:CPU和内存双优化,命中率提升到92%,内存减少75%

千万级QPS验证&#xff01;Caffeine智能双缓存实现 92%命中率&#xff0c;内存减少75% 摘要&#xff1a; 本文揭秘千万级流量场景下的缓存革命性方案&#xff01;基于Caffeine打造智能双模式缓存系统&#xff0c;通过冷热数据分离存储与精准资源分配策略&#xff0c;实现CPU利…

智能汽车图像及视频处理方案,支持视频智能包装能力

美摄科技的智能汽车图像及视频处理方案&#xff0c;通过深度学习算法与先进的色彩管理技术&#xff0c;能够自动调整图像中的亮度、对比度、饱和度等关键参数&#xff0c;确保在各种光线条件下&#xff0c;图像都能呈现出最接近人眼的自然色彩与细节层次。这不仅提升了驾驶者的…