用C语言构建一个手写数字识别神经网络

news2024/11/17 21:49:26

59bc0d11101b414b8ec3a3ce89e498b7.jpg

 

(原理和程序基本框架请参见前一篇 "用C语言构建了一个简单的神经网路")

1.准备训练和测试数据集
从http://yann.lecun.com/exdb/mnist/下载手写数字训练数据集, 包括图像数据train-images-idx3-ubyte.gz 和标签数据 train-labels-idx1-ubyte.gz.
分别将他们解压后放在本地文件夹中,解压后文件名为train-images-idx3-ubyte和train-labels-idx1-ubyte. 训练数据集一共包含了6万个手写数字灰度图片和对应的标签.
为图方便,我们直接从训练数据集中提取5000个作为测试数据.当然,实际训练数据中并不包含这些测试数据.


2.设计神经网络
采用简单的三层全连接神经网络,包括输入层(wi),中间层(wm)和输出层(wo).这里暂时不使用卷积层,下次替换后进行比较.
输入层: 一共20个神经元,每一张手写数字的图片大小为28x28,将其全部展平后的784个灰度数据归一化,即除以255.0, 使其数值位于[0 1]区间,这样可以防止数据在层层计算和传递后变得过分大.将这784个[0 1]之间的数据与20个神经元进行全连接.神经元激活函数用func_ReLU.
中间层: 一共20个神经元,与输入层的20个神经元输出进行全连接.神经元激活函数用func_ReLU.
输出层: 一共10个神经元,分别对应0~9数字的可能性,与层中间的20个神经元输出进行全连接.神经层激活函数用func_softmax.
特别地,神经元的激活函数在new_nvcell()中设定,层的激活函数直接赋给nerve_layer->transfunc.
损失函数: 采用期望和预测值的交叉熵损失函数func_lossCrossEntropy. 损失函数在nvnet_feed_forward()中以参数形式输入.

3.训练神经网络
由于整个程序是以nvcell神经元结构为基础进行构建的,其不同于矩阵/张量形式的批量数据描述,因此这个神经网络只能以神经元为单位,逐个逐层地进行前向和反向传导.
相应地,这里采用SGD(Stochastic Gradient Descent)梯度下降更新法,即对每一个样本先进行前向和反向传导计算,接着根据计算得到的梯度值马上更新所有参数.与此不同,mini-batch GD方法采用的对小批量样本进行前向和反向传导计算,然后根据累积的梯度数值做1次参数更新.显然,采用SGD方法参数更新更加频繁,计算时间相应也变长了,但是,据分析,采用SGD也更容易达到全局最优解附近.本文程序里所做的分批计算是为了方便监控计算过程和打印中间值.(当然,要实现mini-batch GD也是可以的,先完成一批量样本的前后传导计算,期间将各参数的梯度累计起来,  最后取其平均值更新一次参数.)
这里使用平均损失值mean_err=0.0025来作为训练的终止条件,为防止无法收敛到此数值,同时设置最大的epoch计数.
训练的样本数量由TRAIN_IMGTOTAL来设定, 训练时,先读取一个样本数据和一个标签,分别存入到data_input[28*28]和data_target[10], 为了配合应用softmax函数,这里data_target[]是one-hot编码格式.读入样本数据后先进行前向传导计算nvnet_feed_forward(),接着进行反向传导计算nvnet_feed_backward(), 最后更新参数nvnet_update_params(), 这样就完成了一个样本的训练.如此循环计算,完成一次所有样本的训练(epoch)后计算mean_err, 看是否达到预设目标.

4.测试训练后的神经网络
训练完成后,对模型进行简单评估.方法就是用训练后的模型来预测(predict)或推理(infer)前面的测试数据集中的图像数据,将结果与对应的标签值做对比.
同样,将一个测试样本加载到data_input[], 跑一次nvnet_feed_forward(),直接读取输出层的wo_layer->douts[k] (k=0~9),如果其值大于0.5,就认为模型预测图像上的数字是k.

5.小结
取5万条训练样本进行训练,训练后再进行测试,其准确率可接近94%.
与卷积神经网络相比较,为达到相同的结果,全连接的神经网络的所需要的训练时间会更长.

源代码:
https://github.com/midaszhou/nnc
下载后编译:
make TEST_NAME=test_nnc2

ca7ccbb483734fa191163ce2b05ce968.png

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/800142.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【语音识别】- 声学,词汇和语言模型

一、说明 语音识别是指计算机通过处理人类语言的音频信号,将其转换为可理解的文本形式的技术。也就是说,它可以将人类的口语语音转换为文本,以便计算机能够进一步处理和理解。它是自然语言处理技术的一部分,被广泛应用于语音识别助…

代码随想录训练营day2

问题一:长度最小的子数组 给定一个含有 n 个正整数的数组和一个正整数 target 。 找出该数组中满足其和 ≥ target 的长度最小的 连续子数组 [numsl, numsl1, ..., numsr-1, numsr] ,并返回其长度。如果不存在符合条件的子数组,返回 0 输入…

STM32 CAN控制的相关结构体

目录 CAN结构体 CAN初始化结构体 CAN发送及接收结构体 CAN筛选器结构体 CAN结构体 从STM32的CAN外设我们了解到它的功能非常多,控制涉及的寄存器也非常丰富,而使用STM32标准库提供的各种结构体及库函数可以简化这些控制过程。跟其它外设一样&#xf…

Mac配置android studio的终端terminal

一共6步 首先打开terminal 1.echo $HOME 2.touch .bash_profile 3.open -e .bash_profile 4.在弹出框中输入 export PATH${PATH}:你电脑sdk的路径/tools:你电脑sdk的路径/platform-tools 5.source .bash_profile 6.adb version 出现类似上图即为成功

将标签中某一个类别添加到另一个标签中

现在有两张CItyscapes数据集的标签,假设我想把第二张图骑车的人添加到第一张图,暂且不考虑添加位置的变换,那么该如何操作呢? 1:将骑车的人和车作为两个类别独立于其他的类别出来。 2:将这两个类别作为一个…

【LeetCode-简单】剑指 Offer 24. 反转链表(详解)

题目 定义一个函数,输入一个链表的头节点,反转该链表并输出反转后链表的头节点。 方法:迭代 思路 定义三个指针,一起往后走,走一步就修改mid指针的next,原本是mid的next 是right,我们修改成l…

ERROR in unable to locate ‘***/public/**/*‘ glob

前提 自己搭了一个react项目的脚手架,npm包下载一切都很正常,启动的时候突然就报ERROR in unable to locate ***/public/**/* glob这个错误,根据百度分析了一下产生的原因:webpack配置文件中的CopyWebpackPlugin导致的 网上给出的…

用Rust生成Ant-Design Table Columns | 京东云技术团队

经常开发表格,是不是已经被手写Ant-Design Table的Columns整烦了? 尤其是ToB项目,表格经常动不动就几十列。每次照着后端给的接口文档一个个配置,太头疼了,主要是有时还会粘错就尴尬了。 那有没有办法能自动生成colu…

Windows下安装python3教程

参考:https://blog.csdn.net/kailingr/article/details/128193083 一、安装步骤图解 准备工作: 进官网https://www.python.org/下载Python 安装包,注意:Python 3.9不能在Windows 7或更早版本上使用 安装: 1.下载完之后双击该文…

STM32 串口基础知识学习

串行/并行通信 串行通信:数据逐位按顺序依次传输。 并行通信:数据各位通过多条线同时传输。 对比 传输速率:串行通信较低,并行通信较高。抗干扰能力:串行通信较强,并行通信较弱。通信距离:串…

【具有非线性反馈的LTI系统识别】针对反馈非线性的LTI系统,提供非线性辨识方案(SimulinkMatlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码、Simulink仿真实现 💥1 概述 本文为具有反馈非线性的LTI系统提供了一种非线性识别方案,这取决于输入和LTI系统输出。对于MEMS来说尤其如此&#…

uniapp:手写签名,多张图合成一张图

要实现的内容&#xff1a;手写签名&#xff0c;协议内容。点击提交后&#xff1a;生成1张图片&#xff0c;有协议内容和签署日期和签署人。 实现的效果图如下&#xff1a; 1、签名页面 <template><view class"index"><u-navbar title"电子协议…

【数据结构】实验六:队列

实验六 队列 一、实验目的与要求 1&#xff09;熟悉C/C语言&#xff08;或其他编程语言&#xff09;的集成开发环境&#xff1b; 2&#xff09;通过本实验加深对队列的理解&#xff0c;熟悉基本操作&#xff1b; 3&#xff09; 结合具体的问题分析算法时间复杂度。 二、…

Qt完成文本转换为语音播报与保存(系统内置语音引擎)(一)

一、前言 在当今数字化社会,人们对于交互式应用程序的需求越来越高。除了传统的图形用户界面,语音交互也成为了一种流行的交互方式。在这种情况下,将文本转换为语音成为了一项重要的技术,它可以为用户提供更加人性化和便捷的交互方式。在此背景下,Qt提供了QTextToSpeech类…

python测试开发面试常考题:装饰器

目录 简介 应用 第一类对象 装饰器 描述器descriptor 资料获取方法 简介 Python 装饰器是一个可调用的(函数、方法或类)&#xff0c;它获得一个函数对象 func_in 作为输入&#xff0c;并返回另一函数对象 func_out。它用于扩展函数、方法或类的行为。 装饰器模式通常用…

flask中的常用装饰器

flask中的常用装饰器 Flask 框架中提供了一些内置的装饰器&#xff0c;这些装饰器可以帮助我们更方便地开发 Web 应用。以下是一些常用的 Flask 装饰器&#xff1a; app.route()&#xff1a;这可能是 Flask 中最常用的装饰器。它用于将 URL 路由绑定到一个 Python 函数&#x…

【前端知识】React 基础巩固(三十五)——ReduxToolKit (RTK)

React 基础巩固(三十五)——ReduxToolKit (RTK) 一、RTK介绍 Redux Tool Kit &#xff08;RTK&#xff09;是官方推荐的编写Redux逻辑的方法&#xff0c;旨在成为编写Redux逻辑的标准方式&#xff0c;从而解决上面提到的问题。 RTK的核心API主要有如下几个&#xff1a; confi…

Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用&#xff08;ToTensor&#xff0c;Normalize&#xff0c;Resize &#xff0c;Co…

Cadence OrCAD Capture绘制符号时缩小栅格距离的方法图文教程

🏡《总目录》   🏡《宝典目录》 目录 1,概述2,问题概述3,常规方法4,正确方法3,总结1,概述 本文简单介绍,使用Capture软件绘制原理图符号时,缩小或自定义栅格间距的方法。 2,问题概述 如下图所示,在进行原理图符号绘制时,管脚元件实体的线条等,均只能放置在栅…

佰维存储面向旗舰智能手机推出UFS3.1高速闪存

手机“性能铁三角”——SoC、运行内存、闪存决定了一款手机的用户体验和定位&#xff0c;其中存储器性能和容量对用户体验的影响越来越大。 针对旗舰智能手机&#xff0c;佰维推出了UFS3.1高速闪存&#xff0c;写入速度最高可达1800MB/s&#xff0c;是上一代通用闪存存储的4倍以…