gumbel-softmax如何实现离散分布可微+torch代码+原理+证明

news2024/11/15 4:26:15

文章目录

    • 背景
    • 方法通俗理解
    • 什么是重参数化
    • gumbel-softmax
    • 为什么是gumbel
    • torch实现
    • 思考

在这里插入图片描述

背景

这里举一个简单的情况,当前我们有p1, p2, p3三个概率,我们需要得到最优的一个即max(p1, p2, p3),例如当前p3 = max(p1, p2, p3),那么理想输出应当为[0, 0, 1],然后应用于下游的优化目标,这种场景在搜索等场景经常出现。
如果暴力的进行clip或者mask操作转化为独热向量的话会导致在梯度反向传播的时候无法更新上游网络。因为p1和p2对应的梯度一定为0。

方法通俗理解

针对上述情况,采用重参数化的思路可以解决。
即然每次前向传播理想情况下是0-1独热向量向量,但同时能保证[p1, p2, p3]这个分布能被根据概率被更新。于是采用了一种重参数化的方法,即从每次都从一个分布中采样一个u,这个u属于一个均匀分布,从这个均匀分布通过转换变成[p1, p2, p3]这个分布。这样就能即保证梯度可以反向传播,同时根据每次采样来实现按照[p1, p2, p3]这个分布更新,而不是每次只能更新最大的一个。
而这种方法就是重参数化。

什么是重参数化

Reparameterization,重参数化,这是一个方法论,是一种技巧。
我们首先可以抽象出来它的数学表达形式:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) \end{equation} Lθ=Ezpθ(z)(fθ(z))
注意:在有些时候 θ ′ ∈ θ \theta' \in \theta θθ或者 θ ′ = θ \theta' = \theta θ=θ
如何理解:这里我们的优化目标是 L θ L_{\theta} Lθ,其中 f θ ( ) f_{\theta}() fθ()一般是我们的模型,而计算 z z z是从分布 p θ ′ ( z ) p_{\theta'}(z) pθ(z)中采样得到的。但是问题是我们不能把一个分布输入到 f θ ( ) f_{\theta}() fθ()中去,只能从选择一个特定的 z z z,但是这样就没法更新 θ ′ \theta' θ
综上,重参数化就是从给定分布中采样得到一个 z z z,同时保证了梯度可以更新 θ ′ \theta' θ,这种保证采样分布和给定分布无损转换的采样策略叫做重参数化。(个人理解,欢迎大佬指正)

由于我们现在解决的是gumbel-softmax问题,所以只关注当 p θ ′ ( z ) p_{\theta'}(z) pθ(z)是离散的情况下,此时:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) = ∑ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) = \sum p_{\theta'}(z)(f_{\theta}(z)) \end{equation} Lθ=Ezpθ(z)(fθ(z))=pθ(z)(fθ(z))
这也就是gumbel-softmax要解决的数学形式。

gumbel-softmax

gumbel-softmax给出的采样方案,叫做gumbel max:
从原来的 a r g m a x i ( [ p 1 , p 2 , . . . ] ) argmax_i([p1, p2, ...]) argmaxi([p1,p2,...]) a r g m a x i ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) , ϵ i ∈ U [ 0 , 1 ] argmax_i(log(p_i)-log(-log(\epsilon_i))), \epsilon_i \in U[0, 1] argmaxi(log(pi)log(log(ϵi))),ϵiU[0,1]

也就是先算出各个概率的对数 l o g ( p i ) log(p_i) log(pi),然后从均匀分布 U U U中采样随机数 ϵ i \epsilon_i ϵi,把 − l o g ( − l o g ϵ i ) −log(−log\epsilon_i) log(logϵi)加到 l o g ( p i ) log(p_i) log(pi),然后再进行后续操作。
这里可以理解为通过 ϵ \epsilon ϵ的采样将随机性增加。有的人会疑问,为什么格式变得这么复杂,各种算log,这是为什么?这个就涉及到下一节了,具体原因就是来保证数学的变换正确性,即我增加了随机性,但是保证分布的期望仍然是和原始[p1, p2, p3]是一致的,这个证明在下一节,是有比较严谨的数学证明的。

但是这里还有一个问题,就是argmax或者说onehot操作仍然会丢失梯度,所以采用带超参 τ \tau τ的softmax,来进行平滑:
s o f t m a x ( ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) / τ ) \begin{equation} softmax((log(p_i)-log(-log(\epsilon_i)))/\tau) \end{equation} softmax((log(pi)log(log(ϵi)))/τ)
其中 τ \tau τ也被称为退火参数,用来调整平滑的程度: τ \tau τ越小,越接近onhot向量。

这里也解释清楚了所谓gumbel-softmax是通过gumbel max实现重参数化,通过带退火参数的softmax实现梯度反向传递。

为什么是gumbel

这就涉及到一个gumbel max的证明了。
目标是证明针对 l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) log(p_i)-log(-log(\epsilon_i)) log(pi)log(log(ϵi)),当 a r g m a x i ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) = 1 argmax_i(log(p_i)-log(-log(\epsilon_i))) = 1 argmaxi(log(pi)log(log(ϵi)))=1时,其概率为 p 1 p_1 p1

假设:
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) log(p_1)-log(-log(\epsilon_1)) log(p1)log(log(ϵ1)) 最大

则:
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 2 ) − l o g ( − l o g ( ϵ 2 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_2)-log(-log(\epsilon_2)) log(p1)log(log(ϵ1))>log(p2)log(log(ϵ2))
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 3 ) − l o g ( − l o g ( ϵ 3 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_3)-log(-log(\epsilon_3)) log(p1)log(log(ϵ1))>log(p3)log(log(ϵ3))

l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 2 ) − l o g ( − l o g ( ϵ 2 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_2)-log(-log(\epsilon_2)) log(p1)log(log(ϵ1))>log(p2)log(log(ϵ2)) ->
ϵ 1 p 2 / p 1 > ϵ 2 \epsilon_1^{p_2/p_1} > \epsilon_2 ϵ1p2/p1>ϵ2
所以: p 1 p1 p1 大于 p 2 p_2 p2 的概率是 ϵ 1 p 2 / p 1 \epsilon_1^{p_2/p_1} ϵ1p2/p1

同理:
p 1 p1 p1 大于 p 3 p_3 p3 的概率是 ϵ 1 p 3 / p 1 \epsilon_1^{p_3/p_1} ϵ1p3/p1

所以 l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) log(p_1)-log(-log(\epsilon_1)) log(p1)log(log(ϵ1)) 最大的概率是:
ϵ 1 p 2 / p 1 \epsilon_1^{p_2/p_1} ϵ1p2/p1 * ϵ 1 p 3 / p 1 \epsilon_1^{p_3/p_1} ϵ1p3/p1 * … = ϵ 1 ( 1 − p 1 ) / p 1 \epsilon_1^{(1-p_1)/p_1} ϵ1(1p1)/p1
E ( ϵ 1 ( 1 − p 1 ) / p 1 ) E(\epsilon_1^{(1-p_1)/p_1}) E(ϵ1(1p1)/p1) = ∫ 0 1 ϵ 1 ( 1 − p 1 ) / p 1 d ϵ \int_{0}^{1}\epsilon_1^{(1-p_1)/p_1} d\epsilon 01ϵ1(1p1)/p1dϵ = ∫ 0 1 ϵ 1 ( 1 / p 1 ) − 1 d ϵ 1 \int_{0}^{1}\epsilon_1^{(1/p_1)-1} d\epsilon_1 01ϵ1(1/p1)1dϵ1 = ( p 1 ( ϵ 1 1 / p 1 ) ) ∣ 0 1 (p_1(\epsilon_1^{1/p_1}))|^1_0 (p1(ϵ11/p1))01 = p 1 p_1 p1

证明假设成立

torch实现

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    U = U.cuda()
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature=0.5):
    y = torch.log(logits) + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature=1, hard=False):
    """
    input: [B, n_class]
    return: [B, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    
    if not hard:
        return y

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard

思考

为什么gumbel-softmax和softmax的输出是不一样的?
为什么argmax(gumbel-softmax) 和 argmax(softmax)的结果也不一定一样?这是正常的吗?

大家共勉~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1502758.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SSRF漏洞基础原理(浅层面解释 + 靶场演示)

一、SSRF漏洞的基本概念: SSRF--全名:Server-Side Request Forgery,汉译:服务端请求伪造,漏洞别名“借刀杀人”。 想象以下,现存在一个 Web应用,这个Web应用可以帮助我们能爬取互联网上的其他…

js 添加、删除DOM元素

1. js添加、删除DOM元素 1.1. 添加DOM元素 1.1.1. appendChild()方法 该方法添加的元素位于父元素的末尾,使用方法: parentNode.appenChild(NewNode) // parentNode是需要添加元素的容器,NewNode是新添加的元素   创建一个li元素并添加到…

数据结构与算法-哈希表

引言 在计算机科学中,数据结构与算法是构建高效软件系统的关键基石。其中,哈希表作为一种非常实用的数据结构,以其快速查找、插入和删除等特性,在诸多领域发挥着无可替代的作用。本文将深入探讨哈希表的工作原理、实现细节以及其在…

Qt Creator常见问题解决方法

Qt Creator源文件重命名的正确方法 光改文件名是不够的,还要在.pro文件中的SOURCES中把名字改成之后的。 中文乱码(字符集设置) 菜单栏-工具-选项-设置为utf-8

Reset Verification IP

Reset Verification IP IP 参数及接口 IP 例化界面 相关函数 assert_reset //置位复位信号 < hierarchy_path>.assert_reset();deassert_reset //取消置位复位信号 < hierarchy_path>.deassert_reset();set_master_mode //设置 RST_VIP 模式为 Master < hi…

STM32基本定时功能

1、定时器就是计数器。 2、怎么计数&#xff1f; 3、我们需要有一恒定频率的方波信号&#xff0c;再加上一个寄存器。 4、比如每来一个上升沿信号&#xff0c;寄存器值加1&#xff0c;就可以完成计数。 5、假设方波频率是100Hz&#xff0c;也就是1秒100个脉冲。…

严刑拷打_微服务

文章详情 &#xff1a;&#x1f60a; 作者&#xff1a;Lion J &#x1f496; 主页&#xff1a; https://blog.csdn.net/weixin_69252724 &#x1f389; 主题&#xff1a; 微服务相关知识 ⏱️ 创作时间&#xff1a;2024年03月8日 ———————————————— 文章目…

[数据结构]队列

1.队列的概念及结构 队列&#xff1a;只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性表&#xff0c;队列具有先进先出 FIFO(First In First Out) 入队列&#xff1a;进行插入操作的一端称为队尾 出队列&#xff1a;进行删除操作的一端称为队头 2…

Dynamo3.0.3——六年来最大的更新

Hello大家好&#xff01;我是九哥~ 前几天&#xff0c;Dynamo Core 3.0.0版本发布&#xff0c;迎来了Dynamo六年来最大的一次更新。最大的改变&#xff0c;是更新到了.net8&#xff0c;这回对Dynamo节点包产生不小影响。接下来我们详细看一下都有哪些变化。 首先&#xff0…

【vue.js】文档解读【day 3】 | 列表渲染

如果阅读有疑问的话&#xff0c;欢迎评论或私信&#xff01;&#xff01; 文章目录 列表渲染v-forv-for 与对象在 v-for 里使用范围值template 上的 v-forv-for与v-if通过key管理状态组件上使用v-for数组变化侦测 列表渲染 v-for 在我们想要渲染出一个数组中的元素时&#xf…

node-day3-es6模块化+webpack

模块化 一、模块化分类 回顾node.js模块化&#xff1a; node.js遵循了CommonJS的模块化规范【见下文】&#xff0c;其中&#xff1a; 1.导入其它模块使用require()方法 2.模块对外共享成员使用module.exports对象 模块化的好处&#xff1a; 大家都遵守同样的模块化规范写代…

Only fullscreen opaque activities can request orientation

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、商业变现、人工智能等&#xff0c;希望大家多多支持。 未经允许不得转载 目录 一、导读二、概览三、分析四、 推荐阅…

【电工学笔记】上册第一、二章

电工学 上次考试败在了单位&#xff0c;这次单位 一定要记熟。 第一章 电源或信号源的电压或电流称为激励,它推动电路工作; 由激励所产生的电压和电流称为响应。 复杂电路中,一般无法事先判断某个支路电流的 实际方向或者某个电路元件电压的实际方向 140V/4算不出总电阻的 …

leetcode26---删除有序数组中的重复项

大家好&#xff0c;我是大唐&#xff0c;刚刷完了几道经典的leetcode题&#xff0c;今天给大家分享一道leetcode上面的快慢指针经典题型---删除有序数组中的重复项&#xff0c;我们往下看。 题目描述 给你一个 非严格递增排列 的数组 nums &#xff0c;请你原地删除重复出现的元…

【数据结构】拆分详解 - 排序

文章目录 前言一、排序的概念及其运用  1.1 排序的概念  1.2 排序的运用  1.3 常见的排序算法  1.4 排序算法性能测试对比函数 二、常见排序算法的实现  2.1 插入排序   2.1.1  基本思想   2.1.2  直接插入排序   2.1.3  希尔排序     1. 预排序&am…

Dataset 读取数据

Dataset 读取数据 from torch.utils.data import Dataset from PIL import Image import osclass Mydata(Dataset):def __init__(self,root_dir,label_dir):self.root_dir root_dir #根目录 dataset/trainself.label_dir label_dir #标签的后面链接目录 ants_ima…

ChatGPT 提问没反应了,怎么办?4种方法!试试看

用了将近 1 年的 ChatGPT 昨天下午提问忽然之间没反应了&#xff0c;有点失落&#xff0c;我原本以为是账号到期了呢。 之后&#xff0c;尝试用谷歌邮箱注册登录也不行。 打开调试一看&#xff0c;接口状态 403 &#xff0c;没有权限了&#xff0c;logout。 怎么办呢&#xf…

2023年12月CCF-GESP编程能力等级认证Python编程七级真题解析

本文收录于专栏《Python等级认证CCF-GESP真题解析》,专栏总目录・点这里 一、单选题(每题 2 分,共 30 分) 第1题 假设变量 x 为 float 类型,如果下面代码输入为 100,输出最接近( )。 A.0 B.-5 C.-8 D.8 答案:B 第2题 对于下面动态规划方法实现的函数,以下选项中…

【Nestjs实操】服务依赖注入

在开始学习之前&#xff0c;我们首先准备下开发环境&#xff1a; Node&#xff1a;16.20.2包管理器&#xff1a;pnpmnestjs版本&#xff1a;10.2.1全局安装nestjs命令行&#xff1a;pnpm add -g nestjs/cli 一、初始化项目 执行nest new nestjs-blog&#xff0c;系统会自动创…

关于 JVM

1、请你谈谈你对JVM的理解&#xff1f; JVM由JVM运行时数据区&#xff08;图示中蓝色框包含部分&#xff09;、执行引擎、本地库接口、本地方法库组成。 JVM运行时数据区&#xff0c;分为方法区、堆、虚拟机栈、本地方法栈和程序计数器。 1.方法区 Java 虚拟机规范中定…