【卷积神经网络】MNIST 手写体识别

news2025/1/11 6:00:03

LeNet-5 是经典卷积神经网络之一,1998 年由 Yann LeCun 等人在论文 《Gradient-Based Learning Applied to Document Recognition》中提出。LeNet-5 网络使用了卷积层、池化层和全连接层,实现可以应用于手写体识别的卷积神经网络。TensorFlow 内置了 MNIST 手写体数据集,可以很方便地读取数据集,并应用于后续的模型训练过程中。本文主要记录了如何使用 TensorFlow 2.0 实现 MNIST 手写体识别模型。

目录

1 数据集准备

2 模型建立

3 模型训练


1 数据集准备

        TensorFlow 内置了 MNIST 手写体数据集,安装 TensorFlow 之后,使用如下代码就可以加载 MNIST 数据集:

import tensorflow as tf

mnist = tf.keras.datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()

        使用 Matplotlib 查看前 25 张图片,并打印对应的标签。

from matplotlib import pyplot as plt

# 查看训练集
plt.figure(figsize=(3,3))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.imshow(train_x[i], cmap=plt.cm.binary)
    plt.xticks([])
    plt.yticks([])
plt.show()

        接着,使用 tf.one_hot() 函数,对图像的标签进行独热码编码。

# 预处理
train_y = tf.one_hot(train_y, depth=10)
test_y = tf.one_hot(test_y, depth=10)

2 模型建立

        MNIST 手写体数据集中,每张图像的大小是 28 × 28 × 1,按照 LeNet-5 模型的思路,构建卷积神经网络模型。选择 5 × 5 的卷积核,卷积层之后是 2 × 2 的平均池化,激活函数选择 sigmoid(除了最后一层)。

# the first layer can receive an 'input_shape' argument
model = tf.keras.models.Sequential([
   tf.keras.layers.Conv2D(filters=6,kernel_size=5,padding='valid',activation='sigmoid',input_shape=(28,28,1)),
   tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=2,padding='valid'),
   tf.keras.layers.Conv2D(filters=16,kernel_size=5,padding='valid',activation='sigmoid'),
   tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=2,padding='valid'),
   tf.keras.layers.Flatten(),
   tf.keras.layers.Dense(120,activation='sigmoid'),
   tf.keras.layers.Dense(84,activation='sigmoid'),
   tf.keras.layers.Dense(10,activation='softmax')
])

        使用 model.summary() 查看模型信息。

model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 24, 24, 6)         156       
                                                                 
 average_pooling2d (AverageP  (None, 12, 12, 6)        0         
 ooling2D)                                                       
                                                                 
 conv2d_1 (Conv2D)           (None, 8, 8, 16)          2416      
                                                                 
 average_pooling2d_1 (Averag  (None, 4, 4, 16)         0         
 ePooling2D)                                                     
                                                                 
 flatten (Flatten)           (None, 256)               0         
                                                                 
 dense (Dense)               (None, 120)               30840     
                                                                 
 dense_1 (Dense)             (None, 84)                10164     
                                                                 
 dense_2 (Dense)             (None, 10)                850       
                                                                 
=================================================================
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
_________________________________________________________________

3 模型训练

        使用 compile() 函数配置模型,优化算法为 Adam 算法,学习率为 0.001,损失函数为交叉熵损失函数。

# 模型配置
model.compile(
   optimizer=tf.keras.optimizer.Adam(learning_rate=1e-3),
   loss=tf.keras.losses.CategoricalCrossentropy(),
   metrics=['accuracy']
)

# 模型训练
model.fit(
   x=train_x,
   y=train_y,
   validation_split=0.0,
   epochs=10
)

Epoch 1/10
1875/1875 [==============================] - 72s 38ms/step - loss: 0.5806 - accuracy: 0.8206
Epoch 2/10
1875/1875 [==============================] - 70s 37ms/step - loss: 0.1254 - accuracy: 0.9620
Epoch 3/10
1875/1875 [==============================] - 75s 40ms/step - loss: 0.0870 - accuracy: 0.9735
Epoch 4/10
1875/1875 [==============================] - 82s 43ms/step - loss: 0.0699 - accuracy: 0.9785
Epoch 5/10
1875/1875 [==============================] - 69s 37ms/step - loss: 0.0604 - accuracy: 0.9809
Epoch 6/10
1875/1875 [==============================] - 68s 36ms/step - loss: 0.0530 - accuracy: 0.9833
Epoch 7/10
1875/1875 [==============================] - 72s 38ms/step - loss: 0.0477 - accuracy: 0.9854
Epoch 8/10
1875/1875 [==============================] - 70s 38ms/step - loss: 0.0436 - accuracy: 0.9863
Epoch 9/10
1875/1875 [==============================] - 70s 37ms/step - loss: 0.0399 - accuracy: 0.9873
Epoch 10/10
1875/1875 [==============================] - 68s 36ms/step - loss: 0.0357 - accuracy: 0.9883
<keras.callbacks.History at 0x20a56b65660>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/945735.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络基础知识socket编程

目录 网络通信概述网络互连模型&#xff1a;OSI 七层模型TCP/IP 四层/五层模型数据的封装与拆封 IP 地址IP 地址的编址方式IP 地址的分类特殊的IP 地址如何判断2 个IP 地址是否在同一个网段内 TCP/IP 协议TCP 协议TCP 协议的特性TCP 报文格式建立TCP 连接&#xff1a;三次握手关…

vue2 支持图片放大

添加 :preview-src-list属性 <el-imagev-for"item in specialData.urls":src"item":key"item.index":preview-src-list[item]class"pictrue"/>

李跳跳apk

李跳跳下载&#xff0c;密码 65c9

【VRRP】虚拟路由冗余协议

什么是VRRP&#xff1f; 虚拟路由冗余协议VRRP&#xff08;Virtual Router Redundancy Protocol&#xff09;是一种用于提高网络可靠性的容错协议。通过VRRP&#xff0c;可以在主机的下一跳设备出现故障时&#xff0c;及时将业务切换到备份设备&#xff0c;从而保障网络通信的…

【桌面小屏幕项目】ESP32开发环境搭建

视频教程链接&#xff1a; 【【有手就行系列】嵌入式单片机教程-桌面小屏幕实战教学 从设计、硬件、焊接到代码编写、调试 ESP32 持续更新2022】 https://www.bilibili.com/video/BV1wV4y1G7Vk/?share_sourcecopy_web&vd_source4fa5fad39452b08a8f4aa46532e890a7 一、esp…

C++标准库STL容器详解

目录 C标准模板库STL容器容器分类容器通用接口 顺序容器vectorlistdeque 容器适配器queuestackpriority_queue 关联容器&#xff1a;红黑树setmultisetmapmultimap 关联容器&#xff1a;哈希表unordered_set和unordered_multisetunordered_map和unordered_multimap 附1&#xf…

机械硬盘HDD的基础知识介绍

机械硬盘在价格&#xff0c;容量&#xff0c;磨损度上面都只有着SSD不可取代的地方&#xff0c;目前世界上80%的数据仍然存储在HDD上&#xff0c;不过随着科技的进步&#xff0c;以及SSD技术不断的突破和逐渐降低的价格&#xff0c;HDD的占比会越来越低,至于未来会不会被SSD完全…

任务执行和调度----Spring线程池/Quartz

定时任务 在服务器中可能会有定时任务&#xff0c;但是不知道分布式系统下次会访问哪一个服务器&#xff0c;所以服务器中的任务就是相同的&#xff0c;这样会导致浪费。使用Quartz可以解决这个问题。 JDK线程池 RunWith(SpringRunner.class) SpringBootTest ContextConfi…

Redis的五大数据类型的数据结构

概述 Redis底层有六种数据类型包括&#xff1a;简单动态字符串、双向链表、压缩列表、哈希表、跳表和整数数组。这六种数据结构五大数据类型关系如下&#xff1a; String&#xff1a;简单动态字符串List&#xff1a;双向链表、压缩列表Hash&#xff1a;压缩列表、哈希表Sorted…

指针(个人学习笔记黑马学习)

1、指针的定义和使用 #include <iostream> using namespace std;int main() {int a 10;int* p;p &a;cout << "a的地址为&#xff1a;" << &a << endl;cout << "a的地址为&#xff1a;" << p << endl;…

CPU和GPU的区别

介绍什么是GPU, 那就要从CPU和GPU的比较不同中能更好更快的学习到什么是GPU CPU和GPU的总体区别 CPU&#xff1a; 叫做中央处理器&#xff08;central processing unit&#xff09; 可以形象的理解为有25%的ALU(运算单元)、有25%的Control(控制单元)、50%的Cache(缓存单元)…

“短视频类”App个人信息收集情况测试报告

近期&#xff0c;中国网络空间安全协会对“短视频类”公众大量使用的部分App收集个人信息情况进行了测试。测试情况及结果如下&#xff1a; 一、测试对象 本次测试选取了19家应用商店⁽⁾累计下载量达到1亿次的“短视频类”App&#xff0c;共计6款&#xff0c;其基本情况如表…

StarRocks 在金融科技行业的存算分离应用实践

小编导读&#xff1a; 自从 2023 年 4 月正式推出 3.0 版本的存算分离功能以来&#xff0c;目前已有包含芒果TV、聚水潭、网易邮箱、浪潮、天道金科等数十家用户完成测试&#xff0c;多家用户也已开始逐步将其应用于实际业务中。目前&#xff0c;StarRocks 存算分离上线的场景…

【少年的救赎——放牛班的春天】

风中飞舞的风筝&#xff0c;请你别停下 池塘之底 这是马修在池塘之底写下的日记 他所有的故事&#xff0c;还有“我们”的 1949年一月十五日&#xff0c;在经历了所有领域的挫折后&#xff0c;马修来到了人生低谷期&#xff0c;“池塘之底”像专为他挑选的一般。那是在一个…

19 NAT穿透|python高级

文章目录 网络通信过程NAT穿透 python高级GIL锁深拷贝与浅拷贝私有化import导入模块工厂模式多继承以及 MRO 顺序烧脑题property属性property装饰器property类属性 魔法属性\_\_doc\_\_\_\_module\_\_ 和 \_\_class\_\_\_\_init\_\_\_\_del\_\_\_\_call\_\_\_\_dict\_\_\_\_str…

Gin 框架入门实战系列(一)

GIN介绍 Gin是一个golang的微框架,封装比较优雅,API友好,源码注释比较明确,具有快速灵活,容错方便等特点 对于golang而言,web框架的依赖要远比Python,Java之类的要小。自身的net/http足够简单,性能也非常不错 借助框架开发,不仅可以省去很多常用的封装带来的时间,…

8.28~~和学长的谈话

对于大二&#xff0c;我还想问问学长有什么建议&#xff1f; 熟练掌握一到两门开发语言&#xff0c;选好专业的重点学习方向&#xff0c;开始全面了解工程实践方面&#xff0c;10个以上工程开发&#xff0c;可自行规划二年级&#xff0c;着重加强基础技能的学习和提升&#xf…

JMeter性能测试基本过程及示例

jmeter 为性能测试提供了一下特色&#xff1a; jmeter 可以对测试静态资源&#xff08;例如 js、html 等&#xff09;以及动态资源&#xff08;例如 php、jsp、ajax 等等&#xff09;进行性能测试 jmeter 可以挖掘出系统最大能处理的并发用户数 jmeter 提供了一系列各种形式的…

【100天精通python】Day47:python网络编程_Web编程基础

目录 1 网络编程与web编程 1.1 网络编程 1.2 web编程 2 Web开发概述 3 Web开发基础 3.1 HTTP协议 3.2 Web服务器 3.3 前端基础 3.4 静态服务器 3.5 前后端交互的基本原理 4 WSGI接口 4.1 CGI 简介 4.2 WSGI 简介 4.3 定义 WSGI 接口 4.4 运行 WSGI 服务 4.5…

vue3:使用:图片生成二维码并复制

实现在 vue3 中根据 url 生成一个二维码码&#xff0c;且可以复制。 注&#xff09;复制功能 navigator.clipboard.write 只能在安全的localhost 这种安全网络下使用。https中需要添加安全证书&#xff0c;且在域名&#xff08;例&#xff1a;https://www.baidu.com&#xff0…