[Python] scikit-learn - K近邻算法介绍和使用案例

news2025/1/10 3:16:11

什么是K近邻算法?

K近邻算法(K-Nearest Neighbors,简称KNN)是一种基于实例的学习方法,主要用于分类和回归任务。它的基本思想是:给定一个训练数据集,对于一个新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数类别就是该输入实例的类别。

思路:

  1. 计算输入实例与训练数据集中每个实例之间的距离。
  2. 对距离进行排序,找到距离最近的K个实例。
  3. 根据这K个实例的类别进行投票,得到输入实例的类别。

K近邻算法使用场景和注意事项

K近邻算法(K-Nearest Neighbors,简称KNN)是一种基于实例的学习方法,主要用于分类和回归任务。它的使用场景包括:

  1. 数据集较小的情况:当数据集较小时,KNN算法可以快速地进行训练和预测,而不需要大量的计算资源。
  2. 数据集中存在噪声的情况:由于KNN算法是基于实例的,因此它对数据集中的噪声具有一定的容忍度。
  3. 数据集中存在异常值的情况:KNN算法在处理异常值时,会根据邻近实例的类别来进行投票,从而降低了异常值对结果的影响。
  4. 数据集中存在不平衡类别的情况:KNN算法在处理不平衡类别的数据集时,可以通过调整K值来平衡各个类别之间的样本数量。

在使用KNN算法时,需要注意以下几点:

  1. 选择合适的K值:K值的选择对算法的性能有很大影响。通常情况下,可以通过交叉验证等方法来选择合适的K值。
  2. 特征选择:KNN算法对特征的数量和质量要求较高,因此需要对特征进行选择和预处理,以提高算法的性能。
  3. 距离度量:KNN算法需要计算实例之间的距离,因此需要选择合适的距离度量方法,如欧氏距离、曼哈顿距离等。
  4. 性能评估:为了确保算法的性能,需要对算法进行性能评估,如准确率等指标。

K近邻算法python实现

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import numpy as np
from collections import Counter

def euclidean_distance(x1, x2):
    # 计算欧氏距离
    return np.sqrt(np.sum((x1 - x2) ** 2))

class KNN:
    def __init__(self, k=3):
        self.k = k

    def fit(self, X, y):
        self.X_train = X
        self.y_train = y

    def predict(self, X):
        y_pred = [self._predict(x) for x in X]
        return np.array(y_pred)

    def _predict(self, x):
        # 计算输入实例与训练数据集中每个实例之间的距离
        distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
        # 对距离进行排序,找到距离最近的K个实例的索引
        k_indices = np.argsort(distances)[:self.k]
        # 根据这K个实例的类别进行投票,得到输入实例的类别
        k_nearest_labels = [self.y_train[i] for i in k_indices]
        most_common = Counter(k_nearest_labels).most_common(1)
        return most_common[0][0]


data = load_iris()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

knn = KNN(k=3)
knn.fit(X_train, y_train)
predictions = knn.predict(X_test)

print("Accuracy:", accuracy_score(y_test, predictions))

scikit-learn中的K近邻算法

K近邻算法用于分类任务

sklearn.neighbors.KNeighborsClassifier — scikit-learn 1.4.0 documentation

 

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier

data = load_iris()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

knc = KNeighborsClassifier(n_neighbors=3)
knc.fit(X_train, y_train)
predictions = knc.predict(X_test)

print("Accuracy:", accuracy_score(y_test, predictions))

在这个示例中,我们首先从scikit-learn库中加载了iris花卉数据集,并将其划分为训练集和测试集。然后,我们创建了一个KNeighborsClassifier对象,并设置了K值为3。接下来,我们使用训练集对模型进行训练,并使用测试集进行预测。最后,我们计算了预测结果的准确度。 

K近邻算法用于回归任务

sklearn.neighbors.KNeighborsRegressor — scikit-learn 1.4.0 documentation

 

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error

# 加载iris花卉数据集
data = load_iris()
X = data.data
y = data.target

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建KNeighborsRegressor对象,设置K值为3
knn = KNeighborsRegressor(n_neighbors=3)

# 使用训练集对模型进行训练
knn.fit(X_train, y_train)

# 使用测试集进行预测
y_pred = knn.predict(X_test)

# 计算预测结果的均方误差
mse = mean_squared_error(y_test, y_pred)
print("均方误差:", mse)

在这个示例中,我们首先从scikit-learn库中加载了iris花卉数据集,并将其划分为训练集和测试集。然后,我们创建了一个KNeighborsRegressor对象,并设置了K值为3。接下来,我们使用训练集对模型进行训练,并使用测试集进行预测。最后,我们计算了预测结果的均方误差。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1405965.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot统一返回和统一异常处理

Session 认证和 Token 认证 过滤器和拦截器 SpringBoot统一返回和统一异常处理 上篇文章我们学习了基于 Token 认证的登录功能实现,分别使用了过滤器和拦截器去实现登录功能,这篇文章我们来学习项目中常用的统一返回结果和统一异常处理。 一、统一返…

java.lang.IllegalArgumentException: When allowCredentials is true

1.遇到的错误 java.lang.IllegalArgumentException: When allowCredentials is true, allowedOrigins cannot contain the special value "*" since that cannot be set on the "Access-Control-Allow-Origin" response header. To allow credentials to a…

vue echarts地图

下载地图文件: DataV.GeoAtlas地理小工具系列 范围选择器右侧行政区划范围中输入需要选择的省份或地市,选择自己想要的数据格式,这里选择了geojson格式,点右侧的蓝色按钮复制到浏览器地址栏中,打开的geojson文件内容…

[每日一题] 01.23 - 画矩形

画矩形 height,width,c,d input().split() height,width,d int(height),int(width),int(d) lis [c * width if d else c * (width - 2) c for i in range(height) ]lis: ##### # # # # ##### 或 # # # # # # # #if not d:print(c * width)for i in lis[1:-1…

苹果眼镜(Vision Pro)的开发者指南(3)-【3D UI SwiftUI和RealityKit】介绍

为了更深入地理解SwiftUI和RealityKit,建议你参加专注于SwiftUI场景类型的系列会议。这些会议将帮助你掌握如何在窗口、卷和空间中构建出色的用户界面。同时,了解Model 3D API将为你提供更多关于如何为应用添加深度和维度的知识。此外,通过学习RealityView渲染3D内容,你将能…

【加解密篇】电子数据取证分析之特殊的自加密BitLocker解密

【加解密篇】电子数据取证分析之特殊的自加密BitLocker解密 数据加解密通常是个耗时费力的事情—【蘇小沐】 1、实验环境 Windows 11 专业版,[23H2(22631.3007)] (一)自动开启BitLocker之天坑 1、经验之谈 在201…

php基础学习之数据类型

php数据类型的基本概念 数据类型:data type,在PHP中指的是数据本身的类型,而不是变量的类型。 PHP 是一种弱类型语言,变量本身没有数据类型。 把变量类比成一个杯子(容器),杯子可以装雪碧、可…

2024茶饮品牌如何出圈,媒介盒子分析

随着新式茶饮的消费场景更加多元化,品类不断拓宽,消费者对新式茶饮的热情也是只增不减。居民可支配收入水平不断上升,居民消费升级为新式茶饮的发展也提供了良好基础,今天媒介盒子就来和大家聊聊:2024茶饮品牌如何出圈…

【数据分析】matplotlib、numpy、pandas速通

教程链接:【python教程】数据分析——numpy、pandas、matplotlib 资料:https://github.com/TheisTrue/DataAnalysis 1 matplotlib 官网链接:可查询各种图的使用及代码 对比常用统计图 1.1 折线图 (1)引入 from …

软考之软件工程

一、瀑布模型 严格区分阶段,每个阶段因果关系紧密相连,只适合需求明确的项目 缺点:软件需求完整性、正确性难确定;严格串行化,很长时间才能看到结果;瀑布模型要求每个阶段一次性完全解决该阶段工作&#xf…

Prometheus+Grafana监控Mysql数据库

Promethues Prometheus https://prometheus.io Prometheus是一个开源的服务监控系统,它负责采集和存储应用的监控指标数据,并以可视化的方式进行展示,以便于用户实时掌握系统的运行情况,并对异常进行检测。因此,如何…

【测试开发】Junit5 + YAML 轻松实现参数化和数据驱动,让 App 自动化测试更高效(一)

1. 何为数据驱动 什么是参数化?什么又是数据驱动?经常有人会搞不明白他们的关系,浅谈一下个人的理解,先来看两个测试中最常见的场景: 登录:不同的用户名,不同的密码,不同的组合都需要…

makefile编译静态链接库(.a文件)

文章目录 makefile编译静态链接库(.a文件) makefile编译静态链接库(.a文件) 搞个文件测试静态链接库 aTest.h // // Created by qiufh on 2024-01-23. //#ifndef UNTITLED3_ATEST_H #define UNTITLED3_ATEST_Hclass aTest { pu…

c语言数据结构:单链表及其相关基础操作

目录 0.创建一个额外新的结点 1.链表的概念及其结构 2.单链表的概念 3.单链表的结点的创建 4. 顺序表的打印 5. 链表的尾插 5.1 有关单链表的传参 (重点) 5.1.1 错误的写法 5.1.2 如何修正 5.1.3 正确的写法 5.1.4 看穿二级指针变量 ​编辑 6.链表的头插 7.单链表的…

C++进阶:多态(下)

1、多态的原理 多态之所以可以实现,主要是因为虚函数表的存在,虚函数表用于记录虚函数的地址,他是一个函数指针数组,在类中用一个函数指针数组指针来指向数组,子类继承了父类的虚函数表,当有重写的情况发生…

软考14-上午题-编译、解释程序翻译阶段

一、编译、解释程序【回顾】 目的:高级程序设计语言(汇编语言、高级语言)—【翻译】—>机器语言 1-1、编译方式 将高级语言书写的源程序——>目标程序(汇编语言、机器语言) 包含的工作阶段:词法分…

latex添加图片以及引用的实例教程

原理 在 LaTeX 中插入图片,通常是使用 \includegraphics 命令,它是由 graphicx 包提供的。首先,确保在文档的前言部分(\documentclass 之后和 \begin{document} 之前)包含了 graphicx 包。 下面是一个基本的例子来展…

Hikvision综合安防管理平台files;.css接口存在任意文件读取漏洞 附POC软件

免责声明:请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该文章仅供学习用途使用。 1. Hikvisi…

HIVE中关联键类型不同导致数据重复,以及数据倾斜

比如左表关联键是string类型,右表关联键是bigint类型,关联后会出现多条的情况 解决方案: 关联键先统一转成string类型再进行关联 原因: 根据HIVE版本不同,数据位数上限不同, 低版本的超过16位会出现这种…

变分自编码器VAE模型与应用

变分自编码器(VAE,Variational Autoencoder)是一种深度学习模型,用于数据生成和特征学习。它结合了自编码器(autoencoders)和贝叶斯推断。 下面是VAE的详细解释: 自编码器(Autoenco…