24/11/14 算法笔记 EM算法期望最大化算法

news2024/11/16 2:47:01

EM算法用于含有隐变量的概率参数模型的最大似然估计或极大后验概率估计。它在机器学习和统计学中有着广泛的应用,比如在高斯混合模型(GMM)、隐马尔可夫模型(HMM)以及各种聚类和分类问题中。

EM算法的基本思想是:首先根据已经给出的观测数据,估计出模型参数的值;然后再根据上一步估计出的参数值估计缺失数据的值,再根据估计出的缺失数据加上之前已经观测到的数据重新再对参数值进行估计,然后反复迭代,直至最后收敛,迭代结束

EM算法的迭代过程分为两个步骤:

  1. E步(Expectation Step):在这一步中,算法会估计缺失的或隐藏的变量,即基于当前模型参数的估计来计算隐藏变量的期望值。这一步实际上是在计算一个下界(lower bound)或者说是对缺失数据的期望可能性(expected likelihood)的估计。
  2. M步(Maximization Step):在这一步中,算法会优化模型参数以最大化在E步中计算出的期望可能性。这一步实际上是在最大化一个代理函数(surrogate function),这个函数依赖于隐藏变量的期望值。

EM算法的收敛性通常是有保证的,但收敛速度可能是一个问题,特别是在高维数据和复杂模型中。此外,EM算法能够估计复杂模型的参数,但这种复杂性可能会导致模型解释性降低。在实际应用中,我们需要仔细考虑这种权衡。

下面是一个简单的EM算法实现的例子,用于高斯混合模型(GMM)的参数估计

1. 首先,我们需要导入一些必要的库:

import numpy as np
from scipy.stats import multivariate_normal
  • scipy.stats 是SciPy库中的一个子库,它提供了统计学和概率论的函数,这里我们使用其中的 multivariate_normal 类来表示多变量高斯分布。

2. 然后,我们定义一个类来表示高斯混合模型:

class GaussianMixture:
    def __init__(self, n_components=3, covariance_type='full'):
        self.n_components = n_components
        self.covariance_type = covariance_type
        self.weights = None
        self.means = None
        self.covariances = None

    def fit(self, X, n_iter=100):
        n_samples, n_features = X.shape
        self.weights = np.ones(self.n_components) / self.n_components
        self.means = np.random.rand(self.n_components, n_features)
        self.covariances = np.array([np.eye(n_features)] * self.n_components)

        for i in range(n_iter):
            # E-step: Compute responsibilities
            responsibilities = self._compute_responsibilities(X)
            # M-step: Update parameters
            self._update_parameters(X, responsibilities)

    def _compute_responsibilities(self, X):
        responsibilities = np.zeros((X.shape[0], self.n_components))
        for i in range(self.n_components):
            rv = multivariate_normal(self.means[i], self.covariances[i])
            responsibilities[:, i] = self.weights[i] * rv.pdf(X) / np.sum([self.weights[j] * multivariate_normal(self.means[j], self.covariances[j]).pdf(X) for j in range(self.n_components)], axis=0)
        return responsibilities

    def _update_parameters(self, X, responsibilities):
        self.weights = np.mean(responsibilities, axis=0)
        for i in range(self.n_components):
            self.means[i] = np.sum(responsibilities[:, i].reshape(-1, 1) * X, axis=0) / np.sum(responsibilities[:, i])
            self.covariances[i] = np.dot((X - self.means[i]).T * responsibilities[:, i], (X - self.means[i])) / np.sum(responsibilities[:, i])

    def predict(self, X):
        return np.argmax([self._compute_responsibilities(X) for i in range(self.n_components)], axis=1)

这个简单的实现包括了:

  • 初始化模型参数。
  • E步:计算每个数据点属于每个高斯分量的“责任”(即概率)。
  • M步:根据这些责任更新模型参数(权重、均值、协方差)。

我们来分析一下各段代码

2.1 定义高斯混合模型类

class GaussianMixture:
    def __init__(self, n_components=3, covariance_type='full'):
        self.n_components = n_components #混合模型中的高斯分量数量,默认为3。
        self.covariance_type = covariance_type #指定协方差矩阵的类型,这里使用 'full' 表示每个分量都有自己的完整协方差矩阵。
    #分别存储每个高斯分量的权重、均值和协方差矩阵,它们在模型训练过程中会被更新。
        self.weights = None
        self.means = None
        self.covariances = None

2.2 训练模型

    def fit(self, X, n_iter=100):
        n_samples, n_features = X.shape

        #初始化模型的混合权重
        self.weights = np.ones(self.n_components) / self.n_components

        #随机初始化每个组件的均值
        self.means = np.random.rand(self.n_components, n_features)
        
        #初始化每个组件的协方差矩阵。np.eye(n_features) 生成一个单位矩阵,表示每个组件的初始协方差矩阵是单位矩阵。
        self.covariances = np.array([np.eye(n_features)] * self.n_components)

        for i in range(n_iter):
            
            # E-step: Compute responsibilities
            #计算责任度。根据当前的模型参数(均值、协方差和权重)来计算每个数据点属于每个组件的概率。
            responsibilities = self._compute_responsibilities(X)

            # M-step: Update parameters
            #根据计算出的责任度来更新均值、协方差和权重,以最大化数据的似然函数。
            self._update_parameters(X, responsibilities)

2.3 E步:计算责任

    def _compute_responsibilities(self, X):
        responsibilities = np.zeros((X.shape[0], self.n_components))
        for i in range(self.n_components):
            rv = multivariate_normal(self.means[i], self.covariances[i])#高斯正态分布
            responsibilities[:, i] = self.weights[i] * rv.pdf(X) / np.sum([self.weights[j] * multivariate_normal(self.means[j], self.covariances[j]).pdf(X) for j in range(self.n_components)], axis=0)
        return responsibilities
  • 对于每个分量,使用 multivariate_normal.pdf 方法计算数据点的密度,然后通过权重和归一化因子计算责任。

2.4 M步:更新参数

    def _update_parameters(self, X, responsibilities):
        self.weights = np.mean(responsibilities, axis=0)
        for i in range(self.n_components):
            self.means[i] = np.sum(responsibilities[:, i].reshape(-1, 1) * X, axis=0) / np.sum(responsibilities[:, i])
            self.covariances[i] = np.dot((X - self.means[i]).T * responsibilities[:, i], (X - self.means[i])) / np.sum(responsibilities[:, i])

  • 更新权重:每个分量的权重是该分量责任的平均值。
  • 更新均值:每个分量的均值是数据点的加权平均,权重是责任。
  • 更新协方差:每个分量的协方差是加权的数据点偏差的外积,权重是责任。

2.5 预测数据点的类别

    def predict(self, X):
        return np.argmax([self._compute_responsibilities(X) for i in range(self.n_components)], axis=1)

3 使用示例

# 生成一些模拟数据
np.random.seed(0)
data = np.concatenate([np.random.normal(0, 1, (100, 2)), np.random.normal(5, 2, (100, 2))])
data[:, 1] += data[:, 0] * 3

# 训练模型
gmm = GaussianMixture(n_components=2)
gmm.fit(data)

# 预测数据点的类别
labels = gmm.predict(data)
print(labels)
  • 生成模拟数据:两个高斯分布,分别在 (0, 1) 和 (5, 2) 附近。
  • 训练模型:使用 GaussianMixture 类训练模型。
  • 预测类别:使用训练好的模型预测数据点的类别。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2241219.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python中的HTML

文章目录 一. HTML1. html的定义2. html的作用3. 基本结构4. 常用的html标签5. 列表标签① 无序列表② 有序列表 6. 表格标签7. 表单标签8. 表单提交① 表单属性设置② 表单元素属性设置 一. HTML 1. html的定义 HTML 的全称为:HyperText Mark-up Language, 指的是…

大数据新视界 -- 大数据大厂之 Impala 存储格式转换:从原理到实践,开启大数据性能优化星际之旅(下)(20/30)

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

【优选算法 — 滑动窗口】水果成篮 找到字符串中所有字母异位词

水果成篮 水果成篮 题目描述 因为只有两个篮子,每个篮子装的水果种类相同,如果从 0 开始摘,则只能摘 0 和 1 两个种类 ; 因为当我们在两个果篮都装有水果的情况下,如果再走到下一颗果树,果树的水果种类…

【ubuntu16.04】机器人学习笔记遇到的问题及解决办法:仿真小海龟

18版本的后面会出问题,避免万一我还是用了之前的16版本,虽然还没有解决粘贴的问题,但是安装ros很成功 可参考该文章博主讲的很详细,成功画出海龟 最后要把鼠标停在第三个终端,再去点击键盘,海龟才会动哦

游戏引擎学习第九天

视频参考:https://www.bilibili.com/video/BV1ouUPYAErK/ 修改之前的方波数据,改播放正弦波 下面主要讲关于浮点数 1. char(字符类型) 大小:1 字节(8 位)表示方式:char 存储的是一个字符的 A…

探索 JNI - Rust 与 Java 互调实战

真正的救赎,并非厮杀后的胜利,而是能在苦难之中,找到生的力量和内心的安宁。 ——加缪Albert Camus 一、Rust Java ? Java 和 Rust 是两种现代编程语言,各自具有独特的优势,适用于不同的应用场景。 1、…

C++11新特性(二)

目录 一、C11的{} 1.初始化列表 2.initializer_list 二、可变参数模版 1.语法与原理 2.包扩展 3.empalce接口 三、新的类功能 四、lambda 1.语法 2.捕捉列表 3.原理 五、句装器 1.function 2.bind 一、C11的{} 1.初始化列表 C11以后想统⼀初始化⽅式&#xff0…

生信:TCGA学习(R、RStudio安装与下载、常用语法与常用快捷键)

前置环境 macOS系统,已安装homebrew且会相关命令。 近期在整理草稿区,所以放出该贴。 R语言、RStudio、R包安装 R语言安装 brew install rRStudio安装 官网地址:https://posit.co/download/rstudio-desktop/ R包下载 注意R语言环境自带…

Vue3集成搜索引擎智能提示API

需求: 如何在项目中实现像百度搜索框一样的智能提示效果,如下图所示: 相关知识: 下面是各厂商提供的免费API 厂商请求百度http://suggestion.baidu.com/su?wd中国&cbwindow.baidu.sug必应http://api.bing.com/qsonhs.as…

大数据技术在智慧医疗中的应用

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 大数据技术在智慧医疗中的应用 大数据技术在智慧医疗中的应用 大数据技术在智慧医疗中的应用 引言 大数据技术概述 定义与原理 发…

游戏引擎学习第10天

视频参考:https://www.bilibili.com/video/BV1LyU3YpEam/ 介绍intel architecture reference manual 地址:https://www.intel.com/content/www/us/en/developer/articles/technical/intel-sdm.html RDTS(读取时间戳计数器)指令是 x86/x86_64 架构中的…

「QT」文件类 之 QTemporaryDir 临时目录类

✨博客主页何曾参静谧的博客📌文章专栏「QT」QT5程序设计📚全部专栏「Win」Windows程序设计「IDE」集成开发环境「UG/NX」BlockUI集合「C/C」C/C程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「UG/NX」NX定制…

Kettle配置数据源错误“Driver class ‘org.gjt.mm.mysql.Driver‘ could not be found”解决记录

问题描述 错误提示:“Driver class ‘org.gjt.mm.mysql.Driver’ could not be found, make sure the ‘MySQL’ driver (jar file) is installed.” 原因分析: 根据错误提示是缺少了相关的数据源连接jar包。 解决方案: 安装对应的Mysql…

C++《继承》

在之前学习学习C类和对象时我们就初步了解到了C当中有三大特性,分别是封装、继承、多态,通过之前的学习我们已经了解了C的封装特性,那么接下来我们将继续学习另外的两大特性,在此将分为两个章节来分别讲解继承和多态。本篇就先来学…

力扣(LeetCode)283. 移动零(Java)

White graces:个人主页 🙉专栏推荐:Java入门知识🙉 🐹今日诗词:雾失楼台,月迷津渡🐹 ⛳️点赞 ☀️收藏⭐️关注💬卑微小博主🙏 ⛳️点赞 ☀️收藏⭐️关注💬卑微小博主…

运算放大器的学习(一)输入阻抗

输入阻抗 最近需要对运算放大器进行学习,我们后面逐一对其参数进行了解。 首先了解下输入阻抗。 放大电路技术指标测试示意图: 输入电阻: 从放大电路的输入端看进去的等效电阻称为放大电路的输入电阻,如上图,此处考虑…

Python3.11.9下载和安装

一、Python3.11.9下载和安装 1、下载 下载地址:https://www.python.org/downloads/windows/ 选择版本下载,例如:Python 3.11.9 - April 2, 2024 2、安装 双击exe安装 3、配置环境变量 pathD:\Program Files\python3.11.9 pathD:\Progr…

大模型研究报告 | 2024年中国金融大模型产业发展洞察报告|附34页PDF文件下载

随着生成算法、预训练模型、多模态数据分析等AI技术的聚集融合,AIGC技术的实践效用迎来了行业级大爆发。通用大模型技术的成熟推动了新一轮行业生产力变革,在投入提升与政策扶植的双重作用下,以大模型技术为底座、结合专业化金融能力的金融大…

杰控通过 OPCproxy 获取数据发送到服务器

把数据从 杰控 取出来发到服务器 前提你在杰控中已经有变量了(wincc 也适用) 打开你的opcproxy 软件包 opcvarFile 添加变量 写文件就写到 了 opcproxy.ini中 这个文件里就是会读取到的数据 然后 opcproxy.exe发送到桌面快捷方式再考回来 &#…

Vue3 -- 环境变量的配置【项目集成3】

环境: 在项目开发过程中,至少会经历开发环境、测试环境和生产环境(即正式环境)三个阶段。 开发环境 .env.development测试环境 .env.test生产环境 .env.production 不同阶段请求的状态(如接口地址等)不一样,开发项目的时候要经常配置代理跨…