使用TensorFlow Probability实现最大似然估计

news2025/1/19 8:27:47

TensorFlow Probability是一个构建在TensorFlow之上的Python库。它将我们的概率模型与现代硬件(例如GPU)上的深度学习结合起来。

极大似然估计

最大似然估计是深度学习模型中常用的训练过程。目标是在给定一些数据的情况下,估计概率分布的参数。简单来说,我们想要最大化我们在某个假设的统计模型下观察到的数据的概率,即概率分布。

这里我们还引入了一些符号。连续随机变量的概率密度函数大致表示样本取某一特定值的概率。我们将表示这个函数𝑃(𝑥|𝜃),其中𝑥是样本的值,𝜃是描述概率分布的参数:

 tfd.Normal(0, 1).prob(0)
 
 <tf.Tensor: shape=(), dtype=float32, numpy=0.3989423>

当从同一个分布中独立抽取多个样本时(我们通常假设),样本值𝑥1,…,𝑥𝑛的概率密度函数是每个个体𝑥𝑖的概率密度函数的乘积:

可以很容易地用一个例子来计算上面的问题。假设我们有一个标准的高斯分布和一些样本:𝑥1=−0.5,𝑥2=0和𝑥3=1.5。正如我们上面定义的那样,我只需要计算每个样本的概率密度函数,并将输出相乘。

 X = [-0.5, 0, 1.5]
 
 np.prod(tfd.Normal(0, 1).prob(X))
 
 0.01819123

现在,我想直观地告诉大家概率密度函数和似然函数之间的区别。它们本质上是在计算类似的东西,但角度相反。

从概率密度函数开始,我们知道它们是样本𝑥1,…,𝑥𝑛的函数。参数𝜃被认为是固定的。因此当参数𝜃已知时,我们使用概率密度函数,找出相同样本𝑥1,…,𝑥𝑛的概率。简单地说,当我们知道产生某个过程的分布并且我们想从它中推断可能的抽样值时,我们使用这个函数。

对于似然函数,我们所知道的是样本,即观测数据𝑥1,…,𝑥𝑛。这意味着我们的自变量现在是𝜃,因为我们不知道是哪个分布产生了我们观察到的这个过程。所以当我们知道某个过程的样本时,使用这个函数,即我们收集了数据,但我们不知道最初是什么分布生成了该过程。也就是说既然我们知道这些数据,我们就可以对它们来自的分布进行推断。

对于似然函数,惯例是使用字母𝐿,而对于概率密度函数,我们引入了上面的符号。我们可以这样写:

我们准备定义参数为𝜇和𝜎的高斯分布的似然函数:

作为对似然函数有更多直观了解,我们可以生成足够多的样本来直观地了解它的形状。我们对从概率分布中生成样本不感兴趣,我们感兴趣的是生成参数𝜃,使观测数据的概率最大化,即𝑃(𝑥1,…,𝑥𝑛|𝜃)。

我们使用与上面相同的样本𝑥1=−0.5,𝑥2=0和𝑥3=1.5。

 X
 
 [-0.5, 0, 1.5]

为了能够构建2D可视化,我们可以创建一个潜在参数的网格,在一段时间间隔内均匀采样,𝜇从[-2,2]采样,𝜎从[0,3]采样。由于我们对每个参数采样了100个值,得到了𝑛^2个可能的组合。对于每个参数的组合,我们需要计算每个样本的概率并将它们相乘。

 μ = np.linspace(-2, 2, 100)
 σ = np.linspace(0, 3, 100)
 
 l_x = []
 for mu in μ:
     for sigma in σ:
         l_x.append(np.prod(tfd.Normal(mu, sigma).prob(X)))
         
 l_x = np.asarray(l_x).reshape((100, 100)).T

现在准备绘制似然函数。注意这是观察到的样本的函数,这些是固定的,参数是我们的自变量。

 plt.contourf(μ, σ, l_x)
 plt.xlabel('μ')
 plt.ylabel('σ')
 plt.colorbar()
 plt.title('Likelihood');

我们感兴趣的是最大化数据的概率。这意味着想要找到似然函数的最大值,这可以借助微积分来实现。函数的一阶导数对参数的零点应该足以帮助我们找到原函数的最大值。

但是,将许多小概率相乘在数值上是不稳定的。为了克服这个问题,可以使用同一函数的对数变换。自然对数是一个单调递增的函数,这意味着如果x轴上的值增加,y轴上的值也会增加。这很重要,因为它确保概率对数的最大值出现在与原始概率函数相同的点。它为我们做了另一件非常方便的事情,它将乘积转化为和。

让我们执行变换:

现在可以着手解决优化问题了。最大化我们数据的概率可以写成:

上面的表达式可以被求导以找到最大值。展开参数有log(𝐿(𝑋|𝜇,𝜎))。由于它是两个变量𝜇和𝜎的函数,使用偏导数来找到最大似然估计。

专注于𝜇´("撇"表示它是一个估计值,即我们的输出),我们可以使用以下方法计算它:

为了找到最大值,我们需要找到临界值,因此需要将上面的表达式设为零。

得到

这是数据的平均值,可以为我们的样本𝑥1=−0.5,𝑥2=0和𝑥3=1.5计算μ和σ的最大值,并将它们与真实值进行比较。

 idx_μ_max = np.argmax(l_x, axis=1)[-1]
 print(f'μ True Value: {np.array(X).mean()}')
 print(f'μ Calculated Value: {μ[idx_μ_max]}')
 print(f'σ True Value: {np.array(X).std()}')
 print(f'σ Calculated Value: {σ[np.nanargmax(l_x[:,idx_μ_max], axis=0)]}')
 
 μ True Value: 0.3333333333333333
 μ Calculated Value: 0.3434343434343434
 σ True Value: 0.8498365855987975
 σ Calculated Value: 0.8484848484848485

最大似然估计在TensorFlow Probability中的实现

我们先创建一个正态分布随机变量并从中取样。通过绘制随机变量的直方图,可以看到分布的形状。

 x_train = np.random.normal(loc=1, scale=5, size=1000).astype('float32')[:, np.newaxis]
 
 plt.hist(x_train, bins=50);

然后计算随机变量的均值,这是我们想用最大似然估计学习的值。

 x_train.mean()
 
 0.85486585

将TensorFlow Variable对象定义为分布的参数。这向TensorFlow说明,我们想在学习过程中学习这些参数。

 normal = tfd.Normal(loc=tf.Variable(0., name='loc'), scale=5)
 normal.trainable_variables
 
 (<tf.Variable 'loc:0' shape=() dtype=float32, numpy=0.0>,)

下一步是定义损失函数。我们已经看到了我们想要达到的目标最大化似然函数的对数变换。但是在深度学习中,通常需要最小化损失函数,所以直接将似然函数的符号改为负。

 def nll(x_train):
     return -tf.reduce_sum(normal.log_prob(x_train))

最后建立训练程序,使用自定义训练循环,可以自己定义过程细节(即使用自定义损失函数)。

使用tf.GradientTape(),它是访问TensorFlow的自动微分特性的API。然后指定要训练的变量,最小化损失函数并应用梯度。

 @tf.function
 def get_loss_and_grads(x_train):
     with tf.GradientTape() as tape:
         tape.watch(normal.trainable_variables)
         loss = nll(x_train)
         grads = tape.gradient(loss, normal.trainable_variables)
     return loss, grads
 
 optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)

现在训练程序已经准备完毕了。

 @tf.function
 def get_loss_and_grads(x_train):
     with tf.GradientTape() as tape:
         tape.watch(normal.trainable_variables)
         loss = nll(x_train)
         grads = tape.gradient(loss, normal.trainable_variables)
     return loss, grads
 
 optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
 
 Step 000: Loss: 13768.004 Loc: 0.855
 Step 001: Loss: 13768.004 Loc: 0.855
 Step 002: Loss: 13768.004 Loc: 0.855
 ...
 Step 1997: Loss: 13768.004 Loc: 0.855
 Step 1998: Loss: 13768.004 Loc: 0.855
 Step 1999: Loss: 13768.004 Loc: 0.855

我们通过最大化在第一时间生成的抽样数据的概率,计算了参数𝜇的最大似然估计。它是有效的,因为能够得到一个非常接近原始值的𝜇值。

 print(f'True Value: {x_train.mean()}')
 print(f'Estimated Value: {normal.trainable_variables[0].numpy()}')
 
 True Value: 0.8548658490180969
 Estimated Value: 0.8548658490180969

总结

本文介绍了最大似然估计的过程,和TensorFlow Probability的实现。通过一个简单的例子,我们对似然函数的形状有了一些直观的认识。最后通过定义一个TensorFlow变量、一个负对数似然函数并应用梯度,实现了一个使用TensorFlow Probability的自定义训练过程。

https://avoid.overfit.cn/post/e604c2173f754788869c5c1332ccba6d

作者:Luís Roque

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/84877.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开源依赖项管理指南

就像人际关系中人与人之间的关系一样&#xff0c;软件生态系统中包含一个庞大的关系网络。其中一些联系非常深入&#xff0c;而有一些关系则更为表面。但实际上&#xff0c;现代基于开源的软件开发涉及一个极其庞大的依赖关系树&#xff0c;依赖关系层层叠加&#xff0c;同时涉…

喜讯丨创新微MinewSemi的MS11SF1系列荣获2022中国IoT创新奖—产品金狮奖

北京时间2022年12月8日&#xff0c;由知名电子科技媒体“电子发烧友”举办的2022第九届中国IoT大会在深圳圆满落幕&#xff0c;创新微MinewSemi凭借高性能、低功耗的WiFiBLE Combo 模块—MS11SF1系列&#xff0c;在众多参会嘉宾和行业主流媒体的共同见证下&#xff0c;荣获2022…

卷积神经网络中卷积的作用与原理

目录 前言 卷积的作用 卷积的参数 卷积核大小&#xff08;kernel_size&#xff09; 填充&#xff08;padding&#xff09; same valid full 卷积核算子&#xff08;operator&#xff09; Robert 算子 Prewitt算子 Sobel 算子 Laplance 算子 卷积核深度与个数&…

【C++进阶】哈希(万字详解)—— 运用篇(下)

&#x1f387;C学习历程&#xff1a;入门 博客主页&#xff1a;一起去看日落吗持续分享博主的C学习历程博主的能力有限&#xff0c;出现错误希望大家不吝赐教分享给大家一句我很喜欢的话&#xff1a; 也许你现在做的事情&#xff0c;暂时看不到成果&#xff0c;但不要忘记&…

[附源码]Python计算机毕业设计电子工厂进销存管理系统Django(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程 项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等…

为什么要把测试环境的告警当成生产环境的告警处理?是一个哲学问题,还是一个技术问题?...

开发不愿意了一个后端服务通常有三个环境&#xff1a;测试环境&#xff0c;预发布环境&#xff0c;生产环境。运维在给测试环境增加告警规则和告警路由时&#xff0c;开发人员反对。这很容易理解&#xff0c;如果真把告警规则配置到测试环境&#xff0c;他们可能无时不刻地收到…

Web GIS开发教程

Web GIS开发教程 非程序员的基本 Web GIS 开发 课程英文名&#xff1a;Web GIS development course 此视频教程共4.0小时&#xff0c;中英双语字幕&#xff0c;画质清晰无水印&#xff0c;源码附件全 下载地址 课程编号&#xff1a;355 百度网盘地址&#xff1a;https://p…

杭州联合银行 x 袋鼠云:打造智能标签体系,助力银行大零售业务转型

“智能标签平台上线后&#xff0c;支行及业务部门已创建多个客群用于营销&#xff0c;为我行客户精细化管理打下了良好基础。” 2021 年&#xff0c;联合银行就已搭建了大数据基础平台&#xff0c;围绕平台搭建了数据研发平台、大数据调度平台及大数据服务平台&#xff0c;提高…

(附源码)Python飞机票销售系统 毕业设计 141432

摘 要 21世纪的今天&#xff0c;随着社会的不断发展与进步&#xff0c;人们对于信息科学化的认识&#xff0c;已由低层次向高层次发展&#xff0c;由原来的感性认识向理性认识提高&#xff0c;管理工作的重要性已逐渐被人们所认识&#xff0c;科学化的管理&#xff0c;使信息存…

Vue组件的嵌套关系,父组件传递子组件 ,事件总线,Provide,inject,作用域插槽,具名插槽非props的attribute ,子组件传递父组件

组件化 – 组件间通信 认识组件的嵌套 ◼ 前面我们是将所有的逻辑放到一个App.vue中:  在之前的案例中,我们只是创建了一个组件App;  如果我们一个应用程序将所有的逻辑都放在一个组件中,那么这个组件就会变成非常的臃 肿和难以维护;  所以组件化的核心思想应该是对…

【YOLOv5】记录YOLOv5的学习过程

以下记录的是Ubuntu20.04版本&#xff0c;其他Ubuntu版本也相差不大~ 一、安装pytorch GPU版本、显卡驱动、CUDA、cuDNN 下载pytorch GPU版本&#xff1a; 最新版本链接&#xff1a;Start Locally | PyTorch 历史版本链接&#xff1a;Previous PyTorch Versions | PyTorch…

MySQL——内置函数

文章目录内置函数日期函数字符串函数数学函数其他函数内置函数 日期函数 基本使用&#xff1a; 可以进行运算&#xff1a; 在日期基础上加时间 在日期基础上减时间 计算两个日期相差的天数 案例1&#xff1a; 建一张表&#xff0c;记录生日 案例2&#xff1a; 创建一…

设计有趣的轻巧真无线,体积小续航长,南卡小音舱上手

大家平时都会听听音乐、玩玩游戏&#xff0c;这时候就需要用到蓝牙耳机&#xff0c;特别是在户外接打电话时&#xff0c;戴上一副耳机都会方便很多。最近发现了一款南卡小音舱Lite2&#xff0c;这些天用过之后感觉它质量不错&#xff0c;做得十分小巧&#xff0c;日常携带特别方…

Postman带sessionId的post请求访问失败

Postman带sessionId的post请求访问失败1、Python 调用过程2、Postman 错误示例3、Postman 正确示例4、总结使用 Python 访问一个数据接口&#xff0c;调用是正常的&#xff0c;但是使用 Postman 进行访问时出错了&#xff0c;搞了两天&#xff0c;后面发现很简单&#xff0c;故…

如何理解FFT中时间窗与RBW的关系

作为一种常用的频谱分析工具&#xff0c;快速傅里叶变换(FFT) 实现了时域到频域的转换&#xff0c;是数字信号分析中最常用的基本功能之一。FFT 频谱分析是否与传统的扫频式频谱仪类似&#xff0c;也具有分辨率带宽(RBW) 的概念&#xff1f;如果具有RBW &#xff0c;那么FFT 的…

前端食堂技术周刊第 63 期:Vite 4.0、State of CSS 2022、Rome v11、Web 性能日历、VueConf 2022 PPT

美味值&#xff1a;&#x1f31f;&#x1f31f;&#x1f31f;&#x1f31f;&#x1f31f; 口味&#xff1a;霜糖山楂 食堂技术周刊仓库地址&#xff1a;https://github.com/Geekhyt/weekly 本期摘要 Vite 4.0State of CSS 2022 调查结果Rome v11HTMHell Advent Calendar 20…

虚幻引擎VR游戏开发基础教程

虚幻引擎VR游戏开发基础教程 了解如何使用 Oculus Quest 2 的蓝图在虚幻引擎 4 中从头开始构建基本的 VR 机制 课程英文名&#xff1a;Unreal Engine VR Development Fundamentals 此视频教程共4.0小时&#xff0c;中英双语字幕&#xff0c;画质清晰无水印&#xff0c;源码附…

推荐一些Python练手项目,了解完毕后才吃惊

前言 入门篇&#xff1a; 0.Python初学者一般都是那些根本没有编程基础的学生。做这个项目&#xff0c;你应该首先开始基本语法。教程中的几个实验可以让完全零基础的学生在一个下午学习Linux、python基础知识和GitHub命令。 1.Python-Python 图片转字符画50 行 Python 代码…

web前端期末大作业网页设计与制作 ——汉口我的家乡旅游景点 5页HTML+CSS+JavaScript

家乡旅游景点网页作业制作 网页代码运用了DIV盒子的使用方法&#xff0c;如盒子的嵌套、浮动、margin、border、background等属性的使用&#xff0c;外部大盒子设定居中&#xff0c;内部左中右布局&#xff0c;下方横向浮动排列&#xff0c;大学学习的前端知识点和布局方式都有…

JDBC 入门

目录1 JDBC 快速入门1.1 JDBC 的概念1.2 JDBC 快速入门2 JDBC 功能类详解2.1 DriverManager2.2 Connection2.3 Statement2.4 ResultSet3 JDBC 工具类4 SQL 注入攻击5 JDBC 事务5.1 JDBC 管理事务6 连接池6.1 数据库连接池的概念6.2 自定义数据库连接池6.2.1 DataSource6.2.2 归…