PyTorch求导相关

news2024/10/24 12:45:58

PyTorch是动态图,即计算图的搭建和运算是同时的,随时可以输出结果;而TensorFlow是静态图。

在pytorch的计算图里只有两种元素:数据(tensor)和 运算(operation)

运算包括了:加减乘除、开方、幂指对、三角函数等可求导运算

数据可分为:叶子节点(leaf node)和非叶子节点;叶子节点是用户创建的节点,不依赖其它节点;它们表现出来的区别在于反向传播结束之后,非叶子节点的梯度会被释放掉,只保留叶子节点的梯度,这样就节省了内存。如果想要保留非叶子节点的梯度,可以使用retain_grad()方法。

torch.tensor 具有如下属性:

  • 查看 是否可以求导 requires_grad
  • 查看 运算名称 grad_fn
  • 查看 是否为叶子节点 is_leaf
  • 查看 导数值 grad

针对requires_grad属性,自己定义的叶子节点默认为False,而非叶子节点默认为True,神经网络中的权重默认为True。判断哪些节点是True/False的一个原则就是从你需要求导的叶子节点到loss节点之间是一条可求导的通路。

当我们想要对某个Tensor变量求梯度时,需要先指定requires_grad属性为True,指定方式主要有两种:

x = torch.tensor(1.).requires_grad_() # 第一种

x = torch.tensor(1., requires_grad=True) # 第二种

PyTorch提供两种求梯度的方法:backward() and torch.autograd.grad() ,他们的区别在于前者是给叶子节点填充.grad字段,而后者是直接返回梯度给你,我会在后面举例说明。还需要知道y.backward()其实等同于torch.autograd.backward(y)

一个简单的求导例子是:y=(x+1)∗(x+2) ,计算 ∂y/∂x ,假设给定 x=2
先画出计算图

手算:∂y/∂x=(x+2)*1+(x+1)*1->7

使用backward()

x = torch.tensor(2., requires_grad=True)

a = torch.add(x, 1)
b = torch.add(x, 2)
y = torch.mul(a, b)

y.backward()
print(x.grad)
>>>tensor(7.)

看一下这几个tensor的属性

print("requires_grad: ", x.requires_grad, a.requires_grad, b.requires_grad, y.requires_grad)
print("is_leaf: ", x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
print("grad: ", x.grad, a.grad, b.grad, y.grad)

>>>requires_grad:  True True True True
>>>is_leaf:  True False False False
>>>grad:  tensor(7.) None None None

使用backward()函数反向传播计算tensor的梯度时,并不计算所有tensor的梯度,而是只计算满足这几个条件的tensor的梯度:1.类型为叶子节点、2.requires_grad=True、3.依赖该tensor的所有tensor的requires_grad=True。所有满足条件的变量梯度会自动保存到对应的grad属性里。

使用autograd.grad()

x = torch.tensor(2., requires_grad=True)

a = torch.add(x, 1)
b = torch.add(x, 2)
y = torch.mul(a, b)

grad = torch.autograd.grad(outputs=y, inputs=x)
print(grad[0])
>>>tensor(7.)

因为指定了输出y,输入x,所以返回值就是 ∂x/∂y 这一梯度,完整的返回值其实是一个元组,保留第一个元素就行,后面元素是

二阶求导

求一阶导可以用backward()

x = torch.tensor(2., requires_grad=True)
y = torch.tensor(3., requires_grad=True)

z = x * x * y

z.backward()
print(x.grad, y.grad)
>>>tensor(12.) tensor(4.)

也可以用autograd.grad()

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()

z = x * x * y

grad_x = torch.autograd.grad(outputs=z, inputs=x)
print(grad_x[0])
>>>tensor(12.)

为什么不在这里面同时也求对y的导数呢?因为无论是backward还是autograd.grad在计算一次梯度后图就被释放了,如果想要保留,需要添加retain_graph=True

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()

z = x * x * y

grad_x = torch.autograd.grad(outputs=z, inputs=x, retain_graph=True)
grad_y = torch.autograd.grad(outputs=z, inputs=y)

print(grad_x[0], grad_y[0])
>>>tensor(12.) tensor(4.) 

再来看如何求高阶导,理论上其实是上面的grad_x再对x求梯度,试一下看

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()

z = x * x * y

grad_x = torch.autograd.grad(outputs=z, inputs=x, retain_graph=True)
grad_xx = torch.autograd.grad(outputs=grad_x, inputs=x)

print(grad_xx[0])
>>>RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

报错了,虽然retain_graph=True保留了计算图和中间变量梯度, 但没有保存grad_x的运算方式,需要使用creat_graph=True在保留原图的基础上再建立额外的求导计算图,也就是会把 ∂z/∂x=2xy 这样的运算存下来

# autograd.grad() + autograd.grad()
x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()

z = x * x * y

grad_x = torch.autograd.grad(outputs=z, inputs=x, create_graph=True)
grad_xx = torch.autograd.grad(outputs=grad_x, inputs=x)

print(grad_xx[0])
>>>tensor(6.)

grad_xx这里也可以直接用backward(),相当于直接从 ∂z/∂x=2xy 开始回传

# autograd.grad() + backward()
x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()

z = x * x * y

grad = torch.autograd.grad(outputs=z, inputs=x, create_graph=True)
grad[0].backward()

print(x.grad)
>>>tensor(6.)

 也可以先用backward()然后对x.grad这个一阶导继续求导

# backward() + autograd.grad()
x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()

z = x * x * y

z.backward(create_graph=True)
grad_xx = torch.autograd.grad(outputs=x.grad, inputs=x)

print(grad_xx[0])
>>>tensor(6.)

那是不是也可以直接用两次backward()呢?第二次直接x.grad从开始回传,我们试一下

# backward() + backward()
x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()

z = x * x * y

z.backward(create_graph=True) # x.grad = 12
x.grad.backward()

print(x.grad)
>>>tensor(18., grad_fn=<CopyBackwards>)

发现了问题,结果不是6,而是18,发现第一次回传时输出x梯度是12。这是因为PyTorch使用backward()时默认会累加梯度,需要手动把前一次的梯度清零

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()

z = x * x * y

z.backward(create_graph=True)
x.grad.data.zero_()
x.grad.backward()

print(x.grad)
>>>tensor(6., grad_fn=<CopyBackwards>)

向量求导

有没有发现前面都是对标量求导,如果不是标量会怎么样呢?

x = torch.tensor([1., 2.]).requires_grad_()
y = x + 1

y.backward()
print(x.grad)
>>>RuntimeError: grad can be implicitly created only for scalar outputs

x = torch.tensor([1., 2.]).requires_grad_()
y = x * x

y.sum().backward()
print(x.grad)
>>>tensor([2., 4.])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2222396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Psychophysiology:脑-心交互如何影响个体的情绪体验?

摘要 情绪的主观体验与对身体(例如心脏)活动变化的情境感知和评估相关。情绪唤醒增加与高频心率变异性(HF-HRV)降低、EEG顶枕区α功率降低以及心跳诱发电位(HEP)振幅较高有关。本研究使用沉浸式虚拟现实(VR)技术来研究与情绪唤醒相关的脑心相互作用&#xff0c;以实现自然而可…

SSM考研科目学习APP-计算机毕业设计源码90377

摘 要 基于Android的考研科目学习系统的设计与实现&#xff0c;旨在为广大考研学子提供一个便捷、高效的学习平台。该系统充分利用Android操作系统的广泛普及与灵活定制性&#xff0c;结合考研科目的特点和需求&#xff0c;实现了个性化的学习方案、丰富的题库资源以及智能化…

【个人同步与备份】电脑(Windows)与手机/平板(Android)之间文件同步

文章目录 1. syncthing软件下载2. syncthing的使用2.1. 添加设备2.1.1. syncthing具备设备发现功能&#xff0c;因此安装好软件&#xff0c;只需确认设备信息是否对应即可2.1.2. 如果没有发现到&#xff0c;可以通过设备ID连接2.1.3. 设置GUI身份验证用户&#xff0c;让无关设备…

LeetCode: 3274. 检查棋盘方格颜色是否相同

一、题目 给你两个字符串 coordinate1 和 coordinate2&#xff0c;代表 8 x 8 国际象棋棋盘上的两个方格的坐标。   以下是棋盘的参考图。   如果这两个方格颜色相同&#xff0c;返回 true&#xff0c;否则返回 false。   坐标总是表示有效的棋盘方格。坐标的格式总是先…

大模型技术学习过程梳理,零基础入门到精通,收藏这一篇就够了

“ 学习是一个从围观到宏观&#xff0c;从宏观到微观的一个过程 ” 今天整体梳理一下大模型技术的框架&#xff0c;争取从大模型所涉及的理论&#xff0c;技术&#xff0c;应用等多个方面对大模型进行梳理。 01 — 大模型技术梳理 这次梳理大模型不仅仅是大模型本身的技术…

接口测试(八)jmeter——参数化(CSV Data Set Config)

一、CSV Data Set Config 需求&#xff1a;批量注册5个用户&#xff0c;从CSV文件导入用户数据 1. 【线程组】–>【添加】–>【配置元件】–>【CSV Data Set Config】 2. 【CSV数据文件设置】设置如下 3. 设置线程数为5 4. 运行后查看响应结果

vue3项目页面实现echarts图表渐变色的动态配置

完整代码可点击vue3项目页面实现echarts图表渐变色的动态配置-星林社区 https://www.jl1mall.com/forum/PostDetail?postId202410151031000091552查看 一、背景 在开发可配置业务平台时&#xff0c;需要实现让用户对项目内echarts图表的动态配置&#xff0c;让用户脱离代码也…

基于Matlab 人脸识别技术

Matlab 人脸识别技术 算法流程&#xff1a; 本系统运用PCA算法来实现人脸特征提取&#xff0c;然后通过计算欧式距离来判别待识别测试人脸&#xff0c;本个系统框架图如下&#xff1a; 图&#xff1a; 人脸识别系统框架图 整个系统的流程是这样的&#xff0c;首先通过图像采…

给哔哩哔哩bilibili电脑版做个手机遥控器

前言 bilibili电脑版可以在电脑屏幕上观看bilibili视频。然而&#xff0c;电脑版的bilibili不能通过手机控制视频翻页和调节音量&#xff0c;这意味着观看视频时需要一直坐在电脑旁边。那么&#xff0c;有没有办法制作一个手机遥控器来控制bilibili电脑版呢&#xff1f; 首先…

基于SpringBoot+Vue+uniapp的时间管理小程序的详细设计和实现(源码+lw+部署文档+讲解等)

详细视频演示 请联系我获取更详细的演示视频 项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念&#xff0c;提供了一套默认的配置&#xff0c;让开发者可以更专注于业务逻辑而不…

【文件加密系统】华企盾DSC服务程序启动失败解决办法

问题原因&#xff1a; 1.sa账户密码错误导致连接数据数据库失败无法启动DSC服务 解决方法&#xff1a; 用windows身份验证进入数据库更改sa用户密码&#xff1a;安全性>登录名>sa>右键属性>更改密码 ※如果显示请输入秘钥更改&#xff0c;使用更改完密码的sa账户登…

从0开始深度学习(16)——暂退法(Dropout)

上一章的过拟合是由于数据不足导致的&#xff0c;但如果我们有比特征多得多的样本&#xff0c;深度神经网络也有可能过拟合 1 扰动的稳健性 经典泛化理论认为&#xff0c;为了缩小训练和测试性能之间的差距&#xff0c;应该以简单的模型为目标&#xff0c;即模型以较小的维度的…

机器学习与神经网络:科技的星辰大海

前提 近日&#xff0c;2024年诺贝尔物理学奖颁发给了机器学习与神经网络领域的研究者&#xff0c;这是历史上首次出现这样的情况。这项奖项原本只授予对自然现象和物质的物理学研究作出重大贡献的科学家&#xff0c;如今却将全球范围内对机器学习和神经网络的研究和开发作为了一…

RSocket vs WebSocket:Spring Boot 3.3 中的两大实时通信利器

RSocket vs WebSocket&#xff1a;Spring Boot 3.3 中的两大实时通信利器 随着现代互联网应用的不断发展&#xff0c;实时通信已经成为许多应用程序不可或缺的功能。无论是社交网络、在线游戏还是数据监控系统&#xff0c;实时通信都能提供快速、无缝的信息交换。而实现实时通…

“主升筹码”,底部建仓信号+主升加仓位置,不错过任何行情

使用技巧 指标分为主图和副图 其中&#xff0c;主图主升筹码信号较多&#xff0c;副图的信号较少。这里&#xff0c;我说一个选股思路&#xff0c;就是底部主升筹码共振进场&#xff0c;上升过程中主图信号当作加仓信号。 选股&#xff0c;提供一个主升筹码共振选股&#xff0…

Redis 5.0 安装配置(Windows)

Redis 5.0之后支持Redis Stream等功能 下载地址&#xff1a;Releases tporadowski/redis GitHub 点击运行redis-server.exe 此外&#xff1a;Redis 6.0及以后版本目前都没有Windows版

【越狱插件】内网穿透 frpc、frps插件

内网穿透、frp、frpc、frps https://zhaoboy9692.github.io/repo 越狱源 https://zhaoboy9692.github.io/repo 苦于在ios越狱下没有frp穿透使用 特地开发了的越狱插件 基于最新frp0.48编译 ios14.6测试没问题 有问题及时反馈

ubuntu中使用cmake编译报错No CMAKE_CXX_COMPILER could be found.的解决方法

ubuntu中使用cmake编译报错No CMAKE_CXX_COMPILER could be found.的解决方法 No CMAKE_CXX_COMPILER could be found.Could NOT find CUDA (missing: CUDA_NVCC_EXECUTABLE CUDA_CUDART_LIBRARY)Could not find a package configuration file provided by "OpenCV" …

【SQL|大数据|数据清洗|过滤】where条件中 “ != “ 和 “ NOT IN() ” 对NULL的处理

对数据进行清洗过滤的时候&#xff0c;NULL往往是一个很特殊的存在&#xff0c;对NULL值的存在通常有以下三种方式 1、保留NULL 2、过滤掉NULL 3、将NULL替换为其他符合业务需求的默认常量 下面是一些常用处理NULL的方式&#xff1a; 如下图所示数据源&#xff1a; car_vin&…

android openGL ES详解——缓冲区VBO/VAO/EBO/FBO

目录 一、缓冲区对象概念 二、分类 三、顶点缓冲区对象VBO 1、概念 2、为什么使用VBO 3、如何使用VBO 生成缓冲区对象 绑定缓冲区对象 输入缓冲区数据 更新缓冲区中的数据 删除缓冲区 4、VBO应用 四、顶点数组对象VAO 1、概念 2、为什么使用VAO 3、如何使用VAO…