从头开始使用 KNN 进行 KNN 和 MNIST 手写数字识别的初学者指南

news2024/12/28 23:25:35

坦维·佩努穆迪

Kaggle参考: MNIST Perfect 100% using kNN | Kaggle

一、说明

        MNIST (“修改后的国家标准与技术研究所”)是事实上的计算机视觉“hello world”数据集。自 1999 年发布以来,这个经典的手写图像数据集一直作为分类算法基准测试的基础。随着新的机器学习技术的出现,MNIST 仍然是研究人员和学习者的可靠资源。

最终目标是从数万张手写图像的数据集中正确识别数字。

图片来源:维基百科

我们现在将尝试从头开始使用KNN(K 最近邻)算法对数字进行分类

在此之前,我们先来了解一下KNN到底是什么!

二、如何读取Mnist? 

读取Mnist可以用tensorflow完成,也可以用numpy完成。如下:

def load_data(path):
    with np.load(path) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_data('../input/mnist-numpy/mnist.npz')

三、什么是KNN?

        K 最近邻可用于分类和回归。K 最近邻是一种简单的算法,它存储所有可用案例并根据相似性度量对新案例进行分类。

      KNN是一种基于实例的学习惰性学习,其中函数仅在本地进行近似,并且所有计算都推迟到分类。KNN 算法是所有机器学习算法中最简单的算法之一。

        它是一种非参数算法,不需要训练数据来进行推理,因此与参数学习算法相比,训练速度要快得多,而推理速度要慢得多,原因显而易见。

四、KNN 到底是如何工作的?

我们通过一个简单的例子来理解这个算法。

以下是红色圆圈 (RC)绿色方块 (GS)的分布:

        您打算找出蓝星 (BS) 的等级。BS 可以是 RC 或 GS​​,仅此而已。KNN 算法中的“K”是我们希望投票的最近邻居。假设 K = 3。因此,我们现在将以 BS 为中心制作一个圆,其大小仅包含平面上的三个数据点。更多详情请参考下图:

        离SB最近的3个点都是RC。因此,根据我们的智能置信水平,我们可以说 bs 应该属于 RC 类别。在这里,选择变得非常明显,因为最近邻居的所有 3 票都投给了 RC。在此算法中,参数K的选择至关重要。接下来,我们将了解得出最有效的K需要考虑哪些因素。

注: KNN 的一些假设 —

  • 当您有 2 个类时,请选择奇数 K 值以避免平局。即,如果新数据点位于两个类之间,则它无法决定选择哪一个。
  • K 不能是类数的倍数
  • 如果 K 非常小(过拟合),如果您有很多数据点 (n) 将不准确
  • 如果 K 很大(Underfit),K 一定不能等于数据点的数量 n

五、我们如何选择因子K?

首先,让我们尝试了解 K 对算法到底有什么影响。如果我们看到最后一个例子,假设所有 6 个训练观察值保持不变,使用给定的 K 值,我们可以为每个类别划分边界。

K 值的训练误差

        正如您所看到的,训练样本在 K=1 时的错误率始终为零。这是因为与任何训练数据点最接近的点就是其本身。因此,当 K=1 时,预测总是准确的。如果验证误差曲线相似,我们选择的 K 将为 1。

        以下是不同 K 值的验证误差曲线:

K 值的测试/验证错误

        这让故事变得更加清晰。当 K=1 时,我们过度拟合了边界。因此,错误率最初下降并达到最小值。在最小值点之后,它会随着 K 的增加而增加。为了获得 K 的最佳值,您可以将训练和验证与初始数据集分开。现在绘制验证误差曲线以获得 K 的最佳值。这个 K 值应该用于所有预测,这与肘部方法类似

六、KNN 的伪代码

任何人都可以按照下面给出的伪代码步骤来实现 KNN 模型。

  1. 加载数据
  2. 初始化k的值
  3. 要获得预测类别,请从 1 迭代到训练数据点总数
  4. 计算测试数据与每行训练数据之间的距离。在这里,我们将使用欧几里德距离作为距离度量,因为它是最流行的方法。其他可以使用的度量有切比雪夫、余弦等。
  5. 根据距离值对计算出的距离进行升序排序
  6. 从排序数组中获取前 k 行
  7. 获取这些行中最常见的类别
  8. 返回预测的类别

七、KNN 的变体

        在传统提出的 KNN 中,正如我们所见,我们对所有类别和距离给予相同的权重,这是您应该了解的 KNN 的变体!

7.1 距离加权 KNN

        在距离加权 KNN 中,您基本上更多地强调更接近测试值的值,更少地强调远离测试值的值,并类似地为每个值分配权重。

其中 wk 是 —

7.2 加权距离函数

        由于我们在传统 KNN 中为所有特征赋予了相同的权重,因此我们尝试在此变体中为每个特征分配不同的权重。重要的特征将具有较大的权重,而不太重要的特征将具有较低的权重,而最不重要的特征将具有0或接近0的权重。

八、测量距离的方法

  • 闵可夫斯基距离
  • 曼哈顿距离
  • 欧氏距离
  • 汉明距离
  • 余弦距离

        当数据具有高维度时,曼哈顿距离通常优于更常见的欧几里得距离。汉明距离用于衡量分类变量之间的距离,余弦距离度量主要用于查找两个数据点之间的相似程度,明可夫斯基是欧几里德距离和曼哈顿距离在较低级别上的推广。

有关这方面的更多信息,请查看 —机器学习中使用的不同类型的距离度量

九、从头开始实施

您最初需要导入的库:

import numpy as np
import operator 
from operator import itemgetter

让我们首先定义一个返回两点之间的欧几里得距离的函数:

图片来源:Science Direct — 欧几里得距离公式

def euc_dist(x1, x2):
    return np.sqrt(np.sum((x1-x2)**2))

现在,让我们编写一个类“KNN”并为“K”值初始化一个实例

class KNN:
    def __init__(self, K=3):
        self.K = K

让我们在类中添加另一个函数来初始化实例以拟合我们的训练集 — X-train 和 y-train:

class KNN:
    def __init__(self, K=3):
        self.K = K
    def fit(self, x_train, y_train):
        self.X_train = x_train
        self.Y_train = y_train

现在让我们将预测函数添加到此类中:

def predict(self, X_test):
    predictions = [] 
    for i in range(len(X_test)):
        dist = np.array([euc_dist(X_test[i], x_t) for x_t in   
        self.X_train])
        dist_sorted = dist.argsort()[:self.K]
        neigh_count = {}
        for idx in dist_sorted:
            if self.Y_train[idx] in neigh_count:
                neigh_count[self.Y_train[idx]] += 1
            else:
                neigh_count[self.Y_train[idx]] = 1
        sorted_neigh_count = sorted(neigh_count.items(),    
        key=operator.itemgetter(1), reverse=True)
        predictions.append(sorted_neigh_count[0][0]) 
    return predictions

哇!这是很多代码!让我们逐行理解这一点——

        我们初始化了一个列表来存储我们的预测,然后运行一个循环来计算每个测试示例到每个相应训练示例的欧几里德距离,并将所有这些距离存储在 NumPy 数组中,之后我们返回第一个 K- 的索引对距离值进行排序,然后我们创建了一个字典,其中类标签作为键,它们的出现作为每个键的值。

        然后,我们将每个计数附加到每个键值对的 neigh_count 字典中,之后,我们将键值对从最常出现的值到最少出现的值进行排序,其中,我们最常出现的值将是我们对每个训练示例的预测。然后我们返回了预测。

这就是从头开始实现 KNN 的全部内容,现在让我们在 MNIST 数据集上测试我们的模型!

from sklearn.datasets import load_digits
mnist = load_digits()
print(mnist.data.shape)
Out:
(1797, 64)
X = mnist.data 
y = mnist.target

将我们的数据分为训练和测试:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=123)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
Out:
(1347, 64) (1347,) 
(450, 64) (450,)
print(np.unique(y_train,return_counts=True))
print(np.unique(y_test,return_counts=True))
Out:
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([127, 140, 136, 143, 129, 134, 133, 138, 129, 138])) (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([51, 42, 41, 40, 52, 48, 48, 41, 45, 42]))

将数据分为测试和训练就足够了吗?真的有帮助吗?

        通常建议使用交叉验证来分割我们的数据-

        在交叉验证中,我们没有将数据拆分为两部分,而是将其拆分为 3 部分(或K取决于K 折交叉验证中的 K 值)。训练数据、交叉验证数据和测试数据。在这里,我们使用训练数据来查找最近邻居,我们使用交叉验证数据来找到“K”的最佳值(这里是 K 个邻居),最后我们在完全看不见的测试数据上测试我们的模型。这个测试数据就相当于未来未见过的数据点。

        让我们导入更多辅助函数来评估 Sklearn 中的模型:

from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score

使用从 3 到 100 的所有可能的 K 值(奇数)训练我们的模型:

kVals = np.arange(3,100,2)
accuracies = []
for k in kVals:
  model = KNN(K = k)
  model.fit(X_train, y_train)
  pred = model.predict(X_test)
  acc = accuracy_score(y_test, pred)
  accuracies.append(acc)
  print("K = "+str(k)+"; Accuracy: "+str(acc))
Out:
K = 3; Accuracy: 0.9755555555555555
K = 5; Accuracy: 0.9755555555555555
K = 7; Accuracy: 0.9755555555555555
K = 9; Accuracy: 0.9755555555555555
K = 11; Accuracy: 0.9733333333333334
K = 13; Accuracy: 0.9711111111111111
K = 15; Accuracy: 0.9688888888888889
K = 17; Accuracy: 0.9666666666666667
K = 19; Accuracy: 0.9666666666666667
K = 21; Accuracy: 0.9666666666666667
K = 23; Accuracy: 0.9644444444444444
K = 25; Accuracy: 0.9644444444444444
K = 27; Accuracy: 0.9666666666666667
K = 29; Accuracy: 0.9622222222222222
K = 31; Accuracy: 0.96
K = 33; Accuracy: 0.96
K = 35; Accuracy: 0.9577777777777777
K = 37; Accuracy: 0.9577777777777777
K = 39; Accuracy: 0.9577777777777777
K = 41; Accuracy: 0.9555555555555556
K = 43; Accuracy: 0.9511111111111111
K = 45; Accuracy: 0.9488888888888889
K = 47; Accuracy: 0.9444444444444444
K = 49; Accuracy: 0.9444444444444444
K = 51; Accuracy: 0.9377777777777778
K = 53; Accuracy: 0.9355555555555556
K = 55; Accuracy: 0.9333333333333333
K = 57; Accuracy: 0.9333333333333333
K = 59; Accuracy: 0.9311111111111111
K = 61; Accuracy: 0.9333333333333333
K = 63; Accuracy: 0.9333333333333333
K = 65; Accuracy: 0.9311111111111111
K = 67; Accuracy: 0.9288888888888889
K = 69; Accuracy: 0.9266666666666666
K = 71; Accuracy: 0.9288888888888889
K = 73; Accuracy: 0.9311111111111111
K = 75; Accuracy: 0.9288888888888889
K = 77; Accuracy: 0.9266666666666666
K = 79; Accuracy: 0.92
K = 81; Accuracy: 0.9222222222222223
K = 83; Accuracy: 0.9222222222222223
K = 85; Accuracy: 0.92
K = 87; Accuracy: 0.9177777777777778
K = 89; Accuracy: 0.9177777777777778
K = 91; Accuracy: 0.9111111111111111
K = 93; Accuracy: 0.9111111111111111
K = 95; Accuracy: 0.9088888888888889
K = 97; Accuracy: 0.9088888888888889
K = 99; Accuracy: 0.9066666666666666

该模型在 K=3 时最准确:

max_index = accuracies.index(max(accuracies))
print(max_index)
Out:
0

绘制我们的准确性:

from matplotlib import pyplot as plt 
plt.plot(kVals, accuracies) 
plt.xlabel("K Value") 
plt.ylabel("Accuracy")
Out:
Text(0, 0.5, 'Accuracy')

检查精确率、召回率和 F 分数(以获得最准确的 K 值):

model = KNN(K = 3) 
model.fit(X_train, y_train) 
pred = model.predict(X_train)
precision, recall, fscore, _ = precision_recall_fscore_support(y_train, pred)
print("Precision \n", precision)
print("\nRecall \n", recall)
print("\nF-score \n", fscore)
Out:
Precision 
 [1.         0.9929078  1.         1.         1.         1.
 0.98518519 1.         0.9921875  0.99275362]

Recall 
 [1.         1.         1.         1.         1.         0.99253731
 1.         0.99275362 0.98449612 0.99275362]

F-score 
 [1.         0.99644128 1.         1.         1.         0.99625468
 0.99253731 0.99636364 0.98832685 0.99275362]

在测试集上对经过训练的模型进行推理:

model = KNN(K = 3)
model.fit(X_train, y_train)
pred = model.predict(X_test)
acc = accuracy_score(y_test, pred)
precision, recall, fscore, _ = precision_recall_fscore_support(y_test, pred)
print("Precision \n", precision)
print("\nRecall \n", recall)
print("\nF-score \n", fscore)
Out:
Precision 
 [1.         0.89361702 1.         0.93023256 0.98113208 1.
 1.         1.         1.         0.95      ]

Recall 
 [1.         1.         0.97560976 1.         1.         0.95833333
 1.         1.         0.91111111 0.9047619 ]

F-score 
 [1.         0.94382022 0.98765432 0.96385542 0.99047619 0.9787234
 1.         1.         0.95348837 0.92682927]
print(acc) #testing accuracy
Out:
0.9755555555555555 

我知道一下子要吸收很多东西!但你坚持到了最后!对此表示敬意!不要忘记查看我即将发表的文章!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1116825.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

编写内联函数求解 2x²+4x+5的值,并用主函数调用该函数

动态内存分配可以根据实际需要在程序运行过程中动态地申请内存空间,这种内存空间的分配和释放是由程序员自己管理的,因此也被称为手动内存分配。 C++ 中,动态内存的分配和释放是通过 new 和 delete 操作符进行的。new 操作符用于在堆内存上为对象动态分配空间,dele…

Python之哈希表-哈希表原理

Python之哈希表-哈希表原理 集合Set 集合,简称集。由任意个元素构成的集体。高级语言都实现了这个非常重要的数据结构类型。Python中,它是可变的、无序的、不重复的元素的集合 初始化 set() -> new empty set objectset(iterable) -> new set …

Docker——如何自定义镜像【将自己的项目制作成镜像】?

目录 前言:我们以前是如何部署项目的? 1、镜像由哪几部分构成的 2、如何手动自定义一个镜像 2.1、Dockerfile 2.2、dockerfile文本文件中,最终要写什么? 2.3、构建镜像 3、案例:部署java项目 4、如何与其他容器…

Python 机器学习入门之ID3决策树算法

系列文章目录 第一章 Python 机器学习入门之线性回归 第一章 Python 机器学习入门之梯度下降法 第一章 Python 机器学习入门之牛顿法 第二章 Python 机器学习入门之逻辑回归 番外 Python 机器学习入门之K近邻算法 番外 Python 机器学习入门之K-Means聚类算法 第三章 Python 机…

华为云HECS云服务器docker环境下安装nacos

华为云HECS云服务器,安装docker环境,查看如下文章。 华为云HECS安装docker-CSDN博客 一、拉取镜像 docker pull nacos/nacos-server二、宿主机创建挂载目录 执行如下命令: mkdir -p /usr/local/nacos/logs mkdir -p /usr/local/nacos/con…

服务端监控要怎么做?

目录 前言 一、Google的四类黄金指标 二、RED方法 三、USE方法 RED方法 vs USE方法 四、监控指标 WEB服务监控 MySQL数据库监控 QPS TPS 最大连接数 缓存监控 总结 前言 众所周知,业界各种大中型软件系统在生产运行时,总会有一些手段来…

ESP32C3 LuatOS TM1650②动态显示累加整数

--注意:因使用了sys.wait()所有api需要在协程中使用 -- 用法实例 PROJECT "ESP32C3_TM1650" VERSION "1.0.0" _G.sys require("sys") local tm1650 require "tm1650"-- 拆分整数,并把最低位数存放在数组最大索引处 loc…

车载网关通信能力解析——SV900-5G车载网关推荐

随着车联网的发展,各类车载设备对车载网关的需求日益增长。车载网关作为车与车、车与路、车与云之间连接的关键设备,其通信能力直接影响整个系统的性能。本文将详细解析车载网关的通信能力,并推荐性价比高的SV900-5G车载网关。 链接直达:https://www.key-iot.com/i…

类与对象-对象特性-构造函数与析构函数

#include<iostream> using namespace std; class Person { public: //构造函数 : //没有返回值,不写void //函数名与类名相同 //构造函数可以有参数,能发生重载 //创建对象时,构造函数会自动调用,而且只调用一次 Person() { cout << &quo…

哈里斯鹰算法优化BP神经网络(HHO-BP)回归预测研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

Python爬虫:ad广告引擎的模拟登录

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ &#x1f434;作者&#xff1a;秋无之地 &#x1f434;简介&#xff1a;CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作&#xff0c;主要擅长领域有&#xff1a;爬虫、后端、大数据…

Ubuntu桌面环境的切换方法

你在找它吗&#xff1f; 国内麒麟、深度等系统虽然界面更炫&#xff0c;但——软件仓库与Ubuntu官方已不兼容。国内系统遇到稳定性问题&#xff0c;还是得拿Ubuntu做参照。今天本来介绍下这款Linux桌面。 为什么在 Ubuntu 上考虑 LXQt&#xff1f; 性能&#xff1a;LXQt设计为…

基于SSM的教务管理系统运行教程

文章目录 1、前期必备1.1、所需软件版本说明1.2、下载源码1.3、下载开发工具1.4、下载JDK并配置环境变量1.5、安装数据库和数据库管理工具1.6、安装配置Maven 2、将SQL文件导入到数据库2.1、新建MySQL连接2.2、新建数据库并导入SQL 3、用Eclipse运行程序3.1、导入educationalMa…

操作系统【OS】操作系统的引导

激活CPU。 激活的CPU读取ROM中的boot程序&#xff0c;将指令寄存器置为BIOS(基本输入输出系统)的第一条指令&#xff0c; 即开始执行BIOS的指令。硬件自检。 启动BIOS程序后&#xff0c;先进行硬件自检&#xff0c;检查硬件是否出现故障。如有故障&#xff0c;主板会发出不同含…

【JavaEE】Java多线程编程案例 -- 多线程篇(3)

Java多线程编程案例 1. 单例模式1.1 代码的简单实现1.2 懒汉模式的线程安全代码 2. 阻塞队列2.1 阻塞队列的概念2.2 使用库中的BlockingDeque2.3 模拟实现阻塞队列2.4 生产者消费者模型 3. 定时器3.1 概念3.2 使用库的定时器 - Timer类3.3 模拟实现定时器 4. 线程池4.1 概念4.2…

JetBrains系列IDE全家桶激活

jetbrains全家桶 正版授权&#xff0c;这里有账号授权的渠道&#xff1a; https://www.mano100.cn/thread-1942-1-1.html 附加授权后的一张图片

YOLOv8改进实战 | 更换主干网络Backbone(二)之轻量化模型GhostnetV2

前言 轻量化网络设计是一种针对移动设备等资源受限环境的深度学习模型设计方法。下面是一些常见的轻量化网络设计方法: 网络剪枝:移除神经网络中冗余的连接和参数,以达到模型压缩和加速的目的。分组卷积:将卷积操作分解为若干个较小的卷积操作,并将它们分别作用于输入的不…

9 线程池

为什么要使用线程池&#xff1f; 1 重复利用已创建的线程&#xff0c;减少线程创建和销毁带来的开销 2 提高响应速度&#xff1a;任务可以不用等待线程创建就能立即执行&#xff08;T1 创建线程 T2执行任务 T3销毁线程&#xff09;&#xff0c;若T1T3>T2&#xff0c;可以通过…

macOS Tor 在启动期间退出

macos Tor 在启动期间退出。这可能是您的 torrc 文件存在错误&#xff0c;或者 Tor 或您的系统中的其他程序存在问题&#xff0c;或者硬件问题。在您解决此问题并重新启动 Tor 前&#xff0c;Tor 浏览器将不会启动。退出Tor、卸载Tor、删除目录 TorBrowser-Data、重启电脑 访…

关于antdpro的EdittableProTable编辑时两个下拉搜索框的数据联动以及数据回显(以及踩坑)

需求&#xff1a;使用antdpro的editprotable编辑两个下拉框&#xff0c;且下拉框是一个搜索下拉框。下拉框1和2的值是一个编码和名字的联动关系&#xff0c;1变化会带动2&#xff0c;2变化会带动1的一个联动作用。&#xff08;最后有略完整的代码&#xff0c;但是因为公司项目问…