pytorch中backward()函数与gradient 参数详解

news2024/11/28 5:34:54

矩阵乘法的例子1

以下例来说明backward中参数gradient的作用

注意在本文中@表示矩阵乘法,*表示对应元素相乘

求A求偏导

试运行代码1

import torch

# [1,2]@[2*3]=[1,3]
A = torch.tensor([[1., 2.]], requires_grad=True)
B = torch.tensor([[10., 20., 30.], [100., 200., 300.]], requires_grad=True)
Z = A @ B

Z.backward()  # dz1/dx1, dz2/dx1

运行结果1

Traceback (most recent call last):
  File "D:\SoftProgram\JetBrains\PyCharm Community Edition 2023.1\plugins\python-ce\helpers\pydev\pydevd.py", line 1496, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "D:\SoftProgram\JetBrains\PyCharm Community Edition 2023.1\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:\program\python\Bili_deep_thoughts\RNN\help\h_backward_5.py", line 12, in <module>
    Z.backward()  # dz1/dx1, dz2/dx1
  File "D:\SoftProgram\JetBrains\anaconda3_202303\lib\site-packages\torch\_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "D:\SoftProgram\JetBrains\anaconda3_202303\lib\site-packages\torch\autograd\__init__.py", line 166, in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
  File "D:\SoftProgram\JetBrains\anaconda3_202303\lib\site-packages\torch\autograd\__init__.py", line 67, in _make_grads
    raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

查到错误原因

如果 Tensor 是一个标量(即它包含一个元素的数据),则不需要为 backward() 指定任何参数,但是如果它有更多的元素,则需要指定一个 gradient 参数,该参数是形状匹配的张量。

看完有一点了解,但是不太明白“该参数是形状匹配的张量”。

尝试修改代码2

import torch

# [1,2]@[2*3]=[1,3]
A = torch.tensor([[1., 2.]], requires_grad=True)
B = torch.tensor([[10., 20., 30.], [100., 200., 300.]], requires_grad=True)
Z = A @ B

Z.backward(torch.ones_like(A))  # dz1/dx1, dz2/dx1

运行结果2

Traceback (most recent call last):
  File "D:\SoftProgram\JetBrains\PyCharm Community Edition 2023.1\plugins\python-ce\helpers\pydev\pydevd.py", line 1496, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "D:\SoftProgram\JetBrains\PyCharm Community Edition 2023.1\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:\program\python\Bili_deep_thoughts\RNN\help\h_backward_5.py", line 12, in <module>
    Z.backward(torch.ones_like(A))  # dz1/dx1, dz2/dx1
  File "D:\SoftProgram\JetBrains\anaconda3_202303\lib\site-packages\torch\_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "D:\SoftProgram\JetBrains\anaconda3_202303\lib\site-packages\torch\autograd\__init__.py", line 166, in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
  File "D:\SoftProgram\JetBrains\anaconda3_202303\lib\site-packages\torch\autograd\__init__.py", line 50, in _make_grads
    raise RuntimeError("Mismatch in shape: grad_output["
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 2]) and output[0] has a shape of torch.Size([1, 3])

看到这里明白了点,“该参数是形状匹配的张量”更具体的应该是“该参数是和输出Z形状匹配的张量”。

解释

由于pytorch只能对标量进行求导,对于非标量的求导,其思想是将其转换成标量,即将其各个元素相加,然后求导。
但是如果只是简单的把全部元素相加,就无法知道z2对a2的导数,因此引入gradient参数,其形状和输出Z的形状一致,gradient的参数是累加时Z中对应参数的系数,以一开始的矩阵乘法的例子1为例,即

 

 因此矩阵Z对A的导数,就转换成了标量R对A导数,即

注意:对A求导的结果是A的形状

 验证

只求z1对A的导数,即G=[1,0,0],运行代码3

import torch

# [1,2]@[2*3]=[1,3],gradient=[1,3]
A = torch.tensor([[1., 2.]], requires_grad=True)
B = torch.tensor([[10., 20., 30.], [100., 200., 300.]], requires_grad=True)
Z = A @ B

# Z.backward(torch.ones_like(A))  # dz1/dx1, dz2/dx1
Z.backward(gradient=torch.tensor([[1., 0., 0.]])) # dz1/dx1, dz2/dx1
print('Z=', Z)
print('A.grad=', A.grad)
print('B.grad=', B.grad)

运行结果3

Z= tensor([[210., 420., 630.]], grad_fn=<MmBackward0>)

A.grad= tensor([[ 10., 100.]])
B.grad= tensor([[1., 0., 0.],
        [2., 0., 0.]])

 尝试解释一下对B的导数

哈达马积的例子

代码4

import torch

# [1,3]*[2*3]=[1,3],gradient=[2,3]
A1 = torch.tensor([[1., 2., 3.]], requires_grad=True)
B1 = torch.tensor([[10., 20., 30.], [100., 200., 300.]], requires_grad=True)
Z1 = A1 * B1 
G1 = torch.zeros_like(Z1)
G1[1, 0] = 1
Z1.backward(gradient=G1)  # dz1/dx1, dz2/dx1
print('Z1=', Z1)
print('Z1.shape=', Z1.shape)
print('A1.grad=', A1.grad)
print('B1.grad=', B1.grad)
pass

运行结果4

Z1= tensor([[ 10.,  40.,  90.],
        [100., 400., 900.]], grad_fn=<MulBackward0>)
Z1.shape= torch.Size([2, 3])
A1.grad= tensor([[100.,   0.,   0.]])
B1.grad= tensor([[0., 0., 0.],
        [1., 0., 0.]])

 矩阵乘法的例子2

代码5

import torch

A11 = torch.tensor([[0.1, 0.2], [1., 2.]], requires_grad=True)
B11 = torch.tensor([[10., 20., 30.], [100., 200., 300.]], requires_grad=True)
Z11 = torch.mm(A11, B11)
J11 = torch.zeros((2, 2))
G11 = torch.zeros_like(Z11)
G11[1, 0] = 1
Z11.backward(G11, retain_graph=True)  # dz1/dx1, dz2/dx1
print('Z11=', Z11)
print('Z11.shape=', Z11.shape)
print('A11.grad=', A11.grad)
print('B11.grad=', B11.grad)
pass

运行结果5

Z11= tensor([[ 21.,  42.,  63.],
        [210., 420., 630.]], grad_fn=<MmBackward0>)
Z11.shape= torch.Size([2, 3])
A11.grad= tensor([[  0.,   0.],
        [ 10., 100.]])
B11.grad= tensor([[1., 0., 0.],
        [2., 0., 0.]])

参考

Pytorch中的.backward()方法 - 知乎

 Pytorch autograd,backward详解 - 知乎

PyTorch中的backward - 知乎

grad can be implicitly created only for scalar outputs_一只皮皮虾x的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/528744.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

『赠书活动 | 第五期』《人工智能数学基础》

&#x1f497;wei_shuo的个人主页 &#x1f4ab;wei_shuo的学习社区 &#x1f310;Hello World &#xff01; 『赠书活动 &#xff5c; 第五期』 本期书籍&#xff1a;《人工智能数学基础》 赠书规则&#xff1a;评论区&#xff1a;点赞&#xff5c;收藏&#xff5c;留言 评论…

DAP-seq技术鉴定葡萄MADS-box转录因子VvMADS28的靶基因,揭示葡萄胚珠发育的调控机制

葡萄是一种营养丰富、美味可口的水果&#xff0c;深受世界各地消费者的喜爱。近年来&#xff0c;无籽葡萄也越来越受大家的欢迎&#xff0c;因此&#xff0c;无籽葡萄品种的培育成为一个重要的育种目标&#xff0c;而了解葡萄种子发育的分子遗传调控机制对于无籽栽培品种的培育…

Kyligence Zen 产品体验-好用的指标平台

文章目录 一、Kyligence Zen概念一、BI发展历史1.以报表为核心的IT响应式服务2.以宽表为核心的自助可视化分析3.以指标为核心的可视化分析、增强分析 二、什么是Kyligence Zen&#xff1f;1.官网的介绍2.个人的理解 二、产品体验一、创建账号二、指标1.直接创建指标2.导入指标数…

postgresql集群编译安装

postgreSQL集群部署1环境准备&#xff08;三台服务器全部执行&#xff09; 2.1.1 准备三台虚拟机 服务器名称 服务器IP 描述 Pgsql-0 xxx.xxx.xxx.xxx master节点 Pgsql-1 xxx.xxx.xxx.xxx slave1节点 Pgsql-2 xxx.xxx.xxx.xxx slave2节点 2.1.2 安装编译需要的相…

ANR实战案例 - 通用方法总结

系列文章目录 提示&#xff1a;这里可以添加系列文章的所有文章的目录&#xff0c;目录需要自己手动添加 例如&#xff1a;第一章 Python 机器学习入门之pandas的使用 文章目录 系列文章目录前言一、业务耗时1.登录Dialog优化2.子线程更新通知栏 二、频繁调用1.底部Tab资源初始…

【LLM系列之LLaMA】LLaMA: Open and Efficient Foundation Language Models

论文题目&#xff1a;《LLaMA: Open and Efficient Foundation Language Models》 论文链接&#xff1a;https://arxiv.org/pdf/2302.13971.pdf github链接&#xff1a;https://github.com/facebookresearch/llama/tree/main huggingface链接&#xff1a;https://huggingface.c…

[离散数学] 函数

文章目录 函数判断函数的条件复合函数复合函数的性质 逆函数 函数 判断函数的条件 dom F A ⇔ \Leftrightarrow ⇔所有x 都有 F&#xff08;x&#xff09;与之对应 有唯一的与其对应 < x , y > ∈ f ∧ < y , z > ∈ f ⇒ y z <x,y>\in f \land <y,z…

【C++】2. 进入面向对象 - 类和对象的初步认识

专栏导读 &#x1f341;作者简介&#xff1a;余悸&#xff0c;在读本科生一枚&#xff0c;致力于 C 方向学习。 &#x1f341;收录于 C专栏&#xff0c;本专栏主要内容为 C 初阶、C 进阶、STL 详解等&#xff0c;持续更新中&#xff01; &#x1f341;相关专栏推荐&#xff1a;…

TikTok新手做什么账号,Tiktok类目怎么选

有很多刚入驻TikTok的小白不知道要选择什么类目才比较容易起量。看完这篇后&#xff0c;相信你们的疑惑就会烟消云散。选择对了类目对以后的产品带货有很大的促进作用&#xff0c;今天我给你们分享6种适合TikTok小店运营的账号类型&#xff0c;以及一些比较推荐的类目。 TikTok…

测试和调试之Python高级篇

测试和调试 在软件开发过程中&#xff0c;测试和调试是非常重要的环节。测试用于验证代码的正确性和可靠性&#xff0c;而调试则是为了找到并解决代码中存在的问题。下面将会详细介绍单元测试、集成测试、断言、测试框架、调试工具和技巧。 单元测试 单元测试是指对软件中的…

linux环境安装使用redis详解

Redis 1. NoSQL的引言 NoSQL ( Not Only SQL )&#xff0c;意即 不仅仅是SQL , 泛指非关系型的数据库。Nosql这个技术门类,早期就有人提出,发展至2009年趋势越发高涨。 2. 为什么是NoSQL 随着互联网网站的兴起&#xff0c;传统的关系数据库在应付动态网站&#xff0c;特别是超大…

OpenPCDet系列 | 5.1 PointPillars算法——PillarVFE特征构建与编码模块

文章目录 PillarVFE模块1. PillarVFE初始化2. PillarVFE数据处理2.1 特征构造2.2 掩码构造2.3 特征编码 OpenPCDet的整个结构图&#xff1a; PillarVFE模块属于VFE结构的其中一种&#xff0c;所以可以在PCDet中的backbone_3d目录下&#xff0c;可以找到vfe目录结构。在OpenPCDe…

【JOSEF约瑟 JL-8GA/12端子排电流继电器 整定范围宽、功耗低】

JL-8GA/12端子排电流继电器名称:端子排电流继电器型号:JL-8GA/12品牌:JOSEF约瑟功率消耗:≤5W触点容量:250V5A额定电压:58,100,110,220V 系列型号&#xff1a; JL-8GA/11端子排电流继电器&#xff1b; JL-8GA/12端子排电流继电器&#xff1b; JL-8GA/13端子排电流继电器&am…

MySQL(1) ---- 数据库介绍与MySQL概述

介绍 1、什么是数据库&#xff1f; 数据库&#xff1a;DateBase&#xff08;DB&#xff09;&#xff0c;是存储和管理数据的仓库。数据库管理系统&#xff1a;DataBase Management System&#xff08;DBMS&#xff09;&#xff0c;操纵和管理数据库的大型软件。SQL&#xff1…

【C语言】手把手教你文件操作

文章目录 一、前言二、文件的打开和关闭1. fopen函数2. fclose函数 三、文件的顺序读写四、文件的随机读写1. fseek函数2. ftell函数3. fwind函数 一、前言 程序运行时&#xff0c;数据存放在内存中&#xff0c;而当程序退出后&#xff0c;数据也就不复存在。 想做到数据持久化…

数据库管理-第七十五期 手把手教你搭19c RAC(20230516)

数据库管理 2023-05-16 第七十五期 手把手教你搭19c RAC1 基础环境2 操作系统配置2.1 /etc/hosts2.2 配置系统挂载2.3 配置本地yum源2.4 操作系统配置2.5 安装预安装RPM包并配置&#xff1a;2.6 创建对应目录2.7 配置时间同步 3 存储挂载3.1 存储环境3.2 存储识别3.3 多路径聚合…

生成一个手绘图为底图的导游图

1 前言 上一篇演示了制作一个简版导游图。简版导游图的优点是制作简单、快速&#xff0c;不需要第三方软件&#xff0c;缺点是略显简陋、不够专业。 本编介绍制作专业导游图的步骤&#xff0c;用手绘图为地图&#xff0c;用图形展现景区信息&#xff0c;能表现出丰富的景区细…

ChatGPT:使用Edge浏览器获取ChatGPT以及如何使用ChatGPT帮你制作PPT

一&#xff1a;前言 ChatGPT&#xff1a;智能AI助你畅聊天地 在现代人日益忙碌的生活中&#xff0c;难免需要一些轻松愉快的聊天来放松身心。而现在&#xff0c;有了 ChatGPT&#xff0c;轻松愉快的聊天变得更加智能、有趣且不受时间、地点限制&#xff01; 什么是 ChatGPT&…

NSSCTF-[深育杯 2021]Press

下载链接&#xff1a;下载 载入IDA&#xff0c;查看内容 首先进入一个函数进行初始化&#xff0c;进入查看 unsigned __int64 sub_4007B6() {int v1; // [rsp8h] [rbp-48h]int i; // [rspCh] [rbp-44h]char src[56]; // [rsp10h] [rbp-40h] BYREFunsigned __int64 v4; // [r…

【可乐荐书】有趣的矩阵:看得懂又好看的线性代数

本栏目将推荐一些经典的、有趣的、有启发性的书籍&#xff0c;这些书籍涵盖了各个领域&#xff0c;包括文学、历史、哲学、科学、技术等等。相信这些书籍不仅可以让你获得知识&#xff0c;还可以让你感受到阅读的乐趣和魅力。 今天给大家推荐的书籍是&#xff1a;《有趣的矩阵…