KMeans聚类算法实现

news2025/1/8 5:53:10

目录

1. K-Means的工作原理

2.Kmeans损失函数

3.Kmeans优缺点

4.编写KMeans算法实现类

5.KMeans算法测试

6.结果


       Kmeans是一种无监督的基于距离的聚类算法,其变种还有Kmeans++。其中,sklearn中KMeans的默认使用的即为KMeans++。使用sklearn相关算法API的调用案例可参考博主另一篇文章:KMeans算法实现图像分割。本文主要通过纯手写的方式,帮助学习理解KMeans算法的数据处理过程。

1. K-Means的工作原理

       在K-Means算法中,簇的个数K是一个超参数,需要人为输入来确定。K-Means的核心任务就是根据设定好的K,找出K个最优的质心,并将离这些质心最近的数据分别分配到这些质心代表的簇中去。具体过程可以总结如下:

  • 首先随机选取样本中的K个点作为聚类中心;
  • 分别算出样本中其他样本距离这K个聚类中心的距离,并把这些样本分别作为自己最近的那个聚类中心的类别;
  • 对上述分类完的样本再进行每个类别求平均值,求解出新的聚类质心;
  • 与前一次计算得到的K个聚类质心比较,如果聚类质心发生变化,转过程b,否则转过程e;
  • 当质心不发生变化时(当我们找到一个质心,在每次迭代中被分配到这个质心上的样本都是一致的,即每次新生成的簇都是一致的,所有的样本点都不会再从一个簇转移到另一个簇,质心就不会变化了),停止并输出聚类结果。

综上,K-Means 的算法步骤能够简单概括为:

1-分配:样本分配到簇。

2-移动:移动聚类中心到簇中样本的平均位置。

2.Kmeans损失函数

和其他机器学习算法一样,K-Means 也要评估并且最小化聚类代价,在引入 K-Means 的代价函数之前,先引入如下定义:

引入代价函数:

3.Kmeans优缺点

优点:
1.容易理解,聚类效果不错,虽然是局部最优, 但往往局部最优就够了;
2.处理大数据集的时候,该算法可以保证较好的伸缩性;
3.当簇近似高斯分布的时候,效果非常不错;
4.算法复杂度低。

缺点:
1.K 值需要人为设定,不同 K 值得到的结果不一样;
2.对初始的簇中心敏感,不同选取方式会得到不同结果;
3.对异常值敏感;
4.样本只能归为一类,不适合多分类任务;
5.不适合太离散的分类、样本类别不平衡的分类、非凸形状的分类。

4.编写KMeans算法实现类

import numpy as np


class KMeans:
    def __init__(self, data, num_clusters):
        self.data = data
        self.num_clusters = num_clusters

    def train(self, max_iterations):
        centerids = KMeans.centerids_init(self.data, self.num_clusters)        
        num_examples = self.data.shape[0]        
        closest_centerids_ids = np.empty((num_examples, 1))        
        for _ in range(max_iterations):
            closest_centerids_ids = KMeans.centerids_find_closest(self.data, centerids)            
            centerids = KMeans.centerids_compute(self.data, closest_centerids_ids, self.num_clusters)        
        return centerids, closest_centerids_ids

    @staticmethod    
    def centerids_init(data, num_clusters):
        num_examples = data.shape[0]        
        random_ids = np.random.permutation(num_examples)        
        centerids = data[random_ids[:num_clusters], :]        
    return centerids

    @staticmethod    
    def centerids_find_closest(data, centerids):
        num_examples = data.shape[0]        
        num_centerids = centerids.shape[0]        
        closest_centerids_ids = np.zeros((num_examples, 1))        
        for example_index in range(num_examples):
            distance = np.zeros((num_centerids, 1))            
            for centerid_index in range(num_centerids):
                distance_diff = data[example_index, :] - centerids[centerid_index, :]                
                distance[centerid_index] = np.sum((distance_diff ** 2))            
                closest_centerids_ids[example_index] = np.argmin(distance)        
        return closest_centerids_ids

    @staticmethod    
    def centerids_compute(data, closest_centerids_ids, num_clusters):
        num_features = data.shape[1]        
        centerids = np.zeros((num_clusters, num_features))        
        for centerid in range(num_clusters):
            closest_ids = closest_centerids_ids == centerid
            centerids[centerid] = np.mean(data[closest_ids.flatten(), :], axis=0)        
        return centerids

5.KMeans算法测试

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

from cls_kmeans.k_means import KMeans

iris = load_iris()data = pd.DataFrame(data=iris.data, columns=iris.feature_names)
data["species"] = iris.target_names[iris.target]

# print(data.head())
# print(iris.feature_names)

x_axis = iris.feature_names[2]
y_axis = iris.feature_names[3]

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)  # 一行两列,第一个图
for iris_type in iris.target_names:
    plt.scatter(data[x_axis][data["species"] == iris_type],                
    data[y_axis][data["species"] == iris_type],                
    label=iris_type)
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.title("Label Known")
    plt.legend()
    
    plt.subplot(1, 2, 2)  # 一行两列,第二个图
    plt.scatter(data[x_axis][:], data[y_axis][:], label="all_type")
    plt.title("Label Unknown")
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.show()
    
    # print(np.unique(iris.target).shape[0])
    num_examples = data.shape[0]
    x_train = data[[x_axis, y_axis]].values.reshape(num_examples, 2)
    max_iterations = 50
    num_clusters = 3
    
    kmeans = KMeans(data=x_train, num_clusters=num_clusters)
    (centerids, closest_centerids_ids) = kmeans.train(max_iterations=max_iterations)
    
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)  # 一行两列,第一个图
    for iris_type in iris.target_names:
        plt.scatter(data[x_axis][data["species"] == iris_type],                
                    data[y_axis][data["species"] == iris_type],                
                    label=iris_type)
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.title("Label Known")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    for centerid_id, centerid in enumerate(centerids):
        current_example_index = (closest_centerids_ids == centerid_id).flatten()    
        plt.scatter(data[x_axis][current_example_index],                
                    data[y_axis][current_example_index],                
                    label=centerid_id)

    for centerid_id, centerid in enumerate(centerids):
        plt.scatter(centerid[0], centerid[1], c="black", marker="x")

    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.title("Label KMeans")
    plt.legend()
    plt.show()

6.结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/713181.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构与算法】Huffman编码/译码(C/C++)

实践要求 1. 问题描述 利用哈夫曼编码进行信息通讯可以大大提高信道利用率,缩短信息传输时间,降低传输成本。但是,这要求在发送端通过一个编码系统对待传数据预先编码;在接收端将传来的数据进行译码(复原)。对于双工信道(即可以…

3D渲染的定义和应用领域

三维渲染(3D rendering)是一种将三维模型数据转化为二维图像的技术,通常利用计算机图形学的方法来实现。通过运用光线、材质、纹理、阴影等效果,将三维物体展现在二维屏幕上,以模拟真实世界中的三维景象。 一、三维渲…

el-table 默认勾选数据

目录 效果图 步骤: 1. 看elementui 官网上的案例,用到的方法是自带的 toggleRowSelection 2. 思路 原委 选中主表中的一条数据;判断与子表中的数据是否关联(如果子表的关联ID主表的ID,则子表的这条数据显示被勾选&a…

CADD蛋白结构分析、虚拟筛选、分子对接(蛋白-蛋白、蛋白-

时间:第一天上午 课程名称:生物分子互作基础 课程内容:1.生物分子相互作用研究方法 1.1蛋白-小分子、蛋白-蛋白相互作用原理 1.2 分子对接研究生物分子相互作用 1.3 蛋白蛋白对接研究分子相互作用 课程名称:蛋白数据库 课程内容:1. PDB 数据库介绍 1.1 PDB蛋白数据库功能 1.2 …

Springboot整合jdbc_template

1.构建Springboot项目 利用springboot整合jdbctemplate,并不需要导入其他的依赖,具体的项目结构图如图 2.写domain层 数据库映射的实体类 package com.jkk.springboot_jdbc_template.domain;/*** author jkk*/import lombok.AllArgsConstructor; import lombok…

04 - C++学习笔记: 循环语句和跳转语句

在C编程中,循环语句和跳转语句是非常重要的控制结构。循环语句允许重复执行一段代码,而跳转语句允许在程序执行过程中改变执行的流程。本篇笔记将介绍C中常用的循环语句和跳转语句,并通过例子进行说明。 🔁循环类型 C 编程语言提…

查询子节点 postgresql

数据库为postgresql WITH RECURSIVE cte AS (SELECTn. ID,n. com_name,n."parentId" AS pidFROMcompany AS nWHEREn. ID = 2UNION ALLSELECTr. ID,r. com_name,cte. ID AS pidFROMcteJOIN company AS r ON r.

轻松实现邮箱验证码功能!快来体验Spring Boot的神奇力量!

邮件验证是现代互联网服务中常用的安全功能,本文介绍如何利用Spring Boot框架快速搭建一个高效易用的邮箱验证码功能。从配置邮箱>发送服务,到编写验证逻辑,无痛实现邮箱验证码功能轻而易举。快来掌握这个技能,加强您的应用安全…

论文解读 | CVPR 2020:PV-RCNN:用于三维物体检测的点体素特征集提取

原创 | 文 BFT机器人 论文《PV-RCNN: Point-Voxel Feature Set Abstraction for 3D Object Detection》是一篇关于三维物体检测的论文。该论文提出了一种名为PV-RCNN的方法,用于从点云数据中进行三维物体检测,并在各种应用中取得了优秀的性能。 论文的主…

数据库第一章

一。数据库 1.概述 数据库database用来存储数据和管理数据的仓库 分类:关系型MySQL/非关系型Redis 关系型数据库(二维表格模型):Oracle,MySQL,SQLServer,Access 非关系型数据库:MongoDB,Redis&#xf…

linux 文件锁flock与fcntl bytes级别精细控制不再是困难

​专栏内容: postgresql内核源码分析 手写数据库toadb 并发编程 个人主页:我的主页 座右铭:天行健,君子以自强不息;地势坤,君子以厚德载物. 文件锁 概述 前面博客介绍了多任务下互斥的方法,如…

Docker容器的数据卷

Docker容器的数据卷 一、数据卷概念 概念:数据卷是宿主机中的一个目录或文件 当容器目录和数据卷目录绑定后,对方的修改会立即同步一个数据卷可以被多个容器同时挂载一个容器也可以挂载多个数据卷 可以解决以下问题 可以解决容器数据的持久化&#xff0…

高效学习法

目标明确,难度适中 全面:宏观概述,微观详尽 明确:目标要明确,否则陷入选择漩涡,导致大脑内耗。李白的“行路难,多歧路” 渐进:既要进步,也要逐步…

47 # 实现可读流

上一节讲了 fs.createReadStream 创建一个可读流,那么怎么查看它的源码是怎么实现的? 我们可以采用打断点的方式:我们可以看到先执行了 lazyLoadStreams 如果没有 ReadStream 就会 require 内部的 internal/fs/streams 模块 通过 internal/f…

免费开源 | 基于SpringBoot+Vue的物流管理系统

1-介绍 基于SpringBootvuemybatis-plus的简单的物流管理系统DEMO,前后端分离,可用于扩展基础,可用于简单课设,可用于基础学习 2-技术架构 SpringBootvuemybatis-plusmysql 8.0 3-使用说明 安装数据库demo/sql/wuliu.sql运行后端demo 1-…

QT调用glog日志流程

glog日志库是Google开源的轻量级的日志库,平时在开发过程中经常要使用到日志,本篇记录Qt项目使用glog日志库的记录。 1.首先下载cmake,Download | CMake 安装设置环境变量,检查安装情况 2.下载glog源码 git clone https://git…

指数分布的概率密度推导

指数分布的概率密度,一直理解的不够深入,一直都不明白为什么是这么奇怪的形式,指数位置的分母为什么有个-theta,也一直不太明白该分布的特点,直到看到如下篇博文: 指数分布概率密度推导1 指数分布概率密度…

PyCharm安装配置PyQt5/QtDesigner/PyUic的超详细教程

目录 1.介绍 2.安装与配置 1.下载安装PyQt5 2.QtDesignerPyUic的安装配置 1.下载安装 2.打开designer.exe所在位置 3.配置PyCharm QtDesigner 4.验证安装是否成功 5.PyCharmPyUic快捷菜单工具配置:便于将Qt的UI文件转换成.py文件 6.配置PyQt5 PyRcc:便于将资源文件转码 1…

拒绝裸奔,使用jasypt为SpringBoot配置文件进行加密。

平日使用Github上传代码时,不可避免的会遇到一个问题就是配置文件中的敏感信息的处理,如MySQL的用户名密码,Redis的密码等。而如果一不注意提交到Github后,无异于出门不锁还留把钥匙,后果不堪设想, 近些年开…

随笔-毕业十周年聚会

文章目录 随笔-毕业十周年聚会1. 引子2. 流水账3. 感悟 随笔-毕业十周年聚会 1. 引子 上周三,许久不联系的班长给我发了个微信,问我周六有没有时间,学校和学院组织了毕业十周年校友返校活动,凑着这个机会大家聚一聚。 一时间有…