GAN的入门理解

news2024/12/24 2:40:23

这一篇主要是关于生成对抗网络的模型笔记,有一些简单的证明和原理,是根据李宏毅老师的课程整理的,下面有链接。本篇文章主要就是梳理基础的概念和训练过程,如果有什么问题的话也可以指出的。

李宏毅老师的课程链接

1.概述

GAN是Generative Adversarial Networks的缩写,也就是生成对抗网络,最核心在于训练两个网络分别是generator和discriminator,generator主要是输入一个向量,输出要生成的目标,discriminator接受一个输出的目标,然后输出为真的概率(来说就是打分)。

假设任务是生成一组图片,现在输入是一组图片数据集,一开始随便生成乱七八糟的数据,训练共有两个核心的步骤:

  1. 更新discriminator,将真实数据标记为1,生成的数据标记为0,然后进行训练,那么discrimator就可以辨别生成图片。
  2. 更新generator,更新生成网络的参数,让生成网络生成的图片能让discriminator输出尽可能大(打分尽可能高,也就是骗过discriminator)。
  3. 回到1,重复这个过程。

下面是最原始的论文提出的伪代码:

在这里插入图片描述

可以看到第一个阶段是在更新discriminator,D(x)表示对输入图像x的判别,损失函数是两项累加,前面的 x i x^i xi表示真实输入,这些应该输出1,后面的 x ~ i \widetilde{x}^i x i表示生成数据,这些应该给低分(接近0),两项的目标都是越大越好,所以 V ~ \widetilde{V} V 越大越好,因此 θ d \theta_d θd是梯度上升优化。

第二阶段在更新generator, G ( z i ) G(z^i) G(zi)就是对一个向量生成一个目标,然后进行打分,也是越大越好,因此梯度上升优化,这一部分的目标就是让生成的图片尽量得高分。

循环多次迭代就可以得到预期网络。

当然目前我还有一些疑问:

  1. generator输出的图片是如何保证风格和数据集类似的?

    应该是必须要像原风格一样的才能得到高分。

  2. 输入的向量是随机的,如何可控输入向量和输出特征的关系?如何解释每个输入的数字?(比如我想生成蓝色的头发,那么这个是可控的吗)
    这个可能要看了一些具体的代码才能理解。

2.原理简单分析

生成一个图片或者一个语音本质是映射到一个高维点的问题,比如32×32的黑白图片就是 2 32 × 32 2^{32\times 32} 232×32空间中的一个点。下面都以图片生成任务为例,假设真实分布是 P d a t a P_{data} Pdata,生成的分布是 P G P_{G} PG,只有生成的点(图片)到了真实的分布中,才有极大可能是看上去真实的,因此目标就是让生成的分布 P G P_G PG尽可能接近真实的分布 P d a t a P_{data} Pdata,方法就是KL散度或者JS散度,因此一个理想的生成器应该是这样的:
G ∗ = a r g min ⁡ G D i v ( P G , P d a t a ) G^*=arg\min_G Div(P_G,P_{data}) G=argGminDiv(PG,Pdata)其中Div衡量两个分布的差异,而 G ∗ G^* G就是所有生成器 G G G中有着最小差异的那个,也就是最优的。

然而,实际情况中,真实的分布和实际的分布都是未知的,一些传统的算法可能假设高斯分布,但是很多时候可能不正确。

虽然不能直接得到分布,但是可以进行采样(Sample),在GAN中,discriminator就扮演了计算两个分布差异的角色,给出下式:
V ( G , D ) = E x ∼ P d a t a l o g ( D ( x ) ) + E x ∼ P G l o g ( 1 − D ( x ) ) V(G,D)=E_{x\sim P_{data}}log(D(x))+E_{x\sim P_{G}}log(1-D(x)) V(G,D)=ExPdatalog(D(x))+ExPGlog(1D(x))其中 D ( x ) D(x) D(x)表示一个discriminator对一个generator生成的结果进行打分,介于 [ 0 , 1 ] [0,1] [0,1],下面证明这个式子本质上也是JS散度或者KL散度:


证明
max ⁡ E x ∼ P d a t a l o g ( D ( x ) ) + E x ∼ P G l o g ( 1 − D ( x ) ) = max ⁡ ∫ x p d a t a ( x ) l o g ( D ( x ) ) + ∫ x p G ( x ) l o g ( 1 − D ( x ) ) = max ⁡ ∫ x p d a t a ( x ) l o g ( D ( x ) ) + p G ( x ) l o g ( 1 − D ( x ) ) \max E_{x\sim P_{data}}log(D(x))+E_{x\sim P_{G}}log(1-D(x))\\ =\max \int_xp_{data}(x)log(D(x))+\int_xp_{G}(x)log(1-D(x))\\ =\max \int_xp_{data}(x)log(D(x))+p_{G}(x)log(1-D(x)) maxExPdatalog(D(x))+ExPGlog(1D(x))=maxxpdata(x)log(D(x))+xpG(x)log(1D(x))=maxxpdata(x)log(D(x))+pG(x)log(1D(x))
这里假设D(x)可以拟合任何函数,那么对于任意一个x取值 x ∗ x^* x D ( x ∗ ) D(x^*) D(x)都可以对应任何数值,这就意味着可以对每个x都计算最大值,然后求和得到最大值。

a = p d a t a ( x ) , b = p G ( x ) , D ( x ) = t a=p_{data}(x),b=p_G(x),D(x)=t a=pdata(x),b=pG(x),D(x)=t,那么可以得到下式:
f ( t ) = a l o g ( t ) + b l o g ( 1 − t ) f(t)=alog(t)+blog(1-t) f(t)=alog(t)+blog(1t)求导计算最小值对应的t:(直接假设e为底了)
f ′ ( t ) = a x − b 1 − x f'(t)=\frac{a}{x}-\frac{b}{1-x} f(t)=xa1xb
f ′ ( t ) = 0 f'(t)=0 f(t)=0,得到 t = a a + b t=\frac{a}{a+b} t=a+ba,代入 a , b , t a,b,t a,b,t,假设这个值为最优值 D ∗ ( x ) D^*(x) D(x)
D ∗ ( x ) = p d a t a ( x ) p d a t a ( x ) + p G ( x ) D^*(x)=\frac{p_{data}(x)}{p_{data}(x)+p_G(x)} D(x)=pdata(x)+pG(x)pdata(x)此时每个x都有对应的 D ∗ ( x ) D^*(x) D(x),代入得到:
max ⁡ ∫ x p d a t a ( x ) l o g ( D ( x ) ) + p G ( x ) l o g ( 1 − D ( x ) ) = ∫ x p d a t a ( x ) l o g ( D ∗ ( x ) ) + p G ( x ) l o g ( 1 − D ∗ ( x ) ) = ∫ x p d a t a ( x ) l o g ( p d a t a ( x ) p d a t a ( x ) + p G ( x ) ) + p G ( x ) l o g ( p G ( x ) p d a t a ( x ) + p G ( x ) ) = − 2 l o g 2 + ∫ x p d a t a ( x ) l o g ( p d a t a ( x ) ( p d a t a ( x ) + p G ( x ) ) / 2 ) + p G ( x ) l o g ( p G ( x ) ( p d a t a ( x ) + p G ( x ) ) / 2 ) = − 2 l o g 2 + K L ( P d a t a ∣ ∣ P d a t a + P G 2 ) + K L ( P G ∣ ∣ P d a t a + P G 2 ) = − 2 l o g 2 + J S D ( P d a t a ∣ ∣ P G ) \max\int_xp_{data}(x)log(D(x))+p_{G}(x)log(1-D(x))\\ =\int_xp_{data}(x)log(D^*(x))+p_{G}(x)log(1-D^*(x))\\ =\int_xp_{data}(x)log(\frac{p_{data}(x)}{p_{data}(x)+p_G(x)})+p_{G}(x)log(\frac{p_{G}(x)}{p_{data}(x)+p_G(x)})\\ =-2log2+\int_xp_{data}(x)log(\frac{p_{data}(x)}{(p_{data}(x)+p_G(x))/2})+p_{G}(x)log(\frac{p_{G}(x)}{(p_{data}(x)+p_G(x))/2})\\ =-2log2+KL(P_{data}||\frac{P_{data}+P_G}{2})+KL(P_{G}||\frac{P_{data}+P_G}{2})\\ =-2log2+JSD(P_{data}||P_G) maxxpdata(x)log(D(x))+pG(x)log(1D(x))=xpdata(x)log(D(x))+pG(x)log(1D(x))=xpdata(x)log(pdata(x)+pG(x)pdata(x))+pG(x)log(pdata(x)+pG(x)pG(x))=2log2+xpdata(x)log((pdata(x)+pG(x))/2pdata(x))+pG(x)log((pdata(x)+pG(x))/2pG(x))=2log2+KL(Pdata∣∣2Pdata+PG)+KL(PG∣∣2Pdata+PG)=2log2+JSD(Pdata∣∣PG)
后面的几步其实不是很理解,不过到第三步,跟交叉熵形式很像,所以都是类似的衡量两个分布的差异。


训练一个discriminator,实际上就是为了更好区分真实和生成的样本,那么自然要让这个差异越大越好,此时这个discriminator可以最大程度区分生成和真实。 D ∗ D^* D给的打分实际上可以看做生成分布和实际分布的差异
D ∗ = a r g max ⁡ D V ( G , D ) D^*=arg\max_D V(G,D) D=argDmaxV(G,D)而训练generator的过程就是为了让discriminator不容易区分真实和生成样本,因此要减少这个差异:
D ∗ = a r g min ⁡ G V ( G , D ∗ ) = a r g min ⁡ G max ⁡ D V ( G , D ) D^*=arg\min_G V(G,D^*) =arg\min_G \max_D V(G,D) D=argGminV(G,D)=argGminDmaxV(G,D)
也就是现在有一个最优的discriminator D ∗ D^* D,要优化generator使得 D ∗ D^* D打分尽量高,也就是:
θ g = θ g − η ∂ V ( G , D ∗ ) θ g \theta_g=\theta_g-\eta \frac{\partial V(G,D^*)}{\theta_g} θg=θgηθgV(G,D)这里实际上是对 θ G \theta_G θG也就是生成网络的参数求导,实际的网络架构是: v e c t o r → θ G → o u t → θ D → s c o r e vector\rightarrow \theta_G \rightarrow out \rightarrow \theta_D \rightarrow score vectorθGoutθDscore,这里更新的时候,不更新 θ D \theta_D θD,这也就是固定discriminator的思想。

注意点:每次更新的时候,对discriminator的更新要彻底,对generator的更新次数不能多,如下图:

比如现在训练了一个discriminator是 D 0 ∗ D^*_0 D0,现在要让G变得更强,也就是让 D 0 ∗ D^*_0 D0对G生成的图辨别能力降低,直观体现就是 V ( G , D ) V(G,D) V(G,D)变小,但是因为更新参数对生成分布的影响是全局的,那么就可能导致生成图片和实际分布差异变得更大,因为变小的只有 D 0 ∗ D^*_0 D0的得分,可能这时候 D 0 ∗ D^*_0 D0已经不是最好的discriminator,而更好的discriminator可以将生成的和实际的分的更开,就像图二的最大值必原来的还大,那么对应于更高的 D ∗ D^* D计算得到的差异比原来还大。(有点绕这里)

所以有一个简单的假设,就是generator更新后的图形基本和原来保持一致,那么此时优化最大值让最大值变小,那么就相当于生成分布和实际分布差异更小,要达到这样的目的,那么不能更新generator太多;而对于discriminator,因为要找到最大值,应该要更新彻底。

3.实际操作

上面都是理论上的分析,下面讲一讲在实际的操作中是怎么做的。

3.1.训练discriminator

一个discriminator其实就是一个二分分类器,输入一个生成的数据,给出为真的概率,所以训练的过程也是和训练分类器是一样的,上面提到了 V ( G , D ) V(G,D) V(G,D)优化目标,里面有期望,期望一般会被转化为求多个样本的均值来获得。对于一个确定的生成器G,假设抽样取得m个真实样本X,生成了m个生成样本X’,那么期望可以转化为:
V ( D ) = E x ∼ P d a t a l o g ( D ( x ) ) + E x ∼ P G l o g ( 1 − D ( x ) ) = > V ~ = 1 m ∑ i = 1 m l o g ( D ( x i ) ) + 1 m ∑ i = 1 m l o g ( 1 − D ( x i ′ ) ) V(D)=E_{x\sim P_{data}}log(D(x))+E_{x\sim P_{G}}log(1-D(x))\\ =>\widetilde{V}=\frac{1}{m}\sum_{i=1}^{m}log(D(x_i))+\frac{1}{m}\sum_{i=1}^{m}log(1-D(x'_i)) V(D)=ExPdatalog(D(x))+ExPGlog(1D(x))=>V =m1i=1mlog(D(xi))+m1i=1mlog(1D(xi))
一般会采用梯度上升法:(因为要求最大值)
θ d = θ d + η ▽ θ d V ~ ( θ d ) \theta_d=\theta_d+\eta ▽_{\theta_d}\widetilde{V}(\theta_d) θd=θd+ηθdV (θd)

3.2.训练generator

训练generator实际上是为了减少 V ( G , D ∗ ) V(G,D^*) V(G,D),也就是让目前最好的分类器分不清,还是抽样,生成n个样本X’,那么目标如下:
V ( D ) = 1 m ∑ i = 1 m l o g ( 1 − D ( G ( x i ′ ) ) ) V(D)=\frac{1}{m}\sum_{i=1}^{m}log(1-D(G(x'_i))) V(D)=m1i=1mlog(1D(G(xi)))此时在变的是gegenerator的参数 θ g \theta_g θg,要通过改变生成参数让最好的discriminator得分降低,一般是梯度下降:
θ g = θ g − η ▽ θ g V ~ ( θ g ) \theta_g=\theta_g-\eta ▽_{\theta_g}\widetilde{V}(\theta_g) θg=θgηθgV (θg)要注意,不能训练次数太多(一般一次就可以)。

具体的代码实现我还没有去看过,就不进一步展开了,这一篇主要还是记录一些简单的原理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1809229.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BC11 学生基本信息输入输出

BC11 学生基本信息输入输出 废话不多说上题目&#xff1a; 这道题表面上很简单&#xff0c;但是里面有很重要的点先给大家上正确的代码&#xff1a; #include<stdio.h> int main() {int stu 0;float c 0;float English 0;float math 0;scanf("%d;%f,%f,%f"…

Java入门教程上

常见的cmd命令 类 class 字面量 数据类型 输入 public static void main(String[] args) {Scanner anew Scanner(System.in);int na.nextInt();int ma.nextInt();System.out.println(mn);} } 算数运算符 package wclg;public class test {public static void main(String[] ar…

iOS调整collectionViewCell顺序

效果图 原理 就是设置collectionView调整顺序的代理方法&#xff0c;这里要注意一点 调整过代理方法之后&#xff0c;一定要修改数据源&#xff0c;否则导致错乱。 还有就是在collectionView上面添加一个长按手势&#xff0c;在长按手势的不同阶段&#xff0c;调用collectionV…

【数据结构】AVL树(平衡二叉树)

目录 一、AVL树的概念二、AVL树的节点三、AVL树的插入四、AVL树的旋转1.插入在较高左子树的左侧&#xff0c;使用右单旋2.插入在较高右子树的右侧&#xff0c;使用左单旋3.插入较高左子树的右侧&#xff0c;先左单旋再右单旋4.插入较高右子树的左侧&#xff0c;先右单旋再左单旋…

论文研读 A Comparison of TCP Automatic Tuning Techniques for Distributed Computing

论文《分布式计算中TCP自动调优技术的比较》由Eric Weigle和Wu-chun Feng撰写&#xff0c;探讨了自动调整TCP缓冲区大小以提升分布式应用性能的不同方法。文章首先讨论了手动优化TCP缓冲区大小的局限性&#xff0c;并介绍了研究人员提出的各种自动调优技术来应对这些挑战。 作者…

打造智慧工厂核心:ARMxy工业PC与Linux系统

智能制造正以前所未有的速度重塑全球工业格局&#xff0c;而位于这场革命核心的&#xff0c;正是那些能够精准响应复杂生产需求、高效驱动自动化流程的先进设备。钡铼技术ARMxy工业计算机&#xff0c;以其独特的设计哲学与卓越的技术性能&#xff0c;正成为众多现代化生产线背后…

2024全国大学生数学建模竞赛优秀参考资料分享

0、竞赛资料 优秀的资料必不可少&#xff0c;优秀论文是学习的关键&#xff0c;视频学习也非常重要&#xff0c;如有需要请点击下方名片获取。 一、赛事介绍 全国大学生数学建模竞赛(以下简称竞赛)是中国工业与应用数学学会主办的面向全国大学生的群众性科技活动&#xff0c;旨…

⌈ 传知代码 ⌋ 记忆大师

&#x1f49b;前情提要&#x1f49b; 本文是传知代码平台中的相关前沿知识与技术的分享~ 接下来我们即将进入一个全新的空间&#xff0c;对技术有一个全新的视角~ 本文所涉及所有资源均在传知代码平台可获取 以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦&#x…

MySQL 函数与约束

MySQL 函数与约束 文章目录 MySQL 函数与约束1 函数1.1 字符串函数1.2 数值函数1.3 日期函数1.4 流程函数 2 约束2.1 概述2.2 约束演示2.3 外键约束2.4 删除/更新行为 1 函数 函数是指一段可以直接被另一程序调用的程序或代码。 1.1 字符串函数 MySQL中内置了很多字符串函数&…

Python学习打卡:day02

day2 笔记来源于&#xff1a;黑马程序员python教程&#xff0c;8天python从入门到精通&#xff0c;学python看这套就够了 8、字符串的三种定义方式 字符串在Python中有多种定义形式 单引号定义法&#xff1a; name 黑马程序员双引号定义法&#xff1a; name "黑马程序…

攻防演练之-网络集结号

每一次的网络安全攻防演练都是各个安全厂商期待的网络安全盛会&#xff0c;因为目前的安全生态导致了只有在网络安全攻防演练期间&#xff0c;网络安全的价值才会走向台前&#xff0c;收到相关方的重视。虽然每一次都会由于各种原因不能如期举行&#xff0c;但是这一次的推迟总…

浏览器阻止屏幕息屏,js阻止浏览器息屏,Web网页阻止息屏

场景: 比如打开一个浏览器页面(比如大屏),想让它一直显示着,而不是过几分钟不操作就屏幕黑了.(电脑有设置电脑不操作就会多长时间就会息屏睡眠,如果要求每个客户都去操作一下电脑设置一下从不睡眠,这很不友好和现实.而且我也只想客户在大屏的时候才这样,其他页面就正常,按电脑设…

eNSP学习——RIP的路由引入

目录 主要命令 原理概述 实验目的 实验内容 实验拓扑 实验编址 实验步骤 1、基本配置 2、搭建公司B的RIP网络 3、优化公司B的 RIP网络 4、连接公司A与公司B的网络 需要eNSP各种配置命令的点击链接自取&#xff1a;华为&#xff45;NSP各种设备配置命令大全PDF版_ensp…

RAG检索与生成的融合

1、rag定义 检索增强生成 (RAG) 模型代表了检索系统和生成模型两大不同但互补组件完美结合的杰作。通过无缝整合相关信息检索和生成与背景相关的响应&#xff0c;RAG模型在人工智能领域达到了前所未有的复杂程度。 2、rag工作流程 2.1、rag整体框架 query通过llm处理后&…

【Golang】Go语言中defer与return的精妙交织:探索延迟执行与返回顺序的微妙关系

【Golang】Go语言中defer与return的精妙交织&#xff1a;探索延迟执行与返回顺序的微妙关系 大家好 我是寸铁&#x1f44a; 总结了一篇defer 和 return 返回值 的执行顺序探讨的文章✨ 喜欢的小伙伴可以点点关注 &#x1f49d; 前言 在Go语言中&#xff0c;defer 和return是两…

Codeforces Round 951 (Div. 2) 题解分享

A. Guess the Maximum 思路 贪心 毫无疑问的是&#xff0c;Alice会选择所有区间最大值的最小值-1&#xff0c;即。 关键是如何选取。我们注意到区间长度越大&#xff0c;这个区间的最大值是随着它不减的&#xff0c;所以如果Bob要让Alice选的最小的话&#xff0c;选择的区间…

1 c++多线程创建和传参

什么是进程&#xff1f; 系统资源分配的最小单位。 什么是线程&#xff1f; 操作系统调度的最小单位&#xff0c;即程序执行的最小单位。 为什么需要多线程&#xff1f; &#xff08;1&#xff09;加快程序执行速度和响应速度, 使得程序充分利用CPU资源。 &#xff08;2&…

FastAPI:在大模型中使用fastapi对外提供接口

通过本文你可以了解到&#xff1a; 如何安装fastapi&#xff0c;快速接入如何让大模型对外提供API接口 往期文章回顾&#xff1a; 1.大模型学习资料整理&#xff1a;大模型学习资料整理&#xff1a;如何从0到1学习大模型&#xff0c;搭建个人或企业RAG系统&#xff0c;如何评估…

Helm离线部署Rancher2.7.10

环境依赖&#xff1a; K8s集群、helm 工具 Rancher组件架构 Rancher Server 包括用于管理整个 Rancher 部署的所有软件组件。 下图展示了 Rancher 2.x 的上层架构。下图中&#xff0c;Rancher Server 管理两个下游 Kubernetes 集群 准备Rancher镜像推送到私有仓库 cat >…

RPA-UiBot6.0数据整理机器人—杂乱数据秒变报表

前言 友友们是否常常因为杂乱的数据而烦恼?数据分类、排序、筛选这些繁琐的任务是否占据了友友们的大部分时间?这篇博客将为友友们带来一个新的解决方案,让我们共同学习如何运用RPA数据整理机器人,实现杂乱数据的快速整理,为你的工作减负增效! 在这里,友友们将了…