《动手学深度学习 Pytorch版》 4.7 前向传播、反向传播和计算图

news2024/11/30 2:33:22

4.7.1 前向传播

整节理论,详见书本。

4.7.2 前向传播计算图

整节理论,详见书本。

4.7.3 反向传播

整节理论,详见书本。

4.7.4 训练神经网络

整节理论,详见书本。

练习

(1)假设一些标量函数 X X X 的输入 X X X n × m n\times m n×m 矩阵。 f f f 相对于 X X X 的梯度的维数是多少?

还是 n × m n\times m n×m,多少个变量就是多少个导数嘛。


(2)向本节中描述的模型的隐藏层添加偏置项(不需要再正则化项中包含偏置项)。

    a.绘制出相应的计算图。

    b.推导前向传播和反向传播方程。

b. 仍假设输入样本是 x ∈ R d \boldsymbol{x}\in\mathbb{R}^d xRd,则前向传播为:

z = W ( 1 ) x + b h = ϕ ( z ) o = W ( 2 ) h + b L = l ( o , y ) s = λ 2 ( ∣ ∣ W ( 1 ) ∣ ∣ F 2 + ∣ ∣ W ( 2 ) ∣ ∣ F 2 ) J = L + s \begin{align} \boldsymbol{z}&=\boldsymbol{W}^{(1)}\boldsymbol{x}+b\\ \boldsymbol{h}&=\phi(\boldsymbol{z})\\ \boldsymbol{o}&=\boldsymbol{W}^{(2)}\boldsymbol{h}+b\\ L&=l(\boldsymbol{o},y)\\ s&=\frac{\lambda}{2}(||\boldsymbol{W}^{(1)}||^2_F+||\boldsymbol{W}^{(2)}||^2_F)\\ J&=L+s \end{align} zhoLsJ=W(1)x+b=ϕ(z)=W(2)h+b=l(o,y)=2λ(∣∣W(1)F2+∣∣W(2)F2)=L+s

反向传播为:

∂ J ∂ L = 1 , ∂ J ∂ s = 1 ∂ J ∂ o = ∂ J ∂ L ∂ L ∂ o = ∂ L ∂ o ∈ R q ∂ s ∂ W ( 1 ) = λ W ( 1 ) , ∂ s ∂ W ( 2 ) = λ W ( 2 ) ∂ J ∂ W ( 2 ) = ∂ J ∂ o ∂ o ∂ W ( 2 ) + ∂ J ∂ s ∂ s ∂ W ( 2 ) = ∂ J ∂ o h T + λ W ( 2 ) ∂ J ∂ h = ∂ J ∂ o ∂ o ∂ h = W ( 2 ) T ∂ J ∂ o ∂ J ∂ z = ∂ J ∂ h ∂ h ∂ z = ∂ J ∂ h ⊙ ϕ ′ ( z ) ∂ J ∂ W ( 1 ) = ∂ J ∂ z ∂ z ∂ W ( 1 ) + ∂ J ∂ s ∂ s ∂ W ( 1 ) = ∂ J ∂ z x T + λ W ( 1 ) \begin{align} \frac{\partial J}{\partial L}&=1,\frac{\partial J}{\partial s}=1\\ \frac{\partial J}{\partial\boldsymbol{o}}&=\frac{\partial J}{\partial L}\frac{\partial L}{\partial\boldsymbol{o}}=\frac{\partial L}{\partial\boldsymbol{o}}\in\mathbb{R}^q\\ \frac{\partial s}{\partial\boldsymbol{W}^{(1)}}&=\lambda\boldsymbol{W}^{(1)},\frac{\partial s}{\partial\boldsymbol{W}^{(2)}}=\lambda\boldsymbol{W}^{(2)}\\ \frac{\partial J}{\partial\boldsymbol{W}^{(2)}}&=\frac{\partial J}{\partial\boldsymbol{o}}\frac{\partial\boldsymbol{o}}{\partial\boldsymbol{W}^{(2)}}+\frac{\partial J}{\partial s}\frac{\partial s}{\partial\boldsymbol{W}^{(2)}}=\frac{\partial J}{\partial\boldsymbol{o}}\boldsymbol{h}^T+\lambda\boldsymbol{W}^{(2)}\\ \frac{\partial J}{\partial\boldsymbol{h}}&=\frac{\partial J}{\partial\boldsymbol{o}}\frac{\partial\boldsymbol{o}}{\partial\boldsymbol{h}}=\boldsymbol{W}^{(2)T}\frac{\partial J}{\partial\boldsymbol{o}}\\ \frac{\partial J}{\partial\boldsymbol{z}}&=\frac{\partial J}{\partial\boldsymbol{h}}\frac{\partial\boldsymbol{h}}{\partial\boldsymbol{z}}=\frac{\partial J}{\partial\boldsymbol{h}}\odot\phi'(\boldsymbol{z})\\ \frac{\partial J}{\partial\boldsymbol{W}^{(1)}}&=\frac{\partial J}{\partial\boldsymbol{z}}\frac{\partial\boldsymbol{z}}{\partial\boldsymbol{W}^{(1)}}+\frac{\partial J}{\partial s}\frac{\partial s}{\partial\boldsymbol{W}^{(1)}}=\frac{\partial J}{\partial\boldsymbol{z}}\boldsymbol{x}^T+\lambda\boldsymbol{W}^{(1)} \end{align} LJoJW(1)sW(2)JhJzJW(1)J=1,sJ=1=LJoL=oLRq=λW(1),W(2)s=λW(2)=oJW(2)o+sJW(2)s=oJhT+λW(2)=oJho=W(2)ToJ=hJzh=hJϕ(z)=zJW(1)z+sJW(1)s=zJxT+λW(1)

a. 计算图为:

在这里插入图片描述


(3)计算本节所描述的模型用于训练和预测的内存空间。

不会,略。


(4)假设想计算二阶导数。计算图会发生什么变化?预计计算需要多长时间?

二阶计算图应该是在保留一阶计算图的基础上继续拓展出来的,需要的时间大抵是二倍吧。


(5)假设计算图对于当前的GPU来说太大了。

    a. 请尝试把它划分到多个GPU上。
    b. 这与小批量训练相比,有哪些优点和缺点。

a. 应使用 torch.nn.DataParallel 进行并行运算。

b.

batch_size够大则会由于并行计算而加快速度

batch_size不够大时反而会因为多卡之间的通信以及数据拆分与合并的额外开销导致效率反而更低。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1007919.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Wireshark把DDoS照原形

1 前言 MTU、 传输速度、 拥塞控制,还是各种重传,TCP传输相关的核心概念: 学习了RFC规范和具体的Linux实现通过案例,把这些知识灵活运用了起来 这种种还是在协议规范这大框架内的讨论,默认前提就是通信两端是遵照TC…

Activating More Pixels in Image Super-Resolution Transformer(HAT)超分

摘要 基于Transformer的方法在低级视觉任务(如图像超分辨率)上表现出令人印象深刻的性能。然而,我们发现这些网络只能通过归因分析利用有限的输入信息空间范围。这意味着Transformer的潜力在现有网络中仍未得到充分利用。为了激活更多输入像…

yolov7添加 iRMB模块

复制过来 yolo.py添加 yaml文件随便换,建议换3x3的 pip install timm0.6.5,版本问题记得搞一下

DNG格式详解,DNG是什么?为何DNG可以取代RAW统一单反相机、苹果安卓移动端相机拍摄输出原始图像数据标准

返回图像处理总目录:《JavaCV图像处理合集总目录》 前言 在DNG格式发布之前,我们先了解一下之前单反相机、苹果和安卓移动端相机拍照输出未经处理的原始图像格式是什么? RAW 什么是RAW? RAW是未经处理、也未经压缩的格式。可以…

基于开源模型搭建实时人脸识别系统(六):人脸识别(人脸特征提取)

文章目录 人脸识别的几个发展阶段基于深度学习的人脸识别技术的流程闭集和开集(Open set)识别人脸识别的损失Insightface人脸识别数据集模型选型参考文献结语人脸识别系统项目源码 前面我们讲过了人脸检测、人脸质量、人脸关键点、人脸跟踪,接…

微分中值定理

目录 费马定理 罗尔定理 拉格朗日中值定理 柯西中值定理 几个常用的泰勒公式 微分中值定理是微积分中的一个重要定理,它用于描述一个函数在某个区间内的平均变化率与该区间内某一点的瞬时变化率之间的关系。微分中值定理有两个主要形式:拉格朗日中值…

Kotlin Files Paths write ByteArray writeString写多行BufferedWriter

Kotlin Files Paths write ByteArray writeString写多行BufferedWriter import java.nio.file.Files import java.nio.file.Paths import java.nio.file.StandardOpenOptionfun main(args: Array<String>) {val filePath "./myfile.txt"val path Paths.get(…

【报错】springboot3启动报错

报错内容&#xff1a;Cannot load driver class: org.h2.Driver Error starting ApplicationContext. To display the condition evaluation report re-run your application with debug enabled. 解决; 通过源码分析&#xff0c;druid-spring-boot-3-starter目前最新版本是1…

微信小程序 写一个接口不会掉就不会停止的加载动画

我们可以在接口调用前执行 wx.showLoading({title: 加载中,mask: true })这个加载会在这一直转 显示这加载的动画 它不会自己停下来 而是需要你执行 wx.hideLoading()之后 这个加载动画才会停止 那么我们完全可以将wx.hideLoading()放在接口返回的回调中 这样 就达到了一个 …

LeetCode每日一题:2596. 检查骑士巡视方案(2023.9.13 C++)

目录 2596. 检查骑士巡视方案 题目描述&#xff1a; 实现代码与解析&#xff1a; bfs模拟 原理思路&#xff1a; 2596. 检查骑士巡视方案 题目描述&#xff1a; 骑士在一张 n x n 的棋盘上巡视。在有效的巡视方案中&#xff0c;骑士会从棋盘的 左上角 出发&#xff0c;并…

利用Semaphore实现多线程调用接口A且限制接口A的每秒QPS为10

前段时间在群里面发现有个群友抛出一个实际需求&#xff1a;需要通过一个接口拉取数据&#xff0c;这个接口有每秒10QPS限制&#xff0c;请问如何实现数据拉去效率最大化且限制调用拉取接口每秒10PQPS&#xff1f;我觉得这个需求挺有意思的&#xff0c;跟某群友讨论&#xff0c…

CopyOnWriteArrayList源码分析

其中唯一的线程安全 List 实现就是 CopyOnWriteArrayList。 特点 由于读取操作不会对原有数据进行修改&#xff0c;因此&#xff0c;对于每次读取都进行加锁其实是一种资源浪费。相比之下&#xff0c;我们应该允许多个线程同时访问 List 的内部数据&#xff0c;毕竟对于读取操…

企业邮箱选择指南:最适合跨境贸易的解决方案推荐

随着全球贸易的不断发展&#xff0c;外贸公司越来越依赖高效的沟通和协作工具。在众多企业邮箱选择中&#xff0c;哪一种最适合外贸公司的需求呢&#xff1f;让我们一起来看看外贸公司常用的企业邮箱解决方案。 对于外贸公司而言&#xff0c;可靠性是选择企业邮箱的首要考虑因…

LC1798. 你能构造出连续值的最大数目(JAVA)

LC1798. 你能构造出连续值的最大数目 题目描述贪心算法代码演示 题目描述 难度 - 中等 Leetcode - 1798. 你能构造出连续值的最大数目 给你一个长度为 n 的整数数组 coins &#xff0c;它代表你拥有的 n 个硬币。第 i 个硬币的值为 coins[i] 。如果你从这些硬币中选出一部分硬币…

前端构建工具 webpack 笔记

1、了解 webpack 1、定义&#xff1a;本质上&#xff0c;webpack 是一个用于现代 JavaScript 应用程序的静态模块打包工具&#xff0c;当 webpack 处理应用它会在内部从一个或多个入口点构建一个依赖图(dependency graph)&#xff0c;然后将你项目中所程序时&#xff0c;需的…

YOLO物体检测系列3:YOLOV3改进解读

&#x1f388;&#x1f388;&#x1f388;YOLO 系列教程 总目录 YOLOV1整体解读 YOLOV2整体解读 YOLOV3提出论文&#xff1a;《Yolov3: An incremental improvement》 1、YOLOV3改进 这张图讲道理真的过分了&#xff01;&#xff01;&#xff01;我不是针对谁&#xff0c;在…

《C++ Primer》第3章 字符串、向量和数组(二)

参考资料&#xff1a; 《C Primer》第5版《C Primer 习题集》第5版 3.3 标准库类型vector&#xff08;P86&#xff09; vector 表示对象的序列&#xff0c;其中所有对象的类型相同&#xff0c;每个对象都有一个与之对应的索引。vector 容纳着其他对象&#xff0c;所以常被称…

Linux内核4.14版本——drm框架分析(11)——DRM_IOCTL_MODE_ADDFB2(drm_mode_addfb2)

目录 1. drm_mode_addfb2 2. drm_internal_framebuffer_create 3. drm_fb_cma_create->drm_gem_fb_create->drm_gem_fb_create_with_funcs 4. drm_gem_fb_alloc 4.1 drm_helper_mode_fill_fb_struct 4.2 drm_framebuffer_init 5. 调用流程图 书接上回&#xff0c;使…

springboot对接postgres

安装postgres 注意:下述链接方式会自动创建数据库steven_russell,若需要创建其他数据库&#xff0c;可以手动执行命令创建数据库 docker run --name postgres \ -p 5432:5432 \ -e POSTGRES_USERsteven_russell \ -e POSTGRES_PASSWORD123456 \ -itd --privilegedtrue postgre…

【卖出看涨期权策略(Short Call)】

卖出看涨期权策略&#xff08;Short Call) 卖出看涨期权策略又称为卖出无备兑看涨期权&#xff0c;如果一个投资者在不持有标的资产价格的情况下卖出看涨期权&#xff0c;那么这种策略就是卖出无备兑看涨期权策略。这个策略潜在盈利有限&#xff0c;但是亏损无限。 卖出看涨期…