生成对抗网络——GAN(代码+理解)

news2024/12/27 13:48:27

目录

一、GAN模型介绍

二、GAN模型的训练过程

1. 初始化网络:

2. 训练判别器:

3. 训练生成器:

4. 重复步骤 2和步骤 3:

三、GAN实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码

3. 运行结果展示

四、GAN模型的应用领域

五、学习中产生的疑问,及文心一言回答

1. 生成器(Generator)模型理解

2. 为什么要使用 block 函数?

3. 函数解释

4. 为什么要将像素值从[0, 255]缩放到[0.0, 1.0] ?

5. 详细解释一下是怎样对Tensor进行标准化的,以及为什么要这么做?


一、GAN模型介绍

        GAN,全称 Generative Adversarial Network,即生成对抗网络,是一种基于 对抗学习的深度生成模型。该模型由Ian Goodfellow在 2014年 首次提出,并迅速成为 学术界研究的热点,推动了生成模型领域的发展。

        GAN模型主要由两部分组成:生成器(Generator)和判别器(Discriminator)

   1. 生成器:生成器模型可以是 任意结构的神经网络,其 输入是 随机噪声torch.randn,输出则是 生成的样本。生成器的 目标是使生成的样本尽可能接近真实样本的分布,以欺骗判别器

    2. 判别器:判别器模型同样可以是任意结构的神经网络,其 输入是真实样本或生成器生成的样本,输出是一个 概率值,表示 输入样本是真实样本的概率。判别器的 目标是尽可能准确地判断输入样本是真实样本还是生成样本

        这两个网络 在训练过程中 相互对抗、相互优化,形成了一种 零和博弈

二、GAN模型的训练过程

1. 初始化网络

        生成器(Generator)和判别器(Discriminator)的参数需要随机初始化。这两个网络都是神经网络,通常使用随机权重开始训练。

# 实例化
generator = Generator()
discriminator = Discriminator()

2. 训练判别器

1从真实数据集中 随机选择一批数据,将其输入到 判别器中进行训练。

for i, (img, _) in enumerate(dataloader):   # 内层迭代次数为 10000 // 64 = 157次,每次 64个数据

2同时,从生成器的当前状态生成一批假数据(也称为生成数据),也将这些数据输入到判别器中进行训练。

# 假数据的生成
fake_img = torch.randn(size, 100)

3在训练判别器时,需要固定生成器的参数.detach函数),只更新判别器的参数。

output_fake = generator(fake_img)
fake_socre = discriminator(output_fake.detach()) 

(4判别器的目标是将真实数据和假数据区分开来,因此其损失函数通常定义为二元交叉熵损失函数.BCELoss函数

# 损失函数
criterion = torch.nn.BCELoss()    # 对应 Sigmoid,计算二元交叉墒损失

(5使用反向传播算法更新判别器的参数,以最小化损失函数。

3. 训练生成器

1生成器的目标是生成与真实数据相似的假数据,使得判别器无法区分真实数据和假数据。

2生成器的 损失函数通常定义为 判别器对 假数据的 输出结果的 交叉熵损失函数的 相反数。换句话说,生成器希望判别器 对假数据的判断结果 尽可能接近真实数据

fake_G_socre = discriminator(output_fake)
G_fake_loss = criterion(fake_G_socre, torch.ones_like(fake_G_socre))

(3同样使用反向传播算法更新生成器的参数,以最小化其损失函数。

4. 重复步骤 2和步骤 3

1在每一轮训练中,先 训练判别器,然后 训练生成器。这样可以确保两个网络都能得到足够的优化。

2重复这个过程,直到达到预设的迭代次数或满足某种收敛条件(如生成器生成的假数据与真实数据的差距达到一定程度本案例没有设计)。

三、GAN实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


mnist = datasets.MNIST(
    root='./others/',
    train=False,
    download=False,
    transform=transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
)

dataloader = DataLoader(
    dataset=mnist,
    batch_size=64,
    shuffle=True
)

def gen_img_plot(model, epoch, text_input):
    prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16])
    plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((prediction[i] + 1) / 2)
        plt.axis('off')
    plt.show()

# 生成器定义
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        self.mean = nn.Sequential(
            *block(100, 256, normalize=False),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, x):
        imgs = self.mean(x)
        imgs = imgs.view(-1, 1, 28, 28)
        return imgs

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.mean = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        img = self.mean(x)  # 对 64条数据的每一条都进行模型运算
        return img

# 实例化
generator = Generator()
discriminator = Discriminator()

# 定义优化器
G_Apim = torch.optim.Adam(generator.parameters(), lr=0.0001)
D_Apim = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

# 损失函数
criterion = torch.nn.BCELoss()    # 对应 Sigmoid,计算二元交叉墒损失


epoch_num = 100
G_loss_save = []
D_loss_save = []
for epoch in range(epoch_num):  # 将 10000 条数据迭代了两遍
    G_epoch_loss = 0
    D_epoch_loss = 0
    count = len(dataloader)
    for i, (img, _) in enumerate(dataloader):   # 内层迭代次数为 10000 // 64 = 157次,每次 64个数据
        # 训练 Discriminator
        # 判断出假的
        size = img.size(0)  # 0 维有多少个数据
        fake_img = torch.randn(size, 100)

        output_fake = generator(fake_img)
        fake_socre = discriminator(output_fake.detach())    # .detach() 返回一个关闭梯度的 output_fake,这样前向传播不会修改 generater 的 grad
        D_fake_loss = criterion(fake_socre, torch.zeros_like(fake_socre))
        # 判断出真的
        real_socre = discriminator(img)
        D_real_loss = criterion(real_socre, torch.ones_like(real_socre))

        D_loss = D_fake_loss + D_real_loss
        D_Apim.zero_grad()
        D_loss.backward()
        D_Apim.step()

        # 训练 Generater
        # G_fake_img = torch.randn(size, 100)
        # G_output_fake = generator(G_fake_img)
        # fake_G_socre = discriminator(G_output_fake)
        fake_G_socre = discriminator(output_fake)
        G_fake_loss = criterion(fake_G_socre, torch.ones_like(fake_G_socre))
        G_Apim.zero_grad()
        G_fake_loss.backward()
        G_Apim.step()

        with torch.no_grad():   # 其中所有的 requires_grad 都被默认设置为 False
            G_epoch_loss += G_fake_loss
            D_epoch_loss += D_loss

    with torch.no_grad():
        G_epoch_loss /= count
        D_epoch_loss /= count

        G_loss_save.append(G_epoch_loss.item())
        D_loss_save.append(D_epoch_loss.item())

        print('Epoch: [%d/%d] | G_loss: %.3f | D_loss: %.3f'
              % (epoch, epoch_num, G_epoch_loss, D_epoch_loss))
        text_input = torch.randn(64, 100)
        gen_img_plot(generator, epoch, text_input)


x = [epoch + 1 for epoch in range(epoch_num)]
plt.figure()
plt.plot(x, G_loss_save, 'r')
plt.plot(x, D_loss_save, 'b')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['G_loss','D_loss'])
plt.show()

3. 运行结果展示

四、GAN模型的应用领域

        GAN模型 在图像生成、视频生成、文本生成等领域都有 广泛的应用。例如,在图像生成领域,GAN模型可以生成高质量的图像、进行图像修复 和 超分辨率重建 等任务;在视频生成领域,GAN模型可以生成连贯的视频序列;在文本生成领域,GAN模型可以生成逼真的文本内容等。此外,GAN模型还在 AI绘画领域 发挥着 重要作用,成为AI绘画工作流中的 关键辅助模型。

五、学习中产生的疑问,及文心一言回答

1. 生成器(Generator)模型理解

2. 为什么要使用 block 函数?

3. 函数解释

4. 为什么要将像素值从[0, 255]缩放到[0.0, 1.0] ?

5. 详细解释一下是怎样对Tensor进行标准化的,以及为什么要这么做?


                                                后续更新GAN的其他模型结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1830938.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

动态 ETL 管道:使用非结构化 IO 将 AI 与 MinIO 和 Weaviate 的 Web

在现代数据驱动的环境中,网络是一个无穷无尽的信息来源,为洞察力和创新提供了巨大的潜力。然而,挑战在于提取、构建和分析这片浩瀚的数据海洋,使其具有可操作性。这就是Unstructured-IO 的创新,结合MinIO的对象存储和W…

存储器的分类以及介绍

1.存储器的分类 2.按存储介质分 按照存储介质可以分为三类,电/磁/光 1.半导体存储器(电) 存储元件由半导体器件组储层的存储器称为半导体存储器。 现代的半导体存储器都是超大规模集成电路工艺制成芯片。 其优点是:体积小、功…

Nature 苏浩团队发表创新人工智能“仿真中学习”框架,实现外骨骼的智能性和通用性

北京时间2024年6月12日23时,美国北卡罗来纳州立大学与北卡罗来纳大学教堂山分校的苏浩团队在《自然》(Nature)上发表了一篇关于机器人和人工智能算法相结合服务人类的突破性研究论文,标题为“Experiment-free Exoskeleton Assista…

transformers 不同精度float16、bfloat16、float32加载模型对比

参考: https://github.com/chunhuizhang/pytorch_distribute_tutorials/blob/main/tutorials/amp_autocast_mixed_precision_training.ipynb from transformers import AutoModelForCausalLM, AutoTokenizer device "cuda" # the device to load the m…

MySQL初学知识总篇

MySQL入门篇 MySQL下载并安装教程推荐:聚精会神搞学习的文章 图形化工具使用:Dbeaver下载官网 目录 🍉概述:什么是MySQL?一、🍉MySQL语言特点:二、🍉数据库管理系统(数据…

家庭智能助手:Kompas AI引领家居智能化新纪元

一、引言 在数字化浪潮的推动下,现代家庭生活正迅速向智能化转型。从简单的自动化设备到复杂的智能家居系统,智能技术正悄无声息地改变我们的日常生活。Kompas AI作为一款前沿的家庭智能助手,不仅预示着家庭生活的未来趋势,更以其…

Unity EasyRoads3D插件使用

一、插件介绍 描述 Unity 中的道路基础设施和参数化建模 在 Unity 中使用内置的可自定义动态交叉预制件和基于您自己导入的模型的自定义交叉预制件,直接创建独特的道路网络。 添加额外辅助对象,让你的场景栩栩如生:桥梁、安全护栏、栅栏、墙壁…

不可思议!这款 Python 库竟然能自动生成GUI界面:MagicGUI

目录 什么是MagicGUI? ​编辑 MagicGUI的工作原理 安装MagicGUI 创建你的第一个GUI ​编辑 其他案例 输入值对话框 大家好,今天我们来聊一聊一个非常有趣且实用的Python库——MagicGUI。这个库可以让你用最少的代码,快速创建图形用户…

GStreamer——教程——基础教程7:Multithreading and Pad Availability

基础教程7:多线程和Pad可用性 目标 GStreamer自动处理多线程,但是在某些情况下,用户可能需要手动解耦线程。这篇教程将展示如何解耦线程以及完善关于Pad Availability的描述。更准确来说,这篇文档解释了: 如何为pipe…

不会策划营销活动?教你一步步成为策划高手

要想让活动大获成功,不仅需要创意十足,更要有严谨的策划和执行,确实新人会有点感觉不知所措。 但其实也不用怕,只要按照以下五个关键步骤,一步步来,也可以轻松策划及格的好活动。 步骤一:锁定目…

AIGC绘画设计基础——十分钟读懂Stable Diffusion

写在最前面: 由于Stable Diffusion里面有关扩散过程的描述,描述方法有很多版本,比如前向过程也可以叫加噪过程,为了便于理解,这里把各种描述统一说明一下。 Diffusion扩散模型:文章里面所有出现Diffusion…

志全重庆官网下载

baidu搜索:如何联系八爪鱼SEO? baidu搜索:如何联系八爪鱼SEO? baidu搜索:如何联系八爪鱼SEO? 现在越来越多的人抱怨说搜索引擎收录很难做,站群程序似乎不在是那么重要, 花费高价购买域名成为了做出高收录站群的越来越重要的建站前提。实上…

Python文本处理:初探《三国演义》

Python文本处理:初探《三国演义》 三国演义获取文本文本预处理分词与词频统计引入停用词后进行词频统计分析人物出场次数结果可视化完整代码 三国演义 《三国演义》是中国古代四大名著之一,它以东汉末年到晋朝统一之间的历史为背景,讲述了魏…

2024下《软件设计师》50个高频考点汇总,背就有效!

宝子们!上半年软考已经结束一段时间了,准备考下半年软考中级-软件设计师的小伙伴们可以开始准备了,这里给大家整理了50个高频考点,涵盖全书90%以上重点,先把这个存下!再慢慢看书,边看书边背这个…

CNN和Transformer创新结合,模型性能炸裂!

CNN结合Transformer 【CNNTransformer】这个研究方向通过结合卷积神经网络(CNN)的局部特征提取能力和Transformer的全局上下文建模优势,旨在提升模型对数据的理解力。这一方向在图像处理、自然语言处理等多个领域展现出强大的应用潜力&#…

告诉你提升UI质感的两个秘密,谁用谁知道。

秘密一:善用头部装饰 秘密二:设计好瓷片区

老电脑焕发第二春,玩转 Stable Diffusion 3

几年前,我头脑一热,配置了一台顶配级消费 PC(RTX 2080 Ti GPU i9 CPU),打算用来学习 AI。然而,起初我并没有找到合适的切入点。深度学习早期阶段,消费级显卡根本无法承担训练大模型、微调大模型…

优思学院|精益管理是什么?3大问题帮你彻底搞懂

有一位朋友他喜欢投资,他偶然看中了一家公司,从公司的一些新闻稿中表示他们因为实施了“精益管理”(Lean Management),因此每股盈余(EPS)长期稳定增长,甚至在行业内的重要指标——库…

微信小游戏备案 之 游戏内容介绍编写实例

微信小游戏备案 之 游戏内容介绍编写实例 前言一,编写规范二,内容填写2.1 本游戏不涉及2.2 游戏场景2.3 游戏玩法2.4 功能系统2.5 主要特点三,小结前言 对于游戏开发者来说,微信小游戏备案是让游戏合法上线的重要步骤,而其中游戏内容介绍的编写尤为关键。下面为大家提供一…

Python实现管线建模 - 3.同心变径管

往期回顾 Python实现管线建模 || 1.圆直管、方管https://blog.csdn.net/Xxy9426/article/details/138836778?spm1001.2014.3001.5501 对依赖库的补充 随着后续内容的深入,我发现单纯靠trimesh库已经无法完成后续的建模(涉及到多个几何体拼接或者是创建…