DIDL4_前向传播与反向传播(模型参数的更新)

news2024/9/25 21:22:35

前向传播与反向传播

  • 前向传播与反向传播的作用
  • 前向传播及公式
    • 前向传播范例
  • 反向传播及公式
    • 反向传播范例
  • 小结
    • 前向传播计算图

前向传播与反向传播的作用

在训练神经网络时,前向传播和反向传播相互依赖。
对于前向传播,我们沿着依赖的方向遍历计算图并计算其路径上的所有变量。
然后将这些用于反向传播,其中计算顺序与计算图的相反,用于计算w、b的梯度(即神经网络中的参数)。随后使用梯度下降算法来更新参数。

因此,在训练神经网络时,在初始化模型参数后, 我们交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。

注意

  • 反向传播重复利用前向传播中存储的中间值,以避免重复计算。 带来的影响之一是我们需要保留中间值,直到反向传播完成。 这也是训练比单纯的预测需要更多的内存(显存)的原因之一。
  • 这些中间值的大小与网络层的数量和批量的大小大致成正比。 因此,使用更大的批量来训练更深层次的网络更容易导致内存不足(out of memory)错误。

前向传播及公式

前向传播(forward propagation或forward pass) 指的是:按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。

假设输入样本是 x ∈ R d \mathbf{x}\in \mathbb{R}^d xRd, 并且我们的隐藏层不包括偏置项。 这里的中间变量是:

z = W ( 1 ) x , \mathbf{z}= \mathbf{W}^{(1)} \mathbf{x}, z=W(1)x,

其中 W ( 1 ) ∈ R h × d \mathbf{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)Rh×d是隐藏层的权重参数。 将中间变量 z ∈ R h \mathbf{z}\in \mathbb{R}^h zRh通过激活函数 ϕ \phi ϕ后, 我们得到长度为 h h h的隐藏激活向量:
h = ϕ ( z ) . \mathbf{h}= \phi (\mathbf{z}). h=ϕ(z).

隐藏变量 h \mathbf{h} h也是一个中间变量。 假设输出层的参数只有权重 W ( 2 ) ∈ R q × h \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)Rq×h, 我们可以得到输出层变量,它是一个长度为 q q q的向量:
o = W ( 2 ) h . \mathbf{o}= \mathbf{W}^{(2)} \mathbf{h}. o=W(2)h.

假设损失函数为 l l l,样本标签为 y y y,我们可以计算单个数据样本的损失项
L = l ( o , y ) . L = l(\mathbf{o}, y). L=l(o,y).

根据 L 2 L_2 L2正则化的定义,给定超参数 λ \lambda λ,正则化项为
s = λ 2 ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) , s = \frac{\lambda}{2} \left(\|\mathbf{W}^{(1)}\|_F^2 + \|\mathbf{W}^{(2)}\|_F^2\right), s=2λ(W(1)F2+W(2)F2),

其中矩阵的Frobenius范数是将矩阵展平为向量后应用的 L 2 L_2 L2范数。 最后,模型在给定数据样本上的正则化损失为:
J = L + s . J = L + s. J=L+s.

前向传播范例

反向传播及公式

反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法。也称“BP算法”
简言之,该方法根据微积分中的链式规则,按相反的顺序从输出层到输入层遍历网络
该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。

假设我们有函数 Y = f ( X ) \mathsf{Y}=f(\mathsf{X}) Y=f(X) Z = g ( Y ) \mathsf{Z}=g(\mathsf{Y}) Z=g(Y), 其中输入和输出 X , Y , Z \mathsf{X}, \mathsf{Y}, \mathsf{Z} X,Y,Z是任意形状的张量。 利用链式法则,我们可以计算 Z \mathsf{Z} Z关于 X \mathsf{X} X的导数
∂ Z ∂ X = prod ( ∂ Z ∂ Y , ∂ Y ∂ X ) . \frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right). XZ=prod(YZ,XY).

反向传播的目的是计算梯度 ∂ J / ∂ W ( 1 ) \partial J/\partial \mathbf{W}^{(1)} J/W(1) ∂ J / ∂ W ( 2 ) \partial J/\partial \mathbf{W}^{(2)} J/W(2). 为此,我们应用链式法则,依次计算每个中间变量和参数的梯度。 计算的顺序与前向传播中执行的顺序相反,因为我们需要从计算图的结果开始,并朝着参数的方向努力。

  1. 计算目标函数 J = L + s J=L+s J=L+s相对于损失项 L L L和正则项 s s s的梯度
    ∂ J ∂ L = 1    and    ∂ J ∂ s = 1. \frac{\partial J}{\partial L} = 1 \; \text{and} \; \frac{\partial J}{\partial s} = 1. LJ=1andsJ=1.
  2. 根据链式法则计算目标函数关于输出层变量 o \mathbf{o} o的梯度:
    ∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o ∈ R q . \frac{\partial J}{\partial \mathbf{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \mathbf{o}}\right) = \frac{\partial L}{\partial \mathbf{o}} \in \mathbb{R}^q. oJ=prod(LJ,oL)=oLRq.
  3. 计算正则化项相对于两个参数的梯度:
    ∂ s ∂ W ( 1 ) = λ W ( 1 )    and    ∂ s ∂ W ( 2 ) = λ W ( 2 ) . \frac{\partial s}{\partial \mathbf{W}^{(1)}} = \lambda \mathbf{W}^{(1)} \; \text{and} \; \frac{\partial s}{\partial \mathbf{W}^{(2)}} = \lambda \mathbf{W}^{(2)}. W(1)s=λW(1)andW(2)s=λW(2).
  4. 计算最接近输出层的模型参数的梯度 ∂ J / ∂ W ( 2 ) ∈ R q × h \partial J/\partial \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} J/W(2)Rq×h。 使用链式法则得出:
    ∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) . \frac{\partial J}{\partial \mathbf{W}^{(2)}}= \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(2)}}\right)= \frac{\partial J}{\partial \mathbf{o}} \mathbf{h}^\top + \lambda \mathbf{W}^{(2)}. W(2)J=prod(oJ,W(2)o)+prod(sJ,W(2)s)=oJh+λW(2).
  5. 为了获得关于 W ( 1 ) \mathbf{W}^{(1)} W(1)的梯度,我们需要继续沿着输出层到隐藏层反向传播。 关于隐藏层输出的梯度 ∂ J / ∂ h ∈ R h \partial J/\partial \mathbf{h} \in \mathbb{R}^h J/hRh由下式给出:
    ∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o . \frac{\partial J}{\partial \mathbf{h}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{h}}\right) = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}. hJ=prod(oJ,ho)=W(2)oJ.
  6. 由于激活函数 ϕ \phi ϕ是按元素计算的, 计算中间变量 z \mathbf{z} z的梯度 ∂ J / ∂ z ∈ R h \partial J/\partial \mathbf{z} \in \mathbb{R}^h J/zRh需要使用按元素乘法运算符,我们用 ⊙ \odot 表示:
    ∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) . \frac{\partial J}{\partial \mathbf{z}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{h}}, \frac{\partial \mathbf{h}}{\partial \mathbf{z}}\right) = \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right). zJ=prod(hJ,zh)=hJϕ(z).
  7. 最后,我们可以得到最接近输入层的模型参数的梯度 ∂ J / ∂ W ( 1 ) ∈ R h × d \partial J/\partial \mathbf{W}^{(1)} \in \mathbb{R}^{h \times d} J/W(1)Rh×d。 根据链式法则,我们得到:
    ∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) . \frac{\partial J}{\partial \mathbf{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{z}}, \frac{\partial \mathbf{z}}{\partial \mathbf{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(1)}}\right) = \frac{\partial J}{\partial \mathbf{z}} \mathbf{x}^\top + \lambda \mathbf{W}^{(1)}. W(1)J=prod(zJ,W(1)z)+prod(sJ,W(1)s)=zJx+λW(1).

反向传播范例

假设输入x = 1.5,模型初始参数w=0.8,b=0.2。学习率为0.1,则过程如下图:
在这里插入图片描述
当有两层的时候:
在这里插入图片描述

小结

  • 前向传播在神经网络定义的计算图中按顺序计算和存储中间变量,它的顺序是从输入层到输出层。
  • 反向传播按相反的顺序(从输出层到输入层)计算和存储神经网络的中间变量和参数的梯度。
  • 在训练深度学习模型时,前向传播和反向传播是相互依赖的。
  • 训练比预测需要更多的内存。

前向传播计算图

其中正方形表示变量,圆圈表示操作符。 左下角表示输入,右上角表示输出。 注意显示数据流的箭头方向主要是向右和向上的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/354637.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

# AutoSar一文概览

1.什么是AutoSar ​ AUTOSAR全称为“AUTomotive Open System ARchitecture”,译为“汽车开放系统体系结构”;AUTOSAR是由 全球各大汽车整车厂、汽车零部件供应商、汽车电子软件系统公司联合建立的一套标准协议、软件架构。 2.为什么汽车行业要定义一个…

DIDL5_数值稳定性和模型初始化

数值稳定性和模型初始化数值稳定性梯度不稳定的影响推导什么是梯度消失?什么是梯度爆炸?如何解决数值不稳定问题?——参数初始化参数初始化的几种方法默认初始化Xavier初始化小结当神经网络变得很深的时候,数值特别容易不稳定。我…

面试题67. 把字符串转换成整数

题目 写一个函数 StrToInt,实现把字符串转换成整数这个功能。不能使用 atoi 或者其他类似的库函数。 首先,该函数会根据需要丢弃无用的开头空格字符,直到寻找到第一个非空格的字符为止。 当我们寻找到的第一个非空字符为正或者负号时&#xf…

密度峰值聚类算法(DPC)

密度峰值聚类算法目录DPC算法1.1 DPC算法的两个假设1.2 DPC算法的两个重要概念1.3 DPC算法的执行步骤1.4 DPC算法的优缺点matlab代码密度计算函数计算delta寻找聚类中心点聚类算法目录 DPC算法 1.1 DPC算法的两个假设 1)类簇中心被类簇中其他密度较低的数据点包围…

kubernetes 教程

K8s 安装kubectl 下载kubectl curl -LO "https://dl.k8s.io/release/**$(**curl -L -s https://dl.k8s.io/release/stable.txt**)**/bin/linux/amd64/kubectl" 安装 sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl 验证 kubectl versi…

学习 Python 之 Pygame 开发坦克大战(二)

学习 Python 之 Pygame 开发坦克大战(二)坦克大战的需求开始编写坦克大战1. 搭建主类框架2. 获取窗口中的事件3. 创建基类4. 初始化我方坦克类5. 完善我方坦克的移动5. 完善我方坦克的显示6. 在主类中加入我方坦克并完成坦克移动7. 初始化子弹类8. 完善子…

(考研湖科大教书匠计算机网络)第五章传输层-第一、二节:传输层概述及端口号、复用分用等概念

获取pdf:密码7281专栏目录首页:【专栏必读】考研湖科大教书匠计算机网络笔记导航 文章目录一:传输层概述(1)概述(2)从计算机网络体系结构角度看传输层(3)传输层意义二&am…

MySQL行转列列转行实例解析

文档准备要求:找出所有的用户没有安装的软件。创建两个表,用户表app_install 和 app表app建表语句:# 创建app表,并插入数据 create table app(id int,app varchar(32)); insert into app(id,app) values (1,微信),(2,QQ),(3,支付宝…

二叉树理论基础知识点

二叉树的种类 在我们解题过程中二叉树有两种主要的形式:满二叉树和完全二叉树 满二叉树 满二叉树:如果一棵二叉树只有度为0的结点和度为2的结点,并且度为0的结点在同一层上,则这棵二叉树为满二叉树。 如图所示: 这…

About Oracle Database Performance Method

bottleneck(瓶颈): a point where resource contention is highest throughput(吞吐量): the amount of work that can be completed in a specified time. response time (响应时间): the time to complete a spec…

Java 日志简介

目录1、Slf4j2、Log4j3、LogBack4、Logback 优点5、ELK1、Slf4j slf4j 的全称是 Simple Loging Facade For Java,即它仅仅是一个为 Java 程序提供日志输出的统一接口,并不是一个具体的日志实现方案,就比如 JDBC 一样,只是一种规则…

解决:eclipse绿化版Resource注解报Resource cannot be resolved to a type问题

如图: 网上解决教程很多,我的eclipse是绿化版的,不需要安装 解决办法如下: 1、在eclipse中,进入到Window->Preferences->Java->Installed JREs中 默认显示如下: 2、点击Add-->Standard VM--…

分页插件

引入依赖 注意需要和SpringBoot的版本对应&#xff0c;否则分页可能不生效 使用的分页依赖&#xff1a; <!-- pagehelper 插件--><dependency><groupId>com.github.pagehelper</groupId><artifactId>pagehelper-spring-boot-starter</arti…

Dockerfile详解及优化技巧

写在前面 Dockerfile的默认相对路径是Dockerfile所在的目录&#xff1b;Dockerfile中的每一行会被视为一层镜像 一、Dockerfile 原理 1.1 镜像定义 首先我们先来回顾一下 Docker 镜像&#xff0c;它由多个只读层堆叠到一起&#xff0c;每一层是上一层的增量修改。基于镜像创…

深度学习炼丹-数据标准化

前言 一般机器学习任务其工作流程可总结为如下所示 pipeline。 在工业界,数据预处理步骤对模型精度的提高的发挥着重要作用。对于机器学习任务来说,广泛的数据预处理一般有四个阶段(视觉任务一般只需 Data Transformation): 数据清洗(Data Cleaning)、数据整合(Data Integ…

【c语言进阶】深度剖析整形数据

&#x1f680;write in front&#x1f680; &#x1f4dc;所属专栏&#xff1a; &#x1f6f0;️博客主页&#xff1a;睿睿的博客主页 &#x1f6f0;️代码仓库&#xff1a;&#x1f389;VS2022_C语言仓库 &#x1f3a1;您的点赞、关注、收藏、评论&#xff0c;是对我最大的激励…

C++010-C++嵌套循环

文章目录C010-C嵌套循环嵌套循环嵌套循环举例题目描述 输出1的个数题目描述 输出n行99乘法表题目描述 求s1!2!...10!作业在线练习&#xff1a;总结C010-C嵌套循环 在线练习&#xff1a; http://noi.openjudge.cn/ https://www.luogu.com.cn/ 嵌套循环 循环可以指挥计算机重复去…

自命为缓存之王的Caffeine(6)

您好&#xff0c;我是湘王&#xff0c;这是我的CSDN博客&#xff0c;欢迎您来&#xff0c;欢迎您再来&#xff5e;之前用Caffeine替代Redis的时候&#xff0c;发现先保存KV&#xff0c;再获取key&#xff0c;过期时间为3秒。但即使过了3秒&#xff0c;还是能获取到保存的数据。…

网络爬虫简介

前言 没什么可以讲的所以就介绍爬虫吧 介绍 网络爬虫&#xff08;英语&#xff1a;web crawler&#xff09;&#xff0c;也叫网路蜘蛛&#xff08;spider&#xff09;&#xff0c;是一种用来自动浏览万维网的网络机器人。其目的一般为编纂网络索引。 网路搜索引擎等站点通过…

Windows 环境下,cmake工程导入OpenCV库

目录 1、下载 OpenCV 库 2、配置环境变量 3、CmakeLists.txt 配置 1、下载 OpenCV 库 OpenCV官方下载地址&#xff1a;download | OpenCV 4.6.0 下载完毕后解压&#xff0c;便可以得到下面的文件 2、配置环境变量 我们需要添加两个环境变量&#xff0c;一个是 OpenCVConfi…