K 近邻算法

news2024/12/24 21:26:08

为什么学习KNN算法

KNN是监督学习分类算法,主要解决现实生活中分类问题。

(1)首先准备数据,可以是视频、音频、文本、图片等等

(2)抽取所需要的一些列特征,形成特征向量

(3)将这些特征向量连同标记一并送入机器学习算法中,训练出一个预测模型。

(4)采用同样的特征提取方法作用于新数据,得到用于测试的特征向量。

(5)使用预测模型对这些待测的特征向量进行预测并得到结果。

K近邻是机器学习算法中理论最简单,最好理解的算法,虽然算法简单,但效果也不错。

算法的思想:通过K个最近的已知分类的样本来判断未知样本的类别 

  • 图像识别:KNN可以用于图像分类任务,例如人脸识别、车牌识别等。在图像识别领域,KNN通过计算测试图像与训练集中图像的相似度来进行分类。
  • 文本分类:在文本分类方面,KNN算法可以应用于垃圾邮件过滤、情感分析等领域。通过对文本数据的特征提取和距离计算,KNN能够对新文本进行有效的分类。
  • 回归预测:虽然KNN更常用于分类问题,但它也可以用于解决回归问题。在回归任务中,KNN通过找到最近的K个邻居,并根据它们的值来预测连续的输出变量。
  • 医疗诊断:KNN算法可以辅助医生进行疾病的诊断。通过比较患者的临床数据与历史病例数据,KNN有助于识别疾病的模式和趋势。
  • 金融风控:在金融领域,KNN可用于信用评分和欺诈检测。通过分析客户的交易行为和信用历史,KNN可以帮助金融机构评估风险。
  • 推荐系统:KNN还可以用于构建推荐系统,通过分析用户的历史行为和其他用户的行为模式,为用户推荐商品或服务。

Sklearn API

鸢尾花Iris Dataset数据集是机器学习领域经典数据集,鸢尾花数据集包含了150条鸢尾花信息,每50条取自三个鸢尾花中之一:Versicolour、Setosa和Virginica。

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

if __name__ == '__main__':
      
    iris = load_iris() 

    # 数据标准化
    transformer = StandardScaler()
    x_ = transformer.fit_transform(iris.data) 

    # 模型训练
    estimator = KNeighborsClassifier(n_neighbors=3) # Knn中的K值
    estimator.fit(x_, iris.target) # 调用fit方法 传入特征和目标进行模型训练

    result = estimator.predict(x_) 
    print(result)

 数据集划分

 为了能够评估模型的泛化能力,可以通过实验测试对学习器的泛化能力进行评估,进而做出选择。因此需要使用一个测试集来测试学习器对新样本的判别能力。(2比8)

留出法:将数据集划分成两个互斥的集合:训练集,测试集。

交叉验证:将数据集划分为训练集,验证集,测试集 (验证集用于参数调整)。

留出法:

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import ShuffleSplit
from collections import Counter
from sklearn.datasets import load_iris


def test1():

    # 载数据集
    x, y = load_iris(return_X_y=True)
    print('原始类别比例:', Counter(y))

    # 留出法(随机分割)
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
    print('随机类别分割:', Counter(y_train), Counter(y_test))

    # 留出法(分层分割)
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y)
    print('分层类别分割:', Counter(y_train), Counter(y_test))


def test2():

    
    x, y = load_iris(return_X_y=True)
    print('原始类别比例:', Counter(y))
    print('*' * 40)

    # 多次划分(随机分割)
    spliter = ShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
    for train, test in spliter.split(x, y):
        print('随机多次分割:', Counter(y[test]))

    print('*' * 40)

    # 多次划分(分层分割)
    spliter = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
    for train, test in spliter.split(x, y):
        print('分层多次分割:', Counter(y[test]))


if __name__ == '__main__':
    test1()
    test2()

原始类别比例: Counter({0: 50, 1: 50, 2: 50})
随机类别分割: Counter({1: 41, 0: 40, 2: 39}) Counter({2: 11, 0: 10, 1: 9})
分层类别分割: Counter({2: 40, 0: 40, 1: 40}) Counter({2: 10, 1: 10, 0: 10})
原始类别比例: Counter({0: 50, 1: 50, 2: 50})
****************************************
随机多次分割: Counter({1: 13, 0: 11, 2: 6})
随机多次分割: Counter({1: 12, 2: 10, 0: 8})
随机多次分割: Counter({1: 11, 0: 10, 2: 9})
随机多次分割: Counter({2: 14, 1: 9, 0: 7})
随机多次分割: Counter({2: 13, 0: 12, 1: 5})
****************************************
分层多次分割: Counter({0: 10, 1: 10, 2: 10})
分层多次分割: Counter({2: 10, 0: 10, 1: 10})
分层多次分割: Counter({0: 10, 1: 10, 2: 10})
分层多次分割: Counter({1: 10, 2: 10, 0: 10})
分层多次分割: Counter({1: 10, 2: 10, 0: 10})

train_test_split 是一个函数,它用于将数据集划分为训练集和测试集。它可以随机地将数据集划分为两个子集,并可以指定划分的比例或数量。这个方法适用于大多数机器学习任务,特别是需要将数据集划分为训练集和测试集的情况。

ShuffleSplit 是一个类,它用于生成多个独立的训练/测试数据划分。与 train_test_split 不同,ShuffleSplit 会随机打乱数据集的顺序,然后根据指定的参数进行划分。这个方法适用于交叉验证的场景,特别是在需要多次划分数据集以评估模型性能的情况下。

总结来说,train_test_split 是一个简单的函数,用于将数据集划分为训练集和测试集;而 ShuffleSplit 是一个类,用于生成多个独立的训练/测试数据划分,适用于交叉验证的场景。

交叉验证法 

 

 K-Fold交叉验证,将数据随机且均匀地分成k分

  • 第一次使用标号为0-8的共9份数据来做训练,而使用标号为9的这一份数据来进行测试,得到一个准确率
  • 第二次使用标记为1-9的共9份数据进行训练,而使用标号为0的这份数据进行测试,得到第二个准确率
  • 以此类推,每次使用9份数据作为训练,而使用剩下的一份数据进行测试,共进行10次训练,最后模型的准确率为10次准确率的平均值
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from collections import Counter
from sklearn.datasets import load_iris

def test():

    
    x, y = load_iris(return_X_y=True)
    print('原始类别比例:', Counter(y))
    print('*' * 40)

    #  随机交叉验证
    spliter = KFold(n_splits=5, shuffle=True, random_state=0)
    for train, test in spliter.split(x, y):
        print('随机交叉验证:', Counter(y[test]))

    print('*' * 40)

    # 分层交叉验证
    spliter = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
    for train, test in spliter.split(x, y):
        print('分层交叉验证:', Counter(y[test]))

 随机交叉验证: Counter({1: 13, 0: 11, 2: 6})
随机交叉验证: Counter({2: 15, 1: 10, 0: 5})
随机交叉验证: Counter({0: 10, 1: 10, 2: 10})
随机交叉验证: Counter({0: 14, 2: 10, 1: 6})
随机交叉验证: Counter({1: 11, 0: 10, 2: 9})
****************************************
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})

分类算法的评估 

  • 利用训练好的模型使用测试集的特征值进行预测

  • 将预测结果和测试集的目标值比较,计算预测正确的百分比

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
#加载鸢尾花数据集
X,y = datasets.load_iris(return_X_y = True)
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)

knn_clf = KNeighborsClassifier(n_neighbors=6)

knn_clf.fit(X_train,y_train)
y_predict = knn_clf.predict(X_test)
sum(y_predict==y_test)/y_test.shape[0]
# 0.8666666666666667

 SKlearn中模型评估

  • sklearn.metrics包中的accuracy_score方法: 传入预测结果和测试集的标签, 返回预测准确率
from sklearn.metrics import accuracy_score
accuracy_score(y_test,y_predict)

如何确定合适的K值

K值过小:容易受到异常点的影响

k值过大:受到样本均衡的问题

我们可以采用交叉验证法来选择最优的K值。

GridSearchCV

GridSearchCV 是 scikit-learn 库中的一个类,用于进行参数网格搜索。它结合了交叉验证和网格搜索的功能,可以自动地对给定的模型和参数组合进行训练和评估,以找到最佳的参数设置。

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

# 定义模型和参数网格
model = SVC()
param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}

# 创建 GridSearchCV 对象并进行训练
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=5)
grid_search.fit(X_train, y_train)

# 获取最佳参数和对应的评分
best_params = grid_search.best_params_
best_score = grid_search.best_score_

# 使用最佳参数重新训练模型
best_model = grid_search.best_estimator_
best_model.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = best_model.predict(X_test)

GridSearchCV 会遍历所有可能的参数组合,并对每个组合进行交叉验证。这可能会消耗大量的计算资源和时间,特别是当参数空间较大时。因此,在使用 GridSearchCV 时,需要权衡参数网格的大小和计算资源的可用性。

from sklearn.model_selection import GridSearchCV
x, y = load_iris(return_X_y=True)

x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, stratify=y, random_state=0)

# 创建网格搜索对象
estimator = KNeighborsClassifier()
param_grid = {'n_neighbors': [1, 3, 5, 7]}
estimator = GridSearchCV(estimator, param_grid=param_grid, cv=5, verbose=0)
estimator.fit(x_train, y_train)

print('最优参数组合:', estimator.best_params_, '最好得分:', estimator.best_score_)

print('测试集准确率:', estimator.score(x_test, y_test))

# 最优参数组合: {'n_neighbors': 7} 最好得分: 0.9583333333333334
# 测试集准确率: 1.0

手写数字案例

数据集:可以从MNIST数据集或UCI欧文大学机器学习存储库中获取手写数字的数据。这些数据集包含了大量已经标注好的手写数字图片。

import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import joblib
from collections import Counter

def show_digit(idx):
    
    data = pd.read_csv('手写数字识别.csv')
    if idx < 0 or idx > len(data) - 1:
        return
    x = data.iloc[:, 1:]
    y = data.iloc[:,0]
    print('当前数字的标签为:',y[idx])    # 查看当前数字的数值

    
    data_ = x.iloc[idx].values
    # 将数据形状修改为 28*28
    data_ = data_.reshape(28, 28)     # 显示当前数字在数据集的图片
    
    # 显示图像
    plt.imshow(data_)
    plt.show()

def train_model():

    # 1. 加载手写数字数据集
    data = pd.read_csv('手写数字识别.csv')
    x = data.iloc[:, 1:] / 255
    y = data.iloc[:, 0]

    # 2. 打印数据基本信息
    print('数据基本信息:', x.shape)
    print('类别数据比例:', Counter(y))

    # 3. 分割数据集
    split_data = train_test_split(x, y, test_size=0.2, stratify=y, random_state=0)
    x_train, x_test, y_train, y_test = split_data

    # 4. 模型训练
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)

    # 5. 模型评估
    acc = estimator.score(x_test, y_test)
    print('测试集准确率: %.2f' % acc)

    # 6. 模型保存
    joblib.dump(estimator, 'knn.pth')


def test_model():
    # 读取图片数据
    import matplotlib.pyplot as plt
    import joblib
    img = plt.imread('demo.png') # 对于灰度图像,返回的是一个二维数组,其中每个元素是一个介于0和1之间的浮点数,表示该像素的灰度值
    plt.imshow(img)
    # 加载模型
    knn = joblib.load('knn.pth')
    y_pred = knn.predict(img.reshape(1, -1)) #首先将从图片中读取到的数据 (img) 重塑为一维数组形式,以便给 KNN 分类器进行预测。
    print('您绘制的数字是:', y_pred)


show_digit(1)
    # 训练模型
train_model()
    # 测试模型
test_model()

 

 

小结: 


KNN(K-Nearest Neighbors)算法,即K最近邻算法,是一种监督学习算法,可以用于分类和回归问题。其基本思想是:给定一个训练数据集,对于新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数属于某个类别,则该输入实例也属于这个类别。

KNN算法的主要步骤如下:

  1. 计算输入实例与训练数据集中的每个实例之间的距离。常用的距离度量方法有欧氏距离、曼哈顿距离等。

  2. 对计算出的距离进行排序,找出距离最近的K个邻居。

  3. 统计这K个邻居所属的类别,选择出现次数最多的类别作为输入实例的预测类别。

  4. 如果用于回归问题,则计算这K个邻居的平均值或加权平均值作为输入实例的预测值。

KNN算法的优点:

  1. 算法简单,易于理解。

  2. 适用于多分类问题。

  3. 对于一些非线性问题,KNN算法具有较好的性能。

KNN算法的缺点:

  1. 当训练数据集较大时,计算距离的时间复杂度较高。

  2. K值的选择对算法性能影响较大,但目前没有确定K值的通用方法。

  3. 对于不平衡数据集,KNN算法的性能较差。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1513045.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iPhone, Android 手机是如何收到推送通知的?

本文转自 公众号 ByteByteGo&#xff0c;如有侵权&#xff0c;请联系&#xff0c;立即删除 iPhone, Android 手机是如何收到推送通知的&#xff1f; 我们的手机或电脑是如何收到推送通知的&#xff1f; 通常我们可以使用消息解决方案 Firebase 来支持通知推送。下图显示了 Fi…

云数据库Redis配置用户名密码连接

一般情况,生产环境6379端口是禁止对外开放的, 所有用户名密码可以不设置。 但是如果有格鲁需求,需要开放redis公网访问,建议端口限制IP,并设置用户密码 spring中配置 阿里云数据库 云数据库 Redis_缓存数据库_高并发_读写分离-阿里云 添加白名单 申请公网访问地址 配…

2024年共享WiFi项目到底怎么样?

共享WiFi项目是近年来兴起的一种新型商业模式&#xff0c;商家通过在自己店铺升级wifi链接模式使其数字化&#xff0c;让用户能够方便地连接到互联网&#xff0c;提升到店体验&#xff0c;增加线上引流。这一项目的出现&#xff0c;为人们的生活带来了诸多便利&#xff0c;同时…

基于SpringBoot的“实习管理系统”的设计与实现(源码+数据库+文档+PPT)

基于SpringBoot的“实习管理系统”的设计与实现&#xff08;源码数据库文档PPT) 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringBoot 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 系统首页界面图 学生注册界面图 后台登录界面图 …

1.Python是什么?——《跟老吕学Python编程》

1.Python是什么&#xff1f;——《跟老吕学Python编程》 Python是一种什么样的语言&#xff1f;Python的优点Python的缺点 Python发展历史Python的起源Python版本发展史 Python的价值学Python可以做什么职业&#xff1f;Python可以做什么应用&#xff1f; Python是一种什么样的…

LoadBalancer负载均衡服务调用

LoadBalancer负载均衡服务调用 1、Ribbon目前也进入维护 ​ Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端 负载均衡的工具。 ​ 简单的说&#xff0c;Ribbon是Netflix发布的开源项目&#xff0c;主要功能是**提供客户端的软件负载均衡算法和服务调用。**Ribbon…

自动备份文件:守护数据安全新利器

随着信息化时代的到来&#xff0c;文件已经成为我们日常生活和工作中不可或缺的一部分。然而&#xff0c;数据丢失或损坏的风险也随之而来&#xff0c;因此自动备份文件的重要性愈发凸显。自动备份文件不仅可以保护我们的宝贵数据&#xff0c;还可以在意外发生时迅速恢复&#…

Seata源码流程图

1.第一阶段分支事务的注册 流程图地址&#xff1a;https://www.processon.com/view/link/6108de4be401fd6714ba761d 2.第一阶段开启全局事务 流程图地址&#xff1a;https://www.processon.com/view/link/6108de13e0b34d3e35b8e4ef 3.第二阶段全局事务的提交 流程图地址…

Python | Logger通用日志记录器

一、代码 通用日志记录器&#xff0c;可以输出不同颜色的记录到控制台&#xff0c;并输出到指定文件夹下可以在不同py文件中同时使用&#xff0c;共用同一个记录器适用window或linux平台 #!/usr/bin/env python # -*- coding: utf-8 -*- import os import inspect import log…

镭速教你如何解决大数据量串行处理的问题

大数据的高效处理成为企业发展的关键。然而&#xff0c;大数据量串行处理的问题常常困扰着许多企业&#xff0c;尤其是在数据传输方面。本文将探讨大数据量串行处理的常见问题&#xff0c;并介绍企业常用的处理方式&#xff0c;最后重点阐述镭速如何提供创新解决方案&#xff0…

Claude3发布,将取代ChatGPT4?

目录标题 前言Claude简介Claude 3 的能力高级推理视觉分析代码生成多语言处理 性能比较 前言 一夜之间&#xff0c;全球最强 AI 模型易主。大模型行业再次经历变革。OpenAI 最大的竞争对手 Anthropic 发布了新一代 AI 大模型系列 ——Claude 3。该系列包含三个模型&#xff0c…

鸿蒙开发之MPChart图表开发

一、简介 随着移动应用的不断发展,数据可视化成为提高用户体验和数据交流的重要手段之一,因此需要经常使用图表,如折线图、柱形图等。OpenHarmony提供了一个强大而灵活的图表库是实现这一目标的关键。 在 ohpm 中心仓(https://ohpm.openharmony.cn/)中,汇聚了众多开发者…

【Python】新手入门学习:详细介绍接口分隔原则(ISP)及其作用、代码示例

【Python】新手入门学习&#xff1a;详细介绍接口分隔原则&#xff08;ISP&#xff09;及其作用、代码示例 &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、Py…

c++中多种类型sort()排序的用法(数组、结构体、pair、vector)

c中多种类型sort排序的用法 一、对数组排序1、默认排序2、自定义排序 二、对结构体进行排序三、对pair进行排序1、默认排序2、自定义排序 四、对vector进行排序1、默认排序2、去重排序3、自定义排序 一、对数组排序 1、默认排序 默认从小到大进行排序 #include <bits/std…

Windows Server 各版本搭建终端服务器实现远程访问(03~19)

一、Windows Server 2003 左下角开始➡管理工具➡管理您的服务器&#xff0c;点击添加或删除角色 点击下一步 勾选自定义&#xff0c;点击下一步 蒂埃涅吉终端服务器&#xff0c;点击下一步 点击确定 重新登录后点击确定 点击开始➡管理工具➡计算机管理&#xff0c;展开本地…

Java算法总结之冒泡排序(详解)

程序代码园发文地址&#xff1a;Java算法总结之冒泡排序&#xff08;详解&#xff09;-程序代码园小说,Java,HTML,Java小工具,程序代码园,http://www.byqws.com/ ,Java算法总结之冒泡排序&#xff08;详解&#xff09;http://www.byqws.com/blog/3145.html?sourcecsdn 冒泡排序…

网址如何转静态二维码?扫码跳转链接的制作步骤

一般网址想要转换成可以长期使用的二维码&#xff0c;可以通过制作静态码的方式将链接网址转二维码图片使用。这种方式只是将网址从链接的形式转换成二维码的形式&#xff0c;只要添加的网址不失效&#xff0c;那么二维码是可以长期扫码展示内容的&#xff1f;那么如何制作网址…

Mybatis八股

Mybatis是什么 Mybatis是一个半ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;它内部封装了JDBC&#xff0c;加载驱动、创建连接、创建statement等繁杂的过程&#xff0c;开发者开发时只需要关注如何编写SQL语句&#xff0c;可以严格控制sql执行性能&#xff0c;灵…

学成在线-生成扫码下单接口的二维码同时创建创建商品订单记录和交易支付记录

生成下单接口二维码 界面原型 打开课程支付引导界面&#xff0c;点击支付宝支付按钮商户系统生成下单的二维码接口&#xff0c;用户扫描二维码后商户系统开始请求支付宝下单 用户扫码开始请求支付宝下单&#xff0c;但是在生成下单接口的二维码前前端需要做一些操作 前端调用…

vue项目因内存溢出启动报错

前端能正常启动&#xff0c;但只要一改动就报错启动出错。 解决办法&#xff1a; 安装依赖 npm install cross-env increase-memory-limit 然后再做两件事&#xff1a;在node 在package.json 里的 script 里进行配置 LIMIT是你想分配的内存大小&#xff0c;这里的8192单位…