基于全连接神经网络模型的手写数字识别

news2024/9/24 15:25:15

基于全连接神经网络模型的手写数字识别

  • 一. 前言
  • 二. 设计目的及任务描述
    • 2.1 设计目的
    • 2.2 设计任务
  • 三. 神经网络模型
    • 3.1 全连接神经网络模型方案
    • 3.2 全连接神经网络模型训练过程
    • 3.3 全连接神经网络模型测试
  • 四. 程序设计

一. 前言

手写数字识别要求利用MNIST数据集里的70000张手写体数字的图像,建立神经网络模型,进行0到9的分类,并能够对其他来源的图片进行识别,识别准确率大于97%。图片示例如下。
alt

图1.1 mnist数据集图片示例

该设计要求学生基于TensorFlow深度学习平台,利用自动下载的MNIST数据集,建立全连接或者CNN神经网络模型,对MNIST或者其他图片中的数字进行正确识别。同时,在数据获取、处理和分析过程中考虑数据安全、技术经济、工程伦理、行业规范等要素。

通过该题目的训练,使学生对深度学习技术有一定的了解,掌握深度学习模型建立、训练、测试和调优的过程,理解监督学习、数据处理、神经网络、卷积计算等概念并通过实例进行实践,学习TensorFlow并搭建深度学习平台,加深学生对深度学习技术的理解和实际引用,并能够利用深度学习方法解决实际问题。

二. 设计目的及任务描述

2.1 设计目的

深入学习TensorFlow深度学习平台,通过构建全连接神经网络和卷积神经网络的手写数字识别模型,实现对MNIST数据集中的数字0到9的分类,并具备对其他来源的图片进行准确识别的能力,要求识别准确率大于97%。这一设计旨在深入理解深度学习技术,并掌握模型的建立、训练、测试和调优的全过程。

首先,进行文献资料查阅,至少阅读5篇相关文献,以确保对深度学习领域的最新进展有所了解。通过文献的学习,将为设计过程提供前沿的理论支持,在实践中融入最新的研究成果。

学习TensorFlow深度学习平台的搭建是课程设计的第二步,这将提供一个强大而灵活的工具,用以实现神经网络的建模和训练。通过掌握TensorFlow,学生将具备在深度学习领域进行实际工作的基本能力。

在全连接神经网络的学习中,理解神经网络的基本原理,包括监督学习、数据处理、损失率函数的构建方法等。通过构建手写数字识别模型,亲身经历模型训练、测试和调优的过程,深入理解各参数的作用及其对模型准确率的影响。

通过这个课程设计,不仅获得深度学习技术的实际应用经验,还将培养文献查阅、团队协作、数据伦理等方面的能力,为将来深入科研或产业实践打下坚实基础。

2.2 设计任务

  1. 查阅文献资料,一般在5篇以上;
  2. 学习TensorFlow深度学习平台的搭建。
  3. 学习全连接神经网络,建立全连接网络的手写数字识别模型,并进行模型训练、测试和调优。
  4. 理解学习率、衰减率等参数的作用。
  5. 理解监督学习的过程。
  6. 学习损失率函数构建方法。
  7. 经过模型调优,理解模型中各参数的作用以及影响模型准确率的因素。
  8. 模型识别准确率大于97%。
  9. 撰写课程设计说明书,须达到以下要求:
    (1) 陈述设计题目、设计任务;
    (2) 描述TensorFlow深度学习平台的搭建过程;
    (3) 写出全连接神经网络模型方案;
    (4) 记录全连接神经网络模型训练过程;
    (5) 记录全连接神经网络模型测试准确率;
    (6) 陈述模型调优过程,包括调优过程中遇到的主要问题,是如何解决的;对模型设计和编码的回顾、反思和体会等,与同学对问题的讨论、分析、改进设想以及收获等。同时,分析数据处理及分析过程中面临的数据安全、工程伦理等问题。

三. 神经网络模型

3.1 全连接神经网络模型方案

设计中使用的全连接神经网络模型采用了典型的多层感知器(Multi-Layer Perceptron,MLP)架构,旨在解决手写数字识别任务。模型的输入层与输出层之间,有两个隐藏层负责提取和学习输入图像的特征。

模型的输入层包含了784个节点,对应于MNIST数据集中的每个图像像素。这个输入层将图像展平为一维向量,使得神经网络能够处理每个像素的信息。第一个隐藏层包含512个节点,通过ReLU激活函数引入非线性特性,帮助网络学习复杂的特征和模式。第二个隐藏层也有512个节点,并同样使用ReLU激活函数。这两个隐藏层的存在增强了网络对抽象特征的学习能力。

最后,输出层包含10个节点,对应于手写数字的10个可能类别。使用softmax激活函数,输出层将模型的原始输出转换为概率分布,表示每个类别的概率。

在模型的编译阶段,采用了交叉熵作为损失函数,这是多类别分类问题中常用的损失函数。模型的优化器选择了Adam,这是一种自适应学习率的优化算法。为了评估模型性能,选择了准确率作为指标,它度量了模型在训练和测试数据上的分类准确性。

3.2 全连接神经网络模型训练过程

训练过程是深度学习中至关重要的一部分,通过多次迭代优化模型参数,使其能够更好地适应训练数据。在这个训练过程中,采用了全连接神经网络模型,旨在实现手写数字的准确识别。

加载并预处理了MNIST数据集,将图像数据归一化到 [0, 1] 的范围,并进行了独热编码以适应模型的训练需求。构建了一个具有两个隐藏层的全连接神经网络模型,其中包含了512个节点,并使用ReLU激活函数,最终输出层具有10个节点,使用softmax激活函数进行多类别分类。

然后,对模型进行了编译,选择了交叉熵作为损失函数和Adam作为优化器。为了更充分地训练模型,将训练轮数设置为5。每次训练迭代,模型根据梯度下降的原理,不断更新权重和偏差,以最小化损失函数。

训练过程的 fit 函数的参数中,verbose=1表示在训练过程中输出详细信息,包括每个epoch的损失和准确率。模型的性能将在整个训练过程中逐渐提升,反映出它对训练数据的更好拟合能力。在迭代的过程中,我期望看到损失降低,而训练和验证准确率逐步提高。

通过增加训练轮数,提高模型学习的迭代次数,有望取得更好的性能和更强的泛化能力,使模型在未见过的数据上表现出色。
在这里插入图片描述

图 3-1 全连接神经网络_训练结果
如图3-1所示,通过5次训练模型的准确度达到97%。

3.3 全连接神经网络模型测试

使用 Keras 模型的 evaluate 方法在测试集上进行评估。

在这里插入图片描述

图 3-2 全连接神经网络_测试结果
经测试,如图3-2所示,模型准确度为97.66%。

四. 程序设计

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adam

def load_and_preprocess_data():
    # 加载并预处理MNIST数据集
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    # 重塑和手动归一化数据
    x_train = x_train.reshape((x_train.shape[0], 28, 28, 1)).astype('float32') / 255.0
    x_test = x_test.reshape((x_test.shape[0], 28, 28, 1)).astype('float32') / 255.0

    # 对标签进行多分类编码
    num_categories = 10
    y_train = tf.keras.utils.to_categorical(y_train, num_categories)
    y_test = tf.keras.utils.to_categorical(y_test, num_categories)

    return x_train, y_train, x_test, y_test

def build_model_Fully_connected():
    # 构建全连接神经网络模型
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(units=512, activation='relu'))
    model.add(Dense(units=512, activation='relu'))
    model.add(Dense(units=10, activation='softmax'))
    model.summary()

    return model

def compile_and_train_model(model, x_train, y_train, x_test, y_test):
    # 编译并训练模型
    optimizer = Adam(learning_rate=0.0001)
    model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])

    history = model.fit(x_train, y_train, epochs=10, verbose=1, validation_data=(x_test, y_test))
    return history

if __name__ == "__main__":
    # 加载并预处理数据
    x_train, y_train, x_test, y_test = load_and_preprocess_data()

    # 构建全连接神经网络模型
    model = build_model_Fully_connected()

    # 编译并训练模型
    history = compile_and_train_model(model, x_train, y_train, x_test, y_test)

    # 保存训练模型
    model.save("mnist_dnn_model.h5", include_optimizer=True)
    print("Model saved successfully.")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1439357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Multisim14.0仿真(五十五)汽车转向灯设计

一、功能描述: 左转向:左侧指示灯循环依次闪亮; 右转向:右侧指示灯循环依次闪亮; 刹车: 所有灯常亮; 正常: 所有灯熄灭。 二、主要芯片: 74LS161D 74LS04D 74…

深入理解Spark BlockManager:定义、原理与实践

深入理解Spark BlockManager:定义、原理与实践 1.定义 Spark是一个开源的大数据处理框架,其主要特点是高性能、易用性以及可扩展性。在Spark中,BlockManager是其核心组件之一,它负责管理内存和磁盘上的数据块,并确保…

通过docker-compose部署NGINX服务,并使该服务开机自启

要在通过docker-compose部署的NGINX服务实现开机自启,你需要确保Docker守护进程在系统启动时自动运行,并配置docker-compose.yml文件以在容器中运行NGINX服务。以下是步骤: 确保Docker守护进程开机启动: 在Ubuntu/Debian上&#x…

Spring IoC容器(四)容器、环境配置及附加功能

本文内容包括容器的Bean 及 Configuration 注解的使用、容器环境的配置文件及容器的附加功能(包括国际化消息、事件发布与监听)。 1 容器配置 在注解模式下,Configuration 是容器核心的注解之一,可以在其注解的类中通过Bean作用…

32USART串口

目录 一.通信接口 二.时序 三.USART简介 ​编辑四.数据帧 五.起始位侦测和采样位置对齐 &波特率计算 六.相关函数 七.编码格式设置 (1) UTF-8编码(有的软件兼容性不好)​编辑 (2)GB2312编码 八.…

【Nicn的刷题日常】之有序序列合并

1.题目描述 描述 输入两个升序排列的序列,将两个序列合并为一个有序序列并输出。 数据范围: 1≤�,�≤1000 1≤n,m≤1000 , 序列中的值满足 0≤���≤30000 0≤val≤30000 输入描述…

前端异步相关知识总结

目录 一、同步和异步简介 同步(按顺序执行) 异步(不按顺序执行) 异步出现的原因和需求 二、实现异步的方法 回调函数 Promise 生成器Generators/ yield async await 三、promise和 async await 区别 概念 两者的区别 …

07-Java桥接模式 ( Bridge Pattern )

Java桥接模式 摘要实现范例 桥接模式(Bridge Pattern)是用于把抽象化与实现化解耦,使得二者可以独立变化 桥接模式涉及到一个作为桥接的接口,使得实体类的功能独立于接口实现类,这两种类型的类可被结构化改变而互不影…

实践动物姿态估计,基于最新YOLOv8全系列【n/s/m/l/x】参数模型开发构建公共场景下行人人员姿态估计分析识别系统

姿态估计(PoseEstimation)在我们前面的相关项目中涉及到的并不多,CV数据场景下主要还是以目标检测、图像识别和分割居多,最近正好项目中在使用YOLO系列最新的模型开发项目,就想着抽时间基于YOLOv8也开发构建实现姿态估…

Open CASCADE学习|创建多段线与圆

使用Open CASCADE Technology (OCCT)库来创建和显示一些2D几何形状。 主要过程如下: 包含头文件:代码首先包含了一些必要的头文件,这些头文件提供了创建和显示几何形状所需的类和函数。 定义变量:在main函数中,定义…

如何查看端口映射?

端口映射是一种用于实现远程访问的技术。通过将外网端口与内网设备的特定端口关联起来,可以使外部网络用户能够通过互联网访问内部网络中的设备和服务。在网络中使用端口映射可以解决远程连接需求,使用户能够远程访问设备或服务,无论是在同一…

JAVA生产使用登录校验模式

背景 目前我们的服务在用户登录时,会先通过登录接口进行密码校验。一旦验证成功,后端会利用UUID生成一个独特的令牌(token),并将其存储在Redis缓存中。同时,前端也会将该令牌保存在本地。在后续的接口请求…

常用对象和常用成员函数

常量对象与常量成员函数来防止修改对象,实现最低权限原则。 在Obj被定义为常量对象的情况下,下面这条语句是错误的。 错误的原因是常量对象一旦初始化后,其值就再也不能改变。因此,不能通过常量对象调用普通成员函数,因…

海外云手机的核心优势

随着5G时代的到来,云计算产业正处于高速发展的时期,为海外云手机的问世创造了一个可信任的背景。在资源有限且需求不断增加的时代,将硬件设备集中在云端,降低个人用户的硬件消耗,同时提升性能,这一点单单就…

得物自研API网关实践之路

一、业务背景 老网关使用 Spring Cloud Gateway (下称SCG)技术框架搭建,SCG基于webflux 编程范式,webflux是一种响应式编程理念,响应式编程对于提升系统吞吐率和性能有很大帮助; webflux 的底层构建在netty之上性能表…

广度优先搜索(BFS)

力扣刷题之旅:进阶篇(二) 继续我的力扣刷题之旅,我在进阶篇的第一部分中深入探索了BFS(广度优先搜索)算法,并感受到了它在图形搜索中的强大威力。现在,我进入了进阶篇的第二部分&am…

百卓Smart管理平台 uploadfile.php 文件上传漏洞复现(CVE-2024-0939)

0x01 产品简介 百卓Smart管理平台是北京百卓网络技术有限公司(以下简称百卓网络)的一款安全网关产品,是一家致力于构建下一代安全互联网的高科技企业。 0x02 漏洞概述 百卓Smart管理平台 uploadfile.php 接口存在任意文件上传漏洞。未经身份验证的攻击者可以利用此漏洞上传…

单片机无线发射的原理剖析

目录 一、EV1527编码格式 二、OOK&ASK的简单了解 三、433MHZ 四、单片机的地址ID 五、基于STC15W104单片机实现无线通信 无线发射主要运用到了三个知识点:EV1527格式;OOk;433MHZ。下面我们来分别阐述: EV1527是数据的编…

Stable Diffusion 模型下载:Samaritan 3d Cartoon SDXL(撒玛利亚人 3d 卡通 SDXL)

文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八案例九案例十 下载地址 模型介绍 由“PromptSharingSamaritan”创作的撒玛利亚人 3d 卡通类型的大模型,该模型的基础模型为 SDXL 1.0。 条目内容类型大模型基础模型SDXL 1.0来源CIVITA…

IDEA创建Java类时自动添加注释(作者、年份、月份)

目录 IDEA创建Java类时自动添加注释(作者、年份、月份)如图: IDEA创建Java类时自动添加注释(作者、年份、月份) 简单记录下,IDEA创建Java类时自动添加注释(作者、年份、月份)&#…