pytorch的动态计算图机制

news2024/11/16 1:46:28

pytorch的动态计算图机制

一,动态计算图简介

在这里插入图片描述

Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系。

Pytorch中的计算图是动态图。这里的动态主要有两重含义。

第一层含义是:计算图的正向传播是立即执行的。无需等待完整的计算图创建完毕,每条语句都会在计算图中动态添加节点和边,并立即执行正向传播得到计算结果。

第二层含义是:计算图在反向传播后立即销毁。下次调用需要重新构建计算图。如果在程序中使用了backward方法执行了反向传播,或者利用torch.autograd.grad方法计算了梯度,那么创建的计算图会被立即销毁,释放存储空间,下次调用需要重新创建。

1,计算图的正向传播是立即执行的。

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))

print(loss.data)
print(Y_hat.data)
tensor(17.8969)
tensor([[3.2613],
        [4.7322],
        [4.5037],
        [7.5899],
        [7.0973],
        [1.3287],
        [6.1473],
        [1.3492],
        [1.3911],
        [1.2150]])

2,计算图在反向传播后立即销毁。

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))

#计算图在反向传播后立即销毁,如果需要保留计算图, 需要设置retain_graph = True
loss.backward()  #loss.backward(retain_graph = True) 

#loss.backward() #如果再次执行反向传播将报错

二,计算图中的Function

计算图中的另外一种节点是Function, 实际上就是 Pytorch中各种对张量操作的函数。

这些Function和我们Python中的函数有一个较大的区别,那就是它同时包括正向计算逻辑和反向传播的逻辑。

我们可以通过继承torch.autograd.Function来创建这种支持反向传播的Function

class MyReLU(torch.autograd.Function):

    #正向传播逻辑,可以用ctx存储一些值,供反向传播使用。
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    #反向传播逻辑
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.tensor([[-1.0,-1.0],[1.0,1.0]])
Y = torch.tensor([[2.0,3.0]])

relu = MyReLU.apply # relu现在也可以具有正向传播和反向传播功能
Y_hat = relu(X@w.t() + b)
loss = torch.mean(torch.pow(Y_hat-Y,2))

loss.backward()

print(w.grad)
print(b.grad)
tensor([[4.5000, 4.5000]])
tensor([[4.5000]])
# Y_hat的梯度函数即是我们自己所定义的 MyReLU.backward

print(Y_hat.grad_fn)
<torch.autograd.function.MyReLUBackward object at 0x1205a46c8>

三,计算图与反向传播

了解了Function的功能,我们可以简单地理解一下反向传播的原理和过程。理解该部分原理需要一些高等数学中求导链式法则的基础知识。

import torch 

x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2

loss.backward()

loss.backward()语句调用后,依次发生以下计算过程。

1,loss自己的grad梯度赋值为1,即对自身的梯度为1。

2,loss根据其自身梯度以及关联的backward方法,计算出其对应的自变量即y1和y2的梯度,将该值赋值到y1.grad和y2.grad。

3,y2和y1根据其自身梯度以及关联的backward方法, 分别计算出其对应的自变量x的梯度,x.grad将其收到的多个梯度值累加。

(注意,1,2,3步骤的求梯度顺序和对多个梯度值的累加规则恰好是求导链式法则的程序表述)

正因为求导链式法则衍生的梯度累加规则,张量的grad梯度不会自动清零,在需要的时候需要手动置零。

四,叶子节点和非叶子节点

执行下面代码,我们会发现 loss.grad并不是我们期望的1,而是 None。

类似地 y1.grad 以及 y2.grad也是 None.

这是为什么呢?这是由于它们不是叶子节点张量。

在反向传播过程中,只有 is_leaf=True 的叶子节点,需要求导的张量的导数结果才会被最后保留下来。

那么什么是叶子节点张量呢?叶子节点张量需要满足两个条件。

1,叶子节点张量是由用户直接创建的张量,而非由某个Function通过计算得到的张量。

2,叶子节点张量的 requires_grad属性必须为True.

Pytorch设计这样的规则主要是为了节约内存或者显存空间,因为几乎所有的时候,用户只会关心他自己直接创建的张量的梯度。

所有依赖于叶子节点张量的张量, 其requires_grad 属性必定是True的,但其梯度值只在计算过程中被用到,不会最终存储到grad属性中。

如果需要保留中间计算结果的梯度到grad属性中,可以使用 retain_grad方法。
如果仅仅是为了调试代码查看梯度值,可以利用register_hook打印日志。

import torch 

x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2

loss.backward()
print("loss.grad:", loss.grad)
print("y1.grad:", y1.grad)
print("y2.grad:", y2.grad)
print(x.grad)
loss.grad: None
y1.grad: None
y2.grad: None
tensor(4.)
print(x.is_leaf)
print(y1.is_leaf)
print(y2.is_leaf)
print(loss.is_leaf)
True
False
False
False

利用retain_grad可以保留非叶子节点的梯度值,利用register_hook可以查看非叶子节点的梯度值。

import torch 

#正向传播
x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2

#非叶子节点梯度显示控制
y1.register_hook(lambda grad: print('y1 grad: ', grad))
y2.register_hook(lambda grad: print('y2 grad: ', grad))
loss.retain_grad()

#反向传播
loss.backward()
print("loss.grad:", loss.grad)
print("x.grad:", x.grad)
y2 grad:  tensor(4.)
y1 grad:  tensor(-4.)
loss.grad: tensor(1.)
x.grad: tensor(4.)

五,计算图在TensorBoard中的可视化

可以利用 torch.utils.tensorboard 将计算图导出到 TensorBoard进行可视化。

from torch import nn 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.w = nn.Parameter(torch.randn(2,1))
        self.b = nn.Parameter(torch.zeros(1,1))

    def forward(self, x):
        y = x@self.w + self.b
        return y

net = Net()
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('../data/tensorboard')
writer.add_graph(net,input_to_model = torch.rand(10,2))
writer.close()
%load_ext tensorboard
#%tensorboard --logdir ../data/tensorboard
from tensorboard import notebook
notebook.list() 
#在tensorboard中查看模型
notebook.start("--logdir ../data/tensorboard")

在这里插入图片描述


Reference:

https://jackiexiao.github.io/eat_pytorch_in_20_days/2.%E6%A0%B8%E5%BF%83%E6%A6%82%E5%BF%B5/2-3%2C%E5%8A%A8%E6%80%81%E8%AE%A1%E7%AE%97%E5%9B%BE/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2156355.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

“吉林一号”宽幅02B系列卫星

离轴四反光学成像系统 1.光学系统参数&#xff1a; 焦距&#xff1a;77.5mm&#xff1b; F/#&#xff1a;7.4&#xff1b; 视场&#xff1a;≥56゜&#xff1b; 光谱范围&#xff1a;400nm&#xff5e;1000nm。 2.说明&#xff1a; 光学系统采用离轴全反射式结构&#xff0c;整…

解密的军事卫星图像在各种民用地理空间研究中都有应用

一、美军光学侦察卫星计划概述 国家侦察局 &#xff08;NRO&#xff09; 负责开发和操作太空侦察系统&#xff0c;并为美国国家安全开展情报相关活动。NRO 开发了几代机密锁眼 &#xff08;KH&#xff09; 军事光学侦察卫星&#xff0c;这些卫星一直是美国国防部 &#xff08;D…

人工智能不是人工“制”能

文/孟永辉 如果你去过今年在上海举办的世界人工智能大会&#xff0c;就会知道当下的人工智能行业在中国是多么火爆。 的确&#xff0c;作为第四次工业革命的重要组成部分&#xff0c;人工智能愈发引起越来越多的重视。 不仅仅是在中国&#xff0c;当今世界的很多工业强国都在将…

python爬虫案例——异步加载网站数据抓取,post请求(6)

文章目录 前言1、任务目标2、抓取流程2.1 分析网页2.2 编写代码2.3 思路分析前言 本篇案例主要讲解异步加载网站如何分析网页接口,以及如何观察post请求URL的参数,网站数据并不难抓取,主要是将要抓取的数据接口分析清楚,才能根据需求编写想要的代码。 1、任务目标 目标网…

Win10 安装Node.js 以及 Vue项目的创建

一、Node.js和Vue介绍 1. Node.js Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行环境。它允许你在服务器端运行 JavaScript&#xff0c;使得你能够使用 JavaScript 来编写后端代码。以下是 Node.js 的一些关键特点&#xff1a; 事件驱动和非阻塞 I/O&#xff1a;Node…

list(一)

list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。list的底层是双向链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立节点中&#xff0c;在节点中通过指针指向 其前一个元素和后一个元素。 支持 -- 但是不支持…

Linux:终端(terminal)与终端管理器(agetty)

终端的设备文件 打开/dev目录可以发现其中有许多字符设备文件&#xff0c;例如对于我的RedHat操作系统&#xff0c;拥有tty0到tty59&#xff0c;它们是操作系统提供的终端设备。对于tty1-tty12使用ctrlaltF*可以进行快捷切换&#xff0c;下面的命令可以进行通用切换。 sudo ch…

校园热捧的“人气新贵”,D 咖智能饮品机器人

在 2024 年的校园中&#xff0c;一股全新的潮流正在悄然兴起。D 咖智能饮品机器人以其独特的魅力&#xff0c;成功入驻多个校园&#xff0c;迅速成为学生们热烈追捧的对象&#xff0c;在长江大学、荆州职业技术学院、中医高专等多个大学校园&#xff0c;都能发现他们靓丽的身姿…

calibre-web报错:File type isn‘t allowed to be uploaded to this server

calibre-web报错&#xff1a;File type isnt allowed to be uploaded to this server 最新版的calibre-web在Upload时候会报错&#xff1a; File type isnt allowed to be uploaded to this server 解决方案&#xff1a; Admin - Basic Configuration - Security Settings 把…

投资学 01 定义,投资

02. 03. 3.1 直接投资&#xff1a;使用方和提供方是一个人

VUE3学习---【一】【从零开始的VUE学习】

目录​​​​​​​ 什么是Vue 渐进式框架 创建一个Vue应用 什么是Vue应用 使用Vue应用 根组件 挂载应用 模板语法 文本插值 原始HTML Attribute绑定 简写 同名简写 布尔型Attribute 动态绑定多个值 使用JavaScript表达式 仅支持表达式 指令 Directives 指令…

COLORmap

在这段MATLAB代码中&#xff0c;surf(peaks)、map的定义以及colormap(map)的调用共同完成了以下任务&#xff1a; 1. **绘制曲面图**&#xff1a; - surf(peaks)&#xff1a;这个函数调用了MATLAB内置的peaks函数来生成数据&#xff0c;并使用surf函数将这些数据绘制成一个…

双向链表:实现、操作与分析【算法 17】

双向链表&#xff1a;实现、操作与分析 引言 双向链表&#xff08;Doubly Linked List&#xff09;是链表数据结构的一种重要形式&#xff0c;它允许节点从两个方向进行遍历。与单向链表相比&#xff0c;双向链表中的每个节点不仅包含指向下一个节点的指针&#xff08;或引用&…

蓝桥杯嵌入式的学习总结

一. 前言 嵌入式竞赛实训平台(CT117E-M4) 是北京国信长天科技有限公司设计&#xff0c;生产的一款 “ 蓝桥杯全国软件与信息技术专业人才大赛-嵌入式设计与开发科目 “ 专用竞赛平台&#xff0c;平台以STM32G431RBT6为主控芯片&#xff0c;预留扩展板接口&#xff0c;可为用户提…

数据结构篇--顺序查找【详解】

概念章 查找就是在数据集合中寻找某种条件的数据元素的过程。 查找表是指用于查找同一类型的数据元素集合。 找到了满足条件的数据元素&#xff0c;就是查找成功&#xff0c;否则就是称为查找失败。 关键字是指数据元素的某个数据项的值&#xff0c;可用于标识或者记录&…

【Java】线程暂停比拼:wait() 和 sleep()的较量

欢迎浏览高耳机的博客 希望我们彼此都有更好的收获 感谢三连支持&#xff01; 在Java多线程编程中&#xff0c;合理地控制线程的执行是至关重要的。wait()和sleep()是两个常用的方法&#xff0c;它们都可以用来暂停线程的执行&#xff0c;但它们之间存在着显著的差异。本文将详…

【AI学习笔记】初学机器学习西瓜书概要记录(二)常用的机器学习方法篇

初学机器学习西瓜书的概要记录&#xff08;一&#xff09;机器学习基础知识篇(已完结) 初学机器学习西瓜书的概要记录&#xff08;二&#xff09;常用的机器学习方法篇(持续更新) 初学机器学习西瓜书的概要记录&#xff08;三&#xff09;进阶知识篇(待更) 文字公式撰写不易&am…

Django 基础之启动命令和基础配置

Django启动 django启动一般可以通过ide或者命令启动 ide启动&#xff1a; 启动命令&#xff1a; python manage.py runserver该命令后续可以增加参数&#xff0c;如&#xff1a; python manage.py runserver 8081 python manage.py runserver 127.0.0.1:8082 注意&#xff1…

StopIteration: 迭代停止完美解决方法 ️

&#x1f504; StopIteration: 迭代停止完美解决方法 &#x1f6e0;️ &#x1f504; StopIteration: 迭代停止完美解决方法 &#x1f6e0;️摘要引言正文1. 什么是StopIteration异常&#xff1f;&#x1f4dc;2. StopIteration在for循环中的处理机制&#x1f6a6;3. 如何自定…

数仓规范:命名规范如何设计?

目录 0 前言 1 表命名规范 2 字段命名规范 3 任务命名规范 4 层级命名规范 5 自定义函数命名规范 6 视图和存储过程的命名规范 7 综合案例分析 8 常见陷阱和如何避免 9 工具和最佳实践 10 小结 想进一步了解数仓建设这门艺术的&#xff0c;可以订阅我的专栏数字化建设…