对比K近邻算法与决策树算法在MNIST数据集上的分类性能

news2025/1/3 11:20:05

目录

  • 1. 作者介绍
  • 2. K近邻算法与决策树算法介绍
    • 2.1 K近邻(KNN)简介
    • 2.2 决策树算法简介
    • 2.3 MNIST数据集简介:
  • 3. K近邻算法和决策树算法在Mnist数据集分类实验对比
    • 3.1 K近邻算法对Mnist数据集分类实验
    • 3.2 K近邻代码实现
    • 3.3 决策树算法实验
    • 3.4 决策树代码实现
    • 3.5 实验结果对比

1. 作者介绍

郝特吉,男,西安工程大学电子信息学院,2022级研究生
研究方向:机器视觉与人工智能
电子邮件:826844822@qq.com

路治东,男,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组
研究方向:机器视觉与人工智能
电子邮件:2063079527@qq.com

2. K近邻算法与决策树算法介绍

2.1 K近邻(KNN)简介

K近邻是一种经典且简单的监督学习方法,既能够用来解决分类问题,也能够解决回归问题。
原理:当对测试样本进行分类时,通过扫描训练样本集,找到与该测试样本最相似的个训练样本,根据这个样本的类别进行投票确定测试样本的类别。
基本要素:
1.分类决策规则
一般采用少数服从多数的投票制规则,但可以根据具体问题,实现分段距离加权的方式进行,本次KNN主要采用多数服从少数的投票制规则。

2.距离度量
Lp 距离:
在这里插入图片描述

p = 1 ,为曼哈顿距离
p = 2,为欧氏距离
p = ∞ ,为各个坐标距离的最大值
本次实验采用欧氏距离
3.k 值的选择
在这里插入图片描述
本次实验主要对k的值为15的准确率变化进行研究

2.2 决策树算法简介

决策树,是一个类似流程图的树形结构,树内部的每一个节点代表对一个特征的测试,树的分支代表该特征的每一个测试结果,而树的每一个叶子节点代表一个类别。树的最高层是就是根节点。
举个例子,以面试机器学习算法工程师为例,下图说明了如何利用决策树进行面试。
在这里插入图片描述
从中不难总结出决策树的主要问题就是:
1.哪个维度划分?
2.该维度的哪个值划分?

决策树算法的策略
1. 信息熵:
代表随机变量不确定度,熵越大,数据不确定性越高。熵越小,数据不确定性越低。 目的:希望在树节点划分后使信息熵降低。
在这里插入图片描述
在这里插入图片描述
二分类信息熵曲线如上图所示。
**2.基尼指数(基尼不纯度):**表示在样本集合中一个随机选中的样本被分错的概率。
目的:希望在划分后使得基尼指数降低。
在这里插入图片描述
在这里插入图片描述
二分类基尼系数曲线如上图所示

2.3 MNIST数据集简介:

MNIST是一个手写体数字 0-9 的图片数据集,一共统计了来自250个不同的人手写数字图片,其中,每张图片为:28*28的灰度图片,对应标签采用 one-hot -vector 形式编码
Mnist数据集官网:http://yann.lecun.com/exdb/mnist/

MNIST数据集的下载内容:
在这里插入图片描述
使用后自动解压为:
在这里插入图片描述
train_and_test 划分:
Train_datas:60000张
Test_datas:10000张
数据集划分
对数据集进行可视化,下图为Mnist数据集中的图和标签
在这里插入图片描述
下图为Mnist数据集在FDA降维下的分布
在这里插入图片描述

3. K近邻算法和决策树算法在Mnist数据集分类实验对比

3.1 K近邻算法对Mnist数据集分类实验

本次实验对Mnist数据集进行分类,在距离度量为:欧式距离,分类决策规则为:少数服从多数的基础上,研究:在使用KNN算法达到最高准确率的情况下,K在1-5之间的取值。
在测试集(10000)sample:300个样本
在训练集(60000)sample:10000个样本
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.2 K近邻代码实现

import torch
import numpy as np
from torch.utils.data import DataLoader
from torch import nn, optim
from math import sqrt
from torchvision import transforms, datasets
import visdom
from collections import Counter

viz = visdom.Visdom()
batchsize_all = 10000
batchsize = 1

minist_train = datasets.MNIST('minist', True, transform=transforms.Compose([
    transforms.ToTensor()
]), download=True)
minist_train = DataLoader(minist_train, batch_size=batchsize_all, shuffle=True)

minist_test = datasets.MNIST('minist', True, transform=transforms.Compose([
    transforms.ToTensor()
]), download=True)
minist_test = DataLoader(minist_test, batch_size=batchsize, shuffle=True)

# X = []
# Y = []
# for batchidx, (X_train, Y_train) in enumerate(minist_train):
#     X.append(X_train)
#     Y.append(Y_train)
# print(len(Y))
#
# X1 = []
# Y1 = []
# for batchidx, (X_test, Y_test) in enumerate(minist_test):
#     X1.append(X_test)
#     Y1.append(Y_test)
# print(len(Y1))

# x, y = next(iter(minist_train))

# print('x:', x.shape)
# print(y.shape)
# print(y)
acc_sum = 0
sum = 0

viz.line([0], [-1], win='knn_accuracy', opts=dict(title='knn_accuracy'))


k = int(input('请输入选择最近邻的个数:'))
for _ in range(300):
    for batchidx, (X_train, Y_train) in enumerate(minist_train):
        #KNN
        distances = []
        x_test, y_test = next(iter(minist_test))
        viz.images(x_test, nrow=1, win='x', opts=dict(title='x'))
        # print(x_test.shape)
        # print()



        for x_train in X_train:
            x_train = x_train.unsqueeze(0)
            # print(x_train.shape)
            #k_nn
            pp = pow(x_train - x_test, 2).view(1, 28*28)
            D = sqrt(pp.sum(dim=1))
            # print(pp)
    #         D = sqrt(np.sum(((x_train - x_test)**2).view(28*28)))  #欧拉距离
            distances.append(D)
        nearest = np.argsort(distances)  #索引排序从近到远
        # print(nearest)
        k_top = [Y_train[i] for i in nearest[:k]]  #前K个标签值
        nears_value = [X_train[i] for i in k_top]
        votes = Counter(k_top)  #得到投票结果
        pre_label = votes.most_common(1)[0][0]  #预测最可能结果
        sum += 1
        if pre_label == y_test:   #计算分类准确率
            acc_sum += 1
        accuracy = acc_sum / sum
        print('准确率为:', accuracy)
        np_pre_label = pre_label.numpy()
        viz.line([accuracy], [sum], win='knn_accuracy', update='append')
        viz.text(str(np_pre_label), win='pre_label', opts=dict(title='prelabel'))
        break

    # print(pre_label)
    # print(y_test)

3.3 决策树算法实验

本次实验在scikit-learn中集成的决策树CART下进行:
发现:
1.信息熵计算相对较慢,
2.scikit-learn中默认为基尼系数
3.没有特别大的差距
决策树的局限性:严重的过拟合
在这里插入图片描述
下图为1-500个不同深度gini分类准确率曲线
在这里插入图片描述
下图为1-500不同深度entropy分类准确率
在这里插入图片描述

3.4 决策树代码实现

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
import visdom
import numpy as np

mnist = load_digits()
x, test_x, y, test_y = train_test_split(mnist.data, mnist.target, test_size=0.2, random_state=40)

viz = visdom.Visdom()
viz.line([0], [-1], win='decision_tree_accuracy_entropy', opts=dict(title='decision_tree_accuracy_entropy'))
# viz.line([0], [-1], win='decision_tree_accuracy_gini', opts=dict(title='decision_tree_accuracy_gini'))

for i in range(500):
    model = DecisionTreeClassifier(max_depth=i+1, criterion="entropy")
    # model = DecisionTreeClassifier(max_depth=i + 1, criterion="gini")
    model.fit(x, y)
    pre = model.predict(test_x)
    acc = np.sum(pre == test_y) / pre.size
    print('accuracy:', acc)
    viz.line([acc], [i+1], win='decision_tree_accuracy_entropy', update='append')
    # viz.line([acc], [i + 1], win='decision_tree_accuracy_gini', update='append')

3.5 实验结果对比

在本次针对Mnist数据集,分别采用KNN和决策树算法进行分类的对比实验中发现KNN的分类准确率优于决策树!
KNN_average_accuracy = 93.126% > 88.055% = decision_tree_max_accuracy
why?
分析:
Mnist数据集的样本的特殊之处:它是一个250多人手写的数字体,且在Mnist数据集官网上发布的是一个经过居中和裁边处理过的28*28的灰度图。

如果没有居中和裁边处理:
在这里插入图片描述
经过居中裁边处理:
在这里插入图片描述
可以看出,经过裁边的mnist数据集非常适合运用KNN将其区分开来
而灰度图起作用在计算距离时,只考虑其图像的空间分布,没有将颜色通道干扰考虑在内,这样就更高效更有针对性的计算出分布之间的距离!
第二个原因则是,在是决策树本身的局限性和mnist数据集特性的共同作用:
我们可以从mnist的降维分布看出,总会有一些不同类的分布距离十分靠近:
在这里插入图片描述
观察准确率减少时发现分类错误的样本:
在这里插入图片描述
总结:
但是决策树在高维空间中对决策边界的划分,总是有一部分很难被完美区分,而KNN却可以根据K值的选择一定程度上将样本做出正确划分!这就是为什么KNN在MNIST数据集上的分类性能要优于决策树的主要原因!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/631141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3:组件高级(上)

Vue3:组件高级(上) Date: May 20, 2023 Sum: watch倾听器、组件的生命周期、组件之间的数据共享、vue3.x中全局配置axios 目标: 能够掌握 watch 侦听器的基本使用 能够知道 vue 中常用的生命周期函数 能够知道如何实现组件之间…

写 bug 速度提升200%!吊爆的 IDEA 使用技巧

背景 Java 开发过程经常需要编写有固定格式的代码,例如说声明一个私有变量,logger或者bean等等。 对于这种小范围的代码生成,我们可以利用 IDEA 提供的 Live Templates功能。 刚开始觉得它只是一个简单的Code Snippet,后来发现…

msf渗透练习-震网三代

说明: 本章内容,仅供学习,不要用于非法用途(做个好白帽) (一)震网三代漏洞 “震网三代”官方漏洞编号是CVE-2017-8464,2017年6月13日,微软官方发布编号为CVE-2017-8464的…

Redis Cluster集群运维-03

1、Redis集群方案比较 哨兵模式 在redis3.0以前的版本要实现集群一般是借助哨兵sentinel工具来监控master节点的状态,如果master节点异 常,则会做主从切换,将某一台slave作为master,哨兵的配置略微复杂,并且性能和高可…

【CSS】常见的选择器

1.标签选择器 语法 标签 { }作用 标签选择器用于选择某种标签比如 选择p标签,并设置背景颜色 p { background-color:yellow; }例子 选择div标签,并将其字体大小设置为100px,字体设置为"微软雅黑",文字颜色设置为r…

怎么学习渗透测试?路线是什么

我知道很多人肯定觉得,报班什么的太贵了,但是人家贵有贵的道理 owap讲来讲去,还是那个样,但是有人给你解答问题是两个概念,有时候一个虚拟机都能卡死你很久,我就随便说说,我给你们想的学习路线…

2023网络安全工程师面试题汇总(附答案)

又到了毕业季,大四的漂亮学姐即将下架,大一的小学妹还在来的路上,每逢这时候我心中总是有些小惆怅和小激动…… 作为学长,还是要给这些马上要初出茅庐的学弟学妹们,说说走出校园、走向职场要注意哪些方面。 走出校园后…

基于XC7Z100的PCIe采集卡(GMSL FMC采集卡)

GMSL 图像采集卡 特性 ● PCIe Gen2.0 X8 总线; ● 支持V4L2调用; ● 1路CAN接口; ● 6路/12路 GMSL1/2摄像头输入,最高可达8MP; ● 2路可定义相机同步触发输入/输出; 优势 ● 采用PCIe主卡与FMC子…

服务器是什么?它是用来干什么的?

作者:Insist-- 个人主页:insist--个人主页 作者会持续更新网络知识和python基础知识,期待你的关注 目录 一、服务器是什么? 二、服务器的作用 1、提高访问速度 2、提高安全性 三、云服务器与物理服务器 1、云服务器 云服务…

[架构之路-210]- 人人都是产品经理 - 互联网产品需求分析思路和方法笔记

目录 前言: 一、产品需求分析思路和方法--产品需求 1、产品需求的内涵 ①什么是产品? ②什么是需求? ③需求的产品的关系 ④案例分析: ⑤理解需求的误区 2、需求的分类及层次、规律、拆解用户需求 ①需求分类 ②需求层…

算法刷题-链表-设计链表

设计链表 707.设计链表思路代码其他语言版本 听说这道题目把链表常见的五个操作都覆盖了? 707.设计链表 力扣题目链接 题意: 在链表类中实现这些功能: get(index):获取链表中第 index 个节点的值。如果索引无效,则…

MySQL数据库,从入门到精通:第二篇——MySQL关系型数据库与非关系型数据库的比较

MySQL数据库,从入门到精通:第二篇——MySQL关系型数据库与非关系型数据库的比较 1. RDBMS 与 非RDBMS1.1 关系型数据库(RDBMS)1.1.1 实质1.1.2 优势1.2 非关系型数据库(非RDBMS)1.2.1 介绍1.2.2 有哪些非关系型数据库1.2.3 NoSQL的演变1.3 小结 2. 关系型…

SQL开源替代品,诞生了

发明 SQL 的初衷之一显然是为了降低人们实施数据查询计算的难度。SQL 中用了不少类英语的词汇和语法,这是希望非技术人员也能掌握。确实,简单的 SQL 可以当作英语阅读,即使没有程序设计经验的人也能运用。 然而,面对稍稍复杂的查…

Python+QT停车场车牌识别计费管理系统-升级版

程序示例精选 PythonQT停车场车牌识别计费管理系统-升级版 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对<<PythonQT停车场车牌识别计费管理系统-升级版>>编写代码&#xff0c;代…

RabbitMQ入门案例之Direct模式

前言 RabbitMQ的Direct模式是一种可以根据指定路由key&#xff0c;Exchang将消息发送到具有该路由key下的Queue下进行存储。也就类似于将数据写进指定数据库表中。这个路由Key可以类比为SQL语句中的&#xff1a;where routeKey … 官方文档地址&#xff1a;https://www.rabbi…

DragGAN部署全流程

写在前面 看了DragGAN 官方&#xff0c;并没有找到软件&#xff0c;或者程序&#xff0c;github上也没有程序&#xff0c;如果大佬们能找到&#xff0c;可以评论通知下。不过也有技术大佬已经提前开发出来了&#xff0c;我们抢先体验下。 这里本地部署了 DragGAN。经历了报错&…

【LeetCode】HOT 100(5)

题单介绍&#xff1a; 精选 100 道力扣&#xff08;LeetCode&#xff09;上最热门的题目&#xff0c;适合初识算法与数据结构的新手和想要在短时间内高效提升的人&#xff0c;熟练掌握这 100 道题&#xff0c;你就已经具备了在代码世界通行的基本能力。 目录 题单介绍&#…

CTFShow-WEB入门篇--命令执行详细Wp

WEB入门篇--命令执行详细Wp 命令执行&#xff1a;Web29&#xff1a;Web30&#xff1a;Web31&#xff1a;web32&#xff1a;web33&#xff1a;web34&#xff1a;web35&#xff1a;web36&#xff1a;web37&#xff1a;web38&#xff1a; CTFShow 平台&#xff1a;https://ctf.sho…

【Kubernetes资源篇】Service四层代理入门实战详解

文章目录 一、Service四层代理概念、原理1、Service四层代理概念2、Service工作原理3、Service原理解读4、Service四种类型 二、Service四层代理三种类型案例1、创建ClusterIP类型Service2、创建NodePort类型Service3、创建ExternalName类型Service 三、拓展1、Service域名解析…

Nvidia AGX Orin MAX9296 GMSL 载板设计要点

因为项目的需求&#xff0c;我们设计了Nvidia AGX Orin MAX9296 GMSL 载板(绿板&#xff09;&#xff0c;项目完成&#xff0c;总结以下。需要参考原理图的&#xff0c;可以微我&#xff0c;索取。共同交流。 Jetson AGX Orin建立在NVIDIA Ampere架构之上&#xff0c;全新Jets…