PyTorch翻译官网教程6-AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD

news2025/1/13 2:48:21

官网链接

Automatic Differentiation with torch.autograd — PyTorch Tutorials 2.0.1+cu117 documentation

使用TORCH.AUTOGRAD 自动微分

当训练神经网络时,最常用的算法是方向传播算法。在该算法中,根据损失函数与给定参数的梯度来调整模型参数(权重)。

为了计算这些梯度,PyTorch有一个内置的微分引擎,名为torch.autograd。它支持任何计算图的梯度自动计算。

考虑最简单的单层神经网络,输入x,参数w和b,以及一些损失函数。它可以在PyTorch中以以下方式定义:

import torch

x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

张量、函数与计算图

这段代码定义了以下计算图:

在这个网络中,w和b是我们需要优化的参数。因此,我们需要能够计算损失函数相对于这些变量的梯度。为了做到这一点,我们设置了这些张量的requires_grad属性。

注意:

您可以在创建张量时设置requires_grad的值,或者稍后使用x.requires_grad_(True)方法设置。

我们使用张量来构造计算图的函数实际上是Function类的对象,该对象知道如何在正向方向上计算函数,以及如何在反向传播步骤中计算其导数。反向传播函数的引用存储在张量的grad_fn 属性中,你可以在文档中找到Function 的更多信息。Automatic differentiation package - torch.autograd — PyTorch 2.0 documentation

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

输出

Gradient function for z = <AddBackward0 object at 0x114113f70>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x114113f70>


计算梯度

为了优化神经网络中参数的权重,我们需要计算损失函数对于参数的导数,即我们需要\frac{∂loss}{∂w} 和 \frac{∂loss}{∂b} 在x和y的固定值下。为了计算这些导数,我们调用loss.backward(),然后从w.gradb.grad中检索值:

loss.backward()
print(w.grad)
print(b.grad)

输出

tensor([[0.3313, 0.0626, 0.2530],
        [0.3313, 0.0626, 0.2530],
        [0.3313, 0.0626, 0.2530],
        [0.3313, 0.0626, 0.2530],
        [0.3313, 0.0626, 0.2530]])
tensor([0.3313, 0.0626, 0.2530])

注意

  • 我们只能获得计算图的叶子节点的grad属性,当它的requires_grad属性设置为True时。对于图中的所有其他节点,梯度将不可用。
  • 出于性能原因,我们只能在给定的图上使用一次backward梯度计算。如果需要对同一个图进行多次backward调用,则需要将retain_graph=True 传递给backward调用。

禁用梯度跟踪

默认情况下,所有的requires_grad=True 的张量会自动跟踪它们的计算历史并支持梯度计算。然而,在一些情况下我们不需要这样做,例如,当我们完成了模型的训练,只想将其应用于一些测试数据时,即我们只想通过网络进行前向计算。我们可以通过使用torch.no_grad() 块包围我们的计算代码来停止跟踪计算:

z = torch.matmul(x, w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)

输出

True
False

实现相同结果的另一种方法是在张量上使用detach() 方法:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

输出

False

你可能想要禁用梯度跟踪的原因如下:

  • 将神经网络中的一些参数标记为冻结参数
  • 当你只做正向传递时,为了加快计算速度,在不跟踪梯度的张量上的计算会更有效率。

更多关于计算图的知识

从概念上讲,autograd在由Function 的大小应该等于原始张量的大小,为了计算其乘积。

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
 

输出

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

注意,当我们使用相同的参数第二次调用backward时,梯度的值是不同的。这是因为在进行backward传播时,PyTorch会累积梯度的,即计算梯度的值被添加到计算图的所有叶节点的grad属性中。如果你想计算合适的梯度,你需要在此之前将grad属性归零。在现实训练中,优化器可以帮助我们做到这一点

注意

以前我们调用没有参数的backward()函数。这基本上相当于调用backward(torch.tensor(1.0)),这是在标量值函数的情况下计算梯度的有用方法,例如神经网络训练期间的损失。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/756272.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习---定义、用途、算法的分类、假设空间与归纳偏好、奥卡姆剃刀原则

1. 机器学习的定义 基于历史经验的&#xff0c;描述和预测的理论、方法和算法。 从历史数据中&#xff0c;发现某些模式或规律&#xff08;描述&#xff09;&#xff0c;利用发现的模式和规律进行预测。 2. 机器学习能做什么 机器学习已经有了十分广泛的应用&#xff0c;例…

pdf文件大小如何压缩?pdf文件怎么压缩得更小?

日常生活和工作中&#xff0c;经常用到图片&#xff0c;但是有时候需要将图片压缩指定大小来符合各种规定&#xff0c;比如图片压缩到200kb&#xff0c;那么有没有简单方便的图片压缩&#xff08; https://www.yasuotu.com/imagesize&#xff09;的方法呢&#xff1f;下面就拿压…

【测试开发】案例分析

目录 一. 模拟弱网 二. 接口测试 三. 对冒泡排序进行测试 四. 对于 Linux 命令进行测试 五. 微信发送朋友圈设计测试用例 六. 补充 一. 模拟弱网 模拟弱网环境可以借助 Fiddler 来进行&#xff1b; 1. 先要打开 Simulate Modem Speeds 选项&#xff1b; 2. 打开 Customize R…

一起学SF框架系列5.8-模块Beans-注解bean解析1-解析入口

前面跟踪了Spring框架如何解析xml模式配置的bean解析&#xff08;参见“一起学SF框架系列5.7-模块Beans-BeanDefinition解析”&#xff09;&#xff0c;本文主要解析注解bean&#xff08;详见“一起学SF框架系列5.2-模块Beans-bean的元数据配置”&#xff09;是如何被Spring框架…

scripy其他

持久化 # 爬回来&#xff0c;解析完了&#xff0c;想存储&#xff0c;有两种方案 ## 方案一&#xff1a;一般不用 parse必须有return值&#xff0c;必须是列表套字典形式--->使用命令&#xff0c;可以保存到json格式中&#xff0c;csv中scrapy crawl cnblogs -o cnbogs.j…

IEEE WCCI-2020电动汽车路由问题进化计算竞赛的基准集

引言 交通一直是二氧化碳排放的主要贡献者。由于全球变暖、污染和气候变化&#xff0c;联邦快递、UPS、DHL和TNT等物流公司对环境变得更加敏感&#xff0c;他们正在投资于减少作为其日常运作的一部分而产生的二氧化碳排放的方法。毫无疑问&#xff0c;使用电动汽车&#xff08;…

JavaWeb——Linux的常用命令

目录 一、Linux优点 二、Linux常用命令 1、ls &#xff08;1&#xff09;、语法 &#xff08;2&#xff09;、功能 &#xff08;3&#xff09;、常用选项 例: 2、pwd &#xff08;1&#xff09;、语法 &#xff08;2&#xff09;、功能 例: 3、cd &#xff08;1&am…

Doc as Code (1):起源

作为技术传播从业者&#xff0c;你一定听说过Doc as Code&#xff0c;中文大家叫做文档代码化。 近年来&#xff0c;这个词在技术传播行业传开了。也许是在某个大会上&#xff0c;也许是在某篇文章中&#xff0c;再或者是在与同行的讨论群里&#xff0c;不管是从哪里&#xff…

DAY47:动态规划(九)完全背包理论基础

文章目录 完全背包示例与01背包的区别&#xff1a;遍历顺序常规遍历写法DP状态图-为什么背包正序就能放进来重复物品 for循环的嵌套&#xff0c;外层物品内层背包能否颠倒&#xff1f;for嵌套顺序颠倒的遍历写法 测试示例面试题目总结 课程链接&#xff1a; 代码随想录 (progr…

自动生成spring-configuration-metadata.json文件

在开发过程中为避免重复修改代码&#xff0c;往往将代码中容易发生变更的值提取出来放到配置文件中。例如数据库连接信息&#xff0c;使用Http调用第三方应用的网关地址等信息。 使用Sprin Boot的ConfigurationPropertie 从配置文件中读取属性值方法多样&#xff0c;这里介绍…

【反向代理】反向代理及其作用

反向代理及其作用 一、什么是正向代理 在介绍反向代理之前我们先介绍什么是正向代理 首先要明确的是&#xff0c;在http协议中正向代理一般被称为代理&#xff0c;在web服务中我们可以通过主动配置代理服务器的方式来发送请求&#xff0c;并通过代理服务器接收服务器的响应。…

自学网络安全(成为黑客)

一、前言 黑客这个名字一直是伴随着互联网发展而来&#xff0c;给大家的第一印象就是很酷&#xff0c;而且技术精湛&#xff0c;在网络世界里无所不能。目前几乎所有的公司企业甚至国家相关部门都会争相高薪聘请技术精湛的黑客作为互联网机构的安全卫士&#xff0c;所以黑客也…

umi框架的使用

umi框架的使用 安装npm i -g yrm 查看yarn镜像源yrm ls 切换源 yrm use taobao 创建项目 yarn create umijs/umi-app 安装依赖yarn 启动项目yarn start 路由组件还可以进行children进行子路由渲染 打个比方&#xff0c;现在有头部导航跟侧边是一致的我们只希望修改每个应…

Mybatis-Plus详解

目录 一、Mybatis-Plus简介 &#xff08;一&#xff09;什么是Mybatis-Plus &#xff08;二&#xff09;Mybatis-Plus的优势 &#xff08;三&#xff09;Mybatis-Plus的框架结构 二、SpringBoot整合Mybatis-Plus入门 &#xff08;一&#xff09;创建maven工程&#xff0c;…

爬虫+Flask+Echarts搭建《深度学习》书评显示大屏

爬虫FlaskEcharts搭建《深度学习》书评显示大屏 1、前言2、实现2.1 挑选想要采集的书籍2.2 构建爬虫2.2.1 采集书籍信息2.2.2 采集书评 2.3 数据清洗2.3.1 清洗书籍信息2.3.2 清洗书评信息 2.4 统计分析&#xff0c;结果持久化存储2.5 搭建flask框架2.6 数据传值2.7 完整代码&a…

什么是Nginx的反向代理与正向代理详解

文章目录 1、什么是正向代理2、什么是反向代理3、反向代理的作用 1、什么是正向代理 正向代理&#xff0c;“它代理的是客户端”&#xff0c;是一个位于客户端和目标服务器之间的服务器&#xff0c;为了从目标服务器取得内容&#xff0c;客户端向代理服务器发送一个请求并指定…

汽车网卡驱动之TJA1101B

TJA1101B汽车网卡驱动(汽车以太网) 1总体描述 2特点和优点 2.1通用 2.2针对汽车用例优化

酷炫音乐盒: python打造自己的音乐播放器

目录标题 前言代码实现尾语 前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! Python的Tkinter&#xff08;Tk接口&#xff09;是一个用于创建图形用户界面&#xff08;GUI&#xff09;的标准库。 它是Python的内置模块&#xff0c;无需额外安装即可使用。Tkinter提供了一组…

【字节青训pre】后端笔试练兵

文章目录&#xff1a; 零、前言一、选择题二、编程题1、36进制转换a) 题目b) 解题思路 零、前言 好久没更博客了 &#xff0c; 暑假参加字节青训营&#xff0c;记录一下备战经历&#xff0c;水水博客 。 因该博客持续更新&#xff0c;文中部分链接是写该博客时预存占坑位的&…

【雕爷学编程】Arduino动手做(147)---QMC5883L三轴罗盘模块2

37款传感器与执行器的提法&#xff0c;在网络上广泛流传&#xff0c;其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块&#xff0c;依照实践出真知&#xff08;一定要动手做&#xff09;的理念&#xff0c;以学习和交流为目的&am…