KNN模型进行分类和回归任务

news2024/9/23 17:16:12

KNN工作原理
“近朱者赤,近墨者黑”可以说是KNN的工作原理。整个计算过程分为三步:1:计算待分类物体与其他物体之间的距离;2:统计距离最近的K个邻居;3:对于K个最近的邻居,它们属于哪个分类最多,待分类物体就属于哪一类。K-最近邻算法(K-Nearest Neighbor, KNN)中的K值是一个重要的超参数,不同的K值会影响模型的性能。常见的选择K值的方法包括以下几种

  • 网格搜索(Grid Search):指定一组候选的K值,对每个K值进行交叉验证,选取平均交叉验证误差最小的K值作为最佳K值。缺点是需要进行大量的计算,时间开销较大。
  • K折交叉验证(K-fold Cross Validation):将训练集分成K个子集,每次使用其中K-1个子集作为训练集,剩下的1个子集作为验证集,重复K次。对于每个K值,计算K次的平均交叉验证误差,选取平均交叉验证误差最小的K值作为最佳K值。这种方法的优点是可以减少模型的方差,但是计算时间仍然比较长。
  • 自助法(Bootstrap):从训练集中有放回地随机抽取样本,构建新的训练集。对于每个K值,计算自助样本的平均误差,选取平均误差最小的K值作为最佳K值。这种方法的优点是计算速度快,但是对于小数据集来说,可能会出现较大的方差。

网格搜索(Grid Search)

接下来先看看如何通过网格搜索(Grid Search)获取K值。GridSearchCV是Scikit-Learn库中用于网格搜索的函数,其主要作用是在指定的超参数范围内进行穷举搜索,并使用交叉验证来评估每种超参数组合的性能,以找到最优的超参数组合。该函数包含多个参数,具体参数以及每个参数含义如下所示:

estimator:通常是一个Scikit-Learn模型对象,例如KNeighborsClassifier()、RandomForestClassifier()等,用于表示要使用的模型。
param_grid:需要遍历的超参数空间,是一个字典,其中每个键是一个超参数名称,对应的值是超参数的取值列表。例如,对于KNN模型,可以指定param_grid = {'n_neighbors': [3, 5, 7, 9], 'weights': ['uniform', 'distance'], 'p': [1, 2]},表示K值在3, 5, 7和9中选择,权重方式为'uniform'和'distance',距离度量方式为曼哈顿距离和欧几里得距离。当然,除了这两种距离计算方式,还可以选择:闵可夫斯基距离;切比雪夫距离;余弦距离。

scoring:评价指标,用于评估模型性能的指标,通常是一个字符串或可调用的函数,例如'accuracy'、'f1'、'precision'、'recall'等。如果需要评估多个指标,则可以将评价指标指定为列表或元组。
cv:交叉验证的折数,通常为整数或KFold对象。例如,cv = 5表示将数据集分成5个折,其中4个用于训练,1个用于验证。
n_jobs:并行处理的数量,通常为整数,指定在训练期间使用的CPU数量。如果设置为-1,则使用所有可用的CPU。
verbose:输出详细程度,通常为整数。0表示不输出任何消息,1表示输出少量消息,大于1表示输出更多消息。
return_train_score:是否返回每个超参数组合在训练集上的性能指标。默认情况下,它为False,表示只返回每个超参数组合在验证集上的性

下面是使用GridSearchCV执行分类任务的demo代码,运行demo代码会显示执行的交叉参数组合,且给出最优的参数组合值。

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    test_size=0.3,
                                                    random_state=42)
# 定义待调优的超参数及其取值范围
param_grid = {
    'n_neighbors': [3, 5, 7, 9],
    'weights': ['uniform', 'distance'],
    'p': [1, 2]
}
# 构建KNN模型
knn = KNeighborsClassifier()
# 使用网格搜索进行超参数调优
grid_search = GridSearchCV(knn, param_grid, cv=5, verbose=2)
grid_search.fit(X_train, y_train)
# 输出最优超参数组合及其在验证集上的性能指标
print("Best parameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)
# 在测试集上进行评估
score = grid_search.score(X_test, y_test)
print("Test score: ", score)

K折交叉验证(K-fold Cross Validation)

KFold函数是Scikit-Learn库中用于生成K折交叉验证分割的函数。该函数的主要参数及含义如下:

  • n_splits:交叉验证折数,默认值为5。
  • shuffle:是否对样本进行随机排序,默认值为False。
  • random_state:随机种子数,默认为None,即随机种子为当前时间戳。
  • indices:指定分割的索引数组,可以用于固定分割以进行可重复的交叉验证。

下面是使用KFold函数,采用交叉验证进行模型评估的demo代码。

from sklearn.datasets import load_iris
from sklearn.model_selection import KFold, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 定义模型和超参数
knn = KNeighborsClassifier(n_neighbors=5, weights='uniform', p=2)
# 定义交叉验证的折数
kfold = KFold(n_splits=10, shuffle=True, random_state=42)
# 使用交叉验证进行模型评估
scores = cross_val_score(knn, X, y, cv=kfold)
# 输出平均分数和标准差
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))

各类算法准确率对比

前面介绍了KNN算法、SVM 算法、多项式朴素贝叶斯算法等,下面的demo例子使用手写数字作为训练数据,观察每种算法的精确度,具体code如下所示。其中,sklearn.datasets是Scikit-Learn库中用于加载各种标准数据集的模块之一。load_digits函数可以加载一个手写数字数据集,该数据集包含1797个8x8像素的手写数字图像。每个图像都有相应的标签,表示图像中的数字。该数据集可以用于分类和降维等任务。

# 手写数字分类
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_digits
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt

# 加载数据
digits = load_digits()
data = digits.data
# 数据探索
print(data.shape)
# 查看第一幅图像
print(digits.images[0])
# 第一幅图像代表的数字含义
print(digits.target[0])
# 将第一幅图像显示出来
plt.gray()
plt.imshow(digits.images[0])
plt.show()

# 分割数据,将25%的数据作为测试集,其余作为训练集
train_x, test_x, train_y, test_y = train_test_split(data,
                                                    digits.target,
                                                    test_size=0.25,
                                                    random_state=33)

# 采用Z-Score规范化
ss = preprocessing.StandardScaler()
train_ss_x = ss.fit_transform(train_x)
test_ss_x = ss.transform(test_x)

# 创建KNN分类器
knn = KNeighborsClassifier()
knn.fit(train_ss_x, train_y)
predict_y = knn.predict(test_ss_x)
print("KNN准确率: %.4lf" % accuracy_score(test_y, predict_y))

# 创建SVM分类器
svm = SVC()
svm.fit(train_ss_x, train_y)
predict_y = svm.predict(test_ss_x)
print('SVM准确率: %0.4lf' % accuracy_score(test_y, predict_y))

# 采用Min-Max规范化
mm = preprocessing.MinMaxScaler()
train_mm_x = mm.fit_transform(train_x)
test_mm_x = mm.transform(test_x)

# 创建Naive Bayes分类器
mnb = MultinomialNB()
mnb.fit(train_mm_x, train_y)
predict_y = mnb.predict(test_mm_x)
print("多项式朴素贝叶斯准确率: %.4lf" % accuracy_score(test_y, predict_y))

# 创建CART决策树分类器
dtc = DecisionTreeClassifier()
dtc.fit(train_mm_x, train_y)
predict_y = dtc.predict(test_mm_x)
print("CART决策树准确率: %.4lf" % accuracy_score(test_y, predict_y))

实验结果如下图所示,可以看到KNN和SVM准确率比较接近,多项式朴素贝叶斯和CART决策树准确率稍低。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/801447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Nginx12】Nginx学习:HTTP核心模块(九)浏览器缓存与try_files

Nginx学习:HTTP核心模块(九)浏览器缓存与try_files 浏览器缓存在 Nginx 的 HTTP 核心模块中其实只有两个简单的配置,这一块也是 HTTP 的基础知识。之前我们就一直在强调,学习 Nginx 需要的就是各种网络相关的基础知识&…

前端程序员入门:先学Vue3还是Vue2?

一、前言 对于新手来说,学习Vue.js框架时往往会有这样一个疑问:应该先学习Vue2还是直接学习Vue3?在回答这个问题之前,我们先简单介绍一下Vue.js框架。 Vue.js是一个轻量级的MVVM(Model-View-ViewModel)框架,它以数据驱…

数字世界未来十年面貌展望

2023年,数字技术已经深刻改变了我们的生活和社会,而未来十年数字世界的面貌将会更加令人瞩目。从人工智能到区块链,从虚拟现实到5G,各种科技将继续发展演进,给我们带来更多令人兴奋的可能性。以下是对数字世界未来十年…

交换机之HOL拥塞

队首阻塞(Head of Line Blocking, HOL)是一种出现在缓存式通信网络交换中的一种现象,其交换结构通常由缓存式FIFO输入端、交换结构(Switch Fabric)、FIFO输出端构成。 HOL阻塞用一个现实生活中的例子说明,就如同你在一条单车道的马路上右转,…

人机交互与人机混合智能的区别

人机交互和人机融合智能是两个相关但不完全相同的概念: 人机交互是指人与计算机之间的信息交流和互动过程。它关注的是如何设计和实现用户友好的界面,以便人们能够方便、高效地与计算机进行沟通和操作。人机交互通常强调用户体验和界面设计,旨…

如何找回删除的文件?文件恢复,3招就行!

“昨天不小心把我的毕业资料删除了,因为改了很多版,删除的时候没想到把正确的版本删除了,错误的版本还在!这种情况应该怎么办呢?怎样才能找回我删除的文件呀?” 对于一些比较重要的文件,不小心删…

【C++初阶】C++基础(上)——C++关键字、命名空间、C++输入输出、缺省参数、函数重载

目录 1. C关键字 2. 命名空间 2.1 命名空间的定义 2.2 命名空间的使用 3. C输入&输出 4. 缺省参数 4.1 缺省参数概念 4.2 缺省参数分类 5. 函数重载 5.1 函数重载概念 5.2 C支持函数重载的原理——名字修饰(name Mingling) 5.3 extern &…

围棋基础知识

1、气 1.1星位位置 1.2天元位置 1.3 气的位置 2、禁入点 白棋里面的位置就是禁入点,也可以称为没有气的位置可以称为禁入点 破解之法: 在于将白棋全部围住,下一步为围住之策,即可。 3、死棋和活棋 3.1活棋 3.2 死棋 白棋的样…

探寻数据服务的本质:API之外的可能性

数据服务在数据建设中发挥着重要的作用。数据服务到底啥样? 是不是只对外提供一个API? 这么简单? 而我希望你能在学完这部分内容之后,真正掌握数据服务的产品功能设计和系统架构设计。因为这会对你设计一个数据服务,或…

青少年护眼灯哪个好?2023全新五款台灯推荐

国内儿童青少年的视力健康问题越来越突出,甚至许多孩子年纪非常小就已经近视了,所以许多老师以及眼科医生都和家长们强调护眼台灯的重要性。不过,护眼台灯虽好,但在选购时也要注意那些无法护眼的不专业品牌,许多产品有…

ICC2如何计算Gate Count?

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧?知识星球入口 我们认为gate count等于standard cell(非physical only)总面积 / 最小驱动二输入与非门面积。 ICC2没有专门的命令去报告gate count,只能自己计算,使用report_d…

MySQL数据库——DDL基本操作

文章目录 前言数据库操作查看已存在的所有数据库创建数据库选中数据库删除数据库修改数据库编码 表操作创建表显示创建表时的语句显示表结构删除表修改表的结构增加列修改列删除列 修改表名 前言 DDL 操作是与数据库结构相关的操作,它们不涉及实际的数据操作&#…

B2B企业如何选择CRM系统?

CRM软件的优势在于简化业务流程,实现企业的降本增效。越来越多的B2B企业通过CRM为业务赋能,B2B企业如何快速找到适合公司业务的CRM系统?总的来说就是根据企业自身业务而量身打造的一套系统。 1.整理业务需求 B2B企业首先要考虑是业务痛点&a…

易基因:m6A-seq等揭示RBM33参与调控m6A去甲基化酶ALKBH5活性及其底物选择性|科研进展

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 RNA结合蛋白(RNA-binding protein,RBP)是一类结构和功能多样化的蛋白质,参与多种生物过程。越来越多的证据表明,RBP通过调控编…

意外:WPS编程新工具,不用编程,excel用户:可以不用VBA啦

来来来,拓宽一下视野! 别总以为excel和WPS只能用VBA编程,也别总是想着ACCESS这些老生常谈的工具。其实对于电子表格高级用户来讲,不会VBA,不用ACCESS,也一样可以解决复杂问题或者高级应用。 尤其是WPS用户…

C++多线程编程(第三章 利用栈特性自动释放锁RALL,锁管理器、控制器)

1、什么是RALL,手动代码实现 RALL(resource Acquisition Is Initialization )C 之父Bjarne Stroustrup 提出; 使用局部对象来管理资源的技术称为资源获取即初始化;它的生命周期是由操作系统来管理的,无需人…

Hive分区分桶

分区 分区概念 在逻辑上分区表与未分区表没有区别,在物理上分区表会将数据按照分区键的列值存储在表目录的子目录中,目录名“分区键键值”。其中需要注意的是分区键的值不一定要基于表的某一列(字段),它可以指定任意…

ubuntu软件:录制视频和截图工具,压缩视频

1. 自带录制视频工具; 使用方式: 无需下载 开始录屏/结束录屏:Ctrl Alt Shift r 当看到 Ubuntu 桌面的右上方多了一个红色的小圆点,代表正在录制 注意: 录屏默认的时长30秒,超时会自动结束&#xff01…

Postman如何导出接口的几种方法

本文主要介绍了Postman如何导出接口的几种方法,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下 前言: 我的文章还是一贯的作风,简确用风格(简单确实有用)&am…

无涯教程-jQuery - serialize( )方法函数

serialize()方法将一组输入元素序列化为数据字符串。 serialize( ) - 语法 $.serialize( ) serialize( ) - 示例 假设无涯教程在serialize.php文件中具有以下PHP内容- <?php if( $_REQUEST["name"] ) {$name$_REQUEST[name];echo "Welcome ". $na…