什么是期望最大化算法?

news2024/9/28 18:25:19

一、期望最大化算法

        期望最大化(EM)算法是一种在统计学和机器学习中广泛使用的迭代方法,它特别适用于含有隐变量的概率模型参数估计问题。在统计学和机器学习中,有很多不同的模型,例如高斯混合模型(GMM)、隐马尔可夫模型(HMM)等,都可以用EM算法来估计这些模型中的参数。EM算法的主要思想是通过两个步骤的交替执行来找到模型参数的估计值:期望(E)步骤和最大化(M)步骤。此外,EM算法的收敛性也意味着它可以在多次迭代后得到稳定的参数估计,这对于模型的预测和分析非常重要。

二、期望最大化算法原理

1、E步骤(Expectation step)

        在E步骤中,我们计算隐变量的条件期望给定观测数据和当前参数估计。假设我们有一个数据集X = \left \{ x_{1},x_{2},...,x_{N} \right \},隐变量Z = \left \{ z_{1},z_{2},...,z_{N} \right \}和参数\theta,则E步骤计算的是:

Q(z_{i}|x_{i},\theta) = P(z_{i}|x_{i},\theta)

        其中,Q(z_{i}|x_{i},\theta)是隐变量z_{i}在观测数据x_{i}和当前参数\theta下的条件概率。

2、M步骤(Maximization step):

        在M步骤中,我们利用E步骤计算出的隐变量的分布来更新参数\theta的估计,以最大化似然函数。M步骤的计算公式为:

\theta^{(new)} = argmax_{\theta}\sum_{i=1}^{N}\sum_{z_{i}}Q(z_{i}|x_{i},\theta^{(old)})logp(x_{i},z_{i}|\theta)

        其中,\theta^{(new)}是更新后的参数估计,\theta^{(old)}是上一步的参数估计,N是数据点的数量,求和是对所有数据点和所有可能的隐变量值进行的。

三、EM算法应用

        假设我们有一个高斯混合模型GMM,其中有K个高斯分布,参数为\phi = \left \{ \pi _{k},\mu _{k},\sigma _{k}^{2} \right \},其中\pi_{k}是第k个高斯分布的权重,\mu_{k}是均值,\sigma _{k}^{2}是方差,则计算EM有:

1、E步骤

        计算每个数据点属于每个高斯分布的responsibility(也称为 posterior probability):

Q(z_{i}|x_{i},\phi ) = \frac{\pi_{k}N(x_{i}|\mu_{k},\sigma_{k}^{2})}{\sum_{j=1}^{K}\pi_{j}N(x_{i}|\mu_{j},\sigma_{j}^{2})}

        这里,N(x|\mu,\sigma^{2})是多元正态分布的概率密度函数

2、M步骤

        更新每个高斯分布的参数:

\pi_{k} = \frac{1}{N}\sum_{i=1}^{N}Q(z_{i}=k|x_{i},\phi)

\mu_{k} = \frac{\sum_{i=1}^{N}Q(z_{i}=k|x_{i},\phi)x_{i}}{\sum_{i=1}^{N}Q(z_{i}=k|x_{i},\phi)}

\sigma_{k}^{2} = \frac{\sum_{i=1}^{N}Q(z_{i}=k|x_{i},\phi)(x_{i}-\mu_{k})^{2}}{\sum_{i=1}^{N}Q(z_{i}=k|x_{i},\phi)}

        其中,Q(z_{i}=k|x_{i},\phi)是数据点x_{i}属于第k个高斯分布的后验概率。N是数据点的数量,求和对所有数据点进行,E步骤和M步骤交替进行,知道参数\phi收敛。参数更新公式中,分子和分母有相同的部分,但不能简单约去,因为分母中的部分确保了每个数据点在计算新的均值时,其贡献是按照它属于该高斯分布的概率加权的。

四、python实现EM算法

        这里,首先生成两个高斯分布的数据,然后定义一个高斯函数来计算给定均值和标准差的数据的概率密度。接下来,定义E步骤和M步骤的函数。最后,运行EM算法迭代100次。

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

# 生成示例数据
np.random.seed(42)
X = np.vstack([np.random.multivariate_normal([0, 0], np.eye(2), 100),
               np.random.multivariate_normal([5, 5], np.eye(2), 100)])

# 定义高斯函数
def gaussian(X, mean, cov):
    return multivariate_normal.pdf(X, mean, cov)

# E步骤
def e_step(X, weights, means, covariances):
    n, d = X.shape
    k = len(weights)
    responsibilities = np.zeros((n, k))
    
    for i in range(k):
        responsibilities[:, i] = weights[i] * gaussian(X, means[i], covariances[i])
    
    responsibilities /= responsibilities.sum(axis=1, keepdims=True)
    return responsibilities

# M步骤
def m_step(X, responsibilities):
    n, d = X.shape
    k = responsibilities.shape[1]
    
    weights = responsibilities.sum(axis=0) / n
    means = np.dot(responsibilities.T, X) / responsibilities.sum(axis=0)[:, np.newaxis]
    covariances = np.zeros((k, d, d))
    
    for i in range(k):
        diff = X - means[i]
        covariances[i] = np.dot(responsibilities[:, i] * diff.T, diff) / responsibilities[:, i].sum()
    
    return weights, means, covariances

# 初始化参数
def initialize_parameters(X, k):
    n, d = X.shape
    weights = np.ones(k) / k
    means = X[np.random.choice(n, k, False)]
    covariances = np.array([np.eye(d)] * k)
    return weights, means, covariances

# EM算法
def em_algorithm(X, k, max_iter=100, tol=1e-6):
    weights, means, covariances = initialize_parameters(X, k)
    log_likelihoods = []
    
    for i in range(max_iter):
        responsibilities = e_step(X, weights, means, covariances)
        weights, means, covariances = m_step(X, responsibilities)
        log_likelihoods.append(log_likelihood(X, weights, means, covariances))
        
        if i > 0 and np.abs(log_likelihoods[-1] - log_likelihoods[-2]) < tol:
            break
    
    return weights, means, covariances, log_likelihoods, responsibilities

# 计算对数似然
def log_likelihood(X, weights, means, covariances):
    n, d = X.shape
    k = len(weights)
    log_likelihood = 0
    
    for i in range(k):
        log_likelihood += weights[i] * gaussian(X, means[i], covariances[i])
    
    return np.log(log_likelihood).sum()

# 绘制高斯分布的等高线
def draw_ellipse(mean, cov, ax, label, alpha=1.0):
    from matplotlib.patches import Ellipse
    v, w = np.linalg.eigh(cov)
    v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
    u = w[0] / np.linalg.norm(w[0])
    angle = np.arctan(u[1] / u[0])
    angle = 180.0 * angle / np.pi
    ell = Ellipse(mean, v[0], v[1], 180.0 + angle, edgecolor='red', lw=2, facecolor='none', alpha=alpha, label=label)
    ax.add_patch(ell)

# 运行EM算法
k = 2
weights, means, covariances, log_likelihoods, responsibilities = em_algorithm(X, k)

# 可视化最终结果
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], s=10, label='Data points')
ax = plt.gca()
for j in range(k):
    draw_ellipse(means[j], covariances[j], ax, label=f'Gaussian {j+1}', alpha=weights[j])
plt.title('Final Gaussian Mixture Model')
plt.legend()
plt.show()

# 打印结果
print("权重:", weights)
print("均值:", means)
print("协方差矩阵:", covariances)

# 打印对数似然
print("对数似然:", log_likelihoods[-1])

# 计算AIC和BIC
n, d = X.shape
num_params = k * (d + d * (d + 1) / 2) + k - 1
aic = 2 * num_params - 2 * log_likelihoods[-1]
bic = np.log(n) * num_params - 2 * log_likelihoods[-1]
print("AIC:", aic)
print("BIC:", bic)

        其中,aic和bic计算模型的AIC和BIC值,AIC和BIC值越小,表示模型越好。理想情况下,高斯分布的等高线应该很好地覆盖数据点的分布区域,可视化结果如下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2174390.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NSSCTF [HNCTF 2022 Week1]超级签到

查看主函数 看到遍历 Str2&#xff0c;如果字符为 o&#xff0c;则替换为 0 int __fastcall main_0(int argc, const char **argv, const char **envp) {char *v3; // 指向 v7 的指针__int64 i; // 循环计数器size_t v5; // 存储 Str2 的长度char v7; // 存储输入字符int j; …

如何快速自定义一个Spring Boot Starter!!

目录 引言&#xff1a; 一. 我们先创建一个starter模块 二. 创建一个自动配置类 三. 测试启动 引言&#xff1a; 在我们项目中&#xff0c;可能经常用到别人的第三方依赖&#xff0c;又是引入依赖&#xff0c;又要自定义配置&#xff0c;非常繁琐&#xff0c;当我们另一个项…

mysql8.0安装后没有my.ini

今天安装mysql后想改一下配置文件看了一下安装路径 C:\Program Files\MySQL\MySQL Server 8.0 发现根本没有这个文件查看隐藏文件也没用查了之后才知道换地方了和原来的5.7不一样 新地址是C:\ProgramData\MySQL\MySQL Server 8.0 文件也是隐藏的记得改一下配置

【Redis 源码】7RDB持久化

1 功能说明 RDB (Redis Database Backup) 是 Redis 的一种持久化方式&#xff0c;它通过将某一时刻的内存快照&#xff08;snapshot&#xff09;以二进制格式保存到磁盘上。这种持久化方式提供了高性能和紧凑的数据存储&#xff0c;但相对于 AOF (Append Only File) 来说&…

充电桩安装-理想充电桩如何安装全流程-从准备到材料准备全流程

充电桩安装 Willya 2023年3月6日 新能源车出行成本低&#xff0c;那肯定是要在便利的条件下&#xff0c;得有自己的充电桩才行&#xff0c;实在安装不了自己的充电桩&#xff0c;那也要保证居住周边有充足的充电站&#xff0c;这样才能保证用车的便捷。 理想汽车充电桩安装一般…

智能化转型新篇章:EasyCVR引领大型连锁超市视频监控进入AI时代

随着科技的飞速发展&#xff0c;视频监控系统在各行各业中的应用日益广泛&#xff0c;大型连锁超市作为人员密集、商品繁多的公共场所&#xff0c;其安全监控显得尤为重要。为了提升超市的安全管理水平、减少损失、保障顾客和员工的安全&#xff0c;引入高效、全面的视频监控系…

胤娲科技:AI界的超级充电宝——忆阻器如何让LLM告别电量焦虑

当AI遇上“记忆橡皮擦”&#xff0c;电量不再是问题&#xff01; 嘿&#xff0c;朋友们&#xff0c;你们是否曾经因为手机电量不足而焦虑得像个无头苍蝇&#xff1f;想象一下&#xff0c;如果这种“电量焦虑”也蔓延到了AI界&#xff0c; 特别是那些聪明绝顶但“耗电如喝水”的…

逃离陷阱:如何巧妙避免机器学习中的过拟合与欠拟合

逃离陷阱&#xff1a;如何巧妙避免机器学习中的过拟合与欠拟合 前言过拟合&#xff1a;定义与识别定义表现原因示例&#xff1a;决策树模型的过拟合 欠拟合&#xff1a;定义与识别定义表现原因示例&#xff1a;线性回归模型的欠拟合 避免过拟合的策略减少模型复杂度使用正则化…

基于nodejs+vue的校园二手物品交易系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目…

SSM超市售卖管理系统-计算机毕业设计源码23976

目 录 摘要 Abstract 1 绪论 1.1研究的背景和意义 1.2研究内容 1.3论文结构与章节安排 2 开发技术介绍 2.1 SSM框架 2.2 MySQL数据库 3 超市售卖管理系统系统分析 3.1 可行性分析 3.2 系统流程分析 3.2.1 数据流程 3.3.2 业务流程 3.3 系统功能分析 3.3.1 功…

低代码可视化-UniApp二维码可视化-代码生成器

市面上提供了各种各样的二维码组件&#xff0c;做了一简单的uniapp二维码组件&#xff0c;二维码实现依赖davidshimjs/qrcodejs。 组件特点 跨浏览器支持&#xff1a;利用Canvas元素实现二维码的跨浏览器兼容性&#xff0c;兼容微信小程序、h5、app。 无依赖性&#xff1a;QR…

留学生如何适应海外生活以及应对文化差异

对于即将出国学习和生活的留学生来说&#xff0c;文化差异和生活方式的变化常常是一个紧迫的问题。那么&#xff0c;如何应对这些文化差异&#xff0c;以及如何适应新的学习环境和社交生活呢&#xff1f;本文将分享一些具体可行的建议和方法&#xff0c;助您顺利跨越这道难关&a…

数据结构:队列及其应用

队列&#xff08;Queue&#xff09;是一种特殊的线性表&#xff0c;它的主要特点是先进先出&#xff08;First In First Out&#xff0c;FIFO&#xff09;。队列只允许在一端&#xff08;队尾&#xff09;进行插入操作&#xff0c;而在另一端&#xff08;队头&#xff09;进行删…

Hadoop三大组件之YARN(一)

YARN架构与任务提交流程详解 1. YARN的组成架构 YARN&#xff08;Yet Another Resource Negotiator&#xff09;是Hadoop生态系统中的一个重要组成部分&#xff0c;主要用于资源管理和调度。YARN的架构主要由以下几个关键组件构成&#xff1a; 1.1 ResourceManager&#xff…

vue3结合 vue-router和keepalive实现路由跳转保持滚动位置不改变(超级简易清晰)

1.首先我们在路由跳转页面设置keepalive(Seeall是我想实现结果的页面) 2. 想实现结果的页面中如果不是全屏实现滚动而是有单独的标签实现滚动效果

Spring Boot技术栈:打造高效在线商城

2 相关技术 2.1 Springboot框架介绍 Spring Boot是由Pivotal团队提供的全新框架&#xff0c;其设计目的是用来简化新Spring应用的初始搭建以及开发过程。该框架使用了特定的方式来进行配置&#xff0c;从而使开发人员不再需要定义样板化的配置。通过这种方式&#xff0c;Spring…

AI大模型对我国劳动力市场潜在影响研究报告(2024)|附19页PDF文件下载

前言 北京大学国家发展研究院与智联招聘日前联合发布《AI大模型对我国劳动力市场潜在影响研究》。该研究显示&#xff0c;2024年上半年&#xff0c;招聘职位数同比增速前五的人工智能职业&#xff0c;包括大语言模型方面的自然语言处理&#xff08;111%&#xff09;、深度学习…

STM32 RTC实时时钟学习总结

STM32 RTC实时时钟学习总结 写于2024/9/25下午 文章目录 STM32 RTC实时时钟学习总结1. 简介2. 流程框图介绍3. 相关寄存器介绍4. 代码解析 1. 简介 STM32F103 的实时时钟&#xff08;RTC&#xff09;是一个独立的定时器。STM32 的 RTC 模块拥有一组连续计数的计数器&#xff…

【C语言】动态内存管理:malloc、calloc、realloc、free

本篇介绍一下C语言中的malloc/calloc/realloc。 使用这些函数需要包含头文件<stdlib.h>。malloc/calloc/realloc申请的空间都是 堆区的。 1.malloc和free 1.1 malloc C语言提供了一个动态内存开辟的函数malloc&#xff0c;函数原型如下。 void* malloc(size_t size);…

实例讲解电动汽车故障限功限速控制策略及Simulink建模方法

电动汽车出现转向系统、制动系统及其他对车辆行车产生一定风险的故障&#xff0c;整车控制器判定为二级故障&#xff0c;功率限制为正常状态的50%&#xff0c;车速限制至20km/h。当车辆进入二级故障后&#xff0c;VCU需要根据故障处理机制进行限功限速控制。有关故障分级处理策…