Pytorch-自动微分模块

news2024/12/26 22:45:57

49739c720cb4452c9336253d032fc756.gif

🥇接下来我们进入到Pytorch的自动微分模块torch.autograd~

自动微分模块是PyTorch中用于实现张量自动求导的模块。PyTorch通过torch.autograd模块提供了自动微分的功能,这对于深度学习和优化问题至关重要,因为它可以自动计算梯度,无需手动编写求导代码。torch.autograd模块的一些关键组成部分:

  1. 函数的反向传播torch.autograd.function 包含了一系列用于定义自定义操作的函数,这些操作可以在反向传播时自动计算梯度。
  2. 计算图的反向传播torch.autograd.functional 提供了一种构建计算图并自动进行反向传播的方式,这类似于其他框架中的符号式自动微分。
  3. 数值梯度检查torch.autograd.gradcheck 用于检查数值梯度与自动微分得到的梯度是否一致,这是确保正确性的一个有用工具。
  4. 错误检测模式torch.autograd.anomaly_mode 在自动求导时检测错误产生路径,有助于调试。
  5. 梯度模式设置torch.autograd.grad_mode 允许用户设置是否需要梯度,例如在模型评估时通常不需要计算梯度。
  6. 求导方法:PyTorch提供backward()torch.autograd.grad()两种求梯度的方法。backward()会将梯度填充到叶子节点的.grad字段,而torch.autograd.grad()直接返回梯度结果。
  7. requires_grad属性:在创建张量时,可以通过设置requires_grad=True来指定该张量是否需要进行梯度计算。这样在执行操作时,PyTorch会自动跟踪这些张量的计算过程,以便后续进行梯度计算。

梯度基本计算

def func1():
    x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
    f = x ** 2 +10
    # 自动微分求导
    f.backward()   # 反向求导
    # backward 函数计算的梯度值会存储在张量的 grad 变量中
    print(x.grad)
def func2():
    x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)
    # 变量经过中间计算
    f1 = x ** 2 + 10
    
    # f2 = f1.mean()  # 平均损失,相当于每个值/4
    f2 = f1.sum()  # 求和损失,相当于每个值*1
    f2.backward()
    print(x.grad)
def func3():
    x1 = torch.tensor(10, requires_grad=True, dtype=torch.float64)
    x2 = torch.tensor(20, requires_grad=True, dtype=torch.float64)
    y = x1 ** 2 + x2 ** 2 + x1 * x2
    y = y.sum()
    y.backward()
    print(x1.grad, x2.grad)

def func4():
    x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)
    x2 = torch.tensor([30, 40], requires_grad=True, dtype=torch.float64)

    y = x1 ** 2 + x2 ** 2 + x1 * x2
    y = y.sum()
    y.backward()
    print(x1.grad,x2.grad)

func1func2,它们分别处理标量张量和向量张量的梯度计算。

  • func1中,首先创建了一个标量张量x,并设置requires_grad=True以启用自动微分。然后计算f = x ** 2 + 10,接着使用f.backward()进行反向求导。最后,通过打印x.grad输出梯度值。
  • func2中,首先创建了一个向量张量x,并设置requires_grad=True以启用自动微分。然后计算f1 = x ** 2 + 10,接着使用f1.sum()对向量张量进行求和操作,得到一个标量张量f2。最后,使用f2.backward()进行反向求导。
  • func3func4分别求多个标量和向量的情况,与上面相似。

控制梯度计算

我们可以通过一些方法使 requires_grad=True 的张量在某些时候计算时不进行梯度计算。 

  1. 第一种方式是使用torch.no_grad()上下文管理器,在这个上下文中进行的所有操作都不会计算梯度。
  2. 第二种方式是通过装饰器@torch.no_grad()来装饰一个函数,使得这个函数中的所有操作都不会计算梯度。
  3. 第三种方式是通过torch.set_grad_enabled(False)来全局关闭梯度计算功能,之后的所有操作都不会计算梯度,直到下一次再次调用此方法torch.set_grad_enabled(True)开启梯度计算功能。
x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
print(x.requires_grad)

# 第一种方式: 对代码进行装饰
with torch.no_grad():
    y = x ** 2
print(y.requires_grad)

# 第二种方式: 对函数进行装饰
@torch.no_grad()
def my_func(x):
    return x ** 2
print(my_func(x).requires_grad)


# 第三种方式
torch.set_grad_enabled(False)
y = x ** 2
print(y.requires_grad)

默认张量的 grad 属性会累计历史梯度值,如果需要重复计算每次的梯度,就需要手动清除。

x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)

for _ in range(3):

    f1 = x ** 2 + 20
    f2 = f1.mean()

    if x.grad is not None:
        x.grad.data.zero_()   # 本身来改动

    f2.backward()
    print(x.grad)

x.grad不是x,因为x是一个tensor张量,而x.grad是x的梯度。在PyTorch中,张量的梯度是通过自动求导机制计算得到的,而不是直接等于张量本身。

梯度下降优化最优解

x = torch.tensor(10, requires_grad=True, dtype=torch.float64)

for _ in range(5000):

     
    f = x ** 2

    # 梯度清零
    if x.grad is not None:
        x.grad.data.zero_()

    # 反向传播计算梯度
    f.backward()

    # 更新参数
    x.data = x.data - 0.001 * x.grad

    print('%.10f' % x.data)

更新参数相当于通过学习率对当前数值进行迭代。

f.backward()是PyTorch中自动梯度计算的函数,用于计算张量`f`关于其所有可学习参数的梯度。在这个例子中,`f`是一个标量张量,它只有一个可学习参数`x`。当调用f.backward()`时,PyTorch会自动计算`f`关于`x`的梯度,并将结果存储在`x.grad`中。这样,我们就可以使用这个梯度来更新`x`的值,以便最小化损失函数`f`。

梯度计算注意

当对设置 requires_grad=True 的张量使用 numpy 函数进行转换时, 会出现如下报错:

Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

此时, 需要先使用 detach 函数将张量进行分离, 再使用 numpy 函数。detach 之后会产生一个新的张量, 新的张量作为叶子结点,并且该张量和原来的张量共享数据, 但是分离后的张量不需要计算梯度。

import torch

def func1():

    x = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)

    # Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
    # print(x.numpy())  # 错
    print(x.detach().numpy())  


def func2():

    x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)

    # x2 作为叶子结点
    x2 = x1.detach()

    # 两个张量的值一样: 140421811165776 140421811165776
    print(id(x1.data), id(x2.data))
    x2.data = torch.tensor([100, 200])
    print(x1)
    print(x2)

    # x2 不会自动计算梯度: False
    print(x2.requires_grad)

7017d1cccb2c45cd845fefae64ed1947.gif

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1607670.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

行人属性AI识别/人体结构化属性AI识别算法的原理及应用场景介绍

行人属性AI识别技术是一种基于人工智能技术的图像识别技术,通过对行人的图像或视频进行处理和分析,提取出其中的结构化信息,如人体姿态、关键点位置、行人属性(性别、年龄、服装等)等。 行人结构化数据分析的方法包括…

什么是边缘计算?它为何如此重要?-天拓四方

随着信息技术的快速发展,数据处理和计算的需求日益增大,特别是在实时性要求极高的场景中,传统的云计算模式面临着巨大的挑战。在这样的背景下,边缘计算作为一种新兴的计算模式,正逐渐受到业界的广泛关注。那么&#xf…

【创建型模式】单例模式

一、单例模式概述 单例模式的定义:又叫单件模式,确保一个类只有一个实例,并提供一个全局访问点。(对象创建型) 要点: 1.某个类只能有一个实例;2.必须自行创建这个实例;3.必须自行向整…

【nginx代理和tengine的启动-重启等命令】

在nginx成功启动后[任务管理器有nginx.exe进程],运行vue项目,在浏览器访问http://localhost:10001/,提示:访问拒绝(调试中network某些地址403); 解决方案: localhost改为ip&#xff…

自动化测试Selenium(4)

WebDriver相关api 定位一组元素 webdriver可以很方便地使用findElement方法来定位某个特定的对象, 不过有时候我们需要定位一组对象, 这时候就要使用findElements方法. 定位一组对象一般用于一下场景: 批量操作对象, 比如将页面上的checkbox都勾上. 先获取一组对象, 再在这组…

【代码随想录】【回文子串】day57:● 647. 回文子串 ● 516.最长回文子序列 ● 动态规划总结篇

回文子串 def countSubstrings(self, s):# 动态规划解法# dp[i][j] s[i-j]区间的回文子串的数目 dp[i][j]取决于dp[i1]和dp[j-1]count0dp[[False]*len(s) for _ in range(len(s))]for i in range(len(s)-1,-1,-1):for j in range(i,len(s)):if s[i]s[j] :if j-i<1:count1dp[…

全新升级轻舟知识付费系统引流变现至上利器

知识付费系统&#xff1a;引流变现至上利器 本系统参考各大主流知识付费系统&#xff0c;汇总取其精华&#xff0c;自主研发&#xff0c;正版授权系统。 我们给你搭建搭建一个独立运营的知识付费平台&#xff0c;搭建好之后&#xff0c;你可以自由的运营管理。网站里面的名称…

嵌入式软件考试——网络基础知识

1 主要知识点 OSI/RMTCP/IPIP地址与网络划分DNS与DHCP网络规划与设计网络故障诊断 2 OSI/RM 2.1 OSI七层模型 OSI七层模型 Bit流&#xff1a;物理层(集中器/中继器) 帧&#xff1a;数据链路层(网桥/交换机) 包&#xff1a;网络层(路由器) 段&#xff1a;传输层 报文&#xf…

SpringBoot框架——7.整合MybatisPlus

这篇主要介绍Springboot整合MybatisPlus&#xff0c;另外介绍一个插件JBLSpringbootAppGen,以及一个经常用于测试的基于内存的h2数据库。 Mybatisplus是mybatis的增强工具&#xff0c;和tk-mybatis相似&#xff0c;但功能更强大&#xff0c;可避免重复CRUD语句&#xff0c;先来…

uniapp_微信小程序_预约时间组件的使用

一、官方文档 DatetimePicker 选择器 | uView 2.0 - 全面兼容 nvue 的 uni-app 生态框架 - uni-app UI 框架 (uviewui.com) 二、完成的效果 之前使用的是Calendar 日历 这个太耗性能了&#xff0c;直接页面卡顿&#xff0c;所以就换成以上选择器了 三、代码 <u-datetime-p…

graphviz使用

安装 brew install graphviz测试 https://github.com/martisak/dotnets?tabreadme-ov-file

图文教程 | Git安装配置、常用命令大全以及常见问题

前言 因为多了一台电脑&#xff0c;平时写一些代码&#xff0c;改一些文件&#xff0c;用U盘存着转来转去特别麻烦。于是打算用Git管理我的文件&#xff0c;方便在两个终端之间传输数据啥的。也正好给新电脑装好Git。 &#x1f4e2;博客主页&#xff1a;程序源⠀-CSDN博客 &…

3d模型渲染怎么会没材质---模大狮模型网

在进行3D模型渲染时&#xff0c;有时会遇到材质丢失的问题&#xff0c;这可能会给设计师们带来一些困扰。材质是渲染的重要组成部分&#xff0c;它们赋予了模型真实感和视觉吸引力。然而&#xff0c;当模型在渲染过程中出现没有材质的情况时&#xff0c;可能会导致最终效果不如…

Spring Boot 2.x 将 logback 1.2.x 升级至 1.3.x

场景 安全部门针对代码进行漏洞扫描时&#xff0c;发现 logback-core 和 logback-classic 都属于 1.2.x 版本&#xff0c;这个版本存在 CVE 漏洞&#xff0c;并且建议升级到 1.3.x 版本。 问题 将两个包直接升级到 1.3.x 版本时&#xff0c;Spring Boot Web 服务启动直接出现…

基于springboot+vue+Mysql的地方废物回收机构管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

C语言中的数据结构--双向链表

前言 上一节我们已经学习完了单链表&#xff08;单向不带头不循环链表&#xff09;的所有内容&#xff0c;我们在链表的分类里面知道了&#xff0c;链表分为单向的和双向的&#xff0c;那么本节我们就来进行双向链表&#xff08;带头双向循环链表&#xff09;的学习&#xff0c…

Java 的注释

文章目录 java 的注释共有三种形式单行注释多行注释文档注释文档注释的文档需要命令进行生成GBK 不可映射问题 与大多数的编程语言一样&#xff0c;Java 中的注释也不会出现在可执行程序中。 因此我们可以在源程序中根据需要添加任意多的注释&#xff0c;而不必担心可执行代码受…

优秀Burp插件 提取JS、HTML中URL插件

Burp Js Url Finder 攻防演练过程中&#xff0c;我们通常会用浏览器访问一些资产&#xff0c;但很多接口/敏感信息隐匿在html、JS文件中&#xff0c;通过该Burp插件我们可以&#xff1a; 1、发现通过某接口可以进行未授权/越权获取到所有的账号密码 2、发现通过某接口可以枚举用…

【数据结构与算法】贪心算法及例题

目录 贪心算法例题一&#xff1a;找零问题例题二&#xff1a;走廊搬运物品最优方案问题输入样例例题三&#xff1a;贪心自助餐 贪心算法 贪心算法是一种在每一步选择中都采取当前状态下最优的选择&#xff0c;以期望最终达到全局最优解的算法。它的核心思想是每次都选择当前最…

python语言零基础入门——变量与简单数据类型

目录 一、变量 1.创建变量 2.变量的修改 3.变量的命名 &#xff08;1&#xff09;常量 &#xff08;2&#xff09;标识符 &#xff08;3&#xff09;关键字 &#xff08;4&#xff09;命名规则 二、简单数据类型 1.变量的数据类型 2.数据类型 3.整型&#xff08;In…