机器学习入门实例-MNIST手写数据集-简单探索二分分类

news2025/1/13 13:17:52

MNIST数据集介绍

MNIST数据集包含7w张带标签的手写数字图片。每次有新的分类算法出现时,常常会在改数据集测试效果。

from sklearn.datasets import fetch_openml

# 获取的mnist是一个字典
mnist = fetch_openml('mnist_784', version=1)
print(mnist.keys())
# dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
# DESCR是描述信息,data是数据集,target是标签

X, y = mnist["data"], mnist["target"]
print(X.shape)
print(y.shape)
# (70000, 784)
# (70000,)

X.shape表示一共7w张图,每个有784个特征。每个特征是28 x 28 像素中的一个点的数值,在0(白)~ 255(黑)之间。

查看其中一个图:

import matplotlib.pyplot as plt
print(y[5])
print(type(y[5]))
some_digit = X[5]
some_digit_image = some_digit.reshape(28, 28)
# cmp表示颜色映射,即实数值通过什么方法转成RGB图像。常用的还有'viridis'(很多颜色细节)、
# 'gray'(适合灰度图像)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()

# 输出: 2
# <class 'str'>

在这里插入图片描述
注意,这段代码很可能报错如下:

Traceback (most recent call last):
  File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 3621, in get_loc
    return self._engine.get_loc(casted_key)
  File "pandas\_libs\index.pyx", line 136, in pandas._libs.index.IndexEngine.get_loc
  File "pandas\_libs\index.pyx", line 163, in pandas._libs.index.IndexEngine.get_loc
  File "pandas\_libs\hashtable_class_helper.pxi", line 5198, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas\_libs\hashtable_class_helper.pxi", line 5206, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\xxxx\Desktop\study\classification.py", line 24, in <module>
    plt.imshow(X[0], cmap="gray")
  File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\frame.py", line 3505, in __getitem__
    indexer = self.columns.get_loc(key)
  File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 3623, in get_loc
    raise KeyError(key) from err
KeyError: 0

解决:获取数据集时添加参数 as_frame=False。 这个表示以原格式返回。

mnist = fetch_openml('mnist_784', version=1, as_frame=False)

保存数据集到本地&导入

pickle库可以序列化任何Python对象,所以可以用它保存数据集到本地。

from sklearn.datasets import fetch_openml
import pickle

# 获取数据集
# False表示以原始格式返回,每个特征是一个单独的数组。True表示返回Pandas
# DataFrame对象
# 自0.24.0(2020 年 12 月)以来,as_frame参数为auto(而不是之前的False默认选项)
mnist = fetch_openml('mnist_784', version=1, as_frame=False)

# 保存数据集到本地
with open('mnist_data.pkl', 'wb') as f:
    pickle.dump(mnist, f)

# 从本地导入数据集
with open('mnist_data.pkl', 'rb') as f:
    mnist = pickle.load(f)
    X, y = mnist["data"], mnist["target"]

划分训练集和测试集

MNIST已经划分好了,前6w个是训练集,后1w个是测试集。而且已经打乱过顺序了。

通过第一节我们已经知道,y的所有元素都是字符串。因为很多算法的预测结果都是数字,将标签转为数字也有助于计算error,所以使用astype(np.uint8)将y里所有元素转为8位无符号整数。

# 将数组中所有元素转为8位无符号整数
y = y.astype(np.uint8)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

Binary Classifier

Binary Classifier就是分两类。比如以数字 2 为例,我们训练一个分类器,将图片分成是2的和不是2的。

这里使用一个Stochastic Gradient Descent (SGD,随机梯度下降)分类器。这个适合高效处理较大的数据集,而且每个训练实例是单独处理的,一次一个,所以可以online learning。

y_train_2 = (y_train == 2)
y_test_2 = (y_test == 2)

from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_2)
some_digit = X[100]
print(sgd_clf.predict([some_digit]))
print(y[100])

some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()

其中,为了使结果可以复现,设置了random_state=42。
输出为:

[False]
5

在这里插入图片描述
因此可以知道,分类正确。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/428883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hutool-crypto 加密、解密详解!

1. 介绍 在Java开发的过程中&#xff0c;很多场景下都需要加密解密。 比如对敏感数据的加密&#xff0c;对配置文件信息的加密&#xff0c;通信数据的加密等等。 今天介绍的是Hutool工具包中的加密模块 crypto。 2. 加密分类 加密分为三类&#xff1a; 对称加密&#xff0…

Embarcadero RAD Studio Crack

Embarcadero RAD Studio Crack RAD Studio(r)是一款终极IDE&#xff0c;用于使用Delphi(r)、现代C和适用于Windows 11的高级Windows桌面UI库为多个平台创建单源原生应用程序&#xff0c;它为IDE添加了高DPI支持。这使得开发人员可以在更大的屏幕上以更高的分辨率工作。IDE现在支…

在VS和g++下的string结构的区别

文章目录1. 在VS下的结构2.在gcc下的结构3.写时拷贝/共享内存在之前的时间里&#xff0c;我们学习了string类的使用和模拟实现&#xff0c;但是在VS和g下使用string&#xff0c;发现了一点问题&#xff0c;下面我们通过一段代码来重现一下这个问题#include <iostream> #i…

融合DE 端和FE端数据,利用小波变换生成时频图,再分别利用DCNN、KNN和DNN进行对比实验(python代码)

1.数据集介绍&#xff1a; 试验台如图所示&#xff0c;试验台左侧有电动机&#xff0c;中间有扭矩收集器&#xff0c;右侧有动力测试仪&#xff0c;控制电子设备在图中没有显示。SKF6203轴承使用16通道数据采集卡采集轴承的振动数据&#xff0c;并在驱动端部分&#xff08;DE&…

AI模型训练、实施工程师的职业前景怎么样?

本篇文章主要讲解ai模型训练、模型实施工程师的职业前景和趋势分析 作者&#xff1a;任聪聪 日期&#xff1a;2023年4月18日 ai训练师、模型实施工程师&#xff0c;一般是指opencv、pytorh、python、java、机械学习、深度学习、图像识别、视频检测等领域的模型数据训练工作。 …

07 - 深度学习处理器架构⭐⭐⭐⭐

架构设计需要解决的两个主要问题:(1)如何提高处理器的能效比(性能/功耗)- 硬化算法(2)如何提高处理器的可编程性(通用性) - CPU 一、单核深度学习处理器(DLP-S) 1. 总体架构 (1)架构图 DMA是一种硬件机制,允许外围组件将其I/O数据直接传输到主存储器中,而无需…

CentOS 8 手动安装MongoDB

文章目录1. MongoDB概述2. 安装MongoDB2.1 在MongoDB官网选择对应版本2.2 去到MongoDB安装目录&#xff0c;并下载MongoDB安装包2.3 解压MongoDB安装包2.4 重命名解压后的MongoDB文件夹名2.5 创建MongoDB数据库数据存放路径2.6 创建MongoDB日志文件存放路径2.7 进入MongoDB文件…

Pixhawk基础—认识Pixhawk

Pixhawk简介 pixhawk是由3DR联合APM小组与PX4小组于2014年推出的飞控PX4的升级版&#xff0c;它同时拥有PX4和APM两套固件和相应的地面站软件。该飞控是目前全世界飞控产品中硬件规格最高的产品。 Pixhawk基础 端口介绍 1、Spektrum DSM receiver(Spektrum DSM信号转换为PWM…

Java基础总结(一)

文章目录前言封装继承多态抽象方法接口内部类static权限修饰符this superprivate关键字final关键字就近原则构造方法号StringBuilderStringJoiner字符串原理总结&#xff1a;1、字符串存储的内存原理2、号比较的是什么&#xff1f;3、字符串拼接的底层原理4、StringBuilder提高…

ASIC-WORLD Verilog(1)一日Verilog

写在前面 在自己准备写一些简单的verilog教程之前&#xff0c;参考了许多资料----asic-world网站的这套verilog教程即是其一。这套教程写得极好&#xff0c;奈何没有中文&#xff0c;在下只好斗胆翻译过来&#xff08;加了自己的理解&#xff09;分享给大家。 这是网站原文&…

Java反射面试总结(二)

为什么引入反射概念&#xff1f;反射机制的应用有哪些&#xff1f; 我们来看一下 Oracle 官方文档中对反射的描述&#xff1a; 从 Oracle 官方文档中可以看出&#xff0c;反射主要应用在以下几方面&#xff1a; 反射让开发人员可以通过外部类的全路径名创建对象&#xff0c;…

详解C语言结构体内存对齐:你知道如何快速计算结构体大小吗?

本篇博客会讲解C语言结构体的内存对齐&#xff0c;并且给出一种快速计算结构体大小的方式。主要讲解下面几点&#xff1a; 结构体的内存对齐是什么&#xff1f;如何快速计算结构体的大小&#xff1f;如何利用内存对齐节省结构体占用的内存空间&#xff1f;为什么结构体要内存对…

分布式数据库架构路线大揭秘

文章目录分布式数据库是如何演进的&#xff1f;数据库与分布式中间件有什么区别&#xff1f;如何处理分布式事务&#xff0c;提供外部一致性&#xff1f;如何处理分布式SQL&#xff1f;如何实现分布式一致性&#xff1f;数据库更适合金融政企的未来这些年大家都在谈分布式数据库…

MySQL-中间件mycat(一)

目录 &#x1f341;mycat基础概念 &#x1f341;Mycat安装部署 &#x1f343;初始环境 &#x1f343;测试环境 &#x1f343;下载安装 &#x1f343;修改配置文件 &#x1f343;启动mycat &#x1f343;测试连接 &#x1f990;博客主页&#xff1a;大虾好吃吗的博客 &#x1f9…

边缘网关thingsboard-gateway DTU902

thingsboard-gateway是一个采用python语言编写的开放源代码网关程序&#xff0c;用于将传统或第三方系统的设备与thingsboard平台连接。 支持 采集Modbus slaves、CAN、MQTT 、OPC-UA servers, Sigfox Backend。 除了具备普通 网关外&#xff0c;还具备可配置的边缘能力&…

rabbitmq深入实践

生产者&#xff0c;交换机&#xff0c;队列&#xff0c;消费者 交换机和队列通过 rounting key 绑定者&#xff0c;rounting key 可以是#.,*.这类topic模式&#xff0c; 生产者发送消息内容 rountingkey&#xff0c; 到达交换机后交换机检查与之绑定的队列&#xff0c; 如果能匹…

Yolov5之common.py文件解读

深度学习训练营原文链接前言0.导入需要的包以及基本配置1.基本组件1.1 autopad1.2 ConvDWConv模块1.3TransformerLayer模块1.4 Bottleneck和BottleneckCSPBottleneck模型结构1.5 CrossConv模块1.6 C3模块基于C3的改进1.7SPP1.8Focus模块1.9 Concat模块1.10 Contract和Expand1.1…

好东西!!!多亏几位大牛整理的面试题,让我成功上岸!!

凡事预则立&#xff0c;不预则废。相信很多程序员朋友在跳槽前都会临阵磨枪&#xff0c;在网络上搜集一些面试题进行准备。 然而&#xff0c;当机会来临时&#xff0c;却发现这些面试题往往“不快也不光”.... 由于Java面试涉及的范围很广&#xff0c;很杂&#xff0c;而且技…

使用MyBatis实现简单查询

文章目录一&#xff0c;创建数据库与表&#xff08;一&#xff09;在Navicat里创建MySQL数据库testdb&#xff08;二&#xff09;创建用户表 - t_user&#xff08;三&#xff09;在用户表里插入3条记录二&#xff0c;案例演示MyBatis基本使用&#xff08;一&#xff09;创建Mav…

解决idea每次打开新的项目都需要重新配置maven

原理&#xff1a;就是通过 idea 来进行全局配置【非当前工程配置】 IDEA 版本&#xff1a;2023.1 如何查看版本信息 &#xff1f; 【主菜单】——【帮助】——【关于】 我在网上查找了许多文章 &#xff0c;我混淆了一点&#xff01;当前工程的设置 & 全局设置 不在一个地方…