02、Tensorflow实现手写数字识别(数字0-9)

news2025/2/22 18:39:13

02、Tensorflow实现手写数字识别(数字0-9)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0与pycharm

1、识别目标

识别手写仅仅是为了区分手写的0到9,所以实际上是一个多分类问题。
STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import  classification_report
import matplotlib.pyplot as plt
import logging
import warnings

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

from sklearn.metrics import classification_report 这行代码的主要作用是导入classification_report 函数,以便在后续的代码中使用它来评估分类模型的性能。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。


STEP2:屏蔽无用警告并允许中文

# 使用warnings模块来忽略特定类型的警告  
warnings.simplefilter(action='ignore', category=FutureWarning)  
# 配置tensorflow的日志记录级别  
logging.getLogger("tensorflow").setLevel(logging.ERROR)  
# 设置TensorFlow的autograph模块的详细级别  
tf.autograph.set_verbosity(0)  
# 设置numpy的打印选项  
np.set_printoptions(precision=2)  

STEP3:加载数据集并分割测试集

# load dataset
def load_data():
    X = np.load("Handwritten_Digit_Recognition_Multiclass_data/X.npy")
    y = np.load("Handwritten_Digit_Recognition_Multiclass_data/y.npy")
    return X, y

# load dataset
X, y = load_data()

print ('The shape of X is: ' + str(X.shape))
print ('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

原始的输入的数据集是5000* 400数组,共包含5000个手写数字的数据,其中400为20*20像素的图片,
在这里插入图片描述


STEP4:模型构建与训练

# 构建模型  
tf.random.set_seed(1234)  # 设置随机种子以确保每次运行的结果是一致的  
model = Sequential(
    [
        ### START CODE HERE ###  
        tf.keras.Input(shape=(400,)),  # 输入层,输入数据的形状是400维   
        Dense(100, activation='relu', name="L1"),  # 全连接层,100个神经元,使用ReLU激活函数,命名为"L1"  
        Dense(75, activation='relu', name="L2"),  # 全连接层,75个神经元,使用ReLU激活函数,命名为"L2"  
        Dense(10, activation='linear', name="L3"),  # 输出层,10个神经元,使用线性激活函数,命名为"L3"  
        ### END CODE HERE ###  
    ], name="my_model"
)  # 定义模型名称为"my_model"  
model.summary()  # 打印模型的概述信息  

# 配置模型的训练参数  
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # 使用稀疏分类交叉熵作为损失函数,且输出是logits(即未经过softmax的原始输出)  
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # 使用Adam优化器,并设置学习率为0.001  
)

# 训练模型  
history = model.fit(
    X_train, y_train,  # 使用X_train作为输入数据,y_train作为目标数据  
    epochs=100  # 训练100轮  
)


STEP5:结果可视化与打印准确度信息

fig, axes = plt.subplots(20, 25, figsize=(20, 25))
fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):
    # Select random indices
    random_index = np.random.randint(X_test.shape[0])
    # Select rows corresponding to the random indices and
    # reshape the image
    X_random_reshaped = X_test[random_index].reshape((20, 20)).T
    # Display the image
    ax.imshow(X_random_reshaped, cmap='gray')
    # Predict using the Neural Network
    prediction = model.predict(X_test[random_index].reshape(1, 400))
    prediction_p = tf.nn.softmax(prediction)
    yhat = np.argmax(prediction_p)
    # 错误结果标红
    if y_test[random_index, 0] == yhat:
        ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10)
        ax.set_axis_off()
    else:
        ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10, color='red')
        ax.set_axis_off()

fig.suptitle("Label, yhat", fontsize=14)
plt.show()

# 给出预测的测试集误差
def evaluation(y_test, y_predict):
    accuracy=classification_report(y_test, y_predict,output_dict=True)['accuracy']
    s=classification_report(y_test, y_predict,output_dict=True)['weighted avg']
    precision=s['precision']
    recall=s['recall']
    f1_score=s['f1-score']
    #kappa=cohen_kappa_score(y_test, y_predict)
    return accuracy,precision,recall,f1_score #, kappa

y_pred=model.predict(X_test)
prediction_p = tf.nn.softmax(y_pred)
yhat = np.argmax(prediction_p, axis=1)
accuracy,precision,recall,f1_score=evaluation(y_test,yhat)

print("测试数据集准确率为:", accuracy)
print("测试数据集精确率为:", precision)
print("测试数据集召回率为:", recall)
print("测试数据集F1_score为:", f1_score)

3、运行结果

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现手写数字识别(数字0-9)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import  classification_report
import matplotlib.pyplot as plt
import logging
import warnings

# 使用warnings模块来忽略特定类型的警告
warnings.simplefilter(action='ignore', category=FutureWarning)
# 配置tensorflow的日志记录级别
logging.getLogger("tensorflow").setLevel(logging.ERROR)
# 设置TensorFlow的autograph模块的详细级别
tf.autograph.set_verbosity(0)
# 设置numpy的打印选项
np.set_printoptions(precision=2)

# load dataset
def load_data():
    X = np.load("Handwritten_Digit_Recognition_Multiclass_data/X.npy")
    y = np.load("Handwritten_Digit_Recognition_Multiclass_data/y.npy")
    return X, y

# load dataset
X, y = load_data()

print ('The shape of X is: ' + str(X.shape))
print ('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

# # 绘图可选
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(5, 5))
# fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
# # fig.tight_layout(pad=0.5)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # reshape the image
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
#     fig.suptitle("Label, image", fontsize=14)
# plt.show()

# 构建模型
tf.random.set_seed(1234)  # 设置随机种子以确保每次运行的结果是一致的
model = Sequential(
    [
        ### START CODE HERE ###
        tf.keras.Input(shape=(400,)),  # 输入层,输入数据的形状是400维
        Dense(100, activation='relu', name="L1"),  # 全连接层,100个神经元,使用ReLU激活函数,命名为"L1"
        Dense(75, activation='relu', name="L2"),  # 全连接层,75个神经元,使用ReLU激活函数,命名为"L2"
        Dense(10, activation='linear', name="L3"),  # 输出层,10个神经元,使用线性激活函数,命名为"L3"
        ### END CODE HERE ###
    ], name="my_model"
)  # 定义模型名称为"my_model"
model.summary()  # 打印模型的概述信息

# 配置模型的训练参数
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # 使用稀疏分类交叉熵作为损失函数,且输出是logits(即未经过softmax的原始输出)
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # 使用Adam优化器,并设置学习率为0.001
)

# 训练模型
history = model.fit(
    X_train, y_train,  # 使用X_train作为输入数据,y_train作为目标数据
    epochs=100  # 训练100轮
)


fig, axes = plt.subplots(20, 25, figsize=(20, 25))
fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):
    # Select random indices
    random_index = np.random.randint(X_test.shape[0])
    # Select rows corresponding to the random indices and
    # reshape the image
    X_random_reshaped = X_test[random_index].reshape((20, 20)).T
    # Display the image
    ax.imshow(X_random_reshaped, cmap='gray')
    # Predict using the Neural Network
    prediction = model.predict(X_test[random_index].reshape(1, 400))
    prediction_p = tf.nn.softmax(prediction)
    yhat = np.argmax(prediction_p)
    # Display the label above the image
    if y_test[random_index, 0] == yhat:
        ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10)
        ax.set_axis_off()
    else:
        ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10, color='red')
        ax.set_axis_off()

fig.suptitle("Label, yhat", fontsize=14)
plt.show()

# 给出预测的测试集误差
def evaluation(y_test, y_predict):
    accuracy=classification_report(y_test, y_predict,output_dict=True)['accuracy']
    s=classification_report(y_test, y_predict,output_dict=True)['weighted avg']
    precision=s['precision']
    recall=s['recall']
    f1_score=s['f1-score']
    #kappa=cohen_kappa_score(y_test, y_predict)
    return accuracy,precision,recall,f1_score #, kappa

y_pred=model.predict(X_test)
prediction_p = tf.nn.softmax(y_pred)
yhat = np.argmax(prediction_p, axis=1)
accuracy,precision,recall,f1_score=evaluation(y_test,yhat)

print("测试数据集准确率为:", accuracy)
print("测试数据集精确率为:", precision)
print("测试数据集召回率为:", recall)
print("测试数据集F1_score为:", f1_score)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1255693.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WebSocket协议测试实战

当涉及到WebSocket协议测试时,有几个关键方面需要考虑。在本文中,我们将探讨如何使用Python编写WebSocket测试,并使用一些常见的工具和库来简化测试过程。 1、什么是WebSocket协议? WebSocket是一种在客户端和服务器之间提供双向…

基于OPC UA 的运动控制读书笔记(1)

最近一段时间集中研究OPCUA 在机器人控制应用中应用的可能性。这个话题自然离不开运动控制。 笔者对运动控制不是十分了解。于是恶补EtherCAT 驱动,PLCopen 运动控制的知识,下面是自己的读书笔记和实现OPCUA /IEC61499 运动控制器的实现方案设想。 PLCo…

【Spring整合Junit】Spring整合Junit介绍

本文内容基于【Spring整合MyBatis】Spring整合MyBatis的具体方法进行测试 文章目录 1. 导入相关坐标2. 使用Junit测试所需注解3. 在测试类中写相关内容 1. 导入相关坐标 在pom.xml中导入相关坐标&#xff1a; <dependency><groupId>junit</groupId><ar…

CSS常用笔记

1. 脱离文档流&#xff0c;用于微调 {position: relative; top: 10px; right: 0; } 2. flex布局大法 <div class"demo"><div class"demo-1"></div><div class"demo-2"></div><div class"demo-3"&…

Linux面试题(二)

目录 17、怎么使一个命令在后台运行? 18、利用 ps 怎么显示所有的进程? 怎么利用 ps 查看指定进程的信息&#xff1f; 19、哪个命令专门用来查看后台任务? 20、把后台任务调到前台执行使用什么命令?把停下的后台任务在后台执行起来用什么命令? 21、终止进程用什么命令…

【计网 面向连接的传输TCP】 中科大笔记 (十 二)

目录 0 引言1 TCP 的特性1.1 拓展&#xff1a;全双工、单工、半双工通信 2 TCP报文段结构3 TCP如何实现RDT4 TCP 流量控制4.1 题外话&#xff1a;算法感悟 5 TCP连接3次握手、断开连接4次握手5.1 连接5.2 断开连接 6 拥塞控制6.1 拥塞控制原理6.2 TCP拥塞控制 &#x1f64b;‍♂…

shell脚本 ( 函数 数组 冒泡排序)

目录 什么是函数 使用函数的方法 格式 注意事项 函数的使用 函数可以直接使用 函数变量的作用范围 函数返回值 查看函数 删除函数 函数的传递参数 使用函数文件 ​编辑 拓展递归函数 例&#xff1a;求5的阶乘 什么是数组 使用数组的方法 1.先声明 2.定义数组 3…

Python---函数的数据---拆包的应用案例(两个变量值互换,*args, **kwargs调用时传递参数用法)

案例&#xff1a; 使用至少3种方式交换两个变量的值 第一种方式&#xff1a;引入一个临时变量 c1 10 c2 2# 引入临时变量temp temp c2 c2 c1 c1 tempprint(c1, c2) 第二种方式&#xff1a;使用加法与减法运算交换两个变量的值&#xff08;不需要引入临时变量&#xff09…

ArcGIS制作广场游客聚集状态及密度图

文章目录 一、加载实验数据二、平均最近邻法介绍1. 平均最近邻工具2. 广场游客聚集状态3. 结果分析三、游客密度制图一、加载实验数据 二、平均最近邻法介绍 1. 平均最近邻工具 “平均最近邻”工具将返回五个值:“平均观测距离”、“预期平均距离”、“最近邻指数”、z 得分和…

C++学习之路(五)C++ 实现简单的文件管理系统命令行应用 - 示例代码拆分讲解

简单的文件管理系统示例介绍: 这个文件管理系统示例是一个简单的命令行程序&#xff0c;允许用户进行文件的创建、读取、追加内容和删除操作。这个示例涉及了一些基本的文件操作和用户交互。 功能概述&#xff1a; 创建文件 (createFile())&#xff1a; 用户可以输入文件名和内…

计算机系统的层次结构与性能指标

目录 一. 计算机系统的层次结构二. 计算机性能指标2.1. 存储器的性能指标2.2 CPU的性能指标2.3 系统整体的性能指标2.4 系统整体的性能指标(动态测试) \quad 一. 计算机系统的层次结构 \quad \quad 虚拟机器的意思是看起来像是机器直接就能执行程序员所写的代码, 其实是需要通过…

Java王者荣耀

一、创建项目 二、代码 package com.sxt;import javax.swing.*; import java.awt.*;public class Background extends GameObject {public Background(GameFrame gameFrame) {super(gameFrame);// TODO Auto-generated constructor stub}Image bg Toolkit.getDefaultToolkit(…

基于helm的方式在k8s集群中部署gitlab - 备份恢复(二)

接上一篇 基于helm的方式在k8s集群中部署gitlab - 部署&#xff08;一&#xff09;&#xff0c;本篇重点介绍在k8s集群中备份gitlab的数据&#xff0c;并在虚拟机上部署相同版本的gitlab&#xff0c;然后将备份的数据进行还原恢复 文章目录 1. 备份2. 恢复到虚拟机上的gitlab2.…

java学习part13Object类和常用方法

1.Object 2.常用方法 2.1clone() clone()就是深拷贝&#xff0c;创建一个同内容新对象。需要实现接口 2.2finalize()已废弃 类似于析构函数&#xff0c;在GC回收之前调用。 System.gc()强制调用gc&#xff0c;然后就能看到finalize()的输出 2.3equals() 对于引用类型可用。…

帮管客CRM SQL注入漏洞复现

0x01 产品简介 帮管客CRM是一款集客户档案、销售记录、业务往来等功能于一体的客户管理系统。帮管客CRM客户管理系统&#xff0c;客户管理&#xff0c;从未如此简单&#xff0c;一个平台满足企业全方位的销售跟进、智能化服务管理、高效的沟通协同、图表化数据分析帮管客颠覆传…

Linux(8):BASH

硬件、核心与 Shell 操作系统其实是一组软件&#xff0c;由于这组软件在控制整个硬件与管理系统的活动监测&#xff0c;如果这组软件能被用户随意的操作&#xff0c;若使用者应用不当&#xff0c;将会使得整个系统崩溃。因为操作系统管理的就是整个硬件功能。 应用程序在最外层…

ELF分析(以CS:APP linkLab的文件为例)

文件结构&#xff1a;gcc -o test main.o phase1.o 可执行文件的段头表&#xff08;又称程序头表&#xff09;&#xff08;用于描述本文件到虚拟内存的映射&#xff09; text文件的段头表如下。 上图有两个LOAD。它们的区别是权限不同。LOAD1是可读可执行&#xff08;这里面存…

拍这个视频把脸都扇肿了,midjourney官网效果复现

我是如何复现midjourney官网首页效果的&#xff1f; 视频讲解地址&#xff1a;[https://www.bilibili.com/video/BV1FQ4y1p7HC/](https://www.bilibili.com/video/BV1FQ4y1p7HC/)原理&#xff0c;过程&#xff0c;代码讲解 大家好&#xff0c;这一集我来讲一下 字符花园里 总结…

ehr人力资源管理系统(实际项目源码)

eHR人力资源管理系统&#xff1a;功能强大的人力资源管理工具 随着企业规模的不断扩大和业务需求的多样化&#xff0c;传统的人力资源管理模式已无法满足现代企业的需求。eHR人力资源管理系统作为一种先进的管理工具&#xff0c;能够为企业提供高效、准确、实时的人力资源管理…

04_MySQL备份与恢复

任务背景 一、真实案例 某天&#xff0c;公司领导安排刚入职不久的小冯同学将生产环境中的数据(MySQL数据库)全部导入到测试环境给测试人员使用。当小冯去拿备份数据时发现&#xff0c;备份数据是1个礼拜之前的。原因是之前运维同事通过脚本每天对数据库进行备份&#xff0c;…