【机器学习】聚类(一):原型聚类:K-means聚类

news2025/1/10 22:14:19

文章目录

  • 一、实验介绍
    • 1. 算法流程
    • 2. 算法解释
    • 3. 算法特点
    • 4. 应用场景
    • 5. 注意事项
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 导入必要的库
    • 1. Kmeans类
      • a. 构造函数
      • b. 闵可夫斯基距离
      • c. 初始化簇心
      • d. K-means聚类
      • e. 聚类结果可视化
    • 2. 辅助函数
    • 3. 主函数
      • a. 命令行界面 (CLI)
      • b. 数据加载
      • c. 模型训练及可视化
    • 4. 运行脚本的命令
    • 5. 代码整合

  原型聚类中的K均值算法是一种常用的聚类方法,该算法的目标是通过迭代过程找到数据集的簇划分,使得每个簇内的样本与簇内均值的平方误差最小化。这一过程通过不断迭代更新簇的均值来实现。

一、实验介绍

在这里插入图片描述

1. 算法流程

  1. 初始化: 从样本集中随机选择k个样本作为初始均值向量。
  2. 迭代过程: 重复以下步骤直至均值向量不再更新:
    • 对每个样本计算与各均值向量的距离。
    • 将样本划分到距离最近的均值向量所对应的簇。
    • 更新每个簇的均值向量为该簇内样本的平均值。
  3. 输出: 返回最终的簇划分。

2. 算法解释

  • 步骤1中,通过随机选择初始化k个均值向量。
  • 步骤2中,通过计算样本与均值向量的距离,将每个样本分配到最近的簇。然后,更新每个簇的均值向量为该簇内样本的平均值。
  • 算法通过迭代更新,不断优化簇内样本与均值向量的相似度,最终得到较好的聚类结果。

3. 算法特点

  • K均值算法是一种贪心算法,通过局部最优解逐步逼近全局最优解。
  • 由于需要对每个样本与均值向量的距离进行计算,算法复杂度较高。
  • 对于大型数据集和高维数据,K均值算法的效果可能受到影响。

4. 应用场景

  • K均值算法适用于样本集可以被均值向量较好表示的情况,特别是当簇呈现球形或近似球形分布时效果较好。
  • 在图像分割、用户行为分析等领域广泛应用。

5. 注意事项

  • 对于K均值算法,初始均值向量的选择可能影响最终聚类结果,因此有时需要多次运行算法,选择最优的结果。
  • 算法对异常值敏感,可能导致簇的均值向量被拉向异常值,因此在处理异常值时需要谨慎。

二、实验环境

1. 配置虚拟环境

conda create -n ML python==3.9
conda activate ML
conda install scikit-learn matplotlib seaborn

2. 库版本介绍

软件包本实验版本
matplotlib3.5.2
numpy1.21.5
python3.9.13
scikit-learn1.0.2
seaborn0.11.2

三、实验内容

0. 导入必要的库

import numpy as np
import random
import seaborn as sns
import matplotlib.pyplot as plt
import argparse

1. Kmeans类

  • __init__ :初始化K均值聚类的参数,包括聚类数目 k、数据 data、初始化模式 mode(默认为 “random”)、最大迭代次数 max_iters、闵可夫斯基距离的阶数 p、随机种子 seed等。
  • minkowski_distance 函数:计算两个样本点之间的闵可夫斯基距离。
  • center_init 函数:根据指定的模式初始化聚类中心。
  • fit 方法:执行K均值聚类的迭代过程,包括分配样本到最近的簇、更新簇中心,直到满足停止条件。
  • visualization 函数:使用Seaborn和Matplotlib可视化聚类结果。

a. 构造函数

class Kmeans(object):
    def __init__(self, k, data: np.ndarray, mode="random", max_iters=0, p=2, seed=0):
        self.k = k
        self.data = data

        self.mode = mode
        self.max_iter = max_iters if max_iters > 0 else int(1e8)
        self.p = p
        self.seed = seed

        self.centers = None
        self.clu_idx = np.zeros(len(self.data), dtype=np.int32)  # 样本的分类簇
        self.clu_dist = np.zeros(len(self.data), dtype=np.float64)  # 样本与簇心的距离
  • 参数:
    • 聚类数目 k
    • 数据集 data
    • 初始化模式 mode
    • 最大迭代次数 max_iters
    • 闵可夫斯基距离的阶数 p 以及随机种子 seed
  • 初始化类的上述属性,此外
    • self.centers 被初始化为 None,表示簇心尚未计算
    • self.clu_idxself.clu_dist 被初始化为全零数组,表示每个样本的分类簇和与簇心的距离。

b. 闵可夫斯基距离

    def minkowski_distance(self, x, y=0):
        return np.linalg.norm(x - y, ord=self.p)
  • 使用了NumPy的 linalg.norm 函数,其中 ord 参数用于指定距离的阶数。

c. 初始化簇心

    def center_init(self):
        random.seed(self.seed)

        if self.mode == "random":
            ids = random.sample(range(len(self.data)), k=self.k)  # 随机抽取k个样本下标
            self.centers = self.data[ids]  # 选取k个样本作为簇中心
        else:
            ids = [random.randint(0, self.data.shape[0])]
            for _ in range(1, self.k):
                max_idx = 0
                max_dis = 0
                for i, x in enumerate(self.data):
                    if i in ids:
                        continue
                    dis = 0
                    for y in self.data[ids]:
                        dis += self.minkowski_distance(x - y)
                    if max_dis < dis:
                        max_dis = dis
                        max_idx = i
                ids.append(max_idx)
            self.centers = self.data[ids]
  • 根据指定的初始化模式,选择随机样本或使用 “far” 模式。
    • 在 “random” 模式下,通过随机抽样选择 k 个样本作为簇心;
    • 在 “far” 模式下,通过计算每个样本到已选簇心的距离之和,选择距离总和最大的样本作为下一个簇心。

d. K-means聚类

    def fit(self):
        self.center_init()  # 簇心初始化

        for _ in range(self.max_iter):
            flag = False  # 判断是否有样本被重新分类

            # 遍历每个样本
            for i, x in enumerate(self.data):
                min_idx = -1  # 最近簇心下标
                min_dist = np.inf  # 最小距离
                for j, y in enumerate(self.centers):  # 遍历每个簇,计算与该样本的距离
                    # 计算样本i到簇j的距离dist

                    dist = self.minkowski_distance(x, y)

                    if min_dist > dist:
                        min_dist = dist
                        min_idx = j
                if self.clu_idx[i] != min_idx:
                    # 有样本改变分类簇,需要继续迭代更新簇心

                    flag = True

                # 记录样本i与簇的最小距离min_dist,及对应簇的下标min_idx
                self.clu_idx[i] = min_idx
                self.clu_dist[i] = min_dist

            # 样本的簇划分好之后,用样本均值更新簇心
            for i in range(self.k):
                x = self.data[self.clu_idx == i]
                # 用样本均值更新簇心
                self.centers[i] = np.mean(x, axis=0)

            if not flag:
                break
  • 在每次迭代中
    • 遍历每个样本,计算其到各个簇心的距离,将样本分配到距离最近的簇中。
    • 更新每个簇的均值(簇心)为该簇内所有样本的平均值。
  • 上述过程迭代进行,直到满足停止条件(样本不再重新分配到不同的簇)或达到最大迭代次数。

e. 聚类结果可视化

    def visualization(self, k=3):
        current_palette = sns.color_palette()
        sns.set_theme(context="talk", palette=current_palette)
        for i in range(self.k):
            x = self.data[self.clu_idx == i]
            sns.scatterplot(x=x[:, 0], y=x[:, 1], alpha=0.8)
        sns.scatterplot(x=self.centers[:, 0], y=self.centers[:, 1], marker="+", s=500)
        plt.title("k=" + str(k))
        plt.show()

2. 辅助函数

def order_type(v: str):
    if v.lower() in ("-inf", "inf"):
        return -np.inf if v.startswith("-") else np.inf
    else:
        try:
            return float(v)
        except ValueError:
            raise argparse.ArgumentTypeError("Unsupported value encountered")


def mode_type(v: str):
    if v.lower() in ("random", "far"):
        return v.lower()
    else:
        raise argparse.ArgumentTypeError("Unsupported value encountered")

  • order_type 函数:用于处理命令行参数中的 -p(距离测量参数),将字符串转换为浮点数。
  • mode_type 函数:用于处理命令行参数中的 --mode(初始化模式参数),将字符串转换为合法的初始化模式。

3. 主函数

a. 命令行界面 (CLI)

  • 使用 argparse 解析命令行参数
    parser = argparse.ArgumentParser(description="Kmeans Demo")
    parser.add_argument("-k", type=int, default=3, help="The number of clusters")
    parser.add_argument("--mode", type=mode_type, default="random", help="Initial centroid selection")
    parser.add_argument("-m", "--max-iters", type=int, default=40, help="Maximum iterations")
    parser.add_argument("-p", type=order_type, default=2., help="Distance measurement")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--dataset", type=str, default="./kmeans.2.txt", help="Path to dataset")
    args = parser.parse_args()
    

b. 数据加载

  • 从指定路径加载数据集。
	dataset = np.loadtxt(args.dataset)

在这里插入图片描述

c. 模型训练及可视化

	model = Kmeans(k=args.k, data=dataset, mode=args.mode, max_iters=args.max_iters, p=args.p,
                   seed=args.seed)
    model.fit()

    # 聚类结果可视化
    model.visualization(k=args.k)

4. 运行脚本的命令

  • 通过命令行传递参数来运行脚本,指定聚类数目、初始化模式、最大迭代次数等。
python kmeans.py -k 3 --mode random -m 40 -p 2 --seed 0 --dataset ./kmeans.2.txt

在这里插入图片描述

5. 代码整合

import numpy as np
import random
import seaborn as sns
import matplotlib.pyplot as plt
import argparse


class Kmeans(object):
    def __init__(self, k, data: np.ndarray, mode="random", max_iters=0, p=2, seed=0):
        self.k = k
        self.data = data

        self.mode = mode
        self.max_iter = max_iters if max_iters > 0 else int(1e8)
        self.p = p
        self.seed = seed

        self.centers = None
        self.clu_idx = np.zeros(len(self.data), dtype=np.int32)  # 样本的分类簇
        self.clu_dist = np.zeros(len(self.data), dtype=np.float64)  # 样本与簇心的距离

    def minkowski_distance(self, x, y=0):

        return np.linalg.norm(x - y, ord=self.p)

    # 簇心初始化
    def center_init(self):
        random.seed(self.seed)

        if self.mode == "random":
            ids = random.sample(range(len(self.data)), k=self.k)  # 随机抽取k个样本下标
            self.centers = self.data[ids]  # 选取k个样本作为簇中心
        else:
            ids = [random.randint(0, self.data.shape[0])]
            for _ in range(1, self.k):
                max_idx = 0
                max_dis = 0
                for i, x in enumerate(self.data):
                    if i in ids:
                        continue
                    dis = 0
                    for y in self.data[ids]:
                        dis += self.minkowski_distance(x - y)
                    if max_dis < dis:
                        max_dis = dis
                        max_idx = i
                ids.append(max_idx)
            self.centers = self.data[ids]

    def fit(self):
        self.center_init()  # 簇心初始化

        for _ in range(self.max_iter):
            flag = False  # 判断是否有样本被重新分类

            # 遍历每个样本
            for i, x in enumerate(self.data):
                min_idx = -1  # 最近簇心下标
                min_dist = np.inf  # 最小距离
                for j, y in enumerate(self.centers):  # 遍历每个簇,计算与该样本的距离
                    # 计算样本i到簇j的距离dist

                    dist = self.minkowski_distance(x, y)

                    if min_dist > dist:
                        min_dist = dist
                        min_idx = j
                if self.clu_idx[i] != min_idx:
                    # 有样本改变分类簇,需要继续迭代更新簇心

                    flag = True

                # 记录样本i与簇的最小距离min_dist,及对应簇的下标min_idx
                self.clu_idx[i] = min_idx
                self.clu_dist[i] = min_dist

            # 样本的簇划分好之后,用样本均值更新簇心
            for i in range(self.k):
                x = self.data[self.clu_idx == i]
                # 用样本均值更新簇心
                self.centers[i] = np.mean(x, axis=0)

            if not flag:
                break

    def visualization(self, k=3):
        current_palette = sns.color_palette()
        sns.set_theme(context="talk", palette=current_palette)
        for i in range(self.k):
            x = self.data[self.clu_idx == i]
            sns.scatterplot(x=x[:, 0], y=x[:, 1], alpha=0.8)
        sns.scatterplot(x=self.centers[:, 0], y=self.centers[:, 1], marker="+", s=500)
        plt.title("k=" + str(k))
        plt.show()


def order_type(v: str):
    if v.lower() in ("-inf", "inf"):
        return -np.inf if v.startswith("-") else np.inf
    else:
        try:
            return float(v)
        except ValueError:
            raise argparse.ArgumentTypeError("Unsupported value encountered")


def mode_type(v: str):
    if v.lower() in ("random", "far"):
        return v.lower()
    else:
        raise argparse.ArgumentTypeError("Unsupported value encountered")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Kmeans Demo")
    parser.add_argument("-k", type=int, default=3, help="The number of clusters")
    parser.add_argument("--mode", type=mode_type, default="random", help="Initial centroid selection")
    parser.add_argument("-m", "--max-iters", type=int, default=40, help="Maximum iterations")
    parser.add_argument("-p", type=order_type, default=2., help="Distance measurement")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--dataset", type=str, default="./kmeans.2.txt", help="Path to dataset")
    args = parser.parse_args()

    dataset = np.loadtxt(args.dataset)
    model = Kmeans(k=args.k, data=dataset, mode=args.mode, max_iters=args.max_iters, p=args.p,
                   seed=args.seed)  # args.seed)
    model.fit()

    # 聚类结果可视化
    model.visualization(k=args.k)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1248114.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jetson Orin Nano 内核编译

首先是安装编译环境所需的依赖 sudo apt-get install git bison flex libssl-dev zip libncurses-dev makesudo apt-get install build-essential bc下载交叉编译器以及代码&#xff0c;官方链接: link https://developer.nvidia.com/embedded/jetson-linux 解压下载的两个文件…

JVM之jvisualvm多合一故障处理工具

jvisualvm多合一故障处理工具 1、visualvm介绍 VisualVM是一款免费的&#xff0c;集成了多个 JDK 命令行工具的可视化工具&#xff0c;它能为您提供强大的分析能力&#xff0c;对 Java 应 用程序做性能分析和调优。这些功能包括生成和分析海量数据、跟踪内存泄漏、监控垃圾回…

应用内测分发平台如何上传应用包体?

●您可免费将您的应用&#xff08;支持苹果.ios安卓.apk文件&#xff09;上传至咕噜分发平台&#xff0c;我们将免费为应用生成下载信息&#xff0c;但咕噜分发将会对应用的下载次数进行收费&#xff08;每个账号都享有免费赠送的下载点数以及参加活动的赠送点数&#xff09;&a…

RTMDet原理与代码解析

paper&#xff1a;RTMDet: An Empirical Study of Designing Real-Time Object Detectors official implementation&#xff1a;https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet 本文的创新点 Backbone and Neck 在backbone的basic building block中采…

Spark SQL 时间格式处理

初始化Spark Sql package pbcp_2023.clear_dataimport org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.{current_date, current_timestamp}object twe_2 {def main(args: Array[String]): Unit {val con new …

js获取时间日期

目录 Date 对象 1. 获取当前时间 2. 获取特定日期时间 Date 对象的方法 1. 获取各种日期时间组件 2. 获取星期几 3. 获取时间戳 格式化日期时间 1. 使用 toLocaleString() 方法 2. 使用第三方库 UNIX 时间戳 内部表示 时区 Date 对象 JavaScript中内置的 Date 对象…

获取DOM元素和类型判断

一、获取dom元素 <div id"one" class"one">我是第一个div</div> <div>我是第二个div</div> <div name"username">我是第三个div</div> <input type"text" name"username"> 1.qu…

【完美世界】叶倾仙强势登场,孔雀神主VS护道人,石昊重逢清漪

Hello,小伙伴们&#xff0c;我是拾荒君。 《完美世界》国漫第138集已更新。在这一集中&#xff0c;天人族的行为让人大跌眼镜&#xff0c;他们不仅没有对石昊感恩戴德&#xff0c;反而一心想要夺取他身上的所有秘法宝术。天人族的护道人登场&#xff0c;试图以强大的实力压制石…

Theta*: Any-Angle Path Planning on Grids 原文翻译

Theta*: Any-Angle Path Planning on Grids 文章目录 Theta*: Any-Angle Path Planning on Grids翻译摘要1.Introduction2. Path-Planning Problem and Notation3. Existing Terrain Discretizations4.Existing Path-Planning Algorithms4.1 A* on GridsA* with Post-Smoothed …

2023年【R1快开门式压力容器操作】考试资料及R1快开门式压力容器操作复审考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 R1快开门式压力容器操作考试资料参考答案及R1快开门式压力容器操作考试试题解析是安全生产模拟考试一点通题库老师及R1快开门式压力容器操作操作证已考过的学员汇总&#xff0c;相对有效帮助R1快开门式压力容器操作复…

python获取json所有节点和子节点

使用python获取json的所有父结点和子节点 并使用父节点加下划线命名子节点 先展示一段json代码 {"level1": {"level2": {"level3": [{"level4": "4value"},{"level4_2": "4_2value"}]},"level2_…

毅速丨3D打印随形水路为何受到模具制造追捧

在模具制造行业中&#xff0c;随形水路镶件正逐渐成为一种革命性的技术&#xff0c;其提高冷却效率、优化产品设计、降低成本等优点&#xff0c;为模具制造带来了巨大的创新价值。 随形水路是一种根据产品形状定制的冷却水路&#xff0c;其镶件可以均匀地分布在模具的表面或内部…

迪科DTC-F81收费机DTC-F82

迪科DTC-F81收费机是一款挂式收费机&#xff0c;广泛应用的学校食堂刷卡消费&#xff0c;DTC-F82收费机是台式消费机&#xff0c;常用在学校超市&#xff0c;放在桌子上使用的&#xff0c;这2款消费机是迪科畅销机型&#xff0c;如下图 机器质量可靠稳定&#xff0c;不少用户使…

vivado产生报告阅读分析19-设计收敛报告

Challenging Timing Paths “ Challenging Timing Paths ” &#xff08; 时序收敛困难的路径 &#xff09; 部分列出了“ Assessment Details ” &#xff08; 评估详情 &#xff09; 部分中未能通过检查的时序路径的关键属性。默认情况下&#xff0c; 该命令会对每个时钟组中…

2024北京林业大学计算机考研分析

24计算机考研|上岸指南 北京林业大学 特色优势 Characteristics & Advantages&#xff1a;信息学院创建于2001年&#xff0c;是一个年轻而有朝气的学院。学院秉承“结构、特色、质量、创新”的八字方针&#xff0c;坚持以“质量提升、行业融合”为核心的内涵式发展战略&am…

在Linux上搭建JavaWeb项目运行环境

文章目录 安装JDK安装Tomcat安装数据库 安装JDK 安装Oracle官方的JDK比较麻烦&#xff0c;我们在此处选择安装开源社区维护的openjdk。他们俩的差别不大且兼容。 安装Tomcat 我们把本地下载好的 tomcat.zip 包拖到Linux页面上&#xff0c;让Linux也有一个zip包&#xff0c;再…

运动鞋品牌识别

一、前期工作 1. 设置GPU from tensorflow import keras from tensorflow.keras import layers,models import os, PIL, pathlib import matplotlib.pyplot as plt import tensorflow as tfgpus tf.config.list_physical_devices("GPU")if gpus:gpu0 …

网络安全工程师究竟是什么?怎么入门?

首先啊骚年们我们必须先了解网络安全这个行业究竟是干啥的。 是打ctf的&#xff1f;一个个都像韩商言吴白那么帅刷刷敲几个代码就能轻易夺旗&#xff1f; 还是像十大黑客之一的米特尼克一样闯入了“北美空中防务指挥系统”的计算机主机内&#xff0c;还在被通缉逃跑期间控制了…

【多线程】Thread类的使用

目录 1.概述 2.Thread的常见构造方法 3.Thread的几个常见属性 4.启动一个线程-start() 5.中断一个线程 5.1通过共享的标记来进行沟通 5.2 调用 interrupt() 方法来通知 6.等待一个进程 7.获取当前线程引用 8.线程的状态 8.1所有状态 8.2线程状态和转移的意义 1.概述 …

基于java技术的社区交易二手平台

基于java技术的社区交易二手平台的设计与实现 &#xff08;一&#xff09;开发背景 随着因特网的日益普及与发展&#xff0c;更多的人们开始通过因特网来寻求便利。但是&#xff0c;许多人都觉得网上商店里的东西不贵。所以&#xff0c;有些顾客宁愿去那些用二次定价建立起来的…