深度学习——生成对抗网络GAN

news2024/11/24 23:08:28

基本概念

概述

GAN是一种深度学习模型,它是一种无监督学习算法,用于从随机噪声中生成逼真的数据,比如图像、音频、文本等。GAN的结构由两个神经网络组成:生成器(Generator)和判别器(Discriminator),它们彼此竞争,从而推动整个模型学习。

两个主要组件:

1.生成器(Generator):

生成器的目标是将随机噪声(通常是从正态分布或均匀分布中采样的向量)转换成逼真的数据样本。这个过程可以理解为生成器学习了数据的分布,并尝试创建与真实数据相似的新样本。初始阶段,生成器的输出可能是随机的,但随着训练的进行,它会逐渐生成更逼真的数据,以欺骗判别器。

2.判别器(Discriminator):

判别器的任务是对输入的数据样本进行分类,即判断它是真实数据还是由生成器产生的假数据。判别器是一个二元分类器,它的目标是尽可能准确地区分真实数据和生成器生成的假数据。

训练过程

1.在训练开始时,生成器随机产生一些假数据样本,并与真实数据一起提供给判别器。
2.判别器根据输入的数据对其进行分类,并输出概率估计(0代表假数据,1代表真实数据)。
3.根据判别器的输出,计算生成器生成数据被判别为真实数据的概率,并将这个概率作为生成器的“损失”(loss)。
4.接下来,根据生成器的损失,更新生成器的参数,使生成器能够生成更逼真的数据样本。
5.然后,再次随机产生一批假数据样本,并将它们与真实数据一起提供给判别器,重复以上过程。

通过这种竞争和博弈的过程,生成器和判别器逐渐优化自己的能力,直到生成器可以生成高度逼真的数据样本,而判别器无法准确区分真假。

代码与注释

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Hyper Parameters
BATCH_SIZE = 64
# 生成器学习率
LR_G = 0.0001  # learning rate for generator
# 判别器学习率
LR_D = 0.0001  # learning rate for discriminator
N_IDEAS = 5  # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15  # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])


# 定义函数artist_works,用于生成来自著名艺术家的真实画作数据
def artist_works():
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a - 1)
    paintings = torch.from_numpy(paintings).float()
    return paintings


# 定义生成器(Generator)和判别器(Discriminator)
# 初级画家
G = nn.Sequential(
    nn.Linear(N_IDEAS, 128),  # 生成器输入为随机噪声数据
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # 生成器输出为生成的艺术作品
)

# 初级鉴赏家
D = nn.Sequential(
    nn.Linear(ART_COMPONENTS, 128),  # 判别器输入为艺术作品数据
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # 判别器输出为对艺术作品的真假概率
)

# 定义两个优化器,分别用于优化生成器和判别器的参数
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

# 开始GAN的训练
plt.ion()  # 打开交互式绘图

for step in range(10000):
    # 获取来自艺术家的真实画作数据
    artist_paintings = artist_works()
    # 生成随机的噪声数据
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True)
    # 生成器生成假的艺术画作
    G_paintings = G(G_ideas)

    # 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
    prob_artist1 = D(G_paintings)
    # 计算生成器的损失
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_G.zero_grad()  # 清空生成器的梯度
    G_loss.backward()  # 反向传播计算生成器的梯度
    opt_G.step()  # 优化生成器的参数

    # 判别器对真实画作进行判断,试图增大判别器对真实画作的概率
    prob_artist0 = D(artist_paintings)
    # 判别器对生成的画作进行判断,试图减小判别器对生成画作的概率
    prob_artist1 = D(G_paintings.detach())
    # 计算判别器的损失
    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))

    opt_D.zero_grad()  # 清空判别器的梯度
    D_loss.backward(retain_graph=True)  # 反向传播计算判别器的梯度(保留计算图以供下一次计算)
    opt_D.step()  # 优化判别器的参数

    if step % 50 == 0:  # 每隔一段时间进行绘图显示
        # 绘制生成的画作、上界和下界
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting')
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((0, 3))
        plt.legend(loc='upper right', fontsize=10)
        plt.draw()
        plt.pause(0.01)

plt.ioff()  # 关闭交互式绘图
plt.show()  # 展示绘制的图像

运行结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/778962.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

谈谈面试大厂中碰到的问题

面试IT公司的小技巧 非常不建议在简历上造假,简历上能起到关键作用、有分量的部分,别人都是有办法去核实的,比如教育背景、关键性的证书、奖项等;核实不了的,又基本上也对结果产生不了太大影响,又何必去画…

USG6000v防火墙的基本使用:制定安全策略让不同安全区域的设备进行访问

目录 一、首先配置环境: 二、实验拓扑及说明 拓扑: PC1和PC2配置ip地址:​编辑​编辑 r4路由器配置ip: 进行防火墙的设置: 1、创建trust1区域和untrust1区域 2、制定防火墙的策略: 3、为防火墙增加可以…

【PostgreSQL内核学习(四)—— 查询规划】

查询规划 查询规划总体处理流程pg_plan_queries函数standard_planner函数subquery_planner函数inheritance_planner函数grouping_planner函数 总结 声明:本文的部分内容参考了他人的文章。在编写过程中,我们尊重他人的知识产权和学术成果,力求…

SpringBoot原理分析 | 任务:异步、邮件、定时

💗wei_shuo的个人主页 💫wei_shuo的学习社区 🌐Hello World ! 任务 异步任务 Java异步指的是在程序执行过程中,某些任务可以在后台进行,而不会阻塞程序的执行。通常情况下,Java异步使用线程池来…

Apikit 自学日记:如何使用定时执行测试用例功能呢?

API自动化测试其实可以设置定时任务,实现项目在无人值守的情况下自动测试并且发送报告给相应的邮箱,监控项目监控情况。 这样一来,就能大大提高工作效率。 目前在 APIkit中这一部分主要功能有: 1.允许对测试任务进行分组&#xf…

Llama 2: Open Foundation and Fine-Tuned Chat Models

文章目录 TL;DRIntroduction背景本文方案 实现方式预训练预训练数据训练细节训练硬件支持预训练碳足迹 微调SFTSFT 训练细节 RLHF人类偏好数据收集奖励模型迭代式微调(RLHF)拒绝采样(Rejection Sampling)PPO多轮一致性的系统消息&…

GAMS---典型优化模型和算法介绍、GAMS安装和介绍、GAMS程序编写、GAMS程序调试、实际应用算例演示与经验分享

优化分析是很多领域中都要面临的一个重要问题,求解优化问题的一般做法是:建立模型、编写算法、求解计算。常见的问题类型有线性规划、非线性规划、混合整数规划、混合整数非线性规划、二次规划等,优化算法包括人工智能算法和内点法等数学类优…

S32K144 GPIO外设分析

1. S32K144 GPIO外设特性 下面的内容来自于S32K用户手册的翻译,或者网上关于S32K系列的一些pdf文件介绍。有些内容可能会出现理解不到位或者翻译错误方面,如果大家有疑问最好可以查阅用户手册。 GPIO和PORT的数量 从用户手册,对于PCR&#x…

python try/except/finally

稍微总结一下&#xff0c;否则总是忘。 x abc def fetcher(obj, index): return obj[index] fetcher(x, 4) 输出&#xff1a; File "test.py", line 6, in <module> fetcher(x, 4) File "test.py", line 4, in fetcher return obj[index] …

实验四 回溯法

实验四 回溯法 售货员问题 1.实验内容 1、理解回溯法的深度优先搜索策略&#xff0c;掌握用回溯法解题的算法框架 2、设计并实现旅行售货员问题问题&#xff0c;掌握回溯算法。 2.实验环境 Java 3.问题描述 旅行售货员问题&#xff1a;设有一个售货员从城市1出发&#…

docker-compose自建RustDesk远程控制服务器

github&#xff1a; rustdesk/rustdesk-server: RustDesk Server Program (github.com) 一、创建 docker-compose.yml 文件&#xff0c;复制以下 docker-compose 配置文件内容到文件 version: 3networks:rustdesk-net:external: falseservices:hbbs:container_name: hbbspor…

开源ThinkMusic搭建音乐网站,并实现公网连接

1、前言 在我们的日常生活中&#xff0c;音乐已经成为不可或缺的要素之一&#xff0c;听几首喜欢的音乐&#xff0c;能让原本糟糕的心情变得好起来。虽然现在使用电脑或移动电子设备听歌都很方便&#xff0c;但难免受到诸多会员或VIP限制&#xff0c;难免让我们回想起音乐网站…

DAY6,C++(将顺序栈,顺序循环队列定义成模板类);

1.将顺序栈定义成模板类&#xff1b;​​​​​​ 顺序栈模板代码--- #include <iostream>using namespace std;template<typename T> class Stack { private:T *data; //指向堆区空间int top; //记录栈顶位置public:Stack(); //无参构造Stack(T size); //有…

子网划分路由网卡安全组

1."IPv4 CIDR" "IPv4 CIDR" 是与互联网协议地址&#xff08;IP address&#xff09;和网络的子网划分有关的概念。 - "IPv4" 代表 "Internet Protocol version 4"&#xff0c;也就是第四版互联网协议&#xff0c;这是互联网上最广泛使…

动态规划入门第4课,经典DP问题3 ----公共最长子序列

练习 第1题 最长公共子串 查看测评数据信息 给出2个小写字母组成的字符串&#xff0c;求它们最长的公共子串的长度是多少&#xff1f; 例如&#xff1a;”abcdefg” 与”xydoeagab”。有最长的公共子串”deg”&#xff0c; 答案为&#xff1a;3。 输入格式 第一行&#xff…

Java并发编程学习笔记(一)线程的入门与创建

一、进程与线程 认识 程序由指令和数据组成&#xff0c;简单来说&#xff0c;进程可以视为程序的一个实例 大部分程序可以同时运行多个实例进程&#xff0c;例如记事本、画图、浏览器等少部分程序只能同时运行一个实例进程&#xff0c;例如QQ音乐、网易云音乐等 一个进程可以…

【密码学】三、DES

DES 1、DES的加密过程2、初始置换3、16轮迭代变换过程3.1 扩展变换/位选择函数E3.2 S盒代换3.3P盒置换 4、初始逆置换5、密钥扩展5.1 选择置换PC_15.2选择置换PC_2 6、DES的解密过程7、多重DES 美国正式公布实施的DES是一个众所周知的分组密码&#xff0c;其 分组长度是64bit&…

Redis 缓存机制介绍

.Redis 缓存 缓存&#xff08;cache&#xff09;&#xff0c;原始意义是指访问速度比一般随机存取存储器&#xff08;RAM&#xff09;快的一种高速存储器&#xff0c;通常它不像系统主存那样使用 DRAM 技术&#xff0c;而使用昂贵但较快速的 SRAM 技术。缓存的设置是所有现代计…

arm点灯

.text .global _start _start: /**********LED1点灯**************/RCC_INIT:LDR R0,0X50000A28LDR R1,[R0]orr R1,R1,#(0x1<<4)orr R1,R1,#(0X1<<5)STR R1,[R0] LED1_INIT:/**/LDR R0,0X50006000LDR R1,[R0]and R1,R1,#(~(0X3<<20))orr R1,R1,#(0x1<<…

windows下载pytorch gpu时遇见的问题以及解决方案

一些很奇怪的问题 使用官方命令下载失效离线下载之后使用pip安装又md报错了 使用官方命令下载失效 这是官方的下载命令&#xff0c;我在运行这个命令的时候咋的都报错&#xff0c;真的无语。 报错信息如下&#xff08;当时没截图&#xff0c;我再创建个新环境运行此命令给大家…