AI学习(5):PyTorch-核心模块(Autograd):自动求导

news2025/1/20 11:02:13

1.介绍

在深度学习中,自动求导是一项核心技术,它使得我们能够方便地计算梯度并优化模型参数。PyTorch 提供了一个强大的自动求导模块(Autograd),它可以自动计算张量的导数得出梯度信息,同时也支持高阶导数计算。

1.1 概念词

在学习PyTorch的过程中,经常会看到这些词汇: 自动求导、梯度计算、前向传播、反向传播、动态计算图等,下面是一些简单介绍:

  • 自动求导PyTorch Autograd 模块负责自动计算张量的梯度。当我们在 PyTorch 中定义了一个张量,并设置了 requires_grad=True 时,PyTorch 会自动跟踪对该张量的所有操作,并构建一个动态计算图。
  • 梯度计算:梯度是函数在某一点上的导数,表示函数在该点的变化率。在深度学习中,梯度可以告诉我们在参数空间中,哪些方向可以使得损失函数值减小最快PyTorchAutograd 模块通过构建计算图并使用反向传播算法,自动计算张量的梯度。
  • 前向传播:前向传播是指数据从输入层经过隐藏层传递到输出层的过程。在前向传播过程中,每一层的输入经过权重和偏置的线性变换,然后经过激活函数计算得到输出。
  • 反向传播:反向传播是训练神经网络时使用的一种优化算法。它利用链式法则计算损失函数对模型参数的梯度,从而实现模型参数的更新。在 PyTorch 中,反向传播算法通过计算动态计算图的梯度来实现。
  • 动态计算图:动态计算图是 PyTorch 中的一个重要特性,它与静态计算图不同,可以根据代码的执行情况动态构建计算图。动态计算图使得 PyTorch更加灵活,可以处理各种动态的模型结构和数据流动。

他们之间的依赖关系:

  • 自动求导依赖于动态计算图,因为动态计算图记录了张量之间的依赖关系,从而使得 PyTorch 能够跟踪对张量的操作;
  • 梯度计算依赖于自动求导和动态计算图,因为梯度是通过自动求导和反向传播算法在动态计算图中计算得到的。
  • 前向传播和反向传播是损失函数优化的过程,依赖于梯度计算和动态计算图。

2.导数

2.1 导数定义

在学习自动求导模块(Autograd)之前,我们先简单回忆下高数中是如何定义导数的:

2.2 导数作用

从导数的定义上来看,不但理解起来比较费劲,也很难看出导数在深度学习中有什么作用,针对大部分场景的求导,本质上都是求某个函数在某一点的切线。如下图是一个经典的切线模型,求的是 x 0 x_0 x0处的导数:

来自百度百科

看到这里,可能还是没有想明白,导数在深度学习中到底有什么作用?在学习AI时,经常会听到道士下山的故事,故事里最后抛出的问题是: 怎么样让道士快速下山? 最快的办法就是顺着坡度最陡峭的地方走下去。那么怎么样找到最陡峭的地方呢? 答案就是: 求导; 上面说了求导的本质就是某点的切线,切线则有斜率,斜率越大的地方也就是越陡峭的点,然后沿着相反的方向进行,这也是梯度下降算法的原理。

3.梯度计算

@注: 求导后得到的结果,在深度学习中,被称为梯度。

只有体会到复杂操作后的过程,才能真实感受到工具的便捷性,下面分别使用两种方式对函数 f ( x ) = 3 x 2 + 2 x + 1 f(x) = 3x^2+2x+1 f(x)=3x2+2x+1进行求导;下图是列举一些常见函数对应的的求导函数公式,方便后续手动计算时,进行参考

常见求导函数

更多常见函数的求导函数示例:https://baike.baidu.com/item/导数/579188#3

3.1 手动计算

3.2 自动计算

import torch

# 定义函数
def myfunction(x):
    return 3 * x ** 2 + 2 * x + 1

if __name__ == '__main__':
    # 定义变量,并为其指定需要计算梯度
    t = torch.tensor(2.0, requires_grad=True)
    # 计算函数的值
    result = myfunction(t)
    # 反向传播,进行梯度计算
    result.backward()
    # 打印梯度
    print('打印梯度:', t.grad)
    

# 打印梯度:tensor(14.)

调用 backward() 方法时,PyTorch会从张量的节点开始,沿着计算图反向传播,计算所有叶子节点相对于该张量的梯度。需要特别注意的是: 在每次调用 backward() 方法之后,PyTorch 会自动清空计算图中的梯度信息。因此,多次调用 backward() 方法会尝试在没有梯度信息的情况下进行反向传播,从而导致运行时错误。

@注: 从上面示例可以看出Autograd便捷性,如果没有自动求导包Autograd的存在,想想当函数变的复杂时,该怎么去计算某点的导数…

4.梯度累积

PyTorch 中,反向传播函数 backward() 只能在一个张量(或者一系列张量)对应的图中被调用一次,因为它会计算当前图中所有叶子节点的梯度。如果多次调用backward(),会发生梯度累积,导致数据不准确;

4.1 错误示例

修改【3.2】代码示例:

def doBackward(var: torch.tensor):
    # 计算函数的值
    result = myfunction(var)
    # 反向传播,进行梯度计算
    result.backward()
    print('打印梯度:', var.grad)


if __name__ == '__main__':
    # 定义变量,并为其指定需要计算梯度
    t = torch.tensor(2.0, requires_grad=True)
    # 请求多次
    for i in range(3):
        doBackward(t)
             
"""
打印梯度: tensor(14.)
打印梯度: tensor(28.)
打印梯度: tensor(42.)
"""

通过上面运行输出,发现自动求导的结果(梯度)进行了累积,为了避免这种问题的出现,通常需要我们在模型训练过程中,手动清除之前计算的梯度。

4.2 清除梯度

通常情况下,在每次进行反向传播之前,需要调用 optimizer.zero_grad() 来清空之前计算的梯度。这样可以避免梯度累积,确保每次反向传播都是基于当前的梯度计算。修改上面示例中的部分代码:

def doBackward(var: torch.tensor):
    # 计算函数的值
    result = myfunction(var)
    # ------- 假设有个优化器:optimizer -------
    # 在每次迭代之前清零梯度
    optimizer.zero_grad()
    # 反向传播,进行梯度计算
    result.backward()
    print('计算结果:', var.grad)

4.3 累积影响

为什么梯度不能累积呢?根据资料查询可以发现,梯度累积可能会导致几个问题,尤其是在训练深度神经网络时:

  • 减慢收敛速度:梯度累积会导致每个参数的梯度在多次迭代中被累积起来。如果梯度一直累积而不进行更新,可能会导致收敛速度减慢,因为参数更新的幅度变小了。
  • 数值不稳定性:梯度累积可能导致数值不稳定性,尤其是在使用较大的学习率时。由于梯度的累积,更新的幅度可能会变得非常大,导致数值溢出或梯度爆炸的问题。
  • 内存占用:梯度累积会增加内存的占用,因为需要保存多次迭代中的梯度信息。在内存受限的情况下,梯度累积可能导致内存不足的问题,从而无法完成训练。
  • 局部最优解陷阱:梯度累积可能会导致模型陷入局部最优解,而无法跳出。由于梯度的累积,模型可能会固定在一个局部最优解附近,无法继续搜索更好的解决方案。

因此,在训练深度神经网络时,通常建议避免梯度累积,确保每次迭代都使用当前的梯度进行更新,以保证训练的稳定性和收敛速度。

5.局部禁用

  • 什么场景用: 当需要在训练过程中固定某些参数或者临时关闭梯度计算时;
  • 怎么使用: 可以使用 torch.no_grad() 上下文管理器或者在张量上调用 .detach() 方法来实现局部禁用梯度计算。

下面列举一些情况下,可能需要使用局部禁用梯度计算的具体示例:

5.1 固定模型参数禁用

在迁移学习或者模型微调中,通常会冻结预训练模型的一部分参数,只更新其中的部分参数。为了实现这一目的,可以使用 torch.no_grad() 上下文管理器来禁用梯度计算。

# 示例:冻结预训练模型的一部分参数
with torch.no_grad():
    for param in model.parameters():
        param.requires_grad = False
    # 只对新添加的层的参数进行训练
    optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001)

5.2 模型推断时禁用

在模型推断时,不需要计算梯度,因此可以使用 torch.no_grad() 上下文管理器来禁用梯度计算,以提高推断速度和减少内存占用。

# 示例:在前向推断时禁用梯度计算
with torch.no_grad():
    output = model(input)

5.3 计算某些指标时禁用

在计算模型的性能指标(如准确率、损失值等)时,不需要计算梯度,因此可以使用 torch.no_grad() 上下文管理器来禁用梯度计算,以提高计算效率。

# 示例:在计算指标时禁用梯度计算
with torch.no_grad():
    loss = criterion(output, target)

通过局部禁用梯度计算,可以灵活地控制梯度计算的范围,提高训练和推断的效率,并且可以避免不必要的梯度计算和内存消耗。

本文由mdnice多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1475332.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微服务-商城订单服务项目

文章目录 一、需求二、分析三、设计四、编码4.1 商品服务4.2 订单服务4.3 分布式事务4.4 订单超时 商品、购物车 商品服务: 1.全品类购物平台 SPU:Standard Product Unit 标准化产品单元。是商品信息聚合的最小单位。是一组可复用、易检索的标准化信息的集合&#x…

EMR StarRocks实战——Mysql数据实时同步到SR

文章摘抄阿里云EMR上的StarRocks实践:《基于实时计算Flink使用CTAS&CDAS功能同步MySQL数据至StarRocks》 前言 CTAS可以实现单表的结构和数据同步,CDAS可以实现整库同步或者同一库中的多表结构和数据同步。下文主要介绍如何使用Flink平台和E-MapRed…

【程序员英语】【美语从头学】初级篇(入门)(笔记)Lesson 16 At the Shoe Store 在鞋店

《美语从头学初级入门篇》 注意:被 删除线 划掉的不一定不正确,只是不是标准答案。 文章目录 Lesson 16 At the Shoe Store 在鞋店对话A对话B笔记会话A会话B替换 Lesson 16 At the Shoe Store 在鞋店 对话A A: Do you have these shoes in size 8? B:…

如何运行github上的项目

为了讲明白这个过程,特意做了一个相当来说比较好读懂的原理图,希望和我一样初学的小伙伴也能很快上手哈😊 在Github中找到想要部署的项目,这里以BartoszJarocki/CV(线上简历📄)项目为例 先从头…

经典Go知识点总结

开篇推荐 来来来,老铁们,男人女人都需要的技术活 拿去不谢:远程调试,发布网站到公网演示,远程访问内网服务,游戏联机 推荐链接 1.无论sync.Mutex还是其衍生品都会提示不能复制,但是能够编译运行 加锁后复制变量,会将锁的状态也复制,所以 mu1 其实是已…

4核8G服务器并发数多少?性能如何?

腾讯云4核8G服务器支持多少人在线访问?支持25人同时访问。实际上程序效率不同支持人数在线人数不同,公网带宽也是影响4核8G服务器并发数的一大因素,假设公网带宽太小,流量直接卡在入口,4核8G配置的CPU内存也会造成计算…

React Switch用法及手写Switch实现

问&#xff1a;如果注册的路由特别多&#xff0c;找到一个匹配项以后还会一直往下找&#xff0c;我们想让react找到一个匹配项以后不再继续了&#xff0c;怎么处理&#xff1f;答&#xff1a;<Switch>独特之处在于它只绘制子元素中第一个匹配的路由元素。 如果没有<Sw…

[极客大挑战 2019]LoveSQL1 题目分析与详解

一、题目简介&#xff1a; 二、通关思路&#xff1a; 1、首先查看页面源代码&#xff1a; 我们发现可以使用工具sqlmap来拿到flag&#xff0c;我们先尝试手动注入。 2、 打开靶机&#xff0c;映入眼帘的是登录界面&#xff0c;首先尝试万能密码能否破解。 username: 1 or 11…

IDEA如何开启Dashboard

普通的面板 Run Dashboard面板 修改配置文件 找到项目的.idea文件夹 点击编辑workspace.xml文件 添加下方代码 <component name"RunDashboard"><option name"ruleStates"><list><RuleState><option name"name" valu…

雾锁王国服务器配置怎么选择?阿里云和腾讯云

雾锁王国/Enshrouded服务器CPU内存配置如何选择&#xff1f;阿里云服务器网aliyunfuwuqi.com建议选择8核32G配置&#xff0c;支持4人玩家畅玩&#xff0c;自带10M公网带宽&#xff0c;1个月90元&#xff0c;3个月271元&#xff0c;幻兽帕鲁服务器申请页面 https://t.aliyun.com…

通过QScrollArea寻找最后一个弹簧并且设置弹簧大小

项目原因&#xff0c;最近需要通过QScrollArea寻找其中最后一个弹簧并且设置大小和策略&#xff0c;因为无法直接调用UI指针&#xff0c;所以只能用代码寻找。 直接上代码&#xff1a; if (m_scrollArea){int iScrollWidth m_labelSelectedTitle->width();m_scrollArea-&g…

数据结构与算法(数组,栈,队列,链表,哈希表,搜索算法,排序算法,查找算法,策略算法,递归算法,二叉搜索树BST,动态规划算法)

文章目录 1 课程介绍1.1 前置知识1.2 为什么要学习算法1.3 大厂面试常见数据结构题目(基础)1.4 数据结构和算法的关系 2 数据结构2.1 数据结构概述2.1.1 数据结构是什么2.1.2 数据结构分类2.1.2.1 线性结构2.1.2.2 非线性结构2.1.2.3 小总结 2.1.3 数据结构范围 2.2 数组Array2…

前端基础面试题(二)

摘要&#xff1a;最近&#xff0c;看了下慕课2周刷完n道面试题&#xff0c;记录下... 1. offsetHeight scrollHeight clientHeight 区别 计算规则&#xff1a; offsetHeight offsetWidth : border padding content clientHeight clientWidth: padding content scrollHeight…

unity学习(42)——创建(create)角色脚本(panel)——UserHandler(收)+CreateClick(发)——服务器收包2

1.解决上一次留下的问题&#xff1a; log和reg的时候也有session&#xff0c;输出看一下这两个session是同一个不&#xff1a; 实测结果reg log accOnline中的session都是同一个对象&#xff0c;但是getAccid时候的session就是另一个了。 测试结果&#xff0c;说明在LogicHan…

lv21 QT对话框3

1 内置对话框 标准对话框样式 内置对话框基类 QColorDialog, QErrorMessage QFileDialog QFontDialog QInputDialog QMessageBox QProgressDialogQDialog Class帮助文档 示例&#xff1a;各按钮激发对话框实现基类提供的各效果 第一步&#xff1a;实现组件布局&…

水电表远程集中抄表管理系统

水电表远程集中抄表管理系统是当前水电行业智能化发展的关键技术之一&#xff0c;为水电企业和用户提供了便捷、高效的抄表管理解决方案。该系统结合了远程监控、自动抄表、数据分析等多种功能&#xff0c;实现了水电抄表的智能化和精准化&#xff0c;为用户节省了大量人力物力…

The authenticity of host ‘github.com (20.205.243.166)‘ can‘t be established.

1、运行git clone报错&#xff1a; The authenticity of host github.com (20.205.243.166) cant be established. ECDSA key fingerprint is SHA256:p2QAC1TJYererOttrVc98/R1BWERWu3/LiyFdHfQM. Are you sure you want to continue connecting (yes/no/[fingerprint])? 这个…

【六袆-Golang】Golang:安装与配置Delve进行Go语言Debug调试

安装与配置Delve进行Go语言Debug调试 一、Delve简介二、win-安装Delve三、使用Delve调试Go程序[命令行的方式]四、使用Golang调试程序 Golang开发工具系列&#xff1a;安装与配置Delve进行Go语言Debug调试 摘要&#xff1a; 开发环境中安装和配置Delve&#xff0c;一个强大的G…

2024年腾讯云4核8G12M配置的轻量服务器同时支持多大访问量?

腾讯云4核8G服务器支持多少人在线访问&#xff1f;支持25人同时访问。实际上程序效率不同支持人数在线人数不同&#xff0c;公网带宽也是影响4核8G服务器并发数的一大因素&#xff0c;假设公网带宽太小&#xff0c;流量直接卡在入口&#xff0c;4核8G配置的CPU内存也会造成计算…

HUAWEI 华为交换机 配置基于VLAN的MAC地址学习限制接入用户数量 配置示例

组网需求 如 图 2-15 所示&#xff0c;用户网络 1 通过 LSW1 与 Switch 相连&#xff0c; Switch 的接口为 GE0/0/1 。用户网络2通过 LSW2 与 Switch 相连&#xff0c; Switch 的接口为 GE0/0/2 。 GE0/0/1 、 GE0/0/2 同属于 VLAN2。为控制接入用户数&#xff0c;对 VLAN2 进…