《机器学习公式推导与代码实现》chapter22-EM算法

news2024/11/14 2:26:23

《机器学习公式推导与代码实现》学习笔记,记录一下自己的学习过程,详细的内容请大家购买作者的书籍查阅。

EM算法

作为一种迭代算法,EM算法(expectation maximization期望极大值算法)用于包含隐变量的概率模型参数的极大似然估计

EM算法包括两个步骤:E步,求期望(expectation);M步,求极大(maximization)。

1 极大似然估计

极大似然估计maximum likelihood estimation,MLE)是统计学领域中一种经典的参数估计方法。对于某个随机样本满足某种概率分布,但其中的统计参数未知的情况,极大似然估计可以让我通过若干次试验的结果来估计参数的值。

以一个经典的例子进行说明,比如我们想了解某高校学生的身高分布。我们假设该校学生的身高分布服从一个正态分布 N ( μ , σ 2 ) N(\mu,\sigma^{2}) N(μ,σ2),其中分布参数 μ \mu μ σ 2 \sigma^{2} σ2未知。全校有数万名学生,要一个一个实测肯定不现实,所以我们决定采用统计抽样的方法,随机选取100名学生测得其身高。

要通过100人的身高估算全校学生的身高,需要明确以下问题。第一个问题是抽到这100人的概率是多少。因为每个人的选取都是独立的,所以抽到这100人的概率可以表示为单个概率的乘积:
L ( θ ) = L ( x 1 , x 2 , ⋯   , x n ; θ ) = ∏ i = 1 n p ( x i ∣ θ ) L(\theta )=L(x_{1},x_{2},\cdots,x_{n};\theta)=\prod_{i=1}^{n}p(x_{i}\mid \theta) L(θ)=L(x1,x2,,xn;θ)=i=1np(xiθ)
上式为似然函数,为了计算方便,我们会对似然函数取对数:
H ( θ ) = ln ⁡ L ( θ ) = ln ⁡ ∏ i = 1 n p ( x i ∣ θ ) = ∑ i = 1 n ln ⁡ p ( x i ∣ θ ) H(\theta )=\ln_{}{L(\theta )}=\ln_{}{\prod_{i=1}^{n}p(x_{i}\mid \theta)}=\sum_{i=1}^{n} \ln_{}{p(x_{i}\mid \theta)} H(θ)=lnL(θ)=lni=1np(xiθ)=i=1nlnp(xiθ)
第二个问题是为什么刚好抽到这100人。按照极大似然估计理论,在学校这么多学生中,我们恰好抽到这100人而不是另外100人,正是因为这100人出现的概率极大,即其对应的似然函数极大:
θ ^ = a r g m a x L ( θ ) \hat{\theta} = argmax L(\theta ) θ^=argmaxL(θ)
最后一个问题是如何求解,直接对 L ( θ ) L(\theta) L(θ)求导并使其为0。

所以极大似然估计法可以看作由抽样结果对条件的反推,即已知某个参数能使得这些样本出现的概率极大,我们就直接把该参数作为参数估计的真实值。

2 EM算法

假设全校学生的身高付出一个分布的假设过于笼统,实际上男女分布不同,假设其中男生身高服从分布 N ( μ 1 , σ 1 2 ) N(\mu^{}_{1},\sigma^{2}_{1}) N(μ1,σ12),女生身高分布为 N ( μ 2 , σ 2 2 ) N(\mu^{}_{2},\sigma^{2}_{2}) N(μ2,σ22)。现在估计该校学生身高,就不能简单地使用一个分布的假设了。

假设分别抽取50个男生和50个女生,对他们分开进行估计。假设我们并不知道抽样得到的这样样本来自男生还是女生。

学生的身高是观测变量(observable variable),样本的性别是一种隐变量(hidden variable)。

现在我们需要估计两个问题:一是这个样本是男生的还是女生的,而是男生和女生对应的身高的正态分布参数分别是多少。这种情况极大似然估计就不太适用了,要估计男女生身高分布,就必须先估计该学生是男是女。反过来要估计该学生是男还是女,又得从身高来判断。但二者相互依赖,直接用极大似然估计无法计算。

针对这种包含隐变量的参数估计问题,一般使用EM(expectation maximization)算法,即期望极大化算法来进行求解。针对上述身高估计问题,EM算法的求解思路是:既然两个问题相互依赖,这肯定是一个动态求解过程。不如直接给定男女身高的分布初始值,根据初始值估计哪个样本是男/女生的概率(E步),然后据此使用极大似然估计男女生的身高分布参数(M步),之后动态迭代调整到满足终止条件为止。

EM算法的应用场景就是解决包含隐变量的概率模型参数估计问题。给定观测变量数据Y隐变量数据Z联合概率分布 P ( Y , Z ∣ θ ) P(Y,Z|\theta) P(Y,Zθ),以及关于隐变量的条件分布 P ( Z ∣ Y , θ ) P(Z|Y,\theta) P(ZY,θ),使用EM算法对模型参数 θ \theta θ进行估计的流程如下:
(1)初始化模型参数 θ ( 0 ) \theta^{(0)} θ(0),开始迭代。
(2)E步:记 θ ( i ) \theta^{(i)} θ(i)为第 i i i次迭代参数 θ \theta θ的估计值,在第 i + 1 i+1 i+1次迭代的E步,计算Q函数:
Q ( θ , θ ( i ) ) = E Z [ log ⁡ P ( Y , Z ∣ θ ) ∣ Y , θ ( i ) ] = ∑ Z log ⁡ P ( Y , Z ∣ θ ) P ( Z ∣ Y , θ ( i ) ) Q(\theta,\theta^{(i)})=E_{Z}\left [ \log_{}{P(Y,Z\mid\theta)}\mid Y,\theta^{(i)} \right ] =\sum_{Z}^{}\log_{}{P(Y,Z\mid\theta)}P(Z|Y,\theta^{(i)}) Q(θ,θ(i))=EZ[logP(Y,Zθ)Y,θ(i)]=ZlogP(Y,Zθ)P(ZY,θ(i))
其中 P ( Z ∣ Y , θ ( i ) ) P(Z|Y,\theta^{(i)}) P(ZY,θ(i))为给定观测数据 Y Y Y和当前参数估计 θ ( i ) \theta^{(i)} θ(i)的情况下隐变量数据 Z Z Z的条件概率分布。E步的关键是这个Q函数,Q函数定义为完全数据的对数似然函数 log ⁡ P ( Y , Z ∣ θ ) \log_{}{P(Y,Z\mid\theta)} logP(Y,Zθ)关于在给定观测数据 Y Y Y和当前参数 θ ( i ) \theta^{(i)} θ(i)的情况下未观测数据 Z Z Z的条件概率分布。
(3)M步:求使得Q函数最大化的参数 θ \theta θ,确定第 i + 1 i+1 i+1次迭代的参数估计值 θ ( i + 1 ) \theta^{(i+1)} θ(i+1)
θ ( i + 1 ) = a r g m a x θ Q ( θ , θ ( i ) ) \theta^{(i+1)}=\underset{\theta}{argmax}Q(\theta,\theta^{(i)}) θ(i+1)=θargmaxQ(θ,θ(i))
(4)重复迭代E步和M步直至收敛。

由EM算法过程可知,其关键在于E步要确定Q函数。E步在固定模型参数的情况下估计隐状态变量分布,而M步则是固定隐变量来估计模型参数。二者交互进行,直至算法收敛条件。

EM算法动态迭代过程:
请添加图片描述

3 三硬币模型

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4 基于numpy实现三硬币模型

import numpy as np

## EM算法过程定义
def em(data, thetas, max_iter=50, eps=1e-3): # data观测数据,thetas初始化的估计参数值,eps收敛阈值

    ll_old = 0 # 初始化似然函数值
    for i in range(max_iter):
        # E步:求隐变量分布
        log_like = np.array([np.sum(data*np.log(theta), axis=1) for theta in thetas]) # 对数似然 2*5
        like = np.exp(log_like) # 似然 2*5
        ws = like/like.sum(0) # 隐变量分布 2*5
        ll_new = np.sum([w*l for w, l in zip(ws, log_like)]) # 期望

        # M步:更新参数值
        vs = np.array([w[:, None] * data for w in ws]) # 概率加权 2*5*2
        thetas = np.array([v.sum(0)/v.sum() for v in vs]) # 2*2

        # 打印结果
        print(f'Iteration:{i+1}')
        print(f'theta_B = {thetas[0,0]:.2}, theta_C = {thetas[1,0]:.2}, {ll_new}')

        # 满足条件退出迭代
        if np.abs(ll_new - ll_old) < eps:
            break
        ll_old = ll_new
    
    return thetas

EM算法求解三硬币问题:

# 观测数据,5次独立实验,每次试验10次抛掷的正反面次数
observed_data = np.array([(5, 5), (9, 1), (8, 2), (4, 6), (7, 3)]) # 比如第一次试验为5次正面5次反面
# 初始化参数值,硬币B出现正面的概率0.6,硬币C出现正面的概率为0.5
thetas = np.array([[0.6, 0.4], [0.5, 0.5]])
# EM算法寻优
thetas = em(observed_data, thetas, max_iter=30)
thetas
Iteration:1
theta_B = 0.71, theta_C = 0.58, -32.68721052517165
Iteration:2
theta_B = 0.75, theta_C = 0.57, -31.258877917413145
Iteration:3
theta_B = 0.77, theta_C = 0.55, -30.760072598843628
Iteration:4
theta_B = 0.78, theta_C = 0.53, -30.33053606687176
Iteration:5
theta_B = 0.79, theta_C = 0.53, -30.071062062760774
Iteration:6
theta_B = 0.79, theta_C = 0.52, -29.95042921516964
Iteration:7
theta_B = 0.8, theta_C = 0.52, -29.90079955867412
Iteration:8
theta_B = 0.8, theta_C = 0.52, -29.881202814860167
Iteration:9
theta_B = 0.8, theta_C = 0.52, -29.873553692091832
Iteration:10
theta_B = 0.8, theta_C = 0.52, -29.870576075992844
Iteration:11
theta_B = 0.8, theta_C = 0.52, -29.86941691676721
Iteration:12
theta_B = 0.8, theta_C = 0.52, -29.868965223428773

array([[0.7967829 , 0.2032171 ],
       [0.51959543, 0.48040457]])

算法在第7次迭代时收敛,最后硬币B和硬币C出现正面的概率分别为0.80和0.52。

笔记本_Github地址

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/798918.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

devops(后端)

1.前言 该devpos架构为gitlabjenkinsharbork8s&#xff0c;项目是java项目&#xff0c;流程为从gitlab拉取项目代码到jenkins&#xff0c;jenkins通过maven将项目代码打成jar包&#xff0c;通过dockerfile构建jdk环境的镜像并把jar包放到镜像中启动&#xff0c;构建好的镜像通…

springboot运行报错Failed to load ApplicationContext for xxx

Failed to load ApplicationContext for报错解决方法 报错Failed to load ApplicationContext for 报错Failed to load ApplicationContext for 网上找了一堆方法都尝试了还是没用 包括添加mapperScan&#xff0c;添加配置类 配置pom文件 [外链图片转存失败,源站可能有防盗链机…

com.android.ide.common.signing.KeytoolException:

签名没问题但是提示Execution failed for task :app:packageDebug. > A failure occurred while executing com.android.build.gradle.tasks.PackageAndroidArtifact$IncrementalSplitterRunnable > com.android.ide.common.signing.KeytoolException: Failed to read ke…

21.2:象棋走马问题

请同学们自行搜索或者想象一个象棋的棋盘&#xff0c; 然后把整个棋盘放入第一象限&#xff0c;棋盘的最左下角是(0,0)位置 那么整个棋盘就是横坐标上9条线、纵坐标上10条线的区域 给你三个 参数 x&#xff0c;y&#xff0c;k 返回“马”从(0,0)位置出发&#xff0c;必须走k步 …

数据结构—串

4.1串 4.1.1串的定义 串&#xff08;String&#xff09;——零个或多个任意字符组成的有限序列 S"a1 a2...an"串的定义——几个术语 子串&#xff1a;串中任意个连续字符组成的子序列称为该串的子串 例如&#xff0c;“abcde”的子串有&#xff1a; “ ”、“a”、…

【C++】【自用】选择题 刷题总结

文章目录 【类和对象】1. 构造、拷贝构造的调用2. 静态成员变量3. 初始化列表4. 成员函数&#xff1a;运算符重载5. 友元函数、友元类55. 特殊类设计 【细节题】1. 构造 析构 new \ deletet、new[] \ delete[] 【类和对象】 1. 构造、拷贝构造的调用 #include using namespace…

大数据面试题:超详细版MapReduce工作原理

面试题来源&#xff1a; 《大数据面试题 V4.0》 大数据面试题V3.0&#xff0c;523道题&#xff0c;679页&#xff0c;46w字 参考答案&#xff1a; MapReduce详细流程&#xff1a; 1、准备待处理文件&#xff08;200M&#xff09; 2、submit()对原始文件进行切片分析&#…

热点活动-秒杀功能设计

一、需求描述 秒杀活动是电子商务兴起后出现的一种新型的购物方式&#xff0c;通过网上APP、小程序等平台推出一些低于市场价格的商品&#xff0c;提升购买率的营销活动&#xff0c;所有买家在同一时间网上抢购的一种销售方式。对比其他的营销活动&#xff0c;秒杀限时性更强&…

地平线J5芯片部署参考算法(2023.07.27)

本文主要是记录地平线官方提供的可在J5芯片上部署的参考算法。 参考算法数据集FPSPointPillarsKITTI116 (双核)CenterPointNuscenes98.72&#xff08;双核&#xff09;FCOS3DNuscenes589 (双核)GANetCULane2431&#xff08;双核&#xff09;Swin TransformerImageNet133&#…

网络加速技巧

某APP限制网速&#xff0c;可以这么做&#xff1a; &#xff08;1&#xff09;把网络禁用 &#xff08;2&#xff09;在APP的设置里面&#xff0c;把优化速率打开 &#xff08;3&#xff09;启用网络 2023年7月27日亲测有用&#xff0c;开启优化速率之前是100k/s&#xff0c;开…

机器学习---混淆矩阵代码

1. 导包&#xff1a; import pandas as pd from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.pipeline import Pipeline from sklearn.svm import SVC …

共用体类型

共用体&#xff08;union&#xff09;是一种成员共享存储空间的结构体类型。 union 共用体类型名 {成员列表 } 共用体内存长度是所有成员内存长度的最大值。 #include <iostream> using namespace std;int main() {//先声明共用体类型再定义共用体对象 union A {int m,…

11-2_Qt 5.9 C++开发指南_QSqlQueryModel的使用(QSqlQueryModel 只能作为只读数据源使用,不可以编辑数据)

文章目录 1 QSqlQueryModel 功能概述2 使用 QSqlQueryModel 实现数据查询2.1 实例功能2.2 可视化UI设计2.3 主窗口类定义&#xff08;去除自动生成的槽函数&#xff09;2.4 打开数据库2.5 记录移动 1 QSqlQueryModel 功能概述 从下图中可以看到&#xff0c;QSqlQueryModel 是 …

代码随想录算法训练营day13 | 239. 滑动窗口最大值,347. 前 K 个高频元素

239. 滑动窗口最大值 目录 239. 滑动窗口最大值 347. 前 K 个高频元素 239. 滑动窗口最大值 难度&#xff1a;hard 类型&#xff1a;队列&#xff0c;单调队列&#xff0c;滑动窗口 思路&#xff1a; 构造单调队列&#xff0c;维护大小为k的队列。队列里的元素始终是单调递…

无法加载文件 C:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚本。npm.ps1 cannot be loaded

目录 原因 解决方法 提示 查看当前的执行策略命令 改回默认值 "Restricted"命令 这个错误提示是因为您的系统禁止执行 PowerShell 脚本。 原因 现用执行策略是 Restricted&#xff08;默认设置&#xff09; 解决方法 以管理员身份运行 PowerShell&#xff1a;右键…

AICodeConvert网站,可以用AI把代码从一种语言转换为另一种语言实现,代码开源了,从 6.24 到现在一个月, 没有主动推广,居然9.8K 访问量

这是我一个之前周六 6.24 开始验证思路的项目&#xff0c;验证的感觉差不多&#xff0c;不做主动推广到现在一个月&#xff0c;访问量 9.8K 。 源码开源了&#xff0c;github.com 网址&#xff1a;AICodeConvert 另一个在佛系验证中的还有这个&#xff1a;Base64.kr&#xf…

gedit更改字体大小颜色、行号、更改各种属性

最近在linux&#xff08;CentOS&#xff09;中运行gedit时发现&#xff1a; 如果用普通用户运行&#xff0c;不会报错&#xff0c;但是不会出现Preferences &#xff08;首选项&#xff09;等选项&#xff0c;不能进行基本属性参数的更改&#xff1b;如果采用su、sudo 运行则会…

机器学习之十大经典算法

机器学习算法是计算机科学和人工智能领域的关键组成部分&#xff0c;它们用于从数据中学习模式并作出预测或做出决策。本文将为大家介绍十大经典机器学习算法&#xff0c;其中包括了线性回归、逻辑回归、支持向量机、朴素贝叶斯、决策树等算法&#xff0c;每种算法都在特定的领…

云原生架构的定义

前言&#xff1a; 从技术的角度&#xff0c;云原生架构是基于云原生技术的一组架构原则和设计模式的集合&#xff0c;旨在将云应用中非业务代码的部分进行最大化的剥离&#xff0c;从而让云设施接管应用中原有的大量非功能特性&#xff08;如弹性、韧性、安全、可观测性、灰度…

MySQL中锁的简介——全局锁

1.锁的概述及分类 2.全局锁的介绍 给数据库加全局锁&#xff1a; flush tables with read lock;数据备份&#xff1a; mysqldump备份指令 root用户名 1234 密码 itcast数据库名称 itcast.sql备份文件名称 mysqldump -uroot -p1234 itcast >itcast.sql;数据库全局锁解锁&am…