EM算法讲解

news2024/9/21 20:25:18

一、EM算法中隐变量的作用

在EM算法(期望最大化算法)中,隐变量(latent variables)是算法核心的一部分。它们的引入可以简化问题,将复杂的概率模型转换为更易处理的形式。隐变量的主要作用如下:

1. 转化难解问题为易解问题

  • 在许多实际应用中,直接计算模型参数的最大似然估计是非常困难的,特别是在存在未观测变量的情况下。隐变量可以帮助将难以直接处理的问题(例如,包含复杂分布的后验概率估计)分解成两个步骤:期望步骤(E步)最大化步骤(M步),从而简化计算。
    • E步:根据当前的参数估计,计算隐变量的条件期望。
    • M步:使用E步计算出的期望,优化模型参数。

通过将这些未观测的变量当作隐变量,EM算法能够在每次迭代中对这些未观测变量进行期望估计,然后更新参数。这种分解方法使得问题变得可解。

2. 隐变量帮助计算完整数据的似然

  • 在EM算法中,假设观测数据 D D D 和隐变量 Z Z Z 的联合分布是已知的。EM算法的目标是找到模型参数,使得完整数据似然 P ( D , Z ∣ θ ) P(D, Z | \theta) P(D,Zθ) 最大化。
    • 但是,由于隐变量 Z Z Z 是不可观测的,我们不能直接计算其真实值。因此,EM算法在E步中计算的是隐变量的条件期望,即在当前模型参数条件下,隐变量 Z Z Z 的可能取值。

这种条件期望允许我们绕过隐变量的实际值,间接地通过观测数据更新参数,最终收敛到最大似然估计。

3. 引入隐变量提供更多的模型解释力

  • 隐变量不仅在技术上帮助EM算法进行参数估计,还提供了对数据的更多解释。例如,在高斯混合模型(Gaussian Mixture Model, GMM)中,隐变量可以表示每个样本属于哪个高斯分布成分。通过引入这些隐变量,我们可以在E步中为每个样本分配一个“软”类别归属,表示其属于不同成分的概率。

这种归属可以帮助模型更好地捕捉数据中的潜在结构,并解释不同数据点是如何生成的。

4. 处理缺失数据

  • 隐变量通常还用于处理缺失数据。在许多问题中,某些数据点可能是不可观测的。隐变量可以自然地表示这些缺失值,EM算法则通过估计这些隐变量的期望值(即填补缺失数据),从而继续优化模型参数。

总结

在EM算法中,隐变量的作用主要体现在:

  1. 简化难解问题:将复杂问题分解成可迭代的两步,分别计算期望和最大化。
  2. 辅助似然计算:通过估计隐变量的期望,帮助间接地优化似然函数。
  3. 增强解释力:帮助模型更好地捕捉数据的潜在结构。
  4. 处理缺失数据:有效应对部分数据缺失的情况。

隐变量在EM算法中起到了核心的连接作用,使得算法能够在存在未观测变量或复杂概率分布的情况下,依然有效地进行参数估计。

二、实际例子讲解EM算法隐变量的应用

让我们用一个具体的例子来讲解EM算法中隐变量的应用:高斯混合模型(Gaussian Mixture Model, GMM)

高斯混合模型(GMM)概述

高斯混合模型是一种常用的聚类算法,用于将数据划分为多个簇,每个簇由一个高斯分布(即正态分布)来描述。在GMM中,数据点属于不同的高斯分布(簇),但我们不知道每个数据点属于哪个簇,这就是问题中的隐变量

隐变量 Z Z Z 用于表示每个数据点 X X X 属于哪个高斯分布。由于我们无法直接观测到数据点的簇归属,因此 Z Z Z 是隐含的,但它对模型参数的估计非常关键。

EM算法在高斯混合模型中的应用

假设我们有一个二维数据集,我们认为这些数据是由两个不同的高斯分布生成的,但我们不知道每个点属于哪个高斯分布,也不知道两个高斯分布的参数。

1. 模型设置

在GMM中,每个数据点 X i X_i Xi 的概率密度可以写成多个高斯分布的加权和:
P ( X i ∣ θ ) = ∑ k = 1 K π k ⋅ N ( X i ∣ μ k , Σ k ) P(X_i | \theta) = \sum_{k=1}^{K} \pi_k \cdot \mathcal{N}(X_i | \mu_k, \Sigma_k) P(Xiθ)=k=1KπkN(Xiμk,Σk)

  • K K K 是高斯分布的数量(即簇的数量)。
  • π k \pi_k πk 是第 k k k 个高斯分布的权重,满足 ∑ k = 1 K π k = 1 \sum_{k=1}^{K} \pi_k = 1 k=1Kπk=1
  • μ k \mu_k μk Σ k \Sigma_k Σk 分别是第 k k k 个高斯分布的均值和协方差矩阵。
  • θ = { π k , μ k , Σ k } \theta = \{\pi_k, \mu_k, \Sigma_k\} θ={πk,μk,Σk} 是需要估计的参数。

隐变量 Z i Z_i Zi 是一个指示变量,表示第 i i i 个数据点属于哪个高斯分布(即哪个簇)。

2. EM算法的步骤

步骤 1:初始化参数

  • 随机初始化参数 θ = { π k , μ k , Σ k } \theta = \{\pi_k, \mu_k, \Sigma_k\} θ={πk,μk,Σk},即每个簇的权重、均值和协方差矩阵。

步骤 2:E步(期望步骤)

  • 计算每个数据点 X i X_i Xi 属于每个高斯分布 k k k后验概率,即隐变量 Z i Z_i Zi 的条件期望。可以理解为每个数据点在当前参数下属于每个簇的概率:
    γ ( Z i k ) = π k ⋅ N ( X i ∣ μ k , Σ k ) ∑ j = 1 K π j ⋅ N ( X i ∣ μ j , Σ j ) \gamma(Z_{ik}) = \frac{\pi_k \cdot \mathcal{N}(X_i | \mu_k, \Sigma_k)}{\sum_{j=1}^{K} \pi_j \cdot \mathcal{N}(X_i | \mu_j, \Sigma_j)} γ(Zik)=j=1KπjN(Xiμj,Σj)πkN(Xiμk,Σk)
  • 这里 γ ( Z i k ) \gamma(Z_{ik}) γ(Zik) 表示在第 t t t 轮迭代后,第 i i i 个数据点属于簇 k k k 的概率。

步骤 3:M步(最大化步骤)

  • 利用E步中的隐变量期望,重新估计模型参数 θ \theta θ
    • 更新每个高斯分布的均值
      μ k = ∑ i = 1 N γ ( Z i k ) X i ∑ i = 1 N γ ( Z i k ) \mu_k = \frac{\sum_{i=1}^{N} \gamma(Z_{ik}) X_i}{\sum_{i=1}^{N} \gamma(Z_{ik})} μk=i=1Nγ(Zik)i=1Nγ(Zik)Xi
    • 更新协方差矩阵
      Σ k = ∑ i = 1 N γ ( Z i k ) ( X i − μ k ) ( X i − μ k ) T ∑ i = 1 N γ ( Z i k ) \Sigma_k = \frac{\sum_{i=1}^{N} \gamma(Z_{ik}) (X_i - \mu_k)(X_i - \mu_k)^T}{\sum_{i=1}^{N} \gamma(Z_{ik})} Σk=i=1Nγ(Zik)i=1Nγ(Zik)(Xiμk)(Xiμk)T
    • 更新每个高斯分布的权重
      π k = 1 N ∑ i = 1 N γ ( Z i k ) \pi_k = \frac{1}{N} \sum_{i=1}^{N} \gamma(Z_{ik}) πk=N1i=1Nγ(Zik)
    • 这里 N N N 是数据点的数量。

步骤 4:收敛条件

  • 重复步骤2和步骤3,直到参数变化足够小,或者达到最大迭代次数。
3. 具体解释隐变量的作用

在GMM模型中,我们的目标是找到每个数据点属于哪个高斯分布(即属于哪个簇)。由于这些信息是未知的,隐变量 Z Z Z 表示每个数据点归属哪个簇。在E步中,我们使用当前参数来估计这些隐变量的期望值,即每个数据点属于各个簇的概率。然后,在M步中,我们使用这些期望值来更新模型参数(均值、协方差、权重),使得数据点的归属越来越明确。

隐变量的作用是帮助处理簇归属未知的情况。它允许我们在没有明确知道每个数据点属于哪个簇的情况下,仍然可以通过估计簇归属的概率来迭代优化模型。

总结

在高斯混合模型中,隐变量表示每个数据点属于哪个簇。由于这些信息是不可观测的,EM算法通过E步估计这些隐变量的期望值,然后在M步中利用这些期望值来更新模型参数。隐变量的引入使得我们能够在观测数据存在潜在结构(如不同簇)的情况下,有效地估计模型参数。

三、 具体实现

为了演示EM算法以及其在高斯混合模型中的应用,我们将实现一个简化版本的GMM,而不是直接调用现有的库。整个过程包含以下步骤:

  1. 初始化参数:初始化混合模型的参数,比如高斯分布的均值、协方差矩阵和权重。
  2. E步(期望步骤):计算每个数据点属于每个高斯分布的概率,即隐变量的期望。
  3. M步(最大化步骤):使用E步中的结果更新高斯分布的参数。
  4. 迭代执行E步和M步,直到参数收敛。

代码实现(不调用库的情况下手动实现EM算法)

import numpy as np

# 定义二维高斯概率密度函数
def gaussian_pdf(x, mean, cov):
    n = x.shape[0]
    diff = x - mean
    exp_part = np.exp(-0.5 * np.dot(np.dot(diff.T, np.linalg.inv(cov)), diff))
    return (1.0 / np.sqrt((2 * np.pi) ** n * np.linalg.det(cov))) * exp_part

# E步:计算每个数据点属于每个高斯分布的后验概率
def e_step(X, means, covs, pis):
    N = X.shape[0]  # 数据点数量
    K = len(means)  # 高斯分布数量
    responsibilities = np.zeros((N, K))

    # 计算每个数据点属于每个高斯分布的概率
    for i in range(N):
        for k in range(K):
            responsibilities[i, k] = pis[k] * gaussian_pdf(X[i], means[k], covs[k])
        responsibilities[i, :] /= np.sum(responsibilities[i, :])  # 归一化
    
    return responsibilities

# M步:更新模型参数
def m_step(X, responsibilities):
    N, d = X.shape
    K = responsibilities.shape[1]

    pis = np.zeros(K)
    means = np.zeros((K, d))
    covs = np.zeros((K, d, d))

    for k in range(K):
        Nk = np.sum(responsibilities[:, k])
        pis[k] = Nk / N  # 更新簇的权重

        # 更新均值
        means[k] = np.sum(responsibilities[:, k].reshape(-1, 1) * X, axis=0) / Nk

        # 更新协方差矩阵
        cov = np.zeros((d, d))
        for i in range(N):
            diff = (X[i] - means[k]).reshape(-1, 1)
            cov += responsibilities[i, k] * np.dot(diff, diff.T)
        covs[k] = cov / Nk

    return pis, means, covs

# 主EM算法循环
def gmm_em(X, K, max_iters=100, tol=1e-4):
    N, d = X.shape

    # 初始化参数
    np.random.seed(42)
    means = X[np.random.choice(N, K, replace=False)]  # 从数据中随机选取初始均值
    covs = [np.eye(d) for _ in range(K)]  # 初始协方差矩阵为单位矩阵
    pis = np.ones(K) / K  # 初始每个簇的权重相等

    log_likelihoods = []

    for iteration in range(max_iters):
        # E步
        responsibilities = e_step(X, means, covs, pis)

        # M步
        pis, means, covs = m_step(X, responsibilities)

        # 计算当前的对数似然
        log_likelihood = 0
        for i in range(N):
            prob = 0
            for k in range(K):
                prob += pis[k] * gaussian_pdf(X[i], means[k], covs[k])
            log_likelihood += np.log(prob)
        log_likelihoods.append(log_likelihood)

        # 检查是否收敛
        if iteration > 0 and np.abs(log_likelihoods[-1] - log_likelihoods[-2]) < tol:
            break

    return pis, means, covs, responsibilities, log_likelihoods

# 生成数据
np.random.seed(42)
mean1 = [0, 0]
cov1 = [[1, 0.5], [0.5, 1]]
X1 = np.random.multivariate_normal(mean1, cov1, 300)

mean2 = [3, 4]
cov2 = [[1, -0.2], [-0.2, 1]]
X2 = np.random.multivariate_normal(mean2, cov2, 300)

X = np.vstack([X1, X2])

# 运行EM算法
K = 2  # 假设有两个簇
pis, means, covs, responsibilities, log_likelihoods = gmm_em(X, K)

# 输出结果
print("簇的权重:", pis)
print("簇的均值:", means)
print("簇的协方差矩阵:", covs)

代码详解

  1. 二维高斯概率密度函数gaussian_pdf 函数计算每个数据点属于某个高斯分布的概率。
  2. E步e_step 函数计算每个数据点属于每个高斯分布的后验概率(即隐变量的期望)。
  3. M步m_step 函数根据E步的结果更新高斯分布的参数,包括簇的权重、均值和协方差矩阵。
  4. EM循环:主循环中,E步和M步交替进行,直到对数似然值收敛。

输出结果

程序将输出每个簇的权重、均值和协方差矩阵。这些参数是通过EM算法迭代计算得出的,隐变量帮助我们在没有观测到簇归属的情况下,估计出这些簇的参数。

可视化结果

可以在输出结果之后添加可视化代码来查看聚类效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2153231.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为OD机试 - 信号强度(Python/JS/C/C++ 2024 E卷 100分)

华为OD机试 2024E卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试真题&#xff08;Python/JS/C/C&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;私信哪吒&#xff0c;备注华为OD&#xff0c;加入华为OD刷题交流群&#xff0c;…

网站渗透这块水太深,你把握不住!但你叔我能(十年经验分享)

很多朋友问我&#xff0c;想搞网络安全&#xff0c;编程重要吗&#xff0c;选什么语言呢&#xff1f; 国内其实正经开设网络安全专业的学校很少&#xff0c;大部分同学是来自计算机科学、网络工程、软件工程专业的&#xff0c;甚至很多非计算机专业自学的。因此不像这三个专业…

【技术文章】ArcGIS Pro如何批量导出符号和工程样式?

目录 1.确定Pro软件版本 2.共享工程样式 3.管理和调用项目样式 制作好的地图&#xff0c;如何快速分享地图中的符号样式用于其它地图的制作&#xff1f; 在ArcMap软件中&#xff0c;可以通过命令一键批量导出所有符号。ArcGIS Pro软件是否也可以批量导出符号用于其它地图…

Java-数据结构-排序-(一) (。・ω・。)

文本目录&#xff1a; ❄️一、排序的概念及引用&#xff1a; ➷ 排序的概念&#xff1a; ➷ 常见的排序算法&#xff1a; ❄️二、插入排序的实现&#xff1a; ➷ 1、直接插入排序&#xff1a; ☞ 直接插入排序的基本思想&#xff1a; ☞ 直接插入排序的实现&#xff1a; ▶…

UI自动化测试(python)Web端4.0

✨博客主页&#xff1a; https://blog.csdn.net/m0_63815035?typeblog &#x1f497;《博客内容》&#xff1a;.NET、Java.测试开发、Python、Android、Go、Node、Android前端小程序等相关领域知识 &#x1f4e2;博客专栏&#xff1a; https://blog.csdn.net/m0_63815035/cat…

PyCharm与Anaconda超详细安装配置教程

1、安装Anaconda&#xff08;过程&#xff09;-CSDN博客 2.创建虚拟环境conda create -n pytorch20 python3.9并输入conda activate pytorch20进入 3.更改镜像源conda/pip(只添加三个pip源和conda源即可) 4.安装PyTorch&#xff08;CPU版&#xff09; 5.安装Pycharm并破解&…

使用 Anaconda 环境在Jupyter和PyCharm 中进行开发

目录 前言 一、在特定环境中使用jupyter 1. 列出所有环境 2. 激活环境 3. 进入 Jupyter Notebook 二、在特定环境中使用pycham 1. 打开 PyCharm 2. 打开设置 3. 配置项目解释器 4. 选择 Conda 环境 5. 应用设置 6. 安装所需库&#xff08;如果需要&#xff09; 总结 &#x1f3…

2024年中国研究生数学建模竞赛C题——解题思路

2024年中国研究生数学建模竞赛C题——解题思路 数据驱动下磁性元件的磁芯损耗建模——解决思路 二、问题描述 为解决磁性元件磁芯材料损耗精确计算问题&#xff0c;通过实测磁性元件在给定工况&#xff08;不同温度、频率、磁通密度&#xff09;下磁芯材料损耗的数据&#xf…

卡西欧相机SD卡格式化后数据恢复指南

在数字摄影时代&#xff0c;卡西欧相机以其卓越的性能和便携性成为了众多摄影爱好者的首选。然而&#xff0c;随着拍摄量的增加&#xff0c;SD卡中的数据管理变得尤为重要。不幸的是&#xff0c;有时我们可能会因为操作失误或系统故障而将SD卡格式化&#xff0c;导致珍贵的照片…

在线骑行网站设计与实现

摘 要 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c;最后&#xff0c;检索数据费事费力。因此&#xff0c;在计算机上安装在线骑行网站软件来发挥其高效地信息处理的作用&#xff0c…

C++之深拷贝和浅拷贝*

两者本质&#xff1a; 浅拷贝&#xff1a;简单的赋值拷贝操作 深拷贝&#xff1a;在堆区中重新申请空间&#xff0c;进行拷贝操作new & delete 注意事项&#xff1a;堆区是在地址中重新申请空间&#xff0c;所以后续一系列操作new delete是通过指针* age进行操作&#xff0…

某 XXX 云主机,使用感受

简单来说就是: 垃圾&#xff01; 1. 登录垃圾。 我都已经实名认证了&#xff0c; 手机验证码非要发2遍。非要给我起个很难记住的账户名 2. 文档垃圾。 太高估用户的水平了。 建议做点视频教程。而不是各种文档&#xff0c;互相链接&#xff0c;转来转去&#xff0c; 让人心…

LeetCode[简单] 20.有效的括号

给定一个只包括 (&#xff0c;)&#xff0c;{&#xff0c;}&#xff0c;[&#xff0c;] 的字符串 s &#xff0c;判断字符串是否有效。 有效字符串需满足&#xff1a; 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号都有一个对应的相同类型的左括…

NRK3502空气净化器语音芯片方案,本地识别算法+芯片架构

随着环境污染问题的日益严重&#xff0c;空气净化器成为人们居家、办公环境中不可或缺的设备&#xff0c;为了提升用户体验和产品性能&#xff0c;广州九芯电子研发出了一款创新的空气净化器语音芯片方案--NRK3502。此方案结合了本地识别算法与芯片架构&#xff0c;提供Turnkey…

SpringBoot+vue集成sm2国密加密解密

文章目录 前言认识SM2后端工具类实现引入依赖代码实现工具类&#xff1a;SM2Util 单元测试案例1&#xff1a;生成服务端公钥、私钥&#xff0c;前端js公钥、私钥案例2&#xff1a;客户端加密&#xff0c;服务端完成解密案例3&#xff1a;服务端进行加密&#xff08;可用于后面前…

巴蒂克图案识别系统源码分享

巴蒂克图案识别检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer V…

安全热点问题

安全热点问题 1.DDOS2.补丁管理3.堡垒机管理4.加密机管理 1.DDOS 分布式拒绝服务攻击&#xff0c;是指黑客通过控制由多个肉鸡或服务器组成的僵尸网络&#xff0c;向目标发送大量看似合法的请求&#xff0c;从而占用大量网络资源使网络瘫痪&#xff0c;阻止用户对网络资源的正…

手把手教你java+selenium数据驱动测试框架搭建与实践

最近在看JavaseleniumTestNgExcel的数据驱动&#xff0c;如何使用TestNg和Excel进行数据驱动测试。我其实是个自动化测试小白&#xff0c;工作之余看看这方面的书&#xff0c;照着敲敲代码&#xff0c;慢慢理解&#xff0c;希望通过自己坚持不懈的努力&#xff0c;在测试这个职…

Python语言基础教程(下)4.0

✨博客主页&#xff1a; https://blog.csdn.net/m0_63815035?typeblog &#x1f497;《博客内容》&#xff1a;.NET、Java.测试开发、Python、Android、Go、Node、Android前端小程序等相关领域知识 &#x1f4e2;博客专栏&#xff1a; https://blog.csdn.net/m0_63815035/cat…

记录一下(goland导入其他包方法编译不了爆红但能正常使用)

在goLand里面新建go文件,里面放了一个方法,首字符也大写了,但是在别的包里面报错显示无法识别,爆红显示,但是项目能正常运行,这里说下我的解决方案 即可解决