【AI】Python 实现 KNN 手写数字识别

news2025/1/13 17:27:00

KNN 算法

1. 题目介绍

K近邻(K-Nearest Neighbor, KNN)是一种最经典和最简单的有监督学习方法之一。K-近邻算法是最简单的分类器,没有显式的学习过程或训练过程,是懒惰学习(Lazy Learning)。当对数据的分布只有很少或者没有任何先验知识时,K 近邻算法是一个不错的选择。

从背景上来说,KNN 并不复杂,本文不介绍 KNN 的原理,重点关注如何使用 KNN 来实现手写数字的识别。具体来说,本文使用两种办法来实现 KNN,第一种是使用 numpy 手动实现该算法,第二种是使用 sklearn 中封装好的 KNN 接口,并会简要比较一下两种办法。

本文使用的数据集采用文本文件,每一个文件使用大小为 32 × 32 32×32 32×32 的 0-1 阵列来表示一个手写数字。我们的目标是输入一张这样的图片,然后返回对该图片的预测值。例如下面的几张图片都表示手写数字 ‘0’:

在这里插入图片描述

点击链接下载数据集 (有效期至 2023/12/5,过期请留言更新)

2. 代码编排

本实验使用 jupyter 完成,下面按照 cell 的顺序进行介绍。点击下载可执行文件

2.1 全局定义

首先是导入整个项目需要使用到的库,并且定义一些全局变量。training_dir 和 test_dir 分别是训练集和测试集的目录地址,虽然 KNN 中严格来说不存在“训练”和“测试”的概念,但此处把“训练集”理解作空间中已有的那些点,“测试集”就是输入的待分类的点:

import os
import numpy as np
import operator
from sklearn.neighbors import KNeighborsClassifier as kNN
import time
training_dir = 'data/knn-digits/training_digits'
test_dir = 'data/knn-digits/test_digits'
k_global = 3

然后定义对数据集的处理方法。每个文件是 32 × 32 32×32 32×32 的 0-1 阵列,所以我们把他转化为 1 × 1024 1×1024 1×1024 的单行数据,再将单个的数据全部拼接在一起;训练集中一共有 1934 个文件,则最终得到的训练集的大小为 1934 × 1024 1934×1024 1934×1024

对于测试集,也采用和训练集很类似的方法,但我们希望每提取到一个文件,就对它跑一遍 KNN 算法,以此提高程序的并发度,这里使用了 yield 方法。

# 将32*32的数据转为1*1024的数据
def img2vector(filename):
    return_vector = np.zeros((1, 1024))
    f = open(filename)
    for i in range(32):
        line = f.readline()
        for j in range(32):
            return_vector[0, 32 * i + j] = int(line[j])
    return return_vector


def load_training_data():
    training_label = []
    training_file_list = os.listdir(training_dir)
    training_size = len(training_file_list)
    training_data = np.zeros((training_size, 1024))
    for i in range(training_size):
        filename = training_file_list[i]
        label = int(filename.split('_')[0])
        training_label.append(label)
        training_data[i, :] = img2vector(training_dir + '/' + filename)
    return training_data, training_label


def load_test_data():
    test_file_list = os.listdir(test_dir)
    test_size = len(test_file_list)
    for i in range(test_size):
        filename = test_file_list[i]
        label = int(filename.split('_')[0])
        test_data = img2vector(test_dir + '/' + filename)
        yield test_data, label

2.2 使用 numpy 实现 KNN

我们使用两个函数配合来实现 KNN 算法。第一个函数 classify0 用来对单条数据进行分类,它计算测试点 (shape=(1, 1024)) 与训练数据 (shape=(1934, 1024)) 中每一个点分别的欧式距离,得到一个 1934 大小的一维数组,再从其中挑选 k_global 条距离最近的训练点,将这些点的标签作为 KNN 做出决策的标准。

# 对单条数据进行分类
def classify0(in_data, data_set, labels, k):
    data_size = data_set.shape[0]
    diff_mat = np.tile(in_data, (data_size, 1)) - data_set
    distances = (diff_mat ** 2).sum(axis=1) ** 0.5
    argsort_distances = distances.argsort()
    class_count = {}
    for i in range(k):
        label = labels[argsort_distances[i]]
        class_count[label] = class_count.get(label, 0) + 1
    sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count[0][0]


# knn的总体流程
def knn():
    error_count = 0
    correct_count = 0
    training_data, training_label = load_training_data()
    for test_data, test_label in load_test_data():
        pred_label = classify0(test_data, training_data, training_label, k_global)
        if pred_label == test_label:
            correct_count += 1
        else:
            error_count += 1
    num_test = error_count + correct_count
    acc = correct_count / (correct_count + error_count)
    print('test number: %d, failure number: %d, accuracy: %.6f' % (num_test, error_count, acc))

下面执行上述的 knn 函数,并记录它所花费的时间:

time_begin = time.time()
print('use knn implementing from scratch:')
knn()
time_end = time.time()
print('took %f.4 s' % (time_end - time_begin))

程序输出为:

在这里插入图片描述

可以看到,一共测试了 946 张图片,仅有 10 张分类错误了,正确率高达 98.94%,效果还是非常不错的。但是它花费了 14.66s,这个时间相比于 sklearn 中现成的接口还是稍显慢的。

2.3 使用 sklearn 实现 KNN

总的来说,sklearn 中的 knn 接口主要就是替代了上文中的 classify0 函数,主体的逻辑流程和之前手动实现的 knn 函数还是很类似的:

def knn_sklearn(algorithm):
    error_count = 0
    correct_count = 0
    training_data, training_label = load_training_data()
    classifier = kNN(n_neighbors=k_global, algorithm=algorithm)
    classifier.fit(training_data, training_label)
    for test_data, test_label in load_test_data():
        pred_label = classifier.predict(test_data)
        if pred_label == test_label:
            correct_count += 1
        else:
            error_count += 1
    num_test = error_count + correct_count
    acc = correct_count / (correct_count + error_count)
    print('test number: %d, failure number: %d, accuracy: %.6f' % (num_test, error_count, acc))

kNN 函数中有一个参数 algorithm,这个参数决定快速 k 近邻搜索算法,默认为 auto,可以理解为算法自己决定合适的搜索算法。除此之外,可取的值有 kd_tree、ball_tree、brute。

其中,kd_tree 参数构造 kd 树存储数据以便对其进行快速检索的树形数据结构,kd 树也就是数据结构中的二叉树,以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高;ball_tree是为了克服 kd 树高纬失效而发明的,其构造过程是以质心 C 和半径 r 分割样本空间,每个节点是一个超球体;brute 是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。

下面分别演示这四个参数对程序性能的影响:

① auto:

time_begin = time.time()
print('use knn from sklearn:')
knn_sklearn(algorithm='auto')
time_end = time.time()
print('took %f.4 s' % (time_end - time_begin))

输出结果为:
在这里插入图片描述

② brute:

time_begin = time.time()
print('use knn from sklearn:')
knn_sklearn(algorithm='brute')
time_end = time.time()
print('took %f.4 s' % (time_end - time_begin))

输出结果为:
在这里插入图片描述

③ kd_tree:

time_begin = time.time()
print('use knn from sklearn:')
knn_sklearn(algorithm='kd_tree')
time_end = time.time()
print('took %f.4 s' % (time_end - time_begin))

输出结果为:

在这里插入图片描述

④ ball_tree:

time_begin = time.time()
print('use knn from sklearn:')
knn_sklearn(algorithm='ball_tree')
time_end = time.time()
print('took %f.4 s' % (time_end - time_begin))

输出结果为:

在这里插入图片描述

3. 结果分析

基于上述的代码,在不同的 k_global 值下分别测试了 numpy 实现的 KNN、sklearn 中的 KNN (algorithm=‘auto’) 的性能,得到的表格如下:

k_globalnumpy实现的KNNsklearn中实现的KNN
1t=14.86s, acc=98.73%t=6.40s, acc=98.63%
3t=14.79s, acc=98.94%t=8.43s, acc=98.73%
5t=14.35s, acc=98.20%t=7.55s, acc=98.10%
10t=14.59s, acc=98.89%t=6.63s, acc=97.57%
20t=14.73s, acc=97.15%t=6.47s, acc=96.83%

从上面的表格可以得到几个基本的结论:① k_global 的值不能过大也不能过小,在本实验中,该值为 3 时可以取得较高的精度;② 随着 k_global 的增大,模型所消耗的时间差异并不大,所以不用为了节省时间而选择一个较小的 k_global;③ 使用 numpy 实现的 KNN 耗时总是远高于 sklearn 中的 KNN,但前者的精度只是略高于后者,实际的项目中要根据数据集大小来在时间和精度中取一个 trade-off。

在 k_global = 3 的前提下,比较 sklearn 中的 KNN 在不同 algorithm 参数下的性能,得到下面的表格:

algorithmtimeacc
auto6.52s98.73%
brute6.48s98.73%
kd_tree5.83s98.73%
ball_tree4.71s98.73%

观察上表,结合上文对 algorithm 参数的介绍,可以得到一个基本的结论:在本实验中,由于样本量并不多,ball_tree 可以使算法用时最少,而 brute 使算法耗时最大(因为它是线性扫描的);默认情况下,auto 参数选择的可能是 brute 参数,因为它们非常接近。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/62665.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何搭建一个自己的音乐服务器 审核中

点赞再看,动力无限。 微信搜「 程序猿阿朗 」。 本文 Github.com/niumoo/JavaNotes 和 未读代码博客 已经收录,有很多知识点和系列文章。 最近发现,经常用的网易云音乐,有很多歌曲下架了,能听的越来越少了;…

设计模式之中介者模式(十五)

目录 1. 背景 1.1 智能家庭管理项目 1.2 中介者模式概述 2. 中介者模式 2.1 中介者模式解决上述问题 1. 背景 1.1 智能家庭管理项目 智能家庭项目: 智能家庭包括各种设备,闹钟、咖啡机、电视机、窗帘 等。主人要看电视时,各个设备可以协…

7 支持向量机

支持向量机 支持向量机(SVM)是在统计学习理论基础上发展起来的一种数据挖掘方法,1992 年由Boser, Guyon和Vapnik提出,在解决小样本、非线性、高维的回归和分类问题上, 有许多优势。 1 支持向量分类概述 支持向量分类以训练样本集为数据对象…

支持向量机核技巧:10个常用的核函数总结

支持向量机是一种监督学习技术,主要用于分类,也可用于回归。它的关键概念是算法搜索最佳的可用于基于标记数据(训练数据)对新数据点进行分类的超平面。 一般情况下算法试图学习一个类的最常见特征(区分一个类与另一个类的特征),分类是基于学…

[附源码]JAVA毕业设计律师事务所网站(系统+LW)

[附源码]JAVA毕业设计律师事务所网站(系统LW) 项目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术&a…

ubuntu20.04安装anaconda3搭建python环境

1.清华源下载anaconda3 清华源anaconda软件镜像网站: Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror 下载完成在终端打开Downloads 运行 bash Anaconda3-5.3.1-Linux-x86_64.sh 进入到下载页面 紧接着你可以使用conda --…

HTTP抓包神器---Fiddler

今天我们介绍一个好用的抓包工具.是针对于HTTP协议的抓包工具. Fiddler 下载地址 下载工具当然是要去官网下载啦. 这里为了防止有些人在网上找不到下载路径.我们直接把下载路径放在下面. https://www.telerik.com/download/fiddler 下载 点击上述链接以后会直接跳转到下…

蓝海创意云·11月大事记 || 12月,暖心相伴

秋尽冬生,日短天寒 告别了立冬与小雪 时光不紧不慢开启了新一月的篇章 万物冬藏,沉淀酝酿 站在十二月的路口 蛰伏打磨,静待厚积而薄发 导 读 ● 客户端更新:新增PSD通道合成选项 ● 渲染案例:绝代双骄重启江湖…

K8S - Pod 的概念和简介

1. POD的基本概念 Pod 是K8s 系统中可以创建(部署)和管理的最小单元。 Pod 里面可以包含多个容器(多实例),是一组容器的集合。 也就是讲K8S 不会直接管理容器 1个POD中的容器共享网络命名空间(共享ip) P…

MongoDB_前期准备(一)

目录一、数据库(Database)数据库分类1、关系型数据库(RDBMS)2、非关系型数据库(No SQL)二、MongoDB简介1)MongoDB VS MySql2)MongoDB中的三个概念3) MongoDB安装一、数据…

Reading Note(10)——AutoBridge

这篇论文是FPGA 2021年的best paper award,主要解决的是在HLS编译过程中优化布局和布线,最终达到整个multi-die的FPGA板上的大规模HLS设计时钟频率尽可能提升的目的,这篇工作在当前chiplet工艺铺展开来的当下更加有现实意义,通过这…

代码随想录训练营第41天|LeetCode 343. 整数拆分、 96.不同的二叉搜索树

参考 代码随想录 题目一&#xff1a;LeetCode 343.整数拆分 确定dp数组及其下标的含义 dp[i]为整数i拆分后得到的最大化乘积。确定递推公式 dp[i]可以有两种方式得到&#xff1a; dp[i] j * (i-j)&#xff0c;即只拆分成两个数&#xff0c;其中1 < j < i/2&#xff…

Nginx的高可用集群

1、什么是 nginx高可用 只有一台nginx服务器时&#xff0c;如果nginx服务器宕机了&#xff0c;那么请求就无法访问。 要实现高可用&#xff0c;那就可以部署多台nginx服务器&#xff0c;下面以两台nginx服务器为例&#xff0c;示意图如下&#xff1a; 要配置nginx集群&#xf…

西部学刊杂志西部学刊杂志社西部学刊编辑部2022年第22期目录

百年党建与马克思主义中国化研究 党的纪律建设的实践、启示与创新——基于“三大纪律八项注意”的研究 武艳; 5-8 西部研究《西部学刊》投稿&#xff1a;cn7kantougao163.com 新疆红色资源运用现状调查研究——以南疆部分地区为例 王艺潼;努尔古扎丽阿不都克里木; 9-12…

BP神经网络对指纹识别的应用(Matlab代码实现)

目录 &#x1f4a5;1 概述 &#x1f4da;2 运行结果 &#x1f389;3 参考文献 &#x1f468;‍&#x1f4bb;4 Matlab代码 &#x1f4a5;1 概述 在现代计算机具有强大的计算和信息处理能力的今天,指纹识别作为个人身份鉴定等领域的热点问题一直被人们长期关注着,目前也得到…

版本控制 | 一文了解什么是组件化开发,以及如何从单体架构转向组件化开发

传统开发模式中&#xff0c;所有代码都写在APP模块中。随着项目的发展&#xff0c;代码量逐渐庞大&#xff0c;编译时间越来越长。为了方便后续项目的开发和测试、提高编译性能&#xff0c;您需要了解组件化开发&#xff0c;以及如何利用版本控制系统从单体架构转向组件化开发。…

【Python自然语言处理】使用SVM、随机森林法、梯度法等多种方法对病人罹患癌症预测实战(超详细 附源码)

需要源码和数据集请点赞关注收藏后评论区留言私信~~~ 一、数据集背景 乳腺癌数据集是由加州大学欧文分校维护的 UCI 机器学习存储库。 数据集包含 569 个恶性和良性肿瘤细胞样本。 样本类别分布&#xff1a;良性357&#xff0c;恶性212 数据集中的前两列分别存储样本的唯一 …

Prototypical Networks for Few-shot Learning

摘要 我们为零样本分类问题提出了一个原型网络。在这里分类器必须能够被泛化到新类别&#xff08;在训练集中不可见&#xff09;&#xff0c;每个新类只给出少量示例。 原型网络能够学习一个度量空间&#xff0c;通过计算每个类别的原型表示距离实现分类。与少样本学习近几年的…

华为机试 - 完全二叉树非叶子部分后序遍历

目录 题目描述 输入描述 输出描述 用例 题目解析 算法源码 题目描述 给定一个以顺序储存结构存储整数值的完全二叉树序列&#xff08;最多1000个整数&#xff09;&#xff0c;请找出此完全二叉树的所有非叶子节点部分&#xff0c;然后采用后序遍历方式将此部分树&#x…

AOP事务管理(下)

Transactional注解可以设置参数。 readOnly&#xff1a;true只读事务&#xff0c;false读写事务&#xff0c;增删改要设为false,查询设为true。 timeout:设置超时时间单位秒&#xff0c;在多长时间之内事务没有提交成功就自动回滚&#xff0c;-1表示不设置超 时时间。 rollbac…