《动手学深度学习(PyTorch版)》笔记4.7

news2024/11/16 16:48:07

Chapter4 Multilayer Perceptron

4.7 Forward/Backward Propagation and Computational Graphs

本节将通过一些基本的数学和计算图,深入探讨反向传播的细节。首先,我们将重点放在带权重衰减( L 2 L_2 L2正则化)的单隐藏层多层感知机上。

4.7.1 Forward Propagation

前向传播(forward propagation或forward pass)指的是按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。

我们将一步步研究单隐藏层神经网络的机制,为了简单起见,我们假设输入样本是 x ∈ R d \mathbf{x}\in \mathbb{R}^d xRd,并且我们的隐藏层不包括偏置项。这里的中间变量是:

z = W ( 1 ) x , \mathbf{z}= \mathbf{W}^{(1)} \mathbf{x}, z=W(1)x,

其中 W ( 1 ) ∈ R h × d \mathbf{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)Rh×d是隐藏层的权重参数。将中间变量 z ∈ R h \mathbf{z}\in \mathbb{R}^h zRh通过激活函数 ϕ \phi ϕ后,我们得到长度为 h h h的隐藏激活向量:

h = ϕ ( z ) . \mathbf{h}= \phi (\mathbf{z}). h=ϕ(z).

隐藏变量 h \mathbf{h} h也是一个中间变量。假设输出层的参数只有权重 W ( 2 ) ∈ R q × h \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)Rq×h,我们可以得到输出层变量,它是一个长度为 q q q的向量:

o = W ( 2 ) h . \mathbf{o}= \mathbf{W}^{(2)} \mathbf{h}. o=W(2)h.

假设损失函数为 l l l,样本标签为 y y y,我们可以计算单个数据样本的损失项,

L = l ( o , y ) . L = l(\mathbf{o}, y). L=l(o,y).

根据 L 2 L_2 L2正则化的定义,给定超参数 λ \lambda λ,正则化项为

s = λ 2 ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) , s = \frac{\lambda}{2} \left(\|\mathbf{W}^{(1)}\|_F^2 + \|\mathbf{W}^{(2)}\|_F^2\right), s=2λ(W(1)F2+W(2)F2),

∥ X ∥ F \|\mathbf{X}\|_F XF表示矩阵的Frobenius范数:
∥ X ∥ F = ∑ i = 1 m ∑ j = 1 n x i j 2 . \|\mathbf{X}\|_F = \sqrt{\sum_{i=1}^m \sum_{j=1}^n x_{ij}^2}. XF=i=1mj=1nxij2 .
最后,模型在给定数据样本上的正则化损失为:

J = L + s . J = L + s. J=L+s.

在下面的讨论中,我们将 J J J称为目标函数(objective function)。

下图是与上述简单网络相对应的计算图,其中正方形表示变量,圆圈表示操作符。

在这里插入图片描述

4.7.2 Backward Propagation

反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法,该方法根据链式规则,按相反的顺序从输出层到输入层遍历网络。该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。
假设我们有函数 Y = f ( X ) \mathsf{Y}=f(\mathsf{X}) Y=f(X) Z = g ( Y ) \mathsf{Z}=g(\mathsf{Y}) Z=g(Y),其中输入和输出 X , Y , Z \mathsf{X}, \mathsf{Y}, \mathsf{Z} X,Y,Z是任意形状的张量。利用链式法则,我们可以计算 Z \mathsf{Z} Z关于 X \mathsf{X} X的导数:

∂ Z ∂ X = prod ( ∂ Z ∂ Y , ∂ Y ∂ X ) . \frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right). XZ=prod(YZ,XY).

在这里,我们使用 prod \text{prod} prod运算符在执行必要的操(如换位和交换输入位置)后将其参数相乘。对于高维张量,我们使用适当的对应项。

在上面的计算图中单隐藏层简单网络的参数是 W ( 1 ) \mathbf{W}^{(1)} W(1) W ( 2 ) \mathbf{W}^{(2)} W(2),反向传播的目的是计算梯度 ∂ J / ∂ W ( 1 ) \partial J/\partial \mathbf{W}^{(1)} J/W(1) ∂ J / ∂ W ( 2 ) \partial J/\partial \mathbf{W}^{(2)} J/W(2),计算的顺序与前向传播中执行的顺序相反,具体如下:

∂ J ∂ L = 1    and    ∂ J ∂ s = 1. \frac{\partial J}{\partial L} = 1 \; \text{and} \; \frac{\partial J}{\partial s} = 1. LJ=1andsJ=1.

∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o ∈ R q . \frac{\partial J}{\partial \mathbf{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \mathbf{o}}\right) = \frac{\partial L}{\partial \mathbf{o}} \in \mathbb{R}^q. oJ=prod(LJ,oL)=oLRq.

∂ s ∂ W ( 1 ) = λ W ( 1 )    ,    ∂ s ∂ W ( 2 ) = λ W ( 2 ) . \frac{\partial s}{\partial \mathbf{W}^{(1)}} = \lambda \mathbf{W}^{(1)} \; \text{,} \; \frac{\partial s}{\partial \mathbf{W}^{(2)}} = \lambda \mathbf{W}^{(2)}. W(1)s=λW(1),W(2)s=λW(2).

∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) ∈ R q × h . \frac{\partial J}{\partial \mathbf{W}^{(2)}}= \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(2)}}\right)= \frac{\partial J}{\partial \mathbf{o}} \mathbf{h}^\top + \lambda \mathbf{W}^{(2)}\in \mathbb{R}^{q \times h}. W(2)J=prod(oJ,W(2)o)+prod(sJ,W(2)s)=oJh+λW(2)Rq×h.

∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o ∈ R h . \frac{\partial J}{\partial \mathbf{h}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{h}}\right) = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}\in \mathbb{R}^h. hJ=prod(oJ,ho)=W(2)oJRh.

由于激活函数 ϕ \phi ϕ是按元素计算的,计算中间变量 z \mathbf{z} z的梯度需要使用按元素乘法运算符,我们用 ⊙ \odot 表示:

∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) ∈ R h . \frac{\partial J}{\partial \mathbf{z}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{h}}, \frac{\partial \mathbf{h}}{\partial \mathbf{z}}\right) = \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)\in \mathbb{R}^h. zJ=prod(hJ,zh)=hJϕ(z)Rh.

∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) x ⊤ + λ W ( 1 ) = ( W ( 2 ) ⊤ ∂ J ∂ o ) ⊙ ϕ ′ ( z ) x ⊤ + λ W ( 1 ) . \begin{align*} \frac{\partial J}{\partial \mathbf{W}^{(1)}} &= \text{prod}\left(\frac{\partial J}{\partial \mathbf{z}}, \frac{\partial \mathbf{z}}{\partial \mathbf{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(1)}}\right) \\ &= \frac{\partial J}{\partial \mathbf{z}} \mathbf{x}^\top + \lambda \mathbf{W}^{(1)} \\ &= \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)\mathbf{x}^\top + \lambda \mathbf{W}^{(1)} \\ &= ({\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}})\odot \phi'\left(\mathbf{z}\right)\mathbf{x}^\top + \lambda \mathbf{W}^{(1)}. \end{align*} W(1)J=prod(zJ,W(1)z)+prod(sJ,W(1)s)=zJx+λW(1)=hJϕ(z)x+λW(1)=(W(2)oJ)ϕ(z)x+λW(1).

4.7.3 Training Neural Networks

在训练神经网络时,前向传播和反向传播相互依赖。以上述简单网络为例:一方面,在前向传播期间计算正则项取决于模型参数 W ( 1 ) \mathbf{W}^{(1)} W(1) W ( 2 ) \mathbf{W}^{(2)} W(2)的当前值。它们是由优化算法根据最近迭代的反向传播给出的。另一方面,反向传播期间参数的梯度计算,取决于由前向传播给出的隐藏变量 h \mathbf{h} h的当前值。

因此,在训练神经网络时,我们交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。注意,反向传播重复利用前向传播中存储的中间值,以避免重复计算。这带来的影响之一是我们需要保留中间值,直到反向传播完成,这也是训练比单纯的预测需要更多的内存(显存)的原因之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1414239.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【docker】linux系统docker的安装及使用

一、docker应用的安装 1.1 安装方式 Docker的自动化安装,即使用提供的一键安装的脚本,进行安装。 官方的一键安装方式:curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun 国内 daocloud一键安装命令:curl -s…

VSCode Debug 参数设置说明

如果想在vscode中debug一个项目,比如python3 run.py --args 这个时候你需要着重关注几个参数,参数用两个双引号分开,不能有空格。 cwd :运行代码的基础目录env: 设置环境变量 PYTHONPATH: 设置项目用到的模块搜索路径&#xff…

数学建模论文笔记

一、概述 1. 数学建模论文组成 论文电子版:摘要页、正文、参考文献、附录支撑材料:源程序代码以及调用说明、中间结果、支撑数据等首页:论文题目、摘要、关键词论文正文:问题重述、问题分析、模型假设、符号说明、模型建立与求解…

@JsonIgnore的使用及相关问题的解决

目录 1 前言 2 对比及其使用方法 3 遇到的相关问题及解决方法 1 前言 在我们编写的后端项目中,有时候可能需要将某个实体类以JSON格式传送给前端,但是其中可能有部分内容我们并不想传送,这时候我们选择将这部分内容变成Null,这…

响应式Web开发项目教程(HTML5+CSS3+Bootstrap)第2版 例5-1事件处理

代码 <!doctype html> <html> <head> <meta charset"utf-8"> <title>事件处理</title> </head><body> <input id"btn" type"button" name"btn" value"提交" /> <…

Backtrader 文档学习-Bracket Orders

Backtrader 文档学习-Bracket Orders 1. 概述 组合订单类型是一个非常宽泛的订单类别&#xff0c;只要brokder支持的订单类型都可以&#xff0c; 包括(Market, Limit, Close, Stop, StopLimit, StopTrail, StopTrailLimit, OCO)。 该功能用于回测&#xff0c;交互broker Brac…

Java集合-Map接口(key-value)

Map接口的特点&#xff1a;①KV键值对方式存储②Key键唯一&#xff0c;Value允许重复③无序。 Map有四个实现类&#xff1a;1.HashMap类2.LinkedHashMap类3.TreeMap类4.Hashtable类 1.HashMap类&#xff1a; 存储结构&#xff1a;哈希表 数组Node[ ] 链表&#xff08;红黑…

雨云美国二区云服务器评测

雨云美国二区云服务器评测 官网直接百度搜索雨云就行 我买的时候比较便宜&#xff0c;三个月3.4元&#xff0c;1C1G对于我这种小网站来说够用了 本期测评服务器配置 CPU&#xff1a;1核 内存&#xff1a;1G 硬盘&#xff1a;Linux系统20G&#xff0c;win系统30G 流量&…

Qt中Widget样式表实现圆弧边框

第一步 第二步 第三步 第四步 //插入border-radius: 10px; border: 2px solid #000; 效果图

Elasticsearch介绍以及基本操作

目录 一、Elasticsearch介绍 二、关于Elasticsearch的基本操作 &#xff08;1&#xff09;索引操作 &#xff08;2&#xff09;文档操作 三、域的属性 &#xff08;1&#xff09;index &#xff08;2&#xff09;type &#xff08;3&#xff09;store 一、Elasticsearc…

vue3+elementPlus pc和小程序ai聊天文生图

websocket封装可以看上一篇文章 //pc端 <template><div class"common-layout theme-white"><el-container><el-aside><div class"title-box"><span>AI Chat</span></div><div class"chat-list&…

使用vue_cli脚手架创建Vue项目(cmd和图形化方式)

使用vue_cli脚手架创建Vue项目&#xff08;cmd和图形化方式&#xff09; 创建项目(cmd方式) vue create vue_cli1.方向键选择manually select feature(手动选择方式创建)&#xff0c;回车 2.按空格键选择需要的组件&#xff1a;Babel、PWA、Router、Vuex、CSS&#xff0c;回…

【LeetCode】112. 路径总和(简单)——代码随想录算法训练营Day18

题目链接&#xff1a;112. 路径总和 题目描述 给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。判断该树中是否存在 根节点到叶子节点 的路径&#xff0c;这条路径上所有节点值相加等于目标和 targetSum 。如果存在&#xff0c;返回 true &#xff1b;否则&…

Pandas.Series.product() 乘积(累乘积) 详解 含代码 含测试数据集 随Pandas版本持续更新

关于Pandas版本&#xff1a; 本文基于 pandas2.2.0 编写。 关于本文内容更新&#xff1a; 随着pandas的stable版本更迭&#xff0c;本文持续更新&#xff0c;不断完善补充。 传送门&#xff1a; Pandas API参考目录 传送门&#xff1a; Pandas 版本更新及新特性 传送门&…

以太网与PON网络的巅峰对决

在这网络的江湖中&#xff0c;各路江湖豪侠都神色匆忙地往同一个地方赶&#xff0c;豪侠们脸上都充满期待和焦虑&#xff0c;生怕错过了什么。这个地方就是传说中的园区网&#xff0c;因为在那里万众期待已久的以太网与PON网络的巅峰对决“将在今天上演。 一方是以太网大侠&am…

500行Python代码构建的AI搜索工具!

一个500行Python代码构建的AI搜索工具&#xff0c;而且还会开源&#xff0c;试了一下麻雀虽小该有的都有。 后端是Mixtral-8x7b 模型&#xff0c;托管在 LeptonAI 上&#xff0c;输出速度能达到每秒大约200个 token&#xff0c;用的搜索引擎是 Bing 的搜索 API。 作者还写了一…

【昕宝爸爸小模块】什么是POI,为什么它会导致内存溢出?

➡️博客首页 https://blog.csdn.net/Java_Yangxiaoyuan 欢迎优秀的你&#x1f44d;点赞、&#x1f5c2;️收藏、加❤️关注哦。 本文章CSDN首发&#xff0c;欢迎转载&#xff0c;要注明出处哦&#xff01; 先感谢优秀的你能认真的看完本文&…

六、Kotlin 类型进阶

1. 类的构造器 & init 代码块 1.1 主构造器 & 副构造器在使用时的注意事项 & 注解 JvmOverloads 推荐在类定义时为类提供一个主构造器&#xff1b; 在为类提供了主构造器的情况下&#xff0c;当再定义其他的副构造器时&#xff0c;要求副构造器必须调用到主构造器…

2024年预制菜行业市场发展趋势分析(2021-2023年预制菜行业数据分析)

近期&#xff0c;老干妈被称为预制菜、预制菜国标报送稿出炉等事件再次引起大众对于预制菜市场的讨论。随着国家对预制菜审核标准的严格化&#xff0c;预制菜市场未来走向将会如何&#xff1f;鲸参谋带大家从数据角度来了解。 首先来看下预制菜市场的行业发展情况。 根据鲸参…

Linux 驱动开发基础知识—— LED 驱动程序框架(四)

个人名片&#xff1a; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的在校大学生 &#x1f42f;个人主页&#xff1a;妄北y &#x1f427;个人QQ&#xff1a;2061314755 &#x1f43b;个人邮箱&#xff1a;2061314755qq.com &#x1f989;个人WeChat&#xff1a;V…