pytorch的自动微分、计算图 | 代码解析

news2025/1/11 23:57:21

在深度学习和机器学习中,自动微分是一个关键的概念,用于计算函数相对于其输入变量的导数(梯度)从而利用各类优化算法如梯度下降降低损失函数。PyTorch中的张量(tensor)提供了自动微分功能,它使得梯度计算变得非常方便,是深度学习模型训练的关键组成部分。

而梯度下降算法是通过计算图来实现的,计算图非常重要

可以把复杂的求导过程表示成计算图

1. 计算图原理剖析

1.1计算图正向 正向传播

正向传播很重要,我们以 y = ( x 1 2 + 2 x 2 ) 2 y=(x_1^2+2x_2)^2 y=(x12+2x2)2 为例建立计算图

通过中间变量,复杂式子可以划分为一次加减乘除幂运算

y = z 2 2 y=z_2^2 y=z22

z 2 = z 1 + z 3 z_2=z_1+z_3 z2=z1+z3

z 1 = x 1 2 z_1=x_1^2 z1=x12

如图

输入蓝色x1,x2,圆圈代表运算 红色是中间变量z1,z2

在这里插入图片描述

1.2反向传播算梯度

我们现在要求最终输出y对每一个参数x1,x2的梯度

根据链式法则对y求x1的偏导

由链式法则

∂ y ∂ x 1 = ∂ y ∂ z 2 ∗ ∂ z 2 ∂ z 1 ∗ ∂ z 1 ∂ x 1 \frac{\partial y}{\partial x_1}=\frac{\partial y}{\partial z_2}*\frac{\partial z_2}{\partial z_1}*\frac{\partial z_1}{\partial x_1} x1y=z2yz1z2x1z1

拆分每一部分分别求

∂ y ∂ z 2 = 2 z 2 = 20 \frac{\partial y}{\partial z_2}=2z_2=20 z2y=2z2=20

∂ z 2 ∂ z 1 = 1 \frac{\partial z_2}{\partial z_1}=1 z1z2=1

∂ z 1 ∂ x 1 = 2 x 1 = 4 \frac{\partial z_1}{\partial x_1}=2x_1=4 x1z1=2x1=4

把求得的三个累乘即可 得到结果80

∂ y ∂ x 1 = ∂ y ∂ z 2 ∗ ∂ z 2 ∂ z 1 ∗ ∂ z 1 ∂ x 1 = 80 \frac{\partial y}{\partial x_1}=\frac{\partial y}{\partial z_2}*\frac{\partial z_2}{\partial z_1}*\frac{\partial z_1}{\partial x_1}=80 x1y=z2yz1z2x1z1=80

根据链式法则对y求x2的偏导

由链式法则

∂ y ∂ x 1 = ∂ y ∂ z 2 ∗ ∂ z 2 ∂ z 3 ∗ ∂ z 3 ∂ x 2 \frac{\partial y}{\partial x_1}=\frac{\partial y}{\partial z_2}*\frac{\partial z_2}{\partial z_3}*\frac{\partial z_3}{\partial x_2} x1y=z2yz3z2x2z3

拆分每一部分分别求

∂ y ∂ z 2 = 2 z 2 = 20 \frac{\partial y}{\partial z_2}=2z_2=20 z2y=2z2=20

∂ z 2 ∂ z 3 = 2 \frac{\partial z_2}{\partial z_3}=2 z3z2=2

∂ z 3 ∂ x 2 = 1 \frac{\partial z_3}{\partial x_2}=1 x2z3=1

把求得的三个累乘即可 得到结果40

∂ y ∂ x 1 = ∂ y ∂ z 2 ∗ ∂ z 2 ∂ z 3 ∗ ∂ z 3 ∂ x 2 = 40 \frac{\partial y}{\partial x_1}=\frac{\partial y}{\partial z_2}*\frac{\partial z_2}{\partial z_3}*\frac{\partial z_3}{\partial x_2}=40 x1y=z2yz3z2x2z3=40

如图

在这里插入图片描述

1.3 其他补充

(1)通过计算图分析我们可以知道,必须先进行前向运算,每个节点的运算结果还需要保存起来,因为反向梯度回传计算可能用到,如箭头所指

在这里插入图片描述

(2)我们这里每一个节点代表一个很小的操作,就一个乘法或加法或幂,实际操作中我们可以把多个小节点合成一个大节点存储,这样的话就可以只存储更少的正向值,计算更少次的反向传播,我们把这样的计算图称作粗粒度计算图,相反我们上面讲到的是细粒度图

既然需要保存,就涉及内存的占用

选择粗粒度或细粒度的计算图取决于具体的应用和需求:

  • 如果内存资源有限,或者计算过程非常复杂,可以考虑使用粗粒度计算图以减少内存占用和计算开销。
  • 如果需要精确的梯度计算,或者希望深度学习框架能够自动进行反向传播,那么通常会选择细粒度计算图。

2.pytorch代码解析

2.1 标量

对于我们的一些深度学习框架内置的一些数据类型,如pytorch的tensor就是通过上述的方式来实现自动微分求导的

我们来看看代码实现

import torch
#申明张量
x1=torch.tensor(2.0)
x2=torch.tensor(3.0)

#设置梯度可求
x1.requires_grad_(True)
x2.requires_grad_(True)
print("反向求梯度前:",x1.grad,x2.grad)

#前向计算
y=(x1**2+2*x2)**2

#反向传播计算
y.backward()
print("反向求梯度后::",x1.grad,x2.grad)

输出

反向求梯度前: None None
反向求梯度后:: tensor(80.) tensor(40.)

总结:

(1)tensor张量必须先通过requires_grad_属性设置为True,PyTorch才会跟踪张量上的所有操作,并构建计算图计算梯度

(2)通过grad属性查看梯度,grad在默认未反向回传求梯度情况下为None

(3)前向计算后,反向计算通过backward()函数即可

注意:

(1)在PyTorch中,只有具有浮点数类型(如float32、float64等)的张量才能够进行自动微分(Autograd)。整数类型(如int32、int64)的张量默认情况下是不支持自动微分的。上述代码中如果把x1,x2改为int类型会报错RuntimeError: only Tensors of floating point dtype can require gradients

(2)由于梯度会累积,所以在求新的一轮的梯度时候,要通过grad_zero_函数清除梯度

我们还是以上面的运算为例,我们执行两次前向传播,两次反向传播计算,可以观察这种梯度累积现象

import torch
#申明张量
x1=torch.tensor(2.0)
x2=torch.tensor(3.0)
#设置梯度可求
x1.requires_grad_(True)
x2.requires_grad_(True)
print("反向求梯度前:",x1.grad,x2.grad)
#前向计算
y=(x1**2+2*x2)**2
#反向传播计算
y.backward()
#再次前向计算
y=(x1**2+2*x2)**2
#再次反向传播计算
y.backward()
print("反向求梯度后::",x1.grad,x2.grad)

输出

反向求梯度前: None None
反向求梯度后:: tensor(160.) tensor(80.)

是刚才的两倍

2.2 向量

我们上面为了更好理解,使用了标量做解释

实际使用中,参数和最终输出往往都是高纬张量,求导结果也往往是矩阵

当输入是标量,输出是标量的时候,或者输入是向量,输出是标量的时候,上面方法都没有问题。

但是当输出向量的时候,会报错RuntimeError: grad can be implicitly created only for scalar outputs 翻译过来是只能为标量输出创建梯度

因而我们需要先进行一步sum()操作,转向量为标量

import torch

# 假设模型参数是 w
w = torch.tensor([1.0,2.0,3.0], requires_grad=True)

# 定义损失函数 y(这里是一个简单的示例)
y = w*w + 2*w + 1

# 计算损失函数 y 的总和并执行自动微分
loss = y.sum().backward()


# 现在 w.grad 包含了损失函数对 w 的梯度
print(w.grad)  # 输出为 tensor([4., 6., 8.])

2.3 分离计算

在 PyTorch 中,有时候需要使用 .detach().detach_() 方法来分离张量以进行反向传播,通常是为了控制梯度流或避免不必要的计算。一些常见的情况和原因:

  1. 避免不必要的梯度计算:在某些情况下,我们可能希望跟踪某个张量的值,但不需要计算其梯度。例如,如果我们有一个预训练的模型,并且希望固定其中的某些参数(不要更新它们的梯度),则可以将这些参数分离以避免计算它们的梯度。这可以通过在这些参数上调用 .detach().detach_() 来实现。
  2. 避免计算图过大:有时候,我们可能担心计算图会变得过大,占用过多的内存。在这种情况下,您可以使用 .detach() 来剥离计算图中的一部分,以减少内存占用。这在长时间的训练过程中可能会很有用。
  3. 阻止梯度流:有时,我们可能不希望某个张量的梯度流向更底层的张量。例如,在生成对抗网络(GAN)中,生成器和判别器的训练过程可能需要阻止生成器的梯度传播到判别器。在这种情况下,可以使用 .detach()detach_() 来分离生成器的输出。

分离计算可以把某些计算移动到计算图之外,李沐老师的动手学深度学习举了这样一个例子

假设y是作为x的函数计算的,而z则是作为yx的函数计算的。 想象一下,我们想计算z关于x的梯度,但由于某种原因,希望将y视为一个常数, 并且只考虑到xy被计算后发挥的作用。

这里可以分离y来返回一个新变量u,该变量与y具有相同的值, 但丢弃计算图中如何计算y的任何信息。 换句话说,梯度不会向后流经ux。 因此,下面的反向传播函数计算z=u*x关于x的偏导数,同时将u作为常数处理, 而不是z=x*x*x关于x的偏导数。

x.grad.zero_()
y = x * x
u = y.detach()
z = u * x

z.sum().backward()
x.grad == u

输出

tensor([True, True, True, True])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1013582.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

游戏视频录制软件对比,哪款最适合你的需求?

随着电子竞技和游戏直播行业的迅速崛起,越来越多的玩家渴望记录并分享自己在游戏中的精彩瞬间。游戏视频录制软件正是满足这一需求的关键工具。本文将针对三款优秀的游戏视频录制软件进行对比分析,以便为读者提供选购建议。 游戏视频录制软件1&#xff1…

Y4455芯片开发的433遥控流水灯方案

越来越多的家庭通过无线遥控来控制家中的照明系统,本文将介绍一种基于Y4455芯片的433MHz无线遥控流水灯方案,创造出美丽的照明场景。 一、宇凡微Y4455芯片简介 宇凡微Y4455芯片是一款低功耗、高性能的315MHz和433MHz短距离无线通讯发射芯片。它支持ASK…

Linux exec函数族

exec并不是生成新的进程还是在原进程执行 我们通常先创建一个子进程,在子进程里面使用exec,因为调用exec成功后,原进程的资源都被取代,除了一些进程ID等,所以在子进程里面调用exec,对原进程无影响。 前六个…

forest--声明式HTTP客户端框架-spring-b oot项目整合

Forest 是一个开源的 Java HTTP 客户端框架,它能够将 HTTP 的所有请求信息(包括 URL、Header 以及 Body 等信息)绑定到您自定义的 Interface 方法上,能够通过调用本地接口方法的方式发送 HTTP 请求。 官方链接: &…

实战演练 | Navicat 常用功能之转储与运行 SQL 文件

数据库管理工作中,"转储 SQL 文件"和"运行 SQL 文件"是两个极为常见操作。一般来说,用户使用数据库管理工具或命令行工具来完成。Navicat 管理开发工具中的“转储 SQL 文件”和“运行 SQL 文件”功能具有直观易用的界面、多种文件格…

北斗高精度定位,破解共享单车停车乱象

如今,共享单车已经成为了许多人出行的首选方式,方便了市民们的“最后一公里”,给大家的生活带来了很多便利。然而,乱停乱放的单车也给城市治理带来了难题。在这种情况下,相关企业尝试将北斗导航定位芯片装载到共享单车…

企业如何拓展市场,获取客源并进行降本增效?

对于企业来说,在降低成本和提高效率的同时拓展市场和获取客户是一项复杂的挑战。以下是实现这一目标的一些策略和方法: 1.市场研究和细分:进行彻底的市场研究,以确定您的产品或服务最有前途的细分市场。将您的精力集中在最有利可…

【PickerView案例09-上午内容复习 Objective-C预言】

一、好,我们把前面两个案例:点餐系统、城市选择界面、复习一下,然后继续讲第三个案例:国旗选择界面 1.我们就直接照着这个Demo去说了啊, 先来看一下这个,点餐系统: 首先,我们说,点餐系统,整个界面儿呢,分几部分:三部分 1)顶部呢:一个View 2)中间呢:一个Pic…

EndNote21 | 安装及库的创建

EndNote21 | 安装及库的创建 一、EndNote21安装二、EndNote21库的创建 一、EndNote21安装 软件安装界面,双击“EndNote 21.exe”程序; 图1 安装软件界面点击next,选择30天试用,点击next; 图2 安装过程点击next&…

数据结构——图(图的存储及基本操作)

文章目录 前言一、邻接矩阵法(顺序存储)1.无向图存储邻接矩阵算法2.有向图存储邻接矩阵算法 二、邻接表法(图的链式存储结构)总结 前言 邻接矩阵法(图的顺序存储结构) 1.1 无向图邻接矩阵算法 1.2 有向图邻接矩阵算法邻接表法(图的一种链式存储结构) 一…

软文推广在企业中运用的优势有哪些?

随着互联网的发展,越来越多的企业在推广方式上已经逐渐脱离于传统媒体,软文推广已经成为了企业宣传的主要方式。也有不少企业来找盒子进行推广,接下来媒介盒子就来告诉大家,企业进行软文推广的优势有哪些? 成本低 传统…

neon常用指令(updating)

函数参考手册: https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics 并在左侧选择neon\ Neon 128bit寄存器,所以可支持并行运算 加快运算速度 减少循环 CPU运算比加载数据快,速度瓶颈在加载数据这里。 指令集命名…

为什么你觉得Odoo二次开发难?如何切入?

先说结论,学习Odoo开发,我建议从Odoo的开发者模式切入。事实上在Odoo官网很多问题的解决方案就是基于开发者模式的。 前天有位学了《Odoo开发者模式必知必会》课程的网友跟我说,他之前也花钱买了其他的Odoo开发、前端开发课程,但…

无涯教程-JavaScript - XOR函数

描述 XOR函数返回所有参数的逻辑异或。如果所提供条件的奇数判断为TRUE,则XOR函数返回TRUE,否则返回FALSE。 语法 XOR (logical1, [logical2],…)争论 Argument描述Required/Optionallogical1logical1 is required and subsequent logical values are optional.1 to 254 co…

【1++的C++进阶】之emplace详解

👍作者主页:进击的1 🤩 专栏链接:【1的C进阶】 在前面C11系列的文章里,我们漏掉了几个知识点,这篇文章对其中一个知识点进行讲解,关于剩余的知识点的文章在后面会相继出炉。 C11中,针…

网络广播模块2*30W 智能4G广播终端开发模块

SV-704UG 4G网络广播模块2*30W 智能4G广播终端开发模块 一、描述 SV-704UG网络音频模块是一款带2*30W功放输出的4G广播音频模块,采用高性能ARM处理器及专业Codec,能接收4G广播音频数据流,转换成音频模拟信号输出。带有一路line in输入&#…

分布式事务解决方案之可靠消息最终一致性

分布式事务解决方案之可靠消息最终一致性 什么是可靠消息最终一致性事务 可靠消息最终一致性方案是指当事务发起方执行完成本地事务后并发出一条消息,事务参与方(消息消费者)一定能 够接收消息并处理事务成功,此方案强调的是只要消息发给事务参与方最终…

java项目线上cpu过高如何排查

1、查看进程 # 查看cpu过高的进程 top -c2、拿着pid查找cpu过高的线程 # 查找 ps H -eo pid,tid,%cpu | grep 19235可以看到19236过高 3、线程转换16进制 printf "%x\n" 192364、查看代码地址 # 19235 进程 # 4b24 线程16进制 # -A20 前20行 jstack 19235 | gr…

回顾2023百度云智大会:人工智能的未来之路

原创 | 文 BFT机器人 在2023年的百度云智大会上,各界的科技专家、学者、企业家和创新者再次齐聚一堂,共同探讨和分享最新的人工智能、大数据、云计算等前沿技术和行业趋势。此次大会以"探索未来科技趋势"为主题,旨在引领行业对未来…

奥威BI系统:时刻跟着需求走,随需分析

面对同一张报表,不同浏览者有不同的需求,那怎么办?有能够时刻跟着浏览者需求走的数据分析报表吗?还真有,奥威BI系统随需分析,随时跟着需求走。 奥威BI系统中的报表就约等于一个平台,可随时展开…