论文笔记CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX

news2025/1/13 15:32:33

目录

  • Gumbel-Softmax分布
  • Gumbel-Softmax Estimator
  • Straight-Through (ST) Gumbel-Softmax Estimator
    • Straight-Through Estimator (STE)
    • Straight-Through (ST) Gumbel-Softmax Estimator
  • 参考

Gumbel-Softmax分布

Gumbel-Softmax分布是一个定义在单纯形(simplex)上的连续分布。
Gumbel-Softmax分布可以近似categorical分布。
z z z表示为服从 π = ( π 1 , … , π k ) \pi = (\pi_1,\ldots,\pi_k) π=(π1,,πk)的categorical随机变量。categorical分布的样本表示为 k k k维的one-hot向量,在 k − 1 k-1 k1维的单纯形空间 △ k − 1 \bigtriangleup^{k-1} k1中。

Gumbel-Max trick是reparametrization tricks的一个特例,其提供了一个简单有效的从categorical分布采样的方法:
z = one-hot ( argmax ⁡ i [ g i + log ⁡ π i ] ) (1) z = \text{one-hot}\left(\operatorname{argmax}_i[g_i + \log \pi_i]\right) \tag{1} z=one-hot(argmaxi[gi+logπi])(1)其中 g i ∼ G u m b e l ( 0 , 1 ) g_i \sim Gumbel(0,1) giGumbel(0,1)。Gumbel分布用于对各种分布的多个样本的最大值(或最小值)的分布进行建模。Gumbel分布的概率密度是:
G u m b e l ( μ , β ) = 1 β exp ⁡ ( − x − μ β + exp ⁡ ( − x − μ β ) ) Gumbel(\mu, \beta) = \frac{1}{\beta}\exp(-\frac{x - \mu}{\beta} + \exp(-\frac{x - \mu}{\beta})) Gumbel(μ,β)=β1exp(βxμ+exp(βxμ))
softmax函数是可导的,使用softmax函数去近似公式(1)中的argmax,可以得到样本 y ∈ △ k − 1 y\in\bigtriangleup^{k-1} yk1
y i = softmax ⁡ [ g i + log ⁡ π i ] = exp ⁡ ( log ⁡ π i + g i τ ) ∑ j = 1 k exp ⁡ ( log ⁡ π j + g j τ ) y_i = \operatorname{softmax}[g_i + \log \pi_i] = \frac{\exp(\frac{\log \pi_i + g_i}{\tau})}{\sum_{j = 1}^k \exp(\frac{\log \pi_j + g_j}{\tau})} yi=softmax[gi+logπi]=j=1kexp(τlogπj+gj)exp(τlogπi+gi)Gumbel-Softmax分布的概率密度函数是:
在这里插入图片描述
随着 τ \tau τ趋近于0,Gumbel-Softmax分布的样本逐渐变成one-hot的,Gumbel-Softmax分布也逐渐变成了categorical分布。如下图所示:

在这里插入图片描述

Gumbel-Softmax Estimator

Gumbel-Softmax分布的 ∂ y ∂ π \frac{\partial y}{\partial \pi} πy是有定义的。
通过用Gumbel-Softmax样本替换categorical样本,我们可以使用反向传播来计算梯度。
把在训练阶段,用可导的Gumbel-Softmax样本替代不可导的categorical样本的过程称为Gumbel-Softmax Estimator。

在温度 τ \tau τ小时,样本接近单热但梯度方差大,在温度 τ \tau τ大时,样本平滑但梯度方差小。
实际中,我们从高温 τ \tau τ开始,然后退火到一个很小但非零的温度。

Straight-Through (ST) Gumbel-Softmax Estimator

Straight-Through Estimator (STE)

首先介绍下Straight-Through Estimator (STE)。
STE是量化(quantization)中常见的求导方式。
比如有sign函数:
w b = sign ⁡ ( w ) = { + 1 ,  if  w ≥ 0 − 1 ,  otherwise  w_{b}=\operatorname{sign}(w)=\left\{ \begin{array}{ll}{+1,}{\text { if } w \geq 0} \\ {-1,}{\text { otherwise }}\end{array}\right. wb=sign(w)={+1, if w01, otherwise 这个sign函数在定义域范围内导数都是0。STE就是用来解决sign函数梯度无法反传的问题的。
二值网络训练过程可以是这样:模型中每个参数其实都是一个浮点型的数,每次迭代其实都是在更新这个浮点型数。但是,在前向传播的过程中,先用sign函数对浮点型参数二值化处理然后再参与到运算,而此时并没有把这个浮点型数值抛弃掉,而是暂时在内存中保存起来。前向传播完之后,网络得到一个输出,就可以接着通过反向传播算出二值参数的梯度,再直接用这个梯度来更新对应的浮点型参数。这样,前向反向就跑通了。等训练的差不多了,就最后对模型的这些浮点型参数做一次二值化处理形成最终的二值网络,此时浮点型的参数就完成了任务,可以被抛弃掉了。

Straight-Through (ST) Gumbel-Softmax Estimator

在前向的时候使用argmax离散化 y y y,但在梯度反传的时候,使用连续近似 ∇ θ z ≈ ∇ θ y \nabla_\theta z \approx \nabla_\theta y θzθy

参考

ICLR 2017 Categorical Reparameterization with Gumbel-Softmax
Emma Benjaminson blog
二值网络,围绕STE的那些事儿
gumbel-max-trick的数学证明

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/94238.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目成功的制胜法宝——有效的领导力

项目经理在开展项目的过程中,为了确保项目的成功且实现价值交付,往往会使出浑身解数、有勇有谋、甚至熟练使用“孙子兵法”。毕竟在很多情况下,需要“带兵打仗”的项目经理权力微弱,不能像“将军”那般拥有权力的加持、一呼百应。…

javaSE - 认识字符串(String class)上半部分

前言 字符串: 在C语言里面 是 没有字符串类型的! 但是,在 Java 和 C 里,有字符串类型【String】 什么是字符串?什么是字符? 在java里面有表示字符串的类 String 使用双引号,且双引号中包含任意数量的字符【…

42 | iptables的使用方法

1 iptables简介 iptables是一个linux下的防火墙工具,能帮助我们基于规则进行网络流量控制。它可以做到,但不限于以下功能: 允许/拒绝某种协议的链接建立,比如TCP,UDP允许/拒绝 来自某个ip的访问允许/拒绝某个端口被访…

田径运动会成绩管理系统

开发工具(eclipse/idea/vscode等): 数据库(sqlite/mysql/sqlserver等): 功能模块(请用文字描述,至少200字): 模块划分:通知类型、通知信息、裁判信息、运动员信息、项目类型、项目信息、场地信息、项目安排、报名信息、…

深度学习目标检测:YOLOv5实现红绿灯检测(含红绿灯数据集+训练代码)

深度学习目标检测:YOLOv5实现红绿灯检测(含红绿灯数据集训练代码) 1. 前言 本篇博客,我们将手把手教你搭建一个基于YOLOv5的红绿灯目标检测项目。目前,基于YOLOv5s的红绿灯检测精度平均值mAP_0.50.93919,mAP_0.5:0.950.63967&…

_9LeetCode代码随想录算法训练营第九天-C++栈和队列

_9LeetCode代码随想录算法训练营第九天-C栈和队列 理论基础232.用栈实现队列225.用队列实现栈 基础知识 C标准库版本 HP STL 其他版本的C STL,一般是以HP STL为蓝本实现出来的,HP STL是C STL的第一个实现版本,而且开放源代码。P.J.Plauge…

【软件测试】测试面试,面试官其实想要的答案......

目录:导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜)前言 测试经理 保障xxx的…

Python入门自学进阶-Web框架——29、DjangoAdmin项目应用-整表只读、自定义用户认证

一、整表只读 对于readonly_fields是对单个字段设置只读,现在要对整个表使用只读,也做成可配置的。在自己项目的admin.py中进行配置,如在mytestapp_admin.py中对Customer进行整表只读配置,在基类BaseAdmin中增加readonly_table …

juc-1-进程/线程/创建线程代码

1 场景: 1:执行的任务,短而快。不能是很耗时的任务。 2:异步写日志(日志框架底层就是)。 3:异步发送邮件,短信。 用MQ 最好。 4:多线程下载 2 进程与线程 进程: 进程&#xff1…

Centerfusion算法环境配置及模型训练

Centerfusion算法环境配置及模型训练概述1. 配置conda环境1.1 新建conda环境1.2 安装cuda1.3 安装cudnn1.4 安装pytorch1.5 安装cocoapi2. 配置Centerfusion2.1 克隆CenterFusion的github库2.2 安装依赖包2.3 安装DCNv22.4 下载nuscenes-devkit包3. 数据集准备3.1 下载数据集3.…

细粒度图像分类论文研读-2019

文章目录Cross-X Learning for Fine-Grained Visual Categorization(by end-to-end)AbstractIntroductionApproachPreliminariesCross- Category Cross-Semantic RegularizerCross-Layer RegularizerOptimizationLearning a Mixture of Granularity-Spec…

公司刚来的京东架构师:看完我写的 spring 笔记,甩给了我一份文档

Spring 是分层的 full-stack(全栈) 轻量级开源框架,以 IoC 和 AOP 为内核,提供了展现层 SpringMVC 和业务层事务管理等众多的企业级应⽤技术,还能整合开源世界众多著名的第三⽅框架和类库,已经成为使⽤最多…

优维助力国内某省级商业银行同城异地灾备自动化建设

银监会在《商业银行数据中心监管指引》中明确要求“商业银行每年至少进行一次重要信息系统专项灾备切换演练,每三年至少一次重要信息系统全面灾备切换演练,以真实业务接管为目标,验证灾备系统有效接管生产系统与安全回切的能力,并…

SpringSecurity概念以及整合ssm框架

基本概念 Spring中提供安全认证服务的框架,认证:验证用户密码是否正确的过程,授权:对用户能访问的资源进行控制 用户登录系统时我们协助 SpringSecurity 把用户对应的角色、权限组装好,同时把各个资源所要求的权限信息…

毕业设计-基于大数据技术的旅游推荐系统-python

目录 前言 课题背景和意义 实现技术思路 实现效果图样例 前言 📅大四是整个大学期间最忙碌的时光,一边要忙着备考或实习为毕业后面临的就业升学做准备,一边要为毕业设计耗费大量精力。近几年各个学校要求的毕设项目越来越难,有不少课题是研究生级别难度的,对本科…

前端基础(十一)_Float浮动、清除浮动的几种方法

浮动 1、什么是浮动? 目的:为了让多个块级元素在同一行显示; 文档流:可显示的对象在排列时所占的位置; 浮动:使元素脱离正常的文档流,按照指定的顺序,方向发生移动,直到…

疫情下景区管理

开发工具(eclipse/idea/vscode等):idea 数据库(sqlite/mysql/sqlserver等):mysql 功能模块(请用文字描述,至少200字):模块划分:公告类型、公告信息、用户信息、用户咨询、地区信息、景区信息、景区开放、景区预约、统计…

【Flask框架】——17 Flask蓝图

在一个Flask 应用项目中,如果业务视图过多,可否将以某种方式划分出的业务单元单独维护,将每个单元用到的视图、静态文件、模板文件等独立分开? 例如从业务角度上,可将整个应用划分为用户模块单元、商品模块单元、订单…

应急响应-windows/Linux主机加固

windows/Linux主机加固 1,账户安全 首先要确保电脑上的账户均为经常使用的账户,要禁止Guest用户,禁用其他无用账户,一段时间后无反馈即可删除,同时要留意是否有隐藏账户存在。 查看本地用户和组:右键此电脑>计算机管…

PG::FunboxEasy

nmap -Pn -p- -T4 --min-rate1000 192.168.58.111 nmap -Pn -p 22,80,33060 -sCV 192.168.58.111 查看80端口的页面,未发现可用信息。 对路径进行爆破 gobuster dir -u http://192.168.58.111/ -w /usr/share/wordlists/dirbuster/directory-list-2.3-medium.txt…