Gumbel Softmax Trick

news2024/11/20 1:21:05

Gumbel Softmax Trick

  • 重参数化技巧(re-parameters trick)
  • Gumbel softmax trick
    • 基于Softmax的采样
    • 基于Gumbel-max的采样
    • 基于Gumbel-softmax采样
      • Softmax中的温度系数`tau`

算法学习之gumbel softmax
【Learning Notes】Gumbel 分布及应用浅析
gumbel-softmax(替代argmax)
**重参数化技巧(Gumbel-Softmax)

重参数化技巧(re-parameters trick)

从高斯分布 N ( μ , σ 2 ) N(\mu, \sigma^2) N(μ,σ2)从采样 x x x,改为 从标准分布 N ( 0 , 1 ) N(0, 1) N(0,1)中采样 z z z, 再得到 x = z ∗ σ + μ x = z * \sigma + \mu x=zσ+μ。这样做的好处是 将随机性转移到了 z z z这个常量上,而 σ \sigma σ μ \mu μ则当作仿射变换网络的一部分(可学习参数)。

直接采样导致梯度不可导。


在VAE中,期望encoder学习分布 N ( μ , σ 2 ) N(\mu, \sigma^2) N(μ,σ2),在从中采样一个 z z z,给decoder解码。但这个采样操作是不可导的,所以使用到了重参数化技巧。

让encoder学习均值 μ \mu μ和标准差 σ \sigma σ,我们只需要从标准分布 N ( 0 , 1 ) N(0, 1) N(0,1)中采样噪声 q q q,再得到 [ z = q ∗ σ + μ ] ∈ N ( μ , σ 2 ) [z = q * \sigma + \mu ]\in N(\mu, \sigma^2) [z=qσ+μ]N(μ,σ2)即可。

Gumbel softmax trick

解决随机采样不可导问题

【Learning Notes】Gumbel 分布及应用浅析

例如,

对于, logits = ( x 1 , x 2 , . . . , x k ) \text{logits} = (x_1, x_2, ..., x_k) logits=(x1,x2,...,xk),我们需要(按概率)采样得到其中的一个下标,如1, 2, …。

基于Softmax的采样

利用softmax归一化 logits \text{logits} logits

π i = e x i ∑ j = 1 k e x j \pi_i = { e^{x_i} \over \sum_{j=1}^k e^{x_j}} πi=j=1kexjexi

这样得到的 ∑ i = 1 k x i = 1 \sum_{i=1}^k x_i = 1 i=1kxi=1。然后得到的每个 π i ∈ ( 0 , 1 ) \pi_i \in (0, 1) πi(0,1)可以看作概率,然后使用这个概率去抽样下标。

numpy实现的soft-max方法

x = torch.randn(10)
size = 100000
def sample_with_softmax(logits, size):
    # size:     抽取次数
    # 默认有放回采样
    prob = F.softmax(logits)
    indices = torch.multinomial(prob, size, replacement=True)
    return indices

indices_softmax = sample_with_softmax(x, size)

print(x)
print(indices_softmax)

基于Gumbel-max的采样

x = torch.randn(10)
size = 100000
def sample_with_gumbel_max(logits, size):
    gumbel_dist = torch.distributions.gumbel.Gumbel(0, 1)
    noise = gumbel_dist.sample((size, logits.shape[-1]))
    indices = np.argmax(logits + noise, axis=-1)
    return indices

indices_gumbel_max = sample_with_gumbel_max(x, size)

print(indices_gumbel_max)

可以证明,Gumbel-max方法的采样效果等价于softmax采样的方法

如果我们分别利用 两种方法,进行多次采样,得到如下图。

import matplotlib.pylab as plt
import numpy as np
import torch
from torch.nn import functional as F


x = torch.randn(10)
size = 100000
def softmax(x):
    x -= np.max(x)
    return np.exp(x) / np.sum(np.exp(x))

def sample_with_softmax(logits, size):
    # size:     抽取次数
    # 默认有放回采样
    prob = F.softmax(logits)
    indices = torch.multinomial(prob, size, replacement=True)
    return indices

indices_softmax = sample_with_softmax(x, size)

print(x)
print(indices_softmax)

def sample_with_gumbel_max(logits, size):
    gumbel_dist = torch.distributions.gumbel.Gumbel(0, 1)
    noise = gumbel_dist.sample((size, logits.shape[-1]))
    indices = np.argmax(logits + noise, axis=-1)
    return indices

indices_gumbel_max = sample_with_gumbel_max(x, size)

print(indices_gumbel_max)

fig, axes = plt.subplots(1, 2)
axes[0].hist(indices_softmax, bins=100)
axes[1].hist(indices_gumbel_max, bins=100)

请添加图片描述

横坐标是下标,纵坐标是下标出现的次数。对随机生成的10大小的logits,采样10万次。

这里,解决了随机采样的问题。(利用argmax我们也可以进行随机采样)


但如上两种采样方式,都会导致不可导的问题。

  • sample_with_softmax中的np.random.choice
  • sample_with_gumbel_max中的np.argmax

那有没有什么方法使它可导呢?

基于Gumbel-softmax采样

def sample_with_softmax_hard(logits, size, tau=1):
    y = F.softmax(logits / tau)
    y_hard = torch.eye(y.shape[-1])[torch.argmax(y, dim=-1)]        # ont-hot
    y_hard = y + (y_hard - y).detach()            # straight-through estimator   直接复制梯度
    return y_hard

直接将梯度复制,回传跨过argmax。称为gradient straight-through。

  • 这里的tau是一个温度系数,这里暂不提及,见下文。
  • 在前向过程中,我们得到的是y_hard,反向过程中计算的梯度是y

但在sample_with_softmax_hard中,无法实现随机采样。这里我们结合上面的gumbel-max的方法。

def sample_with_gumbel_softmax(logits, size, tau=1):
    gumbel_dist = torch.distributions.gumbel.Gumbel(0, 1)
    noise = gumbel_dist.sample((size, logits.shape[-1]))
    y = F.softmax((logits+noise) / tau)
    y_hard = torch.eye(y.shape[-1])[torch.argmax(y, dim=-1)]        # ont-hot
    y_hard = y + (y_hard - y).detach()            # straight-through estimator   直接复制梯度
    return y_hard

即,给logits加上一个gumbel噪声,使得argmax能够实现随机抽样。

这里,解决了梯度的不可导。


Softmax中的温度系数tau

temperature 是大于零的参数,它控制着softmax的 soft 程度。温度越高,生成的分布越平滑;温度越低,生成的分布越接近离散的 one-hot 分布。 下面示例对比了不同温度下,softmax 的结果。

def softmax_plus(x, tau=1):
    y = F.softmax(x / tau)
    return y

x = torch.randn(10)
a = softmax_plus(x, tau=0.1)
b = softmax_plus(x, tau=1)
c = softmax_plus(x, tau=50)

fig, axes = plt.subplots(1, 3)
axes[0].bar(list(range(0, 10)), a, color='red')
axes[0].set_ylim(0, 1)
axes[0].set_title('tau=0.1')
axes[1].bar(list(range(0, 10)), b)
axes[1].set_ylim(0, 1)
axes[1].set_title('tau=1')
axes[2].bar(list(range(0, 10)), c, color='green')
axes[2].set_ylim(0, 1)
axes[2].set_title('tau=50')
plt.show()

请添加图片描述

横坐标是类别,纵坐标是softmax之后的值。可见,随着温度的升高,生成的概率趋于平滑。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/790892.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Failed to connect to 127.0.0.1 port 7890科学上网导致的问题

找了很多种解法: 首先这个.config配置文件有两个地方存在:先使用第一种方式,不管用再试第二种 第一个位置git安装路径:不需要重启 E:\git\Git\etc,这个需要看你自己的安装路径,找到http_proxy删除即可第二…

类型转换函数

再论类型转换 标准数据类型之间会进行隐式的类型安全转换 转换规则如下: 问题 普通类型与类类型之间能否进行类型转换? 类类型之间能否进行类型转换? 再论构造函数 构造函数可以定义不同类型的参数 参数满足下列条件时称为转换构造函数…

LocalDateTime的json格式化问题

目录 解决: 1、注册日期序列化器 2、自定义LocalDateTime的JSON格式 3、使用第三方库 总结: 实体类中定义了LocalDateTime类型的属性,获取数据会出现以下日期格式问题: 讲述: 对于LocalDateTime的JSON序列化和反序…

解析数据可视化工具:如何选择最合适的软件

在当今信息爆炸的时代,数据已成为各行各业的重要资源。为了更好地理解和分析数据,数据可视化成为一种必不可少的工具。市面上数据可视化工具不说上千也有上百,什么帆软、powerbi、把阿里datav,腾讯云图、山海鲸可视化等等等等&…

【VCS】(6)Code Coverage

Code Coverage VCS 中 Code Coverage 的类型Code Coverage Flow代码覆盖率选项Lab Code Coverage初步尝试其他格式的覆盖率报告屏蔽部分代码屏蔽整个模块 设计和验证到底要做到什么程度? 这里其中一个指标就是 Code Coverage。 代码覆盖率一般考虑以下几个方面&…

前置操作符和后置操作符

下面的代码有没有区别?为什么? 意想不到的事实 现代编译器产品会对代码进行优化 优化使得最终的二进制程序更加高效 优化后的二进制程序丢失了 C/C 的原生语义 不可能从编译后的二进制程序还原 C/C 程序 思考 操作符可以重载吗? 如何区分…

V1.4基站仓储三代标签操作指导

一、管理系统使用 1、启动v1.4基站 插上电源,用网线连接基站和电脑。基站默认ip为192.168.1.200,所以需要修改电脑的IP地址为192.168.1.x,例如:192.168.1.100 ​ 注:当基站第二个灯(绿色)闪烁…

DAY12_JSPEL表达式JSTL标签MVC模式和三层架构

目录 1 JSP 概述2 JSP 快速入门2.1 搭建环境2.2 导入 JSP 依赖2.3 创建 jsp 页面2.4 编写代码2.5 测试 3 JSP 原理4 JSP 脚本4.1 JSP 脚本分类4.2 案例4.2.1 需求4.2.2 实现4.2.3 成品代码4.2.4 测试 4.3 JSP 缺点 5 EL 表达式5.1 概述5.2 代码演示5.3 域对象 6 JSTL标签6.1 概…

leetcode数据结构题解(Java实现)(存在重复元素、最大子数组和、两数之和、合并两个有序数组)

文章目录 第一天217. 存在重复元素53.最大子数组和 第二天1. 两数之和88. 合并两个有序数组 第一天 217. 存在重复元素 题解思路:首先题目需要的是判断数组中是否存在相同的数字,存在返回true,不存在就返回false。 那么显然可以这样做,先进行…

全光谱护眼灯怎么选择?护眼灯全光谱和自然光谱的区别

一、全光谱护眼台灯的挑选技巧 全光谱:想要护眼台灯能有自然光的效果,选择台灯时建议选择全光谱台灯,并且显色指数大于Ra95以上的,显色指数越高越还原色彩,并且选择RGO豁免蓝光才是真的不会伤害眼睛的。 照射面积&…

Python(四十)for-in练习题——100到999之间的水仙花数

❤️ 专栏简介:本专栏记录了我个人从零开始学习Python编程的过程。在这个专栏中,我将分享我在学习Python的过程中的学习笔记、学习路线以及各个知识点。 ☀️ 专栏适用人群 :本专栏适用于希望学习Python编程的初学者和有一定编程基础的人。无…

2023年发布的25个开源大型语言模型总结

大型语言模型(llm)是一种人工智能(AI),在大量文本和代码数据集上进行训练。它们可以用于各种任务,包括生成文本、翻译语言和编写不同类型的创意内容。 今年开始,人们对开源LLM越来越感兴趣。这些模型是在开源许可下发布的,这意味…

Redis Stream 流的深度解析与实现高级消息队列【一万字】

详细介绍了 Redis 5.0 版本新增加的数据结构Stream的使用方式以及原理,如何实现更加可靠的消息队列。 文章目录 Stream 概述2 Stream基本结构3 存储数据3.1 Entry ID3.2 数量限制 4 获取数据4.1 范围查询4.2 独立消费消息4.2.1 非阻塞使用4.2.2 阻塞的使用 4.3 消费…

【Spring定时器】SpringBoot整合Quartz

SpringBoot整合Quartz 简单介绍 简单操作 导入相关pom依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-quartz</artifactId></dependency>创建继承类MyQuartz package com.ustc.quartz; …

C语言假期作业 DAY 01

题目 1.选择题 1、执行下面程序&#xff0c;正确的输出是&#xff08; &#xff09; int x5,y7; void swap() { int z; zx; xy; yz; } int main() { int x3,y8; swap(); printf("%d,%d\n"&#xff0c;x, y)…

Docker 单机/集群 部署 Nacos2.2.0

单机部署 1- 拉取镜像 docker pull nacos/nacos-server:v2.2.02- 准备挂载的配置文件目录和日志目录 日志目录(空目录)&#xff1a;./nacos/logs配置文件&#xff1a;./nacos/conf/application.properties 从官网下载 nacos 压缩包&#xff1a;Release 2.2.0 (Dec 14, 2022…

基于SpringBoot+vue的医院信管系统设计与实现(源码+LW+部署文档等)

博主介绍&#xff1a; 大家好&#xff0c;我是一名在Java圈混迹十余年的程序员&#xff0c;精通Java编程语言&#xff0c;同时也熟练掌握微信小程序、Python和Android等技术&#xff0c;能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架…

TypeScript -- 基础类型

文章目录 TypeScript -- 基础类型let 和 const基本类型写法布尔类型 -- boolean数字类型 -- number字符串类型 -- string数组类型元组类型枚举类型 -- enum任意类型 -- any空值 -- voidNull 和 Undefined不存在的类型 -- never对象 -- object类型断言 TypeScript – 基础类型 1…

关于 ivanti Access Client软件配置问题

最近需要使用ivanti工具连接校园网&#xff0c;但是经常出现ivanti连接后&#xff0c;WIFI或有线网络就显示无互联网连接的情况。 为此&#xff0c;我检查了一下网络的配置状态&#xff0c;发现ivanti连接的时候回临时创建一个网络adapter&#xff0c;该adapter有 一个身份验证…

Java训练二

一、斐波那契数列 1、1、2、3、5、8、13、21、34、...是一组典型的斐波那契数列&#xff0c;前两个数相加等于第三个数。那么请问这组数中的第n个数的值是多少&#xff1f; package haha; import java.util.Scanner; public class helloworld{public static void main(String…