2、k-means聚类算法sklearn与手动实现

news2025/1/18 9:49:55

本文将对k-means聚类算法原理和实现过程进行简述

算法原理

k-means算法原理较简单,基本步骤如下:

1、假定我们要对N个样本观测做聚类,要求聚为K类,首先选择K个点作为初始中心点;
2、接下来,按照距离初始中心点最小的原则,把所有观测分到各中心点所在的类中;
3、每类中有若干个观测,计算K个类中所有样本点的均值,作为第二次迭代的K个中心点;
4、然后根据这个中心重复第2、3步,直到收敛(中心点不再改变或达到指定的迭代次数),聚类过程结束。

聚类过程示意图:
在这里插入图片描述

算法实践

下面对一个具体场景做聚类分析:500x500px的地图上,随机生成60个城市,要求生成10个聚类中心。

Sklearn实现

下面是调取sklearn相关的函数进行实现:

import matplotlib.pyplot as plt
import numpy as np
import warnings

from sklearn.cluster import KMeans

Num_dots = 60  # 城市总数
Num_gas = 10  # 聚类中心总数
Size_map = 500  # 500x500地图


colors = ['#FF3838', '#FF9D97', '#FF701F', '#FFB21D', '#CFD231', '#48F90A', '#92CC17', '#3DDB86', '#1A9334', '#00D4BB',
          '#2C99A8', '#00C2FF', '#344593', '#6473FF', '#0018EC', '8438FF', '#520085', '#CB38FF', '#FF95C8', '#FF37C7']

warnings.filterwarnings("ignore")

# 生成随机点
def generate():
    dots = []
    for i in range(Num_dots):
        dots.append(np.random.uniform([Size_map, Size_map]))
    # dots_sorted_x = sorted(dots, key=lambda dot: dot[0])
    return dots

# 计算两点之间欧式距离
def cal_dist(x, y):
    return ((x[0] - y[0]) ** 2 + (x[1] - y[1]) ** 2) ** 0.5

# 统计数组中各种相同元素个数
def num_same(dots_labels):
    num_labels = []
    key = np.unique(dots_labels)
    for k in key:
        mask = (dots_labels == k)
        y_new = dots_labels[mask]
        v = y_new.size
        num_labels.append(v)
    return num_labels

def cal_center_dist(center, dots):
    distance = 0
    for i in range(len(dots)):
        distance += cal_dist(center, dots[i])
    return distance

# K-Means聚类
def k_means(dots):
    cluster = KMeans(n_clusters=Num_gas)
    dots_labels = cluster.fit_predict(dots)
    centers = cluster.cluster_centers_
    return dots_labels, centers


# 绘制图像
def plot_dots(dots, dots_labels, centers):
    # 绘制点
    for i in range(len(dots_labels)):
        plt.scatter(dots[i][0], dots[i][1], color=colors[dots_labels[i]])
    # 绘制聚类中心
    for i in range(len(centers)):
        plt.scatter(centers[i][0], centers[i][1], marker='x', color="#000000", s=50)
    plt.show()


if __name__ == '__main__':
    np.random.seed(250)
    dots = generate()
    dots_labels, centers = k_means(dots)
    num_labels = num_same(dots_labels)
    # 输出结果
    distance = 0
    for i in range(len(centers)):
        print("聚类中心", i+1, "坐标为", np.round(centers[i], 2))
        index = np.argwhere(dots_labels == i)
        print("属于该聚类中心的城市标号为", [int(x)+1 for x in index])
        mark = [int(x) for x in index]
        distance += cal_center_dist(centers[i], [dots[i] for i in mark])
        print("所有聚类中心和所辖城市的距离之和为", np.round(distance,2))

    # 绘图
    plot_dots(dots, dots_labels, centers)

在这里插入图片描述

输出总距离:所有聚类中心和所辖城市的距离之和为 2860.48.

手动实现

下面根据算法的理解,进行手动实现:

import numpy as np
from matplotlib import pyplot as plt

Num_dots = 60  # 城市总数
Num_gas = 10  # 聚类中心总数
Size_map = 500  # 500x500地图

colors = ['#FF3838', '#FF9D97', '#FF701F', '#FFB21D', '#CFD231', '#48F90A', '#92CC17', '#3DDB86', '#1A9334', '#00D4BB',
          '#2C99A8', '#00C2FF', '#344593', '#6473FF', '#0018EC', '8438FF', '#520085', '#CB38FF', '#FF95C8', '#FF37C7']

# 生成随机点
def generate():
    dots = []
    for i in range(Num_dots):
        dots.append(np.random.uniform([Size_map, Size_map]))
    # dots_sorted_x = sorted(dots, key=lambda dot: dot[0])
    return dots

# 计算两点之间欧式距离
def cal_dist(x, y):
    return ((x[0] - y[0]) ** 2 + (x[1] - y[1]) ** 2) ** 0.5

# 计算中心点距离它所负责的所有点之和
def cal_center_dist(center, dots):
    distance = 0
    for i in range(len(dots)):
        distance += cal_dist(center, dots[i])
    return distance

# 根据城市坐标搜索城市序号
def search_city(value, dots):
    for i, item in enumerate(dots):
        if (item == value).any():
            return i


class K_Means:
    # k是分组数;tolerance‘中心点误差’;max_iter是迭代次数
    def __init__(self, k=2, tolerance=0.0001, max_iter=300):
        self.k_ = k
        self.tolerance_ = tolerance
        self.max_iter_ = max_iter
        self.distance = 0

    def fit(self, data):
        self.centers_ = {}
        for i in range(self.k_):
            self.centers_[i] = data[i]
            # print(self.centers_[i])  # {0: array([256.5, 542. ]), 1: array([586.5, 261.5]), 2: array([869. , 449.5])}

        for iter in range(self.max_iter_):
            self.clf_ = {}
            for i in range(self.k_):
                self.clf_[i] = []
            for feature in data:
                distances = []
                for center in self.centers_:
                    distances.append(cal_dist(feature, self.centers_[center]))

                classification = distances.index(min(distances))
                self.clf_[classification].append(feature)

            # 记录总路程
            self.distance = np.sum(distances)

            # 记录上一阶段中心点位置
            prev_centers = dict(self.centers_)

            # 移动每一个center到所辖城市的中心位置
            for c in self.clf_:
                self.centers_[c] = np.average(self.clf_[c], axis=0)

            # 若center的移动空间在误差范围内,跳出循环得到结果
            optimized = True
            for center in self.centers_:
                org_centers = prev_centers[center]
                cur_centers = self.centers_[center]
                if np.sum((cur_centers - org_centers) / org_centers * 100.0) > self.tolerance_:
                    optimized = False
            if optimized:
                break


if __name__ == '__main__':
    np.random.seed(250)
    dots = generate()
    k_means = K_Means(Num_gas)
    k_means.fit(dots)

    # 输出结果
    for i in range(Num_gas):
        print("聚类中心", i + 1, "坐标为", np.round(k_means.centers_[i], 2))
        city_index = []
        for x in k_means.clf_[i]:
            city_index.append(search_city(x, dots))
        print("属于该聚类中心的城市标号为", city_index)

    print("所有聚类中心和所辖城市的距离之和为", np.round(k_means.distance, 2))

    # 绘制中心点
    for center in k_means.centers_:
        plt.scatter(k_means.centers_[center][0], k_means.centers_[center][1], marker='x', color="#000000", s=50)

    # 绘制城市点
    for cat in k_means.clf_:
        for point in k_means.clf_[cat]:
            plt.scatter(point[0], point[1], c=colors[cat])

    plt.show()

在这里插入图片描述

输出总距离:所有聚类中心和所辖城市的距离之和为 2816.76

结论

聚类的常规标准是让聚类中心和所辖城市的距离之和,在本实验中,手动实现的k-means算法的结果要优于sklearn的结果。

这主要是由于k-means算法本身并不是非常稳定,容易受到初始点、离群点的影响,因此,所求解不一定是最优解。

附录:sklearn K-means参数/属性/接口

下面是sklearn中K-means算法的常用接口参数,数据来自菜菜的机器学习sklearn

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/997756.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蓝桥杯官网填空题(三角形的面积)

题目描述 本题为填空题,只需要算出结果后,在代码中使用输出语句将所填结果输出即可。 已知三角形三个顶点在直角坐标系下的坐标分别为: txt (2.3, 2.5) (6.4, 3.1) (5.1, 7.2) txt 求该三角形的面积。 注意,要提交的是一个小…

解决ROS2 humble版本源码编译中resdep init及rosdep update失败的问题

网上有在/etc/hosts中添加ip地址的方法,使用了不成功,具体做法如下,仅供参考: a.打开查询ip的网址: https://site.ip138.com b.输入:raw.githubusercontent.com c.执行sudo vi /etc/hosts 将获取到的ip添…

docker系列(3) - 常用软件安装

文章目录 3. docker安装常用软件3.1 安装nginx3.2 安装redis3.3 安装mysql3.4 部署springboot程序3.4.1 编写dockerfile3.4.2 构建镜像3.4.3 启动镜像 3. docker安装常用软件 3.1 安装nginx docker pull nginx#挂载启动 docker run -it -d \ --namenginx \ --networkpub_netw…

L2 数据仓库和Hive环境配置

1.数据仓库架构 数据仓库DW主要是一个用于存储,分析,报告的数据系统。数据仓库的目的是面向分析的集成化数据环境,分析结果为企业提供决策支持。-DW不产生和消耗数据 结构数据:数据库中数据,CSV文件 直接导入DW非结构…

2023高教杯数学建模2:DE题+参考论文、代码

2023高教杯数学建模2:DE题 写在最前面E题D题2014C题优秀论文笔记问题一(求解母猪年均产仔量以达到或超过盈亏平衡点)问题二(求解小猪选为种猪的比例和母猪的存栏数)问题三(确定最佳经营策略,计算…

docker系列(4) - docker镜像制作

文章目录 4. docker镜像4.1 联合文件系统(UnionFS)4.2 Docker镜像加载原理4.3 docker commit(扩展镜像)(非常重要)4.3.1 案例4.3.1.1 下载ubuntu镜像4.3.1.2 安装vim4.3.1.3 commit新的镜像4.3.1.3 验证新的镜像 4. docker镜像 4.1 联合文件系统(UnionFS) Union文件系统(Unio…

树树树树树

//先序遍历 void PreOrder(BiTree T){if(T!NULL){printf("%c",T->data);PreOrder(T->lchild);PreOrder(T->rchild);} } //后序遍历 void PostOrder(BiTree T){if(T!NULL){PostOrder(T->lchild);PostOrder(T->rchild);printf("%c",T->dat…

美国星链再迎挑战,中国最有钱的通信企业争夺卫星互联网服务

随着一家手机企业发布5G卫星手机,卫星互联网服务的热度大增,业界人士指出目前能提供卫星互联网服务的仅有中国电信,但是中国移动已不甘落后,正在测试低轨道卫星互联网服务,这也是中国与美国星链竞争的序幕。 据了解日前…

表的约束类型

空属性约束 mysql有空属性和非空属性:null和not null 数据库默认字段基本都是字段为空,但是实际开发时,尽可能保证字段不为空,因为数据为空没办法参与运算 所以我们在设计数据库表的时候,一定要在表中进行限制&…

嵌入式:驱动开发 Day2

作业&#xff1a;字符设备驱动&#xff0c;完成三盏LED灯的控制 驱动代码&#xff1a; mychrdev.c #include <linux/init.h> #include <linux/module.h> #include <linux/fs.h> #include <linux/uaccess.h> #include <linux/io.h> #include &q…

oracle表空间释放

oracle表空间释放 1&#xff09;查询表空间信息2&#xff09;查询指定表空间下各个表的表空间使用情况3-1&#xff09;可以直接释放3-2) 可以move3-3&#xff09;重新排列 1&#xff09;查询表空间信息 selecta.tablespace_name as "表空间名",total as "表空间…

初识Java 7-1 多态

目录 向上转型 难点 方法调用绑定 产生正确的行为 可扩展性 陷阱&#xff1a;“重写”private方法 陷阱&#xff1a;字段与静态方法 构造器和多态 构造器的调用顺序 继承和清理 构造器内部的多态方法行为 协变返回类型 使用继承的设计 替换和扩展 向下转型和反射…

Java开发之Mysql【面试篇 完结版】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、知识体系二、Mysql-优化1. 优化-如何定位慢查询① 问题引入② 解决方案③ 问题总结④ 实战面试 2. 优化-sql执行很慢&#xff0c;如何解决① 问题引入② 解…

AI项目五:结印动作识别

若该文为原创文章&#xff0c;转载请注明原文出处。 感谢恩培大佬对项目进行了完整的实现&#xff0c;并将代码进行开源&#xff0c;供大家交流学习。 恩培大佬开源地址&#xff0c;有兴趣的可以去复现一下。GitHub - enpeizhao/CVprojects: computer vision projects | 计算机…

计算机网路学习-time_wait过多

四次挥手 调试命令 netstat -an|awk ‘/tcp/ {print $6}’|sort|uniq -c netstat -an 列出系统中所有处于活动状态的网络连接信息&#xff0c;包括 IP 地址、端口号、协议等。 其中&#xff0c;第六列是tcp的状态。 Proto Recv-Q Send-Q Local Address Foreign Addr…

Aidlux工业视觉缺陷检测

Aidlux工业视觉缺陷检测 1. AidLux简介 AidLux是成都阿加犀智能科技有限公司自主研发的融合架构平台&#xff0c;提供Android&#xff0f;鸿蒙&#xff0b;Linux融合系统&#xff0c; 双系统既能独立使用又能相互通信。 阿加犀致力于人工智能核心技术持续创新&#xff0c; 独…

CSS元素浮动

概述 浮动简介 在最初&#xff0c;浮动是用来实现文字环绕图片效果的&#xff0c;现在浮动是主流的页面布局方式之一。 元素浮动后的特点 脱离文档流。不管浮动前是什么元素&#xff0c;浮动后&#xff0c;默认宽与高都是被内容撑开的&#xff08;尽可能小&#xff09;&am…

AKF拆分原则

在分布式软件环境下&#xff0c;为了保障分布式架构的可靠性、可扩展、高性能&#xff0c;通常会通过集群、扩容、数据分治等思想来实现&#xff0c;比如很多中间件的使用Redis、ZK、Kafka等&#xff0c;都可以通过这种设计思想来提高系统架构吞吐量。AKF是一个系统化的拓展思想…

Vue框架+Element组件库学习笔记

一、Vue框架 vue&#xff1a;是一款前端框架&#xff0c;免除原生JavaScript中的DOM操作&#xff08;如document.getElementById("文本输入框名").value&#xff09;&#xff0c;简化书写。基于MVVM&#xff08;Model-View-ViewModel&#xff09;思想&#xff0c;实…

YApi 新版如何查看 http 请求数据

YApi 新版如何查看 http 请求数据 因chrome 安全策略限制&#xff0c;在 cross-request 升级到 3.0 后&#xff0c; 不再支持文件上传功能&#xff0c;并且需要通过以下方法查看 network:1.首先在chrome 输入 > chrome://extensions打开扩展页2.开启开发者模式3.点击 cross…