基于 Tensorflow 2.x 实现 BP 神经网络,实践 MNIST 手写数字识别

news2024/9/24 3:27:19

一、MNIST 数据集

MNIST 是一个非常有名的手写数字识别数据集,在很多资料中都会被用作深度学习的入门样例。在 Tensorflow 2.x 中该数据集已被封装在了 tf.keras.datasets 工具包下,如果没有指定数据集的位置,并先前也没有使用过,会自动联网下载该,使该数据集使用起来更加方便,它包括了 70000 张图片数据,大小统一是 28x28的长宽,其中 60000 张作为训练数据,10000张作为测试数据,每一张图片都代表 0~9 中的一个数字,可以通过下面程序对该数据进行可视化预览:

import tensorflow as tf
import matplotlib.pyplot as plt

keras = tf.keras
mnist = tf.keras.datasets.mnist
plt.rcParams['font.sans-serif'] = ['SimHei']

# 加载 fashion_mnist 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

for i in range(10):
    image = x_test[i]
    label = y_test[i]
    plt.imshow(image, cmap=plt.cm.gray)
    plt.title(label=('标签值: ' + str(label)))
    plt.show()

运行后可以看下如下的图片示例:

在这里插入图片描述

二、搭建多层BP神经网络模型

本文基于 Tensorflow 2.x 构建 BP 神经网络,在 Tensorflow 2.x 中官方更推荐的上层API工具为 Keras ,本文也是使用 Keras 进行实验测试。

设计模型结构如下所示:

在这里插入图片描述

通过 Keras 建立模型结构,具体的解释都写在了注释中:

import tensorflow as tf

keras = tf.keras
mnist = tf.keras.datasets.mnist

# 定义模型类
class mnistModel():
    # 初始化结构
    def __init__(self, checkpoint_path, log_path, model_path):
        # checkpoint 权重保存地址
        self.checkpoint_path = checkpoint_path
        # 训练日志保存地址
        self.log_path = log_path
        # 训练模型保存地址:
        self.model_path = model_path
        # 初始化模型结构
        self.model = tf.keras.models.Sequential([
            # 输入层,平坦层,输入 (None, 784)
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            # 隐藏层 一 ,输出 (None, 64)
            tf.keras.layers.Dense(64,
                                  kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                  activation=tf.nn.relu,
                                  kernel_regularizer=keras.regularizers.l2(0.001)),
            # 隐藏层 二 ,输出(None, 128)
            tf.keras.layers.Dense(128,
                                  kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                  activation=tf.nn.relu,
                                  kernel_regularizer=keras.regularizers.l2(0.001)),

            # Dropout 随机失活,防止过拟合,20%的神经元失活,输出  (None, 128)
            tf.keras.layers.Dropout(0.2),
            # 隐藏层 三 ,输出 (None, 256)
            tf.keras.layers.Dense(256,
                                  kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                  activation=tf.nn.relu,
                                  kernel_regularizer=keras.regularizers.l2(0.001)),
            # Dropout 随机失活,防止过拟合,20%的神经元失活,输出 (None, 254)
            tf.keras.layers.Dropout(0.2),
            # softmax 层,输出 (None, 10)
            tf.keras.layers.Dense(10, activation='softmax')
        ])

    # 编译模型
    def compile(self):
        # 输出模型摘要
        self.model.summary()
        # 定义训练模式
        self.model.compile(optimizer='adam',
                           loss='sparse_categorical_crossentropy',
                           metrics=['accuracy'])

    # 训练模型
    def train(self, x_train, y_train):
        # tensorboard 训练日志收集
        tensorboard = keras.callbacks.TensorBoard(log_dir=self.log_path)

        # 训练过程保存 Checkpoint 权重,防止意外停止后可以继续训练
        model_checkpoint = keras.callbacks.ModelCheckpoint(self.checkpoint_path,  # 保存模型的路径
                                                           monitor='val_loss',  # 被监测的数据。
                                                           verbose=0,  # 详细信息模式,0 或者 1
                                                           save_best_only=True,  # 如果 True, 被监测数据的最佳模型就不会被覆盖
                                                           save_weights_only=True,
                                                           # 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)),否则的话,整个模型会被保存,(model.save(filepath))
                                                           mode='auto',
                                                           # {auto, min, max}的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto模式中,方向会自动从被监测的数据的名字中判断出来。
                                                           period=3  # 每3个epoch保存一次权重
                                                           )
        # 填充数据,迭代训练
        self.model.fit(
            x_train,  # 训练集
            y_train,  # 训练集的标签
            validation_split=0.2,  # 验证集的比例
            epochs=30,  # 迭代周期
            batch_size=30,  # 一批次输入的大小
            verbose=2,  # 训练过程的日志信息显示,一个epoch输出一行记录
            callbacks=[tensorboard, model_checkpoint]
        )
        # 保存训练模型
        self.model.save(self.model_path)

    def evaluate(self, x_test, y_test):
        # 评估模型
        test_loss, test_acc = self.model.evaluate(x_test, y_test)
        return test_loss, test_acc

上面优化器使用的 adamlosssparse_categorical_crossentropy,一共训练 30 个周期,每个 batch 30 张图片,验证集的比例为 20%,并且每三个周期保存一次权重,防止意外停止后继续训练,最后保存了 h5 的训练模型,方便后面进行测试预测效果。

下面开始训练模型:

import tensorflow as tf

keras = tf.keras
mnist = tf.keras.datasets.mnist

def main():
    # 加载 MNIST 数据集
    (x_train, y_train), (x_test, y_test) = mnist.load_data(path='F:/Tensorflow/datasets/mnist.npz')
    # 修改shape 数据归一化
    x_train, x_test = x_train / 255.0, x_test / 255.0

    checkpoint_path = './checkout/'
    log_path = './log'
    model_path = './model/model.h5'

    # 构建模型
    model = mnistModel(checkpoint_path, log_path, model_path)
    # 编译模型
    model.compile()
    # 训练模型
    model.train(x_train, y_train)
    # 评估模型
    test_loss, test_acc = model.evaluate(x_test, y_test)
    print(test_loss, test_acc)

if __name__ == '__main__':
    main()

运行后可以看到网络结构:

在这里插入图片描述
训练日志,可以看到 loss 一直在减小:

在这里插入图片描述
最后看下评估模型的结果:
在这里插入图片描述
毕竟是BP神经网络,这里评估的准确率只有 96.8 %,在本专栏后面博客,会使用多层卷积训练模型,可以实现更好的效果。

下面看下 tensorboard 中可视化的损失及准确率:

tensorboard --logdir=log/train

在这里插入图片描述
使用浏览器访问:http://localhost:6006/ 查看结果:

在这里插入图片描述

三、模型预测

上面搭建的模型,训练后会在 model 下生成 model.h5 模型,下面直接加载该模型进行预测:

import tensorflow as tf
import matplotlib.pyplot as plt

keras = tf.keras
mnist = tf.keras.datasets.mnist
plt.rcParams['font.sans-serif'] = ['SimHei']

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

model = keras.models.load_model('./model/model.h5')

for i in range(10):
    image = x_test[i]
    label = y_test[i]
    softmax = model.predict(image.reshape([1, 784]))
    y_label = tf.argmax(softmax, axis=1).numpy()[0]
    plt.imshow(image, cmap=plt.cm.gray)
    plt.title(label = ('预测结果: '+ str(y_label) + ',  真实结果: '+ str(label)))
    plt.show()

运行后可以看下如下的图片示例:

在这里插入图片描述
可以看到面对书写较工整的数字都可以较好的进行识别,但是对于不工整的就有点吃力,下一篇使用卷积神经网络进行优化,提高识别的准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/98320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java语言】— Java基础语法01

Java基础语法 1.注释 什么是注释 注释是写在程序中对代码进行解释说明的文字,方便自己和他人查看,以便理解程序。 注释有哪些 (1)单行注释 // 注释内容,只能写一行(2)多行注释 /* 注释内…

数据结构---判断一个数是否为2的整数次幂

判断一个数是否为2的整数次幂穷举法JAVA实现移位操作优化性能JAVA实现按位与JAVA实现实现一个方法,来判断一个正整数是否是2的整数次幂(如16是2的4次方,返回true;18不是2的整数次幂,则返回false)。要求性能…

BFS——Flood Fill模型及最短路模型

文章目录Flood Fill模型概述模板池塘计数城堡问题山峰和山谷最短路模型概述迷宫问题武士风度的牛抓住那头牛总结Flood Fill模型 概述 定义 从一个起始节点开始把附近与其连通的节点提取出或填充成不同颜色颜色,直到封闭区域内的所有节点都被处理过为止&#xff0c…

关于 SAP Gateway 响应头部 Last Modified 字段的赋值逻辑

本教程迄今为止,讨论的绝大多数都是 OData 服务数据实现类(Data Provider Class) 的实现。而要讨论 OData 服务的元数据话题,就得去 MPC 类研究。 MPC 类的 define 方法,负责生成 OData metadata 元数据: Postman 里请求元数据&…

​6. 独享锁 VS 共享锁

独享锁和共享锁同样是一种概念。我们先介绍一下具体的概念,然后通过ReentrantLock和ReentrantReadWriteLock的源码来介绍独享锁和共享锁。 独享锁也叫排他锁,是指该锁一次只能被一个线程所持有。如果线程T对数据A加上排它锁后,则其他线程不能…

Linux典型IO模型:阻塞、非阻塞、信号驱动、异步

目录 一、阻塞IO 二、非阻塞IO 三、信号驱动IO 四、异步IO 五、阻塞VS非阻塞(概念) 1.阻塞 2.非阻塞 3.区别与联系 六、同步VS异步(概念) 1.同步 2.异步 3.区别与联系 IO就是输入输出 一、阻塞IO 为了完成IO发起IO调…

高通平台开发系列讲解(充电篇)充电管理芯片PM7250B详解

文章目录 一、PM7250B硬件组成二、充电功能沉淀、分享、成长,让自己和他人都能有所收获!😄 📢充电管理芯片PM7250B,用于控制电池充电相关逻辑。 一、PM7250B硬件组成 PWM = Pulse Width Modulator,脉宽调制。SPMS = Switched Mode Power Supply,开关电源。GPIO = Gen…

接口测试(八)—— 日志收集、全量字段校验、JSON Schema语法

目录 一、日志收集 1、日志简介 2、日志的级别 3、日志代码实现分析 4、日志使用 二、全量字段校验 1、简介和安装 2、JSON Schema⼊⻔ 2.1 入门案例 2.2 校验方式 3、JSON Schema语法 3.1 type关键字 3.2 properties关键字 3.3 required关键字 3.4 const关键字…

ADI Blackfin DSP处理器-BF533的开发详解61:DSP控制ADXL345三轴加速度传感器-LCD(含源码)

硬件准备 ADSP-EDU-BF533:BF533开发板 AD-HP530ICE:ADI DSP仿真器 软件准备 Visual DSP软件 硬件链接 MEMS三轴加速度传感器 我做了一个三轴加速度传感器的子卡,插在这个板子上,然后写了一些有意思的应用程序。 代码实现功能…

[C++]类和对象【中】

🥁作者: 华丞臧 📕​​​​专栏:【C】 各位读者老爷如果觉得博主写的不错,请诸位多多支持(点赞收藏关注)。如果有错误的地方,欢迎在评论区指出。 推荐一款刷题网站 👉LeetCode 文章目录类的六个…

推荐一个.Net分布式微服务开发框架

在给大家介绍之前,我们一起来看看分布式架构的使用场景与好处。 针对一些互联网系统,大数据、高并发和快速响应,都是系统必须满足的,而单机系统的架构是无法满足这样的需求的,这时候我们就需要用到分布式的架构。 分…

ADI Blackfin DSP处理器-BF533的开发详解60:DSP控制ADXL345三轴加速度传感器-电子水平仪(含源码)

硬件准备 ADSP-EDU-BF533:BF533开发板 AD-HP530ICE:ADI DSP仿真器 软件准备 Visual DSP软件 硬件链接 MEMS三轴加速度传感器 我做了一个三轴加速度传感器的子卡,插在这个板子上,然后写了一些有意思的应用程序。 代码实现功能…

SpringBoot集成JWT实现Token登录验证

1JWT 1.1 JWT是什么? JSON Web令牌(JWT)是一种开放的标准(RFC 7519),它定义了一种紧凑而独立的方式在各方之间安全地传输信息为JSON对象。该信息可以被验证和信任,因为它是数字签名的。JWT可以使用秘密(使用HMAC算法)或使用RSA或ECDSA的公开…

全国A级景区数据(12000条)

中华人民共和国旅游景区依据质量等级划分景区级别,共分为五级。其中5A级为中国旅游景区最高等级,代表着中国世界级精品的旅游风景区。 而随着国家旅游管理部门对于A级景区实行“有进有出”的动态管理以来,A级景区的调整越来越常态化,其中不乏4A、5A级景区的调整,这也为A级…

使用 NuGet 快速创建 OpenGL 项目

C 目前还没有一个标准的 C 依赖包管理器,传统上都是手动去下载源码编译(经典的例如 make),或者直接下载预编译好的库文件(例如没有开源的)和头文件。之后在项目里配置对应的头文件路径和库路径。这个过程非…

[附源码]Nodejs计算机毕业设计基于响应式交友网站Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置: Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分…

蚁巢相遇问题

一 问题描述 有 N 个蚁巢,编号为 1~N 。第 i 个蚁巢的位置是(xi , yi),没有两个蚁巢在同一位置。所有蚂蚁都遵守一些规律: ① 当一只蚂蚁在蚁巣 p 时,它总是移动到离 p 最近的另一个蚁巣,若有多个蚁巣与 …

计算机毕设Python+Vue心理健康网站(程序+LW+部署)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

[leetcode 739] 每日温度

题目链接:https://leetcode.cn/problems/daily-temperatures/ 第一个想法是简单两个 for 循环,但是可能会超时(其实用C不会超时)。 因为最近在做栈的题目,所以想到了最小栈(原来叫作最小栈啊~)…

Rust 从入门到放弃,再入门到贡献 nacos-sdk-rust

Rust 从入门到放弃,再入门到贡献 nacos-sdk-rust Rust 上手难度大?我想是的。从文章标题便可知一二,小编水平有限经历了多次入门,得来的经验之谈。本文不涉及详细的技术剖析,仅表达入门的心路历程,供客官参…