Gan论文阅读笔记

news2024/11/28 2:51:05

GAN论文阅读笔记

2014年老论文了,主要记录一些重要的东西。论文链接如下:

Generative Adversarial Nets (neurips.cc)

文章目录

  • GAN论文阅读笔记
    • 出发点
    • 创新点
    • 设计
    • 训练代码
    • 网络结构代码
    • 测试代码

出发点

Deep generative models have had less of an impact, due to the difficulty of approximating many intractable probabilistic computations that arise in maximum likelihood estimation and related strategies, and due to difficulty of leveraging the benefits of piecewise linear units in the generative context.

​ 当时的生成模型效果不佳在于近似许多棘手的概率计算十分困难,如最大似然估计等。除此之外,把利用分段线性单元运用到生成场景中也有困难。于是作者提出新的生成模型:GAN。

​ 我的理解是,当时的生成模型都是去学习模型生成数据的分布,比如确定方差,确定均值之类的参数,然而这种方法十分难以学习,而且计算量大而复杂,作者考虑到这一点,对生成模型采用端到端的学习策略,不去学习生成数据的分布,而是直接学习模型,只要这个模型的生成结果能够逼近Ground-Truth,那么就可以直接用这个模型代替分布去生成数据。这是典型的黑箱思想。

创新点

adiscriminative model that learns to determine whether a sample is from the model distribution or the data distribution. The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles.

创新点1:提出对抗学习策略:提出两个model之间相互对抗,相互抑制的策略。一个model名为生成器Generator,一个model名为判别器Discriminator,生成器尽可能生成接近真实的数据,判别器尽可能识别出生成器数据是Fake。

In this article, we explore the special case when the generative model generates samples by passing random noise through a multilayer perceptron, and the discriminative model is also a multilayer perceptron.

创新点2:当两个model都使用神经网络时,可以运用反向传播和Dropout等算法进行学习,这样就可以避免使用马尔科夫链。

设计

To learn the generator’s distribution pgover data x, we define a prior on input noise variables pz(z), then represent a mapping to data space as G(z; θg), where G is a differentiable function represented by a multilayer perceptron with parameters θg. We also define a second multilayer perceptron D(x; θd) that outputs a single scalar. D(x) represents the probability that x came from the data rather than pg.

1.输入:为了让生成器G生成的数据分布pg与真实数据分布x接近,策略是给G输入一个噪音变量z,然后学习参数θg,这个θg是G网络权重。因此,G可以被写作:G(z;θg)。
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \underset{G}{min}\underset{D}{max}V(D, G) =\mathbb{E}_{x \sim p_{data}(x)}\left[ logD(x)\right] + \mathbb{E}_{z \sim p_z(z)}\left[log(1 - D(G(z)))\right] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
2.对抗性损失函数:从代码可知,对抗性损失是两个BCELoss的和,V尽可能使D(x)更大,在此基础上尽可能使G(z)更小。这是有先后顺序的,在后面会做说明。

在代码中可知,先人为生成两个标签,第一个标签是用torch.ones生成的全为1的矩阵,形状为(batch,1)。其中batch是输入噪声的batch,第二维度只是一个数字——1,这个标签用于判别器D的BCELoss中,代入BCELoss即可得到上面对抗性损失中左侧的期望。第二个标签是用torch.zeors生成的全为0的矩阵,形状同理为(batch,1),运用于生成器G的BCELoss中,代入即可得到对抗性损失的右侧期望。

we alternate between k steps of optimizing D and one step of optimizing G.

This results in D being maintained near its optimal solution, so long as G changes slowly enough.

3.D与G的训练有先后顺序:判别器D先于生成器G训练,而且要求先对D训练k步,再为G训练1步,这就保证G的训练比D足够慢。

如果生成器G足够强大,那么判别器无法再监测生成器,也就没有对抗的必要了。相反,如果判别器D太过于强大,那么生成器也训练地十分缓慢。

在这里插入图片描述

4.算法图如上。

训练代码

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from Model import generator
from Model import discriminator

import os

if not os.path.exists('gan_train.py'):  # 报错中间结果
    os.mkdir('gan_train.py')


def to_img(x):  # 将结果的-0.5~0.5变为0~1保存图片
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


batch_size = 96
num_epoch = 200
z_dimension = 100


# 数据预处理
img_transform = transforms.Compose([
    transforms.ToTensor(),  # 图像数据转换成了张量,并且归一化到了[0,1]。
    transforms.Normalize([0.5], [0.5])  # 这一句的实际结果是将[0,1]的张量归一化到[-1, 1]上。前面的(0.5)均值, 后面(0.5)标准差,
])
# MNIST数据集
mnist = datasets.MNIST(
    root='./data', train=True, transform=img_transform, download=True)
# 数据集加载器
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)

D = discriminator()  # 创建生成器
G = generator()  # 创建判别器
if torch.cuda.is_available():  # 放入GPU
    D = D.cuda()
    G = G.cuda()

criterion = nn.BCELoss()  # BCELoss 因为可以当成是一个分类任务,如果后面不加Sigmod就用BCEWithLogitsLoss
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)  # 优化器
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)  # 优化器

# 开始训练
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):  # img[96,1,28,28]
        G.train()
        num_img = img.size(0)  # num_img=batchsize
        # =================train discriminator
        img = img.view(num_img, -1)  # 把图片拉平,为了输入判别器 [96,784]
        real_img = img.cuda()  # 装进cuda,真实图片

        real_label = torch.ones(num_img).reshape(num_img, 1).cuda()  # 希望判别器对real_img输出为1 [96,1]
        fake_label = torch.zeros(num_img).reshape(num_img, 1).cuda()  # 希望判别器对fake_img输出为0  [96,1]

        # 先训练鉴别器
        # 计算真实图片的loss
        real_out = D(real_img)  # 将真实图片输入鉴别器 [96,1]
        d_loss_real = criterion(real_out, real_label)  # 希望real_out越接近1越好 [1]
        real_scores = real_out  # 后面print用的

        # 计算生成图片的loss
        z = torch.randn(num_img, z_dimension).cuda()  # 创建一个100维度的随机噪声作为生成器的输入 [96,1]
        #   这个z维度和生成器第一个Linear第一个参数一致
        # 避免计算G的梯度
        fake_img = G(z).detach()  # 生成伪造图片 [96,748]
        fake_out = D(fake_img)  # 给判别器判断生成的好不好 [96,1]

        d_loss_fake = criterion(fake_out, fake_label)  # 希望判别器给fake_out越接近0越好 [1]
        fake_scores = fake_out  # 后面print用的

        d_loss = d_loss_real + d_loss_fake

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # 训练生成器
        # 计算生成图片的loss
        z = torch.randn(num_img, z_dimension).cuda()  # 生成随机噪声 [96,100]

        fake_img = G(z)  # 生成器伪造图像 [96,784]
        output = D(fake_img)  # 将伪造图像给判别器判断真伪 [96,1]
        g_loss = criterion(output, real_label)  # 生成器希望判别器给的值越接近1越好 [1]

        # 更新生成器
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print(
                f'Epoch [{epoch}/{num_epoch}], d_loss: {d_loss.cpu().detach():.6f}, g_loss: {g_loss.cpu().detach():.6f}',
                f'D real: {real_scores.cpu().detach().mean():.6f}, D fake: {fake_scores.cpu().detach().mean():.6f}')
    if epoch == 0:  # 保存图片
        real_images = to_img(real_img.detach().cpu())
        save_image(real_images, './img_gan/real_images.png')

    fake_images = to_img(fake_img.detach().cpu())
    save_image(fake_images, f'./img_gan/fake_images-{epoch + 1}.png')

    G.eval()
    with torch.no_grad():
        new_z = torch.randn(batch_size, 100).cuda()
        test_img = G(new_z)
        print(test_img.shape)
        test_img = to_img(test_img.detach().cpu())
        test_path = f'./test_result/the_{epoch}.png'
        save_image(test_img, test_path)

# 保存模型
torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

网络结构代码

import torch
from torch import nn


# 判别器 判别图片是不是来自MNIST数据集
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),  # 784=28*28
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
            #   sigmoid输出这个生成器是或不是原图片,是二分类
        )

    def forward(self, x):
        x = self.dis(x)
        return x


# 生成器 生成伪造的MNIST数据集
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),  # 输入为100维的随机噪声
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            #   生成器输出的特征维和正常图片一样,这是一个可参考的点
            nn.Tanh()
        )

    def forward(self, x):
        x = self.gen(x)
        return x


class FinetuneModel(nn.Module):
    def __init__(self, weights):
        super(FinetuneModel, self).__init__()
        self.G = generator()
        base_weights = torch.load(weights)

        model_parameters = dict(self.G.named_parameters())
        #   不是对model进行named_parameters,而是对model里面的具体网络进行named_parameters取出参数,否则取出的是model冗余的参数去测试
        pretrained_weights = {k: v for k, v in base_weights.items() if k in model_parameters}

        new_state_dict = {k: pretrained_weights[k] for k in model_parameters.keys()}
        self.G.load_state_dict(new_state_dict)

    def forward(self, input):
        output = self.G(input)
        return output

测试代码

import os
import sys
import numpy as np
import torch
import argparse
import torch.utils.data
from PIL import Image
from Model import FinetuneModel
from Model import generator
from torchvision.utils import save_image

parser = argparse.ArgumentParser("GAN")
parser.add_argument('--save_path', type=str, default='./test_result')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=int, default=2)
parser.add_argument('--model', type=str, default='generator.pth')

args = parser.parse_args()
save_path = args.save_path
os.makedirs(save_path, exist_ok=True)


def to_img(x):  # 将结果的-0.5~0.5变为0~1保存图片
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


def main():
    if not torch.cuda.is_available():
        print("no gpu device available")
        sys.exit(1)

    model = FinetuneModel(args.model)
    model = model.to(device=args.gpu)
    model.eval()

    z_dimension = 100

    with torch.no_grad():
        for i in range(100):
            z = torch.randn(96, z_dimension).cuda()  # 创建一个100维度的随机噪声作为生成器的输入 [96,100]
            output = model(z)
            print(output.shape)
            u_name = f'the_{i}.png'
            print(f'processing {u_name}')
            u_path = save_path + '/' + u_name
            output = to_img(output.cpu().detach())
            save_image(output, u_path)


if __name__ == '__main__':
    main()

本文毕

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1295907.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

06 JQuery调用接口

文章目录 一、Qs.js库介绍1. Qs简介2. Qs.parse3. Qs.stringify 二、jQuery调用接口1. 增加(Create)2. 删除(Delete)3. 读取(Read)4. 更新(Update) 三、示例 一、Qs.js库介绍 1. Qs…

机器连接和工业边缘计算

软件应用和IT创新是制造业投资的主要驱动力。解决方案架构应围绕特定标准进行整合,并采用架构蓝图和最佳实践来满足最终用户的需求。此外,边缘计算(Edge Computing)也将在制造业中加速部署。 边缘计算是制造业的下一个变革驱动力。…

视频剪辑高手揭秘:如何批量减少时长并调整播放速度,提升视频效果

随着社交媒体的兴起,视频制作的需求越来越大。然而往往视频文件存在一些问题,例如时长过长,或者要调整播放速度以更好地传达信息。这些问题不仅影响了视频的观看体验,也可能导致视频难以在社交媒体上获得广泛的传播。那么&#xf…

电子学会C/C++编程等级考试2021年06月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:数字变换 给定一个包含5个数字(0-9)的字符串,例如 “02943”,请将“12345”变换到它。 你可以采取3种操作进行变换 1. 交换相邻的两个数字 2. 将一个数字加1。如果加1后大于9,则变为0 3. 将一个数字加倍。如果加倍后大于…

react Hooks实现原理

Fiber 上篇文章fiber简单理解记录了react fiber架构,Hooks是基于fiber链表来实现的。阅读以下内容时建议先了解react fiber。 jsx -> render function -> vdom -> fiber树 -> dom vdom 转 fiber 的过程称为 recocile。diff算法就是在recocile这个过程…

深入探讨Go语言协程调度:GRM模型解析与优化策略

一、线程调度 1、早期单线程操作系统 一切的软件都是跑在操作系统上,真正用来干活(计算)的是 CPU早期的操作系统每个程序就是一个进程,直到一个程序运行完,才能进行下一个进程,就是“单进程时代”一切的程…

数据可视化:解锁企业经营的智慧之道

在现代企业管理中,数据可视化已经成为了一项重要的工具。它不仅仅是简单地展示数据,更是提供了深入理解数据、做出更明智决策的方法。作为一名可视化设计从业人员,我经手过一些企业自用的数据可视化项目,今天就来和大家聊聊数据可…

C语言第十七集(待修)

11.30的视频 1.结构体可以这样重新赋值 注:字符数组不能用来赋值 2.匿名结构体重新赋值方法: 注:在创建x时就已经使用过一次匿名结构体了 但是,在使用匿名结构体时,可以一次性创立多个变量 3.结构体内存对齐和对其规则详细搜: 4.总之,我们在创建结构体时,要将占用空间小的成…

一个newman命令行让某大厂瘫痪半天,速看!

newman简介 newman是为Postman而生,专门用来运行Postman编写好的脚本; 使用newman,你可以很方便的用命令行来执行postman collections。 newman的安装 1.先下载Node.js;https://nodejs.org/en/ 2.安装NodeJs(很容易安装&#x…

cmd命令 常用的命令

网络工作为常年公司里的背锅侠,不得不集齐十八般武艺很难甩锅。像cmd命令这种好用又好上手的技术,就是网络工程师上班常备技能。 只要按下快捷键 winR,输入cmd回车,然后输入cmd命令。 像我自己,我就经常用cmd命令检测…

初识优先级队列与堆

1.优先级队列 由前文队列queue可知,队列是一种先进先出(FIFO)的数据结构,但有些情况下,操作的数据可能带有优先级,一般出队列时,可能需要优先级高的元素先出队列,在此情况下,使用队列queue显然不…

xilinx的XVC协议

文章目录 概述JTAG工作方式XVC协议 其他Debug Bridge IP 概述 JTAG工作方式 XVC协议 其他 Debug Bridge IP

cookie总结

cookie和session: 一、Cookie和Session二、使用Cookie保存用户上次的访问时间。三、Cookie常用方法总结乱码问题解决: 一、Cookie和Session 会话:用户从打开浏览器到关闭的整个过程就叫1次会话。 比如有的网站登录过一次,下次再进…

python的websocket方法教程

WebSocket是一种网络通信协议,它在单个TCP连接上提供全双工的通信信道。在本篇文章中,我们将探讨如何在Python中使用WebSocket实现实时通信。 websockets是Python中最常用的网络库之一,也是websocket协议的Python实现。它不仅作为基础组件在…

Spring AOP 概念及其使用

目录 AOP概述 什么是AOP? 什么是Spring AOP ? Spring AOP 快速入门 1.引⼊ AOP 依赖 2.编写AOP程序 Spring AOP 核心概念 1.切点 2.连接点 3.通知 4.切面 通知类型 注意事项: PointCut(定义切点) 切面优先级 Order 切点表达…

IDEA删除最近打开的文件记录

IDEA删除最近打开的文件记录 遇见问题:如何删除IDEA中最近打开的文件记录 解决方法 先关闭IDEA 找到 recentProjects.xml 文件 windows 位置:(AppData是隐藏文件夹) 1.C:\Users\电脑用户名\AppData\Roaming\JetBrains\IntelliJIde…

Bash脚本调用百度翻译API进行中文到英文的翻译

写一个bash脚本调用百度翻译API进行中文到英文的翻译,首先需要进行相关的申请。看百度给出的文档链接: 百度翻译API文档 需要先注册一个百度账号,然后申请APPID。脚本中会用到appid和key这两个值。按照文档给出的提示可以获得。如下是脚本: #…

零基础如何入门HarmonyOS开发?

HarmonyOS鸿蒙应用开发是当前非常热门的一个领域,许多人都想入门学习这个技术。但是,对于零基础的人来说,如何入门确实是一个问题。下面,我将从以下几个方面来介绍如何零基础入门HarmonyOS鸿蒙应用开发学习。 一、了解HarmonyOS鸿…

markdown记录

文章目录 基础操作使用一级列表、二级列表 博文链接 基础操作 使用一级列表、二级列表 博文链接 CSDN-Markdown语法集锦 CSDN-markdown语法之如何使用LaTeX语法编写数学公式 CSDN Markdown简明教程1-关于Markdown CSDN Markdown简明教程2-基本使用 CSDN Markdown简明教程3-表…

X86汇编语言:从实模式到保护模式(代码+注释)--c6

X86汇编语言:从实模式到保护模式(代码注释)–c6 标志寄存器FLAGS: 6th:ZF位(Zero Flag):零标志,执行算数或者逻辑运算之后,会将该位置位。10th:D…