Pytorch学习笔记(二)官方60min入门教程之自动微分

news2024/9/22 11:10:11

目录

一.相关包及函数介绍

二.雅各比向量积

三.练习代码


一.相关包及函数介绍

autograd 包是 PyTorch 中所有神经网络的核心。首先让我们简要地介绍它,然后我们将会去训练我们的第一个神经网络。该 autograd 软件包为 Tensors 上的所有操作提供自动微分。它是一个由运行定义的框架,这意味着以代码运行方式定义你的后向传播,并且每次迭代都可以不同。我们从 tensor 和 gradients 来举一些例子。

1、TENSOR

torch.Tensor 是包的核心类。如果将其属性 .requires_grad 设置为 True,则会开始跟踪针对 tensor 的所有操作。完成计算后,您可以调用 .backward() 来自动计算所有梯度。该张量的梯度将累积到 .grad 属性中。

要停止 tensor 历史记录的跟踪,您可以调用 .detach(),它将其与计算历史记录分离,并防止将来的计算被跟踪。

要停止跟踪历史记录(和使用内存),您还可以将代码块使用 with torch.no_grad(): 包装起来。在评估模型时,这是特别有用,因为模型在训练阶段具有 requires_grad = True 的可训练参数有利于调参,但在评估阶段我们不需要梯度。

还有一个类对于 autograd 实现非常重要那就是 Function。Tensor 和 Function 互相连接并构建一个非循环图,它保存整个完整的计算过程的历史信息。每个张量都有一个 .grad_fn 属性保存着创建了张量的 Function 的引用,(如果用户自己创建张量,则grad_fn 是 None )。

如果你想计算导数,你可以调用 Tensor.backward()。如果 Tensor 是标量(即它包含一个元素数据),则不需要指定任何参数backward(),但是如果它有更多元素,则需要指定一个gradient 参数来指定张量的形状。

可参考这篇博客:

PyTorch自动求导:Autograd

二.雅各比向量积

从数学上讲,autograd类只是一个雅可比向量积计算工具。简而言之,雅可比矩阵就是表示两个向量的所有可能偏导数的矩阵。它是一个向量相对于另一个向量的梯度。
注:在这个过程中,PyTorch从未显式地构造整个雅可比矩阵。直接计算JVP(雅可比向量积)通常更简单、更有效。
如果一个向量X = [x1, x2,…xn]用于计算其他向量f(X) = [f1, f2, …fn] 通过函数f,则雅可比矩阵(J)简单地包含了所有偏导数组合,如下所示:

 以上矩阵表示f(X)对X的梯度
设PyTorh支持梯度的tensor为
X = [x1, x2, …… xn](假设这是某个机器学习模型的权重)
X经过一些运算得到向量Y
Y = f(X) = [y1, y2, …. ym]
然后用Y来计算标量损失l。假设向量v恰好是标量损失l对向量Y的梯度,如下所示

 向量v被称为grad_tensor并作为参数传递给backward()函数
为了得到损失l对权值X的梯度,将雅可比矩阵J与向量v相乘

这种计算雅可比矩阵并将其与向量v相乘的方法使PyTorch能够轻松地提供外部梯度,即使是非标量输出。
个人理解:X是权重向量,Y是假设函数(Hypothesis function,比如交叉熵或线性函数),l则是整体的损失函数(比如均方误差)。
目的是要计算l关于X的梯度,但是可能直接计算不太方便或者代价大或者存在其他弊端。所以采用先计算Y关于X的梯度,再计算l关于Y的梯度,再利用结果计算l关于X的梯度,这样做应该是有某些好处。

三.练习代码

import torch

#创建一个张量,设置 requires_grad=True 来跟踪与它相关的计算
# x=torch.ones(2,2,requires_grad=True)
# print('x:',x)
# #
# # #针对张量做一个操作
# y = x + 2
# print('y:',y)
#
# #y 作为操作的结果被创建,所以它有 grad_fn
# # 每个张量都有一个 .grad_fn 属性保存着创建了张量的 Function 的引用,(如果用户自己创建张量,则grad_fn 是 None )
# # print(y.grad_fn)
#
# #针对 y 做更多的操作:
# z = y * y * 3
# out = z.mean()
# print('z:',z, 'out:',out)

#.requires_grad_( ... ) 会改变张量的 requires_grad 标记。输入的标记默认为 False ,如果没有提供相应的参数。
#2行2列的张量
# a = torch.randn(2, 2)
# a = ((a * 3) / (a - 1))
# print(a.requires_grad)
# a.requires_grad_(True)
# print(a.requires_grad)
# b = (a * a).sum()
# print(b.grad_fn)

'''
梯度:
我们现在后向传播,因为输出包含了一个标量,out.backward() 等同于out.backward(torch.tensor(1.))。
'''
# 调用 Tensor.backward()来计算导数
# out.backward()
# # 打印梯度 d(out)/dx
# print(x.grad)

# 现在让我们看一个雅可比向量积的例子:
#1行3列
x = torch.randn(3, requires_grad=True)
y = x * 2
'''
data.norm()对张量y每个元素进行平方,然后对它们求和,最后取平方根。 这些操作计算就是所谓的L2范数或欧几里德范数 。
L1范数是指向量中各个元素绝对值之和。
'''
while y.data.norm() < 1000:
    y = y * 2
# print(y)
'''
现在在这种情况下,y不再是一个标量(只有一个元素才叫做标量)。torch.autograd 不能够直接计算整个雅可比矩阵,
但是如果我们只想要雅可比向量积,只需要简单的传递向量给 backward 作为参数。 
'''
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)
print(x.grad)

'''
你可以通过将代码包裹在 with torch.no_grad(),来停止对从跟踪历史中 的 .requires_grad=True 的张量自动求导。 
'''
# print(x.requires_grad)
# print((x ** 2).requires_grad)
# with torch.no_grad():
#     print((x ** 2).requires_grad)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/32831.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

显卡天梯图2022年11月新版 显卡性能排行榜天梯图

1 RTX 3090Ti 2 RTX 3090 3 RX 6900 XT水冷版 我用的显卡就是活动时8折抢购的太划算了 http://www.adiannao.cn/dq 4 RTX 3080 Ti 5 RX 6900 XT 6 Titan RTX 7 RTX 3080 8 RX 6800 XT 9 RX 6800 10 RTX 3070 Ti

C++数据结构X篇_01_数据结构的基本概念

从本篇开始学习数据结构相关概念。 数据结构的基本概念1 数据结构的相关概念1.1 为什么要学习数据结构1.2 数据结构中的基本概念2 算法2.1 算法的概念2.2 算法和数据结构的区别2.3 算法特性2.4 算法效率的度量2.4.1 事后统计法2.4.2 事前分析估算2.4.3 大O表示法2.4.3.1采用大O…

从事先进计算的工程师对此都有什么感想?

电子计算机最初诞生于二十世纪&#xff0c;体积庞大的初代机型运算能力有限&#xff0c;随着计算技术的升级完善&#xff0c;现在多样小巧的计算机及手机的计算能力呈指数级增长&#xff0c;更是成为人们生活密不可分的综合性助手。 先进计算是在计算的基础上诞生的全新概念&a…

Python3安装及基础语法

Python 官网&#xff1a;Welcome to Python.org Python安装&#xff1a;进入官网Download找到对应版本安装包&#xff0c;下载后双击安装&#xff0c;一直下一步即可&#xff1b;注意&#xff1a;安装最后一步勾选&#xff08;Add Python to PATH&#xff09;&#xff0c;默认…

用nginx作反向代理时,请求头中含波浪线无法转发请求的解决方法

请求头如下 POST /CDGServer3/s/rs/uni HTTP/1.1 Content-Type: text/html; charsetUTF-8 method~name: upgradePatchService user~userId: admin.local user~clientId: 343834353230344334424431 user~SessionID: 0 data~packageNo: 618 data~packageState: 1 User-Agent: Ra…

Android -- 每日一问:怎么理解 Activity 的生命周期?

典型回答 如果一个 Activity 在用户可见时才处理某个广播&#xff0c;不可见时注销掉&#xff0c;那么应该在哪两个生命周期的回调方法去注册和注销 BroadcastReceiver 呢&#xff1f; Activity 的可见生命周期发生在 onStart调用与 onStop调用之间。在这段时间&#xff0c;用户…

nginx(六十四)proxy模块(五)接收上游响应

一 接收上游的响应 前提&#xff1a; nginx与上游建立连接,把nginx生成的请求(line、header、body)信息发送给上游补充&#xff1a; 上游解析处理完之后,会发送响应​核心&#xff1a; nginx如何接收、解析、处理上游响应行、响应头、响应体 下载大文件失败 &#xff08;…

一文了解 Go 的复合数据类型(数组、Slice 切片、Map)

一文了解 Go 的复合数据类型[数组、切片 Slice、Map]前言数组数组的创建方式数组的遍历Slice 切片切片的创建方式切片的遍历向切片追加元素MapMap 的创建方式Map 的基本操作插入和修改删除查找操作遍历操作删除操作小结耐心和持久胜过激烈和狂热。 前言 上一篇文章一文熟悉 Go…

CMake Cookbook by Eric

I. Basics 关键字&#xff1a;CMake中的构建指令 指令的书写是大小写无关的&#xff1b; II. Project&#xff1a;指定项目名称和语言类型 命令格式&#xff1a;project(<PROJECT-NAME> [<language-name>...]) Note 项目名称<PROJECT-NAME>不需要与项目根…

论文阅读【7】HHM隐马尔科夫模型

1.隐马尔科夫模型&#xff08;HMM&#xff09;的介绍 隐马尔科夫模型有两个序列&#xff0c;上面一层序列的值称之为影藏值(隐式变量)&#xff0c;下面一个序列中的值被称为观察值&#xff0c;想这个的序列模型被称为生成模型&#xff08;Generate model&#xff09;。z表示的是…

Linux - lsof显示 tcp,udp 的端口和进程

文章目录功能语法示例lsof -i 显示 tcp&#xff0c;udp 的端口和进程等相关查看服务器 80 端口的占用情况使用 -p 查看指定进程打开的文件更多命令功能 lsof&#xff08;list open files&#xff09;是一个列出当前系统打开文件的工具。 lsof 需要访问核心内存和各种文件&…

【区块链技术与应用】(八)

https://blog.csdn.net/lakersssss24/article/details/125762826?spm1001.2014.3001.5501 https://blog.csdn.net/lakersssss24/article/details/126434147 https://blog.csdn.net/lakersssss24/article/details/126671408?spm1001.2101.3001.6650.3&utm_mediumdistribut…

[附源码]java毕业设计医院仪器设备管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

cubeIDE开发, stm32的OLED点亮及字符显示设计(基于SPI通信)

一、SPI 通信技术 显示屏&#xff08;LCD、OLED&#xff09;接口一般有I2C、SPI、UART、RGB、LVDS、MIPI、EDP和DP等。一般3.5寸以下的小尺寸LCD屏&#xff0c;显示数据量比较少&#xff0c;普遍采用低速串口&#xff0c;如I2C、SPI、UART。SPI&#xff08;Serial Peripheral I…

通用后台管理系统前端界面Ⅹ——前端数据表格的删除、查找

删除操作 1、拿到id或者行数据对象 2、查看后端接口方法&#xff0c;写api方法&#xff0c;将操作连接上后端 后端请求操作成功&#xff0c;但是前端数据表格未更新&#xff0c;最简单的一种方法数据删除后要重新获取数据》 依旧显示成功&#xff0c;但是前端数据表格未变化&…

Bert and its family

Bert没有办法一次性读入特别长的文本的问题。自注意力机制非常消耗时间和空间。 概率值最大取argmax&#xff0c;对应的下标 整体全部更新&#xff0c;所有参数都更新&#xff0c;比固定住pre-trained要好很多。 不做预训练&#xff0c;loss下降比较慢&#xff0c;收敛比较慢&a…

BIM在工程中的20种典型功能

1、BIM模型维护 根据项目建设进度建立和维护BIM模型&#xff0c;实质是使用BIM平台汇总各项目团队所有的建筑工程信息&#xff0c;消除项目中的信息孤岛&#xff0c;并且将得到的信息结合三维模型进行整理和储存&#xff0c;以备项目全过程中项目各相关利益方随时共享。 由于…

Java 微信关注/取消关注事件

Java 微信关注/取消关注事件一、需求、思路二、文档、配置配置步骤1配置步骤2三、代码1、引入依赖包2、controller3、封装消息对象4、service、解密5、工具包一、需求、思路 需求&#xff1a;用户订阅/取消订阅公众号时接收消息并保存到数据库中以便后续功能的处理。 思路&…

【分类-SVDD】基于支持向量数据描述 (SVDD) 的多类分类算法附matlab代码

✅作者简介&#xff1a;热爱科研的Matlab仿真开发者&#xff0c;修心和技术同步精进&#xff0c;matlab项目合作可私信。 &#x1f34e;个人主页&#xff1a;Matlab科研工作室 &#x1f34a;个人信条&#xff1a;格物致知。 更多Matlab仿真内容点击&#x1f447; 智能优化算法 …

机器学习-回归模型相关重要知识点

目录01 线性回归的假设是什么&#xff1f;02 什么是残差&#xff0c;它如何用于评估回归模型&#xff1f;03 如何区分线性回归模型和非线性回归模型&#xff1f;04 什么是多重共线性&#xff0c;它如何影响模型性能&#xff1f;05 异常值如何影响线性回归模型的性能&#xff1f…