【转载+修改】pytorch中backward求梯度方法的具体解析

news2025/1/22 13:08:14

原则上,pytorch不支持张量对张量的求导,它只支持标量对张量的求导
我们先看标量对张量求导的情况

import torch
x=torch.ones(2,2,requires_grad=True)
print(x)
print(x.grad_fn)

输出,由于x是被直接创建的,也就是说它是一个叶子节点,所以它的grad_fn属性的值为None
tensor([[1., 1.], [1., 1.]], requires_grad=True) None
接下来对叶子节点x进行第一步操作,y=x+2

y=x+2
print(y)
print(y.grad_fn)

输出:这里可以看到y的grad_fn属性变成了AddBackward,所以grad_fn属性记录的是该张量的上一步操作。
tensor([[3., 3.], [3., 3.]], grad_fn=) <AddBackward0 object at 0x00000206A7EE2108>
然后进行操作z=y * y * 3,再对z求平均值

z=y*y*3 
out=z.mean()
print(z,out)

输出结果:
tensor([[27., 27.], [27., 27.]], grad_fn=) tensor(27., grad_fn=)
此时我们利用backward()函数来求x的梯度,由于out是求平均值得到的一个标量,所以我们可以不用向backward函数传递一个张量,而是直接计算。

out.backward()
print(x.grad)

输出:tensor([[4.5000, 4.5000], [4.5000, 4.5000]])

我们来手动计算,看得到的结果是否与backword函数得到的结果一致
在这里插入图片描述
显然,结果是一致的。

再来看张量对张量求导的情况
前面已经强调过,pytorch不允许张量对张量求导,所以在使用张量对张量求导的时候,必须要传入一个与被求导张量同形的张量,然后pytorch根据传入的张量与被求导张量作加权求和将其转化为标量,这里比较晦涩难懂,没关系,接下来我们用例子来解释
首先创建一个叶子节点x

x=torch.tensor([[1.0,2.0],[3.0,4.0]],requires_grad=True)
print(x)

tensor([[1., 2.], [3., 4.]], requires_grad=True)
接下来计算y=3*x

y=3*x
print(y)

tensor([[ 3., 6.], [ 9., 12.]], grad_fn=)
接下来我们直接用y求导y.backward()
毫无意外,直接报错,这就印证了前面说过的pytorch不支持张量对张量直接求导

RuntimeError: grad can be implicitly created only for scalar outputs

于是我们构建一个与y同形的张量z,(一般是传入一个单位张量,可以参考y.backward(torch.ones_like(y)),这样计算得到的参数梯度没有张量z的影响)把z作为y.backward()的参数求y对x的导。

z=torch.tensor([[1.0,0.1],[0.01,0.001]],dtype=torch.float)
y.backward(z)
print(x.grad)

输出结果:tensor([[3.0000, 0.3000], [0.0300, 0.0030]])(张量的梯度是一个与原张量同形的张量)
事实上,到这里我们仍一头雾水,不知道这个结果是如何得出的,下面给出他的通用计算公式(需要注意的是,该公式只是用来方便计算的,属于计算技巧)
在这里插入图片描述
至于y对x的导数,如果有多层复合函数,利用链式法则计算即可。上面的例子比较简单,y对x求导的结果是3,再乘以张量z,很容验证得到同样的结果。
接下来我们推导上述计算公式,上文已经提到,对于表达式y.backward(z) (y、z为同形张量)的计算过程,实际上将y与z加权求和得到标量m,然后用m对x求导得到结果,也就是说实际上有这样一步计算m=torch.sum(y*z)
我们可以来验证这一步计算的正确性

m = torch.sum(y*z)
print(m)

输出tensor(3.7020, grad_fn=)可以看到的是m是一个标量。
接着再用m对x求导

m.backward()
print(x.grad)

很容易得到上述结果tensor([[3.0000, 0.3000], [0.0300, 0.0030]])
下面给出上述计算公式的推导:
在这里插入图片描述
参考链接:https://blog.csdn.net/weixin_45021364/article/details/105194187

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/779495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue.js uni-app 混合模式原生App webview与H5的交互

在现代移动应用开发中&#xff0c;原生App与H5页面之间的交互已经成为一个常见的需求。本文将介绍如何在Vue.js框架中实现原生App与H5页面之间的数据传递和方法调用。我们将通过一个简单的示例来展示如何实现这一功能。附完整源码下载地址:https://ext.dcloud.net.cn/plugin?i…

Java集成openAi的ChatGPT实战

效果图&#xff1a; 免费体验地址&#xff1a;AI智能助手 具体实现 public class OpenAiUtils {private static final Log LOG LogFactory.getLog(OpenAiUtils.class);private static OpenAiProxyService openAiProxyService;public OpenAiUtils(OpenAiProxyService openAiP…

【C++】入门 --- 命名空间

文章目录 &#x1f36a;一、前言&#x1f369;1、C简介&#x1f369;2、C关键字 &#x1f36a;二、命名冲突&#x1f36a;三、命名空间&#x1f369;1、命名空间定义&#x1f369;2、命名空间的使用 &#x1f36a;四、C输入&输出 &#x1f36a;一、前言 本篇文章是《C 初阶…

Data Transfer Object-DTO,数据传输对象,前端参数设计多个数据表对象

涉及两张表的两个实体对象 用于在业务逻辑层和持久层&#xff08;数据库访问层&#xff09;之间传输数据。 DTO的主要目的是将多个实体&#xff08;Entity&#xff09;的部分属性或多个实体关联属性封装成一个对象&#xff0c;以便在业务层进行数据传输和处理&#xff0c;从而…

八、HAL_UART(串口)的接收和发送

1、开发环境 (1)Keil MDK: V5.38.0.0 (2)STM32CubeMX: V6.8.1 (3)MCU: STM32F407ZGT6 2、UART和USART的区别 2.1、UART (1)通用异步收发收发器&#xff1a;Universal Asynchronous Receiver/Transmitter)。 2.2、USART (1)通用同步异步收发器&#xff1a;Universal Syn…

【《R4编程入门与数据科学实战》——一本“能在日常生活中使用统计学”的书】

《R 4编程入门与数据科学实战》的两名作者均为从事编程以及教育方面的专家&#xff0c;他们用详尽的语言&#xff0c;以初学者的角度进行知识点的讲解&#xff0c;每个细节都手把手教学,以让读者悉数掌握所有知识点&#xff0c;在每章的结尾都安排理论与实操相结合的习题。与同…

banner轮播图实现、激活状态显示和分类列表渲染、解决路由缓存问题、使用逻辑函数拆分业务(一级分类)【Vue3】

一级分类 - banner轮播图实现 分类轮播图实现 分类轮播图和首页轮播图的区别只有一个&#xff0c;接口参数不同&#xff0c;其余逻辑完成一致 适配接口 export function getBannerAPI (params {}) {// 默认为1 商品为2const { distributionSite 1 } paramsreturn httpIn…

VTK是如何显示一个三维立体图像的

VTK是如何显示一个三维立体图像的 1、文字描述2、图像演示 1、文字描述 2、图像演示

MySQL-事务-介绍与操作

思考 假设在一个场景中&#xff0c;学工部解散了&#xff0c;需要删除该部门及该部门下的员工对应的SQL语句涉及的数据表信息如下 员工表 部门表 实现的SQL语句 -- todo 事务 -- 删除学工部 -- 删除1号部门 delete from tb_dept where id 1; -- 删除学工部下的员工 delete …

SPEC CPU 2006 docker gcc:4 静态编译版本 Ubuntu 22.04 LTS 测试报错Invalid Run

runspec.sh #!/bin/bash source shrc ulimit -s unlimited runspec -c gcc41.cfg -T all -n 1 int fp > runspec.log 2>&1 & tail -f runspec.log runspec.log 由于指定了-T all&#xff0c;导致-n 1 失效&#xff0c;用例运行了三次&#xff08;后续验证&…

【LeetCode 75】 第十题(283)移动零

目录 题目: 示例: 分析: 代码运行结果: 题目: 示例: 分析: 给一个数组,要求将数组中的零都移动到数组的末尾. 首先我们可以遍历一边数组,遇到0的时候就在数组中把0删除,并且统计0的数量. 遍历完成以后数组中就没有0了,这时我们再在数组的后面添上之前统计的0的数量个0. …

IntelliJ IDEA Copyright添加

IDEA代码文件的版权(copyright)信息配置 1. 快速创建Copyright 版权配置文件 1.1 创建copyright文件 依次点击 File > Settings… > Editor > Copyright > 点击 “” 号或 “Add profile”***&#xff0c;弹出创建 Copyright Profile 操作窗口&#xff0c;在***文…

【iOS】App仿写--网易云音乐

文章目录 前言一、首页界面二、我的界面三、账号界面总结 前言 在暑假之前仿写了网易云app&#xff0c;一直没总结。 网易云app主要让我熟悉了视图之间的相互嵌套的用法与关系以及自定义cell的用法&#xff0c;特此撰写以下博客进行总结。 一、首页界面 首先来看一下完成的效…

深度挖掘《TCP与UDP》

文章目录 UDPTCPTCP特性TCP是如何实现的可靠传输&#xff1f;序号和确认序号为啥网络上会后发先至 什么是丢包&#xff0c;如何解决丢包&#xff1f;TCP建立连接&#xff1a;三次握手四次交互&#xff0c;为什叫三次握手&#xff1f;三次握手起到什么效果&#xff1f;达到什么目…

YZ06:加载项是否加载的判断

【分享成果&#xff0c;随喜正能量】人生&#xff0c;因有缘而聚&#xff0c;因情而暖&#xff1b;人生&#xff0c;因不珍惜而散&#xff0c;因恨而亡&#xff1b;活着就要善待自己&#xff0c;不属于自己的不强求&#xff0c;不是真心的不必喜欢&#xff0c;时间在变&#xf…

Spring初识(三)

文章目录 前言一.存储 Bean 对象1.1 类注解的用法1.2 为什么要使用这么多类注解1.2.1 为什么需要五大类注解 1.3 各个类注解的关系1.4 Bean的命名规则1.5 方法注解的使用 二.取出 Bean 对象2.1 属性注入2.2 Setter注入2.3 构造方法注入 三.总结 前言 经过前面的学习,我们已经学…

【C++】STL使用仿函数控制优先级队列priority_queue

文章目录 前言一、priority_queue的底层实现二、使用仿函数控制priority_queue的底层总结 前言 本文章讲解CSTL的容器适配器&#xff1a;priority_queue的实现&#xff0c;并实现仿函数控制priority_queue底层。 一、priority_queue的底层实现 priority_queue叫做优先级队列&…

uview2.0使用u-calendar 的formatter属性,在formatter方法里无法访问this的bug,解决办法!!!!

uview 版本2.0.36 文档 使用该文档的案例&#xff0c;在 formatter打印this也会是undefined。 自己写了个demo 父给子传值v-bind传一个函数&#xff0c;然后在这个函数里面打印this&#xff0c;this是子组件的实例&#xff0c;但是不知道为什么formatter里会打印undefined。希…

pytorch工具——使用pytorch构建一个神经网络

目录 构建模型模型中的可训练参数假设输入尺寸为32*32损失函数反向传播更新网络参数 构建模型 import torch import torch.nn as nn import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net,self).__init__()#定义第一层卷积层&#xff0c;输入维…

配置NFS服务

环境 环境 ubuntu 10.4 vm 7.1 终端 ifconfig 得到 ubuntu资料 INET ADDR 192.168.0.4 BCAST 192.168.0.255 MASK 255.255.255.0 操作前先关闭防火墙 关闭防火墙&#xff1a; 命令&#xff1a;sudo ufw disable 打开防火墙 命令&#xff1a;sudo ufw enable 配置过程 一 安…