【PyTorch 攻略 (4/7)】张量和梯度函数

news2024/11/15 13:53:10

一、说明

 

        W在训练神经网络时,最常用的算法是反向传播。在该算法中,参数(模型权重)根据损失函数相对于给定参数的梯度进行调整。损失函数计算神经网络产生的预期输出和实际输出之间的差异。
        目标是获得尽可能接近零的损失函数的结果。反向传播算法通过神经网络向后遍历,以调整权重和偏差以重新训练模型。这种随着时间的推移重新训练模型的来回和前进过程将损失减少到 0,称为梯度下降

        为了计算这些梯度,PyTorch有一个名为torch.autograd的内置微分引擎。它支持任何计算图的梯度自动计算。

%matplotlib inline
import torch

x = torch.ones(5)   # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

二、张量、函数和计算图

        这段代码定义了以下计算图:
        在这个网络中,wb是我们需要优化的参数。因此,我们需要能够计算损失函数相对于这些变量的梯度。为了做到这一点,我们设置了这些张量的requires_grad属性。

        我们应用于张量来构造计算图的函数实际上是一个对象类函数。此对象知道如何在向前方向上计算函数,以及如何向后传播步骤中计算其导数。对向后传播函数的引用存储在张量的 grad_fn 属性中。

print('Gradient function for z =',z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)
Gradient function for z = <AddBackward0 object at 0x00000280CC630CA0>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward object at 0x00000280CC630310>

三、计算梯度

        为了优化神经网络中参数的权重,我们需要计算损失函数相对于参数的导数,即我们需要在xy的固定值下∂w/∂loss和∂loss/∂b。为了计算这些导数,我们调用 loss.backward(),然后从 w.grad 和 b.grad 中检索值。

loss.backward()
print(w.grad)
print(b.grad)
tensor([[0.2739, 0.0490, 0.3279],
        [0.2739, 0.0490, 0.3279],
        [0.2739, 0.0490, 0.3279],
        [0.2739, 0.0490, 0.3279],
        [0.2739, 0.0490, 0.3279]])
tensor([0.2739, 0.0490, 0.3279]) 

四、禁用渐变跟踪

        默认情况下,所有 requires_grad=True 的张量都在跟踪它们的计算历史并支持梯度计算。但是,在某些情况下,我们不需要这样做,例如,当我们训练了模型并只想将其应用于某些输入数据时,即我们只想通过网络进行前向计算。我们可以通过在计算代码周围用 torch.no_grad() 块来停止跟踪计算。

z = torch.matmul(x, w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)
True
False

        您可能想要禁用梯度跟踪的原因如下:
- 将神经网络中的某些参数标记为冻结参数。这是微调预训练网络的一种非常常见的方案。
- 在只执行正向传递时加快计算速度,因为对不跟踪梯度的张量的计算会更有效。

        C从概念上讲,Autograd 在由函数对象组成的有向无环图 (DAG) 中记录数据(张量)和所有执行的操作(以及生成的新张量)。在此 DAG 中,叶子是输入张量,根是输出张量。通过跟踪从根到叶的图形,您可以使用链式规则自动计算梯度。

        在正向传递中,autograd 同时做两件事:
- 运行请求的操作以计算生成的张量
- 在 DAG 中维护操作的梯度函数

        在向后传递中,.back() 在 DAG 根目录上调用。然后,
autograd :- 计算每个.grad_fn
的梯度 - 将它们累积在相应张量的 .grad 属性
中 - 使用链式规则一直传播到叶张量

DAG 在 PyTorch 中是动态的。

        需要注意的重要一点是,图形是从头开始重新创建的;每次 .backward() 调用后,Autograd 开始填充一个新图形。这正是允许您在模型中使用控制流语句的原因。
如果需要,您可以在每次迭代时更改形状、大小和操作。

五、张量梯度和雅可比积

        在许多情况下,我们有一个标量损失函数,我们需要计算相对于某些参数的梯度。但是,在某些情况下,输出函数是任意张量。在这种情况下,PyTorch 允许您计算所谓的雅可比乘积,而不是实际的梯度。

        对于向量函数 y* = f(x*),其中 x* = (x1, ..., xn) 和 y* = (y1, ..., ym),
y* 相对于 x* 的梯度由雅可比矩阵给出其元素 J 包含 ∂xi/∂yj

        PyTorch 不是计算雅可比矩阵本身,而是允许您计算雅可比乘积。J 对于给定的输入向量 v = (v1, ..., vm)。
这是通过使用 v 作为参数向调用来实现的。v 的大小应该与原始张量的大小相同,我们想要计算乘积。

inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True)
print("First call\n", inp.grad)

out.backward(torch.ones_like(inp), retain_graph=True)
print("\nSecond call\n", inp.grad)

inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print("\nCall after zeroing gradients\n", inp.grad)
First call
 tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])

Second call
 tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.],
        [4., 4., 4., 4., 8.]])

Call after zeroing gradients
 tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])

        请注意,当我们使用相同的参数第二次向后调用时,梯度的值是不同的。发生这种情况是因为在进行向后传播时,PyTorch 会累积梯度,即计算梯度的值被添加到计算图的所有叶节点的 grad 属性中。如果要计算正确的梯度,则需要在之前将 grad 属性归零。在现实生活中的训练中,优化器可以帮助我们做到这一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1019362.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一款 Linux 邮件客户端—Nylas Mail

导读Linux 上面有许多邮件客户端&#xff0c;Geary、Empathy、Evolution 和 Thunderbird 本身已经为很多用户提供了很好的服务&#xff0c;但是我发现了值得一试的软件&#xff1a; Nylas Mail。 为什么使用 Nylas&#xff1f; ​很多人因为种种原因选择了 Nylas Mail。让我们…

雷池社区WAF:保护您的网站免受黑客攻击 | 开源日报 0918

keras-team/keras Stars: 59.2k License: Apache-2.0 Keras 是一个用 Python 编写的深度学习 API&#xff0c;运行在机器学习平台 TensorFlow 之上。它 简单易用&#xff1a;减少了开发者认知负荷&#xff0c;使其能够更关注问题中真正重要的部分。灵活性强&#xff1a;通过逐…

Learn Prompt-ChatGPT 精选案例:代码助理

你可以使用 ChatGPT 进行代码生成、生成测试用例、注释、审查和漏洞检测。 代码生成​ 我们可以让 ChatGPT 自动生成一个排序算法中的快速排序的Python代码。 简单的代码生成对于 ChatGPT 来说小事一桩。 测试用例​ 用例来源出自StuGRua 在待测函数函数定义清晰的情况下…

什么是气传导耳机?气传导耳机值得入手吗?

​随着生活节奏的加快&#xff0c;人们越来越关注听力健康。气传导耳机以其独特的传导方式和舒适的佩戴感受&#xff0c;逐渐成为耳机市场的新宠。气传导耳机不入耳设计听音&#xff0c;让你在享受音乐的同时&#xff0c;也能保护你的听力安全。今天我们就一起来看看几款值得大…

无涯教程-JavaScript - COMBINA函数

描述 COMBINA函数返回给定数量的项目的组合数量(重复)。 语法 COMBINA (number, number_chosen)争论 Argument描述Required/OptionalNumber 必须大于或等于0,并且大于或等于Number_chosen。 非整数值将被截断。 RequiredNumber_chosen 必须大于或等于0。 非整数值将被截断。…

国际上被广泛认可的电子邮箱服务有哪些?

随着全球化的发展&#xff0c;越来越多的企业开始涉足国际贸易。在众多的邮箱服务提供商中&#xff0c;哪些是国际上比较认可的呢&#xff1f;本文将为您详细介绍几款在全球范围内广受好评的邮箱服务&#xff1a;Gmail(谷歌邮箱)、Outlook(微软邮箱)、Yahoo Mail(雅虎邮箱)、Zo…

品牌营销|小红书母婴消费的发展趋势,与创新路径在哪里

随着社会的发展与进步&#xff0c;母婴消费已经成为现代家庭生活中的重要组成部分。人们对于孩子的关爱和需求日益增长&#xff0c;母婴市场也变得愈发繁荣。今天来分享下品牌营销&#xff0c;小红书母婴消费的发展趋势&#xff0c;与创新路径在哪里&#xff1f; 一、母婴消费的…

【Redis面试题(46道)】

文章目录 Redis面试题&#xff08;46道&#xff09;基础1.说说什么是Redis?2.Redis可以用来干什么&#xff1f;3.Redis 有哪些数据结构&#xff1f;4.Redis为什么快呢&#xff1f;5.能说一下I/O多路复用吗&#xff1f;6. Redis为什么早期选择单线程&#xff1f;7.Redis6.0使用…

抖音带货怎么找货源合作?

随着社交媒体的快速发展&#xff0c;抖音已成为销售商品的重要平台。越来越多的个人和企业开始在抖音上销售​​商品&#xff0c;但寻找合适的货源进行合作是一个很大的挑战。本文将为您介绍一些寻找合作货源的方法和技巧。 如何寻找抖音合作的货源&#xff1f; 确定你的目标市…

数据库管理-第104期 RAC上升级SSH的坑(20230918)

数据库管理-第104期 RAC上升级SSH的坑&#xff08;20230918&#xff09; 最近一些版本的OpenSSH和OpenSSL都爆出了比较严重的漏洞&#xff0c;但是Oracle数据库尤其是RAC升级SSH和SSL其实是有一定风险的&#xff0c;这里就借助我的OCM环境&#xff0c;做一次SSH升级的演示及排…

用ModelScope给大家送上中秋祝福

用ModelScope来阐述中秋的意义 第一 中秋节的背景 接下来我们继续深入一下看看ModelScope的理解 可以看当我们讨论家庭团聚时&#xff0c;ModelScope 对这个主题的理解的确十分准确。然而&#xff0c;有时候我们在表达这个概念时可能会变得有些过于正式和僵硬&#xff0c;这样…

【码银送书第七期】七本考研书籍

八九月的朋友圈刮起了一股晒通知书潮&#xff0c;频频有大佬晒出“研究生入学通知书”&#xff0c;看着让人既羡慕又焦虑。果然应了那句老话——比你优秀的人&#xff0c;还比你努力。 心里痒痒&#xff0c;想考研的技术人儿~别再犹豫了。小编咨询了一大波上岸的大佬&#xff…

论文解读 | YOLO系列开山之作:统一的实时对象检测

原创 | 文 BFT机器人 01 摘要 YOLO是一种新的目标检测方法&#xff0c;与以前的方法不同之处在于它将目标检测问题视为回归问题&#xff0c;同时预测边界框和类别概率。这一方法使用单个神经网络&#xff0c;可以从完整图像中直接预测目标边界框和类别概率&#xff0c;实现端…

二叉树的概念、存储及遍历

一、二叉树的概念 1、二叉树的定义 二叉树&#xff08; binary tree&#xff09;是 n 个结点的有限集合&#xff0c;该集合或为空集&#xff08;空二叉树&#xff09;&#xff0c;或由一个根结点与两棵互不相交的&#xff0c;称为根结点的左子树、右子树的二叉树构成。 二叉树的…

ClickHouse进阶(十七):clickhouse优化-写出查询优化

进入正文前&#xff0c;感谢宝子们订阅专题、点赞、评论、收藏&#xff01;关注IT贫道&#xff0c;获取高质量博客内容&#xff01; &#x1f3e1;个人主页&#xff1a;含各种IT体系技术,IT贫道_大数据OLAP体系技术栈,Apache Doris,Kerberos安全认证-CSDN博客 &#x1f4cc;订…

4G工业路由器,开启智能工厂,这就是关键所在

​提到工业物联网,首先联想到的就是数据传输。要把海量的工业数据从设备端传到控制中心,无线数传终端就发挥着重要作用。今天就跟着小编来看看它的“联”是怎么建立的吧! 原文&#xff1a;https://www.key-iot.com/iotlist/1838.html 一提到无线数传终端,相信大家首先想到的是…

Python 元组的常用方法

视频版教程 Python3零基础7天入门实战视频教程 下标索引用法和列表一样&#xff0c;唯一区别就是不能修改元素 实例&#xff1a; # 下标索引用法和列表一样&#xff0c;唯一区别就是不能修改元素 t1 ("java", "python", "c") # t1[1] "…

【PyTorch 攻略 (3/7)】线性组件、激活函数

一、说明 神经网络是由层连接的神经元的集合。每个神经元都是一个小型计算单元&#xff0c;执行简单的计算来共同解决问题。它们按图层组织。有三种类型的层&#xff1a;输入层、隐藏层和输出层。每层包含许多神经元&#xff0c;但输入层除外。神经网络模仿人脑处理信息的方式。…

虹科分享 | 谷歌Vertex AI平台使用Redis搭建大语言模型

文章来源&#xff1a;虹科云科技 点此阅读原文 基础模型和高性能数据层这两个基本组件始终是创建高效、可扩展语言模型应用的关键&#xff0c;利用Redis搭建大语言模型&#xff0c;能够实现高效可扩展的语义搜索、检索增强生成、LLM 缓存机制、LLM记忆和持久化。有Redis加持的大…

Docker启动Mysql容器并进行目录挂载

一、创建挂载目录 mkdir -p 当前层级下创建 mkdir -p mysql/data mkdir -p mysql/conf 进入到conf目录下创建配置文件touch hym.conf 并把配置文件hmy.conf下增加以下内容使用vim hym.conf即可添加(cv进去就行) Esc :wq 保存 [mysqld] skip-name-resolve character_set_…