神经网络的公式推导与代码实现(论文复现)

news2024/11/18 9:23:37

神经网络的公式推导与代码实现(论文复现)

本文所涉及所有资源均在传知代码平台可获取

概述

本文将详细推导一个简单的神经网络模型的正向传播、反向传播、参数更新等过程,并将通过一个手写数字识别的例子,使用python手写和pytorch分别实现,能够让读者深刻地理解神经网络的具体参数更新训练的工作流程,文末将包含数据+代码+PPT。

这些内容是基于神经网络和机器学习的通用知识,正向传播和反向传播,如今几乎所有的深度学习模型的训练都是基于这样相同或者相似的方法进行训练的,有助于帮助我们更加深入的理解深度学习模型。

引言

多层感知机(Multilayer Perceptron,简称MLP)是神经网络的一种。MLP是一种前馈神经网络,它包含一个或多个隐藏层,以及非线性激活函数,这使得MLP能够学习和模拟复杂的非线性关系。MLP是最基础也是最广泛研究的神经网络类型之一,本文将以一个MLP模型来展开。

MLP的结构通常如下:

输入层:接收外部输入数据。

隐藏层:一个或多个隐藏层,每层包含多个神经元。隐藏层负责从输入数据中提取特征并进行初步的非线性变换。

输出层:输出网络的预测结果,对于分类问题,输出层通常使用softmax激活函数进行多类分类。

MLP的训练过程通常包括以下几个步骤:

前向传播:输入数据通过网络,通过每个神经元的加权和和激活函数,最终得到输出。
计算损失:使用损失函数(如均方误差、交叉熵等)计算网络输出与真实标签之间的差异。

反向传播:根据损失函数的梯度,计算每一层的权重对损失的贡献,即梯度。

权重更新:使用梯度下降或其他优化算法(如Adam、RMSprop等)根据梯度更新网络的权重和偏置。

MLP在许多领域都有应用,包括图像识别、语音识别、自然语言处理、游戏AI等。随着深度学习的发展,MLP作为深度神经网络的基础,其结构和训练方法也在不断地被改进和优化。

实际上,几乎所有的深度学习模型中都会有MLP的身影,相当于深度学习模型的骨架,特别是在深度学习模型中最后一步,通常会接个MLP来使得输出的维度符合我们任务的需求,例如我们当前需要要对手写数字识别,那就是一个10分类问题,最后输出可以通过接一个MLP变成10维,每一维代表一个分类,从而顺利地使模型适配我们的任务。

神经网络公式推导

在这里插入图片描述

假设我们有这么一个神经网络,由输入层、一层隐藏层、输出层构成:(这里为了方便,不考虑偏置bias)

在这里插入图片描述

在这里插入图片描述

前向传播(forward)

首先,我们可以试着表示一下y1
如模型图所示可以表示为:

在这里插入图片描述

那么我要表示yj呢?

在这里插入图片描述

其中j=1时,就是y1的表示,j=m时,就是ym的表示。

同理我们可以得到:

在这里插入图片描述

ok表示输出层第k个神经元的预测值,这就是我们需要的输出。
至此,正向传播完毕

反向传播(backward)

光正向传播,我们只能得到模型的预测值,不能更新模型的参数,也就是说,正向传播的时候,模型是不会被更新的。

因为我们得到了模型输出的预测值,并且我们手上有对应的真实值,我们就能够将误差反向传播,更新模型参数。

具体操作怎么操作呢?

首先,我们需要定义误差,即预测值和真实值差了多少,以此来决定模型参数更新的方向和力度。

这里我们采用简单的差的平方的损失函数:

在这里插入图片描述

注意,这里只是更新输出层第k个神经元所反馈的误差。

隐藏层和输出层的权重更新
首先根据已知如下:

输出层预测值ok

在这里插入图片描述

激活函数Sigmoid

在这里插入图片描述

那我们可以试着展开一下Ek

在这里插入图片描述

因为我们现在需要更新的是wjk,因此展开到wjk我们就能有一个比较形象的认识了。

根据梯度下降法可得,我们现在只需要求出

在这里插入图片描述

在这里插入图片描述

接下来我们分别求出:

在这里插入图片描述

在这里插入图片描述

我们先给出激活函数的导数推导过程:

在这里插入图片描述

就是使用复合函数除的求导法则进行求导。我们可以发现sigmoid函数求导之后还是挺好看的。

接下来就是计算两个导数即可。

在这里插入图片描述
在这里插入图片描述

一眼就能看出来了吧,就是别忘了里面的-ok也要导,负号别漏了,然后是

在这里插入图片描述

这个可能会有点困难,但是仔细看看,发现还是很简单的;首先

在这里插入图片描述

在这里插入图片描述

(链式求导法)因此:

在这里插入图片描述

那么这个结果计算起来就比较简单了;既然如此,将结果拼起来就是我们要求的结果了:

在这里插入图片描述

在这里插入图片描述

全是已知的,不就可以更新参数了嘛;因此,加个学习率这层权重更新推导就大功告成了

在这里插入图片描述

输入层和隐藏层的权重更新;如果上面的推导看懂了,下面的推导就非常简单了,无非就是多展开一级,多求一次导数而已;首先(前面已经推到过了)

在这里插入图片描述

那么我们可以将误差再展开一级(接着链导下去):

在这里插入图片描述

那么下面这个就非常直观了

在这里插入图片描述

同样的,我们也分别求出三次的导数,最后拼起来就行了。

在这里插入图片描述

至此分别求出来了,拼起来就是我们要的结果了:

在这里插入图片描述

通过观察,里面全是已知的变量;那么更新公式也就有了:

在这里插入图片描述

数据集介绍

实验数据就是mnist手写数据集

在这里插入图片描述

第一列为label,表示这个图片是什么数字;后面都为图片的像素值,表示图片的数据;模型的输入就是像素值,输出就是预测值,即通过像素预测出是什么数字。

核心代码

其中比较关键的就是那两个参数的更新公式;隐藏层和输出层的权重更新:

在这里插入图片描述

输入层和隐藏层的权重更新:

在这里插入图片描述

数据集+python手写代码+pytorch代码+ppt都在附件里哦

运行结果

在这里插入图片描述

在这里插入图片描述

总结

感觉从推导到代码实现也是一个反复的过程,从推导发现代码写错了,写不出代码了就要去看看推导的过程,这个过程让我对反向传播有了较全面的理解。

我们发现,手写代码运行时间要一分多钟而pytorch其实只要10s不到,毕竟框架,底层优化很多,用起来肯定用框架。

以及二者准确率有一些差距,可能是因为pytorch里使用了交叉熵损失函数,比较适合分类任务;手写的并没有分batch,而是所有数据直接更新参数,但是pytorch里分了batch,分batch能够使得模型训练速度加快(并行允许),也使得模型参数更新的比较平稳。

文章代码资源点击附件获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2136556.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

充电管理芯片

1.接口 typec SBU引脚(辅助通道) SBU引脚是Type-C接口母座中的辅助通道,用于支持附加的功能,如模拟音频和视频传输。通过SBU引脚,Type-C接口母座可以实现更广泛的连接应用,包括连接到外部显示器、音频设备…

【前端】main.js中app.vue中 render函数的作用及使用背景

vue.js中的main.js中的作用是将app组件挂载到页面中&#xff0c;其中app组件是汇总所有组件元素的组件。main.js的创建vue实例。 #认为的版本 import APP from ./App.vue;new Vue({el:#root,template:<App></App>,components:{App}, })#实际的版本 /* 整个项目的入…

基于imx6ull平台opencv的图像采集和显示屏LCD显示功能(带Qt界面)

目录 一、概述二、环境要求2.1 硬件环境2.2 软件环境三、开发流程3.1 编写测试3.2 验证功能一、概述 本文档是针对imx6ull平台opencv的图像采集和显示屏LCD显示功能,创建Qt工程,在工程里面通过点击按钮,实现opencv通过摄像头采集视频图像,将采集的视频图像送给显示屏LCD进…

docker-compose elasticsearch 集群搭建(用户登录+https访问)

文章目录 前言docker-compose elasticsearch 集群搭建(用户登录+https访问)1. 效果2. 制作elasticsearch + 分词器镜像2.1. 拉取elasticsearch:7.11.12.2. 制作特定版本镜像3. docker-compose elasticsearch 集群制作4. es账户密码初始化前言 如果您觉得有用的话,记得给博主点…

<Python>基于python使用PyQt6编写一个延迟退休计算器

前言 这两天关于延迟退休的话题比较火&#xff0c;官方也退出了延迟退休计算器的小程序&#xff0c;我们使用python来实现一个。 环境配置 系统&#xff1a;windows 平台&#xff1a;visual studio code 语言&#xff1a;python 库&#xff1a;pyqt6 程序依据 程序的算法依据…

常见本地大模型个人知识库工具部署、微调及对比选型

文章目录 常见本地大模型个人知识库工具部署、微调及对比选型知识库侧AnythingLLMMaxKBRAGFlowFastGPTDifyOpen WebUI小结大模型侧OllamaLM StudioXinference小结大模型侧工具安装部署实践Ollama部署Windows部署OllamaLinux部署OllamaOllama使用技巧模型更换存储路径导出某个模…

外国车牌字符识别与分类系统源码分享

外国车牌字符识别与分类检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Co…

【Java面试】第九天

&#x1f31f;个人主页&#xff1a;时间会证明一切. 目录 Spring中如何开启事务&#xff1f;编程式事务声明式事务声明式事务的优点声明式事务的粒度问题声明式事务用不对容易失效 Spring的事务传播机制有哪些&#xff1f;Spring事务失效可能是哪些原因&#xff1f;代理失效的情…

Ubuntu下beanstalkd无法绑定局域网IP地址以及消息队列beanstalkd上的error: JOB_TOO_BIG的解决

一、ubuntu下beanstalkd无法绑定局域网IP地址 今天因为业务需要&#xff0c;我把之前安装的beanstalkd所绑定的IP地址由127.0.0.1改成局域网IP地址&#xff0c;但是怪了&#xff0c;显示beanstalkd已经启动&#xff0c;查看端口监控也显示IP地址变了&#xff0c;但是使用telnet…

matlab绘制不同区域不同色彩的图,并显示数据(代码)

绘图结果如下&#xff1a; 代码如下&#xff1a; A为绘图的数据&#xff0c;每个数据对应着上图中的一个区域&#xff0c;数据大小决定区域的颜色 % 假设有一系列的数据点 Arand(5,6); %A为绘图的数据&#xff0c;数据大小决定颜色 wei_shu%.3f; %代表数据保留三位小…

[Golang] Channel

[Golang] Channel 文章目录 [Golang] Channel什么是Channelchannel的初始化channel的操作双向channel和单向channel为什么有channel有缓冲channel和无缓冲channlechannel做一把锁 从之前我们知道go关键字可以开启一个Goroutine&#xff0c;但是Goroutine之间的通信还需要另一个…

Recyclerview实现滑动居中缩放菜单

最近项目中需要的一个滑动菜单效果:要求当前居中选项放大、滑动时有缩放效果、点击两边的选项滑动到屏幕中央、停止滑动选项停留在屏幕中间(类似viewPager的效果),为了直观,先上最终实现效果图: 大体思路: Recyclerview item头尾添加空数据,让第一个和最后一个item也能…

计算机组成原理(第二次笔记)

各种码 真值 (书写用)&#xff1a; 将用“”、“-” 表示正负的二进制数称为真值 机器不能识别书写格式&#xff0c;故用“0/1”表示“/-”符号。 机器码 (机器内部使用)&#xff1a; 将符号和数值一起编码表示的二进制数称为机器码。 常用机器码&#xff1a;原码、 反码、 补…

Linux网络编程 --- 高级IO

前言 IO Input&&Output read && write 1、在应用层read && write的时候&#xff0c;本质把数据从用户层写给OS --- 本质就是拷贝函数 2、IO 等待 拷贝。 等的是&#xff1a;要进行拷贝&#xff0c;必须先判断读写事件成立。读写事件缓冲区空间满…

Kafka+PostgreSql,构建一个总线服务

之前开发的系统&#xff0c;用到了RabbitMQ和SQL Server作为总线服务的传输层和存储层&#xff0c;最近一直在看Kafka和PostgreSql相关的知识&#xff0c;想着是不是可以把服务总线的技术栈切换到这个上面。今天花了点时间试了试&#xff0c;过程还是比较顺利的&#xff0c;后续…

破解AI生成检测:如何用ChatGPT降低论文的AIGC率

学境思源&#xff0c;一键生成论文初稿&#xff1a; AcademicIdeas - 学境思源AI论文写作 降低论文的“AIGC率”是个挑战&#xff0c;但有一些策略可以尝试。使用ChatGPT逐步调整和改进内容&#xff0c;使其更加自然和原创&#xff0c;降低AI检测工具识别出高“AIGC率”的概率…

专访阿里云:AI 时代服务器操作系统洗牌在即,生态合作重构未来

编者按&#xff1a;近日&#xff0c;2024 龙蜥操作系统大会已于北京圆满举办。大会期间&#xff0c;CSDN 采访了阿里云基础软件部资深技术总监、龙蜥社区技术委员会主席杨勇&#xff0c;前瞻性宏观解读面向 AI 智算时代&#xff0c;服务器操作系统面临的挑战与机遇。以下为采访…

云曦2024秋考核

真正的hacker 进去以后一眼就能看出来&#xff0c;是ThinkphpV5漏洞&#xff0c;只是版本不能确定&#xff0c;一开始考核的时候是&#xff0c;抓包看了php的版本&#xff0c;是7.23&#xff0c;是手注了几个尝试出来的&#xff08;后面才发现报错信息里面就有&#xff09;。漏…

记录word转xml文件踩坑

word文件另存为xml文件后&#xff0c;xml文件乱码 解决方法&#xff1a; 1.用word打开.docx文件 2.另存为xml文件 3.点击工具 -> Web选项 -> 编码&#xff0c;选择UTF-8 4.点击确定 5.使用notpad打开xml文件 6.使用xml tool进行xml格式化即可。

【免费资料推荐】数据资产管理实践白皮书(6.0版)

荐言&#xff1a;随着数字经济的快速发展&#xff0c;数据已成为企业最重要的资产之一。为有效管理和利用数据资产&#xff0c;各行业纷纷推出数据管理框架和标准。数据资产管理实践白皮书&#xff08;6.0版&#xff09;由中国信息通信研究院联合相关企业共同编写&#xff0c;是…