7步搞懂手写数字识别Mnist

news2024/11/27 17:50:48

大家好啊,我是董董灿。

图像识别有很多入门项目,其中Mnist 手写数字识别绝对是最受欢迎的。

该项目以数据集小、神经网络简单、任务简单为优势,并且集合了CNN网络中该有的东西,可谓麻雀虽小,五脏俱全。

非常适合新手上手学习。

本文以代码走读的形式,带你一览该项目的每一处细节。

文章末尾附代码下载链接,不用GPU, 你也可以从头训练一个神经网络出来。

什么是手写数字识别

简答来说,就是搭建了一个卷积神经网络,可以完成手写数字的识别。

我用笔在纸上写了个6,神经网络就能认识这是个6,我写了个8,它就识别出来这是个8,就这么简单。

之所以说该任务简单,是因为它的标签只有 0-9 这 10 种分类,相比于 resnet 等网络在 ImageNet 上 1000 个分类,确实小很多。

虽然简单,但背后的原理却一点都不少,典型的CNN训练和算法无一缺席。

与该项目一起出名的,便是大名鼎鼎的 MNIST(Mathematical Numbers In Text) 数据集。

该数据集中包含了 60,000 个训练图像和 10,000 个测试图像,图像都是各种手写的数字,基本都是长这样的。

7步精读代码

在简单了解了项目背景后,我以代码走读的形式,一点点介绍该神经网络。

第一步:导入必要的库

# 导入NumPy数学工具箱
import numpy as np 
# 导入Pandas数据处理工具箱
import pandas as pd
# 从 Keras中导入 mnist数据集
from keras.datasets import mnist

keras 是一个开源的人工神经网络库,里面有很多经典的神经网络和数据集,要用的 mnist 数据集就在其中。

第二步:加载数据集

(x_train, y_train), (x_test, y_test)
=  mnist.load_data() 

这条命令利用 keras 中自带的 mnist 模块,加载数据集(load_data)进来,分别赋值给四个变量。

其中:x_train 保存用来训练的图像,y_train 是与之对应的标签。假设图像中的数字是1,那么标签就是1。

x_test 和 y_test 分别为用来验证的图像和标签,也就是验证集。训练完神经网络后,可以使用验证集中的数据进行验证。

第三步:数据预处理

其中一个预处理内容是改变数据集的 shape,使其满足模型的要求。

 # 导入keras.utils工具箱的类别转换工具
from tensorflow.keras.utils import to_categorical
 # 给标签增加维度,使其满足模型的需要
 # 原始标签,比如训练集标签的维度信息是[60000, 28, 28, 1]
X_train = X_train_image.reshape(60000,28,28,1)
X_test = X_test_image.reshape(10000,28,28,1)
 # 特征转换为one-hot编码
y_train = to_categorical(y_train_lable, 10)
y_test = to_categorical(y_test_lable, 10)

这个数据集中的共 60000 张训练图像,10000 张验证图像,每张图像的长宽均为 28 个像素,通道数为 1。

那么对于训练集 x_train 而言,将其形状变为 NHWC = [60000, 28, 28, 1], 验证集类似。

to_categorical 的作用是将样本标签转为 one-hot 编码,而 one-hot 编码的作用是可以对于类别更好的计算概率或得分。

one-hot

之所以用 one-hot 编码,是因为对于输出 0-9 这10个标签而言,每个标签的地位应该是相等的,并不存在标签数字 2 大于数字 1 的情况。

但如果我们直接利用标签的原始值(0-9)进行最终结果的计算,就会出现标签2 大于标签 1的情况。

因此,在大部分情况下,都需要将标签转换为 one-hot 编码,也就独热编码,这样标签之间便没有任何大小而言。

这个例子中,数字 0-9 转换为的独热编码为:

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]

每一行的向量代表一个标签。

假设 [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.] 代表 0 而 [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.] 代表1,可以看到这两者之间是正交独立的,不存在谁比谁大的问题。

第四步:创建神经网络。

# 从 keras 中导入模型
from keras import models 
# 从 keras.layers 中导入神经网络需要的计算层
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
# 构建一个最基础的连续的模型,所谓连续,就是一层接着一层
model = models.Sequential()
# 第一层为一个卷积,卷积核大小为(3,3), 输出通道32,使用 relu 作为激活函数
model.add(Conv2D(32, (3, 3), activation='relu', 
                 input_shape=(28,28,1)))
# 第二层为一个最大池化层,池化核为(2,2)
# 最大池化的作用,是取出池化核(2,2)范围内最大的像素点代表该区域
# 可减少数据量,降低运算量。
model.add(MaxPooling2D(pool_size=(2, 2)))
# 又经过一个(3,3)的卷积,输出通道变为64,也就是提取了64个特征。
# 同样为 relu 激活函数
model.add(Conv2D(64, (3, 3), activation='relu'))
# 上面通道数增大,运算量增大,此处再加一个最大池化,降低运算
model.add(MaxPooling2D(pool_size=(2, 2)))
# dropout 随机设置一部分神经元的权值为零,在训练时用于防止过拟合
# 这里设置25%的神经元权值为零
model.add(Dropout(0.25)) 
# 将结果展平成1维的向量
model.add(Flatten())
# 增加一个全连接层,用来进一步特征融合
model.add(Dense(128, activation='relu'))
# 再设置一个dropout层,将50%的神经元权值为零,防止过拟合
# 由于一般的神经元处于关闭状态,这样也可以加速训练
model.add(Dropout(0.5)) 
# 最后添加一个全连接+softmax激活,输出10个分类,分别对应0-9 这10个数字
model.add(Dense(10, activation='softmax'))

上面每一行代码都加了注释,说明每一行的作用,短短几行,便是这个手写数字识别神经网络的全部了。

第五步:训练

# 编译上述构建好的神经网络模型
# 指定优化器为 rmsprop
# 制定损失函数为交叉熵损失
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
# 开始训练              
model.fit(X_train, y_train, # 指定训练特征集和训练标签集
          validation_split = 0.3, # 部分训练集数据拆分成验证集
          epochs=5, # 训练轮次为5轮
          batch_size=128) # 以128为批量进行训练

Epoch 5/5
329/329 [==============================] - 15s 46ms/step - loss: 0.1054 - accuracy: 0.9718 - val_loss: 0.0681 - val_accuracy: 0.9826
训练结果如上,可以看到最后的训练精度达到了98.26%,还是挺高的。

第6步:验证集验证

# 在测试集上进行模型评估
score = model.evaluate(X_test, y_test) 
print('测试集预测准确率:', score[1]) # 打印测试集上的预测准确率

313/313 [==============================] - 1s 4ms/step - loss: 0.0662 - accuracy: 0.9815 测试集预测准确率: 0.9815000295639038

可以看到在验证集上也能有98%的准确率。

第7步:验证一张图片

# 预测验证集第一个数据
pred = model.predict(X_test[0].reshape(1, 28, 28, 1)) 
# 把one-hot码转换为数字
print(pred[0],"转换一下格式得到:",pred.argmax())
 # 导入绘图工具包
import matplotlib.pyplot as plt
# 输出这个图片
plt.imshow(X_test[0].reshape(28, 28),cmap='Greys')

以验证集中的第一张图片为例来进行验证。

1/1 [==============================] - 0s 17ms/step
[4.2905590e-15 2.6790809e-11 2.8249305e-09 2.3393848e-11 7.1304548e-14
1.8217797e-18 5.7493907e-19 1.0000000e+00 8.0317367e-15 4.6352322e-10]

转换一下格式得到:7

得到的数字是7,将该图片显示出来,确实是7。说明训练的模型确实达到了识别数字的水平。

总结

手写数字识别项目比较简单,仅仅两个卷积层,整体运算量不大,就目前计算机的配置,即使笔记本基本上都可以完成该神经网络的训练和验证。

如果你感兴趣,关注公众号《董董灿是个攻城狮》后台回复【mnist】获取源码,实操起来吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/590751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Fourier分析入门——第12章——Fourier变换的性质

目录 第12章 Fourier变换的性质 12.1 引言 12.2 Fourier变换性质的相关定理 12.2.1 线性定理(Linearity) 12.2.2 伸缩性定理(Scaling) 12.2.3 时间/空间平移定理(Shift) 12.2.4 频移定理 12.2.5 调制定理(Modulation) 12.2.6 微分定理(Differentiation) 12.2.7 积分定…

冒泡排序详解(Bubble Sort)

本文已收录于专栏 《算法合集》 目录 一、简单释义1、算法概念2、算法目的3、算法思想4、算法性质 二、核心思想构建排序 三、图形展示宏观展示微观展示 四、算法实现实现思路代码实现客户端调用构造堆的方法元素交换的方法元素比较的方法 运行结果 五、算法描述1、问题描述2、…

数据库管理-第七十八期 记第一次数据库吐槽大会(20230530)

数据库管理 2023-05-30 第七十八期 记第一次数据库吐槽大会1 主席2 三六九等3 数据库吐槽大会总结 第七十八期 记第一次数据库吐槽大会 昨天晚上终于还是把Exadata X9M-2和之前用于展示RAC搭建及升级的那套库做好了ADG,这部分操作在整理后会在下个月发出来。因为之…

Python列表类型的使用

文章目录 Python中的列表类型一、列表的常用操作二、列表的增删改查三、列表常用的函数 Python中的列表类型 将各个元素用方括号([])括起来,用逗号(,)分隔开,这种形式的数据类型就是列表。各个元素的数据类…

HNU-电子测试平台与工具2-串口实验5次

计算机串口使用与测量 【实验属于电子测试平台与工具】 湖南大学信息科学与工程学院 计科 210X wolf (学号 202108010XXX) 0.环境搭建 在实验开始之前,安装好Ubuntu 20.04操作系统。(这个没有难度) 但要提醒的是,这个ubuntu是xubuntu,而且虚拟硬盘只有10GB的大小…

智警杯1.4---excel可视化

视频要点: 首先就是有数据透视表 点击数据透视表,分析,字段项目, 切片器筛选 切片器(我希望用什么对数据进行一个筛选) 跟下拉列表有点像,只不过切片器仅仅之对于数据透视表 依旧需要用su…

HBase集群搭建

hbase 1.解压HBase安装包 先 下载HBase压缩包,并解压安装文件,示例代码如下: tar -zxvf hbase-2.0.1-bin.tar.gz2. 修改配置文件 编辑 conf目录下的 hbase-env.sh文件,示例代码如下: cd conf vi hbase-env.sh添加…

压缩感知入门③基于ADMM的全变分正则化的压缩感知重构算法

压缩感知系列博客:压缩感知入门①从零开始压缩感知压缩感知入门②信号的稀疏表示和约束等距性压缩感知入门③基于ADMM的全变分正则化的压缩感知重构算法 文章目录 1. Problem2. 仿真结果3. MATLAB算法4. 源码地址参考文献 1. Problem 信号压缩是是目前信息处理领域非…

Frame Pacing

Frame Pacing是每个游戏都要遇到的问题,这里面有很多细节值得探讨。 为什么需要做Frame Pacing? 从我们的游戏线程渲染一帧到最终屏幕上绘制出一帧不是一个概念,这种间会经历CPU,GPU,屏幕合成器等多个角色的协同工作&a…

【xv6操作系统】安装、运行与调试

一、构建、装入过程 1.编写“启动代码主体代码”(在下载的xv6的原始代码上进行修改) 2.源代码进行编译、链接生成系统镜像(elf格式的目标文件) 3.将系统镜像保存起来(如保存到磁盘、flash或者网络服务器上&#xff…

上海斯歌K2 x 赛博威 | 战略合作深度交流暨签约仪式

2月16日,上海斯歌K2与赛博威进行了战略合作深度交流,并在赛博威广州科学城办公室举办战略合作签约仪式。 为满足客户在数智化建设过程中的多元化需求,上海斯歌K2与赛博威曾多次产生交集。凭借双方多年合作的良好基础,自2022年始&a…

【C++】右值引用和移动语义(详细解析)

文章目录 1.左值引用和右值引用左值引用右值引用 2.左值引用和右值引用的比较左值引用总结右值引用总结 3.右值引用的使用场景和意义知识点1知识点2知识点3知识点4总结 4.完美转发万能引用见识完美转发的使用完美转发的使用场景 1.左值引用和右值引用 传统的C语法中就有引用的…

【C++】引用 - 基本语法,注意事项,函数参数,函数返回值,本质

文章目录 1. 引用的基本语法2. 引用的注意事项3. 引用做函数参数4. 引用做函数返回值5. 引用的本质6. 常量引用 1. 引用的基本语法 作用是:给变量起别名 语法:数据类型 &别名 原名 2. 引用的注意事项 引用必须初始化引用在初始化后,不…

量子计算:基本概念

选了课程 《量子计算与量子信息》,没学过量子力学的博主实在是听不懂啊 (ㄒoㄒ) 简略整理了下 可能大概也许 明白一二都没有 的课程最开始两节的内容,如有错误欢迎指出 ~ ~ ~ 文章目录 矩阵论复空间中的矩阵矩阵上的运算 量子力学量子态基本假设 量子计算…

阿里云的内容识别技术可以实现哪些场景下的智能化应用?

阿里云的内容识别技术可以实现哪些场景下的智能化应用? [本文由阿里云代理商[聚搜云]撰写]   随着人工智能技术的快速发展,阿里云借助自身的技术和资源优势,开发了一种名为“内容识别”的技术。这项技术能够高效、准确地识别出图片、视频、…

有个规划文档,会让软件开发更有效

有个规划文档,会让软件开发更有效 中小企业,业务部门不太清楚软件生产过程 软件生产有一定的抽象和复杂性 要形成一个共识 趣讲大白话:要有点整体观 【趣讲信息科技181期】 **************************** 2019年整理出了一个目录框架 用在很多…

windows的cmd命令窗口介绍

1.打开cmd 1.1.方式一 左下角搜索:“运行” -> 打开 输入"cmd" -> 确定 1.2.方式二 直接使用快捷键 windows r 即可打开 然后输入cmd,点击确认 1.3.方式三 打开文件管理器,输入cmd,回车 即可在该文件路径下…

统计软件与数据分析Lesson16----pytorch基本知识及模型构建

统计软件与数据分析Lesson16----pytorch基本知识及模型构建 0.上节回顾0.1 一元线性回归数据生成数据处理初始数据可视化 0.2 梯度下降Gradient DescentStep 0: 随机初始化 Random InitializationStep 1: 计算模型预测值 Compute Models PredictionsStep 2: 计算损失 Compute t…

让进程能够“相互沟通”的高级方式一:匿名管道

代码运行及测试环境:linux centos7.6 在阅读这篇文章时,需要掌握OS对文件管理的基础知识(文件打开表、文件描述符、索引结点…) 前言 我们都知道进程是具有独立性的,意味着进程之间无法相互通信。但在一些情况下&…

当我们谈笔记的时候,我们在谈什么

文章具体内容如图,感谢妙友分享好文🎉 本篇内容来源于网站Untag Minja 上传的内容《当我们谈笔记的时候,我们在谈什么》 如有侵权请联系删除!