机器学习 | 期望最大化(EM)算法介绍和实现

news2024/9/28 1:16:15

在现实世界的机器学习应用中,通常有许多相关的特征,但只有其中的一个子集是可观察的。当处理有时可观察而有时不可观察的变量时,确实可以利用该变量可见或可观察的实例,以便学习和预测不可观察的实例。这种方法通常被称为处理缺失数据。通过使用变量可观察的可用实例,机器学习算法可以从观察到的数据中学习模式和关系。然后,这些学习到的模式可以用于预测变量在缺失或不可观察的情况下的值。

期望最大化算法可用于处理变量部分可观察的情况。当某些变量是可观察的时,我们可以使用这些实例来学习和估计它们的值。然后,我们可以预测这些变量在不可观测的情况下的值。

EM算法是在1977年由亚瑟·登普斯特、南·莱尔德和唐纳德·鲁宾发表的一篇开创性论文中提出并命名的。他们的工作形式化了算法,并证明了其在统计建模和估计中的实用性。

EM算法适用于潜变量,潜变量是不能直接观测到的变量,而是从其他观测变量的值推断出来的。通过利用控制这些潜在变量的概率分布的已知一般形式,EM算法可以预测它们的值。

EM算法是机器学习领域中许多无监督聚类算法的基础。它提供了一个框架来找到统计模型的局部最大似然参数,并在数据缺失或不完整的情况下推断潜在变量。

期望最大化算法

期望最大化(EM)算法是一种迭代优化方法,它结合了不同的无监督机器学习算法,以找到涉及未观察到的潜在变量的统计模型中参数的最大似然或最大后验估计。EM算法通常用于潜变量模型,可以处理缺失数据。它由估计步骤(E步骤)和最大化步骤(M步骤)组成,形成迭代过程以改善模型拟合。

  • 在E步骤中,算法使用当前参数估计值计算潜在变量,即对数似然的期望值。
  • 在M步骤中,算法确定使在E步骤中获得的期望对数似然最大化的参数,并且基于估计的潜在变量更新相应的模型参数。

在这里插入图片描述

通过迭代地重复这些步骤,EM算法寻求最大化观察数据的可能性。它通常用于无监督学习任务,例如聚类,其中隐变量被推断并在各种领域中应用,包括机器学习,计算机视觉和自然语言处理。

EM算法中的关键术语

期望最大化(EM)算法中最常用的一些关键术语如下:

  • 潜在变量:潜变量是统计模型中不可观测的变量,只能通过其对可观测变量的影响间接推断。它们不能直接测量,但可以通过它们对可观察变量的影响来检测。
  • 可能性:在给定模型参数的情况下,观察到给定数据的概率。在EM算法中,目标是找到使可能性最大化的参数。
  • 对数似然函数:它是似然函数的对数,用于度量观测数据与模型之间的拟合优度。EM算法寻求最大化对数似然。
  • 最大似然估计(Maximum Likewise Estimation,MLE):MLE是一种通过找到使似然函数最大化的参数值来估计统计模型参数的方法,该方法衡量模型解释观测数据的程度。
  • 后验概率:在贝叶斯推理的背景下,EM算法可以扩展到估计最大后验(MAP)估计,其中参数的后验概率是基于先验分布和似然函数计算的。
  • 预期(E)步骤:EM算法的E步骤计算给定观测数据和当前参数估计的潜在变量的期望值或后验概率。它涉及计算每个数据点的每个潜在变量的概率。
  • 最大化(M)步骤:EM算法的M步通过最大化从E步获得的预期对数似然来更新参数估计值。它涉及找到优化似然函数的参数值,通常通过数值优化方法。
  • 收敛:收敛是指EM算法达到稳定解的条件。它通常通过检查对数似然或参数估计值的变化是否低于预定义的阈值来确定。

期望最大化(EM)算法是如何工作的

期望最大化算法的本质是使用数据集的可用观测数据来估计缺失数据,然后使用该数据来更新参数的值。让我们详细了解EM算法。

在这里插入图片描述

  1. 初始化:
    首先,考虑一组参数的初始值。假设观测数据来自特定的模型,给系统一组不完整的观测数据。
  2. E-Step(期望步骤):在这一步中,我们使用观察到的数据来估计或猜测缺失或不完整数据的值。它主要用于更新变量。
    在给定观测数据和当前参数估计值的情况下,计算每个潜在变量的后验概率。
    使用当前参数估计值估计缺失或不完整的数据值。
    基于当前参数估计值和估计缺失数据计算观测数据的对数似然。
  3. M步(最大化步骤):在这一步中,我们使用前面的“期望”步骤中生成的完整数据来更新参数值。它主要用于更新假设。
    通过最大化从E步骤获得的预期完整数据对数似然来更新模型的参数。
    这通常涉及解决优化问题,以找到最大化对数似然的参数值。
    所使用的具体优化技术取决于问题的性质和所使用的模型。
  4. 融合:在该步骤中,检查值是否收敛,如果是,则停止,否则重复步骤2和步骤3,即“期望”步骤和“最大化”步骤,直到收敛发生。
    通过比较迭代之间的对数似然或参数值的变化来检查收敛性。
    如果变化低于预定义的阈值,则停止并认为算法收敛。
    否则,返回E步骤并重复该过程,直到实现收敛。

期望最大化算法的实现

导入必要的库

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

生成具有两个高斯分量的数据集

# Generate a dataset with two Gaussian components
mu1, sigma1 = 2, 1
mu2, sigma2 = -1, 0.8
X1 = np.random.normal(mu1, sigma1, size=200)
X2 = np.random.normal(mu2, sigma2, size=600)
X = np.concatenate([X1, X2])

# Plot the density estimation using seaborn
sns.kdeplot(X)
plt.xlabel('X')
plt.ylabel('Density')
plt.title('Density Estimation of X')
plt.show()

在这里插入图片描述
初始化参数

# Initialize parameters
mu1_hat, sigma1_hat = np.mean(X1), np.std(X1)
mu2_hat, sigma2_hat = np.mean(X2), np.std(X2)
pi1_hat, pi2_hat = len(X1) / len(X), len(X2) / len(X)

执行EM算法

  • 迭代指定数量的epoch(本例中为20)。
  • 在每个epoch中,E步骤通过评估每个分量的高斯概率密度并通过相应的比例对其进行加权来计算(伽马值)。
  • M步通过计算每个分量的加权平均值和标准差来更新参数。
# Perform EM algorithm for 20 epochs
num_epochs = 20
log_likelihoods = []

for epoch in range(num_epochs):
	# E-step: Compute responsibilities
	gamma1 = pi1_hat * norm.pdf(X, mu1_hat, sigma1_hat)
	gamma2 = pi2_hat * norm.pdf(X, mu2_hat, sigma2_hat)
	total = gamma1 + gamma2
	gamma1 /= total
	gamma2 /= total
	
	# M-step: Update parameters
	mu1_hat = np.sum(gamma1 * X) / np.sum(gamma1)
	mu2_hat = np.sum(gamma2 * X) / np.sum(gamma2)
	sigma1_hat = np.sqrt(np.sum(gamma1 * (X - mu1_hat)**2) / np.sum(gamma1))
	sigma2_hat = np.sqrt(np.sum(gamma2 * (X - mu2_hat)**2) / np.sum(gamma2))
	pi1_hat = np.mean(gamma1)
	pi2_hat = np.mean(gamma2)
	
	# Compute log-likelihood
	log_likelihood = np.sum(np.log(pi1_hat * norm.pdf(X, mu1_hat, sigma1_hat)
								+ pi2_hat * norm.pdf(X, mu2_hat, sigma2_hat)))
	log_likelihoods.append(log_likelihood)

# Plot log-likelihood values over epochs
plt.plot(range(1, num_epochs+1), log_likelihoods)
plt.xlabel('Epoch')
plt.ylabel('Log-Likelihood')
plt.title('Log-Likelihood vs. Epoch')
plt.show()

在这里插入图片描述
绘制最终密度估计

# Plot the final estimated density
X_sorted = np.sort(X)
density_estimation = pi1_hat*norm.pdf(X_sorted,
										mu1_hat, 
										sigma1_hat) + pi2_hat * norm.pdf(X_sorted,
																		mu2_hat, 
																		sigma2_hat)


plt.plot(X_sorted, gaussian_kde(X_sorted)(X_sorted), color='green', linewidth=2)
plt.plot(X_sorted, density_estimation, color='red', linewidth=2)
plt.xlabel('X')
plt.ylabel('Density')
plt.title('Density Estimation of X')
plt.legend(['Kernel Density Estimation','Mixture Density'])
plt.show()

在这里插入图片描述

EM算法的应用

  • 它可用于填充样本中缺失的数据
  • 它可以作为无监督聚类学习的基础
  • 它可以用于估计隐马尔可夫模型(HMM)的参数
  • 它可以用来发现潜在变量的值

EM算法的优缺点

EM算法的优点

  • 总是保证可能性将随着每次迭代而增加
  • E步骤和M步骤在实现方面对于许多问题来说通常是相当容易的
  • M阶的解通常以封闭形式存在

EM算法的缺点

  • 它收敛缓慢
  • 它只收敛到局部最优
  • 它需要向前和向后的概率(数值优化只需要向前概率)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1551121.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++】手撕哈希表的闭散列和开散列

> 作者:დ旧言~ > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:手撕哈希表的闭散列和开散列 > 毒鸡汤:谁不是一边受伤,一边学会坚强。 > 专栏选自:C嘎嘎进阶 > 望小伙伴们…

全国超市数据可视化仪表板制作

全国超市消费数据展示 指定 Top几 客户销费数据展示 指定 Top几 省份销费数据展示 省份销售额数据分析 完整结果

蓝牙耳机哪个品牌质量最好最耐用?购前必读的高热度机型指南!

​面对市场上众多不同场景使用的蓝牙耳机,我们该如何选择呢?我们最怕遇到耳机延迟高、不防水防汗、音质差以及佩戴体验差的问题。针对这些常见问题,我这次精选了五款市面上热销且质量不错的蓝牙耳机分享给大家,让我们一起来看看吧…

当当狸智能激光雕刻机 多种材质自由雕刻,轻松打造独一无二的作品

提及“激光雕刻”,大多数人的印象一般都是:笨重巨大、价格昂贵、操作复杂、使用门槛较高、调试难度大...不是普通人能够随意操作的,让人望尘莫及。 而小米有品上新的这台「当当狸桌面智能激光雕刻机L1」,将超乎你的想象&#xff…

【unity】如何汉化unity编译器

在【unity】如何汉化unity Hub这篇文章中,我们已经完成了unity Hub的汉化,现在让我们对unity Hub安装的编译器也进行下汉化处理。 第一步:在unity Hub软件左侧栏目中点击安装,选择需要汉化的编译器,再点击设置图片按钮…

【MATLAB源码-第171期】基于matlab的布谷鸟优化算法(COA)无人机三维路径规划,输出做短路径图和适应度曲线

操作环境: MATLAB 2022a 1、算法描述 布谷鸟优化算法(Cuckoo Optimization Algorithm, COA)是一种启发式搜索算法,其设计灵感源自于布谷鸟的独特生活习性,尤其是它们的寄生繁殖行为。该算法通过模拟布谷鸟在自然界中…

RabbitMQ 的高阶应用及可靠性保证

目录 一、RabbitMQ 高阶应用 1.1 消息何去何从 1.2 过期时间 1.3 死信队列 1.4 延迟队列 1.5 优先级队列 1.6 消费质量保证(QOS) 二、持久化 三、生产者确认 四、消息可靠性和重复消费 4.1 消息可靠性 4.2 重复消费问题 上篇文章介绍了 Rabb…

理解JVM:从字节码到程序运行

大家好,我是程序员大猩猩。 今天我们来讲一下JVM,好多面试者在面试的时候,都会被问及JVM相关知识。那么JVM到底是什么,要理解它到底是出于什么原因? JVM俗称Java虚拟机,它是一个抽象的计算机,…

香港 4月7日「比特币新资产 Runes 协议」线下沙龙报名开启!

比特币减半在即,聚焦新资产、布局新赛道。 香港时间 2024 年 4 月 7 日,聚焦「比特币新资产 & Runes 协议」的《New Asset Issuance and Settlement on Bitcoin》主题沙龙将于「香港海富中心」欢迎大家的到来! 我们邀请了比特币新资产生…

题目:摆花(蓝桥OJ 0389)

问题描述&#xff1a; 题解&#xff1a; #include <bits/stdc.h> using namespace std; using ll long long; const int N 105; const ll p 1e6 7; ll a[N], dp[N][N];int main() {int n, m; cin >> n >> m;for(int i 1; i < n; i)cin >> a[i…

IDEA Android新建项目基础

title: IDEA Android基础开发 search: 2024-03-16 tags: “#JavaAndroid开发” 一、构建基本项目 在使用 IDEA 进行基础的Android 开发时&#xff0c;我们可以通过IDEA自带的新建项目功能进行Android应用开发基础架构的搭建&#xff0c;可以直接找到 File --> New --> …

【机器学习-08】参数调优宝典:网格搜索与贝叶斯搜索等攻略

超参数是估计器的参数中不能通过学习得到的参数。在scikit-learn中&#xff0c;他们作为参数传递给估计器不同类的构造函数。典型的例子有支持向量分类器的参数C&#xff0c;kernel和gamma&#xff0c;Lasso的参数alpha等。 ​ 在超参数集中搜索以获得最佳cross validation交叉…

Java开发过程中如何进行进制换换

最近由于工作上的需要&#xff0c;遇到进制转换的问题。涉及到的进制主要是十进制、十六进制、二进制转换。 1、十进制转十六进制、二进制 调用java自带的api,测试十进制转16进制、2进制 package com.kangning.common.utils.reflect;/*** 十进制 转 十六进制* 十进制 转 二进…

黑群晖Docker安装aria2-pro

前言 最近买了星际蜗牛C款当Nas&#xff0c;来满足我的存储需求&#xff0c;在之前我写过一篇docker安装aria2-pro的文章&#xff0c;既然买了nas那当然也要安装一个aria2-pro做下载器 1.安装 Container Manager 套件 可以在套件中心搜索docker找到 2.下载aria2-pro镜像 打…

力扣热门算法题 89. 格雷编码,92. 反转链表 II,93. 复原 IP 地址

89. 格雷编码&#xff0c;92. 反转链表 II&#xff0c;93. 复原 IP 地址&#xff0c;每题做详细思路梳理&#xff0c;配套Python&Java双语代码&#xff0c; 2024.03.24 可通过leetcode所有测试用例。 目录 89. 格雷编码 解题思路 完整代码 Python Java 92. 反转链表…

利用Tensor在jetson orin 上加速YOLOv5

一、第一种方法&#xff0c;需要下载各种包&#xff1a; 要用到一个大佬的开源&#xff0c;GitHub地址如下&#xff1a; https://github.com/wang-xinyu/tensorrtx/tree/master/yolov51. 安装pycuda&#xff0c;在线安装pycuda pip3 install pycuda 2. Windows操作&#xf…

Ubuntu Desktop 更改默认应用程序 (Videos -> SMPlayer)

Ubuntu Desktop 更改默认应用程序 [Videos -> SMPlayer] References System Settings -> Details -> Default Applications 概况、默认应用程序、可移动介质、法律声明 默认应用程序&#xff0c;窗口右侧列出了网络、邮件、日历、音乐、视频、照片操作的默认应用程序…

2024全行业数字化转型企业建设解决方案PPT合集(附下载)

精品推荐&#xff0c;2024全行业数字化转型企业建设解决方案PPT合集&#xff0c;精品PPT源格式共21份。 点击直达星球下载地址&#xff08;文末领取优惠券&#xff09;&#xff1a;2024全行业数字化转型企业建设解决方案PPT合集 1.制造业数字化转型解决方案及应用.pptx 2.医院…

Java代码基础算法练习-求一个三位数的各位数字之和-2024.03.27

任务描述&#xff1a; 输入一个正整数n&#xff08;取值范围&#xff1a;100<n<1000&#xff09;&#xff0c;然后输出每位数字之和 任务要求&#xff1a; 代码示例&#xff1a; package M0317_0331;import java.util.Scanner;public class m240327 {public static voi…

langchin-chatchat部分开发笔记(持续更新)

大模型相关目录 大模型&#xff0c;包括部署微调prompt/Agent应用开发、知识库增强、数据库增强、知识图谱增强、自然语言处理、多模态等大模型应用开发内容 从0起步&#xff0c;扬帆起航。 大模型应用向开发路径及一点个人思考大模型应用开发实用开源项目汇总大模型问答项目…