Pytorch入门(6)—— 梯度计算控制

news2025/1/16 4:04:22
  • 前文 PyTorch入门(2)—— 自动求梯度 介绍过 Pytorch 中的自动微分机制,这是实现神经网络反向传播的基础,也是所有深度学习框架最重要的基础设施之一
  • 梯度计算是需要占用计算资源的,而我们并不总是需要计算梯度(比如做评估时),Pytorch 提供了几种方式来控制梯度计算,本文对这些方法进行梳理
  • 参考自 pytorch 文档:Locally disabling gradient computation

文章目录

  • 1. 回顾 pytorch 的自动微分机制
  • 2. 局部梯度控制
    • 2.1 通过设置 `requires_grad` 实现精确梯度控制
    • 2.2 三种梯度计算模式
      • 2.2.1 梯度模式 (Grad Mode)
      • 2.2.2 无梯度模式 (No-grad Mode)
      • 2.2.3 推断模式 (Inference Mode)
    • 2.3 容易混淆的模型评估模式(Evaluation Mode)
  • 3. 总结

1. 回顾 pytorch 的自动微分机制

  • PyTorch 提供的 autograd 是一个反向自动微分系统,它能根据对 tensor 的操作过程自动构建计算图。具体而言,计算图是一个有向无环图,记录了前向传播过程的全部数据操作,图中的根节点是输出张量(Output Tensor),叶节点是输入张量(Input Tensor)。沿着这个图即可使用链式法则计算得到中间的梯度。下面给出一个例子
    • x = [ x 1 , x 2 , x 3 , x 4 ] \pmb{x} = [x_1,x_2,x_3,x_4] x=[x1,x2,x3,x4] 是一个 1 × 4 1\times 4 1×4 的向量,求下式梯度 f ( x ) = 1 4 ∑ i = 1 4 3 ( x i + 2 ) 2 f(\pmb{x})=\frac{1}{4}\sum_{i=1}^43(x_i+2)^2 f(x)=41i=143(xi+2)2。注意向量的梯度也将是一个同尺寸的向量
    • 可以把计算图如下画出,可见,有 ∂ f ∂ x i = ∂ d ∂ x i = ∂ d ∂ c i ∂ c i ∂ b i ∂ b i ∂ a i ∂ a i ∂ x i = 1 4 ⋅ 3 ⋅ ( 2 x i + 4 ) ⋅ 1 = 1.5 x i + 3 \frac{\partial f}{\partial x_i} =\frac{\partial d}{\partial x_i}= \frac{\partial d}{\partial c_i}\frac{\partial c_i}{\partial b_i}\frac{\partial b_i}{\partial a_i}\frac{\partial a_i}{\partial x_i}=\frac{1}{4}·3·(2x_i+4)·1=1.5x_i+3 xif=xid=cidbiciaibixiai=413(2xi+4)1=1.5xi+3
      在这里插入图片描述
  • pytorch 中计算的对象都是 tensor,所以计算图中每个节点也都是一个 tensor 对象。Pytorch 使用了动态计算图机制,每次前向传播过程中都会从头构造一次计算图,我们可以在每一次迭代中改变计算过程或禁用部分梯度计算,从而改变计算图的形状和大小
  • 关于梯度计算的更多细节,请参考前文 PyTorch入门(2)—— 自动求梯度

2. 局部梯度控制

  • 有几种机制可以在Python中临时禁用梯度计算:
    1. 要在一段代码块中禁用梯度计算,可以使用 no-grad 模式和 inference模式等上下文管理器
    2. 要更精确地控制梯度,比如从计算图中剔除部分子图,可以通过设置计算图中节点 tensor 的 requires_grad 字段来实现。这样可以打断有向无环图中的一些有向边,从而选择性地排除某些子图不参与梯度计算
  • PyTorch 中还有一个针对 nn.Module 的评估模式方法 nn.Module.eval(),它实际上并不用于禁用梯度计算。但是由于其名称的误导性,经常与上述机制混淆使用

2.1 通过设置 requires_grad 实现精确梯度控制

  • Tensor.requires_grad 是 tensor 对象的一个标志变量,默认为 False,它在前向传播和反向传播中都起作用,允许对梯度计算中的子图进行精细排除。

    1. 正向传递过程中,一个操作只有在其输入张量中至少有一个 requires_grad=True 时才会记录在计算图中
    2. 在向后传递期间,只有 requires_grad=True叶张量才会将梯度累积到它们的 .grad 字段中

    有几种方式可以将 Tensor.requires_grad 设置为 True

    1. 定义 tensor 时设置参数,如 torch.ones(2,2,requires_grad=True)
    2. 使用 requires_grad_() 进行 in-place 设置,如 a.requires_grad_(True)
    3. 使用 nn.Parameter 对 tensor 进行包装,如 nn.Parameter(torch.zeros(2,2))。定义神经网络时如果需要优化一个张量(比如定义Transformer的可训练位置编码),通常使用这种方法
  • 值得注意的是,尽管每个张量都有这个标志,但设置它只对 leaf tensor 有意义;non-leaf tensor 是有可以记录其计算过程的 .grad_fn 方法,有一部分反向图与之相关的 tensor,它们自动具有 require_grad=True,因为要计算 leaf tensor 的梯度时必须借助相关的 non-leaf tensor 的梯度作为中间结果

  • 设置 requires_grad 是控制模型进行部分梯度计算的主要方法。例如考虑函数: y 3 = y 1 + y 2 = x 2 + x 3 y_3 = y_1+y_2=x^2+x^3 y3=y1+y2=x2+x3,有
    ∂ y 3 ∂ x = ∂ y 1 ∂ x + ∂ y 2 ∂ x = 2 x + 3 x = 5 x \frac{\partial y_3}{\partial x} = \frac{\partial y_1}{\partial x} + \frac{\partial y_2}{\partial x} = 2x+3x = 5x xy3=xy1+xy2=2x+3x=5x 如果将其中的 y 2 y_2 y2 设置为 requires_grad=False,梯度就无法从 y 2 y_2 y2 往回传播,这时有
    ∂ y 3 ∂ x = ∂ y 1 ∂ x = 2 x \frac{\partial y_3}{\partial x} = \frac{\partial y_1}{\partial x}= 2x xy3=xy1=2x 计算图如下
    在这里插入图片描述

    x = torch.tensor(1.0, requires_grad=True)
    y1 = x ** 2 
    with torch.no_grad():
        y2 = x ** 3
    y3 = y1 + y2
    
    print(x.requires_grad)		# True
    print(y1, y1.requires_grad) # tensor(1., grad_fn=<PowBackward0>) True
    print(y2, y2.requires_grad) # tensor(1.) False
    print(y3, y3.requires_grad) # tensor(2., grad_fn=<AddBackward0>) True
    
    y3.backward()
    print(x.grad)               # tensor(2.)
    
    #y2.backward() # 报错: element 0 of tensors does not require grad and does not have a grad_fn
    

    精细梯度控制在模型微调期间比较常用,比如若想在微调时冻结部分预训练模型,只需对不想更新的参数应用 .requires_grad_(False) 即可,这样使用这些参数作为输入的计算就不会被记录在向前传递中,它们不再成为计算图的一部分,也就不会在向后传递时更新它们的 .grad 字段了

  • 另外,requires_grad 也可以在模块级别通过 nn.Module.requires_grad_() 进行设置,这对模块内的所有参数生效 (默认情况下requires_grad=True)

2.2 三种梯度计算模式

  • 除了设置 requires_grad 之外,Pytorch 还提供了三种可以影响 autograd 内部梯度计算的模式:默认模式/梯度模式(Grad Mode)无梯度模式(No-grad Mode)推理模式(Inference Mode),所有这些模式都可以通过python语法中的上下文管理器和装饰器进行切换

2.2.1 梯度模式 (Grad Mode)

  • 这是 Pytorch 工作的默认模式,是我们在没有启用其他模式时隐含的模式。为了与 “无梯度模式” 形成对比,有时也被称为 “梯度模式”。梯度模式是 requires_grad=True 生效的唯一模式,requires_grad 在其他两种模式中总是被设置为 False

2.2.2 无梯度模式 (No-grad Mode)

  • 在无梯度模式下即使有 require_grad=True 的输入,也不会在反向图中记录。有两种常用的进入无梯度模式的方法:使用上下文管理器(with语法)和函数装饰器
    with torch.no_grad():
        do_something()
    
    @torch.no_grad()
    def do_something_func():
        do_something()
    
    这两种方法可以方便地禁用代码块或函数的梯度。另外还有一种手动设置的方法
    torch.set_grad_enabled(False)
    do_something()
    torch.set_grad_enabled(True)
    
  • 无梯度模式适用于有一些操作无需记录梯度,但需要中间计算结果用于后续(梯度模式下)梯度计算的情况(可以理解成将计算图的一部分变成一个常数)
    1. 编写优化器时可能很适合使用无梯度模式:每轮迭代中,优化器要就地更新模型参数,这些更新操作不应被记录梯度,之后在下一轮的前向传递中要使用更新后的参数进行梯度模式的计算,例如
      def sgd(params, lr, batch_size): 
          """小批量随机梯度下降"""
          with torch.no_grad():
              for param in params:    # param 是一个list,如果模型是y=Xw+b,则param=[w,b]
                  param -= lr * param.grad / batch_size
                  param.grad.zero_()
      
    2. torch.nn.init 方法的实现也依赖于无梯度模式,以避免在就地初始化参数时就自动跟踪梯度了
    3. 做模型验证或评估时也常常使用无梯度模式
  • 由于无梯度模式下不会生成反向计算图,显存占用和计算资源的消耗都大大减少了,体现在代码上就是计算验证损失时 batch size 可以远大于计算训练损失时的 batch size,而且计算更快

2.2.3 推断模式 (Inference Mode)

  • 推断模式是无梯度模式的极端版本,这种模式下也不会记录反向图,它的执行速度更快,但缺点在于推断模式下创建的 tensor 将无法用于后续(在梯度模式下)梯度计算。进入推断模式也有上下文管理器(with语法)和函数装饰器两种方法
    with torch.inference_mode():
    	do_something()
    
    @torch.inference_mode()
    def func(x):
      	do_something()
    
  • 建议在代码中不需要自动梯度跟踪的部分 (如数据处理和模型评估阶段) 尝试推理模式,相比过去使用无梯度模式这可以无成本地提升性能。如果在启用推理模式后遇到错误,请检查是否在退出推理模式后由 Autograd 记录的计算中使用了在推理模式下创建的 tensor,如果无法避免这种情况,你可以随时切换回无梯度模式
  • 需要注意的是,推断模式是从 pytorch 1.10 开始引入的新特性,使用前需要确保 Pytorch 版本支持

2.3 容易混淆的模型评估模式(Evaluation Mode)

  • (模型)评估模式 nn.Module.eval() 实际上不是一种影响 autograd 内部梯度计算的模式,但它有时会被混淆成这样一种机制。 在功能上,module.eval()/ module.train() 与 2.2 节介绍的三种模式完全正交。model.eval() 如何影响模型完全取决于模型中使用的特定模块,以及它们是否定义了任何特定于训练模式的行为。具体而言:model.eval() 的作用是不启用 Batch Normalization 和 Dropout,即
    1. Dropout 层会让所有的激活单元都通过,不会随机失能
    2. Batch Normalization 层会停止计算和更新 mean 和 var,直接使用在训练阶段已经学出的 mean 和 var 值
  • 建议无论模型定义中是否涉及上述操作,都在训练时始终使用 model.train(),在评估模型(验证/测试)时始终使用 model.eval(),以免受到任何潜在导致这些操作的模型更新的影响
  • 注意在梯度模式下,即使调用了 module.eval(),所有梯度还是会被计算

3. 总结

  • Pytorch 使用 Autograd 机制自动追踪对 tensor 的各种操作,并实时生成可以用于计算梯度的反向计算图。可以通过设置 Tensor.requires_grad 参数来打断计算图中的某些边,以实现对梯度计算的精确控制。在微调模型时可以用这种方式冻结某些参数
  • Pytorch 还提供了三种可以影响梯度计算的模式
    1. 梯度模式:仅在这种模式下 requires_grad=True 生效,会进行计算图构建,这也是默认模式
    2. 无梯度模式:这时即使有 require_grad=True 的输入也不会在反向图中记录,适用于有一些操作无需记录梯度,但需要中间计算结果用于后续(梯度模式下)梯度计算的情况。这种模式下显存占用和计算资源的消耗都大大减少了
    3. 推断模式:无梯度模式的极端版本,也不会记录反向计算图,执行速度更快,但推断模式下创建的 tensor 将无法用于后续(在梯度模式下)梯度计算
  • nn.Module 单独有一种评估模式,它的功能和以上三种梯度计算模式是正交的,它仅影响 Dropout 和 Batch Normalization 的行为模式而和梯度计算无关。在训练应时始终使用 model.train(),在评估模型 (验证/测试) 时应始终使用 model.eval()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/998366.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

锯片检测示例

1.锯片检测 1.1 应用示例目的与思路 (1) 提取并筛选锯齿的轮廓&#xff1b; (2) 对筛选后的锯齿轮廓进行直线拟合&#xff1b; (3) 统计正常锯齿的角度和缺陷锯齿的个数。 1.2 应用示例相关算子介绍 (1) threshold_sub_pix(Image : Border : Threshold : ) 功能&#xf…

应用开发平台集成工作流系列之10——流程建模功能环节业务逻辑处理的设计与实现

背景 基于工作流的表单流转&#xff0c;在某些特定的环节&#xff0c;需要执行一些业务逻辑处理。例如动态分配节点处理人、发送邮件或短信给待办用户、统计流程处理时长判断是否超时&#xff0c;以及业务层面数据处理&#xff08;例如&#xff0c;在请假流程中将部门领导审批…

Unity之Android项目的打包

一 Unity里面配置Android运行环境 1.1 首先unity需要集成android编译环境&#xff0c;点击FIle->Build Settings 1.2 没是否有Android模块&#xff0c;没的话先下载Android模块 1.3 按下面的操作&#xff0c;下载Android支持&#xff0c;SDK&#xff0c;NDK&#xff0c;和J…

15 - 多线程调优(上):哪些操作导致了上下文切换?

1、初识上下文切换 我们首先得明白&#xff0c;上下文切换到底是什么。 其实在单个处理器的时期&#xff0c;操作系统就能处理多线程并发任务。处理器给每个线程分配 CPU 时间片&#xff08;Time Slice&#xff09;&#xff0c;线程在分配获得的时间片内执行任务。 CPU 时间…

【图解RabbitMQ-6】说说交换机在RabbitMQ中的四种类型以及使用场景

&#x1f9d1;‍&#x1f4bb;作者名称&#xff1a;DaenCode &#x1f3a4;作者简介&#xff1a;CSDN实力新星&#xff0c;后端开发两年经验&#xff0c;曾担任甲方技术代表&#xff0c;业余独自创办智源恩创网络科技工作室。会点点Java相关技术栈、帆软报表、低代码平台快速开…

自然语言处理: 第十二章LoRA解读

论文地址:[2106.09685] LoRA: Low-Rank Adaptation of Large Language Models (arxiv.org) 理论基础 自从GPT-3.5问世以来&#xff0c;整个AI界基本都走向了大模型时代&#xff0c;而这种拥有数亿参数的大模型对于普通玩家来说作全量微调基本是不可能的事。从而微软公司提出了…

指令延迟隐藏

一、指令延迟隐藏 1. 延迟和延迟隐藏 指令延迟指计算指令从调度到指令完成所需的时钟周期如果在每个时钟周期都有就绪的线程束可以被执行&#xff0c;此时GPU处于满符合状态指令延迟被GPU满负荷计算状态所掩盖的现象称为延迟隐藏延迟隐藏对GPU编程开发很重要&#xff0c;GPU设…

BeanFactory 和 FactoryBean傻傻分不清楚

&#x1f935;‍♂️ 个人主页&#xff1a;香菜的个人主页&#xff0c;加 ischongxin &#xff0c;备注csdn ✍&#x1f3fb;作者简介&#xff1a;csdn 认证博客专家&#xff0c;游戏开发领域优质创作者,华为云享专家&#xff0c;2021年度华为云年度十佳博主 &#x1f40b; 希望…

HTTPS双向认证

双向认证&#xff0c;指的是客户端和服务器端都需要验证对方的身份&#xff0c;在建立HTTPS连接的过程中&#xff0c;握手的流程相对于单向认证多了几步。 单向认证的过程&#xff0c;客户端从服务器端下载服务器端公钥证书进行验证&#xff0c;然后建立安全通信通道。 双向通信…

java的数据类型与变量(超详细每个都有小结论,习题巩固)

【本文章的目标】 1.字面常量 2.数据类型 3.变量 文章最后有习题等来帮助巩固&#xff0c;加深印象&#xff0c;相信看完这篇文章&#xff0c;大家会有收获 1.字面常量 在上节课HelloWorld程序中&#xff0c;System.Out,println(Hello World"); 语句&#xff0c;不论…

算法[动态规划]---买卖股票最佳时机

1、题目&#xff1a; 给你一个整数数组 prices&#xff0c;其中 prices[i] 表示某支股票第 i 天的价格。 在每一天&#xff0c;你可以决定是否购买和/或出售股票。你在任何时候最多只能持一股股票。你也可以先购买&#xff0c;然后在同一天出售。 返回你能获得的最大利润 。 2…

PLSQL

文章目录 基本pl/sql语法流程控制条件判断&#xff08;两种&#xff09;循环结构&#xff08;三种&#xff09;goto&#xff0c;exit关键字 游标的使用异常的处理存储过程&#xff08;无返回值&#xff09;&#xff0c;存储函数&#xff08;有返回值&#xff09;触发器 命令行窗…

苹果手机远程控制安卓手机,为什么不能发起控制?

这位用户想要用iOS设备远程控制安卓设备&#xff0c;在被控端安装好AirDroid之后&#xff0c;就在控制端的苹果手机上也安装了AirDroid&#xff0c;然而打开控制端的软件&#xff0c;却没有在手机界面上看到【远程控制】按钮&#xff0c;于是提出了以上疑问。 解答 想要让iOS设…

A,B,C , D, E类地址的划分及子网划分汇总的详解

一、 A类地址 &#xff08;1&#xff09;A类地址第1字节为网络地址&#xff0c;其它3个字节为主机地址。它的第1个字节的第一位固定为0. &#xff08;2&#xff09;A类地址范围&#xff1a;1.0.0.1—126.255.255.254 &#xff08;3&#xff09;A类地址中的私有地址和保留地…

苹果电脑快捷键集合

苹果电脑Windows系统下的ALT键是组合键。苹果电脑键盘左下角的Fnoption是Windows的alt键。同时按下两个键是ALT键的功能。在非组合状态下&#xff0c;单独按Option键。 补充&#xff1a; 1. 按controlalt&#xff08;选项&#xff09;delete 启动任务管理器。 2. Option-Del…

nrf52832 使用ADC点LED

#define SAMPLES_IN_BUFFER 5 volatile uint8_t state 1;/*** brief UART events handler.*/void saadc_callback(nrf_drv_saadc_evt_t const * p_event) { // }//saadc的初始化 void saadc_init(void) {ret_code_t err_code;nrf_saadc_channel_config_t channel_config NR…

C#,数值计算——柯西微分(Cauchy deviates)的计算方法与源代码

1 文本格式 using System; namespace Legalsoft.Truffer { /// <summary> /// Cauchy deviates /// </summary> public class Cauchydev : Ran { private double mu { get; set; } private double sig { get; set; } public…

C++ -- 学习系列 static 关键字的使用

static 是 C 中常用的关键字&#xff0c;被 static 修饰的变量只会在 静态存储区&#xff08;常量数据也存放在这里&#xff09; 被分配一次内存&#xff0c;生命周期与整个程序一样&#xff0c;随着程序的消亡而消亡。 一 static 有以下几种用法&#xff1a; 1. 在文件中定义…

管理类联考——数学——汇总篇——知识点突破——应用题——交叉比例法/杠杆原理

读书笔记 甲有&#xff1a;x个a&#xff0c;乙有&#xff1a;y个b&#xff0c;甲乙的平均值为c&#xff0c;根据总数相等&#xff0c;得&#xff1a;axbyc(xy)&#xff0c;即ax-cxcy-by&#xff0c;则 x y c − b a − c \frac{x}{y}\frac{c-b}{a-c} yx​a−cc−b​ &#…

【Vue2.0源码学习】生命周期篇-初始化阶段(initState)

文章目录 1. 前言2. initState函数分析3. 初始化props3.1 规范化数据3.2 initProps函数分析3.3 validateProp函数分析3.4 getPropDefaultValue函数分析3.5 assertProp函数分析 4. 初始化methods5. 初始化data6. 初始化computed6.1 回顾用法6.2 initComputed函数分析6.3 defineC…