【机器学习 P19】【实战 P1】 MINST 手写数字识别

news2024/11/24 17:47:57

MINST 手写数字识别

  • 引入数据
  • 模型训练
    • 模型创建程序
    • 模型编译程序
    • 模型训练程序
    • 模型预测程序
  • 完整代码


在这里插入图片描述

引入数据

MINST数据集是一个经典的手写数字识别数据集,由Yann LeCun等人创建。它包含了来自真实手写数字图片的70000个灰度图像,这些图像是由250个不同的人手写而成的,其中60000个图像被用作训练数据,10000个图像用作测试数据。

每个图像都是28x28像素大小的,并且已经被规范化和中心化处理,以便于输入到机器学习模型中。每个图像还带有一个标签,指示该图像所代表的数字是什么。

通过Keras可以方便地获取MINST数据集。Keras提供了一个简单的API,可以直接从官方网站下载并加载MNIST数据集:

  • 下载加载数据集:
from tensorflow import keras
from tensorflow.keras.datasets import mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
  • 查看数据集大小:
print(x_train.shape)
# (60000, 28, 28)
print(x_test.shape)
# (10000, 28, 28)
- 60000张训练图片
- 每张图片 28 * 28 pixel
- 10000张测试图片
- 每张图片 28 * 28 pixel

在这里插入图片描述

  • 使用matplotlib查看MINST数据集中图片:
import matplotlib.pyplot as plt

# 显示训练集中前25个图像
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10,10))
axes = axes.ravel()

for i in range(25):
    axes[i].imshow(x_train[i], cmap='gray')
    axes[i].set_title(y_train[i])
    axes[i].axis('off')

plt.subplots_adjust(hspace=0.5)
plt.show()

cmap='gray' 参数使图像以灰度模式显示。
axis('off') 方法可以将坐标轴关闭,以便更好地查看图像。
plt.subplots_adjust(hspace=0.5) 将相邻子图之间的间距设置为一个子图高度的50%,以便更好地分隔每个子图。

  • 展示效果如下图所示:

在这里插入图片描述


模型训练

在这里插入图片描述
模型结构:

模型采用 输入层 + 两个隐藏层 + 输出层 结构:

  • 输入层:
    x_train 包含60000张图片,每张图片 28 * 28 pixel;

  • 隐藏层:
    总共两层隐藏层,第一层包含25个神经元,第二层包含15个神经元;
    隐藏层采用 ReLU 激活函数;

  • 输出层
    输出层总共10个神经元(因为输出结果 0~9 共十个);
    采用 linear 线性激活函数 以及 Softmax 激活函数做对比;

模型参数:

  • layer1:
    • w [ 1 ] w^{[1]} w[1].shape is (784, 25)
    • b [ 1 ] b^{[1]} b[1].shape is (25,)
  • layer2:
    • w [ 2 ] w^{[2]} w[2].shape is (25, 15)
    • b [ 2 ] b^{[2]} b[2].shape is: (15,)
  • layer3:
    • w [ 3 ] w^{[3]} w[3].shape is (15, 10)
    • b [ 3 ] b^{[3]} b[3].shape is: (10,)

模型创建程序

from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

model = Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(25, activation='relu', name='layer1'),
    layers.Dense(15, activation='relu', name='layer2'),
    layers.Dense(10, activation='softmax', name='layer3')
], name='Minst_Model'
)

model.summary()
  • layers.Dense 是TensorFlow中的一个类,用于定义全连接层。全连接层也被称为密集层(Dense Layer),因为其中每个神经元都与前一层的每个神经元相连,可以将前一层的所有输入都与权重相乘并加上偏置,得到一组输出。

    • units:全连接层中的神经元数量;
    • activation:全连接层的激活函数。如果没有指定,默认使用线性激活函数;
    • name:全连接层的名称;
  • model.summary() 用于打印模型的摘要信息,包括每一层的名称、形状和参数数量。它非常有用,可以让我们快速查看模型的架构和大小,以便进行调试和优化。

在这里插入图片描述


模型编译程序

model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)
)
  • SparseCategoricalCrossentropy(from_logits=True) 损失函数
    • Sparse 表示标签是稀疏编码(即单个整数),而不是独热编码;Categorical 表示标签是分类数据;Crossentropy 表示使用交叉熵来计算损失。
    • from_logits=True 如果在输出层的Dense中选择Softmax作为激活函数,那么在使用交叉熵损失函数时,需要将from_logits参数设置为False;若未使用 Softmax 则需要设置 from_logits 参数为 True。
  • 使用 Adam 优化算法来更新模型的权重参数,并且设置学习率 learning_rate 为0.001。

模型训练程序

history = model.fit(
    x_train, y_train,
    epochs=40
)

model.fit() 是Keras中用于训练模型的方法,其作用是对给定的训练数据进行模型训练,并返回训练过程中的历史记录,包括损失函数和指定的评价指标的值。在模型训练过程中,model.fit()方法将对模型的参数进行更新,使得模型的预测结果逐渐接近真实结果,从而提高模型的性能。

  • x_trainy_train 分别表示训练数据的输入特征和标签;
  • epochs 表示训练的轮数
  • history 记录历史记录信息,包含每次的损失值等

在这里插入图片描述
使用损失值构建损失值下降图像:

import matplotlib.pyplot as plt

# 获取训练历史中的 loss 值和验证集上的 loss 值
train_loss = history.history['loss']

# 绘制 loss 曲线
plt.plot(train_loss, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

在这里插入图片描述


模型预测程序

from sklearn.metrics import accuracy_score
import numpy as np

pred_y = model.predict(x_test)
pred_labels = np.argmax(pred_y, axis=1)
acc = accuracy_score(y_test, pred_labels)
print('Test accuracy:', acc)

model.predict() 是神经网络模型用于进行预测的方法;
np.argmax() 获取每个输入的最大值索引,以获取预测标签;
accuracy_score 是一个来自Scikit-learn库的函数,用于计算分类问题的准确率;


完整代码

import tensorflow as tf
from tensorflow import keras

# # # 输入数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
print(x_train.shape)

# # # 显示训练集中前25个图像
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10,10))
axes = axes.ravel()

for i in range(25):
    axes[i].imshow(x_train[i], cmap='gray')
    axes[i].set_title(y_train[i])
    axes[i].axis('off')

plt.subplots_adjust(hspace=0.5)
plt.show()

# # # 设定神经网络模型
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([
    layers.Flatten(input_shape=(28, 28, 1)),
    layers.Dense(25, activation='relu', name='layer1'),
    layers.Dense(15, activation='relu', name='layer2'),
    layers.Dense(10, activation='softmax', name='layer3')
],name='MINST_Model'
)

model.summary()

# # # 神经网络编译部分
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
)

# # # 训练模型
history = model.fit(
    x_train, y_train,
    epochs=40
)

# # # 根据 Loss 绘制曲线
import matplotlib.pyplot as plt

train_loss = history.history['loss']

plt.plot(train_loss, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# # # 进行预测
from sklearn.metrics import accuracy_score
import numpy as np

pred_y = model.predict(x_test)
pred_labels = np.argmax(pred_y, axis=1)
acc = accuracy_score(y_test, pred_labels)
print('Test accuracy:', acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/424560.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

三行Python代码,让数据处理速度提高2到6倍

本文可以教你仅使用 3 行代码,大大加快数据预处理的速度。 Python 是机器学习领域内的首选编程语言,它易于使用,也有很多出色的库来帮助你更快处理数据。但当我们面临大量数据时,一些问题就会显现…… 在默认情况下,…

OpenShift 4 - 使用 virtctl 远程访问 OpenShift Virtualization 的虚拟机

《OpenShift / RHEL / DevSecOps 汇总目录》 说明:本文已经在支持 OpenShift 4.12 的 OpenShift 环境中验证 在《OpenShift 4 - 用 OpenShift Virtualization 运行容器化虚拟机 (视频)》一文中使用了 OpenShift 控制台直接访问运行在 OpenSh…

SQL中去除重复数据的几种方法,我一次性都告诉你​

使用SQL对数据进行提取和分析时,我们经常会遇到数据重复的场景,需要我们对数据进行去重后分析。以某电商公司的销售报表为例,常见的去重方法我们用到distinct 或者group by 语句, 今天介绍一种新的方法,利用窗口函数对…

MIT 6.S965 韩松课程 05

Lecture 05: Quantization (Part 1) 文章目录Lecture 05: Quantization (Part 1)动机数字的数据类型整数定点数浮点数量化基于 K-Means 的量化 [[Han et al., ICLR 2016]](https://arxiv.org/pdf/1510.00149v5.pdf)线性量化 [[Jacob et al. CVPR 2018]](https://arxiv.org/pdf/…

Makefile项目管理-----在Linux下编译c/c++程序

这里写目录标题起因makefile项目管理一、用途:二、 makefile的基础规则1.多文件联合编译2. makefile检测原理3. ALL来指定终极目标三、 makefile的两个函数和clean四、 makefile中的三个自动变量五、模式规则六、 静态模式规则七、 扩展1. 扩展1 伪目标2. 扩展2 可添…

在 Python 中检查字符串是否为 ASCII

使用 str.isascii() 方法检查字符串是否为 ASCII,例如 if my_str.isascii():。 如果字符串为空或字符串中的所有字符都是 ASCII,则 str.isascii() 方法返回 True,否则返回 False。 my_str www.jiyik.comif my_str.isascii():# &#x1f447…

网络安全工程师做什么?

​ 网络安全很复杂。数字化转型、远程工作和不断变化的威胁形势需要不同的工具和不同的技能组合。 系统必须到位以保护端点、身份和无边界网络边界。负责处理这种复杂安全基础设施的工作角色是网络安全工程师。 简而言之,网络安全工程师是负责设计和实施组织安全系…

基于TF-IDF+KMeans聚类算法构建中文文本分类模型(附案例实战)

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

UHD安装教程

UHD Universal Hardware Driver,即USRP驱动。 UHD,Windows平台安装教程 uhd驱动安装 http://files.ettus.com/binaries/misc/erllc_uhd_winusb_driver.zip 安装LibUSBx http://files.ettus.com/binaries/uhd/latest_release 下载默认C盘 环境配置 将…

Android FrameWork 知识点与面试题整合~

1.如何对 Android 应用进行性能分析 android 性能主要之响应速度 和UI刷新速度。 首先从函数的耗时来说,有一个工具TraceView 这是androidsdk自带的工作,用于测量函数耗时的。 UI布局的分析,可以有2块,一块就是Hierarchy Viewe…

面试-Sqrt(x)

题目 给你一个非负整数 x ,计算并返回 x 的 算术平方根 。 由于返回类型是整数,结果只保留 整数部分 ,小数部分将被 舍去 。 注意:不允许使用任何内置指数函数和算符,例如 pow(x, 0.5) 或者 x ** 0.5 。 思路 二分查…

项目管理:项目进度难以把控,项目经理应该怎么办?

项目管理中,对进度的管理也是保障整个项目顺利完成的重要条件。项目进度难以把控,项目常常延期,项目经理怎么办?如何跟进整个项目的进度? 对于如何做好项目进度管理,有几点建议,希望能对大家有…

Java实现导出多个excel表打包到zip文件中,供客户端另存为窗口下载

文章目录一、业务背景二、实现思路二、准备工作1.准备data模板.xlsx2.引入poi相关依赖,用于操作excel3.针对WorkBookZIP压缩输入/输出流,相关方法知识点要有所了解三、完整的项目代码四、可能遇到的问题错误场景1:java.io.IOException: Strea…

【RabbitMQ】SpringBoot整合RabbitMQ实现延迟队列、TTL、DLX死信队列

目录 一、TTL 1、什么是TTL 2、设置TTL的两种方式 3、控制台设置TTL 4、SpringBoot实现两种方式设置TTL 1.给消息设置过期时间 2.给队列设置过期时间 二、DLX死信队列 1、什么是死信交换机与死信队列 2、消息何时会成为死信 3、队列如何绑定死信交换机与死信队列 4…

vscode“检测到 #include 错误,请更新 includepath。”的问题解决办法

目录 一.报错更新includepath​编辑 二.原因 三.解决方法 一.报错更新includepath 如图 二.原因 1.没有安装gcc 2.没有配置好环境 winR打开cmd,输入gcc -v,如果安装了gcc,会返回版本 三.解决方法 1.安装MinGW 2.添加MinGW环境变量 将bin文件夹的位置添加到系统环境变量中…

三分钟搭建个人博客技术栈Nuxt3+vite+mysql+koa2

最近也是想入一下Nuxt3的坑,然后就写了一个博客系统,目前已开源github,欢迎大家star!!! 效果预览 网址:http://180.76.121.2:3000/ github地址 https://github.com/ztzzhi/ztzzhi-nuxt3-vite…

MySQL事物(基础篇)

MySQL事务事物的基本概念事物的ACID属性事务的使用事务隔离级别MVCC&ReadViewMySQL是否还存在幻读事物的基本概念 Transaction作为关系型数据库的核心组成,在数据安全方面有着非常重要的作用,本文会一步步解析事务的核心特性,以获得对事…

多云数据存储,理想与现实之间还差着什么?

去年底,“数据二十条”正式颁布,数据要素全面提速已是指日可待。 无疑,数据作为数字经济的基础,其价值的释放依赖于数据的流动、共享和应用。数据要素只有充分地流动和应用起来,才能够实现价值的最大化。 换而言之&a…

VPN、IPSEC、AH、ESP、IKE、DSVPN

目录 1.什么是数据认证,有什么作用,有哪些实现的技术手段? 2.什么是身份认证,有什么作用,有哪些实现的技术手段? 3.什么VPN技术? 4. VPN技术有哪些分类? 5. IPSEC技术能够提供哪些安全服务? 6. IPSEC的技术架构是什么?…

idea中使用git工具

目录一、IDEA中配置git二、git操作将项目设置成git仓库一、IDEA中配置git 打开idea,点击File–>Settings 点击版本控制,然后点击git 将你的git.exe安装目录填到下面位置 点击test可以看到显示了版本,说明配置成功 二、git操作 将项目设置…