KNN算法原理及应用

news2024/12/29 10:48:06

理解KNN 算法原理

KNN是监督学习分类算法,主要解决现实生活中分类问题。

根据目标的不同将监督学习任务分为了分类学习及回归预测问题。

监督学习任务的基本流程和架构:

(1)首先准备数据,可以是视频、音频、文本、图片等等

(2)抽取所需要的一些列特征,形成特征向量(Feature Vectors)。

(3)将这些特征向量连同标记一并送入机器学习算法中,训练出一个预测模型。

(4)然后,采用同样的特征提取方法作用于新数据,得到用于测试的特征向量。

(5)最后,使用预测模型对这些待测的特征向量进行预测并得到结果(Expected Model)。

KNN(K-Nearest Neihbor,KNN)K近邻是机器学习算法中理论最简单,最好理解的算法,是一个非常适合入门的算法,拥有如下特性:

  • 思想极度简单,应用数学知识少(近乎为零),对于很多不擅长数学的小伙伴十分友好

  • 虽然算法简单,但效果也不错

如果要了解一个人的经济水平,只需要知道他最好的5个朋友的经济能力, 对他的这五个人的经济水平求平均就是这个人的经济水平。这句话里面就包含着kNN的算法思想。 

如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

类别的判定

①投票决定,少数服从多数。取类别最多的为测试样本类别。

②加权投票法,依据计算得出距离的远近,对近邻的投票进行加权,距离越近则权重越大,设定权重为距离平方的倒数。

KNN 算法原理简单,不需要训练,属于监督学习算法,常用来解决分类问题

KNN原理:先确定K值, 再计算距离,最后挑选K个最近的邻居进行投票

 KNN的应用 

KNN即能做分类又能做回归, 还能用来做数据预处理的缺失值填充。由于KNN模型具有很好的解释性,对于每一个预测结果,我们可以很好的进行解释。文章推荐系统中, 对于一个用户A,我们可以把和A最相近的k个用户,浏览过的文章推送给A。

算法的思想:通过K个最近的已知分类的样本来判断未知样本的类别。

KNN三要素:

  • 距离度量
  • K值选择
  • 分类决策准则

 鸢尾花数据集

鸢尾花Iris Dataset数据集是机器学习领域经典数据集,鸢尾花数据集包含了150条鸢尾花信息,每50条取自三个鸢尾花中之一:Versicolour、Setosa和Virginica

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

iris = load_iris() #通过iris.data 获取数据集中的特征值  iris.target获取目标值
# 数据标准化
transformer = StandardScaler()
x_ = transformer.fit_transform(iris.data) # iris.data 数据的特征值

#  模型训练
estimator = KNeighborsClassifier(n_neighbors=3) # n_neighbors 邻居的数量,也就是Knn中的K值
estimator.fit(x_, iris.target) # 调用fit方法 传入特征和目标进行模型训练
# 利用模型预测
result = estimator.predict(x_)
print(result)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 2 1
 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2
 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]

sklearn中自带了几个学习数据集,都封装在sklearn.datasets 这个包中,加载数据后,通过data属性可以获取特征值,通过target属性可以获取目标值。

Demo数据集--kNN分类

1: 库函数导入

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.neighbors import KNeighborsClassifier
from sklearn import datasets

2: 数据导入

iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target

3: 模型训练

k_list = [1, 3, 5, 8, 10, 15]
h = .02
# 创建不同颜色画布
cmap_light = ListedColormap(['orange', 'cyan', 'cornflowerblue'])
cmap_bold = ListedColormap(['darkorange', 'c', 'darkblue'])

plt.figure(figsize=(15,14))
# 根据不同的k值进行可视化
for ind,k in enumerate(k_list):
    clf = KNeighborsClassifier(k)
    clf.fit(X, y)
    
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    
    Z = Z.reshape(xx.shape)

    plt.subplot(321+ind)  
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
    
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,
                edgecolor='k', s=20)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.title("3-Class classification (k = %i)"% k)

plt.show()

当k=1的时候,在分界点位置的数据很容易受到局部的影响,图中蓝色的部分中还有部分绿色块,主要是数据太局部敏感。当k=15的时候,不同的数据基本根据颜色分开,当时进行预测的时候,会直接落到对应的区域。

数据集划分 

不能将所有数据集全部用于训练,为了能够评估模型的泛化能力,可以通过实验测试对学习器的泛化能力进行评估,进而做出选择。因此需要使用一个测试集来测试学习器对新样本的判别能力。

测试集要代表整个数据集、与训练集互斥、测试集与训练集建议比例: 2比8、3比7。

数据集划分的方法

1.将数据集划分成两个互斥的集合:训练集,测试集。

  • 训练集用于模型训练
  • 测试集用于模型验证

2.将数据集划分为训练集,验证集,测试集

  • 训练集用于模型训练
  • 验证集用于参数调整
  • 测试集用于模型验证

 1:将数据集 D 划分为两个互斥的集合,其中一个集合作为训练集 S,另一个作为测试集 T。

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import ShuffleSplit
from collections import Counter
from sklearn.datasets import load_iris


def te1():

    # 1. 加载数据集
    x, y = load_iris(return_X_y=True)
    print('原始类别比例:', Counter(y))

        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
    print('随机类别分割:', Counter(y_train), Counter(y_test))

    
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y)
    print('分层类别分割:', Counter(y_train), Counter(y_test))


def te2():

    # 1. 加载数据集
    x, y = load_iris(return_X_y=True)
    print('原始类别比例:', Counter(y))
    print('*' * 40)

    # 2. 多次划分
    spliter = ShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
    for train, test in spliter.split(x, y):
        print('随机多次分割:', Counter(y[test]))

    spliter = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
    for train, test in spliter.split(x, y):
        print('分层多次分割:', Counter(y[test]))

if __name__ == '__main__':
    te1()
    te2()

# 随机类别分割: Counter({0: 43, 2: 40, 1: 37}) Counter({1: 13, 2: 10, 0: 7})
# 分层类别分割: Counter({1: 40, 2: 40, 0: 40}) Counter({2: 10, 1: 10, 0: 10})
随机多次分割: Counter({1: 13, 0: 11, 2: 6})
随机多次分割: Counter({1: 12, 2: 10, 0: 8})
随机多次分割: Counter({1: 11, 0: 10, 2: 9})
随机多次分割: Counter({2: 14, 1: 9, 0: 7})
随机多次分割: Counter({2: 13, 0: 12, 1: 5})

分层多次分割: Counter({0: 10, 1: 10, 2: 10})
分层多次分割: Counter({2: 10, 0: 10, 1: 10})
分层多次分割: Counter({0: 10, 1: 10, 2: 10})
分层多次分割: Counter({1: 10, 2: 10, 0: 10})
分层多次分割: Counter({1: 10, 2: 10, 0: 10})

 K-Fold交叉验证,将数据随机且均匀地分成k分

 

  • 第一次使用标号为0-8的共9份数据来做训练,而使用标号为9的这一份数据来进行测试,得到一个准确率
  • 第二次使用标记为1-9的共9份数据进行训练,而使用标号为0的这份数据进行测试,得到第二个准确率
  • 共进行10次训练,最后模型的准确率为10次准确率的平均值,避免了数据划分而造成的评估不准确。
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from collections import Counter
from sklearn.datasets import load_iris

def test():

    # 1. 加载数据集
    x, y = load_iris(return_X_y=True)
    print('原始类别比例:', Counter(y))
    print('*' * 30)

    # 2. 随机交叉验证
    spliter = KFold(n_splits=5, shuffle=True, random_state=0)
    for train, test in spliter.split(x, y):
        print('随机交叉验证:', Counter(y[test]))

    print('*' * 30)

    # 3. 分层交叉验证
    spliter = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
    for train, test in spliter.split(x, y):
        print('分层交叉验证:', Counter(y[test]))

test()

数据划分结果:

随机交叉验证: Counter({1: 13, 0: 11, 2: 6})
随机交叉验证: Counter({2: 15, 1: 10, 0: 5})
随机交叉验证: Counter({0: 10, 1: 10, 2: 10})
随机交叉验证: Counter({0: 14, 2: 10, 1: 6})
随机交叉验证: Counter({1: 11, 0: 10, 2: 9})
****************************************
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})
分层交叉验证: Counter({0: 10, 1: 10, 2: 10})

留一法:每次抽取一个样本做为测试集。

from sklearn.model_selection import LeaveOneOut
from sklearn.model_selection import LeavePOut
from sklearn.datasets import load_iris
from collections import Counter


def test01():

    # 1. 加载数据集
    x, y = load_iris(return_X_y=True)
    print('原始类别比例:', Counter(y))
    print('*' * 40)

    
    spliter = LeaveOneOut()
    for train, test in spliter.split(x, y):
        print('训练集:', len(train), '测试集:', len(test), test)

test01()

分类算法的评估

  • 利用训练好的模型使用测试集的特征值进行预测

  • 将预测结果和测试集的目标值比较,计算预测正确的百分比

  • 百分比就是准确率, 准确率越高说明模型效果越好

确定合适的K值 

K值过小:容易受到异常点的影响

k值过大:受到样本均衡的问题,K一般取一个较小的数值。

使用 scikit-learn 提供的 GridSearchCV 工具, 配合交叉验证法可以搜索参数组合

x, y = load_iris(return_X_y=True)

# 分割数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y, random_state=0)

estimator = KNeighborsClassifier()
grid = {'n_neighbors': [1, 3, 5]}
estimator = GridSearchCV(estimator, param_grid=grid, cv=5)
estimator.fit(x_train, y_train)

K近邻算法的优缺点:

  • 优点:简单,易于理解,容易实现
  • 缺点:算法复杂度高,结果对K取值敏感,容易受数据分布影响

knn算法中我们最需要关注两个问题:k值的选择和距离的计算。

距离/相似度的计算:

样本之间的距离的计算,我们一般使用对于一般使用Lp距离进行计算。当p=1时候,称为曼哈顿距离,当p=2时候,称为欧氏距离,当p=∞时候,称为极大距离。一般采用欧式距离较多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1394760.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大模型关键技术:上下文学习、思维链、RLHF、参数微调、并行训练、旋转位置编码、模型加速、大模型注意力机制优化、永久记忆、LangChain、知识图谱、多模态

大模型关键技术 大模型综述上下文学习思维链 CoT奖励建模参数微调并行训练模型加速永久记忆:大模型遗忘LangChain知识图谱多模态大模型系统优化AI 绘图幻觉问题从 GPT1 - GPT4 拆解GPTs 对比主流大模型技术点旋转位置编码层归一化激活函数注意力机制优化 大模型综述…

天锐绿盾有哪些功能?

天锐绿盾是一款企业内网安全管理软件,具有多种功能来保护企业的信息安全。 PC地址: https://isite.baidu.com/site/wjz012xr/2eae091d-1b97-4276-90bc-6757c5dfedee 以下是天锐绿盾的主要功能: 文件加密保护:天锐绿盾可以对文件…

C# 线程间操作无效: 从不是创建控件的线程访问它--多线程操作

我们在用线程操作的时候,可能会出现异常:线程间操作无效: 从不是创建控件richTextBox1的线程访问它。因为windows窗体控件不是线程安全的,如果几个线程操作某一控件的状态,可能会使该控件的状态不一致,出现争用或死锁状…

黑马 Javaweb - MySQL 精华篇

我是南城余!阿里云开发者平台专家博士证书获得者! 欢迎关注我的博客!一同成长! 一名从事运维开发的worker,记录分享学习。 专注于AI,运维开发,windows Linux 系统领域的分享! 知…

华为交换机配置NQA DNS检测IP网络DNS解析速度

华为HCIA视频教程:超级实用,华为VRP系统文件详解 华为HCIA视频教程:不会传输层协议,HCIA都考不过 华为HCIA视频教程:网络工程师的基本功:网络地址转换NAT 华为HCIP视频教程:DHCP协议原理与配…

Hadoop集群配置及测试

Hadoop集群配置及测试 NameNode与SecondaryNameNode最好不在同一服务器 ResourceManager较为消耗资源,因而和NameNode与SecondaryNameNode最好不在同一服务器。 配置文件 hadoop102hadoop103hadoop104HDFSNameNodeDataNodeDataNodeSecondaryNameNodeDataNodeYAR…

如何通过IDEA创建基于Java8的Spring Boot项目

上次发现我的IDEA创建Spring Boot项目时只支持11和17的JDK版本,于是就通过Maven搭建SpringBoot项目。 究其原因,原来是Spring官方抛弃了Java8!!! 使用IDEA内置的Spring Initializr创建SpringBoot项目时,已…

设计模式——1_5 享元(Flyweight)

今人不见古时月,今月曾经照古人 ——李白 文章目录 定义图纸一个例子:可以复用的样式表绘制表格降本增效?第一步,先分析 变化和不变的地方第二步,把变化和不变的地方拆开来第三步:有没有办法共享这些内容完…

【数据结构】堆的实现和排序

目录 1、堆的概念和结构 1.1、堆的概念 1.2、堆的性质 1.3、堆的逻辑结构和存储结构 2、堆的实现 2.1、堆的初始化和初始化 2.2、堆的插入和向上调整算法 2.3、堆的删除和向下调整算法 2.4、取堆顶的数据和数据个数 2.5、堆的判空和打印 2.6、测试 3、堆的应用 3.1…

AIGC之视频图片生成工具gen-2

最近无事时研究了一款图片和视频生成工具,先说结论: 1.可以生成视频,生成方式有三种 通过文本的方式生成视频可以通过图片的方式生成视频也可以通过图片文本的方式生成视频 2.可以通过文本描述的方式生成图片 3.生成的视频有瑕疵&#xf…

Eureka整合seata分布式事务

文章目录 前言一、Seata配置1.1、Seata下载1.2、修改conf目录中 flie.conf 文件1.3、修改conf目录中 registry.conf文件1.4、初始化seata数据库 二、微服务整合Seata2.1、父工程项目创建引入依赖 2.2、Eureka集群搭建2.3、搭建账户微服务2.3.1 新建seata-account-service微服务…

React全局状态管理

redux是一个状态管理框架,它可以帮助我们清晰定义state和处理函数,提高可读性,并且redux中的状态是全局共享,规避组件间通过props传递状态等操作。 快速使用 在React应用的根节点,需要借助React的Context机制存放整个…

安卓apk加固后重签名

背景 等保检测,安卓apk使用第三方加固后签名信息会丢失,需要我们重新进行签名 使用jarsigner签名遇到的问题 APP失效无法安装 如何解决签名失效 我们在这里使用Android SDK的apksigner进行签名 mac系统,apksigner 需要设置环境变量 1、…

leedcode刷题day2

题目: 根据这道题我的思路是用python首先将第一个值赋给a,然后将下一个值赋值给b在这里写一个循环计算下一个值是否等于a,不等于就进入数组当等于a的时候输出数组长度,然后比较数组长度输出最长长度对应的元素不过显然这很慢。 然…

【Linux】权限的深度解析

前言:在此之前我们学习了一些常用的Linux指令,今天我们进一步学习Linux下权限的一些概念 💖 博主CSDN主页:卫卫卫的个人主页 💞 👉 专栏分类:Linux的学习 👈 💯代码仓库:卫卫周大胖的学习日记&a…

行列转化【附加面试题】

在MySQL中,行列转换是一种常见的操作。它包括行转列和列转行两种情况。 行转列:行转列是将表中的某些行转换成列,以提供更为清晰、易读的数据视图。例如,假设我们有一个包含科目和分数的表,我们可以使用SUM和CASE语句…

一款轻量级、基于Java语言开发的低代码开发框架,开箱即用!

数字化时代,企业对于灵活、高效和安全的软件开发需求日益旺盛。为了满足这些需求,许多组织转向低代码技术,以寻求更具成本效益和创新性的解决方案。JNPF基础框架正是在这一背景下应运而生,凭借其私有化部署和100%源码交付的特性&a…

011:vue结合css动画animation实现下雪效果

文章目录 1. 实现效果2. 编写一个下雪效果组件 VabSnow.vue3. 页面使用4. 注意点 1. 实现效果 GIF录屏文件太卡有点卡&#xff0c;实际是很丝滑的 2. 编写一个下雪效果组件 VabSnow.vue 在 src 下新建 components 文件&#xff0c;创建VabSnow.vue组件文件 <template>…

C++系统笔记教程----vscode远程连接ssh

C系统笔记教程 文章目录 C系统笔记教程前言开发环境配置总结 前言 开发环境配置 Ubuntu20.24VScode 如果没有linux系统&#xff0c;但是想用其编译&#xff0c;可以使用ssh远程连接。 首先进入vscode,打开远程连接窗口&#xff08;蓝色的小箭头这&#xff09; 选择连接到主机…

三菱plc学习入门(创建属于自己的FB模块)

在现实生活中&#xff0c;往往会需要修改一些属于方便自己的库&#xff0c;1&#xff0c;自己创建的库方便自己使用与查看2&#xff0c;提高自己编程能力&#xff0c;3&#xff0c;保护自己的程序不被外人修改&#xff01;&#xff01;&#xff01;下面就让我来操作一下 导入需…