机器学习实战:Python基于EM期望最大化进行参数估计(十五)

news2025/1/12 21:08:43

文章目录

    • 1. 前言
      • 1.1 EM的介绍
      • 1.2 EM的应用场景
    • 2. 高斯混合模型估计
      • 2.1 导入函数
      • 2.2 创建数据
      • 2.3 初始化
      • 2.4 Expectation Step
      • 2.5 Maximization step
      • 2.6 循环迭代可视化
    • 3. 多维情况
    • 4. 讨论

1. 前言

1.1 EM的介绍

Expectation-MaximizationEM)是一种迭代式的优化算法,主要用于解决含有隐变量的概率模型的参数估计问题。它的目标是在给定观测数据和未观测数据(隐变量)的情况下,估计概率模型的参数,使得模型能够最好地拟合观测数据。

EM算法的基本思想是通过交替进行两个步骤来优化模型参数:E步骤(Expectation)和M步骤(Maximization)。

  • E步骤(Expectation):
    在E步骤中,我们根据当前的参数估计值,计算出每个观测数据属于每个隐变量状态的概率,即计算出每个观测数据的后验概率。这些后验概率称为期望,因为它们代表了在当前参数下观测数据所“期望”的隐变量状态。

  • M步骤(Maximization):
    在M步骤中,我们根据E步骤得到的后验概率,最大化对数似然函数(或者叫Q函数)来更新模型参数。这一步骤可以看作是在给定观测数据和当前隐变量的情况下,对模型参数进行最大似然估计。

通过反复迭代E步骤和M步骤,EM算法不断优化模型参数,直到达到收敛条件。最终得到的模型参数能够使得模型对观测数据的拟合效果达到最优。

优点

  • 强大的参数估计能力:EM算法在含有隐变量的概率模型中具有较强的参数估计能力,尤其对于复杂模型和大规模数据集表现出色。

  • 高效的迭代优化:EM算法采用迭代的方式优化参数,通常能够在有限的迭代次数内收敛到局部最优解,相比其他优化方法更高效。

  • 灵活性:EM算法可以用于广泛的机器学习任务,包括聚类、混合高斯模型、隐马尔可夫模型等,使其在不同领域中得到广泛应用。

  • 统计性解释:EM算法基于最大似然估计,提供了对模型参数的统计性解释,能够在一定程度上量化参数估计的不确定性。

缺点

  • 收敛性不稳定:EM算法对于参数的初始值敏感,可能会陷入局部最优解,导致收敛性不稳定。

  • 需要选择合适的迭代次数:EM算法的收敛速度取决于迭代次数的选择,过多或过少的迭代次数都可能影响参数估计的精度。

  • 对高维数据敏感:在高维数据上,EM算法可能会面临维度灾难和过拟合问题,导致模型性能下降。

  • 可能陷入局部最优解:EM算法是一种局部优化方法,可能会陷入局部最优解,而无法得到全局最优解。

1.2 EM的应用场景

EM算法主要是用于参数估计,特别是在一些含有隐变量的概率模型,因此应用领域相对广泛:

  1. 计算机视觉:在图像处理和计算机视觉中,EM算法可以用于图像分割、目标识别和人脸识别等任务,特别是在混合高斯模型和高斯混合模型中的应用较为广泛。

  2. 自然语言处理:在自然语言处理领域,EM算法常用于文本聚类、主题模型和情感分析等任务,例如隐含狄利克雷分布模型(Latent Dirichlet Allocation,LDA)就是一种常见的应用。

  3. 生物信息学:在基因组学和蛋白质结构预测中,EM算法可以用于基因表达聚类、DNA序列分析和蛋白质折叠等问题。

  4. 金融领域:在金融风险评估、投资组合优化和市场预测中,EM算法可以用于建模和预测复杂的金融数据。

  5. 推荐系统:在个性化推荐和协同过滤任务中,EM算法可以用于学习用户和物品的隐含因子,从而实现更准确的推荐。

  6. 医学图像分析:在医学影像处理和分析中,EM算法可以用于图像分割、病灶检测和医学图像重建等任务。

  7. 无线通信:在无线信号处理和通信中,EM算法可以用于信号检测、通信信道估计和信号解调等问题。

2. 高斯混合模型估计

EM算法通常用于无监督学习问题,这里就用简单的高斯混合模型GMM)作实战演示,此外还有隐马尔可夫模型HMM)也是挺常见的。

2.1 导入函数

import random
import numpy as np
from numpy.linalg import inv
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

2.2 创建数据

m1 = [1, 1]   
m2 = [7, 7]
cov1 = [[3, 2], [2, 3]]
cov2 = [[2, -1], [-1, 2]]

x = np.random.multivariate_normal(m1, cov1, size=(200,))
y = np.random.multivariate_normal(m2, cov2, size=(200,))
d = np.concatenate((x, y), axis=0)

查看分布情况

plt.figure(figsize=(10,10))                                 
plt.scatter(d[:,0], d[:,1], marker='o')     
plt.axis('equal')                                  
plt.xlabel('X-Axis', fontsize=16)              
plt.ylabel('Y-Axis', fontsize=16)                     
plt.title('Ground Truth', fontsize=22)    
plt.grid()            
plt.show()

2.3 初始化

这里是在进行EM算法前对两个高斯分布的均值和协方差矩阵初始,其中参数pi初始化为 0.5,表示两个高斯分布的先验概率相等

m1 = random.choice(d)
m2 = random.choice(d)
cov1 = np.cov(np.transpose(d))
cov2 = np.cov(np.transpose(d))
pi = 0.5

可视化高斯分布情况(等高线)

x1 = np.linspace(-4, 11, 200)
x2 = np.linspace(-4, 11, 200)
X, Y = np.meshgrid(x1, x2)

Z1 = multivariate_normal(m1, cov1)
Z2 = multivariate_normal(m2, cov2)

pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y

plt.figure(figsize=(10, 10))
plt.scatter(d[:, 0], d[:, 1], marker='o')
plt.contour(X, Y, Z1.pdf(pos), colors="r", alpha=0.5)
plt.contour(X, Y, Z2.pdf(pos), colors="b", alpha=0.5)
plt.axis('equal')
plt.xlabel('X-Axis', fontsize=16)
plt.ylabel('Y-Axis', fontsize=16)
plt.title('Initial State', fontsize=22)
plt.grid()
plt.show()

2.4 Expectation Step

计算数据点对应于每个类别的"期望"

def Estep(lis1):
    m1=lis1[0]
    m2=lis1[1]
    cov1=lis1[2]
    cov2=lis1[3]
    pi=lis1[4]
    
    pt2 = multivariate_normal.pdf(d, mean=m2, cov=cov2)
    pt1 = multivariate_normal.pdf(d, mean=m1, cov=cov1)
    w1 = pi * pt2
    w2 = (1-pi) * pt1
    eval1 = w1/(w1+w2)

    return(eval1)

2.5 Maximization step

使用E步骤中得到的隐含变量的估计值,来最大化(最优化)模型的对数似然函数

def Mstep(eval1):
    num_mu1,din_mu1,num_mu2,din_mu2=0,0,0,0

    for i in range(0,len(d)):
        num_mu1 += (1-eval1[i]) * d[i]
        din_mu1 += (1-eval1[i])

        num_mu2 += eval1[i] * d[i]
        din_mu2 += eval1[i]

    mu1 = num_mu1/din_mu1
    mu2 = num_mu2/din_mu2

    num_s1,din_s1,num_s2,din_s2=0,0,0,0
    for i in range(0,len(d)):

        q1 = np.matrix(d[i]-mu1)
        num_s1 += (1-eval1[i]) * np.dot(q1.T, q1)
        din_s1 += (1-eval1[i])

        q2 = np.matrix(d[i]-mu2)
        num_s2 += eval1[i] * np.dot(q2.T, q2)
        din_s2 += eval1[i]

    s1 = num_s1/din_s1
    s2 = num_s2/din_s2

    pi = sum(eval1)/len(d)
    
    lis2=[mu1,mu2,s1,s2,pi]
    return(lis2)

2.6 循环迭代可视化

这里修改迭代次数(i)分别为1,2,3,4,结果:

iterations = 20
lis1=[m1,m2,cov1,cov2,pi]
for i in range(0,iterations):
    lis2 = Mstep(Estep(lis1))
    lis1=lis2
    if(i==0 or i == 4 or i == 9 or i == 14 or i == 19):
        plot(lis1)

i = 0

i = 1

i = 2

i = 3

i = 4

当i越大,等高线的重叠程度越小,说明经过更多的迭代,高斯混合模型的参数估计越接近真实值,模型的拟合效果越好。根据结果可以看到在第二次迭代后两个模型等高线几乎没有变化,这表示模型已经收敛到一个稳定状态。在EM算法中,迭代会持续更新参数,直到收敛到一个局部最优解或全局最优解为止。

当模型达到收敛状态后,后续的迭代可能不会有显著的变化,因为模型已经找到了最优解或接近最优解。此时,进一步迭代可能只会导致细微的调整,不会对整体结果产生重大影响。

3. 多维情况

有小伙伴会问如果我一个数据集有多个分组,直接上:GMM的建模代码和刚刚一样

import matplotlib.pyplot as plt
from matplotlib import style
style.use('fivethirtyeight')
from sklearn.datasets import make_blobs
import numpy as np
from scipy.stats import multivariate_normal


# 0. Create dataset
X, Y = make_blobs(cluster_std=1.5, random_state=20, n_samples=500, centers=3)

# Stratch dataset to get ellipsoid data
X = np.dot(X, np.random.RandomState(0).randn(2, 2))


class GMM:

    def __init__(self,X,number_of_sources,iterations):
        self.iterations = iterations
        self.number_of_sources = number_of_sources
        self.X = X
        self.mu = None
        self.pi = None
        self.cov = None
        self.XY = None
        
    def run(self):
        self.reg_cov = 1e-6*np.identity(len(self.X[0]))
        x,y = np.meshgrid(np.sort(self.X[:,0]),np.sort(self.X[:,1]))
        self.XY = np.array([x.flatten(),y.flatten()]).T
           
        self.mu = np.random.randint(min(self.X[:,0]),max(self.X[:,0]),size=(self.number_of_sources,len(self.X[0])))
        self.cov = np.zeros((self.number_of_sources,len(X[0]),len(X[0])))
        for dim in range(len(self.cov)):
            np.fill_diagonal(self.cov[dim],5)

        self.pi = np.ones(self.number_of_sources)/self.number_of_sources
        log_likelihoods = []

        fig = plt.figure(figsize=(10,10))
        ax0 = fig.add_subplot(111)
        ax0.scatter(self.X[:,0],self.X[:,1])
        ax0.set_title('Initial state')
        for m,c in zip(self.mu,self.cov):
            c += self.reg_cov
            multi_normal = multivariate_normal(mean=m,cov=c)
            ax0.contour(np.sort(self.X[:,0]),np.sort(self.X[:,1]),multi_normal.pdf(self.XY).reshape(len(self.X),len(self.X)),colors='black',alpha=0.3)
            ax0.scatter(m[0],m[1],c='grey',zorder=10,s=100)
        
        for i in range(self.iterations):               
            r_ic = np.zeros((len(self.X),len(self.cov)))

            for m,co,p,r in zip(self.mu,self.cov,self.pi,range(len(r_ic[0]))):
                co+=self.reg_cov
                mn = multivariate_normal(mean=m,cov=co)
                r_ic[:,r] = p*mn.pdf(self.X)/np.sum([pi_c*multivariate_normal(mean=mu_c,cov=cov_c).pdf(X) for pi_c,mu_c,cov_c in zip(self.pi,self.mu,self.cov+self.reg_cov)],axis=0)

            self.mu = []
            self.cov = []
            self.pi = []
            log_likelihood = []

            for c in range(len(r_ic[0])):
                m_c = np.sum(r_ic[:,c],axis=0)
                mu_c = (1/m_c)*np.sum(self.X*r_ic[:,c].reshape(len(self.X),1),axis=0)
                self.mu.append(mu_c)
                self.cov.append(((1/m_c)*np.dot((np.array(r_ic[:,c]).reshape(len(self.X),1)*(self.X-mu_c)).T,(self.X-mu_c)))+self.reg_cov)
                self.pi.append(m_c/np.sum(r_ic))
            
            log_likelihoods.append(np.log(np.sum([k*multivariate_normal(self.mu[i],self.cov[j]).pdf(X) for k,i,j in zip(self.pi,range(len(self.mu)),range(len(self.cov)))])))

        fig2 = plt.figure(figsize=(10,10))
        ax1 = fig2.add_subplot(111) 
        ax1.set_title('Log-Likelihood')
        ax1.plot(range(0,self.iterations,1),log_likelihoods)
    
    def predict(self,Y):
        fig3 = plt.figure(figsize=(10,10))
        ax2 = fig3.add_subplot(111)
        ax2.scatter(self.X[:,0],self.X[:,1])
        for m,c in zip(self.mu,self.cov):
            multi_normal = multivariate_normal(mean=m,cov=c)
            ax2.contour(np.sort(self.X[:,0]),np.sort(self.X[:,1]),multi_normal.pdf(self.XY).reshape(len(self.X),len(self.X)),colors='black',alpha=0.3)
            ax2.scatter(m[0],m[1],c='grey',zorder=10,s=100)
            ax2.set_title('Final state')
            for y in Y:
                ax2.scatter(y[0],y[1],c='orange',zorder=10,s=100)
        prediction = []        
        for m,c in zip(self.mu,self.cov):  
            prediction.append(multivariate_normal(mean=m,cov=c).pdf(Y)/np.sum([multivariate_normal(mean=mean,cov=cov).pdf(Y) for mean,cov in zip(self.mu,self.cov)]))
        return prediction

GMM = GMM(X, 3, 50)
GMM.run()
GMM.predict([[0.5, 0.5]])

原始状态:

EM优化后:

五个组别的也一样可以收敛的很到位:

4. 讨论

EM算法广泛应用于许多领域,尤其在统计学、机器学习和数据挖掘中,用于处理包含缺失数据或未观测变量的复杂模型,如高斯混合模型(GMM)、隐马尔可夫模型(HMM)等。通过EM算法,我们可以估计模型参数,对数据进行聚类、密度估计等任务,从而更好地理解和分析数据。

常见的机器学习算法实战演示基本上都在前10节归纳到位了,从本节起也会陆续把重心放到优化算法上。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/815219.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实战案例:使用 Python 机器学习预测外卖送餐时间

现在的天气是一天比一天热,好多人周末休息在家的时候,就会选择点外卖,毕竟出去一趟又晒又热。 如果你太饿了,点餐太晚了,就可能去关注外卖员送餐到哪了,还有多少时间能送达。 这些信息在美团、饿了吗的Ap…

MapReduce原理剖析

一、基本介绍 MapReduce是Hadoop的核心,是Google提出的一个软件架构,用于大规模数据集(大于1TB)的并行运算。概念“Map(映射)”和“Reduce(化简)”,及他们的主要思想&am…

AWS 推出开源 AutoML 工具包“AutoGluon”

亚马逊网络服务最近推出了一个开源库,使开发人员只需几行代码即可在图像、文本或表格数据上实现深度学习模型。 AutoGluon 旨在成为一个易于使用且易于扩展的 AutoML 工具包,适合机器学习初学者和专家。它只需几行即可对深度学习模型进行原型设计;自动超…

stm8_独立看门狗配置顺序错误导致不断复位

1、问题 在配置stm8独立看门狗的时候,先设置分频、重载寄存器,然后启动看门狗,发现不断复位。 按照手册中的表格,看门狗的超时时间应该是1s,但是在这1s中多次喂狗也不断复位,然后排查到是配置顺序的问题&…

重新审视MHA与Transformer

本文将基于PyTorch源码重新审视MultiheadAttention与Transformer。事实上,早在一年前博主就已经分别介绍了两者:各种注意力机制的PyTorch实现、从零开始手写一个Transformer,但当时的实现大部分是基于d2l教程的,这次将基于PyTorch…

【实践篇】最全的【DDD领域建模】小白学习手册(文末附资料) | 京东云技术团队

导读 DDD领域建模被各个大小厂商提起并应用,而每个人都有自己的理解,本文就是针对小白,系统地讲解DDD到底是什么,解决了什么问题,及一些建议和实践。本文主要是思想的一种碰撞和分享,希望能对朋友们有所启…

第四章 No.2单点线段树的介绍与使用

文章目录 基本操作练习题1275. 最大数245. 你能回答这些问题吗246. 区间最大公约数 基本操作 单点线段树一共4个常用操作,pushup, build, modify, query 相比区间线段树少了pushdown,懒标记,由于pushdown的实现极容易SF,所以能用…

Python GUI应用程序开发之wxPython库详解

概要 wxPython是一个强大的跨平台GUI工具包,它使用Python编程语言开发,提供了丰富的控件功能。如果你是一名Python开发者,而且希望创建一个功能齐全的桌面应用程序,那么wxPython是一个值得考虑的选择。wxPython是wxWidgets C库的P…

算法——十大排序 (部分未完结)

总结 为什么需要稳定排序? ▪ 让第⼀个关键字的排序结果服务于第⼆个关键字排序中数值相同的那些数 ▪ 主要是为了第⼀次考试分数相同时候,可以按照第⼆次分数的⾼低进行排序 一、冒泡排序 从最简单的冒泡排序开始 思想:交换相邻的元素&am…

电子文件管理系统的最佳实践指南分享

电子文件管理系统是一种专门用于管理电子文件的软件工具,可以帮助组织更有效地管理、存储、检索和共享文件。 首先,在选择适合自己组织的电子文件管理系统时,需要考虑以下几个关键因素。首先,系统的易用性和用户界面是否友好&…

Qt应用开发(基础篇)——布局管理Layout Management

目录 一、前言 二:相关类 三、水平、垂直、网格和表单布局 四、尺寸策略 一、前言 在实际项目开发中,经常需要使用到布局,让控件自动排列,不仅节省控件还易于管控。Qt布局系统提供了一种简单而强大的方式来自动布局小部件中的…

前段时间面试了一些人,有这些槽点跟大家说说

大家好,我是拭心。 前段时间组里有岗位招人,花了些时间面试,趁着周末把过程中的感悟和槽点总结成文和大家讲讲。 简历书写和自我介绍 今年的竞争很激烈:找工作的人数量比去年多、平均质量比去年高。裸辞的慎重,要做好…

Android 第三方库CalendarView

Android 第三方库CalendarView 根据需求和库的使用方式,自己弄了一个合适自己的日历,仅记录下,方便下次弄其他样式的日历。地址 需求: 只显示当月的数据 默认的月视图有矩形的线 选中的天数也要有选中的矩形框 今天的item需要…

强推!大语言模型『百宝书』,一文缕清所有大模型!

夕小瑶科技说 原创 作者 | 王思若 最近,大型语言模型无疑是AI社区关注的焦点,各大科技公司和研究机构发布的大模型如同过江之鲫,层出不穷又眼花缭乱。 让笔者恍惚间似乎又回到了2020年国内大模型“军备竞赛”的元年,不过那时候…

package-lock.json 作用

参照: https://www.cnblogs.com/honkerzh/p/16767566.html

【雕爷学编程】MicroPython动手做(25)——语音合成与语音识别

知识点:什么是掌控板? 掌控板是一块普及STEAM创客教育、人工智能教育、机器人编程教育的开源智能硬件。它集成ESP-32高性能双核芯片,支持WiFi和蓝牙双模通信,可作为物联网节点,实现物联网应用。同时掌控板上集成了OLED…

山西电力市场日前价格预测【2023-08-01】

日前价格预测 预测明日(2023-08-01)山西电力市场全天平均日前电价为310.15元/MWh。其中,最高日前电价为335.18元/MWh,预计出现在19: 45。最低日前电价为288.85元/MWh,预计出现在14: 00。 价差方向预测 1:实…

无涯教程-jQuery - css( properties )方法函数

css(properties)方法将键/值对象设置为所有匹配元素的样式属性。 css( properties ) - 语法 selector.css( properties ) 上面的语法可以写成如下- selector.css( {key1:val1, key2:val2....keyN:valN}) 这是此方法使用的所有参数的描述- key:value - 设置为样式属…

郑州https数字证书

很多注重隐私的网站都注重网站信息的安全,比如购物网站就需要对客户的账户信息以及支付信息进行安全保护,否则信息泄露,客户与网站都有损失,网站也会因此流失大量客户。而网站使用https证书为客户端与服务器之间传输的信息加了一个…

<Git>版本控制工具Git常见的开发操作

下载安装,环境变量配置直接百度; 1.代码拉取: 操作步骤:在正确配置完git的条件下:在本地文件夹下:右键–Git Bash -Here: 出现如下弹窗: 在黑窗口输入代码拉取路径(一般都是把命令和路径直接在外面写好,直接粘贴(在窗口右键,Paste)) 代码拉去…