使用MNIST数据集训练手写数字识别模型

news2025/1/13 9:43:51

一、MNIST数据集介绍
MNIST 数据集(手写数字数据集)是一个公开的公共数据集,任何人都可以免费获取它。目前,它已经是一个作为机器学习入门的通用性特别强的数据集之一,所以对于想要学习机器学习分类的、深度神经网络分类的、图像识别与处理的小伙伴,都可以选择MNIST数据集入门。

二、MNIST数据集结构
MNIST 数据集包含70000(60000+10000)个样本,其中有60000个训练样本和10000个测试样本,每个样本的像素大小为28*28。

1.MNIST数据集下载方式


方法一
下载地址:http://yann.lecun.com/exdb/mnist/

可以直接下载这四个文件,这四个文件分别为:
①训练样本的图像(60000个)
②对应训练样本上每一张图像上数字的标签(0~9)(60000个)
③测试样本的图像(10000个)
④对应测试样本上每一张图像上数字的标签(0~9)(10000个)

方法二
在Keras中已经内置了多种公共数据集,其中就包含MNIST数据集,如图所示。

所以可以直接调用 tf.keras.datasets.mnist,直接下载数据集。


2.开始训练

可以跟着一步一步做,不会出错

(1)导包

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time

(2)打印开始时间

print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print(nowtime)

 效果展示:

(3)预处理

#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

#加载数据
mnist = tf.keras.datasets.mnist

(train_x,train_y),(test_x,test_y) = mnist.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape)) 

#数据预处理
#X_train = train_x.reshape((60000,28*28))
#Y_train = train_y.reshape((60000,28*28))       #后面采用tf.keras.layers.Flatten()改变数组形状
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

效果展示:

(4)建立模型查看结构

# 建立模型
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
print('\n',model.summary())     #查看网络结构和参数信息

    效果展示:       

(5)开始训练

#配置模型训练方法
#adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])   

#训练模型
#批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练前时刻:'+str(nowtime))

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练后时刻:'+str(nowtime))

效果展示:

(6)评估模型

#评估模型
model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力

效果展示:

(7)结果可视化

#结果可视化
print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率

plt.figure(figsize=(10,3))

plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

#暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
#根据需要自行选择
#plt.ion()       #打开交互式操作模式
#plt.show()
#plt.pause(5)
#plt.close()

#使用模型
plt.figure()
for i in range(10):
    num = np.random.randint(1,10000)

    plt.subplot(2,5,i+1)
    plt.axis('off')
    plt.imshow(test_x[num],cmap='gray')
    demo = tf.reshape(X_test[num],(1,28,28))
    y_pred = np.argmax(model.predict(demo))
    plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
#y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
#print('X_test[0:5]: %s'%(X_test[0:5].shape))
#print('y_pred: %s'%(y_pred))

#plt.ion()       #打开交互式操作模式
plt.show()
#plt.pause(5)
#plt.close()

展示效果:

3.测试模型

1.修改测试图片的路径

2.修改保存模型的路径

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import cv2

# 建立模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))     # 添加Flatten层说明输入数据的形状
model.add(tf.keras.layers.Dense(128, activation='relu'))     # 添加隐含层,为全连接层,128个节点,relu激活函数
model.add(tf.keras.layers.Dense(10, activation='softmax'))   # 添加输出层,为全连接层,10个节点,softmax激活函数

# 加载模型参数
model.load_weights('mnist_weights.h5') # 路径根据文件实际位置修改,不然会报错

# 定义一个函数来预处理图片
def preprocess_image(image_path):
    # 读取图片,转换为灰度图
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        print(f"无法加载图片:{image_path}")
        return None
    # 调整图片大小为28x28像素
    img = cv2.resize(img, (28, 28))
    # 归一化图片像素值到0-1范围
    img = img / 255.0
    # 转换图片形状以匹配模型输入
    img = img.reshape(1, 28, 28)
    return img

# 使用模型进行预测
plt.figure()
# 这里替换为你的图片路径列表
image_paths = ['shouxieti_img/test/0_10.jpg']
for i, image_path in enumerate(image_paths):
    # 预处理图片
    img = preprocess_image(image_path)
    if img is not None:
        # 使用模型进行预测
        y_pred = np.argmax(model.predict(img))
        
        # 显示图片和预测结果
        plt.subplot(1, len(image_paths), i+1)
        plt.imshow(img[0], cmap='gray')
        plt.axis('off')
        plt.title('预测值:' + str(y_pred))
    
plt.show()

 效果展示:

话不多说 源码奉上!

 4.全部代码

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time

print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print(nowtime)

#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

#加载数据
mnist = tf.keras.datasets.mnist

(train_x,train_y),(test_x,test_y) = mnist.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape)) 

#数据预处理
#X_train = train_x.reshape((60000,28*28))
#Y_train = train_y.reshape((60000,28*28))       #后面采用tf.keras.layers.Flatten()改变数组形状
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

# 建立模型
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
print('\n',model.summary())     #查看网络结构和参数信息

#配置模型训练方法
#adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])   

#训练模型
#批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练前时刻:'+str(nowtime))

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练后时刻:'+str(nowtime))

#评估模型
model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力


#结果可视化
print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率

plt.figure(figsize=(10,3))

plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

#暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
#根据需要自行选择
#plt.ion()       #打开交互式操作模式
#plt.show()
#plt.pause(5)
#plt.close()

#使用模型
plt.figure()
for i in range(10):
    num = np.random.randint(1,10000)

    plt.subplot(2,5,i+1)
    plt.axis('off')
    plt.imshow(test_x[num],cmap='gray')
    demo = tf.reshape(X_test[num],(1,28,28))
    y_pred = np.argmax(model.predict(demo))
    plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
#y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
#print('X_test[0:5]: %s'%(X_test[0:5].shape))
#print('y_pred: %s'%(y_pred))

#plt.ion()       #打开交互式操作模式
plt.show()
#plt.pause(5)
#plt.close()

 转载于:【神经网络与深度学习】使用MNIST数据集训练手写数字识别模型——[附完整训练代码]_使用mnist数据集进行模型训练时-CSDN博客

 谢谢支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1819809.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springboot+vue3-学生信息管理系统

以下为学生信息管理的主要前端界面 项目仓库 CYHone/vue-student-ms: 学生信息管理系统前端界面 (github.com) 项目仓库 CYHone/student-ms-server: 学生信息管理系统后端代码 (github.com) 4.1、学生功能 登录界面可以以三种身份登录:学生、教师、管理员。 选…

jsp 实验20

三、源代码以及执行结果截图&#xff1a; NewFile.jsp <% page import "java.io.*" %> <% page contentType"text/html" %> <% page pageEncoding "utf-8" %> <jsp:useBean id"english" class "web.Engli…

oracle 删除当前用户下所有表

荆轲刺秦王 通常呢 我们将正式环境的 oracle 数据库 导出成 dmp 文件&#xff0c;然后导入到测试环境或者本地环境&#xff0c;期间可能会出现各种问题。那么如何使错误的导入数据全部删除呢。可以这样做&#xff1a; 1. 本地虚拟机启动 oracle 服务 2. sqldeveloper 连接 o…

多商户小程序开发步骤和方法

在当今的数字经济中&#xff0c;多商户小程序作为一种创新的商业平台&#xff0c;提供了一种新的商业模式&#xff0c;使多个商户能够在同一平台上展示和销售他们的产品或服务。这种模式不仅增强了消费者选择的多样性&#xff0c;也为商家提供了一个更广泛的销售渠道。以下是详…

探索Java 8 Stream API:现代数据处理的新纪元

Stream流 Stream初探&#xff1a;何方神圣&#xff1f; Stream流是一种处理集合数据的高效工具&#xff0c;它可以让你以声明性的方式处理数据集合。Stream不是存储数据的数据结构&#xff0c;而是对数据源&#xff08;如集合、数组&#xff09;的运算操作概念&#xff0c;支…

微信小程序开发教程

尚硅谷微信小程序开发教程&#xff0c;2024最新版微信小程序项目实战&#xff01; 一、小程序基础 1. 初始小程序 微信小程序是一种运行在微信内部的 轻量级 应用程序。 使用小程序时 不需要下载&#xff0c;用户 扫一扫 或 搜一下 即可打开应用&#xff0c;它也体现了 “用…

基于Django和Vue的商城管理系统

文章目录 前言一、系统运行结果二、相关技术简介三、系统设计四、系统测试五、总结 前言 近年来&#xff0c;互联网技术的飞速发展极大地改变了人们的生活方式。网络购物作为一种新的购物模式&#xff0c;因其方便、快捷、选择多样等优点&#xff0c;迅速普及。为了满足人们日…

Python第二语言(十二、SQL入门和实战)

目录 1. Python中使用MySQL 1.1 pymysql第三方库使用MySQL 1.2 连接MySQL 1.3 操作数据库&#xff0c;创建表 1.4 执行查询数据库语句 2. python中MySQL的插入语句 2.1 commit提交 2.2 自动提交 3. pymysql案例 3.1 数据内容 3.2 DDL定义 3.3 实现步骤 3.4 文件操…

UITableView之cell复用

关于cell复用的必要性 cellForRowAtIndexPath会随着屏幕滚动而调用&#xff0c;每次出现新行时因为行号变化&#xff0c;就会被调用。 底层原理&#xff1a;当前单元格滚出屏幕时cell销毁&#xff0c;当前单元格又滚回来时cell创建。短时间内频繁创建和销毁cell会影响系统性能…

vue2文件下载和合计表格

vue2文件数据不多可以直接下载不需要后端的时候 1.首先&#xff0c;确保你已经安装了 docx 和 file-saver 库 npm install file-saver npm install docx file-saver2.全部代码如下 <template><a-modaltitle"详情"width"80%"v-model"visi…

舵机堵转的危害与简单解决方式

舵机的堵转保护是一种安全特性&#xff0c;用于防止舵机在遇到阻力无法正常旋转时受到损害。当舵机尝试移动到某个位置但因为外部阻力&#xff08;如卡住或碰撞&#xff09;而无法完成动作时&#xff0c;它会持续施加力直至达到其最大扭矩。如果没有堵转保护&#xff0c;这种情…

OpenCV学习(4.14) 基于分水岭算法的图像分割

1. 目标 在这一章当中&#xff0c; 我们将学习使用分水岭算法使用基于标记的图像分割我们将看到&#xff1a;cv.watershed() 2.理论 任何灰度图像都可以看作是地形表面&#xff0c;其中高强度表示峰和丘陵&#xff0c;而低强度表示山谷。您开始用不同颜色的水&#xff08;标…

什么是基于风险的漏洞管理RBVM,及其优势

文章目录 一、什么是漏洞管理二、什么是基于风险的漏洞管理RBVM三、RBVM的基本流程四、RBVM的特点和优势 一、什么是漏洞管理 安全漏洞是网络或网络资产的结构、功能或实现中的任何缺陷或弱点&#xff0c;黑客可以利用这些缺陷或弱点发起网络攻击&#xff0c;获得对系统或数据…

Vue前端ffmpeg压缩视频再上传(全网唯一公开真正实现)

1.Vue项目中安装插件ffmpeg 1.1 插件版本依赖配置 两个插件的版本 "ffmpeg/core": "^0.10.0", "ffmpeg/ffmpeg": "^0.10.1"package.json 和 package-lock.json 都加入如下ffmpeg的版本配置&#xff1a; 1.2 把ffmpeg安装到项目依…

飞腾派初体验(2)

水个字数&#xff0c;混个推广分&#xff0c;另外几个点还是想吐槽一下 - 1&#xff0c;上篇文章居然没有给开发板一个硬照&#xff0c;补上 - 飞腾派 自拍 2. 现在做镜像用Win32DiskImager的多吗&#xff1f;我记得当年都是dd命令搞定&#xff0c;玩树莓派的应该记得这个命令…

离散化——Acwing.802区间和

离散化 定义 离散化可以简单理解为将连续的数值或数据转换为离散的、有限个不同的值或类别。离散化就是将一个可能具有无限多个取值或在一个较大范围内连续取值的变量&#xff0c;通过某种规则或方法&#xff0c;划分成若干个离散的区间或类别&#xff0c;并将原始数据映射到…

【课程总结】Day8(下):计算机视觉基础入门

前言 数据结构 在人工智能领域&#xff0c;机器可以处理的数据类型如上图&#xff0c;大约可以分为以上类别。其中较为常用的数据类别有&#xff1a; 表格类数据 数据特点&#xff1a; 成行成列&#xff1a;一行一个样本&#xff0c;一列一个特征特征之间相互独立&#xff0…

最实用的AI软件开发工具CodeFlying测评

就在上个月&#xff0c;OpenAI宣布GPT-4o支持免费试用&#xff0c;调用API价格降到5美元/百万token。 谷歌在得到消息后立马将Gemini 1.5 的价格下降到0.35美元/百万token。 Anthropic的API价格&#xff0c;直接干到了0.25美元/百万token。 国外尚且如此&#xff0c;那么国内…

科技赋能,避震婴儿车或成为行业硬通货

全球知识经济发展发展到今天&#xff0c;消费者对于品质、服务、体验的要求越来越高&#xff0c;与之对应的产品也就越来越科技化、智能化、个性化&#xff0c;品牌化和差异化逐步成为产品的竞争核心。 婴儿推车作为关系婴幼儿出行安全的支柱性产业之一&#xff0c;从车架结构…

分享一份 .NET Core 简单的自带日志系统配置,平时做一些测试或个人代码研究,用它就可以了

前言 实际上&#xff0c;.NET Core 内部也内置了一套日志系统&#xff0c;它是一个轻量级的日志框架&#xff0c;用于记录应用程序的日志信息。 它提供了 ILogger 接口和 ILoggerProvider 接口&#xff0c;以及一组内置的日志提供程序&#xff08;如 Console、Debug、EventSo…