小白学Pytorch系列- -torch.distributions API Distributions (1)

news2024/10/5 23:29:43

小白学Pytorch系列- -torch.distributions API Distributions (1)

分布包包含可参数化的概率分布和抽样函数。这允许构造用于优化的随机计算图和随机梯度估计器。这个包通常遵循TensorFlow分发包的设计。

不可能通过随机样本直接反向传播。但是,有两种主要方法可以创建可以反向传播的代理函数。这些是得分函数估计器/似然比估计器/REINFORCE 和路径导数估计器。REINFORCE 通常被视为强化学习中策略梯度方法的基础,而路径导数估计器通常出现在变分自动编码器的重新参数化技巧中。而得分函数只需要样本的值 f ( x ) f(x) f(x), 路径导数需要导数 F ‘ ( x ) F^‘(x) F(x). 下一节将在强化学习示例中讨论这两者。有关更多详细信息,请参阅 使用随机计算图进行梯度估计。

评分功能

当概率密度函数关于其参数可微时,我们只需要sample()log_prob()实现 REINFORCE:
Δ θ = α r ∂ log ⁡ p ( a ∣ π θ ( s ) ) ∂ θ \Delta \theta=\alpha r \frac{\partial \log p\left(a \mid \pi^\theta(s)\right)}{\partial \theta} Δθ=αrθlogp(aπθ(s))

其中 θ θ θ是参数, α α α是学习率, r r r是奖励, p ( a ∣ π θ ( s ) ) p(a|πθ(s)) p(aπθ(s))是在给定策略πθ的情况下在状态 s s s中采取行动 a a a的概率。

在实践中,我们会从网络的输出中抽取一个动作,将这个动作应用到一个环境中,然后使用log_probb来构造一个等效的损失函数。请注意,我们使用负数是因为优化器使用梯度下降,而上面的规则假设梯度上升。使用分类策略,实现REINFORCE的代码如下所示

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

路径导数

实现这些随机/策略梯度的另一种方法是使用从rsample()方法中使用重新聚集技巧,其中参数化的随机变量可以通过无参数无参数随机变量的参数化确定性函数构造。因此,重新聚集样品变得可区分。实施路径衍生物的代码如下:

params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assuming that reward is differentiable
loss = -reward
loss.backward()

分布

分布是概率分布的抽象基类。

ExponentialFamily

Bernoulli

Beta

Binomial

Categorical

Cauchy

Chi2

ContinuousBernoulli

Dirichlet

Exponential

FisherSnedecor

Gamma

Geometric

Gumbel

HalfCauchy

HalfNormal

Independent

Kumaraswamy

LKJCholesky

Laplace

LogNormal

LowRankMultivariateNormal

MixtureSameFamily

Multinomial

MultivariateNormal

NegativeBinomial

Normal

OneHotCategorical

Pareto

Poisson

RelaxedBernoulli

LogitRelaxedBernoulli

RelaxedOneHotCategorical

StudentT

TransformedDistribution

基于一个基础分布和一系列分布变换构建一个新的分布。

  • arg_constraints

  • cdf(value) 通过反转变换和计算基本分布的分数来计算累积分布函数。

  • expand(batch_shape, _instance=None)

  • icdf(value) 使用变换和计算基本分布的分数计算逆累积分布函数。

  • log_prob(value) 通过反变换对样本进行评分,并使用基本分布的评分和对数ab(det)雅可比矩阵的评分计算评分。
    log_prob(value)是计算value在定义的正态分布(mean,1)中对应的概率的对数,正太分布概率密度函数是
    f ( x ) = 1 2 π σ e − ( x − μ ) 2 2 σ 2 f(x)=\frac{1}{\sqrt{2 \pi} \sigma} e^{-\frac{(x-\mu)^2}{2 \sigma^2}} f(x)=2π σ1e2σ2(xμ)2
    对其取对数可得
    log ⁡ ( f ( x ) ) = − ( x − μ ) 2 2 σ 2 − log ⁡ ( σ ) − log ⁡ ( 2 π ) \log (f(x))=-\frac{(x-\mu)^2}{2 \sigma^2}-\log (\sigma)-\log (\sqrt{2 \pi}) log(f(x))=2σ2(xμ)2log(σ)log(2π )

  • rsample(sample_shape=torch.Size([])) 如果分布参数是批处理的,则生成一个样本形状的重新参数化样本或重新参数化样本的样本形状的批处理。首先从基本分布中采样,并对列表中的每个转换应用transform()。

  • sample(sample_shape=torch.Size([])) 如果分布参数是批量的,则生成一个样本形状的样本或样本形状的样本批次。首先从基本分布中采样,并对列表中的每个转换应用transform()。

Uniform

VonMises

Weibull

Wishart

KL Divergence

Transforms

Constraints

Constraint Registry

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/411146.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【华为机试真题详解JAVA实现】—矩阵乘法

目录 一、题目描述 二、解题代码 一、题目描述 如果A是个x行y列的矩阵,B是个y行z列的矩阵,把A和B相乘,其结果将是另一个x行z列的矩阵C。这个矩阵的每个元素是由下面的公式决定的 矩阵的大小不超过100*100 输入描述: 第一行包含一个正整数x,代表第一个矩阵的行数 第二行…

APP测试弱网测试

1、为什么要做弱网测试 当前APP网络环境比较复杂,网络制式有2G、3G、4G网络,还有越来越多的公共Wi-Fi。不同的网络环境和网络制式的差异,都会对用户使用app造成一定影响。 另外,当前app使用场景多变,如进地铁、上公交、…

【Linux】进程理解与学习Ⅳ-进程地址空间

环境:centos7.6,腾讯云服务器Linux文章都放在了专栏:【Linux】欢迎支持订阅🌹相关文章推荐:【Linux】冯.诺依曼体系结构与操作系统【Linux】进程理解与学习Ⅰ-进程概念浅谈Linux下的shell--BASH【Linux】进程理解与学习…

跟姥爷深度学习1 浅用tensorflow做个天气预测

一、前言 最近人工智能、深度学习又火了,我感觉还是有必要研究一下。三年前浅学了一下原理没深入研究框架,三年后感觉各种框架都成熟了,现成的教程也丰富了,所以我继续边学边写。原教程链接: 第一章:tens…

Linux- 系统随你玩之--玩出花活的命令浏览器下

文章目录1、背景2、常规操作2.1、测试相关2.1.1、修改 HTML 请求标头2.1.2、 模拟不同浏览器发出2.1.3、重定向2.2、 下载相关操作2.2.1、 后台下载2.2.2、设置下载重试次数2.2.3、过滤指定格式下载2.2.4、限制总下载文件大小2.2.5、匿名FTP下载2.2.6、FTP认证下载2.2.7、利用代…

(链表专题) 725. 分隔链表 ——【Leetcode每日一题】

725. 分隔链表 给你一个头结点为 head 的单链表和一个整数 k ,请你设计一个算法将链表分隔为 k 个连续的部分。 每部分的长度应该尽可能的相等:任意两部分的长度差距不能超过 1 。这可能会导致有些部分为 null 。 这 k 个部分应该按照在链表中出现的顺…

亚马逊 CodeWhisperer: 个人免费的类似GitHubCopilot能代码补全的 AI 编程助手

1、官网 AI Code Generator - Amazon CodeWhisperer - AWS 官方扩展安装教程 2、安装VSCode 下载安装VSCode 3、VSCode安装CodeWhisperer插件 安装VSCode插件 - AWS Toolkit主侧栏,点击AWS ,展开CodeWhisperer,点击Start 在下拉菜单中点…

【100个 Unity实用技能】 | C# 中关于补位的写法 PadLeft,PadRight 函数

Unity 小科普 老规矩,先介绍一下 Unity 的科普小知识: Unity是 实时3D互动内容创作和运营平台 。包括游戏开发、美术、建筑、汽车设计、影视在内的所有创作者,借助 Unity 将创意变成现实。Unity 平台提供一整套完善的软件解决方案&#xff…

LeetCode_101

内容提要 贪心算法 保证每次操作都属局部最优的,从而使得最后的结果是全局最优 全局结果是局部结果的简单求和,且局部结果互不相干 分配问题 分发饼干 455 简单 分发糖果 135 困难 先从左往右遍历一遍,如果右边孩子的评分比左边的高…

TryHackMe-Year of the Jellyfish(linux渗透测试)

Year of the Jellyfish 请注意 - 此框使用公共 IP 进行部署。想想这对你应该如何应对这一挑战意味着什么。如果您高速枚举公共 IP 地址,ISP 通常会不满意… 端口扫描 循例nmap 扫描结果中还有域名,加进hosts FTP 枚举 尝试anonymous Web枚举 有三个端…

LoRa无线通信技术之CAD介绍

信道活动检测 Lora扩频调制技术的使用在确定信道是否已被可能低于接收机噪声底限的信号。在这种情况下使用常规的RSSI方式判断显然是不切实际的。为此,信道活动检测器用于检测其他LoRaTM信号的存在。下图为通道活动检测(CAD)过程: 工作原理 Lora信道活动检测模式被设计成以最…

一站式指标平台 Kyligence Zen 功能详解

近日,Kyligence 正式发布一站式指标平台 Kyligence Zen GA 版本。其基于 Kyligence 核心 OLAP 能力打造,融合了领先企业建设指标平台的丰富实践,具备 ZenML 指标语言、指标目录、Excel / WPS 直连分析、模板市场等创新能力,将以简…

GPU受限,国内AI大模型能否交出自己的答卷?

继百度之后,阿里、华为、京东、360等大模型也陆续浮出水面,大模型军备竞赛正式开启。 4月7日,阿里云宣布自研大模型“通义千问”开始邀请企业用户测试体验。 4月8日,华为云人工智能领域首席科学家田奇现身《人工智能大模型技术高峰…

一起学 WebGL:图元的类型

大家好,我是前端西瓜哥,今天来说说 WebGL 中的三种图元。 在 WebGL 中,图元有三种:点、线、以及三角形。 绘制的 API 为: gl.drawArrays(mode, first, count)这里的 mode 就是要绘制的图元类型。 我们绘制 4 个点&…

办公协作效率想提质增效,可借助开源大数据工具!

在信息爆炸式发展的今天,提升办公协作效率,让各部门的信息有效互通起来,做好数据管理,已经成为众企业提升竞争力的方式方法。那么,如果想要提升办公效率,就需要了解开源大数据工具了。在数字化发展进程中&a…

HTTP协议概述 | 简析HTTP请求流程 | HTTP8种请求方法

目录 🌏 HTTP的简单介绍 何为HTTP HTTP1.0与HTTP1.1 🌏 HTTP的请求方法 1、OPTIONS 2、HEAD 3、GET 4、POST 5、PUT 6、DELETE 7、TRACE 8、CONNECT 🌏 HTTP的工作原理 🌏 HTTP请求/响应的步骤 1、客户端连接到Web…

AI 芯片的简要发展历史

随着人工智能领域不断取得突破性进展。作为实现人工智能技术的重要基石,AI芯片拥有巨大的产业价值和战略地位。作为人工智能产业链的关键环节和硬件基础,AI芯片有着极高的技术研发和创新的壁垒。从芯片发展的趋势来看,现在仍处于AI芯片发展的…

【案例教程】基于R语言、MaxEnt模型融合技术的物种分布模拟、参数优化方法、结果分析制图与论文写作实践技术

【原文链接】: 基于R语言、MaxEnt模型融合技术的物种分布模拟、参数优化方法、结果分析制图与论文写作实践技术https://mp.weixin.qq.com/s?__bizMzU5NTkyMzcxNw&mid2247537049&idx3&sn31ef342c4808aed6fee6ac108b899a33&chksmfe6897f3c91f1ee5c4fa8e4eeea34…

JDBC概述三(批处理+事务操作+数据库连接池)

一(批处理) 1.1 批处理简介 批处理,简而言之就是一次性执行多条SQL语句,在一定程度上可以提升执行SQL语句的速率。批处理可以通过使用Java的Statement和PreparedStatement来完成,因为这两个语句提供了用于处理批处理…

IO多路复用机制详解

高性能IO模型浅析 服务器端编程经常需要构造高性能的IO模型,常见的IO模型有四种: (1)同步阻塞IO(Blocking IO):即传统的IO模型。 (2)同步非阻塞IO(Non-blo…