Python实现KNN算法(附源码)

news2025/1/3 5:29:55

本篇我们将讨论一种广泛使用的分类技术,称为k邻近算法,或者说K最近邻(KNN,k-Nearest Neighbor)。所谓K最近邻,是k个最近的邻居的意思,即每个样本都可以用它最接近的k个邻居来代表。

01、KNN算法思想

如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。

由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

02、KNN算法的决策过程

下图中有两种类型的样本数据,一类是蓝色的正方形,另一类是红色的三角形,中间那个绿色的圆形是待分类数据:

▍近邻分类图

如果K=3,那么离绿色点最近的有2个红色的三角形和1个蓝色的正方形,这三个点进行投票,于是绿色的待分类点就属于红色的三角形。而如果K=5,那么离绿色点最近的有2个红色的三角形和3个蓝色的正方形,这五个点进行投票,于是绿色的待分类点就属于蓝色的正方形。

KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成反比。

下面用代码来实现KNN算法的应用。本次用到的数据是经典的Iris数据集。该数据集有150条鸢尾花数据样本,并且均匀分布在3个不同的亚种:每个数据样本被4个不同的花瓣、花萼的形状特征所描述。

#读取数据
from sklearn.datasets import load_iris
data = load_iris()
#查看数据大小
data.data.shape
(150, 4)
#查看数据说明
print (data.DESCR)
Notes
-----
Data Set Characteristics:
    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988
This is a copy of UCI ML iris datasets.
http://archive.ics.uci.edu/ml/datasets/Iris
The famous Iris database, first used by Sir R.A Fisher
This is perhaps the best known database to be found in the pattern recognition literature.  Fisher's paper is a classic in the field and is referenced frequently to this day.  (See Duda & Hart, for example.)  The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant.  One class is linearly separable from the other 2; the latter are NOT linearly separable from each other.

References
----------
   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...

通过上述代码对数据的查验以及数据本身的描述,我们可以了解到Iris数据集共有150条鸢尾花数据样本,并且均匀分布在3个不同的亚种;每一个数据样本被4个不同的花瓣、花萼的形状特征所描述。由于没有指定的测试集,依据管理,我们需要第数据进行随机分割,25%的数据用作测试,75的数据用作训练。

需要强调的是,如果大家自行编写程序用作数据分割,请务必保证是随机采样。尽管很多数据集中的样本的排序相对随机,但是也有例外。本例中,Iris数据就是根据类别一次排列的。如果只采样前25%的数据用作测试,那么所有的测试样本都属于一个类别,同时训练样本也是不均衡的,这样得到的结果存在偏置,并且可信度非常低,Scikit-learn所提供的数据分割模块是默认采用随机采样的功能的,因此大家可不必担心。

#对数据进行分割
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size = 0.25, random_state = 33)

#使用KNN算法进行分类
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
#初始化
ss = StandardScaler()

#数据标准化
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)

#训练模型
knc = KNeighborsClassifier()
knc.fit(X_train, y_train)
#预测
y_pred = knc.predict(X_test)

#模型评估
print ('The accuracy of KNN is:', knc.score(X_test, y_test))
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred, target_names = data.target_names))

代码输出结果如下,Knn算法对鸢尾花测试数据的分类准确率为89.474%,其他数据如下可见。

KNN算法的特点分析:KNN算法是非常直观的机器学习模型,因此深受广大初学者的喜爱。许多教科书往往一次模型抛砖引玉,便足以看出其不仅特别,而且尚有瑕疵之处。细心的读者会发现,KNN算法与其他算法模型最大的不同在于:该模型没有参数训练过程。也就是说,我们并没有通过任何学习算法来分析训练数据,而只是根据测试样本在训练数据中的的分布直接做出分类决策。因此,KNN算法属于无参数模型中非常简单的一种。然而,正是这样的决策算法,导致了其非常高的计算复杂度和内存消耗。因为该模型每处理一个测试样本,都需要对所有事先加载在内存中的训练样本进行遍历、逐一计算相似度、排序并且选取K个最近邻训练样本的标记,进而做出分类决策。这是平方级的算法复杂度,一旦数据规模稍大,便需要权衡更多计算时间的代价。

最后,对KNN算法做一个简单的小结:

优点
简单,易于理解,易于实现,无需估计参数,无需训练;
适合对稀有事件进行分类;
特别适合于多分类问题(multi-modal,对象具有多个类别标签),kNN比SVM的表现要好。
缺点
当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数,少数类容易分错。
需要存储全部训练样本。
计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。
可理解性差,无法给出像决策树那样的规则。

03、源码

链接: https://pan.baidu.com/s/1XtygYiZYH51Dob9s4K0tfw?pwd=9j99 提取码: 9j99 复制这段内容后打开百度网盘手机App,操作更方便哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/593001.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

绿色智慧档案顺丰环境一体化平台选型表

盛世宏博八防一体化监控系统选型表 序号 功能选择 1 恒温恒湿系统 温湿度监测 口Y:需要 口N:不需要 空调控制 口Y:需要 口N:不需要 加湿机控制 口Y:需要 口N:不需要 除湿…

KD05丨动量RSI策略

大家好,今天我们来分享魔改RSI策略,RSI即相对强弱指数,本质上就是一个动量指标,用于衡量一定时间内价格变动的速度及其变动的大小。它在0-100的范围内变动,通常以70和30作为过热和过冷的界限。要将RSI指标改为一个趋势…

Smartbi“三步走”构建智慧经营分析平台,实现国有企业监管报送和数智化转型

01. 现状与痛点 — 一直以来,国资国企都是促进我国经济高速发展的领头羊,但近年来受疫情冲击和国际经济下行影响,国资企业经营面临较大压力,同时为实现国有企业高质量发展,国务院国资委下发一系列政策要求&#xff…

Halcon 表面法向量 pcl

一、Halcon halcon 案例: find_surface_model_noisy_data.hdev 思路步骤: 1、读取图像 2、拆通道 3、通过Z通道选出比较合适做匹配的模板 4、通过Z x y 生成一个模型xyz_to_object_model_3d 5、计算表面法向量并生成表面的模型,这个模型…

Linux:apache网页优化

Linux:apache网页优化 一、Apache 网页优化二、网页压缩2.1 检查是否安装 mod_deflate 模块2.2 如果没有安装mod_deflate 模块,重新编译安装 Apache 添加 mod_deflate 模块2.3 配置 mod_deflate 模块启用2.4 检查安装情况,启动服务2.5 测试 m…

字节跳动测试岗,3面都过了,HR告诉我是这个原因才被刷...

说在前面 面试时最好不要虚报工资。本来字节跳动是很想去的,几轮面试也通过了,最后没offer,自己只想到下面几个原因: 虚报工资,比实际高30%;有更好的人选,这个可能性不大,我看还在…

医院检验科LIS系统的常规检验项目有哪些?

医院检验科LIS系统的常规检验项目包括: 白细胞数目、中性粒细胞数目、淋巴细胞数目、单核细胞数目、嗜酸性粒细胞数目、嗜碱性粒细胞数目、中性粒细胞百分比、 淋巴细胞百分比、单核细胞百分比、嗜酸性粒细胞百分比、嗜碱性粒细胞百分比、红细胞数目、血红蛋白、红…

淘宝太细了:mysql 和 es 的5个一致性方案,你知道吗?

说在前面 在40岁老架构师 尼恩的读者交流群(50)中,最近有小伙伴拿到了一线互联网企业如拼多多、极兔、有赞、希音的面试资格,遇到一几个很重要的面试题: 说5种mysql 和 elasticsearch 数据一致性方案 与之类似的、其他小伙伴遇到过的问题还…

电脑怎么隐藏文件夹?这样做,快速搞定!

案例:我想把一些敏感和重要的文件夹隐藏起来,不想别人看到它们。在电脑上如何隐藏电脑文件夹?有没有小伙伴知道如何操作?急需! 我们在使用电脑的过程中,会产生大量文件,有些文件可能包含私密信…

4.1 Spark SQL概述、数据帧与数据集

一、数据帧 - DataFrame (一)准备工作 1、准备数据文件 2、启动Spark Shell (二)加载数据为Dataset 1、读文件得数据集 2、显示数据集内容 3、显示数据集模式 (三)给数据集添加元数据信息 1、定…

强!PCB“金手指”从设计到生产全流程

在电脑内存条、显卡上,有一排金黄色导电触片,就是大家俗称的“金手指”。 在PCB设计制作行业中的“金手指”(Gold Finger,或称Edge Connector),是由connector连接器作为PCB板对外连接网络的出口。 关于“金手指”你知道多少呢&a…

像核战争一样,AI可能灭绝人类:Geoffrey Hinton、Sam Altman等百名专家签署了一封公开信

多位图灵奖得主、顶级 AI 公司 CEO、顶尖高校教授,与数百位在各自领域享有话语权的专家,共同签署了一份公开信,内容简单却有力: 降低 AI 灭绝人类的风险,应该与大流行病、核战争等其他社会规模的风险一样,…

AI落地:儿童节礼物指南

这个儿童节,用AI做点不一样的礼物,给孩子一个惊喜。 可行清单: 写走心的贺卡(增强表达能力,培养心思细腻)用AI让孩子的画的小人动起来(激发创造力,培养想象力)把孩子的…

Ansys Zemax | 如何模拟部分反射和散射的表面

这篇文章介绍了如何模拟一个部分反射的表面,该表面会根据指定的散射分布对一部分入射光能量进行散射。本文介绍的示例包含部分吸收以及部分镜面反射的情况。(联系我们获取文章附件) 介绍 使用 OpticStudio 非序列模式模拟散射和膜层的能力,我们可以模拟一…

MFC按钮中添加图标

目录 一、创建对话框 二、 开始添加 1、将.ico图片放进res路径下 2、添加资源 3、添加按钮 4、将按钮属性中icon修改为true 5、代码添加 一、创建对话框 首先需要创建个对话框程序,参考之前写的博客: mfc入门基础(三)创…

浅谈智能化配电室在居民小区的建设应用

安科瑞 徐浩竣 江苏安科瑞电器制造有限公司 zx acrelxhj 摘要:近年来居民小区配电室的数量增长快且设备情况较复杂,以致巡视效果不理想、缺陷和事故处理不及时,亟需建立一套智能化的配电室监控系统。按照实用性、统一性、分层和模块化设计…

RobotFramework接口测试方案

1. Robot FrameWork介绍 1.1 介绍 Robot Framework是用于验收测试和回归测试的通用测试自动化框架。它使用易于理解的表格数据语法,非常友好的实现了关键字驱动和数据驱动模式。它的测试功能可以通过使用Python或Java实现的测试库进行扩展,用户可以使用…

RCE代码及命令执行漏洞全解(30)

web应用中,有时候程序员为了考虑灵活性,简洁性,会在代码中调用代码或执行命令执行函数去处理。 比如当应用在调用一些能将字符串转化成代码的函数时,没有考虑用户是否能够控制这些字符串,将代码执行漏洞,同…

华为OD机试真题B卷 Java 实现【求最大连续bit数】,附详细解题思路

一、题目描述 求一个int类型数字对应的二进制数字中1的最大连续数,例如3的二进制为00000011,最大连续2个1。 二、输入描述 输入一个int类型数字。 三、输出描述 输出转成二进制之后连续1的个数。 四、解题思路 首先通过输入获取一个 int 类型的数…

K8s环境使用Triton实现云端模型推理

前置条件:K8集群、helm 1、以模型名作为目录名,创建目录 mkdir resnet50_pytorch 2、将模型文件、配置文件(输入、输出等)存到刚创建的目录下,resnet50_pytorch目录下文件层级结构如下 model-respository/ └── …