【Pytorch】AutoGrad个人理解

news2025/1/11 6:00:33

前提知识:[Pytorch] 前向传播和反向传播示例_友人小A的博客-CSDN博客

目录

简介

叶子节点

Tensor AutoGrad Functions


简介

torch.autograd是PyTorch的自动微分引擎(自动求导),为神经网络训练提供动力。torch.autograd需要对现有代码进行最少的更改——声明需要计算梯度的Tensor的属性requires_grad=True。截至目前,PyTorch仅支持 FloatTensor类型(half、float、double和bfloat16)和 ComplexTensor(cfloat、cdouble)的autograd。【信息来自官网】

叶子节点

叶子结点是离散数学中的概念。一棵树当中没有子结点(即度为0)的结点称为叶子结点,简称“叶子”。 叶子是指出度为0的结点,又称为终端结点。

在pytorch中,什么是叶子节点?根据官方定义理解如下。

  • 所有requires_grad为False的张量,都约定俗成地归结为叶子张量
  • requires_grad为True的张量, 如果他们是由用户创建的,则它们是叶张量(leaf Tensor), 表明不是运算的结果,因此grad_fn=None

示例1

def test_training_pipeline2():
    input_data = [[4, 4, 4, 4],
                  [9, 9, 9, 9]]  # 2x4
    input = torch.tensor(input_data, dtype=torch.float32)  # requires_grad=False
    output = torch.sqrt(input)
    
    target_data = [1, 2, 3, 4]
    target = torch.tensor(target_data, dtype=torch.float32)  # requires_grad=False
    loss_fn = torch.nn.MSELoss()
    loss = loss_fn(input=output, target=target)
    print("\ninput.is_leaf:", input.is_leaf)
    print("output.requires_grad:", output.requires_grad)
    print("output.is_leaf:", output.is_leaf)
    print("target.is_leaf:", target.is_leaf)
    print("loss.requires_grad:", loss.requires_grad)
    print("loss.is_leaf:", loss.is_leaf)

样例2

def test_training_pipeline2():
    input_data = [[4, 4, 4, 4],
                  [9, 9, 9, 9]]  # 2x4
    input = torch.tensor(input_data, dtype=torch.float32)  # requires_grad=False
    output = torch.sqrt(input)
    output.requires_grad_(True) # requires_grad=True
    
    target_data = [1, 2, 3, 4]
    target = torch.tensor(target_data, dtype=torch.float32)  # requires_grad=False
    loss_fn = torch.nn.MSELoss()
    loss = loss_fn(input=output, target=target)
    
    print("\ninput.is_leaf:", input.is_leaf)
    print("output.requires_grad:", output.requires_grad)
    print("output.is_leaf:", output.is_leaf)
    print("target.is_leaf:", target.is_leaf)
    print("loss.requires_grad:", loss.requires_grad)
    print("loss.is_leaf:", loss.is_leaf)

样例3

 

def test_training_pipeline5():
    input = torch.rand(1, requires_grad=True)
    output = torch.unique(
        input=input, 
        sorted=True, 
        return_inverse=False, 
        return_counts=False, 
        dim=None
    )
    
    print("\ninput.is_leaf:", input.is_leaf)
    print("output.requires_grad:", output.requires_grad)
    print("output.is_leaf:", output.is_leaf)
    
    output.backward()

样例4

def test_training_pipeline3():
    input_data = [[4, 4, 4, 4],
                  [9, 9, 9, 9]]  # 2x4
    input_a = torch.tensor(input_data, dtype=torch.float32, requires_grad=True)
    input_b = torch.tensor(input_data, dtype=torch.float32, requires_grad=True)
    output = torch.ne(input_a, input_b)

    print("\ninput_a.is_leaf:", input_a.is_leaf)
    print("input_b.is_leaf:", input_b.is_leaf)
    print("output.dtype:", output.dtype)
    print("output.requires_grad:", output.requires_grad)
    print("output.is_leaf:", output.is_leaf)

    output.backward()   # 报错

 

 

样例5

def test_training_pipeline7():
    input_data = [[4, 4, 4, 4],
                  [9, 9, 9, 9]]  # 2x4
    input_a = torch.tensor(input_data, dtype=torch.float32, requires_grad=True)
    input_b = torch.tensor(input_data, dtype=torch.float32)    
    output = torch.add(input_a, input_b)
    print("\ninput_a.requires_grad:", input_a.requires_grad)
    print("input_b.requires_grad:", input_b.requires_grad)
    print("output.requires_grad:", output.requires_grad)
    print("output.is_leaf:", output.is_leaf)
    
    grad = torch.ones_like(output)
    
    input_b[0][0] = 10 
    input_a[0][0] = 10 
    output.backward(grad)

 样例6

def test_training_pipeline9():
    x = torch.tensor([1.0], requires_grad=True)
    y = x + 2
    z = 2 * y		# <-- dz/dy=2
    y[0] = -2.0
    
    print("\nx.is_leaf:", x.is_leaf)
    print("y.is_leaf:", y.is_leaf)
    print("z.is_leaf:", z.is_leaf)
    print("\nx.requires_grad:", x.requires_grad)
    print("y.requires_grad:", y.requires_grad)
    print("z.requires_grad:", z.requires_grad)
    z.backward()


def test_training_pipeline9():
    x = torch.tensor([1.0], requires_grad=True)
    y = x + 2
   z = y * y  # <-- dz/dy= 2*y
    y[0] = -2.0
    
    print("\nx.is_leaf:", x.is_leaf)
    print("y.is_leaf:", y.is_leaf)
    print("z.is_leaf:", z.is_leaf)
    print("\nx.requires_grad:", x.requires_grad)
    print("y.requires_grad:", y.requires_grad)
    print("z.requires_grad:", z.requires_grad)
    z.backward()

 

Tensor AutoGrad Functions

  1. Tensor.grad

  2. Tensor.requires_grad

  3. Tensor.is_leaf

  4. Tensor.backward(gradient=None, reqain_graph=None, create_graph=False)

  5. Tensor.detach()

  6. Tensor.detach_()

  7. Tensor.retain_grad()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/392186.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

线性表的合并

线性表的应用 线性表的合并 问题描述&#xff1a; 假设利用两个线性表La和Lb表示两个集合A和B&#xff0c;现要求一个新的集合AA∪B 即&#xff1a; La(7,5,3,11) Lb(2,6,3) A(7,8,3,11,2,6) 算法步骤&#xff08;A既是参数&#xff0c;也是操作的结果&#xff09; 依次…

全志V853芯片 Tina Linux下网络ADB内存泄露如何修复?

1.主题 Tina Linux 网络ADB内存泄露修复 2.问题背景 硬件&#xff1a;V853 软件&#xff1a;Tina4.0 Linux-4.9 背景&#xff1a;使用网络adb时&#xff0c;反复connect disconnect&#xff0c;会发生内存泄露的问题。 3.问题描述 3.1复现步骤 1、首先使能网络ADB功能。 …

Async注解使用和CompletableFuture注解获取返回值

举栗个现实问题&#xff1a; 需求&#xff1a;拉取 业务数据不能超过 5秒。 拉取第三方数据 &#xff0c;分别需要拉取 A业务数据&#xff08;需要2秒&#xff09; 、拉取 B业务数据&#xff08;需要2秒&#xff09;、拉取 C业务数据&#xff08;需要2秒&#xff09; &#xff…

模电学习11 运算放大器学习入门

一、基本概念 运算放大器简称运放&#xff0c;是一种模拟电路实现的集成电路&#xff0c;可以对信号进行很高倍数的放大。一般有正相输入端、反相输入端、输出端口、正电源、负电源等接口。 运放可工作在饱和区、放大区&#xff0c;其中放大区极其陡峭&#xff0c;因为运放的放…

【深度学习】BERT变体—SpanBERT

SpanBERT出自Facebook&#xff0c;就是在BERT的基础上&#xff0c;针对预测spans of text的任务&#xff0c;在预训练阶段做了特定的优化&#xff0c;它可以用于span-based pretraining。这里的Span翻译为“片段”&#xff0c;表示一片连续的单词。SpanBERT最常用于需要预测文本…

c++11 标准模板(STL)(std::unordered_map)(四)

定义于头文件 <unordered_map> template< class Key, class T, class Hash std::hash<Key>, class KeyEqual std::equal_to<Key>, class Allocator std::allocator< std::pair<const Key, T> > > class unordered…

【3.6】链表、操作系统CPU是如何执行程序的、Redis数据类型及其应用

链表 题目题型203. 移除链表元素 - 力扣&#xff08;LeetCode&#xff09;辅助头节点解决移出head问题707. 设计链表 - 力扣&#xff08;LeetCode&#xff09;辅助头节点206. 反转链表 - 力扣&#xff08;LeetCode&#xff09;迭代 / 递归19. 删除链表的倒数第 N 个结点 - 力扣…

什么?年终奖多发1块钱竟要多缴9.6W的税

对于大多数的工薪阶级来说&#xff0c;目前现行的个人所得税适用于全年累计收入一次性税收优惠。 有可能有的人不理解一次性税收优惠是什么意思&#xff0c;所以这里我首先解释下什么是一次性税收优惠&#xff0c;然后在讲一下为什么明明公司多发了钱&#xff0c;到手反而会更…

Kotlin中的destructuring解构声明

开发中有时只是想分解一个包含多个字段的对象来初始化几个单独的变量。要实现这一点&#xff0c;可以使用Kotlin的解构声明。本文主要了解&#xff1a;“1、如何使用解构声明这种特性 2、底层是如何实现的 3、如何在你自己的类中实现它1、解构声明的使用解构声明&a…

hutool XML反序列化漏洞(CVE-2023-24162)

漏洞简介 Hutool 中的XmlUtil.readObjectFromXml方法直接封装调用XMLDecoder.readObject解析xml数据&#xff0c;当使用 readObjectFromXml 去处理恶意的 XML 字符串时会造成任意代码执行。 漏洞复现 我们在 maven 仓库中查找 Hutool ​https://mvnrepository.com/search?…

基于EB工具的TC3xx_MCAL配置开发01_WDG模块配置介绍

目录 1.概述2. WDG 配置2.1 General部分配置2.2 WdgSettingsConfig配置2.2.1 配置概述2.2.2 CPU WDG具体配置2.3 WdgDemEventParameterRefs3. WDG配置注意事项1.概述 本篇开始我们基于EB Tresos工具对英飞凌TC3xx系列MCU的MCAL开发进行介绍,结合项目经验对各MCAL外设的开发及…

C++回顾(七)—— 面向对象模型

7.1 静态成员变量和静态成员函数 7.1.1 静态成员变量 关键字 static 可以用于说明一个类的成员&#xff1b;静态成员提供了一个同类对象的共享机制&#xff1b;把一个类的成员说明为 static 时&#xff0c;这个类无论有多少个对象被创建&#xff0c;这些对象共享这个 static …

ubuntu C++调用python

普通 目录结构 main.py 等会用c调用func() #!/usr/bin/env python # _*_ coding:utf-8 _*_ import osdef func():print(hello world)if __name__ __main__:func()main.cpp 其中Py_SetPythonHome的路径是anaconda中环境的路径&#xff0c;最开始的L一定要加&#xff08;因为…

基于 Rainbond 的 Pipeline(流水线)插件

背景 Rainbond 本身具有基于源码构建组件的能力&#xff0c;可以将多种编程语言的代码编译成 Docker 镜像&#xff0c;但是在持续集成的过程中&#xff0c;往往会需要对提交的代码进行静态检查、构建打包以及单元测试。之前由于 Rainbond 并没有 Pipeline 这种可编排的机制&am…

Git-学习笔记02【Git连接远程仓库】

Java后端 学习路线 笔记汇总表【黑马-传智播客】Git-学习笔记01【Git简介及安装使用】Git-学习笔记02【Git连接远程仓库】Git-学习笔记03【Git分支】目录 01-使用github创建一个远程仓库 02-推送到远程仓库介绍 03-创建ssh密钥及在github上配置公钥 04-使用ssh方式将本地仓…

MySQL基本查询

文章目录表的增删查改Create&#xff08;创建&#xff09;单行数据 全列插入多行数据 指定列插入插入否则更新替换Retrieve&#xff08;读取&#xff09;SELECT列全列查询指定列查询查询字段为表达式查询结果指定别名结果去重WHERE 条件基本比较BETWEEN AND 条件连接OR 条件连…

SpringBoot With IoC,DI, AOP,自动配置

文章目录1 IoC&#xff08;Inverse Of Controller&#xff09;2 DI&#xff08;Dependency Injection&#xff09;3 AOP&#xff08;面向切面编程&#xff09;3.1 什么是AOP&#xff1f;3.2 AOP的作用&#xff1f;3.3 AOP的核心概念3.4 AOP常见通知类型3.5 切入点表达式4 自动配…

计算机网络的166个概念 你知道几个第七部分

计算机网络传输层 可靠数据传输&#xff1a;确保数据能够从程序的一端准确无误的传递给应用程序的另一端。 容忍丢失的应用&#xff1a;应用程序在发送数据的过程中可能会存在数据丢失的情况。 非持续连接&#xff1a;每个请求/响应会对经过不同的连接&#xff0c;每一个连接…

vue3+ts:约定式提交(git husky + gitHooks)

一、背景 Git - githooks Documentation https://github.com/typicode/husky#readme gitHooks: commit-msg_snowli的博客-CSDN博客 之前实践过这个配置&#xff0c;本文在vue3 ts 的项目中&#xff0c;再记录一次。 二、使用 2.1、安装 2.1.1、安装husky pnpm add hus…

python学习——【第三弹】

前言 上一篇文章 python学习——【第二弹】中学习了python中的运算符内容&#xff0c;这篇文章接着学习python中的流程控制语句。 流程控制指的是代码运行逻辑、分支走向、循环控制&#xff0c;是真正体现我们程序执行顺序的操作。流程控制一般分为顺序执行、条件判断和循环控…