【PyTorch】教程:学习基础知识-(6) Autograd

news2024/9/25 19:19:36

AUTOMATIC DIFFERENTIATION WITH torch.autograd

在训练神经网络时,最常用的算法是反向传播算法,在该算法中,参数根据损失函数相对于给定参数的梯度进行调整。

为了计算这些梯度, PyTorch 有一个内置的微分引擎 torch.autograd 。它智慧任何计算图的梯度自动计算。

考虑最简单的单层神经网络,输入 x, 参数 w 和 b, 以及一些损失函数。 它可以在 PyTorch 中以以下方式定义:

import torch 
x = torch.ones(5) # 输入
y = torch.zeros(3) # 输出
w = torch.randn(5, 3, requires_grad=True) # 权重
b = torch.randn(3, requires_grad=True) # 偏置
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
print(f"loss = {loss}")
loss = 0.7214622497558594

Tensors, Functions and Computational graph

上述代码定义了以下计算图

在这里插入图片描述

这个网络中,wb 是需要优化的参数。因此,我们需要能够计算关于这些变量的损失函数的梯度。为了做到这一点,我们设置了这些 tensorrequires_grad 属性。

你可以在创建一个 tensor 时设置 requires_grad 值,也可以在以后使用时利用 x.requires_grad(True) 方法。

应用于 tensor 来构建计算图的函数实际上 function 类的一个对象,该对象知道如何前向计算,也知道在反向传播时计算导数。反向传播的引用存储在 tensor 的 grad_fn 属性中。

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")
Gradient function for z = <AddBackward0 object at 0x00000230B70C84F0>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward object at 0x00000230B70C8D30>

Computing Gradients

为了计算神经网络的权重参数,我们需要计算 loss function 对参数的导数。即,我们需要在给定 x x x y y y 计算 ∂ l o s s ∂ w \frac{\partial loss }{\partial w} wloss ∂ l o s s ∂ b \frac{\partial loss }{\partial b} bloss 。为了计算这些导数,我们调用 loss.backward(), 然后从 w.gradb.grad 提取值。

loss.backward()
print(w.grad)
print(b.grad)
tensor([[0.0494, 0.0666, 0.2772],
        [0.0494, 0.0666, 0.2772],
        [0.0494, 0.0666, 0.2772],
        [0.0494, 0.0666, 0.2772],
        [0.0494, 0.0666, 0.2772]])
tensor([0.0494, 0.0666, 0.2772])

我们只能获得计算图的叶节点的 grad 属性,它们的 requires_grad 属性设置为 True 。对于图中的所有其他节点,梯度将不可用。
出于性能考虑,我们只能在给定的图上反向执行一次梯度计算。如果我们需要对同一个图执行多个反向调用,我们需要将 retain_graph=True 传递给反向调用。

Disabling Gradient Tracking (不跟踪梯度)

默认情况下,所有 requires_grad=Truetensor 都在跟踪它们的计算历史并支持梯度计算。然而,有些情况下我们不需要这样做,例如,当我们训练了模型,只想对输入数据进行计算,我们只需要通过网络进行 forward 计算。我们可以用 torch.no_grad() 停止跟踪梯度计算。

z = torch.matmul(x, w) + b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w) + b

print(z.requires_grad)
True
False

实现相同结果的另一种方法是在 tensor 上使用 detach() 方法

z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)
False

为什么要禁用梯度跟踪呢?

    1. 将神经网络中的一些参数标记为冻结参数,这是一个非常常见的场景finetuning a pretrained network
    1. 为了加速计算。因为在不跟踪梯度的 tensor 上计算会更有效。

More on Computational Graphs

从概念上讲,autograd 在一个由 Function 对象组成的有向无环图 ( DAG ) 中保存数据( tensors )和所有执行的操作( 以及由此产生的新张量 )的记录。在这个 DAG 中,叶子节点是输入 tensors ,根是输出 tensors 。通过从根到叶跟踪这个图,您可以使用链式法则自动计算梯度。

在向前传播中,autograd 同时做两件事:

  1. 根据要求的操作计算结果 tensor
  2. DAG 中维护操作的梯度函数。

当在 DAG 根上调用 .backward() 时,后向传播开始,autograd 然后

  • 从每个 .grad_fn 中计算梯度;
  • 在各自的 tensor.grad 属性中累积它们
  • 利用链式法则,一直传播到叶子 tensor

PyTorch 中的 DAG 是动态的,需要注意的一件重要的事情是,图是从头开始重新创建的,在每次 .backward() 调用之后, autograd 开始填充一个新的图。 这正是允许您在模型中使用控制流语句的原因,如果需要,您可以在每次迭代中更改形状、大小和操作等。

Optional Reading: Tensor Gradients and Jacobian Products

在很多情况下,我们有一个标量损失函数,我们需要计算关于一些参数的梯度,然后,也有输出函数是任意 tensor 的情况,在这种情况下,PyTorch 允许你计算所谓的雅克比矩阵积,而不是实际的梯度。

对于一个向量函数

y → = f ( x → ) \overrightarrow{y} = f(\overrightarrow{x}) y =f(x ) , 当 $ \overrightarrow{x}=<x_1,…x_n>$, y → = < y 1 , . . . y m > \overrightarrow{y}=<y_1,...y_m> y =<y1,...ym>, y → \overrightarrow{y} y x → \overrightarrow{x} x 的导数可以利用 Jacobian matrix (雅克比矩阵):

J = ( ∂ y 1 ∂ x 1 ⋯ ∂ y 1 ∂ x n ⋮ ⋱ ⋮ ∂ y m ∂ x 1 ⋯ ∂ y m ∂ x n ) \begin{equation*} J = \begin{pmatrix} \frac{\partial y_1}{\partial x_1} & \cdots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \cdots & \frac{\partial y_m}{\partial x_n} \\ \end{pmatrix} \end{equation*} J= x1y1x1ymxny1xnym

PyTorch 允许你在给定 v = ( v 1 , . . . v m ) v=(v_1,...v_m) v=(v1,...vm) 向量时计算雅克比矩阵的点积 v T ⋅ J v^T \cdot J vTJ, 而不是雅克比矩阵本身。可以通过 backward v v v 作为参数而获得。 v v v 的大小应该与原始的 tensor 大小相同。

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")

out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")

inp.grad.zero_()

out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

注意: 当我们用相同的参数第二次调用 backward() 时,梯度值是不同的。这是因为在进行反向传播时,PyTorch 会累积梯度,即计算梯度的值被添加到计算图的所有叶节点的 grad 属性中,如果逆向计算合适的梯度,你需要在之前讲梯度属性归零。在现实训练中,优化器可以帮助我们做到这一点。
以前我们调用 backward() 函数时不带参数。这本质上相当于调用 backward(torch.tensor(1.0)),这是一种有用的方法,可以在标量值函数的情况下计算梯度,例如神经网络训练期间的损失。

【参考】

Automatic Differentiation with torch.autograd — PyTorch Tutorials 1.13.1+cu117 documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/170633.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2022秋招算法岗面经题:训练模型时loss除以10和学习率除以10真的等价吗(SGD等价,Adam不等价)

问题描述&#xff1a;训练深度学习模型时loss除以10和学习率除以10等价吗&#xff1f; 先说结论 这个问题的答案与优化器有关 使用Adam、Adagrad、RMSprop等带有二阶动量vtv_tvt​的优化器训练时&#xff0c;当我们将loss除以10&#xff0c;对训练几乎没有影响。使用SGD、Mo…

Streamlit自定义组件开发教程

在这篇文章中&#xff0c;我们将学习如何构建Streamlit组件以及如何发布streamlit组件供其他人使用。 使用 3D场景编辑器快速搭建三维数字孪生场景 1、什么是Streamlit组件&#xff1f; Streamlit 组件是一个可共享的 Streamlit 插件&#xff0c;可让你为应用程序添加新的视觉…

Java——子集

题目链接 leetcode在线oj题——子集 题目描述 给你一个整数数组 nums &#xff0c;数组中的元素 互不相同 。返回该数组所有可能的子集&#xff08;幂集&#xff09;。 解集 不能 包含重复的子集。你可以按 任意顺序 返回解集。 题目示例 输入&#xff1a;nums [1,2,3] …

MySQL管理

1&#xff1a;MySQL管理1.1&#xff1a;系统数据库Mysql数据库安装完成后&#xff0c;自带了一下四个数据库&#xff0c;具体作用如下&#xff1a; 数据库 含义 mysql 存储MySQL服务器正常运行所需要的各种信息 &#xff08;时区、主从、用 户、权限等&#xff09; information…

性能测试/实战演示 H5 性能分析

W3C标准是浏览器标准&#xff0c;一般浏览器都支持W3C标准&#xff0c;它规定使用者可以通过api查询性能信息&#xff0c;可借用W3C协议完成自动化H5性能测试。 W3C官网&#xff1a;Navigation Timing 使用chrome浏览器对webview进行手工查看&#xff0c;伴随着业务增多&#x…

mysql:索引的数据结构,B树,B+树浅聊

mysql&#xff1a;索引的数据结构 什么是索引&#xff1f; 索引&#xff08;Index&#xff09;是帮助MySQL高效获取数据的数据结构 为什么学索引&#xff1f; 之前应该有概念说&#xff0c;把索引理解为目录&#xff0c;比如通过s就可以查询到s开头的汉子从哪也开始&#xff…

[网鼎杯 2020 青龙组]AreUSerialz

目录 信息收集 代码审计 前提知识 思路分析 绕过检测 方法一 poc payload 方法二 poc payload 信息收集 进入页面给出了源代码如下&#xff0c;是一道PHP的反序列化题目 <?phpinclude("flag.php");highlight_file(__FILE__);class FileHandler {pro…

Linux多线程 线程概念 | 线程VS进程 | 线程控制【万字精讲】

线程 一、线程概念 1. 知识支持及回顾 在我们一开始学习进程的时候。我们总说进程在内部执行时&#xff0c;是OS操作系统调度的基本单位。其实并不严谨&#xff0c;今天&#xff0c;我们要重新完善这个说法——线程在进程内部运行&#xff0c;线程是OS操作系统调度的基本单位…

WorkPlus移动办公平台,助力企业随时随地“指尖办公”

近年来&#xff0c;随着移动互联网的发展&#xff0c;越来越多的人习惯于随时随地通过移动设备完成工作、购物、游戏等。移动办公应用就是基于移动终端的信息化办公应用&#xff0c;利用企微、钉钉、WorkPlus等移动办公平台&#xff0c;实现企业与员工间的随时随地工作、沟通&a…

技术分享 | ClickHouse StarRocks 使用经验分享

作者&#xff1a;许天云 本文来源&#xff1a;原创投稿 *爱可生开源社区出品&#xff0c;原创内容未经授权不得随意使用&#xff0c;转载请联系小编并注明来源。 一. 大纲 本篇分享下个人在实时数仓方向的一些使用经验&#xff0c;主要包含了ClickHouse 和 StarRocks 这两款目…

ASP.NET MVC解决方案的搭建(.NET Framework)——C#系列(一)

一、新建项目 1、控制器新建 2、Service层新建 3、Business数据层新建 4、Commons公共层新建 5、Models实体层新建 二、调用接口 1、接口建立 Web API 2 控制器新建 2、调用 三、Swagger接口调试配置 1、添加NuGet包 在启动项中添加Swashbuckle NuGet包 2、访问 htt…

Vivado2018.3安装及注册指南-安装包获取

一、vivado 介绍 vivado设计套件 是FPGA 厂商赛灵思&#xff08;Xilinx&#xff09;公司最新的为其产品定制的集成开发环境&#xff0c;支持Block Design、Verilog、VHDL等多种设计输入方式&#xff0c;内嵌综合器以及仿真器&#xff0c;可以完成从设计输入、综合适配、仿真到…

mysql存储过程的基础知识

本文来简单说下存储过程的基础知识 文章目录概述什么是存储过程存储过程的优缺点概述 mysql官网提供的储存过程&#xff1a;https://www.mysqlzh.com/doc/225/499.html 什么是存储过程 简单的说&#xff0c;存储过程是一条或者多条SQL语句的集合&#xff0c;可视为批文件&…

SAP采购中不基于收货的发票校验的价差计算过程实例

年前最后一天上班了&#xff0c;我还在帮财务分析一个价差问题。源于财务用户头天的一个请求&#xff0c;在不基于收货的发票校验中&#xff0c;计算倒不难&#xff0c;难的是梳理数量的逻辑关系。不过总算时间也没白花&#xff0c;记录下来。下次就不算了&#xff0c;能解释清…

理解使用并查集

目录 一.并查集原理 1.概念&#xff1a; 2.性质 3.形式 4. 合并方式 二.并查集实现 1.成员变量 2.构造函数 3.查找根 4.合并集合 5.判断是否在一个集合 6.查看集合数量 三.并查集总代码 前言&#xff1a;理解并查集是为了接下来学习图要用&#xff0c;而会使用并查…

Linux:系统性能监控工具-tsar安装和使用

在上家公司做性能压力测试时就用过tsar&#xff0c;但总结文档留在了内部&#xff0c;正好借着最近工作内容又用上了tsar&#xff0c;总结起来 目录前言tsar介绍总体架构安装tasrtsar配置介绍配置文件定时任务配置日志文件tsar使用tsar实际使用参考查看可用的监控模块列表查看C…

本松新材创业板IPO终止:业绩下滑,客户较集中,周永松为实控人

撰稿|汤汤 来源|贝多财经 近日&#xff0c;深圳证券交易所披露的信息显示&#xff0c;杭州本松新材料技术股份有限公司&#xff08;下称“本松新材”&#xff09;提交了撤回上市申请文件的申请&#xff0c;保荐人财通证券也撤回对该公司的保荐。因此&#xff0c;深交所终止了…

目标检测:YOLOV3技术详解

目标检测&#xff1a;YOLOV3技术详解前言主要改进DarkNet53新的分类器正负样本的匹配损失函数前言 YOLOV3是V2的升级版&#xff0c;也是原作者的绝笔&#xff0c;V3主要还是把当时一些有用的思想融入了进来&#xff0c;没有什么创新型的突破&#xff0c;具体细节我们下面介绍。…

【安卓逆向】Frida入门与常用备忘

【安卓逆向】Frida入门与常用备忘前置知识什么是hook&#xff1f;hook的作用常见的逆向工具Frida使用入门简介与资料参考备忘前置环境配置执行hook常用hook脚本/API前置知识 什么是hook&#xff1f; hook&#xff0c;译为“钩子”&#xff0c;是指将方法/函数勾住&#xff0c;…

2022-CSDN的一年

前言 马上要到兔年的春节&#xff0c;年前最后一个版本顺利上线&#xff0c;闲下来两天&#xff0c;可以对过往一年进行一下总结&#xff0c;说起来这是入职CSDN之后第一次自己将自己所思所想以以博客的形式发布在CSDN网站上&#xff0c;也是比较奇特的体验。语言表达能力不强&…