机器学习之KNN算法:基于pytorch在MNIST数据集上实现数据分类预测

news2024/9/23 9:29:40

1 KNN算法介绍

KNN算法又叫做K近邻算法,是众多机器学习算法里面最基础入门的算法。KNN算法是最简单的分类算法之一,同时,它也是最常用的分类算法之一。KNN算法是有监督学习中的分类算法,它看起来和Kmeans相似(Kmeans是无监督学习算法),但却是有本质区别的。

KNN算法基于实例之间的相似性进行分类或回归预测。在KNN算法中,要解决的问题是将新的数据点分配给已知类别中的某一类。该算法的核心思想是通过比较距离来确定最近邻的数据点,然后利用这些邻居的类别信息来决定待分类数据点的类别。其核心思想为:“近朱者赤近墨者黑”

4c4b2043083b47c6935026a1ae5f4f69.png

1.1 KNN算法三要素

  • 距离度量算法:一般使用的是欧氏距离。也可以使用其他距离:曼哈顿距离、切比雪夫距离、闵可夫斯基距离等。
  • k值的确定:k值越小,模型整体变得越复杂,越容易过拟合。通常使用交叉验证法来选取最优k值
  • 分类决策:一般使用多数表决,即在 k 个临近的训练点钟的多数类决定输入实例的类。可以证明,多数表决规则等价于经验风险最小化

1.2 KNN是一种非参的,惰性的算法模型。

  • 非参:并不是说这个算法不需要参数,而是意味着这个模型不会对数据做出任何的假设,与之相对的是线性回归总会假设线性回归是一条直线。KNN建立的模型结构是根据数据来决定的,这也比较符合现实的情况。
  • 惰性:同样是分类算法,逻辑回归需要先对数据进行大量训练,最后会得到一个算法模型。而KNN算法却不需要,它没有明确的训练数据的过程,或者说这个过程很快。

1.3 KNN算法的优缺点

(1)KNN算法具有以下优点:

  • 简单易懂:KNN算法的基本思想直观简单,易于理解和实现。

  • 无需训练过程:KNN算法是一种基于实例的学习方法,不需要显式的训练过程。它直接利用已有的训练数据进行分类或回归预测。

  • 适用于多类别问题:KNN算法可以应用于多类别问题,不受类别数目的限制。

  • 对于不平衡数据集有效:KNN算法在处理不平衡数据集时相对较为有效,因为它不假设数据分布的先验知识。

(2)KNN算法的一些缺点:

  • 计算复杂度高:在进行分类或回归预测时,KNN算法需要计算待分类数据点与所有训练数据点之间的距离。当训练数据集较大时,计算复杂度会显著增加。

  • 对特征空间维度敏感:KNN算法对于特征空间的维度敏感。当特征空间维度较高时,由于所谓的"维度灾难",KNN算法的性能可能会下降。在高维数据中,距离度量变得不准确,所有数据点都变得离得很远,失去了近邻的意义。

  • 需要选择合适的K值:KNN算法的性能很大程度上取决于选择合适的最近邻数量K。选择过小的K值可能导致模型过于敏感,容易受到噪声的影响;选择过大的K值可能导致模型过于平滑,无法捕捉到细微的类别特征。

  • 不适用于大规模数据集:由于KNN算法需要在预测阶段计算待分类数据点与所有训练数据点的距离,因此对于大规模数据集来说,存储和计算的开销可能会非常大。

KNN算法是一种简单但强大的分类和回归方法,适用于多种问题领域。但在使用时需要注意计算复杂度、维度敏感性、合适的K值选择以及适应大规模数据集的挑战。

2 KNN算法的应用场景

KNN算法的优点包括简单易懂、无需训练过程、适用于多类别问题等。KNN算法在许多领域中都有广泛的应用,KNN算法常见的应用场景如下:

  • 分类问题:KNN算法可以用于分类问题,如文本分类、图像分类、语音识别等。通过比较待分类数据点与已知数据点之间的相似性,KNN可以将新的数据点分配到最相似的类别中。

  • 回归问题:KNN算法也可以用于回归问题,如房价预测、股票价格预测等。通过计算最近邻数据点的平均值或加权平均值,KNN可以预测待分类数据点的数值属性。

  • 推荐系统:KNN算法可以应用于推荐系统,根据用户之间的相似性来推荐相似兴趣的物品。通过比较用户之间的行为模式或兴趣偏好,KNN可以找到与当前用户最相似的一组用户,并向其推荐相似的物品。

  • 异常检测:KNN算法可以用于检测异常数据点,如信用卡欺诈、网络入侵等。通过计算数据点与其最近邻之间的距离,KNN可以识别与大多数数据点不同的异常数据点。

  • 文本挖掘:KNN算法可以用于文本挖掘任务,如文本分类、情感分析等。通过比较文本之间的相似性,KNN可以将新的文本数据点归类到相应的类别中。

  • 图像处理:KNN算法可以应用于图像处理领域,如图像识别、图像检索等。通过比较图像之间的像素值或特征向量,KNN可以识别和检索相似的图像。

然而,该算法的缺点是计算复杂度高,特别是当训练数据集较大时,需要计算大量的距离。此外,KNN算法对于特征空间的维度敏感,对于高维数据的处理可能会出现问题。

针对部分数据(特征空间维度大,数据容量大)为了提高KNN算法的性能,可以使用特征选择和降维技术来减少特征空间的维度,以及采用KD树等数据结构来加速最近邻搜索过程。

KD Tree 是一种平衡二叉树,目的是实现对 k 维空间的划分。

9a0907d17e734ab9b81d3ff29e76e9d4.png

KDTree形似二叉搜索树,其实KDTree就是二叉搜索树的变种。这里的K = 3(维度).

KD树的组织原则

将每一个元组按0排序(第一项序号为0,第二项序号为1,第三项序号为2),在树的第n层,第 n%3 项被用粗体显示,而这些被粗体显示的树就是作为二叉搜索树的key值,比如,根节点的左子树中的每一个节点的第一个项均小于根节点的的第一项,右子树的节点中第一项均大于根节点的第一项,子树依次类推。

对于这样的一棵树,对其进行搜索节点会非常容易,给定一个元组,首先和根节点比较第一项,小于往左,大于往右,第二层比较第二项,依次类推。
 

KD树检索

假设我们的KDTree通过样本集{(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)}创建的。
我们来查找点(2.1,3.1),在(7,2)点测试到达(5,4),在(5,4)点测试到达(2,3),然后search_path中的结点为<(7,2), (5,4), (2,3)>,从search_path中取出(2,3)作为当前最佳结点nearest, dist为0.141 (欧氏距离);
然后回溯至(5,4),以(2.1,3.1)为圆心,以dist=0.141为半径画一个圆,并不和超平面y=4相交,如下图,所以不必跳到结点(5,4)的右子空间去搜索,因为右子空间中不可能有更近样本点了。
于是在回溯至(7,2),同理,以(2.1,3.1)为圆心,以dist=0.141为半径画一个圆并不和超平面x=7相交,所以也不用跳到结点(7,2)的右子空间去搜索。
至此,search_path为空,结束整个搜索,返回nearest(2,3)作为(2.1,3.1)的最近邻点,最近距离为0.141。
b786abe45e6847c2a8b21642887e0e7e.png

再举一个稍微复杂的例子,我们来查找点(2,4.5),在(7,2)处测试到达(5,4),在(5,4)处测试到达(4,7),然后search_path中的结点为<(7,2), (5,4), (4,7)>,从search_path中取出(4,7)作为当前最佳结点nearest, dist为3.202;
然后回溯至(5,4),以(2,4.5)为圆心,以dist=3.202为半径画一个圆与超平面y=4相交,如下图,所以需要跳到(5,4)的左子空间去搜索。所以要将(2,3)加入到search_path中,现在search_path中的结点为<(7,2), (2, 3)>;另外,(5,4)与(2,4.5)的距离为3.04 < dist = 3.202,所以将(5,4)赋给nearest,并且dist=3.04。
回溯至(2,3),(2,3)是叶子节点,直接平判断(2,3)是否离(2,4.5)更近,计算得到距离为1.5,所以nearest更新为(2,3),dist更新为(1.5)
回溯至(7,2),同理,以(2,4.5)为圆心,以dist=1.5为半径画一个圆并不和超平面x=7相交, 所以不用跳到结点(7,2)的右子空间去搜索。
至此,search_path为空,结束整个搜索,返回nearest(2,3)作为(2,4.5)的最近邻点,最近距离为1.5。

78f376e44c84495eae76808a3464a00a.png

2bc2d2c547724429a7567ed91d80efae.png

 3 基于pytorch在MNIST数据集上实现数据分类

3.1 获取MNIST数据集

(1)代码自动下载

train_dataset = datasets.MNIST(root='data',  # 选择数据的根目录
                            train=True,  # 选择训练集
                            transform=None,  # 不使用任何数据预处理
                            download=True)  # 从网络上下载图片

test_dataset = datasets.MNIST(root='data',  # 选择数据的根目录
                           train=False,  # 选择测试集
                           transform=None,  # 不适用任何数据预处理
                           download=True)  # 从网络上下载图片

但这个自动下载可能会出错,错误如下:

urllib.error.ContentTooShortError: <urlopen error retrieval incomplete: got only 5303709 out of 9912422 bytes>

 (2)手工下载数据集

下载地址:MNIST数据

下载完成后,放到data/MNIST/raw目录下

图片内容展示:

digit = train_loader.dataset.data[0] 
plt.imshow(digit, cmap=plt.cm.binary)
plt.show()
print(train_loader.dataset.targets[0])

4293984ad8ee4c03a4b9c3d7fb3bc4d1.png

 

3.2 KNN计算

以MNIST的60000张图片作为训练集,通过KNN计算对测试数据集的10000张图片全部打上标签。通过KNN算法比较测试图片与训练集中每一张图片,然后将它认为最相似的那个训练集图片的标签赋给这张测试图片
具体应该如何比较这两张图片呢?在本例中,比较图片就是比较28×28的像素块。最简单的方法就是逐个像素进行比较,最后将差异值全部加起来两张图片使用L1距离来进行比较。逐个像素求差值,然后将所有差值加起来得到一个数值。如果两张图片一模一样,那么L1距离为0,但是如果两张图片差别很大,那么,L1的值将会非常大。

def KNN_classify(k, dis_func, train_data, train_label, test_data):
    num_test = test_data.shape[0]  # 测试样本的数量
    label_list = []
    for idx in range(num_test):
        distances = dis_func(train_data, test_data[idx])
        nearest_k = np.argsort(distances)
        top_k = nearest_k[:k]  # 选取前k个距离
        class_count = {}
        for j in top_k:
            class_count[train_label[j]] = class_count.get(train_label[j], 0) + 1
        sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
        label_list.append(sorted_class_count[0][0])

    return np.array(label_list)

3.3 完整代码

#!/usr/bin/env python
# coding: utf-8


import operator
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


batch_size = 100
train_dataset = datasets.MNIST(root='data',  # 选择数据的根目录
                            train=True,  # 选择训练集
                            transform=None,  # 不使用任何数据预处理
                            download=True)  # 从网络上下载图片

test_dataset = datasets.MNIST(root='data',  # 选择数据的根目录
                           train=False,  # 选择测试集
                           transform=None,  # 不适用任何数据预处理
                           download=True)  # 从网络上下载图片

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

print("train_data:", train_dataset.data.size())
print("train_labels:", train_dataset.data.size())
print("test_data:", test_dataset.data.size())
print("test_labels:", test_dataset.data.size())


# digit = train_loader.dataset.data[0]  # 取第一个图片的数据
# plt.imshow(digit, cmap=plt.cm.binary)
# plt.show()
# print(train_loader.dataset.targets[0])

# 欧式顿距离计算
def e_distance(dataset_a, data_b):
    return np.sqrt(np.sum(((dataset_a - np.tile(data_b, (dataset_a.shape[0], 1))) ** 2), axis=1))

# 曼哈顿距离计算
def m_distance(dataset_a, data_b):
    return np.sum(np.abs(train_data - np.tile(test_data[i], (train_data.shape[0], 1))), axis=1)


def KNN_classify(k, dis_func, train_data, train_label, test_data):
    num_test = test_data.shape[0]  # 测试样本的数量
    label_list = []
    for idx in range(num_test):
        distances = dis_func(train_data, test_data[idx])
        nearest_k = np.argsort(distances)
        top_k = nearest_k[:k]  # 选取前k个距离
        class_count = {}
        for j in top_k:
            class_count[train_label[j]] = class_count.get(train_label[j], 0) + 1
        sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
        label_list.append(sorted_class_count[0][0])

    return np.array(label_list)


def get_mean(data):
    data = np.reshape(data, (data.shape[0], -1))
    mean_image = np.mean(data, axis=0)
    return mean_image


def centralized(data, mean_image):
    data = data.reshape((data.shape[0], -1))
    data = data.astype(np.float64)
    data -= mean_image  # 减去图像均值,实现领均值化
    return data


if __name__ == '__main__':
    # 训练数据
    train_data = train_loader.dataset.data.numpy()
    train_data = train_data.reshape(train_data.shape[0], 28 * 28)
    
    # 归一化处理
    mean_image = get_mean(train_data)  # 计算所有图像均值
    train_data = centralized(train_data, mean_image)
    
    print('train_data shape:', train_data.shape)
    train_label = train_loader.dataset.targets.numpy()
    print('train_lable shape', train_label.shape)

    # 测试数据
    test_data = test_loader.dataset.data[:1000].numpy()
    test_data = centralized(test_data, mean_image)
    test_data = test_data.reshape(test_data.shape[0], 28 * 28)
    print('test_data shape', test_data.shape)
    test_label = test_loader.dataset.targets[:1000].numpy()
    print('test_label shape', test_label.shape)

    # 训练
    test_label_pred = KNN_classify(5, e_distance, train_data, train_label, test_data)

    # 得到训练准确率
    num_test = test_data.shape[0]
    num_correct = np.sum(test_label == test_label_pred)
    print(num_correct)
    accuracy = float(num_correct) / num_test
    print('Got %d / %d correct => accuracy: %f' % (num_correct, num_test, accuracy))

3.4 计算结果展示

train_data: torch.Size([60000, 28, 28])
train_labels: torch.Size([60000, 28, 28])
test_data: torch.Size([10000, 28, 28])
test_labels: torch.Size([10000, 28, 28])
train_data shape: (60000, 784)
train_lable shape (60000,)
test_data shape (1000, 784)
test_label shape (1000,)
963
Got 963 / 1000 correct => accuracy: 0.963000

使用欧氏距离计算,最终结果准确率达到了96.3%

4 完整工程及数据下载

下载地址:代码和数据

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/658882.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CMake中的find_package(xxx REQUIRED)在windows平台怎么解

最近在编译FastDDS时&#xff0c;遇到了这个问题&#xff0c;使用CMake构建时提示找不到库。 下载的源代码不能一次性编过是最让人头疼的问题&#xff0c;这种开源代码通常都是迭代了很多版本&#xff0c;各种配置信息如果不在文档中说明&#xff0c;全靠自己去摸索确实会让人头…

idea运行java项目提示异常: java.security.InvalidKeyException: Illegal key size

idea运行java项目提示异常&#xff1a;java.lang.IllegalArgumentException: java.security.InvalidKeyException: Illegal key size 参考&#xff1a;java.security.InvalidKeyException: Illegal key size_gqltt的博客-CSDN博客 产生错误原因&#xff1a;为了数据代码在传输过…

4、做什么类型的产品经理

1、如何选择适合自己的产品经理岗位 怎么选择适合自己的这个产品经理岗位呢&#xff1f;建议大家是先考虑行业&#xff0c;再考虑其他的。 考虑行业就是说我要做什么行业的产品经理,然后再考虑在这个行业里面具体的你要做前端还是后端或者是APP端&#xff0c;还是web端&#x…

【MySQL】不就是MySQL——索引

前言 嗨&#xff01;小伙伴们周末快乐呀&#xff01;想必你们周末都在家里边呆着吧&#xff0c;外面实在是太热了&#xff01;在家里吹着空调做着自己喜欢做的事情吧&#xff01;本期我们主要学习的是MySQL中的约束条件。 目录 前言 索引概述 外键约束 1.概念 2.语法 1.添加…

【HTML界面设计(二)】说说模块、登录界面

记录很早之前写的前端界面&#xff08;具体时间有点久远&#xff09; 一、说说模板 采用 适配器&#xff08;Adapter&#xff09;原理 来设计这款说说模板&#xff0c;首先看一下完整效果 这是demo样图&#xff0c;需要通过业务需求进行修改的部分 这一部分&#xff0c;就是dem…

ch8_2_CPU的指令周期,流水线技术

1.  指令周期 指令周期是指_ CPU从主存取出一条指令, 分析指令&#xff0c;加上执行这条指令的时间。 1.1指令周期 指令周期&#xff1a; 是指cpu&#xff0c;从内存中取出指令&#xff0c;并且执行一条指令所需要的全部时间。 比如 从内存单元中&#xff0c;取出操作数&…

【使用Neo4j进行图数据可视化】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

“面试造火箭,入职拧螺丝”2023最新最全的Java开发八股文合集来了

前言 金三银四招聘旺季马上就到了&#xff0c;不知道大家是否准备好了&#xff0c;面对金三银四的招聘旺季&#xff0c;如果没有精心准备那笔者认为那是对自己不负责任&#xff1b;就我们 Java 程序员来说&#xff0c;多数的公司总体上面试都是以自我介绍项目介绍项目细节/难点…

Java016——Java输入输出语句

一、输出语句 Java常用的输出语句有三种&#xff1a; 1&#xff09;System.out.println(); 换行输出&#xff0c;输出后会自动换行。 //示例 System.out.println("Hello"); System.out.println("World");//输出 Hello World2&#xff09;System.out.pri…

LIN-物理层(收发器)

文章目录 一、显性和隐性二、LIN的供电电压说明三、LIN通道数3.1 单通道3.2 双通道3.3 四通道 一、显性和隐性 LIN总线协议规定其物理层收发器的显性&#xff08;Dominant , 逻辑 “ 0”&#xff0c;电气特性为GND(0V)&#xff09;和隐性电平&#xff08;Recessive , 逻辑 “ …

cgi接口原理(boa服务器)

CGI&#xff1a;通用网关接口&#xff08;Common Gateway Interface&#xff09;是一个Web服务器主机提供信息服务的标准接口。通过CGI接口&#xff0c;Web服务器就能够获取客户端提交的信息&#xff0c;转交给服务器端的CGI程序进行处理&#xff0c;最后返回结果给客户端。 b…

字符串概述

字符串 一、API二、字符串2.1字符串的构造方法2.2 字符串构造时的内存2.2.1 直接赋值时的内存模型2.2.2 由new创建时的内存模型 2.3 字符串的比较三、StringBuilder 一、API 目前已学过的两个API&#xff1a;Random和Scanner。 对记不清的API可以去JDK-API帮助文档进行查找。 …

基于matlab对现代相控阵系统中常用的子阵列进行建模分析

一、前言 本示例说明如何使用相控阵系统工具箱对现代相控阵系统中常用的子阵列进行建模并进行分析。 相控阵天线与传统碟形天线相比具有许多优势。相控阵天线的元件更容易制造;整个系统受组件故障的影响较小;最重要的是&#xff0c;可以向不同方向进行电子扫描。 但是&#xff…

分布式系统学习第一天 fastDFS框架学习

目录 1. 项目架构图 1.1 一些概念 1.2 项目架构图 2. 分布式文件系统 2.1 传统文件系统 3. FastDFS 3.1 fastDFS介绍 3.2 fastDFS安装 3.3 fastDFS配置文件 3.4 fastDFS的启动 3.5 对file_id的解释 4. 上传下载代码实现 5. 源码安装 - 回顾 1. 项目架构图 1.1 一…

JDK8-2-流(2)- 流操作

JDK8-2-流&#xff08;2&#xff09;- 流操作 上篇 JDK8-2-流&#xff08;1&#xff09;-简介 中简单介绍了什么是流以及使用流的好处&#xff0c;本篇主要介绍流的操作类型以及如何操作。 如何返回一个流 ① collection.stream 即调用集合 java.util.Collection 下的 stre…

大学生如何申请一台免费服务器?

大学生如何申请一台免费服务器&#xff1f;阿里云学生服务器免费申请&#xff1a;高效计划&#xff0c;可以免费领取一台阿里云服务器&#xff0c;如果你是一名高校学生&#xff0c;想搭建一个linux学习环境、git代码托管服务器&#xff0c;或者创建个人博客网站记录自己的学习…

【小米技术分享】MySQL:一条数据的存储之旅

大家好&#xff0c;我是你们的小米&#xff0c;一个热爱技术分享的活泼小伙伴&#xff01;今天&#xff0c;我来给大家揭开一个神秘的面纱&#xff0c;带你们深入了解一下MySQL数据库是如何保存一条数据的。 客户端 首先&#xff0c;让我们从客户端&#xff08;Client&#x…

【雕爷学编程】Arduino动手做(114)---US-015高分辨超声波模块

37款传感器与执行器的提法&#xff0c;在网络上广泛流传&#xff0c;其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块&#xff0c;依照实践出真知&#xff08;一定要动手做&#xff09;的理念&#xff0c;以学习和交流为目的&am…

【免费】【sci】考虑不同充电需求的电动汽车有序充电调度方法(含matlab代码)

目录 1 主要内容 2 部分代码 3 程序结果 4 下载链接 1 主要内容 该程序复现sci文献《A coordinated charging scheduling method for electric vehicles considering different charging demands》&#xff0c;主要实现电动汽车协调充电调度方法&#xff0c;该方法主要有以…

如何使用PyTorch 在 OpenAI Gym 上的 CartPole-v0 任务上训练深度 Q 学习(DQN)智能体

强化学习&#xff08;DQN&#xff09;教程 本教程说明如何使用 PyTorch 在 OpenAI Gym 上的 CartPole-v0 任务上训练深度 Q 学习&#xff08;DQN&#xff09;智能体。 任务 智能体必须在两个动作之间做出决定-向左或向右移动推车-以便使与之相连的杆子保持直立。 您可以在 G…