数据聚类:Mean-Shift和EM算法

news2024/11/20 8:43:48

目录

  • 1. 高斯混合分布
  • 2. Mean-Shift算法
  • 3. EM算法
  • 4. 数据聚类
  • 5. 源码地址


1. 高斯混合分布

在高斯混合分布中,我们假设数据是由多个高斯分布组合而成的。每个高斯分布被称为一个“成分”(component),这些成分通过加权和的方式来构成整个混合分布。

高斯混合分布的公式可以表示为:

p ( x ) = ∑ i = 1 K π i N ( x ∣ μ i , Σ i ) p(x) = \sum^K_{i=1} \pi_i N(x|\mu_i, \Sigma_i) p(x)=i=1KπiN(xμi,Σi)

其中:

  • p ( x ) p(x) p(x)是观测数据点 x x x的概率密度函数,
  • K K K是高斯分布的数量,
  • π i \pi_i πi是第 i i i个高斯分布的混合系数,满足 ∑ i = 1 K π i = 1 \sum^K_{i=1} \pi_i = 1 i=1Kπi=1,
  • μ i \mu_i μi是第 i i i个高斯分布的均值向量,
  • Σ i \Sigma_i Σi是第 i i i个高斯分布的协方差矩阵。

为了简单呈现结果,我们选取 K = 2 K=2 K=2个高斯分布。下图为一个高斯混合分布的采样散点图,其中两个高斯分布的 μ i \mu_i μi分别为 [ 0 , 0 ] , [ 5 , 5 ] [0,0], [5,5] [0,0],[5,5],协方差矩阵均为:
[ 1 0 0 1 ] \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} [1001]

在这里插入图片描述

Fig. 1. 高斯混合分布的采样散点图

2. Mean-Shift算法

Mean-Shift是一种非参数化的密度估计和聚类算法,用于将数据点组织成具有相似特征的群集。它是一种迭代算法,通过计算数据点的梯度信息来寻找数据点在特征空间中的密度极值点,从而确定聚类中心。

算法的核心思想是通过不断地更新数据点的位置,将它们移向密度估计函数梯度的最大方向,直到达到收敛条件。具体来说,Mean-Shift算法包括以下步骤:

  • 初始化:选择一个数据点作为初始聚类中心,或者随机选择一个点作为初始中心。
  • 确定梯度向量:对于每个数据点,计算其与其他数据点之间的距离,并根据一定的核函数(如高斯核)计算梯度向量。梯度向量的方向指向密度估计函数增加最快的方向。
  • 移动数据点:将每个数据点移动到梯度向量的方向上,即向密度估计函数增加最快的方向移动一定的步长。
  • 更新聚类中心:对于移动后的每个数据点,重新计算它们周围数据点的梯度向量,并更新它们的位置。重复这个过程直到达到收敛条件,比如聚类中心的移动距离小于某个阈值。
  • 形成聚类:最终,根据收敛后的聚类中心,将数据点分配到最近的聚类中心,形成最终的聚类结果。

Mean-Shift算法的优点是不需要事先指定聚类的个数,且能够自适应地调整聚类中心的数量和形状。它在处理非线性和非凸形状的数据集时表现出良好的聚类效果。然而,该算法对于大规模数据集的计算复杂度较高,且对初始聚类中心的选择敏感。Mean-Shift算法的具体实现见代码片:

class MeanShift:
    def __init__(self, bandwidth=1.0, max_iterations=100):
        self.min_shift = 1
        self.n_clusters_ = None
        self.cluster_centers_ = None
        self.labels_ = None
        self.bandwidth = bandwidth
        self.max_iterations = max_iterations

    def euclidean_distance(self, x1, x2):
        return np.sqrt(np.sum((x1 - x2) ** 2))

    def gaussian_kernel(self, distance, bandwidth):
        return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth) ** 2))

    def shift_point(self, point, data, bandwidth):
        shift_x = 0.0
        shift_y = 0.0
        total_weight = 0.0

        for i in range(len(data)):
            distance = self.euclidean_distance(point, data[i])
            weight = self.gaussian_kernel(distance, bandwidth)
            shift_x += data[i][0] * weight
            shift_y += data[i][1] * weight
            total_weight += weight

        shift_x /= total_weight
        shift_y /= total_weight

        return np.array([shift_x, shift_y])

    def fit(self, data):
        centroids = np.copy(data)

        for _ in range(self.max_iterations):
            shifts = np.zeros_like(centroids)

            for i, centroid in enumerate(centroids):
                distances = cdist([centroid], data)[0]
                weights = self.gaussian_kernel(distances, self.bandwidth)
                shift = np.sum(weights[:, np.newaxis] * data, axis=0) / np.sum(weights)
                shifts[i] = shift

            shift_distances = cdist(shifts, centroids)
            centroids = shifts

            if np.max(shift_distances) < self.min_shift:
                break

        unique_centroids = np.unique(np.around(centroids, 3), axis=0)

        self.cluster_centers_ = unique_centroids
        self.labels_ = np.argmin(cdist(data, unique_centroids), axis=1)
        self.n_clusters_ = len(unique_centroids)

3. EM算法

EM算法是一种迭代的数值优化算法,用于求解包含隐变量的概率模型参数的最大似然估计。它在统计学和机器学习领域被广泛应用,尤其在聚类问题中有着重要的应用。其基于观测数据和隐变量之间的概率模型,通过交替进行两个步骤:E步骤(Expectation Step)和M步骤(Maximization Step)来迭代地优化模型参数。下面是EM算法的基本步骤:

  • 初始化:选择一组初始参数来开始迭代过程。
  • E步骤:根据当前的参数估计,计算隐变量的后验概率,即给定观测数据下隐变量的条件概率分布。
  • M步骤:使用在E步骤中计算得到的后验概率,对参数进行更新,以最大化对数似然函数。
  • 重复步骤2-3至收敛:重复执行E步骤和M步骤,直到参数的变化很小或满足收敛条件。

在聚类问题中,EM算法可以用于估计混合高斯模型的参数,从而实现数据的聚类。EM算法在聚类中的应用优点是能够处理具有隐变量的概率模型,适用于灵活的聚类问题。然而,EM算法对于初始参数的选择敏感,可能会陷入局部最优解,并且在处理大规模数据集时可能会面临计算复杂度的挑战。EM算法(包含正则化步骤)的具体实现见代码片:

class RegularizedEMClustering:
    def __init__(self, n_clusters, max_iterations=100, epsilon=1e-4, regularization=1e-6):
        self.labels_ = None
        self.X = None
        self.n_features = None
        self.n_samples = None
        self.cluster_probs_ = None
        self.cluster_centers_ = None
        self.n_clusters = n_clusters
        self.max_iterations = max_iterations
        self.epsilon = epsilon
        self.regularization = regularization

    def fit(self, X):
        self.X = X
        self.n_samples, self.n_features = X.shape

        self.cluster_centers_ = X[np.random.choice(self.n_samples, self.n_clusters, replace=False)]

        self.cluster_probs_ = np.ones((self.n_samples, self.n_clusters)) / self.n_clusters

        # EM
        for iteration in range(self.max_iterations):
            # E-step
            prev_cluster_probs = self.cluster_probs_
            self._update_cluster_probs()

            # M-step
            self._update_cluster_centers()

            delta = np.linalg.norm(self.cluster_probs_ - prev_cluster_probs)

            if delta < self.epsilon:
                break

        self.labels_ = np.argmax(self.cluster_probs_, axis=1)

    def _update_cluster_probs(self):
        distances = np.linalg.norm(self.X[:, np.newaxis, :] - self.cluster_centers_, axis=2)

        # Calculate the cluster probabilities with regularization
        numerator = np.exp(-distances) + self.regularization
        denominator = np.sum(numerator, axis=1, keepdims=True)
        self.cluster_probs_ = numerator / denominator

    def _update_cluster_centers(self):
        self.cluster_centers_ = np.zeros((self.n_clusters, self.n_features))
        for k in range(self.n_clusters):
            self.cluster_centers_[k] = np.average(self.X, axis=0, weights=self.cluster_probs_[:, k])

    def predict(self, X):
        distances = np.linalg.norm(X[:, np.newaxis, :] - self.cluster_centers_, axis=2)
        return np.argmin(distances, axis=1)

4. 数据聚类

Mean-Shift和EM算法的聚类结果分别如图2的a-b子图所示,由于MoG比较简单,两种算法均可以合理且完整地实现聚类,聚类中心也没有显著差异。

在这里插入图片描述

Fig. 2. Mean-Shift(a)和EM(b)算法的聚类结果

5. 源码地址

如果对您有用的话可以点点star哦~

https://github.com/Jurio0304/cs-math/blob/main/hw3_clustering.ipynb
https://github.com/Jurio0304/cs-math/blob/main/func.py


创作不易,麻烦点点赞和关注咯!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1629131.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql-connector-java和spring-boot-starter-jdbc和mybatis-spring-boot-start

mysql-connector-java和spring-boot-starter-jdbc和mybatis-spring-boot-start JDBC是什么意思&#xff1f; JDBC是使用java语言操作mysql数据库的规范&#xff0c;java语言必须按照这个规范写才可以操作mysql数据库。 mysql-connector-java 在最开始的时候 程序中是不允许…

深入解析AI绘画算法:从GANs到VAEs

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

蛋糕购物商城

蛋糕购物商城 运行前附加数据库.mdf&#xff08;或使用sql生成数据库&#xff09; 登陆账号&#xff1a;admin 密码&#xff1a;123456 修改专辑价格时去掉&#xffe5;以及上传专辑图片 c#_asp.net 蛋糕购物商城 网上商城 三层架构 在线购物网站&#xff0c;电子商务系统 …

Linux——终端

一、终端 1、终端是什么 终端最初是指终端设备&#xff08;Terminal&#xff09;&#xff0c;它是一种用户与计算机系统进行交互的硬件设备。在早期的计算机系统中&#xff0c;终端通常是一台带有键盘和显示器的电脑&#xff0c;用户通过它输入命令&#xff0c;计算机在执行命…

redisson分布式锁的单机版应用

package com.redis;/*** author linn* date 2024年04月23日 15:31*/ import org.redisson.Redisson; import org.redisson.api.RedissonClient; import org.redisson.config.Config; import org.springframework.context.annotation.Bean; import org.springframework.context.…

SpringBoot 3.2.5 引入Swagger(OpenApi)

SpringBoot 3.2.5 引入Swagger&#xff08;OpenApi&#xff09; pom文件配置文件启动类Controller 层ApiFox题外话 springdoc-openapi 和 swagger 都可以用&#xff0c;用其中一个就行&#xff0c;不用两个都引入。 这里简单记录以下springdoc-openapi。 springdoc-openapi(J…

【AI相关】模型相关技术名词

目录 过拟合和欠拟合 1.过拟合 2.欠拟合 特征清洗、数据变换、训练集、验证集和测试集 1.特征清洗 2.数据变换 3.训练集 4.验证集 5.测试集 跨时间测试和回溯测试 1.跨时间测试&#xff08;OOT 测试&#xff09; 2.回溯测试 联合建模与联邦学习 1.联合建模 2.联…

用友政务财务系统FileDownload接口存在任意文件读取漏洞

声明&#xff1a; 本文仅用于技术交流&#xff0c;请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任。 简介 用友政务财务系统是由用友软件开发的一款针对政府机…

OPPO手机支持深度测试+免深度测试解锁BL+ROOT权限机型整理-2024年3月更新

绿厂OPPO手机线上线下卖的都很不错&#xff0c;目前市场份额十分巨大&#xff0c;用户自然也非常多&#xff0c;而近期ROM乐园后台受到很多关于OPPO手机的私信&#xff0c;咨询哪些机型支持解锁BL&#xff0c;ROOT刷机&#xff0c;今天ROM乐园正式盘点当前市场上可以解BL刷root…

树莓派4-通过IIC实现图片循环播放

一、环境 1、树莓派4&#xff1b; 2、串口连接电脑&#xff1b; 3、树莓派由杜邦线连接0.96寸OLED1306协议 4、树莓派能够联网&#xff0c;便于安装环境。离线情况也可以安装&#xff0c;相对麻烦&#xff1b; 二、目标 1、树莓派可以开启IIC并识别已连接的IIC&#xff1b; …

机器人-轨迹规划

旋转矩阵 旋转矩阵--R--一个3*3的矩阵&#xff0c;其每列的值时B坐标系在A坐标系上的投影值。 代表B坐标系相对于A坐标系的姿态。 旋转矩阵的转置矩阵 其实A相对于B的旋转矩阵就相当于把B的列放到行上就行。 视频 &#xff08;将矩阵的行列互换得到的新矩阵称为转置矩阵。&…

4月26日 阶段性学习汇报

1.毕业设计与毕业论文 毕业设计已经弄完&#xff0c;加入了KNN算法&#xff0c;实现了基于四种常见病的判断&#xff0c;毕业论文写完&#xff0c;格式还需要调整&#xff0c;下周一发给指导老师初稿。目前在弄答辩ppt&#xff08;25%&#xff09;。25号26号两天都在参加校运会…

【蓝桥杯省赛真题38】python字符串拼接 中小学青少年组蓝桥杯比赛 算法思维python编程省赛真题解析

目录 python字符串拼接 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 七、 推荐资料 1、蓝桥杯比赛 2、考级资料 3、其它资料 python字符串拼接 第十三届蓝桥杯青少年组python编程省赛真题 一、题目…

Cadence OrCAD学习笔记(2)OrCAD原理图

最近换份工作主要用到Cadence&#xff0c;之前都是用AD居多&#xff0c;所以现在也开始记录下Cadence学习过程&#xff0c;方便后面复习。 参考教程&#xff1a; OrCAD视频教程第2期&#xff1a;10分钟学会OrCAD原理图_哔哩哔哩_bilibili 本期主要介绍原理图中的基本操作&…

ZooKeeper 搭建详细步骤之二(伪集群模式)

ZooKeeper 搭建详细步骤之一&#xff08;单机模式&#xff09; ZooKeeper 及相关概念简介 伪集群搭建 ZooKeeper 伪集群是指在一个单一的物理或虚拟机环境中模拟出一个由多个 ZooKeeper 节点构成的集群。尽管这些节点实际上运行在同一台机器上&#xff0c;但它们通过配置不同的…

【学习笔记二十八】EWM和QM集成的后台配置和前台展示

一、EWM和QM集成概述 SAP EWM(扩展仓库管理)和QM(质量管理)的集成是SAP系统中一个重要的特性,它允许企业在仓库管理过程中实现质量控制和检验流程的自动化。以下是关于EWM和QM集成的一些关键点概述: 集成优势:通过集成,企业可以确保仓库中的物料在收货、存储、…

flutter笔记-主要控件及布局

文章目录 1. 富文本实例2. Image2.1 本地图片2.2 网络图片 笔记3. 布局4. 滑动相关view4.1 GridView类似九宫格view4.2 ListView 关于widget的生命周期的相关知识这里就不做介绍&#xff0c;和很多语言类似&#xff1b; 1. 富文本实例 Dart中使用richtext&#xff0c;示例如下…

深入浅出MySQL-02-【MySQL支持的数据类型】

文章目录 前言1.数值类型2.日期时间类型3.字符串类型3.1.CHAR和VARCHAR类型3.2.ENUM类型3.3.SET类型 4.JSON类型 前言 环境&#xff1a; Windows11MySQL-8.0.35 1.数值类型 MySQL中的数值类型&#xff0c;如下&#xff1a; 整数类型字节最小值最大值TINYINT1有符号 -128无…

C#反射应用

1.根据类名名称生成类实例 CreateInstance后面的参数部分一定要和所构造的类参数数量对应&#xff0c;即使设置参数默认值&#xff0c;也不可省略。 2.只知道类名&#xff0c;需要将该类作为参数调用泛型接口。 3.只知道类名&#xff0c;需要将该类的数组作为参数调用泛型接口…

基于51单片机的电梯仿真系统

基于51单片机的电梯设计 &#xff08;仿真&#xff0b;程序PPT&#xff09; 功能介绍 具体功能&#xff1a; 1.一共4层&#xff0c;数码管显示当前楼层&#xff1b; 2.六个按键模拟电梯外按键&#xff08;1上、2上、2下、3上、3下、4下&#xff09;&#xff0c;每当按下时有…