KNN算法与模型选择及调优

news2025/1/13 8:04:49

KNN算法-分类

1 样本距离判断

(1)欧式距离

欧式距离(Euclidean distance),也称为欧氏度量,是用来衡量两个点之间直线距离的方法。

(2)曼哈顿距离

曼哈顿距离(Manhattan distance),也称为城市街区距离,是在网格状坐标系统中,从一个点到另一个点的距离之和。

2 KNN 算法原理

K-近邻算法(K-Nearest Neighbors,简称KNN)是一种基于样本相似性的分类方法。它通过比较目标样本与其最近的 ( K ) 个邻居样本的类别来决定目标样本的类别。

具体步骤如下:

  1. 选择 ( K ):确定要考虑的邻居数量 ( K )。

  2. 计算距离:计算目标样本与所有其他样本的距离,找到距离最近的 ( K ) 个样本。

  3. 投票决策:统计这 ( K ) 个邻居中各类别的出现次数,目标样本将被分配到出现次数最多的类别。

假设我们有 10000 个样本,并且选择 ( K = 5 )。我们需要对一个待分类的样本 B 进行分类。首先,我们找到与样本 B 距离最近的 5 个样本。假设这 5 个邻居的类别分布如下:

  • 类别 A:3 个样本

  • 类别 B:1 个样本

  • 类别 C:1 个样本

在这种情况下,由于类别 A 的邻居样本最多,样本 B 将被分类为类别 A。

弊端:计算量大维度灾难需要选择合适的 ( K ) 值和距离度量

sklearn.neighbors.KNeighborsClassifier 的主要参数和方法如下:
参数
n_neighbors:邻居数量,默认为 5。
algorithm:寻找邻居的算法,默认为 'auto'。
方法
fit(X, y):训练模型。
predict(X):进行预测。

使用KNN算法预测《唐人街探案》电影属于哪种类型?分别计算每个电影和预测电影的距离然后求解:

import pandas as pd
from collections import Counter
​
# 读取电影数据
data = pd.read_csv("./src/movies.csv")
df = pd.DataFrame(data)
​
# 目标电影数据
target_movie = {
    '电影名称': '唐人街探案',
    '搞笑镜头': 23,
    '拥抱镜头': 3,
    '打斗镜头': 17,
    '距离': None  # 不需要实际的距离值
}
​
# 按距离排序并选择最近的3个邻居
k = 3
nearest_neighbors = df.sort_values(by='距离').head(k)#head(k) 是 pandas 的一个方法,用于从 DataFrame 的顶部获取前 k 行数据
​
# 统计K个邻居的电影类型
neighbor_types = nearest_neighbors['电影类型']
predicted_type = Counter(neighbor_types).most_common(1)[0][0]
#most_common(n) 是 Counter 类的一个方法,它返回出现频率最高的 n 个元素及其计数,返回一个列表,其中包含元组 (元素, 频次)。
#[0][0] 表示取第一个元组,然后取第一个元素,即电影类型
​
print(f"唐人街探案的预测电影类型是: {predicted_type}")
2.sklearn 实现KNN示例

用KNN算法对葡萄酒进行分类

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_wine
import joblib #模型保存与加载需导入
​
def knn(path):
    # 加载酒类数据集
    wine = load_wine()
    x = wine.data  # 特征数据
    y = wine.target  # 目标标签
​
    # 划分训练集和测试集
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
​
    # 创建并训练 KNN 模型
    model = KNeighborsClassifier(n_neighbors=7)
    model.fit(x_train, y_train)
​
    # 评估模型准确率
    accuracy = model.score(x_test, y_test)
    print("准确率:", accuracy)
​
    # 保存模型
    joblib.dump(model, path)
​
# 定义模型保存路径并调用函数
path = './src/knn.pkl'
knn(path)
​
# 加载保存的模型
model = joblib.load(path)
​
# 预测新样本的类别
y_predict = model.predict([[2.16, 5.51, 1.34, 1.9, 2.76, 1.35, 4.43, 2.05, 0.94, 1.36, 4.36, 3.59, 1.86]])
print(y_predict)
​
# 输出预测结果对应的标签名
print(load_wine().target_names[y_predict])
​

模型选择与调优

1 交叉验证
(1) 保留交叉验证HoldOut

HoldOut 交叉验证将数据集分为训练集和验证集,通常70%用于训练,30%用于验证。这是一种简单直接的方法。

优点:

  • 实现简单,操作方便。

缺点:

  1. 不适用于不平衡数据集:在不平衡数据集中,训练集和验证集可能不均衡,导致模型无法有效学习少数类的数据。(比如80%的数据属于 “0 “类,其余20%的数据属于 “1 “类)

  2. 数据利用不充分:在小数据集上,一部分数据用于验证,可能导致模型错过关键特征,影响性能。

(2) K-折交叉验证(K-fold)

K-Fold 交叉验证将数据集划分为K个大小相同的部分(Fold)。每次用一个Fold作为验证集,其余K-1个Fold作为训练集。这一过程重复K次,确保每个Fold都被用作验证集。

最终模型的准确度通过计算K次验证结果的平均值来获得。

(3) 分层k-折交叉验证Stratified k-fold

分层 K-Fold 交叉验证是 K-Fold 交叉验证的变种。在每一折中保持原始数据中各类别的比例。例如,如果原始数据有三类,比例为 1:2:1,那么每一折中的类别比例也保持 1:2:1。这种方法使得每一折的验证结果更加可靠。

(4) 其它验证
去除p交叉验证)
留一交叉验证)
蒙特卡罗交叉验证
时间序列交叉验证

(5)API的使用

strat_k_fold=sklearn.model_selection.KNeighborsClassifier(n_splits=5, shuffle=True, random_state=42)
​
•   n_splits划分为几个折叠 
•   shuffle是否在拆分之前被打乱(随机化),False则按照顺序拆分
•   random_state随机因子

from sklearn.datasets import load_wine
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import StratifiedKFold
import numpy as np
​
# 加载葡萄酒数据集
wine = load_wine()
x = np.array(wine.data)  
y = np.array(wine.target)  
​
dts = []  
# 初始化分层 K-Fold 交叉验证器
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=44)
​
# 遍历每个折叠
for train_index, test_index in kf.split(x, y):
    x_train, x_test = x[train_index], x[test_index]  # 划分训练集和测试集
    y_train, y_test = y[train_index], y[test_index]
    
    # 初始化 KNN 分类器并训练
    knn = KNeighborsClassifier(n_neighbors=5)
    knn.fit(x_train, y_train)
    
    # 计算测试集的准确率并存储
    dt = knn.score(x_test, y_test)
    dts.append(dt)
​
print("每个折叠的准确率:", dts)
print("平均准确率:", np.mean(dts))
2 超参数搜索及其API使用

超参数搜索也叫网格搜索(Grid Search)

比如在KNN算法中,k是一个可以人为设置的参数,所以就是一个超参数。网格搜索能自动的帮助我们找到最好的超参数值。

sklearn.model_selection.GridSearchCV 是一个工具,用于同时进行网格搜索和交叉验证。它也被视为一个估计器(estimator)。

class sklearn.model_selection.GridSearchCV(estimator, param_grid)
      best_params_  最佳参数
      best_score_ 在训练集中的准确率
      best_estimator_ 最佳估计器
      cv_results_ 交叉验证过程描述
      best_index_最佳k在列表中的下标
参数:
estimator:待优化的 scikit-learn 估计器
param_grid:参数名称(字符串)和对应值列表的字典,如 {"n_neighbors": [1, 3, 5, 7, 9, 11]}
cv:交叉验证策略
None:默认5折
integer:指定折数
分类器使用“分层K折交叉验证”(StratifiedKFold),其他情况使用KFold。
3 示例-葡萄酒分类

用KNN算法对葡萄酒进行分类,添加网格搜索和交叉验证

from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_wine
import joblib
​
x, y = load_wine(return_X_y=True)
​
knn = KNeighborsClassifier()
​
model = GridSearchCV(knn, {"n_neighbors": [5, 6, 9, 10]})
​
​
re = model.fit(x, y)
​
print("最佳参数:\n", model.best_params_)
print("最佳k在列表中的下标", model.best_index_)
print("最佳准确率:\n", model.best_score_)
print("最佳估计器:\n", model.best_estimator_)
print("交叉验证过程描述:\n", model.cv_results_)
​
# 使用最佳模型进行预测
y_predict = re.predict([[2.06, 5.41, 1.34, 1.85, 2.73, 1.65, 3.83, 2.05, 1.04, 1.36, 4.36, 3.49, 1.96]])
print(y_predict)
​
# 验证是否模型与最佳估计器相同
print(re == model.best_estimator_)
joblib.dump(model.best_estimator_, "./src/knn2.pkl")
​
model = joblib.load("./src/knn2.pkl")
​
# 使用加载的模型进行预测
y_predict = model.predict([[2.16, 5.51, 1.34, 1.9, 2.76, 1.35, 4.43, 2.05, 0.94, 1.36, 4.36, 3.59, 1.86]])
print(y_predict)
​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2060701.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

探索CompletableFuture:高效异步编程的利器

目录 一、CompletableFuture基本功能安利 二、CompletableFuture使用介绍 (一)任务创建使用 1.supplyAsync创建带有返回值的异步任务 2.runAsync创建没有返回值的异步任务 (二)异步回调使用 1.异步回调:thenApp…

Android 修改SystemUI 音量条的声音进度条样式

一、前言 Android System UI 开发经常会遇到修改音量进度条样式的需求,主要涉及的类有VolumeDialogImpl与xml文件,接下来会逐步实现流程。先看看效果。 修改前 修改后 二、找到对应类 通过aidegen 打断点调试对应代码类VolumeDialogImpl定位到volume_d…

中国第一起名大师的老师颜廷利: 名字中的姓氏家谱字辈的最新解析

在探讨文化和文明的深层含义时,我们常常发现,对传统的尊重与现代价值观之间存在着一种微妙的张力。这种张力在一个简单的例子中得到了生动的体现:姓名的选择。 在古代社会,名字不仅仅是个体的标识,更是家族传承和社会结…

JavaScript_11_练习:小米搜索框案例(焦点事件)

效果图 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>练习&#xff1a;小米搜索框案例&#…

第十二届青少年蓝桥杯Python组省赛试题

一、选择题 1.设s’Hello Lan Qiao’,执行print(s[4:11])输出的结果为()。 *选择题严禁使用程序验证 A、lo Lan Qi B、lo Lan Q C、o Lan Qi D、o Lan Q 提示&#xff1a;切片 2.循环语句for i in range(8,-4,-2):执行了几次循环()。 *选择题严禁使用程序验证 A、4 B、5 C、6…

LabVIEW锅炉燃烧远程监控系统

随着信息技术的发展&#xff0c;远程监控技术已经广泛应用于各种工业过程。开发了一个基于LabVIEW和互联网技术的锅炉燃烧远程监控系统&#xff0c;该系统不仅提高了锅炉运行的安全性和效率&#xff0c;还具备了故障远程诊断的功能&#xff0c;为锅炉管理提供了一种全新的解决方…

[论文笔记]Improving Retrieval Augmented Language Model with Self-Reasoning

引言 今天带来一篇百度提出的关于提升RAG准确率的论文笔记&#xff0c;Improving Retrieval Augmented Language Model with Self-Reasoning。 为了简单&#xff0c;下文中以翻译的口吻记录&#xff0c;比如替换"作者"为"我们"。 检索增强语言模型(Retrie…

谷歌浏览器自动填充密码怎么设置

谷歌浏览器的自动填充密码功能为用户提供了一种安全而便捷的在线体验&#xff0c;让用户在下次登录网站的时候&#xff0c;减去重复输入密码的麻烦。下面就给大家分享一下关于谷歌浏览器自动填充密码的相关内容&#xff0c;让你更加轻松的管理自己的账户。 谷歌浏览器自动填充密…

26.删除有序数组中的重复项---力扣

题目链接&#xff1a; . - 力扣&#xff08;LeetCode&#xff09;. - 备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/remove-duplicates-from-sorted-array/descript…

使用maven快速生成打包文件

最近在部署基于SpringBoot开发的项目时&#xff0c;由于微服务较多&#xff0c;本地工程编译后只得出一个JAR包&#xff0c;部署起来实在不方便&#xff0c;因此总想着怎么偷偷懒&#xff0c;执行一次命令编译出整个部署的文件。先说结果&#xff0c;最后期望打包的目录如下&am…

【数据结构篇】~双向链表(附源码)

前言 学完了单链表&#xff0c;还有其他等着我们去攻克&#xff0c;链表其实分为8种 &#xff0c;共2 * 2 * 28种 之前的单链表是不带头单向不循环链表 一、双向链表 注意&#xff1a;这里的“带头”跟前面我们说的“头结点”是两个概念&#xff0c;实际前面的在单链表阶段称…

百度地图路书实现历史轨迹回放、轨迹回放进度、聚合点、自定义弹框和实时监控视频、多路视频轮巡播放

前言 分享一个刚做完项目集成技术&#xff0c;一个车辆行驶轨迹监控、行车视频监控、对特种车辆安全监管平台&#xff0c;今年政府单位有很多监管平台项目&#xff0c;例如&#xff1a;渣土车监控、租出车监管、危害气体运输车监管等平台&#xff0c;这些平台都有车辆行驶轨迹…

QT基础知识5

思维导图 client.cpp #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget), socket(new QTcpSocket(this))//给客户端实例化分配空间 {ui->setupUi(this);//初始化界面ui->msgEdit-&…

微盟年中报:聚焦主业降本增效,经调整净亏损同比大幅收窄81.4%

8月21日&#xff0c;微盟集团&#xff08;2013.HK&#xff09;发布2024年中期业绩报告。在充满挑战的市场环境中&#xff0c;公司积极进行业务布局优化调整&#xff0c;战略性聚焦核心业务&#xff0c;集团总收入达8.67亿元人民币&#xff0c;整体毛利率保持平稳。报告期内&…

语言基础/单向链表的构建和使用(含Linux中SLIST的解析和使用)

文章目录 概述简单的链表描述链表的术语简单实现一个单链表 Linux之SLIST机理分析结构定义单链表初始化单链表插入元素单链表遍历元素单链表删除元素 Linux之SLIST使用实践纯C中typedef重命名带来的问题预留 概述 本文讲述了数据结构中单链表的基本概念&#xff0c;头指针、头…

监控状态流图中的测试点

此示例展示了如何将数据或状态指定为测试点&#xff0c;你可以在仿真过程中使用浮动范围绘制这些测试点或将其记录到MATLAB基本工作区。 关于状态流图中的测试点 Stateflow测试点是您可以在模拟过程中观察到的信号&#xff0c;例如&#xff0c;通过使用浮动范围块。您可以使用…

进阶SpringBoot之 SpringSecurity(1)环境搭建

Spring Security 中文文档 Spring Security 是一个 Java 框架&#xff0c;用于保护应用程序的安全性 它提供认证&#xff08;authentication&#xff09;、授权&#xff08;authorization&#xff09;和保护&#xff0c;以抵御常见的攻击 Spring Security 基于过滤器链的概念…

Linux虚拟机磁盘管理-创建新磁盘分区

1.查看新加的硬盘情况 b英文为block表示块 查看磁盘信息方法一&#xff1a;ll /dev/sd* 查看磁盘信息方法二&#xff1a;lsblk 2.创建分区 1&#xff09;创建磁盘分区 以sdb这块磁盘进行分区为例 一个磁盘最多分4个分区 输入w进行确认创建一个房间&#xff0c;这个房间就能…

Nginx: 配置项之main段核心参数用法梳理

概述 我们了解下配置文件中的一个全局段&#xff0c;有哪些配置参数&#xff0c;包括后面的 events 字段&#xff0c;有哪些配置参数这里面也有一些核心参数, 对于我们Nginx运行的性能也是有很重要的帮助我们现在首先关注整个 main 段的一个核心参数用法所谓 main 段&#xff…

前后端分离开发:用 Apifox 高效管理 API

目录 1.前后台分离开发介绍 2.API 2.1 APIfox介绍 2.2 接口文档管理 1.前后台分离开发介绍 前端开发有2种方式&#xff1a;「前后台混合开发」和「前后台分离开发」。 前后台混合开发&#xff0c;顾名思义就是前台后台代码混在一起开发&#xff0c;如下图所示&#xff1a…