K最近邻算法:简单高效的分类和回归方法(三)

news2024/12/24 8:29:08

文章目录

  • 🍀引言
  • 🍀训练集和测试集
  • 🍀sklearn中封装好的train_test_split
  • 🍀超参数

🍀引言

本节以KNN算法为主,简单介绍一下训练集和测试集超参数


🍀训练集和测试集

训练集和测试集是机器学习和深度学习中常用的概念。在模型训练过程中,通常将数据集划分为训练集和测试集,用于训练和评估模型的性能。

训练集是用于模型训练的数据集合。模型通过对训练集中的样本进行学习和参数调整来提高自身的预测能力。训练集应该尽可能包含各种不同的样本,以使模型能够学习到数据集中的模式和规律,并能够适应新的数据。

测试集是用于评估模型性能的数据集合。模型训练完成后,使用测试集中的样本进行预测,并与真实标签进行对比,以评估模型的精度、准确度和其他性能指标。测试集应该与训练集相互独立,以确保对模型的泛化能力进行准确评估。

一般来说,训练集和测试集的划分比例是80:20或者70:30。有时候还会引入验证集,用于在训练过程中调整模型的超参数。训练集、验证集和测试集是机器学习中常用的数据集拆分方式,以确保模型的准确性和泛化能力。

接下来我们回顾一下KNN算法的简单原理,选取离待预测最近的k个点,再使用投票进行预测结果

from sklearn.neighbors import KNeighborsClassifier
knn_clf = KNeighborsClassifier()
from sklearn.datasets import load_iris  # 因为我们并没有数据集,所以从库里面调出来一个
iris = load_iris()
X = iris.data
y = iris.target
knn_clf.fit(X,y)
knn_clf.predict()

那么我们如何评价KNN模型的好坏呢?

这里我们将数据集分为两部分,一部分为训练集,一部分为测试集,因为这里的训练集和测试集都是有y的,所以我们只需要将训练集进行训练,然后产生的模型应用到测试集,再将预测的y和原本的y进行对比,这样就可以了

接下来进行简易代码演示讲解

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target

我们可以把y打印出来看看
在这里插入图片描述
这里我们不妨思考一下,如果训练集和测试集是8:2的话,测试集的y岂不是都是2了,那么还有啥子意义,所以我们需要将其打乱一下下,当然我们这里打乱的是index也就是下标,可不要自以为是的将y打乱了

import numpy as np
indexs = np.random.permutation(len(X))

导入必要的库后,我们将数据集下标进行打乱并保存于indexs中,接下来迎来重头戏分割数据集

test_ratio = 0.2
test_size = int(len(X) * test_ratio)
test_indexs = shuffle_indexs[:test_size] # 测试集
train_indexs = shuffle_indexs[test_size:] # 训练集

不信的小伙伴可以使用如下代码进行检验

test_indexs.shape
train_indexs.shape

在这里插入图片描述
接下来将打乱的下标进行分别赋值

X_train = X[train_indexs]
y_train = y[train_indexs]
X_test = X[test_indexs]
y_test = y[test_indexs]

分割好数据集后,我们就可以使用KNN算法进行预测了

from sklearn.neighbors import KNeighborsClassifier
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train,y_train)
y_predict = knn_clf.predict(X_test)

我们这里可以打印一下y_predict和y_test进行肉眼对比一下
在这里插入图片描述
在这里插入图片描述
最后一步就是将精度求出来

np.sum(np.array(y_predict == y_test,dtype='int'))/len(X_test)

在这里插入图片描述


🍀sklearn中封装好的train_test_split

上面我们只是简单演示了一下,接下来我们使用官方的train_test_split

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y) # 注意这里返回四个结果

这里你可以试着看一眼,分割的比例与之前手动分割的比例大不相同
最后按部就班来就行

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train,y_train)
knn_clf.predict(X_test) 
knn_clf.score(X_test,y_test)

在这里插入图片描述


🍀超参数

什么是超参数,可以点击链接查看

在pycharm中我们可以查看一些参数
在这里插入图片描述

接下来通过简单的演示来介绍一下

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
knn_clf = KNeighborsClassifier(weights='distance') 
from sklearn.model_selection import train_test_split
iris = load_iris()
X = iris.data
y = iris.target
X_train,X_test,y_train,y_test = train_test_split(X,y)

上面是老熟人了就不一一赘述了,但是注意这里面有个超参数(weights),这个参数有两种,一个是distance一个是uniform,前者和距离有关联,后者无关


首先测试一下n_neighbors这个参数代表的就行之前的那个k,邻近点的个数

%%time
best_k = 0
best_score = 0.0
best_clf = None
for k in range(1,21):
    knn_clf = KNeighborsClassifier(n_neighbors=k)
    knn_clf.fit(X_train,y_train)
    score = knn_clf.score(X_test,y_test)
    if score>best_score:
        best_score = score
        best_k = k
        best_clf = knn_clf
print(best_k)
print(best_score)
print(best_clf)

在这里插入图片描述
测试完参数n_neighbors,我们再来试试weights

%%time
best_k = 0
best_score = 0.0
best_clf = None
best_method = None
for weight in ['uniform','distance']:
    for k in range(1,21):
        knn_clf = KNeighborsClassifier(n_neighbors=k,weights=weight)
        knn_clf.fit(X_train,y_train)
        score = knn_clf.score(X_test,y_test)
        if score>best_score:
            best_score = score
            best_k = k
            best_clf = knn_clf
            best_method = weight
print(best_k)
print(best_score)
print(best_clf)
print(best_method)

在这里插入图片描述
最后我们测试一下参数p

%%time
best_k = 0
best_score = 0.0
best_clf = None
best_p = None
for p in range(1,6):
    for k in range(1,21):
        knn_clf = KNeighborsClassifier(n_neighbors=k,weights='distance',p=p)
        knn_clf.fit(X_train,y_train)
        score = knn_clf.score(X_test,y_test)
        if score>best_score:
            best_score = score
            best_k = k
            best_clf = knn_clf
            best_p = p
            
print(best_k)
print(best_score)
print(best_clf)
print(best_p)

或许大家不知道这个参数p的含义,下面我根据几个公式带大家简单了解一下
请添加图片描述

请添加图片描述
请添加图片描述

三张图分别代表欧拉距离曼哈顿距离明科夫斯基距离,细心的小伙伴就可以发现了,p=1位曼哈顿距离,p=2位欧拉距离,这里不做详细的说明,感兴趣的小伙伴可以翻阅相关数学书籍

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/849023.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

个人对智能家居平台选择的思考

本人之前开发过不少MicroPython程序,其中涉及到自动化以及局域网控制思路,也可以作为智能家居的实现方式。而NodeMCUESPHome的方案具有方便添加硬件、容易更新程序和容量占用小的优势,本人也查看过相关教程后感觉部署ESPHome和编译固件的步骤…

科学与信息化杂志科学与信息化杂志社科学与信息化编辑部2023年第14期目录

科学视野 现代技术角度下对光纤通信传输技术的思考 崔文佳1-3 浅谈非道路移动机械污染防治难点与对策 刘华4-6 基层公路养护档案管理 高富丽7-9《科学与信息化》投稿:cnqikantg126.com 奉贤区第二次全国污染源普查及防治对策建议 卫伟10-12 数字化赋能在国土空间治理…

工厂方法模式-java实现

介绍 工厂方法模式,通过把工厂抽象为一个接口,这样当我们新增具体产品的时候,就只需要实现一个新的具体工厂类即可。一个具体工厂类,对应着一个产品。 请注意:在工厂方法模式中,一个具体工厂类只对应生产…

vue3+vite配置多入口文件

1.修改vite.config.ts 文件: 2.在src目录底下建相应的html文件和对应的ts入口文件和vue文件,如下图: npm run dev运行后本地访问: http://127.0.0.1:5173/home_index.htmlnpm run build打包后的结构如图:

物联网的定义、原理、示例、未来

什么是物联网? 物联网 (IoT) 是指由嵌入传感器、软件和网络连接的物理设备、车辆、电器和其他物理对象组成的网络,允许它们收集和共享数据。这些设备(也称为“智能对象”)的范围可以从简单的“智能家居”设备(如智能恒温器)到可穿戴设备(如智能手表和支持RFID的服…

Anaconda Prompt使用pip安装PyQt5-tools后无法打开Spyder或闪退

艹!MLGBZD! 真TMD折腾人! 出现原因: 首次安装完Anaconda3-2023.07-1-Windows-x86_64.exe后首次打开Spyder,此时是没有问题的,然后打开Anaconda Prompt,查看有哪些包,pip list 这时候开始首次安…

k8s之Pod控制器

目录 一、Pod控制器及其功用二、pod控制器的多种类型2.1 pod容器中的有状态和无状态的区别 三、Deployment 控制器四、SatefulSet 控制器4.1 StatefulSet由以下几个部分组成4.2 为什么要有headless?4.3 为什么要有volumeClaimTemplate?4.4 滚动更新4.5 扩…

Rocketmq Filter 消息过滤(TAGS、SQL92)原理详解 源码解析

1. 背景 1.1 Rocketmq 支持的过滤方式 Rocketmq 作为金融级的业务消息中间件,拥有强大的消息过滤能力。其支持多种消息过滤方式: 表达式过滤:通过设置过滤表达式的方式进行过滤 TAG:根据消息的 tag 进行过滤。SQL92&#xff1a…

【每日一题】—— B. Maximum Rounding(Codeforces Round 891 (Div. 3))

🌏博客主页:PH_modest的博客主页 🚩当前专栏:每日一题 💌其他专栏: 🔴 每日反刍 🟡 C跬步积累 🟢 C语言跬步积累 🌈座右铭:广积粮,缓称…

UNIX网络编程——UDP协议,CS架构

目录 一.socket创建通信的套接字 二.IPv4地址结构 三.通用地址结构 四. 两种地址结构的使用场合 五.sendto发送数据 六.bind固定地址信息​编辑 七.recvfrom接受UDP的消息​编辑 一.socket创建通信的套接字 二.IPv4地址结构 三.通用地址结构 四. 两种地址结构的使用场合…

MySQL— 基础语法大全及操作演示!!

MySQL—— 基础 一、MySQL概述1.1 、数据库相关概念1.2 、MySQL 客户端连接1.3 、数据模型 二、SQL2.1、SQL通用语法2.2、SQL分类2.3、DDL2.4、DML2.5、DQL2.6、DCL 三、函数四、约束五、多表查询六、事务 一、MySQL概述 1.1 、数据库相关概念 数据库、数据库管理系统、SQL&a…

sql server 删除指定字符串

replace方法 update #test set FIVCODEreplace(FIVCODE,440,) WHERE SOURCEFENTRYID140728

嵌入式软件测试-测试类型

使用质量属性来定义测试类型,即回归到测试类型的本质。 如果测试负载在系统允许的负载范围内,那测试的是系统的功能,此时的测试属于功能性测试;若在此基础上再加大测试时间,那就是稳定性测试了,此时关注的…

有哪些常用的设计素材网站?

素材网站可以是设计师和创意人员的灵感来源。这些网站收集了各种类型的平面设计图片,包括标志、海报、网站设计、包装设计、插图等。在本文中,我将推荐15个平面设计图素材网站,以帮助您找到新的想法和灵感。 1.即时设计资源社区 即时设计资…

8月8日上课内容 研究nginx组件rewrite

location 匹配uri location 匹配的规则和优先级。(重点,面试会问,必须理解和掌握) nginx常用的变量,这个要求掌握 rewrite:重定向功能。有需要掌握,有需要理解的。 location匹配:…

【RabbitMQ】golang客户端教程5——使用topic交换器

topic交换器(主题交换器) 发送到topic交换器的消息不能具有随意的routing_key——它必须是单词列表,以点分隔。这些词可以是任何东西,但通常它们指定与消息相关的某些功能。一些有效的routing_key示例:“stock.usd.ny…

角角の Qt学习笔记(一)

目录 一、解决在创建新项目时遇到的几个问题 二、信号和槽(非自定义) 三、调用 UI 中的元素(比如按钮) 一、解决在创建新项目时遇到的几个问题 在新建项目时,我选择的构建系统为CMake。然后勾选了Generate form&…

程序员月薪3w、4w难吗?该如何突破?

先说结论,如果你能成为互联网大厂的程序员,那么恭喜你,你的月薪大概率能达到3w、4w,甚至更高,此外一些非互联网大厂的程序员,比如金融、汽车制造等,月薪突破3w、4w的概率也非常高,但…

分享一个计算器

先看效果&#xff1a; 再看代码&#xff08;查看更多&#xff09;&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>计算器</title><style>* {box-sizing: border-box;}body…

threejs点击模型实现模型边缘高亮的选中效果--更改后提高帧率

先来个效果图 之前写的那个稍微有点问题&#xff0c;帧率只有30&#xff0c;参照官方代码修改后&#xff0c;帧率可以达到50了&#xff0c;在不全屏的状态下&#xff0c;帧率60 1.首先需要导入库 // 用于模型边缘高亮 import { EffectComposer } from "three/examples/js…