【RL】Wasserstein距离-GAN背后的直觉

news2024/9/25 9:38:04

一、说明

        在本文中,我们将阅读有关Wasserstein GANs的信息。具体来说,我们将关注以下内容:i)什么是瓦瑟斯坦距离?,ii)为什么要使用它?iii) 我们如何使用它来训练 GAN?

二、Wasserstein距离概念

        Wasserstein距离,又称为Earth Mover's Distance (EMD),是衡量两个概率分布之间的差异程度的一种数学方式。它考虑了分布之间的距离和它们之间的“传输成本”。

        简单来说,Wasserstein距离将两个分布看作“堆积在地图上的土堆”,并计算将一个堆移到另一个的最小成本。这个距离度量的优点是它能够处理非均匀分布,并且能够考虑分布的形状和结构。

        Wasserstein距离在机器学习领域中应用非常广泛,特别是在生成模型中用来评估生成器生成的图像与真实图像之间的差异。

图1:学习区分两个高斯时的最优判别器和批评者[1]。

2.1 瓦瑟施泰因距离

        Wasserstein 距离(地球移动器的距离)是给定度量空间上两个概率分布之间的距离度量。直观地说,它可以被视为将一个分布转换为另一个分布所需的最小功,其中功被定义为必须移动的分布的质量和要移动的距离的乘积。在数学上,它被定义为:

方程 1:瓦瑟斯坦 分布P_r和P_g之间的距离。

        在方程1中,Π(P_r,P_g)是x和y上所有联合分布的集合,使得边际分布等于P_r和P_g。 γ(x, y)可以看作是必须从x移动到y才能将P_r转换为P_g的质量量[1]。因此,瓦瑟斯坦距离是最佳运输计划的成本。

2.2 瓦瑟斯坦距离 vs. 詹森-香农分歧

        最初的GAN目标被证明是Jensen-Shannon分歧的最小化[2]。JS背离定义为:

方程 2:P_r 和 P_g 之间的 JS 背离 P_m = (P_r + P_g)/2

        

        与JS相比,Wasserstein距离具有以下优点:

  • Wasserstein 距离是连续的,几乎可以在任何地方微分,这使我们能够训练模型达到最佳状态。
  • 随着鉴别器的变好,JS散度局部饱和,因此梯度变为零并消失。
  • Wasserstein 距离是一个有意义的度量,即当分布彼此靠近时,它收敛到 0,当它们越来越远时发散。
  • 作为目标函数的 Wasserstein 距离比使用 JS 散度更稳定。当使用Wasserstein距离作为目标函数时,模式崩溃问题也得到了缓解。

        从图 1 我们清楚地看到,最佳GAN鉴别器饱和并导致梯度消失,而优化Wasserstein距离的WGAN评论家在整个过程中具有稳定的梯度。

        有关数学证明和更详细的研究,请查看此处的论文!

三、瓦瑟斯坦·GAN

        现在可以清楚地看到,优化 Wasserstein 距离比优化 JS 散度更有意义,还需要注意的是,方程 1 中定义的 Wasserstein 距离非常棘手[3],因为我们不可能计算所有 γ ∈Π(Pr ,Pg) 的下界(最大下界)。然而,从坎托罗维奇-鲁宾斯坦二元性中,我们有,

公式3:1-利普希茨条件下的瓦瑟斯坦距离。

        这里我们有 W(P_r, P_g) 作为所有 1-Lipschitz 函数 f: X → R 的上确界(最低上限)。

        K-利普希茨连续性:给定 2 个度量空间 (X, d_X) 和 (Y, d_Y),变换函数 f: X → Y 是 K-利普希茨连续的,如果

公式3:K-Lipschitz连续性。

        其中d_X和d_Y是各自度量空间中的距离函数。当一个函数是 K-Lipschitz 时,从方程 2 开始,我们最终得到 K ∙ W(P_r, P_g)。

        现在,如果我们有一系列参数化函数 {f_w},其中 w∈W 是 K-Lipschitz 连续的,我们可以有

公式 4

即,w∈W 最大化方程 4 给出瓦瑟斯坦距离乘以一个常数。

四、WGAN评论家

        为此,WGAN引入了一个批评者,而不是我们在GAN中了解到的鉴别器。批评者网络在设计上类似于判别器网络,但通过优化找到将最大化方程 4 的 w* 来预测 Wasserstein 距离。为此,批评家的客观功能如下:

公式5:批评家客观函数。

       在这里,为了在函数f上强制执行Lipschitz连续性,作者诉诸于将权重w限制在一个紧凑的空间内。这是通过将砝码夹紧到一个小范围(论文中的[-1e-2,1e-2][1])来完成的。

鉴别器和批评者之间的区别在于,鉴别器经过训练以正确识别P_r样本和P_g样本,批评家估计P_r和P_g之间的Wasserstein距离。

这是训练批评家的python代码。

for ix in n_critic_steps:
  opt_critic.zero_grad()

  real_images = data[0].float().to(device)

  # * Generate images
  noise = sample_noise()
  fake_images = netG(noise)

  # * though they are name so, they are not logits!
  real_logits = netCritic(real_images)
  fake_logits = netCritic(fake_images)

  # * max E_{x~P_X}[C(x)] - E_{Z~P_Z}[C(g(z))]
  loss = -(real_logits.mean() - fake_logits.mean())

  loss.backward(retain_graph=True)
  opt_critic.step()

  # * Gradient clippling
  for p in netCritic.parameters():
      p.data.clamp_(-self.c, self.c)

五、WGAN生成器目标

        当然,发电机的目标是最小化P_r和P_g之间的瓦瑟斯坦距离。生成器试图找到最小化P_g和P_r之间的 Wasserstein 距离的 θ*。为此,生成器的目标函数如下:

        公式 6:生成器目标函数。

        在这里,WGAN生成器和标准生成器之间的主要区别再次在于,WGAN生成器试图最小化P_r和P_g之间的Wasserstein距离,而标准生成器试图用生成的图像欺骗鉴别器。

        以下是训练生成器的 python 代码:

opt_gen.zero_grad()

noise = sample_noise()

fake_images = netG(noise)

# again, these are not logits.
fake_logits = netCritic(fake_images)

# * - E_{Z~P_Z}[C(g(z))]
loss = -fake_logits.mean().view(-1)

loss.backward()
opt_gen.step()

六、培训结果

fig2:WGAN训练的早期结果[3]。

        图例.2显示了训练WGAN的一些早期结果。请注意,图 2 中的图像是早期结果,一旦确认模型按预期训练,训练就会停止。

七、代码

        Wasserstein GAN的完整实现可以在这里找到[3]。

八、结论

        WGAN提供非常稳定的培训和有意义的培训目标。本文介绍并直观地解释了什么是 Wasserstein 距离,Wasserstein 距离相对于标准 GAN 使用的 Jensen-Shannon 散度的优势,以及如何使用 Wasserstein 距离来训练 WGAN。我们还看到了用于训练 Critic 和生成器的代码片段,以及早期训练模型的大量输出。尽管WGAN比标准GAN具有许多优势,但WGAN论文的作者明确承认,权重裁剪不是执行Lipschitz连续性的最佳方法[1]。为了解决这个问题,他们提出了带有梯度惩罚的Wasserstein GAN[4],我们将在后面的文章中讨论。

        如果您喜欢这个,请查看本系列的下一篇文章,其中讨论了 WGAN-GP!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/838910.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件工程专业应该学什么?

昨天,我朋友的孩子报考了软件工程专业,问我软件工程到底学啥?所以我给他开列了一个书单。 现在高校开了一堆花名头的专业: 偏技术类:云计算、大数据、人工智能、物联网 偏应用类:电子商务、信息管理 但我个…

flink1.17 eventWindow不要配置processTrigger

理论上可以eventtime processtime混用,但是下面代码测试发现bug,输入一条数据会一直输出. flink github无法提bug/问题. apache jira账户新建后竟然flink又需要一个账户,放弃 bug复现操作 idea运行代码后 往source kafka发送一条数据 a,1,1690304400000 可以看到无限输出…

.net 6 efcore一个model映射到多张表(非使用IEntityTypeConfiguration)

现在有两张表,结构一模一样,我又不想创建两个一模一样的model,就想一个model映射到两张表 废话不多说直接上代码 安装依赖包 创建model namespace oneModelMultiTable.Model {public class Test{public int id { get; set; }public string…

【C语言进阶】数据的存储----浮点型篇

🍁 博客主页:江池俊的博客 💫收录专栏:C语言—探索高效编程的基石 💻 其他专栏:数据结构探索 ​💡代码仓库:江池俊的代码仓库 🎪 社区:GeekHub 🍁 如果觉得博…

部分常用CSS样式

目录 1.字体样式 2.文本样式 3.鼠标样式 cursor 4.背景样式 5.列表样式 6.CSS伪类 7.盒子模型 1.字体样式 font-family 字体类型:隶书” “楷体” font-size 字体大小:像素px font-weight 字体粗细:bold 定义粗体字…

8月5日上课内容 nginx的优化和防盗链

全部都是面试题 nginx的优化和防盗链 重点就是优化: 每一个点都是面试题,非常重要,都是面试题 1、隐藏版本号(重点,一定要会) 备份 cp nginx.conf nginx.conf.bak.2023.0805 方法一:修改配…

拦截器在SpringBoot中使用,HandlerInterceptor,WebMvcConfigurer

拦截器在Controller之前执行。 用于权限校验,日志记录,性能监控 在SpringBoot中使用 创建拦截器类:首先,创建一个Java类来实现拦截器逻辑。拦截器类应该实现Spring提供的HandlerInterceptor接口。实现拦截器方法:拦…

探索PostgreSQL的新功能:最新版本更新解析

PostgreSQL作为一种强大而开源的关系型数据库管理系统,不断在不断进化和改进。每一次的版本更新都带来了更多功能和改进,让用户在处理大规模数据和复杂查询时体验更好的性能和功能。在本文中,我们将深入探索PostgreSQL的最新版本更新&#xf…

进程上下文切换以及应用场景

各个进程之间是共享 CPU 资源的,在不同的时候进程之间需要切换,让不同的进程可以在 CPU 执行,那么这个一个进程切换到另一个进程运行,称为进程的上下文切换。 在详细说进程上下文切换前,我们先来看看 CPU 上下文切换 大…

VX-API-Gateway开源网关技术的使用记录

VX-API-Gateway开源网关技术的使用记录 官网地址 https://mirren.gitee.io/vx-api-gateway-doc/ VX-API-Gateway(以下称为VX-API)是基于Vert.x (java)开发的 API网关, 是一个分布式、全异步、高性能、可扩展、轻量级的可视化配置的API网关服务官网下载程序zip包 访问 https:/…

深入浅出 Typescript

TypeScript 是 JavaScript 的一个超集,支持 ECMAScript 6 标准(ES6 教程)。 TypeScript 由微软开发的自由和开源的编程语言。 TypeScript 设计目标是开发大型应用,它可以编译成纯 JavaScript,编译出来的 JavaScript …

AtcoderABC226场

A - Round decimalsA - Round decimals 题目大意 给定一个实数X,它最多可以使用三位小数表示,而且X的小数点后有三位小数。将X四舍五入到最接近的整数并打印结果。 思路分析 可以使用round函数进行四舍五入 知识点 round(x) 是一个用来对数字进行四…

SpringIoc-个人学习笔记

Spring的Ioc、DI、AOP思想 Ioc Ioc思想:Inversion of Control,控制反转,在创建Bean的权利反转给第三方 DI DI思想:Dependency Injection,依赖注入,强调Bean之间的关系,这种关系由第三方负责去设…

Redis 报错 RedisConnectionException: Unable to connect to x.x.x.x:6379

文章目录 Redis报错类型可能解决方案 Redis报错类型 org.springframework.data.redis.connection. spingboot调用redis出错 PoolException: Could not get a resource from the pool; 连接池异常:无法从池中获取资源; nested exception is io.lettuce.core. 嵌套异常 RedisConn…

针对高可靠性和高性能优化的1200V碳化硅沟道MOSFET

目录 标题:1200V SiC Trench-MOSFET Optimized for High Reliability and High Performance摘要信息解释研究了什么文章创新点文章的研究方法文章的结论 标题:1200V SiC Trench-MOSFET Optimized for High Reliability and High Performance 摘要 本文详…

FPGA----UltraScale+系列的PS侧与PL侧通过AXI-HP交互(全网唯一最详)附带AXI4协议校验IP使用方法

1、之前写过一篇关于ZYNQ系列通用的PS侧与PL侧通过AXI-HP通道的文档,下面是链接。 FPGA----ZCU106基于axi-hp通道的pl与ps数据交互(全网唯一最详)_zcu106调试_发光的沙子的博客-CSDN博客大家好,今天给大家带来的内容是&#xff0…

获取k8s scale资源对象的命令

kubectl get --raw /apis/<apiGroup>/<apiVersion>/namespaces/<namespaceName>/<resourceKind>/<resourceName>/scale 说明&#xff1a;scale资源对象用来水平扩展k8s资源对象的副本数&#xff0c;它是作为一种k8s资源对象的子资源存在&#xf…

This function has none of DETERMINISTIC, NO SQL, or READS SQL DATA in...错误解决

在创建函数的时候报错如下&#xff1a; 解决&#xff1a; 设置如下参数即可 SET GLOBAL log_bin_trust_function_creatorsTRUE;

[CKA]考试之PersistentVolumeClaims

由于最新的CKA考试改版&#xff0c;不允许存储书签&#xff0c;本博客致力怎么一步步从官网把答案找到&#xff0c;如何修改把题做对&#xff0c;下面开始我们的 CKA之旅 题目为&#xff1a; Task 创建一个名字为pv-volume的pvc&#xff0c;指定storageClass为csi-hostpath-…

面试热题(打家窃舍)

一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响小偷偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的房屋在同一晚上被小偷闯入&#xff0c;系统会自动报警。 给定一个代表每个房屋存放金额的非负…