神经网络反向传播的数学原理

news2024/11/20 19:47:09

如果能二秒内在脑袋里解出下面的问题,本文便结束了。

已知:J=(Xw-y)^T(Xw-y)=\left | \right | Xw-y\left | \right | ^2,其中X\in R^{m\times n},w\in R^{n\times1},y\in R^{m\times1}

求:\frac{\partial J}{\partial X},\frac{\partial J}{\partial w},\frac{\partial J}{\partial y}

到这里,请耐心看完下面的公式推导,无需长久心里建设。

首先,反向传播的数学原理是“求导的链式法则” :

设f和g为x的可导函数,则(f\circ g)'(x)=f'(g(x))g'(x)

接下来介绍

  • 矩阵、向量求导的维数相容原则

  • 利用维数相容原则快速推导反向传播

  • 编程实现前向传播、反向传播

  • 卷积神经网络的反向传播

快速矩阵、向量求导

这一节展示如何使用链式法则、转置、组合等技巧来快速完成对矩阵、向量的求导

一个原则维数相容,实质是多元微分基本知识,没有在课本中找到下列内容,维数相容原则是我个人总结:

维数相容原则:通过前后换序、转置 使求导结果满足矩阵乘法且结果维数满足下式:

如果x\in R^{m \times n},f(x) \in R^1 ,那么\frac{\partial f(x)}{\partial x} \in R^{m \times n}

利用维数相容原则解上例:

step1:把所有参数当做实数来求导,J=(Xw-y)^2

依据链式法则有\frac{\partial J}{\partial X}=2(Xw-y)w, \frac{\partial J}{\partial w}=2(Xw-y)X, \frac{\partial J}{\partial y}=-2(Xw-y)

可以看出除了\frac{\partial J}{\partial y}=-2(Xw-y)\frac{\partial J}{\partial X}\frac{\partial J}{\partial w}的求导结果在维数上连矩阵乘法都不能满足。

step2:根据step1的求导结果,依据维数相容原则做调整:前后换序、转置

依据维数相容原则\frac{\partial J}{\partial X} \in R^{m \times n},但\frac{\partial J}{\partial X} \in R^{m \times n}=2(Xw-y)w(Xw-y)\in R^{m \times 1}X \in R^{m \times n},自然得调整为\frac{\partial J}{\partial X}=2(Xw-y)w^T

同理:\frac{\partial J}{\partial w} \in R^{n \times 1},但\frac{\partial J}{\partial w} \in R^{n \times 1}=2(Xw-y)X(Xw-y) \in R^{m \times 1}X \in R^{m \times n},那么通过换序、转置我们可以得到维数相容的结果2X^T(Xw-y)

对于矩阵、向量求导:

  • “当做一维实数使用链式法则求导,然后做维数相容调整,使之符合矩阵乘法原则且维数相容”是快速准确的策略;

  • “对单个元素求导、再整理成矩阵形式”这种方式整理是困难的、过程是缓慢的,结果是易出错的(不信你试试)。

如何证明经过维数相容原则调整后的结果是正确的呢?直觉!简单就是美...

快速反向传播

神经网络的反向传播求得“各层”参数W和b的导数,使用梯度下降(一阶GD、SGD,二阶LBFGS、共轭梯度等)优化目标函数。

接下来,展示不使用下标的记法(W_{ij},b_iorb_j)直接对W和b求导反向传播链式法则维数相容原则的完美体现,对每一层参数的求导利用上一层的中间结果完成。

这里的标号,参考UFLDL教程 - Ufldl

前向传播:

z^{(l+1)}=W^{(l)}a^{(l)}+b^{(l)}    (公式1)

a^{(l+1)}=f(z^{(l+1)})             (公式2)

z^{(l)}为第 l 层的中间结果,a^{(l)} 为第 l 层的激活值,其中第 l +1层包含元素:输入a^{(l)},参数W^{(l)}b^{(l)},激活函数f(),中间结果z^{(l+1)},输出a^{(l+1)}

设神经网络的损失函数为J(W,b) \in R^1(这里不给出具体公式,可以是交叉熵、MSE等),根据链式法则有:

图片

这里记 \frac{\partial J(W,b)}{\partial {z^{(l+1)}}}=\delta ^{(l+1)},其中\frac{\partial z^{(l+1)}}{\partial W^{(l)}}=a^{(l)}\frac{\partial z^{(l+1)}}{\partial b^{(l)}}=1可由 公式1 得出,a^{(l)}加转置符号(a^{(l)})^T是根据维数相容原则作出的调整。

如何求 \delta ^{(l)}=\frac{\partial J(W,b)}{\partial z^{(l)}} ?可使用如下递推(需根据维数相容原则作出调整):

图片

其中

图片

那么我们可以从最顶层逐层往下,便可以递推求得每一层的\delta ^{(l)}=\frac{\partial J(W,b)}{\partial z^{(l)}}

注意:\frac{\partial a^{(l)}}{\partial z^{(l)}}=f'(z^{(l)})是逐维求导,在公式中是点乘的形式。

反向传播整个流程如下:

1) 进行前向传播计算,利用前向传播公式,得到隐藏层和输出层 的激活值。

2) 对输出层(第 l 层),计算残差:\delta ^{(l)}=\frac{\partial J(W,b)}{\partial z^{(l)}}(不同损失函数,结果不同,这里不给出具体形式)

3) 对于l-1,l-2,\cdot \cdot \cdot ,2的隐藏层,计算:

图片

4) 计算各层参数W^{(l)},b^{(l)}偏导数:

图片

编程实现

大部分开源library(如:caffe,Kaldi/src/{nnet1,nnet2})的实现通常把W^{(l)},b^{(l)}作为一个layer,激活函数f()作为一个layer(如:sigmoid、relu、softplus、softmax)。

反向传播时分清楚该层的输入、输出即能正确编程实现,如:

图片

(1)式AffineTransform/FullConnected层,以下是伪代码:

图片

注: out_diff = \frac{\partial J}{\partial z^{(l+1)}}是上一层(Softmax 或 Sigmoid/ReLU的 in_diff)已经求得:

图片

(2)式激活函数层(以Sigmoid为例)

图片

注:out_diff = \frac{\partial J}{\partial a^{(l+1)}}是上一层AffineTransform的in_diff,已经求得,

图片

在实际编程实现时,in、out可能是矩阵(通常以一行存储一个输入向量,矩阵的行数就是batch_size),那么上面的C++代码就要做出变化(改变前后顺序、转置,把函数参数的Vector换成Matrix,此时Matrix out_diff 每一行就要存储对应一个Vector的diff,在update的时候要做这个batch的加和,这个加和可以通过矩阵相乘out_diff*input(适当的转置)得到。

如果熟悉SVD分解的过程,通过SVD逆过程就可以轻松理解这种通过乘积来做加和的技巧。

丢掉那些下标记法吧!

卷积层求导

卷积怎么求导呢?实际上卷积可以通过矩阵乘法来实现(是否旋转无所谓的,对称处理,caffe里面是不是有image2col),当然也可以使用FFT在频率域做加法。

那么既然通过矩阵乘法,维数相容原则仍然可以运用,CNN求导比DNN复杂一些,要做些累加的操作。具体怎么做还要看编程时选择怎样的策略、数据结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1225427.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多因素方差分析(Multi-way Analysis of Variance) R实现

1, data0507 flower 是某种植物在两个海拔和两个气温下的开花高度,采用合适 的统计方法,检验该种植物的开花高度在不同的海拔之间和不同的气温之间有无差异?如果有差异,具体如何差异的?(说明依据、结论等关…

春秋云境靶场CVE-2022-28512漏洞复现(sql手工注入)

文章目录 前言一、CVE-2022-28512靶场简述二、找注入点三、CVE-2022-28512漏洞复现1、判断注入点2、爆显位个数3、爆显位位置4 、爆数据库名5、爆数据库表名6、爆数据库列名7、爆数据库数据 总结 前言 此文章只用于学习和反思巩固sql注入知识,禁止用于做非法攻击。…

Learning Perception Module

参考文章:自动驾驶开发者说|框架|如何单独运行apollo相机感知模块? - 知乎引言文章主要尝试了apollo框架下,视觉感知模块的单独运行,并利用离线的数据包进行检测实时展示结果。过程相对来说比较顺利。在加上已经用VScode搭建的单步…

springboot321基于java的校园服务平台设计与开发

交流学习: 更多项目: 全网最全的Java成品项目列表 https://docs.qq.com/doc/DUXdsVlhIdVlsemdX 演示 项目功能演示: ————————————————

解决:ERROR: No matching distribution found for PIL

解决:ERROR: No matching distribution found for PIL 背景 在搭建之前的代码环境时,报错: ERROR: Could not find a wersion that satisfies the requirement PIL(from versions: none) ERROR: No matching distribu…

机器视觉系统选型-定光照强度

同一个外形结构的光源,光照强度受如下影响: 单颗灯珠的亮度灯珠排列的数量和密度漫射板/防护板的材质(透明、半透明、全漫射) 在合理范围内提升光照强度,可降低对相机曝光时长的要求 外形结构尺寸相同的两款光源&am…

基于SSM的古董拍卖系统

基于SSM的古董拍卖系统的设计与实现~ 开发语言:Java数据库:MySQL技术:SpringMyBatisSpringMVC工具:IDEA/Ecilpse、Navicat、Maven 系统展示 主页 拍卖界面 管理员界面 摘要 古董拍卖系统是一个基于SSM框架(Spring …

Linux yum 使用时提示 获取 GPG 密钥失败Couldn‘t open file RPM-GPG-KEY-EPEL-7

资料 错误提示: no crontab for root - using an empty one 888 原因剖析: 第一次使用crontab -e 命令时会让我们选择编辑器,很多人会不小心选择默认的nano(不好用),或则提示no crontab for root - usin…

Java拼图游戏

运行出的游戏界面如下: 按住A不松开,显示完整图片;松开A显示随机打乱的图片。 User类 package domain;/*** ClassName: User* Author: Kox* Data: 2023/2/2* Sketch:*/ public class User {private String username;private String password…

进程概述

文章目录 计算机算机组成因特尔CPU型号摩尔定律衡量CPU的指标指令(Instruction)操作系统(Operating System)虚拟地址空间(Virtual Address Space)进程(Process/task)进程管理(PCB - 进程控制块)进程控制块(…

一文讲明 Spring 的使用 【全网超详细教程】

文章底部有个人公众号:热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享? 踩过的坑没必要让别人在再踩,自己复盘也能加深记忆。利己利人、所谓双赢。 前言 目录结构 Spring 的相关代码 都公开在…

分类预测 | Matlab实现基于SDAE堆叠去噪自编码器的数据分类预测

分类预测 | Matlab实现基于SDAE堆叠去噪自编码器的数据分类预测 目录 分类预测 | Matlab实现基于SDAE堆叠去噪自编码器的数据分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.Matlab实现基于SDAE堆叠去噪自编码器的数据分类预测(完整源码和数据) 2.多…

Linux procps-ng - top

procps-ng 是一个开源的进程管理工具集,它提供了一系列用于监控和管理系统进程的命令行工具。它是 procps 工具集的一个分支,旨在改进和增强原有的 procps 工具。 procps-ng 包括了一些常用的命令行工具,例如: ps:用于…

【漏洞复现】泛微e-Weaver SQL注入

漏洞描述 泛微e-Weaver(FANWEI e-Weaver)是一款广泛应用于企业数字化转型领域的集成协同管理平台。作为中国知名的企业级软件解决方案提供商,泛微软件(广州)股份有限公司开发和推广了e-Weaver平台。 泛微e-Weaver旨在…

springBoot 配置druid多数据源 MySQL+SQLSERVER

1:pom 文件引入数据 <dependency> <groupId>com.alibaba</groupId> <artifactId>druid-spring-boot-starter</artifactId> <version>1.1.0</version> </dependency>…

前端性能优化之LightHouse

优质博文&#xff1a;IT-BLOG-CN 一、LightHouse环境搭建 LightHouse是一款由Google开发的开源工具&#xff0c;用于评估Web应用程序的性能和质量。可以将其看作是一个Chrome扩展程序运行&#xff0c;或从命令行运行。为LightHouse提供一个需要审查的网址&#xff0c;它将针对…

基于django的在线教育系统

基于python的在线教育系统 摘要 基于Django的在线教育系统是一种利用Django框架开发的现代化教育平台。该系统旨在提供高效、灵活、易用的在线学习体验&#xff0c;满足学生、教师和管理员的需求。系统包括学生管理、课程管理、教师管理、视频课程、在线测验等核心功能。系统采…

获取虎牙直播源

为了今天得LOL总决赛 然后想着下午看看 但是网页看占用高 就想起来有个直播源 也不复杂看了大概一个小时 没啥问题 进入虎牙页面只有 直接F12 网络 然后 看这个长条 一直在获取 发送 那就选中这个区间 找到都是数字这一条 如果直接访问的话会一直下载 我这都取消了 然后 打开…

Michael Jordan最新报告:去中心化机器学习中的契约、不确定性和激励

‍ ‍导读 11月3日&#xff0c;智源研究院学术顾问委员会委员、机器学习泰斗Michael Jordan在以“新一代人工智能前沿”为主题的2023北京论坛 新工科专题论坛上&#xff0c;发表了题为Contracts, Uncertainty, and Incentives in Decentralized Machine Learning&#xff08;去…

H5ke11..--2其他界面也要提取我的locatStarage

获取浏览器里面的本地缓存 localStorage就是我们的浏览器缓存在哪都可以用,调用我们的locatStarage就行 下面代码是获取打印到我们的页面上 修改在我们另一个界面得到 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8&quo…