Gumbel-Softmax简介

news2024/11/25 20:49:04

一、Gumbel Softmax trick的使用场景

1. argmax简介

在NLP领域的强化学习或者对抗学习中,token的生成是离散的。比如,一个token的产生是一个大小为vocab size的one-hot向量。比如,对于character level的token: [ 1 , 0 , 0 , 0 , . . . , 0 ] [1, 0, 0, 0, ..., 0] [1,0,0,0,...,0]代表a, [ 0 , 1 , 0 , 0 , . . . , 0 ] [0, 1, 0, 0,..., 0] [0,1,0,0,...,0] 代表b。具体选择哪个token就根据输出的每个维度的大小,选择预测概率最大作为输出token,即 a r g   m a x arg \ max arg max操作。

如图:
在这里插入图片描述
以四维向量 v v v为例,对其做argmax操作,得到的one-hot vector为 [ 0 , 1 , 0 , 0 ] [0, 1, 0, 0] [0,1,0,0]。虽然该方法可以得到正确的分类,但是显而易见,argmax是不可导的。

2. softmax简介

在一般的分类问题中,为了解决argmax不可导的问题,通常选择softmax方法,softmax即是argmax的光滑近似。这种方法通过把向量归一化,既可以计算梯度,同时值的大小还可以代表概率的含义。

如图:

在这里插入图片描述

在经过softmax后,既不会改变动作或者说类别的选取,同时softmax还倾向于让最大值的概率显著大于其他值(比如10和6.2在经过softmax后变成了0.59和0.01),这样更有利于将网络训练成一个one-hot形式。

但是,softmax还有一个问题,就是softmax后的向量并不能真正显示概率的含义。比如对于两个softmax后的向量 [ 0 , 0.59 , 0.39 , 0.01 ] [0, 0.59, 0.39, 0.01] [0,0.59,0.39,0.01] [ 0 , 0.99 , 0.01 , 0 ] [0, 0.99, 0.01, 0] [0,0.99,0.01,0],两者都是选择第二个分类,但是其在概率上的表示可谓是天差地别。

因此,我们需要一种算法,既可以选出动作,还要遵从概率的含义。这时,最直觉的办法就是根据概率采样,这既可以选出动作,又遵从概率的含义,但是,采样不能求导。

3. 为什么采样过程需要求导

对于一般的分类问题,我们只需计算最后一层的softmax,然后与标签(one-hot vector)求交叉熵损失就可以完成网络的训练,这种问题其实是不需要sample的。因为sample就是最终的目的,生成的one-hot就是最后要完成的任务,是固定的(即标签)。

但是对于另一些问题,sample只是中间的步骤,sample是不固定的(即没有具体的one-hot),是需要训练的,如VAE和GAN,这个时候sample变成了一种优化的任务,因此必须要保证其可导性。

二、Reparameterization Trick

我们知道,模型的训练图需要各处都能传回梯度进行训练,而采样这一操作会打破这一链条。采样的意义无非是引入随机性。既然这样,就把“随机性的引入”和“计算图的构建”这两个属性剥离开。Reparameterization Trick就是这个思路。

以离散情况为例:

假设从一个模型中得到一个概率分布 p p p,需要从p中得到一个具体的sample进行后面的计算。假设 p = [ 0.1 , 0.6 , 0.1 , 0.2 ] p=[0.1, 0.6, 0.1, 0.2] p=[0.1,0.6,0.1,0.2],分别对应四个不同的选择,现在需要按概率进行采样。直觉上来说,直接选择第二个就行了,但是真是这样吗。仔细想一下,我们现在是要训练模型,模型还没有训练好, p p p的输出分布也很不靠谱,没道理选最大的。因此,需要为sample引入随机性。sample过程引入随机性的意义就是“搜索”,让模型“搜索”所有可能的选择,然后根据loss回调参数,最终训练到合理的sample策略。

为了在采样的同时不破坏计算图的梯度传播,我们不直接在 p p p上进行随机操作,而是引入Gumbel分布,通过它来提供sample需要的随机性。

三、Gumble-Softmax Trick

1. Gumbel Max Trick

Gumbel Max提供了一种从类别分布中采用的算法。

z z z是一个分类变量,类概率为 π 1 , π 2 , . . . , π k \pi_1, \pi_2,...,\pi_k π1,π2,...,πk,从类别概率为 π \pi π的分布中提取样本 z z z,加上Gumbel噪声,可得:
z = a r g   m a x i [ g i + l o g   π i ] z=arg \ max_i[g_i + log \ \pi_i] z=arg maxi[gi+log πi]
其中, g g g是独立同分布的标准Gumbel分布的随机变量。标准Gumbel分布的CDF为 F ( x ) = e − e − x F(x)=e^{-e^{-x}} F(x)=eex

g i g_i gi通过Gumbel分布求逆从均匀分布中生成,即
g i = − l o g ( − l o g ( ϵ i ) ) ,   ϵ i ∼ U ( 0 , 1 ) g_i = -log(-log(\epsilon_i )), \ \epsilon_i \sim U(0, 1) gi=log(log(ϵi)), ϵiU(0,1)
数学上可以证明(网上很多,此处略),这个过程精确等价于依概率 π 1 , π 2 , . . . , π k \pi_1, \pi_2, ..., \pi_k π1,π2,...,πk采样一个类别。即,输出的 z i z_i zi的概率刚好是 π i \pi_i πi。由于现在的随机性已经转移到 U [ 0 , 1 ] U[0,1] U[0,1]上去了,并且 U [ 0 , 1 ] U[0, 1] U[0,1]无未知参数,因此Gumbel Max就是离散分布的一个重参数过程。

考虑到arg max是不可导的,我们这里仍需要argmax的光滑近似,softmax。

2. Gumbel Softmax Trick

y i = e x p ( ( l o g ( π i ) + g i ) / τ ) ∑ j = 1 k e x p ( ( l o g ( π j ) + g i ) / τ ) ,   f o r   i = 1 , . . . , k y_i = \frac{exp((log(\pi_i) + g_i)/\tau )}{\sum^k_{j=1}exp((log(\pi_j) + g_i)/\tau)}, \ for \ i=1, ..., k yi=j=1kexp((log(πj)+gi)/τ)exp((log(πi)+gi)/τ), for i=1,...,k

其中,参数 τ > 0 \tau>0 τ>0称为退火参数,它越小结果就越接近one-hot形式(同时梯度消失越严重),越大就越接近均匀分布。

注意:Gumbel Softmax不是采样类别的等价形式,Gumbel Max才是。Gumbel Max可以看做Gumbel Softmax在 τ → 0 \tau \rightarrow 0 τ0时的极限。因此在应用Gumbel Softmax时,可以先选择较大的 τ \tau τ,之后再慢慢退火到一个接近0的数。

引用:

漫谈重参数:从正态分布到Gumbel Softmax

Gumbel Softmax 是什么?

重参数化技巧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/844975.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

阻抗是什么?什么时候要考虑阻抗匹配?

在电路设计中,我们常常碰到跟阻抗有关的问题,那么到底什么是阻抗? 在具有电阻、电感和电容的电路里,对电路中电流所起的阻碍作用叫做阻抗。常用Z来表示,它的值由交流电的频率、电阻R、电感L、电容C相互作用来决定。由…

Mybatis异常Invalid bound statement (not found)原因之Mapper文件配置不匹配

模拟登录操作 $.post("/admin/login", {aname, pwd }, rt > {if (rt.code 200) {location.href "manager/index.html";return;}alert(rt.msg)});网页提示服务器代码错误 POST http://localhost:8888/admin/login 500后端显示无法找到Mapper中对应的…

ros tf

欢迎访问我的博客首页。 tf 1. tf 命令行工具1.1 发布 tf1.2 查看 tf 2.参考 1. tf 命令行工具 1.1 发布 tf 我们根据 cartographer_ros 的 launch 文件 backpack_2d.launch 写一个 tf.launch,并使用命令 roslaunch cartographer_ros tf.launch 启动。该 launch 文件…

wpf 项目中使用 Prism + MaterialDesign

1.通过nuget安装MaterialDesign 2.通过nuget安装Prism 3.修改App.xmal <prism:PrismApplication x:Class"VisionMeasureGlue.App"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/…

简单易懂的Transformer学习笔记

1. 整体概述 2. Encoder 2.1 Embedding 2.2 位置编码 2.2.1 为什么需要位置编码 2.2.2 位置编码公式 2.2.3 为什么位置编码可行 2.3 注意力机制 2.3.1 基本注意力机制 2.3.2 在Trm中是如何操作的 2.3.3 多头注意力机制 2.4 残差网络 2.5 Batch Normal & Layer Narmal 2.…

C++入门篇5---模板

相信大家都遇到过这么一种情况&#xff0c;为了满足不同类型的需求&#xff0c;我们要写多个功能相同&#xff0c;参数类型不同的代码&#xff0c;为此&#xff0c;C引入了泛型编程这一概念&#xff0c;而模板就是实现泛型编程的基础&#xff0c;其实本质就是我们写一个类似”模…

JVM、JRE、JDK三者之间的关系

JVM、JRE和JDK是与Java开发和运行相关的三个重要概念。 再了解三者之前让我们先来了解下java源文件的执行顺序&#xff1a; 使用编辑器或IDE(集成开发环境)编写Java源文件.即demo.java程序必须编译为字节码文件&#xff0c;javac(Java编译器)编译源文件为demo.class文件.类文…

JavaScript + GO 通过 AES + RSA 进行数据加解密

浏览器端搞些小儿科的加密&#xff0c;就好比在黑暗夜空中&#xff0c;点缀了几颗星星&#xff0c;告诉黑客「这里有宝贵信息&#xff0c;快来翻牌」 浏览器端的加密&#xff0c;都是相对安全的。 它的具体安危&#xff0c;取决于里面存在的信息价值&#xff0c;是否值得破解者…

GO学习之 网络通信(Net/Http)

GO系列 1、GO学习之Hello World 2、GO学习之入门语法 3、GO学习之切片操作 4、GO学习之 Map 操作 5、GO学习之 结构体 操作 6、GO学习之 通道(Channel) 7、GO学习之 多线程(goroutine) 8、GO学习之 函数(Function) 9、GO学习之 接口(Interface) 10、 文章目录 GO系列前言一、H…

CAPL - XML和TestModule结合实现测试项可选(续)

二、xml文件编写 1、设置xml文件版本号 这个方便我们对xml文件进行文件管理,对于后续工作有进一步帮助。 <?xml version="1.0" ?> 2、设置xml根元素 在CANoe中使用的xml文件根元素我统一都会设置为testmodule,这也是我们在CANoe软件中选择测试用例的最大…

微服务间消息传递

微服务间消息传递 微服务是一种软件开发架构&#xff0c;它将一个大型应用程序拆分为一系列小型、独立的服务。每个服务都可以独立开发、部署和扩展&#xff0c;并通过轻量级的通信机制进行交互。 应用开发 common模块中包含服务提供者和服务消费者共享的内容provider模块是…

七、ESP32 4位数码管显示数字

1. 运行后的效果 可以显示0~9999之间的任何数字 2. 4位数码管与ESP32链接方式 3. 代码</

java网络编程概述及例题

网络编程概述 计算机网络 把分布在不同地理区域的计算机与专门的外部设备用通信线路连成一个规模大、功能强的网络系统&#xff0c;从而使众多的计算机可以方便地互相传递信息、共享硬件、软件、数据信息等资源。 网络编程的目的 直接或间接地通过网络协议与其他计算机实现…

每天一道leetcode:剑指 Offer 53 - II. 0~n-1中缺失的数字(适合初学者二分查找)

今日份题目&#xff1a; 一个长度为n-1的递增排序数组中的所有数字都是唯一的&#xff0c;并且每个数字都在范围0&#xff5e;n-1之内。在范围0&#xff5e;n-1内的n个数字中有且只有一个数字不在该数组中&#xff0c;请找出这个数字。 示例1 输入: [0,1,3] 输出: 2 示例2 …

Linux 信号signal处理机制

Signal机制在Linux中是一个非常常用的进程间通信机制&#xff0c;很多人在使用的时候不会考虑该机制是具体如何实现的。signal机制可以被理解成进程的软中断&#xff0c;因此&#xff0c;在实时性方面还是相对比较高的。Linux中signal机制的模型可以采用下图进行描述。 每个进程…

网络编程——数据报的组装和拆解

数据包的组装和拆解 一、数据包在各个层之间的传输 二、各个层的封包格式 1、链路层封包格式 -------------------------------------------------------------------------------------------------------------------------------------- | 目标MAC地址&#xff08;6字节&a…

Chatgpt AI newbing作画,文字生成图 BingImageCreator 二次开发,对接wxbot

开源项目 https://github.com/acheong08/BingImageCreator 获取cookie信息 cookieStore.get("_U").then(result > console.log(result.value)) pip3 install --upgrade BingImageCreator import os import BingImageCreatoros.environ["http_proxy"]…

一、Webpack相关(包括webpack-dev-server用以热更新和html-webpack-plugin)

概念与功能&#xff1a; webpack是前端项目工程化的具体解决方案。它提供了友好的前端模块化开发支持&#xff0c;以及代码压缩混淆、处理浏览器端JavaScript的兼容性、性能优化等强大的功能。 快速上手&#xff1a;隔行变色 -S实际是--save的简写&#xff0c;表示安装的第三方…

Mysql存储引擎InnoDB

一、存储引擎的简介 MySQL 5.7 支持的存储引擎有 InnoDB、MyISAM、Memory、Merge、Archive、Federated、CSV、BLACKHOLE 等。 1、InnoDB存储引擎 从MySQL5.5版本之后&#xff0c;默认内置存储引擎是InnoDB&#xff0c;主要特点有&#xff1a; &#xff08;1&#xff09;灾难恢…

分享21年电赛F题-智能送药小车-做题记录以及经验分享

这里写目录标题 前言一、赛题分析1、车型选择2、巡线1、OpenMv循迹2、灰度循迹 3、装载药品4、识别数字5、LED指示6、双车通信7、转向方案1、开环转向2、位置环速度环闭环串级转向3、MPU6050转向 二、调试经验分享1、循迹2、识别数字3、转向4、双车通信5、逻辑处理6、心态问题 …