01、Tensorflow实现二元手写数字识别

news2024/9/21 4:28:18

01、Tensorflow实现二元手写数字识别(二分类问题)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0

1、识别目标

识别手写仅仅是为了区分手写的0和1,所以实际上是一个二分类问题。

2、Tensorflow算法实现

STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。

from sklearn.metrics import accuracy_score:这是从scikit-learn库中引入accuracy_score函数。这个函数用于计算分类准确率,常用于评估分类模型的性能。


STEP2:屏蔽无用警告并允许中文

logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

logging.getLogger(“tensorflow”).setLevel(logging.ERROR):这行代码用于设置 TensorFlow 的日志级别为 ERROR。这意味着只有当 TensorFlow 中发生错误时,才会在日志中输出相关信息。较低级别的日志信息(如 WARNING、INFO、DEBUG)将被忽略。

tf.autograph.set_verbosity(0):这行代码用于设置 TensorFlow 的自动图形(Autograph)日志的冗长级别为 0。这意味着在将 Python 代码转换为 TensorFlow 图形代码时,将不会输出任何日志信息。这有助于减少日志噪音,使日志更加干净。

warnings.simplefilter(action=‘ignore’,category=FutureWarning):这行代码用于忽略所有 FutureWarning 类型的警告。在 Python中,当使用某些即将过时或未来版本中可能发生变化的特性时,通常会发出 FutureWarning。通过设置action=‘ignore’,代码将不会输出这类警告,使控制台输出更加干净。

plt.rcParams[‘font.sans-serif’]=[‘SimHei’]:这行代码用于设置 matplotlib 中的默认无衬线字体为 SimHei。SimHei 是一种常用于显示中文的字体,这样设置后,matplotlib 将在绘图时使用 SimHei 字体来显示中文,从而避免中文乱码问题。

plt.rcParams[‘axes.unicode_minus’]=False:这行代码用于解决 matplotlib
中负号显示异常的问题。默认情况下,matplotlib 可能无法正确显示负号,将其设置为 False 可以使用 ASCII字符作为负号,从而正常显示。


STEP3:导入并划分数据集

划分10%作为测试:

X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

STEP4:模型构建与训练

# 构建模型,三层模型进行分类,第一层输入100个神经元...
model = Sequential(
    [
        tf.keras.Input(shape=(400,)),    #specify input size
        ### START CODE HERE ###
        Dense(100, activation='sigmoid'),
        Dense(10, activation='sigmoid'),
        Dense(1, activation='sigmoid')
        ### END CODE HERE ###
    ], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(
    X_train,y_train,
    epochs=20
)


STEP5:结果可视化与打印准确度信息
原始的输入的数据集是400 * 1000的数组,共包含1000个手写数字的数据,其中400为20*20像素的图片,因此对每个400的数组进行reshape((20, 20))可以得到原始的图片进而绘图。

# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):
    # Select random indices
    random_index = np.random.randint(X_test.shape[0])

    # Select rows corresponding to the random indices and
    # reshape the image
    X_random_reshaped = X_test[random_index].reshape((20, 20)).T

    # Display the image
    ax.imshow(X_random_reshaped, cmap='gray')

    # Predict using the Neural Network
    prediction = model.predict(X_test[random_index].reshape(1, 400))
    if prediction >= 0.5:
        yhat = 1
    else:
        yhat = 0
    # Display the label above the image
    ax.set_title(f"{y_test[random_index, 0]},{yhat}")
    ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()

# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

3、运行结果

按照最初的划分,数据集包含1000个数据,划分10%为测试集,也就是100个数据。结果可视化随机选择其中的64个数据绘图,每个图像的上方标明了其真实标签和预测的结果,这个是一个非常简单的示例,准确度还是非常高的。
在这里插入图片描述

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现二元手写数字识别(二分类问题)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score

logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

# load dataset
def load_data():
    X = np.load("Handwritten_Digit_Recognition_data/X.npy")
    y = np.load("Handwritten_Digit_Recognition_data/y.npy")
    X = X[0:1000]
    y = y[0:1000]
    return X, y


# 加载数据集,查看数据集大小,可以看到有1000个数据集,每个输入是20*20=400大小的图片
X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

# # 下面画图,随便从原数据取出几个画图,可以注释
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(8, 8))
# fig.tight_layout(pad=0.1)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # 将1*400的数据转换为20*20的图像格式
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
# plt.show()

# 构建模型,三层模型进行分类,第一层输入25个神经元...
model = Sequential(
    [
        tf.keras.Input(shape=(400,)),    #specify input size
        ### START CODE HERE ###
        Dense(100, activation='sigmoid'),
        Dense(10, activation='sigmoid'),
        Dense(1, activation='sigmoid')
        ### END CODE HERE ###
    ], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(
    X_train,y_train,
    epochs=20
)

# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):
    # Select random indices
    random_index = np.random.randint(X_test.shape[0])

    # Select rows corresponding to the random indices and
    # reshape the image
    X_random_reshaped = X_test[random_index].reshape((20, 20)).T

    # Display the image
    ax.imshow(X_random_reshaped, cmap='gray')

    # Predict using the Neural Network
    prediction = model.predict(X_test[random_index].reshape(1, 400))
    if prediction >= 0.5:
        yhat = 1
    else:
        yhat = 0
    # Display the label above the image
    ax.set_title(f"{y_test[random_index, 0]},{yhat}")
    ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()

# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1249859.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NeurIPS 2023|AI Agents先行者CAMEL:第一个基于大模型的多智能体框架

AI Agents是当下大模型领域备受关注的话题,用户可以引入多个扮演不同角色的LLM Agents参与到实际的任务中,Agents之间会进行竞争和协作等多种形式的动态交互,进而产生惊人的群体智能效果。本文介绍了来自KAUST研究团队的大模型心智交互CAMEL框…

浅谈安科瑞无线测温设备在挪威某项目的应用

摘要:安科瑞无线温度设备装置通过无线温度收发器和各无线温度传感器直接进行温度值的传输,并采用液晶显示各无线温度传感器所测温度。 Absrtact:Acre wireless temperature device directly transmits the temperature value through the wireless temp…

Nginx安装与配置、使用Nginx负载均衡及动静分离、后台服务部署、环境准备、系统拓扑图

目录 1. 系统拓扑图 2. 环境准备 3. 服务器安装 3.1 mysql,tomcat 3.2 Nginx的安装 4. 部署 4.1 后台服务部署 4.2 Nginx配置负载均衡及静态资源部署 1. 系统拓扑图 说明: 用户请求达到Nginx若请求资源为静态资源,则将请求转发至静态…

【蓝桥杯省赛真题47】Scratch小猫踩球 蓝桥杯scratch图形化编程 中小学生蓝桥杯省赛真题讲解

目录 scratch小猫踩球 一、题目要求 编程实现 二、案例分析 1、角色分析

vue3.0使用leaflet

1、获取天地图密钥; 访问:https://www.tianditu.gov.cn/ 注册并登录,访问开发资源 》地图API 》 地图服务》申请key 应用管理》创建新应用》获取到对应天地图key 2、引入leaflet组件 参考资料:https://leafletjs.com/reference.html#pa…

一盏茶的时间,入门 Node.js

一、.什么是 Node.js? Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行时,用于构建高性能、可伸缩的网络应用。 它采用事件驱动、非阻塞 I/O 模型,使其在处理并发请求时表现出色。 二、安装 Node.js 首先,让我们从 Node.…

CSS3新特性(2-1)

CSS3新特性 前言border:radius标签属性选择器box-sizing透明度 前言 本文主要讲解CSS3有哪些新的特性和内容,那么好,本文正式开始. border:radius 新增了圆角边框概念,可以通过具体数值或者百分比,来让边…

互联网上门洗鞋店小程序

上门洗鞋店小程序门店版是基于原平台版进行增强的,结合洗鞋行业的线下实际运营经验和需求,专为洗鞋人和洗鞋店打造的高效、实用、有价值的管理软件系统。 它能够帮助洗鞋人建立自己的私域流量,实现会员用户管理,实现用户与商家的点…

电源控制系统架构(PCSA)之电源控制框架概览

目录 6 电源控制框架 6.1 电源控制框架概述 6.1.1 电源控制框架低功耗接口 6.1.2 电源控制框架基础设施组件 6 电源控制框架 电源控制框架是标准基础设施组件、接口和相关方法的集合,可用于构建SoC电源管理所需的基础设施。 本章介绍框架的主要组件和低功耗接…

FFmpeg零基础学习(一)——初步介绍与环境搭建

目录 前言正文一、开发环境二、搭建环境二、测试代码 参考 前言 FFmpeg是一个开源的跨平台多媒体处理框架,它包含了一组用于处理音频、视频、字幕等多媒体数据的库和工具。FFmpeg提供了强大的功能和灵活性,被广泛用于多媒体应用开发、视频编辑、流媒体传…

每日OJ题_算法_双指针_力扣11. 盛最多水的容器

力扣11. 盛最多水的容器 11. 盛最多水的容器 - 力扣(LeetCode) 难度 中等 给定一个长度为 n 的整数数组 height 。有 n 条垂线,第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线,使得它们与 x 轴共同构成…

Windows核心编程 跨进程操作

目录 进程A拿到进程B句柄是否能用 句柄的权限 关于句柄表 跨进程使用句柄-继承 CreateProcess:bInheritHandles OpenProcess FindWinodw GetCurrentProcess 跨进程使用句柄-拷贝 跨进程操作内存 WriteProcessMemory VirtualProtectEx ReadProcessMemo…

<蓝桥杯软件赛>零基础备赛20周--第7周--栈和二叉树

报名明年4月蓝桥杯软件赛的同学们,如果你是大一零基础,目前懵懂中,不知该怎么办,可以看看本博客系列:备赛20周合集 20周的完整安排请点击:20周计划 每周发1个博客,共20周(读者可以按…

AI人工智能对话系统网页版源码系统 附带完整的搭建教程

AI人工智能对话系统网页版源码系统的开发背景主要是基于自然语言处理技术和机器学习算法的不断发展。自然语言处理技术使得计算机能够理解和分析人类语言,而机器学习算法则能够让计算机自我学习和改进,不断提高对话系统的智能化水平。 此外,…

有序表的详解

目录 有序表的介绍 树的左旋和右旋操作 AVL树的详解 SB树的详解 红黑树的介绍 SkipList的详解 有序表的介绍 有序表是除具备哈希表所具备的功能外,有序表中的内容都是按照key有序排列的,并且增删改查等操作的时间复杂度都是,红黑树&…

单片非晶磁性测量系统非晶测量方法

非晶测量方法 单片法是国际主流的非晶测量方法之一,如美标 A932 和日标 H7152 均早已提出了该方法;2014 年 IEC 起草的标准,和我国 2015 年重新修订的 GB/T 19345.1 标准中均明确提出了单片法测量非晶磁性能。单片法与环样法相比&#xff0c…

表单考勤签到作业周期打卡打分评价评分小程序开源版开发

表单考勤签到作业周期打卡打分评价评分小程序开源版开发 表单打卡评分 表单签到功能:学生可以通过扫描二维码或输入签到码进行签到,方便教师进行考勤管理。 考勤功能:可以记录学生的出勤情况,并自动生成出勤率和缺勤次数等统计数…

SpringBoot项目连接,有Kerberos认证的Kafka

在连接Kerberos认证kafka之前,需要了解Kerberos协议 二、什么是Kerberos协议 Kerberos是一种计算机网络认证协议 ,其设计目标是通过密钥系统为网络中通信的客户机(Client)/服务器(Server)应用程序提供严格的身份验证服务,确保通信双方身份的真…

​LeetCode解法汇总2304. 网格中的最小路径代价

目录链接: 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目: https://github.com/September26/java-algorithms 原题链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 描述: 给你一个下…

AMESim与MATLAB联合仿真demo

本文是AMESim与MATLAB联合仿真的demo,记录一下如何进行联合仿真。 AMESim与MATLAB联合仿真可以大幅度提高工作效率。 author:xiao黄 缓慢而坚定的生长 csdn:https://blog.csdn.net/Python_Matlab?typeblog主页传送门 博主的联合仿真环境如下&#xff…