【Pytorch笔记】4.梯度计算

news2025/2/23 16:21:51

深度之眼官方账号 - 01-04-mp4-计算图与动态图机制

前置知识:计算图
可以参考我的笔记:
【学习笔记】计算机视觉与深度学习(2.全连接神经网络)

计算图

在这里插入图片描述
以这棵计算图为例。这个计算图中,叶子节点为x和w。

import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

# 调用backward()方法,开始反向求梯度
y.backward()
print(w.grad)

print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

输出:

tensor([5.])
is_leaf:
 True True False False False
gradient:
 tensor([5.]) tensor([2.]) None None None

由此可见,非叶子节点在最后不会被保留梯度。这是出于节省空间的需要而这样设计的。实际的计算图会非常大,如果每个节点都保留梯度,会占用非常大的存储空间,而这些节点的梯度对于我们学习并没有什么帮助。

如果非要看他们的梯度,可以这样操作:在a = torch.add(w, x)的后面加上一句a.retain_grad(),这样a的梯度就会被存储起来。
输出会变成:

tensor([5.])
is_leaf:
 True True False False False
gradient:
 tensor([5.]) tensor([2.]) tensor([2.]) None None

对于节点,还可以看这些节点进行的运算。grad_fn,gradient function的缩写,表示这个节点的tensor是什么运算产生的。加一句:

print("gradient function:\n", w.grad_fn, '\n', x.grad_fn, '\n', a.grad_fn, '\n', b.grad_fn, '\n', y.grad_fn)

会输出

gradient function:
 None
 None
 <AddBackward0 object at 0x000001B1DA3651C0>
 <AddBackward0 object at 0x000001B1DA3651F0>
 <MulBackward0 object at 0x000001B1DA3515B0>

retain_graph

import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
a.retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)

# 调用backward()方法,开始反向求梯度
y.backward()
y.backward()

连续两次调用backward()方法,会报这样的错误:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因是我们进行第一次backward()后,计算图就被自动释放掉了,进行第二次backward()时,没有计算图可以计算梯度,于是报错。

解决方案:backward内部添加一个参数:retain_graph=True,意思是计算完梯度后保留计算图。

# 调用backward()方法,开始反向求梯度
y.backward(retain_graph=True)
y.backward()

这样就不会报错了。

gradient

当计算图末部的节点有1个以上时,有时我们会希望他们之间的梯度有一个权重关系。这时就会用上gradient

import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)

# 不难看出,y0和y1是两个互不干扰的末部节点
y0 = torch.mul(a, b)
y1 = torch.add(a, b)

# 将两个末部节点打包起来
loss = torch.cat([y0, y1], dim=0)
grad_tensors = torch.tensor([1., 2.])

# 将grad_tensors中的内容作为权重,变成y0+2y1
loss.backward(gradient=grad_tensors)

print(w.grad)

输出

tensor([9.])

如果把grad_tensors改成:

grad_tensors = torch.tensor([1., 3.])

输出变成:

tensor([11.])

torch.autograd.grad()

除了加减乘除法,我们还可以对torch进行求导操作。求的是 d ( o u t p u t s ) d ( i n p u t s ) \frac{d(outputs)}{d(inputs)} d(inputs)d(outputs)

torch.autograd.grad(outputs,
                    inputs,
                    grad_outputs=None,
                    retain_graph=None,
                    create_graph=False)

outputs和inputs已在上述定义中给出;
grad_outputs:多梯度权重;
retain_graph:保留计算图;
create_graph:创建计算图。

import torch

# y = x ** 2
x = torch.tensor([3.], requires_grad=True)
y = torch.pow(x, 2)

# grad_1 = dy / dx = 2x = 6
grad_1 = torch.autograd.grad(y, x, create_graph=True)
print(grad_1)

# grad_2 = d(dy / dx) / dx = 2
grad_2 = torch.autograd.grad(grad_1, x)
print(grad_2)

输出

(tensor([6.], grad_fn=<MulBackward0>),)
(tensor([2.]),)

autograd注意事项

1.梯度不会自动清零

import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

for i in range(4):
    a = torch.add(w, x)
    b = torch.mul(w, x)
    y = torch.mul(a, b)
    y.backward()
    print("w's grad: ", w.grad)
    # w.grad.zero_()

输出:

w's grad:  tensor([8.])
w's grad:  tensor([16.])
w's grad:  tensor([24.])
w's grad:  tensor([32.])

由此可以看出,在不加上注释掉的那一行时,梯度在w处是不断累积的。而如果我们把print后面的那句w.grad.zero_()加上,输出就会变成:

w's grad:  tensor([8.])
w's grad:  tensor([8.])
w's grad:  tensor([8.])
w's grad:  tensor([8.])

w.grad.zero_()的意思就是把w处积累的梯度清零。

2.依赖于叶子节点的节点,requires_grad默认为True

可以从上面的代码中发现,我们只有在定义w和x两个tensor时,设置requires_grad为True。这个参数在定义tensor时默认为False。后面我们的a、b、y都没有设置这个参数。

如果我们定义w和x的时候不加上requires_grad=True,那么y.backward()这一步就会报错,因为我们的预设,这两个tensor不需要梯度,于是就无法求梯度。而w和x是我们计算图上的叶子节点,所以必须加上requires_grad=True。

而后面通过w和x延伸定义出的a、b、y,由于依赖的w、x的requires_grad是True,那么a、b、y的这个参数也被默认设置为了True,不需要我们手动添加。

3.叶子节点不可执行in-place操作

计算图上叶子节点处的tensor不能进行原地修改。

什么是in-place操作?
t = torch.tensor([1., 2.])
t.add_(3.)
print(t)

输出

tensor([4., 5.])

torch.Tensor.add_就是torch.add的in-place版本。所谓in-place,就是在tensor上进行原地修改。大部分的torch.tensor的运算,名字后面加一个下划线,就变成inplace操作了。

再比如求绝对值:

t = torch.tensor([-1., -2.])
t.abs_()
print(t)

输出

tensor([1., 2.])

知道什么是in-place操作后,我们尝试一下在requires_grad=True的叶子节点上原地修改,代码如下:

import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.mul(w, x)
y = torch.mul(a, b)

w.add_(1)

y.backward()

报错信息:

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1058594.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

web漏洞-PHP反序列化

目录 PHP反序列化序列化反序列化原理涉及技术利用危害CTF靶场 PHP反序列化 序列化 将对象转换成字符串 反序列化 相反&#xff0c;将字符串转换成对象。 数据格式的转换对象的序列化有利于对象的保存和传输&#xff0c;也可以让多个文件共享对象。 原理 未对用户输入的序列化字…

使用关键字interface来声明使用接口-PHP8知识详解

继承特性简化了对象、类的创建&#xff0c;增加了代码的可重用性。但是php8只支持单继承&#xff0c;如果想实现多继承&#xff0c;就需要使用接口。PHP8可以实现多个接口。 接口类通过关键字interface来声明&#xff0c;接口中不能声明变量&#xff0c;只能使用关键字const声明…

vcomp120.dll丢失的详细解决方法,全面分享5个解决方法分享

vcomp120.dll 是 Visual C Redistributable 的一个组件&#xff0c;是许多 Windows 应用程序所必需的动态链接库 (DLL) 之一。如果计算机上缺少 vcomp120.dll 文件&#xff0c;或者该文件已损坏或不正确&#xff0c;可能会导致许多应用程序无法正常运行&#xff0c;出现“无法继…

openGauss学习笔记-88 openGauss 数据库管理-内存优化表MOT管理-内存表特性-使用MOT-MOT使用将磁盘表转换为MOT

文章目录 openGauss学习笔记-88 openGauss 数据库管理-内存优化表MOT管理-内存表特性-使用MOT-MOT使用将磁盘表转换为MOT88.1 前置条件检查88.2 转换88.3 转换示例 openGauss学习笔记-88 openGauss 数据库管理-内存优化表MOT管理-内存表特性-使用MOT-MOT使用将磁盘表转换为MOT …

数据挖掘实验(一)数据规范化【最小-最大规范化、零-均值规范化、小数定标规范化】

一、数据规范化的原理 数据规范化处理是数据挖掘的一项基础工作。不同的属性变量往往具有不同的取值范围&#xff0c;数值间的差别可能很大&#xff0c;不进行处理可能会影响到数据分析的结果。为了消除指标之间由于取值范围带来的差异&#xff0c;需要进行标准化处理。将数据…

Linux系统编程系列之线程的信号处理

一、为什么要有线程的信号处理 由于多线程程序中线程的执行状态是并发的&#xff0c;因此当一个进程收到一个信号时&#xff0c;那么究竟由进程中的哪条线程响应这个信号就是不确定的&#xff0c;只能取决于哪条线程刚好在信号达到的瞬间被调度&#xff0c;这种不确定性在程序逻…

java学生成绩管理信息系统

一、 引言 学生成绩管理信息系统是一个基于Java Swing的桌面应用程序&#xff0c;旨在方便学校、老师和学生对学生成绩进行管理和查询。本文档将提供系统的详细说明&#xff0c;包括系统特性、使用方法和技术实现。 二、 系统特性 2.1 学生管理 添加学生信息&#xff1a;录…

基于SSM农产品商城系统

基于SSM农产品商城系统的设计与实现&#xff0c;前后端分离&#xff0c;文档 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringSpringMVCMyBatisVue工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 农产品列表 产品详情 个人中心 登陆界面 管…

gici-open示例数据运行(1.1开阔环境数据运行)

1、配置数据和处理模式 下载对应的数据集后&#xff0c;首先处理1.1中的开阔环境下数据&#xff0c;将option目录下的配置文件复制到1.1数据目录下&#xff08;若采用ROS编译&#xff0c;则配置文件目录为ros_wrapper/src/gici/option/ros real time estimation xxx.yaml&…

Fiddle日常运用手册(2)-使用过滤器进行接口精准拦截

关于Fiddle的基础界面大家已经了解&#xff0c;日常工作中可以进行简单的抓包和数据分析了。 但是&#xff0c;工作中我们又会发现&#xff0c;单纯的进行批量抓包会抓取很多无效的心跳接口数据导致让我们漏掉一些重要信息。那么如果我们想精准的拦截某一个IP的接口交互数据&am…

231003-四步MacOS-iPadOS设置无线竖屏随航SideCar

Step 0&#xff1a;MacOS到iPad无线竖屏随航显示&#xff0c;最终效果 Step 1&#xff1a; 下载 Better Display Step 2&#xff1a;在设置中新建虚拟屏幕&#xff0c;创建虚拟屏幕 Step 3&#xff1a;进行如下设置 Step 4&#xff1a;注意事项 ⚠️ 设置后的虚拟屏幕与Sideca…

基于SSM的餐厅点菜管理系统的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用Vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

找不到VCRUNTIME140_1.dll怎么办,VCRUNTIME140_1.dll丢失的5个解决方法

在当今的数字时代&#xff0c;我们的生活和工作都离不开电脑。然而&#xff0c;随着科技的发展&#xff0c;我们也会遇到各种各样的问题。其中&#xff0c;VCRUNTIME140_1.dll丢失的问题是许多人都会遇到的困扰。这个问题可能会导致许多应用程序无法正常运行&#xff0c;给我们…

如何在 Google Earth 中创建轨迹、路线并制作动画

如何创建航迹 https://kurviger.de/en Google 地球飞行教程(天桥动画) 选择合适的点 &#xff08;可调整视图快照&#xff09;点击录制&#xff0c;依次点击图标即可

电子计算机核心发展(继电器-真空管-晶体管)

目录 继电器 最大的机电计算机之一——哈弗Mark1号&#xff0c;IBM1944年 背景 组成 性能 核心——继电器 简介 缺点 速度 齿轮磨损 Bug的由来 真空管诞生 组成 控制开关电流 继电器对比 磨损 速度 缺点 影响 代表 第一个可编程计算机 第一个真正通用&am…

Go 代码中的文档和注释

撰写清晰、简洁和全面的代码文档的指南 在软件开发领域&#xff0c;编写代码只占了一半的战斗。另一半则围绕着创建清晰、简洁和全面的文档展开&#xff0c;这些文档不仅有助于开发人员理解代码库&#xff0c;还充当未来开发的路线图。在本指南中&#xff0c;我们将深入探讨编…

蓝桥杯每日一题2023.10.3

杨辉三角形 - 蓝桥云课 (lanqiao.cn) 题目描述 题目分析 40分写法&#xff1a; 可以自己手动构造一个杨辉三角&#xff0c;然后进行循环&#xff0c;用cnt记录下循环数的个数&#xff0c;看哪个数与要找的数一样&#xff0c;输出cnt #include<bits/stdc.h> using na…

协议栈——收发数据(拼接网络包,自动重发,滑动窗口机制)

目录 协议栈何时发送数据&#xff5e; 数据长度 IP模块的分片功能 发送频率 网络包序号&#xff5e;利用syn拼接网络包ack确认网络包完整 确定偏移量 服务器ack确定收到数据总长度 序号作用 双端告知各自序号 协议栈自动重发机制 大致流程 ack等待时间如何调整 是…

动态链接那些事

1、为什么要动态链接 1.1 空间浪费 对于静态链接来说&#xff0c;在程序运行之前&#xff0c;会将程序所需的所有模块编译、链接成一个可执行文件。这种情况下&#xff0c;如果 Program1 和 Program2 都需要用到 Lib.o 模块&#xff0c;那么&#xff0c;内存中和磁盘中实际上就…

WEB3 solidity 带着大家编写测试代码 操作订单 创建/取消/填充操作

好 在我们的不懈努力之下 交易所中的三种订单函数已经写出来了 但是 我们只是编译 确认了 代码没什么问题 但还没有实际的测试过 这个测试做起来 其实就比较的麻烦了 首先要有两个账号 且他们都要在交易所中有存入 我们还是先将 ganache 的虚拟环境启动起来 然后 我们在项目…