深度学习训练营实现minist手写数字识别

news2024/10/1 12:23:54

深度学习训练营

  • 原文链接
  • 环境介绍
  • 前置工作
    • 设置GPU
    • 导入要使用的包
    • 进行归一化操作
    • 样本可视化
    • 调整图片格式
  • 构建CNN网络
  • 编译模型
  • 模型训练
  • 预测操作

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P1周:实现mnist手写数字识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.13
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前置工作

设置GPU

如果

# K同学啊深度学习练习
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

导入要使用的包

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

# 导入mnist数据,依次分别为训练集图片、训练集标签、测试集图片、测试集标签
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

进行归一化操作

# 将像素的值标准化至0到1的区间内。
train_images, test_images = train_images / 255.0, test_images / 255.0

train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

样本可视化

使用到的是python当中专门用来画图的matplotlib.pyplot

# 将数据集前20个图片数据可视化显示
# 进行图像大小为20宽、10长的绘图(单位为英寸inch)
plt.figure(figsize=(20,10))
# 遍历MNIST数据集下标数值0~49
for i in range(20):
    # 将整个figure分成5行10列,绘制第i+1个子图。
    plt.subplot(2,10,i+1)
    # 设置不显示x轴刻度
    plt.xticks([])
    # 设置不显示y轴刻度
    plt.yticks([])
    # 设置不显示子图网格线
    plt.grid(False)
    # 图像展示,cmap为颜色图谱,"plt.cm.binary"为matplotlib.cm中的色表
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    # 设置x轴标签显示为图片对应的数字
    plt.xlabel(train_labels[i])
# 显示图片
plt.show()

在这里插入图片描述

调整图片格式

#调整数据到我们需要的格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

调整图片格式为我们需要的
在这里插入图片描述

构建CNN网络

CNN名为卷积神经网络,是一种专门用来处理类似网络结构的数据的神经网络

  • 卷积层:通过卷积操作对输入图像进行降维和特征抽取
  • 池化层:是一种非线性形式的下采样。主要用于特征降维,压缩数据和参数的数量,减小过拟合,同时提高模型的鲁棒性。
  • 全连接层:在经过几个卷积和池化层之后,神经网络中的高级推理通过全连接层来完成。
model = models.Sequential([
    # 设置二维卷积层1,设置32个3*3卷积核,activation参数将激活函数设置为ReLu函数,input_shape参数将图层的输入形状设置为(28, 28, 1)
    # ReLu函数作为激活励函数可以增强判定函数和整个神经网络的非线性特性,而本身并不会改变卷积层
    # 相比其它函数来说,ReLU函数更受青睐,这是因为它可以将神经网络的训练速度提升数倍,而并不会对模型的泛化准确度造成显著影响。
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    #池化层1,2*2采样
    layers.MaxPooling2D((2, 2)),                   
    # 设置二维卷积层2,设置64个3*3卷积核,activation参数将激活函数设置为ReLu函数
    layers.Conv2D(64, (3, 3), activation='relu'),  
    #池化层2,2*2采样
    layers.MaxPooling2D((2, 2)),                   
    
    layers.Flatten(),                    #Flatten层,连接卷积层与全连接层
    layers.Dense(64, activation='relu'), #全连接层,特征进一步提取,64为输出空间的维数,activation参数将激活函数设置为ReLu函数
    layers.Dense(10)                     #输出层,输出预期结果,10为输出空间的维数
])
# 打印网络结构
model.summary()

在这里插入图片描述
Sequential:连续的,序列的
Non-trainable params:不可训练参数

编译模型

# model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
model.compile(
	# 设置优化器为Adam优化器
    optimizer='adam',
	# 设置损失函数为交叉熵损失函数(tf.keras.losses.SparseCategoricalCrossentropy())
    # from_logits为True时,会将y_pred转化为概率(用softmax),否则不进行转换,通常情况下用True结果更稳定
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # 设置性能指标列表,将在模型训练时监控列表中的指标
    metrics=['accuracy'])

模型训练

"""
这里设置输入训练数据集(图片及标签)、验证数据集(图片及标签)以及迭代次数epochs
关于model.fit()函数的具体介绍可参考K同学的博客:
https://blog.csdn.net/qq_38251616/category_10258234.html
"""
history = model.fit(
    # 输入训练集图片
	train_images, 
	# 输入训练集标签
	train_labels, 
	# 设置10个epoch,每一个epoch都将会把所有的数据输入模型完成一次训练。
	epochs=10, 
	# 设置验证集
    validation_data=(test_images, test_labels))

训练过程以及结果如下
在这里插入图片描述

预测操作

plt.imshow(test_images[7])

在这里插入图片描述
查看图片的预测结果

pre = model.predict(test_images) # 对所有测试图片进行预测
pre[7] # 输出第七张图片的预测结果

在这里插入图片描述
在预测图片的过程当中,对图像进行预测,预测图像为哪一个0到9的数值,数值越大,则代表越靠近该值,所以由以上结果可以知道,图片的预测结果为9

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/54742.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenCV图像处理——卷积操作

总目录 图像处理总目录←点击这里 二十五、卷积操作 25.1、预处理 # 指定输入图像 ap argparse.ArgumentParser() ap.add_argument("-i", "--image", requiredTrue, help"path to the input image") args vars(ap.parse_args())# 分别构建…

【数据结构趣味多】顺序表基本操作实现(Java)

目录 顺序表 1.定义顺序顺序表 2.顺序表功能 3.函数实现(java实现)? 打印顺序表display()函数 新增元素函数add() (默认在数组最后新增) 在 pos 位置新增元素add()函数(与上方函数构成重载) 判定是否包含某个元素…

XctNet:从单个X射线图像重建体积图像的网络

摘要 传统的计算机断层扫描(CT)通过使用不同角度的X射线投影计算逆氡变换来生成体积图像,这导致高剂量辐射、长重建时间和伪影。生物学上,可以利用先前的知识或经验在一定程度上从2D图像中识别体积信息。提出了一种深度学习网络Xc…

为什么要使用 kafka,为什么要使用消息队列?

总结以下两点: 1、缓冲和削峰: 上游数据时有突发流量,下游可能扛不住,或者下游没有⾜够多的机器来保证冗余,kafka在中间可以起到⼀个缓冲的作⽤,把消息暂存在kafka中,下游服务就可以按照⾃⼰的节…

B. Moderate Modular Mode(nmodx=ymodn )

Problem - 1603B - Codeforces 帮助他找到一个整数n,使得1≤n≤2⋅1018,并且nmodxymodn。这里,amodb表示a除以b后的余数。如果有多个这样的整数,请输出任何一个。可以证明,在给定的约束条件下,这样的整数总…

图的关键路径(含多支交叉路径分离输出)

文章目录关键路径的理解关键路径求解的图解与分析关键路径查找的代码实现多支交叉路径的分离输出总结此文代码均可在Windows与Linux操作系统下的常用编译器上运行,例如:vs、vscode、Dev-C等等。关键路径的理解 图的关键路径一般是在求从一个顶点到另一个…

RocketMQ-RocketMQ部署(Linux、docker)

文章目录一、Linux1、单机部署RocketMQ> 前置条件第一步、官网下载 并 上传至服务器第二步、配置jdk环境第三步、修改初始内存第四步、启动 NameServer第五步、启动 Broker第六步、关闭RocketMQDemo:发送与接收消息测试 (Linux端)2、部署可视化管理工具—rocketm…

tictoc 例子理解 13-15

tictoc13-tictoc13 子类化cMessage生成消息,随机目标地址tictoc 14 在13的基础上增加两变量显示于仿真界面tictoc 15 模型数据输出为直方图tictoc13 子类化cMessage生成消息,随机目标地址 在这一步中,目标地址不再是节点2——我们绘制了一个…

[附源码]计算机毕业设计springboot现代诗歌交流平台

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

MySQL存储过程

目录 存储过程 1、存储过程的概念 2、存储过程的优点 3、创建存储过程 格式: 4、调用存储过程 格式 5、查看存储过程 格式: 6、存储过程的参数 7、删除存储过程 格式: 8、存储过程的控制语句 准备a表 (1)条…

Spring基础篇:注入

第一章:注入 一:什么是注入 (Injection)注入就是通过Spring的工厂类和spring的配置文件,对spring所创建的对象进行赋值,为成员变量进行赋值 二:为什么注入 为什么需要Spring工厂创建对象的时…

[附源码]Python计算机毕业设计SSM开放式在线课程教学与辅助平台(程序+LW)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

[附源码]计算机毕业设计JAVA校园闲置物品租赁系统

[附源码]计算机毕业设计JAVA校园闲置物品租赁系统 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM my…

Maven使用指南(超详细)

Maven高级 目标 理解并实现分模块开发能够使用聚合工程快速构建项目能够使用继承简化项目配置能够根据需求配置生成、开发、测试环境,并在各个环境间切换运行了解Maven的私服 1,分模块开发 1.1 分模块开发设计 (1)按照功能拆分 我们现在的项目都是在…

Delay Penalty for RNN-T and CTC

1. 背景 之前介绍了如何在 RNN-T 流式模型上应用时延正则,以及在 Conformer 和 LSTM 上的实验结果。 本期公众号重点带大家回顾下具体的思路,以及如何类似地在 CTC 流式模型上应用时延正则。 有些内容可能有所重复,读者可适当跳过。2. Dela…

iwebsec靶场 SQL注入漏洞通关笔记12-等价函数替换绕过

系列文章目录 iwebsec靶场 SQL注入漏洞通关笔记1- 数字型注入_mooyuan的博客-CSDN博客 iwebsec靶场 SQL注入漏洞通关笔记2- 字符型注入(宽字节注入)_mooyuan的博客-CSDN博客 iwebsec靶场 SQL注入漏洞通关笔记3- bool注入(布尔型盲注&#…

Ajax学习:同源策略(与跨域相关)ajax默认遵循同源策略

同源策略:是浏览器的一种安全策略 同源意味着:协议、域名、端口号必须相同 违背同源便是跨域 当前网页的url和ajax请求的目标资源的url必须协议、域名、端口号必须相同 比如:当前网页:协议http 域名 a.com 端口号8000 目标请求…

python——spark入门

Hadoop是对大数据集进行分布式计算的标准工具,这也是为什么当你穿过机场时能看到”大数据(Big Data)”广告的原因。它已经成为大数据的操作系统,提供了包括工具和技巧在内的丰富生态系统,允许使用相对便宜的商业硬件集群进行超级计算机级别的…

Android Poco初始化时,不大起眼但可能存在坑点的参数们

1. 前言 进行Android poco初始化的时候,可能大多数同学都是直接在Poco辅助窗里选择Android模式,然后选择自动帮我们补充poco的初始化脚本: 这种情况下,我们大多数都不会关注初始化的参数。但如果我们不了解这些参数的含义&#x…

Spring之@RequestMapping、@GetMapping、 @PostMapping 三者的区别

我的理解:其实RequestMapping、GetMapping、 PostMapping 三者就是父类和子类的区别,RequestMapping是父类,GetMapping、 PostMapping为子类集成了RequestMapping更明确了http请求的类型 分析三者的源码: RequestMapping .class&…