Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)

news2024/11/15 21:38:53

目录

  • 1 计算图原理
  • 2 基于计算图的传播
  • 3 神经网络计算图
  • 4 自动微分机
  • 5 Pytorch中的自动微分
    • 5.1 梯度缓存
    • 5.2 参数冻结

1 计算图原理

计算图(Computational Graph)是机器学习领域中推导神经网络和其他模型算法,以及软件编程实现的有效工具。

计算图的核心是将模型表示成一张拓扑有序(Topologically Ordered)有向无环图(Directed Acyclic Graph),其中每个节点 u i u_i ui包含数值信息(可以是标量、向量、矩阵或张量)和算子信息 f i f_i fi。拓扑有序指当前节点仅在全体指向它的节点被计算后才进行计算。

在这里插入图片描述
计算图的优点在于:

  • 可以通过基本初等映射 的拓扑联结,形成复合的复杂模型,大多数神经网络模型都可以被计算图表示;
  • 便于实现自动微分机(Automatic Differentiation Machine),对给定计算图可基于链式法则由节点局部梯度进行反向传播。

计算图的基本概念如表所示,基于计算图的基本前向传播和反向传播算法如表

符号含义
n n n计算图的节点数
l l l计算图的叶节点数
L L L计算图的叶节点索引集
C C C计算图的非叶节点索引集
E E E计算图的有向边集合
u i u_i ui计算图中的第 i i i节点或其值
d i d_i di u i u_i ui 的维度
f i f_i fi u i u_i ui的算子
α i \alpha _i αi u i u_i ui的全体关联输入
J j → i \boldsymbol{J}_{j\rightarrow i} Jji节点 u i u_i ui关于节点 u j u_j uj的雅克比矩阵
P i \boldsymbol{P}_i Pi输出节点关于输入节点的雅克比矩阵

2 基于计算图的传播

基于计算图的前向传播算法如下

在这里插入图片描述
基于计算图的反向传播算法如下

在这里插入图片描述

以第一节的图为例,可知 E = { ( 1 , 3 ) , ( 2 , 3 ) , ( 2 , 4 ) , ( 3 , 4 ) } E=\left\{ \left( 1,3 \right) ,\left( 2,3 \right) ,\left( 2,4 \right) ,\left( 3,4 \right) \right\} E={(1,3),(2,3),(2,4),(3,4)}。首先进行前向传播:

{ u 3 = u 1 + u 2 = 5 u 4 = u 2 u 3 = 15 \begin{cases} u_3=u_1+u_2=5\\ u_4=u_2u_3=15\\\end{cases} {u3=u1+u2=5u4=u2u3=15

{ J 1 → 3 = ∂ u 3 / ∂ u 1 = 1 J 2 → 3 = ∂ u 3 / ∂ u 2 = 1 J 2 → 4 = ∂ u 4 / ∂ u 2 = u 3 = 5 J 3 → 4 = ∂ u 4 / ∂ u 3 = u 2 = 3 \begin{cases} \boldsymbol{J}_{1\rightarrow 3}={{\partial u_3}/{\partial u_1=}}1\\ \boldsymbol{J}_{2\rightarrow 3}={{\partial u_3}/{\partial u_2=}}1\\ \boldsymbol{J}_{2\rightarrow 4}={{\partial u_4}/{\partial u_2=}}u_3=5\\ \boldsymbol{J}_{3\rightarrow 4}={{\partial u_4}/{\partial u_3=}}u_2=3\\\end{cases} J13=u3/u1=1J23=u3/u2=1J24=u4/u2=u3=5J34=u4/u3=u2=3

接着进行反向传播:

{ P 4 = 1 P 3 = P 4 J 3 → 4 = 3 P 2 = P 4 J 2 → 4 + P 3 J 2 → 3 = 8 P 1 = P 3 J 1 → 3 = 3 \begin{cases} \boldsymbol{P}_4=1\\ \boldsymbol{P}_3=\boldsymbol{P}_4\boldsymbol{J}_{3\rightarrow 4}=3\\ \boldsymbol{P}_2=\boldsymbol{P}_4\boldsymbol{J}_{2\rightarrow 4}+\boldsymbol{P}_3\boldsymbol{J}_{2\rightarrow 3}=8\\ \boldsymbol{P}_1=\boldsymbol{P}_3\boldsymbol{J}_{1\rightarrow 3}=3\\\end{cases} P4=1P3=P4J34=3P2=P4J24+P3J23=8P1=P3J13=3

3 神经网络计算图

一个神经网络的计算图实例如下,所有参数都可以用之前的模型表示

在这里插入图片描述

L { u 1 = W 1 ∈ R n 1 × n 0 u 2 = b 1 ∈ R n 1 u 3 = x ∈ R n 0 u 4 = W 2 ∈ R n 2 × n 1 u 5 = b 2 ∈ R n 2 u 6 = y ∈ R n 2    C { u 7 = z 1 ∈ R n 1 = W 1 x + b 1 u 8 = a 1 ∈ R n 1 = σ ( z 1 ) u 9 = z 2 ∈ R n 2 = W 2 a 1 + b 2 u 10 = y ∈ R n 2 = σ ( z 2 ) u 11 = E ∈ R = 1 2 ( y − y ~ ) T ( y − y ~ ) L\begin{cases} u_1=\boldsymbol{W}^1\in \mathbb{R} ^{n_1\times n_0}\\ u_2=\boldsymbol{b}^1\in \mathbb{R} ^{n_1}\\ u_3=\boldsymbol{x}\in \mathbb{R} ^{n_0}\\ u_4=\boldsymbol{W}^2\in \mathbb{R} ^{n_2\times n_1}\\ u_5=\boldsymbol{b}^2\in \mathbb{R} ^{n_2}\\ u_6=\boldsymbol{y}\in \mathbb{R} ^{n_2}\\\end{cases}\,\, C\begin{cases} u_7=\boldsymbol{z}^1\in \mathbb{R} ^{n_1}=\boldsymbol{W}^1\boldsymbol{x}+\boldsymbol{b}^1\\ u_8=\boldsymbol{a}^1\in \mathbb{R} ^{n_1}=\sigma \left( \boldsymbol{z}^1 \right)\\ u_9=\boldsymbol{z}^2\in \mathbb{R} ^{n_2}=\boldsymbol{W}^2\boldsymbol{a}^1+\boldsymbol{b}^2\\ u_{10}=\boldsymbol{y}\in \mathbb{R} ^{n_2}=\sigma \left( \boldsymbol{z}^2 \right)\\ u_{11}=E\in \mathbb{R} =\frac{1}{2}\left( \boldsymbol{y}-\boldsymbol{\tilde{y}} \right) ^T\left( \boldsymbol{y}-\boldsymbol{\tilde{y}} \right)\\\end{cases} L u1=W1Rn1×n0u2=b1Rn1u3=xRn0u4=W2Rn2×n1u5=b2Rn2u6=yRn2C u7=z1Rn1=W1x+b1u8=a1Rn1=σ(z1)u9=z2Rn2=W2a1+b2u10=yRn2=σ(z2)u11=ER=21(yy~)T(yy~)

4 自动微分机

自动微分机的基本原理是:

  • 跟踪记录从输入张量到输出张量的计算过程,并生成一幅前向传播计算图,计算图中的节点与张量一一对应
  • 基于计算图反向传播原理即可链式地求解输出节点关于各节点的梯度

必须指出,Pytorch不允许张量对张量求导,故输出节点必须是标量,通常为损失函数或输出向量的加权和;为节约内存,每次反向传播后Pytorch会自动释放前向传播计算图,即销毁中间计算节点的梯度和节点间的连接结构。

5 Pytorch中的自动微分

Tensor在自动微分机中的重要属性如表所示。

属性含义
device该节点运行的设备环境,即CPU/GPU
requires_grad自动微分机是否需要对该节点求导,缺省为False
grad输出节点对该节点的梯度,缺省为None
grad_fn中间计算节点关于全体输入节点的映射,记录了前向传播经过的操作。叶节点为None
is_leaf该节点是否为叶节点

完成前向传播后,调用反向传播API即可更新各节点梯度,具体如下

backward(gradient=None, retain_graph=None, create_graph=None)

其中

  • gradient是权重向量,当输出节点 y y y不为标量时需指定与其同维的gradient,并以标量 g r a d i e n t T y gradient^Ty gradientTy为输出进行反向传播
  • retain_graph用于缓存前向传播计算图,可应用于一次传播测试多个损失函数等情形;
  • creat_graph用于构造导数计算图,可用于进一步求解高阶导数。

5.1 梯度缓存

中间计算节点的梯度需要通过retain_grad()方法进行缓存

w1 = torch.tensor([[2.], [3.]], requires_grad=True)
b1 = torch.tensor([1.], requires_grad=True)
x = torch.tensor([[10.], [20.]])

y = torch.mm(w1.transpose(0, 1), x) + b1
y.retain_grad()	# 若不缓存则y.grad=None
out = 3*y
out.backward()


>> tensor([[30.], [60.]]) tensor([3.]) None tensor([[3.]])

5.2 参数冻结

若希望冻结网络部分参数,只调整优化另一部分参数;或按顺序训练分支网络而屏蔽对主网络梯度的,可使用detach()方法从计算图中分离节点,阻断反向传播。分离的节点与原节点共享值内存,但不具有gradgrad_fn属性。

# 记第一层网络w1-b1为f,第二层网络w2-b2为g
w1 = torch.tensor([[2.], [3.]], requires_grad=True)
w2 = torch.tensor([3.], requires_grad=True)
b1 = torch.tensor([1.], requires_grad=True)
b2 = torch.tensor([2.], requires_grad=True)
x = torch.tensor([[10.], [20.]])

y = torch.mm(w1.transpose(0, 1), x) + b1
y_ = y.detach()
z = w2 * y_ + b2
out = 3*z
out.backward()

print(w1.grad, b1.grad, w2.grad, b2.grad)
>> None None tensor([243.]) tensor([3.]) # f被冻结,梯度不更新
# 若不使用detach冻结y之前的网络,则
>> tensor([[ 90.], [180.]]) tensor([9.]) tensor([243.]) tensor([3.])

🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/381896.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3 企业级项目实战:项目须知与课程约定

本节内容很重要,希望大家能够耐心看完。 Vue3 企业级项目实战 - 程序员十三 - 掘金小册Vue3 Element Plus Spring Boot 企业级项目开发,升职加薪,快人一步。。「Vue3 企业级项目实战」由程序员十三撰写,2744人购买https://s.ju…

解决方案| anyRTC 融合其他厂商视频会议系统方案

背景 视频会议市场经历疫情后,不管是硬件视频会议还是云视频会议已经在各行各业铺开使用,特别是政府行业,职能部门除了几大硬件视频会议外,也开始逐渐尝试云视频会议,视频会议的场景运用除了日常的交流、沟通、学习外…

开启互联网赚钱模式

随着互联网的发展,现在几乎会玩手机和电脑的都离不开网络,自然出现了很多网络赚钱的项目,受到了很多新人创业者和做副业兼职者的欢迎。很多朋友都想利用电脑或手机在网上赚钱。其实不管做什么项目,都有一个过程,没有什…

【监控】Linux部署postgres_exporter及PG配置(非Docker)

目录一、下载及部署二、postgres_exporter配置1. 停止脚本stop.sh2. 启动脚本start.sh3. queries.yaml三、PostgreSQL数据库配置1. 修改postgresql.conf配置文件2. 创建用户、表、扩展等四、参考一、下载及部署 下载地址 选一个amd64下载 上传至服务器,解压 tax…

$ 6 :选择、循环

if-else语句 #include <stdio.h> //判断输入值是否大于0 int main() {int i;while (scanf("%d",&i)){if (i > 0)//不要在括号后加分号{printf("i is bigger than O\n");}else {printf("i is not bigger than O\n");}}return O; } …

cglib代理解析

工作原理 使用 <dependency><groupId>cglib</groupId><artifactId>cglib</artifactId><version>3.3.0</version></dependency>对类和接口分别进行代理 DemoService package com.fanqiechaodan.user.service;/*** author fa…

itop-3568 开发板系统编程学习笔记(3)目录 IO

【北京迅为】嵌入式学习之Linux系统编程篇 https://www.bilibili.com/video/BV1zV411e7Cy/ 个人学习笔记 文章目录mkdir() 函数opendir() 和 closedir() 函数readdir() 函数综合实验mkdir() 函数 头文件&#xff1a; #include <sys/types.h> #include <sys/stat.h&g…

linux代码调试-gdb

在windows调试各类代码经常依托相关便利的IDE工具&#xff0c;如Microsoft的Visual Studio,TI的Code Composer Studio,ADI的CrossCore Embedded Studio ,ADI的VisualDSP&#xff0c;Renesas的CS for CC,NXP的S32 Design Studio…这些调试&#xff0c;或借助软、硬件仿真&#x…

JPA 之 Hibernate EntityManager 使用指南

Hibernate EntityManager 专题 参考&#xff1a; JPA – EntityManager常用API详解EntityManager基本概念 基本概念及获得 EntityManager 对象 基本概念 在使用持久化工具的时候&#xff0c;一般都有一个对象来操作数据库&#xff0c;在原生的Hibernate中叫做Session&…

排序之损失函数List-wise loss(系列3)

排序系列篇&#xff1a; 排序之指标集锦(系列1)原创 排序之损失函数pair-wise loss(系列2)排序之损失函数List-wise loss(系列3) 最早的关于list-wise的文章发表在Learning to Rank: From Pairwise Approach to Listwise Approach中&#xff0c;后面陆陆续续出了各种变形&#…

SpringBoot入门 - SpringBoot HelloWorld

我们了解了SpringBoot和SpringFramework的关系之后&#xff0c;我们可以开始创建一个Hello World级别的项目了。创建 SpringBoot Web 应用为快速进行开发&#xff0c;推荐你使用IDEA这类开发工具&#xff0c;它将大大提升你学习和开发的效率。选择 Spring InitializeSpring提供…

开源的 OA 办公系统 — 勾股 OA 4.3.01 发布

勾股 OA 办公系统是一款简单实用的开源的企业办公系统。系统集成了系统设置、人事管理、行政管理、消息管理、企业公告、知识库、审批流程设置、办公审批、日常办公、财务管理、客户管理、合同管理、项目管理、任务管理等功能模块。系统简约&#xff0c;易于功能扩展&#xff0…

【vue】图标选择(elementUI和svg结合)

目标&#xff1a;在做菜单权限的时候需要选择图标&#xff0c;如果既想要用elementUI自带的图标&#xff0c;还想要自定义的图标&#xff0c;这时就需要二者结合一下如果用的是vue-admin-template&#xff0c;那svg组件和引入elementUI是不需要操作的&#xff0c;直接使用即可。…

pytest学习和使用17-Pytest如何重复执行用例?(pytest-repeat)

17-Pytest如何重复执行用例&#xff1f;&#xff08;pytest-repeat&#xff09;1 使用场景2 pytest-repeat插件2.1 环境要求2.2 插件安装3 pytest-repeat使用3.1 重复测试直到失败3.2 用例标记执行重复多次3.3 命令行参数--repeat-scope详解3.3.1 class示例3.3.2 module示例1 使…

如何在软件测试面试中脱颖而出?(附教程)天花板都这样回答

面试软件测试工程师岗位&#xff0c;是否真的如网上所说&#xff0c;需要不停刷面试题?面试题可能掌握的技巧实际是一样的&#xff0c;只是题目形式不一样&#xff0c;那么应该如何在面试中脱颖而出呢?今天我们就来聊一聊。 我录制了一整套完整的软件测试面试的话术教程&…

拿下32k成功入职阿里软件测试面试常见问题及回答技巧

1、什么是兼容性测试&#xff1f;兼容性测试侧重哪些方面&#xff1f; 参考答案&#xff1a; 兼容测试主要是检查软件在不同的硬件平台、软件平台上是否可以正常的运行&#xff0c;即是通常说的软件的可移植性。 兼容的类型&#xff0c;如果细分的话&#xff0c;有平台的兼容…

测试结束参考标准

在软件消亡之前&#xff0c;如果没有测试的结束点&#xff0c;那么软件测试就永无休止&#xff0c;永远不可能结束。软件测试的结束点&#xff0c;要依据自己公司具体情况来制定&#xff0c;不能一概而论!个人认为测试结束点由以下几个条件决定&#xff1a; 1.基于“测试阶段”…

预览版Edge申请微软new Bing失败解决方案

文章目录1.首先需要配置科学上网2.下载预览版Edge浏览器卡它bug&#xff01;卡它bug&#xff01;卡它bug&#xff01;没有申请上ChatGPT的朋友们&#xff0c;试试new Bing吧&#xff0c;更新更强大&#xff0c;关于申请方式&#xff0c;网上已经有很多帖子了&#xff0c;其中一…

WebRTC 拥塞控制 | Trendline 滤波器

1.指数平滑1.1一次指数平滑法&#xff08;Single Exponential Smoothing&#xff09;指数平滑法&#xff08;Exponential Smoothing&#xff09; 是在移动平均法基础上发展起来的一种时间序列分析预测法&#xff0c;它是通过计算指数平滑值&#xff0c;配合一定的时间序列预测模…

fiddler抓包实战(1),模拟手机弱网测试,判断BUG来自客户端还是服务端

手机app中常见的测试之一就是弱网测试&#xff0c;什么是弱网测试呢?顾名思义就是模拟弱网的时候用户对于手机的一些操作和响应是否成功&#xff0c;在使用的过程中是否能够正常的使用 手机端常见弱网测试方法就是切换5G、4G、3G、2G这样 Web中就可以直接模拟打开F12然后就可…