GAN | 代码简单实现生成对抗网络(GAN)(PyTorch)

news2024/9/22 15:45:46

2014年GAN发表,直到最近大火的AI生成全部有GAN的踪迹,快来简单实现它!!!

GAN通过计算图和博弈论的创新组合,他们表明,如果有足够的建模能力,相互竞争的两个模型将能够通过普通的旧反向传播进行共同训练。

这些模型扮演着两种不同的(字面意思是对抗的)角色。给定一些真实的数据集R,G是生成器,试图创建看起来像真实数据的假数据,而D鉴别器,从真实集或G获取数据并标记差异。 G就像一造假机器,通过多次画画练习,使得画出来的话像真图一样。而D是试图区分的侦探团队。(除了在这种情况下,伪造者G永远看不到原始数据——只能看到D的判断。他们就像盲人摸象的探索伪造的人

Sourse

GAN实现代码

#!/usr/bin/env python

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

matplotlib_is_available = True
try:
  from matplotlib import pyplot as plt
except ImportError:
  print("Will skip plotting; matplotlib is not available.")
  matplotlib_is_available = False

# Data params
data_mean = 4
data_stddev = 1.25

# ### Uncomment only one of these to define what data is actually sent to the Discriminator
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
#(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)

print("Using data [%s]" % (name))

# ##### DATA: Target data and generator input data

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        x = self.map3(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.f(self.map3(x))

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def get_moments(d):
    # Return the first 4 moments of the data provided
    mean = torch.mean(d)
    diffs = d - mean
    var = torch.mean(torch.pow(diffs, 2.0))
    std = torch.pow(var, 0.5)
    zscores = diffs / std
    skews = torch.mean(torch.pow(zscores, 3.0))
    kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0  # excess kurtosis, should be 0 for Gaussian
    final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))
    return final

def decorate_with_diffs(data, exponent, remove_raw_data=False):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    if remove_raw_data:
        return torch.cat([diffs], 1)
    else:
        return torch.cat([data, diffs], 1)

def train():
    # Model parameters
    g_input_size = 1      # Random noise dimension coming into generator, per output vector
    g_hidden_size = 5     # Generator complexity
    g_output_size = 1     # Size of generated output vector
    d_input_size = 500    # Minibatch size - cardinality of distributions
    d_hidden_size = 10    # Discriminator complexity
    d_output_size = 1     # Single dimension for 'real' vs. 'fake' classification
    minibatch_size = d_input_size

    d_learning_rate = 1e-3
    g_learning_rate = 1e-3
    sgd_momentum = 0.9

    num_epochs = 5000
    print_interval = 100
    d_steps = 20
    g_steps = 20

    dfe, dre, ge = 0, 0, 0
    d_real_data, d_fake_data, g_fake_data = None, None, None

    discriminator_activation_function = torch.sigmoid
    generator_activation_function = torch.tanh

    d_sampler = get_distribution_sampler(data_mean, data_stddev)
    gi_sampler = get_generator_input_sampler()
    G = Generator(input_size=g_input_size,
                  hidden_size=g_hidden_size,
                  output_size=g_output_size,
                  f=generator_activation_function)
    D = Discriminator(input_size=d_input_func(d_input_size),
                      hidden_size=d_hidden_size,
                      output_size=d_output_size,
                      f=discriminator_activation_function)
    criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
    d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
    g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)

    for epoch in range(num_epochs):
        for d_index in range(d_steps):
            # 1. Train D on real+fake
            D.zero_grad()

            #  1A: Train D on real
            d_real_data = Variable(d_sampler(d_input_size))
            d_real_decision = D(preprocess(d_real_data))
            d_real_error = criterion(d_real_decision, Variable(torch.ones([1])))  # ones = true
            d_real_error.backward() # compute/store gradients, but don't change params

            #  1B: Train D on fake
            d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            d_fake_decision = D(preprocess(d_fake_data.t()))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1])))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

            dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]

        for g_index in range(g_steps):
            # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            g_fake_data = G(gen_input)
            dg_fake_decision = D(preprocess(g_fake_data.t()))
            g_error = criterion(dg_fake_decision, Variable(torch.ones([1])))  # Train G to pretend it's genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters
            ge = extract(g_error)[0]

        if epoch % print_interval == 0:
            print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %
                  (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))

    if matplotlib_is_available:
        print("Plotting the generated distribution...")
        values = extract(g_fake_data)
        print(" Values: %s" % (str(values)))
        plt.hist(values, bins=50)
        plt.xlabel('Value')
        plt.ylabel('Count')
        plt.title('Histogram of Generated Distribution')
        plt.grid(True)
        plt.show()


train()

代码输出结果

个人总结

GAN从编程的角度来看(纯个人理解,不对可指正)

  • 利用numpy的random方法,随机生成多维的噪音向量

  • 创建一个G网络用来生成

  • 创建一个D网络用来判断

  • 俩个网络在训练时分别进行优化

  • 先训练D网络去判断真假:如果训练D为真时,进行传播;如果训练D为假时,进行传播,投入优化器(1为真,0为假)

  • 在D的基础上训练G。

*因为是随机生成,所以每次生成结果不同

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/377464.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GMSL相机的相关配置(1)

文章目录一:GMSL相机的信息二:相关配置1.emmc系统下运行upgrade文件2.连接GMSL相机3.给ui可执行文件赋权限4.进入图为GMSL相机配置ui图形界面5.运行程序,打开摄像头一:GMSL相机的信息 我选择相机适配于基于Jetson AGX Orin的图为…

Docker在Windows环境的搭建和使用

文章目录安装WSL安装Docker安装Docker镜像下载Docker镜像启动gpu启动传送文件训练yolov5安装WSL Windows10和11支持Docker的安装,安装需要用到WSL。所以,我们先安装WSL。 参考文章:旧版 WSL 的手动安装步骤 以管理员身份打开powershell, 执行…

matplotlib综合学习

1.arange函数arange函数需要三个参数,分别为起始点、终止点、采样间隔。采样间隔默认值为1看例子: import numpy as np #import matplotlib.pyplot as plt xnp.arange(-5,5,1) print(x)2.绘制sin(x)曲线import numpy as np import matplotlib.pyplot as …

Python jieba分词如何添加自定义词和去除不需要长尾词

Python jieba分词如何添加自定义词和去除不需要长尾词 作者:虚坏叔叔 博客:https://xuhss.com 早餐店不会开到晚上,想吃的人早就来了!😄 通过如下代码,读取一个txt的高频词汇: # 找到高频词汇t…

苹果触控笔有必要买吗?开学季性价比电容笔推荐

Apple Pencil的性能的确不错,但是由于它的售价实在是太高了,一般人还是舍不得花那么多钱买下来。目前市场上有很多平替的电容笔,不仅价格便宜,而且使用方便。那么,我们应该选择那个牌子的平替电笔呢?在购买…

“智能”创造未来:PDU智能化全面提升IDC数据中心用电能效!

一个月前,万众期盼的《流浪地球2》如期上映,无论是剧情还是特效,让广大观众享受到一次久违的来自中国科幻的震撼,时至今日仍是大家茶余饭后津津乐道的热点谈资。说起这部片子里,最让人紧张的部分,还得数为了…

解决MySQL的 Row size too large (> 8126).

📢欢迎点赞 :👍 收藏 ⭐留言 📝 如有错误敬请指正,赐人玫瑰,手留余香!📢本文作者:由webmote 原创📢作者格言:无尽的折腾后,终于又回到…

电脑系统崩溃怎么修复教程

系统崩溃了怎么办? 如今的软件是越来越复杂、越来越庞大。由系统本身造成的崩溃即使是最简单的操作,比如关闭系统或者是对BIOS进行升级都可能会对PC合操作系统造成一定的影响。下面一起来看看电脑系统崩溃修复方法步骤。 工具/原料: 系统版本&#xf…

LeetCode-47. 全排列 II

目录题目思路回溯法拓展题目来源 47. 全排列 II 题目思路 这道题目和46.全排列的区别在与给定一个可包含重复数字的序列,要返回所有不重复的全排列。 强调的是去重一定要对元素进行排序,这样我们才方便通过相邻的节点来判断是否重复使用了。 我以示例中…

CC2530+ESP8266使用MQTT协议上传阿里云的问题

ATMQTTPUB<LinkID>,<"topic">,<"data">,<qos>,<retain>LinkID: 当前只支持 0 topic: 发布主题, 最长 64 字节 data: 发布消息, data 不能包含 \0, 请确保整条 ATMQTTPUB 不超过 AT 指令的最大长度限制 qos: 发布服务质量, 参…

项目管理软件排行榜!盘点前十名!

项目管理软件排行榜&#xff01;盘点前十名&#xff01; 如今企业规模不断扩大&#xff0c;业务逐渐复杂化&#xff0c;项目管理已经成为现代企业管理中不可或缺的一环。作为协调管理者、团队成员和客户之间交流的工具&#xff0c;项目管理软件不仅可以提高工作效率&#xff0…

数据结构入门--时间 空间复杂度

数据结构入门 时间 空间复杂度解析 目录 一. 算法效率 二. 时间复杂度 2.1 时间复杂度的概念 2.2 大O的渐进表示法 2.3 题目练习 题目一 题目二 题目三 题目四 题目五 题目六 题目七 三. 空间复杂度 3.1 题目练习 题目一 题目二 题目三 一. 算法效率 算法效率…

Vim常用命令汇总

目录1 普通模式2 插入模式3 可视模式4 命令行模式4 文件操作5 动作命令1 普通模式 命令操作符合命令作用等同命令.重复上次修改x删除光标下的字符dd删除整行>G从当前行到文档末尾处的缩进层级a在当前光标之后添加内容i在当前光标之前添加内容A在当前行的结尾添加内容$aI在当…

Docker之安装redis

下面记录一下在docker中安装redis过程 1.查看redis镜像 docker search redis2.拉去镜像到Linux //指定拉取redis版本 docker pull redis:6.0 //不指定版本默认拉取最新 docker pull redis3.查看镜像是否拉取成功 docker images4.启动redis //2f66aad5324为redis的image id do…

驾驭云安全:2023年云安全展望

由于其的良好的可扩展性和优质的事件处理效率&#xff0c;云技术已成为现代企业的必备的管理技术之一&#xff0c;目前他已经成为所有行业及企业的热门选择。然而&#xff0c;攻击面积的增加以及不针对云技术衍生出来的多类攻击方式&#xff0c;使许多企业更容易受到威胁和数据…

Docker学习总结

1、镜像操作 1.1 拉取、查看镜像 步骤一&#xff1a; 首先去镜像仓库搜索nginx镜像&#xff0c;比如[DockerHub]( Docker Hub Container Image Library | App Containerization ) : 步骤二&#xff1a; 根据查看到的镜像名称&#xff0c;拉取自己需要的镜像 通过命令&…

代码随想录 NO54 |单调栈_leetcode 503.下一个更大元素II 42. 接雨水

单调栈_leetcode 503.下一个更大元素II 42. 接雨水单调栈第二天&#xff0c;也是本轮刷题任务倒数第二天&#xff0c;加油&#xff01; 503.下一个更大元素II 这道题和739. 每日温度几乎如出一辙。在遍历的过程中模拟走了两遍nums。 class Solution:def nextGreaterElements(…

算法设计与分析——十大经典排序算法一(1--5)

目录 算法设计与分析——十大经典排序算法 第1关&#xff1a;冒泡排序 参考代码 第2关&#xff1a;选择排序 参考代码 第3关&#xff1a;插入排序 参考代码 第4关&#xff1a;希尔排序 参考代码 第5关&#xff1a;归并排序 参考代码 作者有言 一个不知名大学生&#x…

软考信息系统监理师备考建议

用好备考方法&#xff0c;两三个月就可以过的。信息系统监理师备考最好以教材和历年真题为主&#xff0c;教学视频模拟题为辅。考试介绍与复习建议&#xff1a;考试设置的科目包括&#xff1a;&#xff08;1&#xff09;信息系统工程监理基础知识&#xff0c;考试时间150分钟&a…

回顾1-idea创建Java项目

创建Java项目 创建项目和模块的区别 环境前置 IDEA开发工具JDK及配置环境变量 创建项目/工程 新建项目 选择Java模块 > SDK( 已配置的JDK ) > 下一步 直接下一步 填写项目信息 QQ游戏工程 里的 叫项目 所以 QQgame目录下 可以放 > 斗地主项目 / 美女来找茬等… …