基于 Tensorflow 2.x 实现多层卷积神经网络,实践 MNIST 手写数字识别

news2025/1/12 12:21:24

一、MNIST 数据集

上篇文章中使用了Tensorflow 2.x 搭建了对层的 BP 神经网络,经过训练后发现准确率只有 96.8% 对于单环境的图片识别场景来说,还是有点偏低,本文使用多层的卷积代替BP网络中的隐藏层对模型进行优化。

下面是上篇文章地址:https://blog.csdn.net/qq_43692950/article/details/128361681

下面再次简单介绍一下MNIST 数据集,有的小伙伴可能没有看过上篇文章。该数据集已被封装在了 tf.keras.datasets 工具包下,包括了 70000 张图片数据,大小统一是 28x28的长宽,其中 60000 张作为训练数据,10000张作为测试数据,每一张图片都代表 0~9 中的一个数字,可以通过下面程序对该数据进行可视化预览:

import tensorflow as tf
import matplotlib.pyplot as plt

keras = tf.keras
mnist = tf.keras.datasets.mnist
plt.rcParams['font.sans-serif'] = ['SimHei']

# 加载 fashion_mnist 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

for i in range(10):
    image = x_test[i]
    label = y_test[i]
    plt.imshow(image, cmap=plt.cm.gray)
    plt.title(label=('标签值: ' + str(label)))
    plt.show()

在这里插入图片描述

二、搭建多层卷积神经网络模型

本文基于 Tensorflow 2.x 构建多层卷积神经网络,在 Tensorflow 2.x 中官方更推荐的上层API工具为 Keras ,本文也是使用 Keras 进行实验测试。

设计模型结构如下所示:

在这里插入图片描述

通过 Keras 建立模型结构,具体的解释都写在了注释中:

import tensorflow as tf

keras = tf.keras
mnist = tf.keras.datasets.mnist

# 定义模型类
class mnistModel():
    # 初始化结构
    def __init__(self, checkpoint_path, log_path, model_path):
        # checkpoint 权重保存地址
        self.checkpoint_path = checkpoint_path
        # 训练日志保存地址
        self.log_path = log_path
        # 训练模型保存地址:
        self.model_path = model_path
        # 初始化模型结构
        self.model = tf.keras.models.Sequential([
            # 输入层,第一层卷积 ,卷积核 3x3 ,输出 (None, 28, 28, 32),卷积模式 same
            keras.layers.Conv2D(32, (3, 3),
                                kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                activation=tf.nn.relu,
                                kernel_regularizer=keras.regularizers.l2(0.001),
                                padding='same',
                                input_shape=(28, 28, 1)),
            # 第一层卷积的池化层,2x2 MaxPool,输出 (None, 14, 14, 32)
            keras.layers.MaxPooling2D(2, 2),

            # 第二层卷积 ,卷积核 3x3 ,输出 (None, 14, 14, 64) ,卷积模式 same
            keras.layers.Conv2D(64, (3, 3),
                                kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                activation=tf.nn.relu,
                                kernel_regularizer=keras.regularizers.l2(0.001),
                                padding='same'),
            # 第二层卷积的池化层,2x2 MaxPool,输出 (None, 7, 7, 64)
            keras.layers.MaxPooling2D(2, 2),
            # Dropout 随机失活,防止过拟合,输出 (None, 7, 7, 64)
            keras.layers.Dropout(0.2),
            # 转为全链接层,输出 (None, 3136)
            keras.layers.Flatten(),
            # 第一层全链接层,输出 (None, 512)
            keras.layers.Dense(512,
                               kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                               kernel_regularizer=keras.regularizers.l2(0.001),
                               activation=tf.nn.relu),
            # softmax 层,输出 (None, 10)
            keras.layers.Dense(10, activation=tf.nn.softmax)
        ])

    # 编译模型
    def compile(self):
        # 输出模型摘要
        self.model.summary()
        # 定义训练模式
        self.model.compile(optimizer='adam',
                           loss='sparse_categorical_crossentropy',
                           metrics=['accuracy'])

    # 训练模型
    def train(self, x_train, y_train):
        # tensorboard 训练日志收集
        tensorboard = keras.callbacks.TensorBoard(log_dir=self.log_path)

        # 训练过程保存 Checkpoint 权重,防止意外停止后可以继续训练
        model_checkpoint = keras.callbacks.ModelCheckpoint(self.checkpoint_path,  # 保存模型的路径
                                                           monitor='val_loss',  # 被监测的数据。
                                                           verbose=0,  # 详细信息模式,0 或者 1
                                                           save_best_only=True,  # 如果 True, 被监测数据的最佳模型就不会被覆盖
                                                           save_weights_only=True,
                                                           # 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)),否则的话,整个模型会被保存,(model.save(filepath))
                                                           mode='auto',
                                                           # {auto, min, max}的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto模式中,方向会自动从被监测的数据的名字中判断出来。
                                                           period=3  # 每3个epoch保存一次权重
                                                           )
        # 填充数据,迭代训练
        self.model.fit(
            x_train,  # 训练集
            y_train,  # 训练集的标签
            validation_split=0.2,  # 验证集的比例
            epochs=30,  # 迭代周期
            batch_size=30,  # 一批次输入的大小
            verbose=2,  # 训练过程的日志信息显示,一个epoch输出一行记录
            callbacks=[tensorboard, model_checkpoint]
        )
        # 保存训练模型
        self.model.save(self.model_path)

    def evaluate(self, x_test, y_test):
        # 评估模型
        test_loss, test_acc = self.model.evaluate(x_test, y_test)
        return test_loss, test_acc

上面优化器使用的 adamlosssparse_categorical_crossentropy,一共训练 30 个周期,每个 batch 30 张图片,验证集的比例为 20%,并且每三个周期保存一次权重,防止意外停止后继续训练,最后保存了 h5 的训练模型,方便后面进行测试预测效果。

下面开始训练模型:

import tensorflow as tf

keras = tf.keras
mnist = tf.keras.datasets.mnist

def main():
    # 加载 MNIST 数据集
    (x_train, y_train), (x_test, y_test) = mnist.load_data(path='F:/Tensorflow/datasets/mnist.npz')
    # 修改shape 数据归一化
    x_train, x_test = x_train.reshape(60000, 28, 28, 1) / 255.0, \
                      x_test.reshape(10000, 28, 28, 1) / 255.0

    checkpoint_path = './checkout/'
    log_path = './log'
    model_path = './model/model.h5'

    # 构建模型
    model = mnistModel(checkpoint_path, log_path, model_path)
    # 编译模型
    model.compile()
    # 训练模型
    model.train(x_train, y_train)
    # 评估模型
    test_loss, test_acc = model.evaluate(x_test, y_test)
    print(test_loss, test_acc)

if __name__ == '__main__':
    main()

运行后可以看到打印的网络结构:

在这里插入图片描述
从训练日志中,可以看到 loss 一直在减小:

在这里插入图片描述

等待训练结束后看下评估模型的结果:

在这里插入图片描述
如果看过上篇文章可以发现损失和准确率都有明显的提升。

下面看下 tensorboard 中可视化的损失及准确率:

tensorboard --logdir=log/train

在这里插入图片描述

使用浏览器访问:http://localhost:6006/ 查看结果:

在这里插入图片描述

三、模型预测

上面搭建的模型,训练后会在 model 下生成 model.h5 模型,下面直接加载该模型进行预测:

import tensorflow as tf
import matplotlib.pyplot as plt

keras = tf.keras
mnist = tf.keras.datasets.mnist
plt.rcParams['font.sans-serif'] = ['SimHei']

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data(path='F:/Tensorflow/datasets/mnist.npz')
# 数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

model = keras.models.load_model('./model/model.h5')

for i in range(10):
    image = x_test[i]
    label = y_test[i]
    softmax = model.predict(image.reshape([1, 28, 28, 1]))
    y_label = tf.argmax(softmax, axis=1).numpy()[0]
    plt.imshow(image, cmap=plt.cm.gray)
    plt.title(label=('预测结果: ' + str(y_label) + ',  真实结果: ' + str(label)))
    plt.show()

运行后可以看下如下的图片示例:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/98586.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言重点解剖第12课笔记

1.int* a,b; a和b的类型不一样, a是指针,b是整型。 typedef int* int_p; int_p a,b; 或者int* a,*b; 这样写的话,a和b都是指针类型。 #define int_p int*;这是纯粹的文本替换。 typedef定义之后是一种独立类型。 2.大部分注释都换成了…

Linux网络协议之HTTP协议(应用层)

Linux网络协议之HTTP协议(应用层) 文章目录Linux网络协议之HTTP协议(应用层)1.HTTP协议的概念2.HTTP协议中URL的理解3.HTTP协议的数据流4.HTTP协议的格式4.1 HTTP请求格式4.2 HTTP响应格式5.HTTP协议格式图解6.HTTP协议版本7.HTTP协议请求方法7.1 GET方法:获取资源7…

OWASP API安全Top 10

文章目录API1-失效的对象级授权API2-失效的用户认证API3-过度的数据暴露API4-缺乏资源和速率控制API5-失效的功能级授权API6-批量分配API7-安全性配置错误API8-注入API9-资产管理不当API10-日志记录和监控不足在API安全发展的过程中,除了各大安全厂商和头部互联网企…

计算机基础学习笔记:操作系统篇之硬件结构,CPU的基本工作原理

一、CPU的是如何运行程序的? 本文知识来源小林Coding阅读整理思考,原文链接请见以下: https://xiaolincoding.com/os/1_hardware/how_cpu_run.html#图灵机的工作方式 问题引入 程序的执行过程?例如 12 的具体过程是怎么样的&…

Windows VS2015 cmake编译Gtest并进行测试

1.下载Gtest 下载网址:https://github.com/google/googletest/releases 也可以直接使用下载好的附件 解压,放到一个目录中,演示所用,直接存放D盘了。 2.使用CMake生成vs编译工程 选好下图中两个路径,点击Configure…

用 AWTK 和 AWPLC 快速开发嵌入式应用程序 (8)- AWBlock

AWPLC 目前还处于开发阶段的早期,写这个系列文章的目的,除了用来验证目前所做的工作外,还希望得到大家的指点和反馈。如果您有任何疑问和建议,请在评论区留言。 1. 背景 AWTK 全称 Toolkit AnyWhere,是 ZLG 开发的开源…

玩以太坊链上项目的必备技能(OOP-接口-Solidity之旅十一)

接口(interface) 我们知道在Java里接口是特殊的抽象类,限制多于抽象类,但随着Java版本的更新,Java中的接口是越来越趋于抽象类了(这样说,可能有点不妥,因为接口本就是特殊的抽象类&…

自己整理的Java面试题(下)

目录五.Java框架部分Spring1.Spring中的拦截器,过滤器组件介绍?2.说一下spring的IOC?3.Spring中的异常处理:4.jdk动态代理和cglib动态代理:5.Spring Bean生命周期:6.Spring IOC原理:7.BeanFacto…

RK3568平台开发系列讲解(Camera篇)Camera API v2框架

🚀返回专栏总目录 文章目录 一、Camera API v2框架二、preview流程三、核心模块沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇讲介绍 Camera API v2框架。 一、Camera API v2框架 应用框架:应用代码位于应用框架级别,它使用 Camera 2 API 与相机硬件进行交互…

【正点原子I.MX6U-MINI】u-boot过程移植详解

正点原子的I.MX6ULL开发板参考的是NXP官方的I.MX6ULL EVK开发板做的硬件。 Linux的移植要复杂的多,在移植Linux之前我们需要先移植一个 bootloader 代码,这个 bootloader 代码用于启动Linux 内核,bootloader有很多,常用的就是 U-…

蓝桥杯C/C++百校真题赛(1期)Day3题解(等差数列、回路计数)

Q1 等差数列 由于保证了题目给出的一定是一个等差数列的部分项,且等差数列具有单调性质,所以根据大小排序后最小的did_idi​就是所求等差数列的公差ddd, 又因为求的是最小,所以n(an−a1)/d1,特别的,当ana1,d0时,特判输…

[数据库]复习杂项

(画师蓝鸟mo13tto) 数据库笔记(补充)——候选码的确定方法 求最小依赖集 最小函数依赖集Fm的定义,求法以及举例 当然这篇文章后半部分有误:【通俗易懂】关系模式范式分解教程 3NF与BCNF口诀!小白也能看…

企业数字化转型:数据集成是成功的关键

按照数据的生命周期,我们通常将大数据技术分为数据集成、数据存储、批/流处理、数据查询与分析、数据调度与编排、数据开发、BI 7 个部分。 数据集成是什么? 可以看到数据集成在数据生命周期的最前面位置,它负责将多个来自不同数据源的数据…

[附源码]计算机毕业设计Python保护濒危动物公益网站(程序+源码+LW文档)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程 项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等…

vue3 中的响应式设计原理

Vue 3 中的响应式原理可谓是非常之重要,通过学习 Vue3 的响应式原理,不仅能让我们学习到 Vue.js 的一些设计模式和思想,还能帮助我们提高项目开发效率和代码调试能力。 一、Vue 3 响应式使用 1. Vue 3 中的使用 当我们在学习 Vue 3 的时候&…

51单片机——动态数码管实验,小白讲解,相互学习

多位数码管介绍: 多位数码管,即两个或两个以上单个数码管并列集中在一起形成一体的数码管。当多位一体时,他们内部的公共端是独立的,二负责显示什么数字的段线(a-dp)全部是连接在一起的,独立的公…

中国水文地质图集

概述 水文地质图集部分来源于 《中华人民共和国水文地质图集》(地质出版社1979年版)的GIS数字化版(数据格式:JPEG),图集是由全国性、地区性和分省/自治区/直辖市等三类图幅组成,共68幅图(实际收集到55幅图)。 主要内容包括:水文地质图、地下热水分布图、水化学图、…

数据结构C语言版 —— 栈的实现

文章目录栈1. 基本概念2. 栈的实现1) 初始化栈2) 栈的扩容3) 判断栈是否为空4) 入栈5) 出栈6) 获取栈顶元素7) 获取栈中元素个数8) 销毁栈栈 1. 基本概念 栈(Stack):一种特殊的线性表,其只限定于在表尾进行插入或者删除操作。进行数据插入和删除操作的…

RocketMq02_复制刷盘、Broker常用模式、磁盘阵列、集群搭建

文章目录①. 单机版本安装与启动②. 控制台的安装与启动③. 复制刷盘、Broker集群模式④. 磁盘阵列 - RAID⑤.JBOD、RAID0⑥. RAID1、RAID10、01⑦. 搭建集群 - 异步两主两从①. 单机版本安装与启动 ①. 系统要求是64位的,JDK要求是1.8及其以上版本的 ②. 将下载的安装包上传到…

NFT及智能合约开发

文章目录1.Web3.01.1 GameFi1.2 DeFi1.3 dApp2.NFT2.1 NFT Applications2.2 NFT Earning2.3 NFT结构2.3 IPFS2.4 Wallet3.Smart Contract3.1 Smart Contract System3.2 Smart Contract Development3.2.1 Language3.2.2 IDE3.2.3 BlockChain3.2.4 FrontEnd3.2.5 NFT Test WebSit…