【李沐深度学习笔记】自动求导实现

news2024/11/15 10:44:19

课程地址和说明

自动求导实现p2
本系列文章是我学习李沐老师深度学习系列课程的学习笔记,可能会对李沐老师上课没讲到的进行补充。

自动求导

# 创建变量
import torch
x = torch.arange(4, dtype=torch.float32) #只有浮点数才能求导
# 计算y关于x的梯度之前,需要一个地方存储梯度
x.requires_grad_(True)
print("x为:",x)
# 计算y
y = 2 * torch.dot(x,x) # y=2倍的x向量的内积,即y=2x^2
print("y为:",y)
# 通过调用反向传播函数来自动计算y关于x每个分量的梯度
y.backward() # 求导,y'=4x
print("x的梯度为:",x.grad)
print("y的导数是否为y'=4x:",x.grad == 4 * x)

运行结果:

x为: tensor([0., 1., 2., 3.], requires_grad=True)
y为: tensor(28., grad_fn=<MulBackward0>)
x的梯度为: tensor([ 0., 4., 8., 12.])
y的导数是否为y’=4x: tensor([True, True, True, True])

默认情况下PyTorch会把梯度累积起来

所以我们在算下一步的时候需要讲梯度清零

  • 举一个例子,之前的章节说过,当 y = s u m ( x → ) = ∑ i = 1 m x i = x 1 + x 2 + ⋯ + x m y=sum(\overrightarrow x)=\sum\limits_{i=1}^{m} x_{i}=x_{1}+x_{2}+\dots +x_{m} y=sum(x )=i=1mxi=x1+x2++xm a a a为任意常数, u = g ( x → ) u=g(\overrightarrow x) u=g(x )时,有:
    ∂ y ∂ x → = [ ∂ f ( x → ) ∂ x 1 ∂ f ( x → ) ∂ x 2 ⋮ ∂ f ( x → ) ∂ x m ] m × 1 = [ 1 1 ⋮ 1 ] m × 1 \frac{\partial {y}}{\partial\overrightarrow x}=\begin{bmatrix} \frac{\partial {f(\overrightarrow x)}}{\partial{x_{1}}}\\ \frac{\partial {f(\overrightarrow x)}}{\partial{x_{2}}}\\ \vdots \\ \frac{\partial {f(\overrightarrow x)}}{\partial{x_{m}}} \end{bmatrix}_{m\times 1}=\begin{bmatrix} 1\\ 1\\ \vdots \\ 1 \end{bmatrix}_{m\times 1} x y= x1f(x )x2f(x )xmf(x ) m×1= 111 m×1最终结果是单位向量。
    用PyTorch复现为:
# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
# 重新定义,y=x1+x2+...+xn
y = sum(x)
# 向后传播求导
y.backward()
# 打印求导结果
print(x.grad)

运行结果:
tensor([1., 1., 1., 1.])

深度学习中,我们的目的不是计算微分矩阵,而是批量中每个样本单独计算的偏导数之和。

# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
# 重新定义,y=x1+x2+...+xn
y = x * x # 变为了标量
print("y=",y)
# 向后传播求导
y.sum().backward() # 变量需要先让y求和再向后传播求导
print("y.sum()=",y.sum())
# 打印求导结果
print(x.grad)

运行结果:
y= tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)
y.sum()= tensor(14., grad_fn=<SumBackward0>)
tensor([0., 2., 4., 6.])

将某些计算移动到记录的计算图之外

# 清空累积的梯度
x.grad.zero_()
# 定义y = x * x
y = x* x
# 定义u为常数,不再是y = x * x这样的函数,其常数值为x * x
u = y.detach()
print(u)
# z = u * x,相当于常数u乘以向量x
z = u * x
# z = [u1,u2,u3,u4](常数向量)×[x1,x2,x3,x4]^T(变量向量)
# z = u1x1+u2x2+u3x3+u4x4
# 所以z'x1=u1,以此类推,其偏导数的向量最终就和u相等
z.sum().backward()
print(x.grad  == u)

运行结果:
tensor([0., 1., 4., 9.])
tensor([True, True, True, True])

# 清空累积的梯度
x.grad.zero_()
# 现在y为 x^2(x1^2+x2^2+x3^2+x4^2),y'=2x(y'x1=2x1,以此类推)
# 对y求导
y.sum().backward()
# 证明y的导数是否是2x
print(x.grad ==2 *x)

运行结果:
tensor([True, True, True, True])

【提问】请用通俗的语言告诉我在PyTorch中, y.sum().backward()与y.backward()这两种向后求导方式有什么区别?
【chatGPT回答】在PyTorch中,y.sum().backward()和y.backward()是两种不同的向后求导方式,它们在计算梯度的方式上有一些区别。

  1. y.sum().backward()
  • 这个方式首先对y进行了求和操作,将一个张量中的所有元素相加得到一个标量(一个单一的数值)。
  • 然后对这个标量进行反向传播,计算相对于y的梯度。
  • 这种方式通常用于处理损失函数的梯度计算,因为损失函数通常返回一个标量,你希望计算相对于模型参数的梯度。
  1. y.backward():
  • 这个方式直接对张量y进行反向传播,计算相对于y的梯度。
  • 这种方式通常用于计算中间变量的梯度,而不是损失函数的梯度。
    总的来说,主要区别在于
  • y.sum().backward()用于计算损失函数相对于模型参数的梯度,因为损失函数通常返回一个标量值。
  • y.backward()用于计算中间变量相对于其他变量的梯度,通常在计算模型的中间层时使用。
    选择哪种方式取决于你的具体需求和计算图的结构。通常,当你需要计算损失函数相对于模型参数的梯度时,你会使用y.sum().backward(),而在其他情况下,你会使用y.backward()来计算相对于中间变量的梯度。

即使构建函数的计算图需要通过Python控制流(例如,条件循环或任意函数调用),我们仍然可以计算得到的变量的梯度

# 清空累积的梯度
x.grad.zero_()
# 设定一个函数
def f(a):
    b = a * 2
    # 如果b向量的模长(L2范式)小于1000,就不断给它乘以2
    while b.norm() < 1000:
        b = b * 2
    if b.sum() > 0:
        c = b
    else:
        c = 100 * b
    return c
# 取一个随机数标量作为其值
a = torch.randn(size=(),requires_grad=True)
print("a为:",a)
d = f(a)
print("f(a)为:",d)
d.backward()
# 此函数相当于f(a)=ba,b是系数,如果是向量也是如此
# 所以其导数(或向量对应梯度向量)一定是斜率b
print(a.grad == d/a)

运行结果:
a为: tensor(1.6346, requires_grad=True)
f(a)为: tensor(1673.8451, grad_fn=)
tensor(True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1035138.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VUE之模板解析(v-for)

v-for的多种写法 1. item in list 2. (item, index) in list 3. (item, name, index) in object forAliasRE 非贪婪模式匹配 &#xff1f; 正则表达式默认都是贪婪匹配&#xff0c;添加&#xff1f;后将其变成非贪婪模式 由下面例子可以看出如果没有添加&#xff1f;正则…

虚拟机软件Parallels Desktop 18 mac中文新增功能(PD18虚拟机)

Parallels Desktop 18可以在 Mac 计算机上下载并安装 Windows 操作系统。在 Mac 与 Windows 之间无缝复制和粘贴文本或拖放对象。在 Mac 虚拟机中跨多个操作系统开发和测试。毫不费力地运行 Windows 应用程序&#xff0c;不会减慢 Mac 的运行速度。 Parallels Desktop 18 for M…

lv5 嵌入式开发-5 线程的创建和参数传递

目录 1 线程基础 2 Linux线程库 2.1 线程创建 – pthread_create 2.2 线程结束 – pthread_exit 2.3 线程查看tid函数 2.4 线程间参数传递&#xff08;重点&#xff09; 2.4.1 练习 2.5 线程查看命令&#xff08;多线程&#xff09; 2.6 线程回收 – pthread_join 2.…

【重新定义matlab强大系列十四】基于问题求解有/无约束非线性优化

&#x1f517; 运行环境&#xff1a;Matlab &#x1f6a9; 撰写作者&#xff1a;左手の明天 &#x1f947; 精选专栏&#xff1a;《python》 &#x1f525; 推荐专栏&#xff1a;《算法研究》 #### 防伪水印——左手の明天 #### &#x1f497; 大家好&#x1f917;&#x1f91…

玩转Mysql系列 - 第22篇:mysql索引原理详解

这是Mysql系列第22篇。 背景 使用mysql最多的就是查询&#xff0c;我们迫切的希望mysql能查询的更快一些&#xff0c;我们经常用到的查询有&#xff1a; 按照id查询唯一一条记录 按照某些个字段查询对应的记录 查找某个范围的所有记录&#xff08;between and&#xff09; …

如何解决python安装Crypto失败的问题

文章目录 准备安装相关链接准备 Python Crypto模块是一个第三方库,它提供了常见的加密算法和协议,比如AES、RSA、DES等。 PyCrypto 是加密工具包 Crypto 的 python 版,该模块实现了各种算法和协议的加密,提供了各种加密方式对应的算法的实现,包括单向加密、对称加密及公钥…

Linux(ubuntu)系统更新后不能进入图形界面

最近需要跑一个深度学习的程序&#xff0c;把许久没用的ubuntu系统调了出来&#xff0c;手欠的我更新了一下系统&#xff0c;结果再启动&#xff0c;系统就只停留在光标闪动那里&#xff0c;不能看到图形界面了。网上查了一下&#xff0c;说是因为更新后&#xff0c;显卡驱动没…

MySQL 8.0数据库主从搭建和问题处理

错误处理&#xff1a; 在从库通过start slave启动主从复制时出现报错 Last_IO_Error: error connecting to master slaveuser10.115.30.212:3306 - retry-time: 60 retries: 1 message: Authentication plugin caching_sha2_password reported error: Authentication require…

【draw】draw.io怎么设置默认字体大小

默认情况下draw里面的字体大小是12pt&#xff0c;我们可以将字体改成我们想要的大小&#xff0c;例如18pt。 然后点击样式&#xff0c;设置为默认样式。 下一次当我们使用文字大小时就是18pt了。

html和css相关操作

html第一个网页 <!DOCTYPE html> <!--html文档声明&#xff0c;声明此文档是一个html5的文档--> <html> <!--html文档开头标签--><head><!--html文档的设置标签&#xff0c;文档的设置及资源的引用都写在这个标签中--><meta charset&q…

[论文笔记]Prefix Tuning

引言 今天带来微调LLM的第二篇论文笔记Prefix-Tuning。 作者提出了用于自然语言生成任务的prefix-tuning(前缀微调)的方法,固定语言模型的参数而优化一些连续的任务相关的向量,称为prefix。受到了语言模型提示词的启发,允许后续的token序列注意到这些prefix,当成虚拟toke…

1793_树莓派杂志第一期MagPi01阅读

全部学习汇总&#xff1a; GreyZhang/little_bits_of_raspberry_pi: my hacking trip about raspberry pi. (github.com) 给自己的产品起一个好听的名称&#xff0c;我觉得这个是国外的企业中很好的一种文化。这里提到的苹果、黑莓等全都是一系列的水果。树莓派也有这样的风格&…

命令行程序测试自动化

【软件测试面试突击班】如何逼自己一周刷完软件测试八股文教程&#xff0c;刷完面试就稳了&#xff0c;你也可以当高薪软件测试工程师&#xff08;自动化测试&#xff09; 这几天有一个小工具需要做测试&#xff0c;是一个命令行工具&#xff0c;这个命令行工具有点类似mdbg等命…

【Vue+Element-UI】实现登陆注册界面及axios之get、post请求登录功能实现、跨域问题的解决

目录 一、实现登陆注册界面 1、前期准备 2、登录静态页实现 2.1、创建Vue组件 2.2、静态页面实现 2.3、配置路由 2.4、更改App.vue样式 2.5、效果 3、注册静态页实现 3.1、静态页面实现 3.2、配置路由 3.3、效果 二、axios 1、前期准备 1.1、准备项目 1.2、安装…

pytorch学习------数据集的使用方式

一、前言 在深度学习中&#xff0c;数据量通常是都非常多&#xff0c;非常大的&#xff0c;如此大量的数据&#xff0c;不可能一次性的在模型中进行向前的计算和反向传播&#xff0c;经常我们会对整个数据进行随机的打乱顺序&#xff0c;把数据处理成一个个的batch&#xff0c…

【力扣1464】数组中两元素的最大乘积

&#x1f451;专栏内容&#xff1a;力扣刷题⛪个人主页&#xff1a;子夜的星的主页&#x1f495;座右铭&#xff1a;前路未远&#xff0c;步履不停 目录 一、题目描述二、题目分析1、排序2、最值模拟 一、题目描述 题目链接&#xff1a;数组中两元素的最大乘积 给你一个整数数…

针对 SAP 的增强现实技术

增强现实技术是对现实世界的一种交互式模拟。这种功能受到各种企业和制造商的欢迎&#xff0c;因为它可以减少生产停机时间、快速发现问题并维护流程&#xff0c;从而提高运营效率。许多安卓应用都在探索增强现实技术。 使用增强现实技术&#xff08;AR&#xff09;的Liquid U…

【Vue入门】语法 —— 事件处理器、自定义组件、组件通信

目录 一、事件处理器 1.1 样式绑定 1.2 事件修饰符 1.3 按键修饰符 1.4 常用控制符 1.4.1 常用字符综合案例 1.4.2 修饰符 二、自定义组件 2.1 组件介绍及定义 2.2 组件通信 2.2.1 组件传参&#xff08;父 -> 子&#xff09; 2.2.1 组件传参&#xff08;子 ->…

【Leuanghing】ANSA入门——GUI界面介绍

【Leuanghing】ANSA入门——软件介绍 大家好&#xff01;今天为大家推荐一款软件 —— ANSA。ANSA是目前公认的全球最快捷的CAE前处理软件之一&#xff0c;也是一个功能强大的通用CAE前处理软件。 一、相关介绍 BETA CAE Systems S.A公司总部位地希腊的赛萨罗尼奇市(Thessalon…

水平基准和垂直基准

水平基准 针对不同的地理空间测量&#xff0c;水平基准可以是一个参考点或一个参考点集。在地球上&#xff0c;相同的位置上的点可能有不同的坐标&#xff0c;取决于使用了哪种基准。 水平基准用于测量地球上的位置&#xff08;position&#xff09;。因为地球不是一个完美的球…