【Pytorch】自定义autograd函数,使用graphviz画出计算图

news2025/1/6 19:49:49

使用pytorch.autograd.Function构建一个自动求导层

  • 1. 手工设计一个线性运算层
  • 2. 使用pytorch.autograd.Function编码实现
  • 3. graphviz进行可视化

1. 手工设计一个线性运算层

设输入为 x \bold{x} x,参数为 w \bold{w} w b \bold{b} b,运算如下:
y = w ⊙ x + b \bold{y=w\odot x+b} y=wx+b
其中, ⊙ \odot 是矩阵的Hadmard积运算。

f f f来表示接下来的所有层的运算,有:
m o d e l ( x ) = f ( y ) model(\bold{x})=f(\bold{y}) model(x)=f(y)
这里, m o d e l model model表示模型的全部运算。

在反向传播求导中,给出损失函数对于 y \bold{y} y的导数 ∂ L ∂ y \frac{\partial{L}}{\partial{\bold{y}}} yL,则求参数 w \bold{w} w b \bold{b} b对于损失函数的导数有:
∂ L ∂ w = ∂ L ∂ y ∂ y ∂ w = ∂ L ∂ y ⊙ x \frac{\partial{L}}{\partial{\bold{w}}}=\frac{\partial{L}}{\partial{\bold{y}}}\frac{\partial{\bold{y}}}{\partial{\bold{w}}}=\frac{\partial{L}}{\partial{\bold{y}}}\odot\bold{x} wL=yLwy=yLx ∂ L ∂ b = ∂ L ∂ y ∂ y ∂ b = ∂ L ∂ y ⊙ 1 \frac{\partial{L}}{\partial{\bold{b}}}=\frac{\partial{L}}{\partial{\bold{y}}}\frac{\partial{\bold{y}}}{\partial{\bold{b}}}=\frac{\partial{L}}{\partial{\bold{y}}}\odot\bold{1} bL=yLby=yL1

至此,我们已经求出了这个线性层的前向和反向传播的公式。

2. 使用pytorch.autograd.Function编码实现

pytorch.autograd.Function是实现自动求导类的基类。为了实现自定义类,实现forward()backward()两个静态方法。对于上文的运算,代码如下:

from torch.autograd import Function
class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, w, x, b):
        ctx.save_for_backward(x,)
        output = w*x +b
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        grad_w = grad_output * x
        grad_b = grad_output * 1
        return grad_w, None, grad_b

在这里,ctx可以理解为一个连接前后运算的对象,用ctx.save_for_backward存储反向传播的使用用到的参数。grad_output是这一层的输出的导数,即 ∂ L ∂ y \frac{\partial{L}}{\partial{\bold{y}}} yL

给出w,x,b

w = torch.tensor([[1.,2],[3,4]], requires_grad=True)
x = torch.rand(2, 2)
b = torch.tensor([[4.,3],[2,1]], requires_grad=True)
w,x,b
(tensor([[1., 2.],
         [3., 4.]], requires_grad=True),
 tensor([[0.0534, 0.8366],
         [0.9568, 0.1293]]),
 tensor([[4., 3.],
         [2., 1.]], requires_grad=True))

在创建参数的tensor的时候,参数requires_grad=True,使得该张量可以计算梯度(默认为False)。

ag_func = MultiplyAdd()
out = ag_func.apply(w, x, b)
out.backward(torch.ones(2,2), retain_graph=True)
w.grad, b.grad

在前向传播的时候,使用apply()方法而不是直接调用forward()方法,具体可以参见PyTorch文档。

(tensor([[0.0534, 0.8366],
         [0.9568, 0.1293]]),
 tensor([[1., 1.],
         [1., 1.]]))

使用 grad_fn可以得到当前张量计算的计算图。 grad_fn.next_functions存储了上一层的计算单元。这里存储了三个单元,但是由于没有求x的梯度,所以其对应的是None。可以看到存储的单元中存放的是w和b。

print(out.grad_fn)
print(out.grad_fn.next_functions)
print(out.grad_fn.next_functions[0][0].variable)
print(out.grad_fn.next_functions[2][0].variable)
<torch.autograd.function.MultiplyAddBackward object at 0x0000014F2AFB52E0>
((<AccumulateGrad object at 0x0000014F2AF82040>, 0), (None, 0), (<AccumulateGrad object at 0x0000014F2A53F250>, 0))
tensor([[1., 2.],
        [3., 4.]], requires_grad=True)
tensor([[4., 3.],
        [2., 1.]], requires_grad=True)

3. graphviz进行可视化

graphviz是一个常用的画图工具包,具体的安装可以参考网上的教程(记得添加环境变量)。

from graphviz import Digraph

node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()

def size_to_str(size):
    return '(' + (', ').join(['%d' % v for v in size]) + ')'

def add_nodes(var):

    if var not in seen:
        if torch.is_tensor(var):
            # note: this used to show .saved_tensors in pytorch0.2, but stopped
            # working as it was moved to ATen and Variable-Tensor merged
            dot.node(str(id(var)), size_to_str(var.size()), fillcolor='yellow')
        elif hasattr(var, 'variable'):
            u = var.variable
            node_name = size_to_str(u.size())
            dot.node(str(id(var)), node_name, fillcolor='lightblue')
        else:
            dot.node(str(id(var)), str(type(var).__name__))

            seen.add(var)
        if hasattr(var, 'next_functions'):
            for u in var.next_functions:
                if u[0] is not None:
                    dot.edge(str(id(u[0])), str(id(var)))
                    add_nodes(u[0])
        if hasattr(var, 'saved_tensors'):
            for t in var.saved_tensors:
                dot.edge(str(id(t)), str(id(var)))
                add_nodes(t)
dot.node('Output', 'out\n'+size_to_str(out.size()))
dot.edge( str(id(out.grad_fn)),'Output')
add_nodes(out.grad_fn)
dot.render(('graph'), view=False)

最后得到的PDF如下所示:
在这里插入图片描述

其中,淡蓝色表示需要求导的单元,也就是w和b。黄色单元表示用于反向求导的参数,也就是x。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/186498.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Tips for Confluence Administrators: Part 3

上一篇Part 2中&#xff0c;我们谈到了Confluence的一些已知Bug&#xff0c;例如&#xff1a;从垃圾桶中恢复同名页面不应覆盖现有页面复制/粘贴表格单元格无法按预期工作无法复制表头嵌套表格无法调整大小请不要吝惜你的投票&#xff0c;这有助于提高Atlassian修复它们的优先级…

四、Maven详细教程

本教程相关资料&#xff1a;https://www.aliyundrive.com/s/wMiqbd4Zws6 Maven是专门用于管理和构建Java项目的工具&#xff0c;它的主要功能有&#xff1a; 提供了一套标准化的项目结构 提供了一套标准化的构建流程&#xff08;编译&#xff0c;测试&#xff0c;打包&#xf…

MacOs安装Redis并设置为开机、后台启动

前言 最近闲来无事&#xff0c;将自己的MBP系统重装里&#xff0c;导致里面原来安装的软件都需要重新安装&#xff0c;今天记录一下MacOs安装Redis并设置为开机启动、后台启动的步骤&#xff0c;安装过程略有波折&#xff0c;参考里几篇文章才搞定。 一、安装Redis 两种方式…

【JavaEE】线程池

目录 线程池概念 线程池优点 使用线程池 Executor接口&#xff1a; ThreadPoolExecutor类&#xff1a; 构造方法 Executors工厂类&#xff1a; 工厂方法 线程池中的常用方法 线程池的工作流程 线程池的状态 RUNNING SHUTDOWN STOP TIDYING TERMINATED 简单实现…

Android10 开机向导流程

最近在弄开机向导&#xff0c;网上查了查&#xff0c;基本都是参照系统的​​​​​​Provision应用来做的&#xff0c;而且还要将apk打包到系统目录下的pri-app目录下&#xff0c;打包到其他目录下不行&#xff0c;参照着做是没问题&#xff0c;但是好奇为什么要这么做&#x…

esp32连接阿里云物联网平台进行MQTT通信

前提&#xff1a;IDE是采用arduino IDE&#xff0c;arduino使用的库是pubsubclient 开发板可以使用esp32&#xff08;esp8266也是一样的&#xff09; 已经学会pubsubclient库的基本使用 使用pubsubclient 库连接阿里云物联网平台 const char* ssid "........"; c…

java ssm校园勤工俭学助学志愿者兼职系统 idea maven

本论文主要论述了如何使用JSP技术开发一个校园勤工俭学兼职系统&#xff0c;本系统将严格按照软件开发流程进行各个阶段的工作&#xff0c;采用B/S架构&#xff0c;面向对象编程思想进行项目开发。在引言中&#xff0c;作者将论述校园勤工俭学兼职系统的当前背景以及系统开发的…

使用CCProxy+Proxifier实现代理

目录1.使用场景2.什么是网络代理&#xff1f;3.CCProxy3.1 说明3.2 下载安装3.3 使用说明4.Proxifier4.1 说明4.2 下载安装4.3 使用说明4.4 Proxifier CPU占用率高问题解决1.使用场景 很多时候当我们访问某个网络&#xff08;例如&#xff1a;校园网、企业网&#xff09;&#…

射频识别技术|期末考试知识点|重点题目|第1讲_RFID

课堂笔记 1.RFID技术 标签(芯片、天线、封装) 读写器 中间件和系统软件 公共服务体系 2.IC&ID

使用原始命令编译打包部署springboot-demo项目

目录简介源文件介绍编译编译restful-common编译manual-springboot打包&部署&执行jar命令介绍不打包直接运行打普通jar包&#xff0c;通过java -jar运行打fat jar通过java -jar打war&#xff0c;通过部署至tomcat运行纯手工命令开发打包部署的缺点参考简介 本文将使用j…

CUDA编程笔记(7)

文章目录前言共享内存的合理使用数组归约计算使用全局内存的计算引入线程块中的同步函数使用共享内存计算静态共享内存使用动态共享内存性能比较避免共享内存的bank冲突使用共享内存进行数组转置bank概念性能比较总结前言 cuda共享内存的合理使用。 共享内存的合理使用 共享…

TF-A移植

1.对tf-a源码进行解压 $> tar xfz tf-a-stm32mp-2.2.r2-r0.tar.gz 2.打补丁 $> for p in ls -1 ../*.patch; do patch -p1 < $p; done 3.配置交叉编译工具链 将Makefile.sdk中EXTRA_OEMAKE修改为 EXTRA_OEMAKECROSS_COMPILEarm-linux-gnueabihf- 4.复制设备树…

linux 部署jmeter

一、linux 安装jdk Java Downloads | Oracle 二、 linux上传jmeter 2.1 上传jmeter jmeter 下载地址&#xff1a; Apache JMeter - Download Apache JMeter 注意&#xff1a; 我先在我本地调试脚本&#xff08;mac环境&#xff09;&#xff0c;调试完成后&#xff0c;再在…

前端首屏优化

一. 打包分析 在 package.json 中添加命令 “report”: “vue-cli-service build --report” 然后命令行执行 npm run report&#xff0c;就会在dist目录下生成一个 report.html 文件&#xff0c;右键浏览器中打开即可看到打包分析报告。 二. 路由懒加载 component: () >…

MacOS - steam 蒸汽平台安装教程,带你躲避高仿网站的陷阱

MacOS - steam 蒸汽平台安装教程 MacOS 其实也是可以安装 Steam 平台的&#xff0c;虽然 steam 上的大多游戏暂时都不支持 MacOS&#xff0c;但还是有一些游戏可以玩的&#xff0c;而且近几年支持 MacOS 的游戏也是越来越多了。另外现在高仿网站特别多&#xff0c;所以才有了这…

transformer库使用

Transformer库简介 是一个开源库&#xff0c;其提供所有的预测训练模型&#xff0c;都是基于transformer模型结构的。 Transformer库 我们可以使用 Transformers 库提供的 API 轻松下载和训练最先进的预训练模型。使用预训练模型可以降低计算成本&#xff0c;以及节省从头开…

Grafana 系列文章(三):Tempo-使用 HTTP 推送 Spans

&#x1f449;️URL: https://grafana.com/docs/tempo/latest/api_docs/pushing-spans-with-http/ &#x1f4dd;Description: 有时&#xff0c;使用追踪系统是令人生畏的&#xff0c;因为它似乎需要复杂的应用程序仪器或 span 摄取管道&#xff0c;以便 ... 有时&#xff0c;使…

SurfaceFlinger学习笔记(七)之SKIA

关于Surface请参考下面文章 SurfaceFlinger学习笔记(一)应用启动流程 SurfaceFlinger学习笔记(二)之Surface SurfaceFlinger学习笔记(三)之SurfaceFlinger进程 SurfaceFlinger学习笔记(四)之HWC2 SurfaceFlinger学习笔记(五)之HWUI SurfaceFlinger学习笔记(六)之View Layout Dr…

react 实现表格列表拖拽排序

问题描述 在项目开发中&#xff0c;遇到这样一个需求&#xff1a;需要对表格里面的数据进行拖拽排序。 效果图如下所示&#xff1a; 思路 安装两个插件&#xff1a; react-sortable-hoc &#xff08;或者 react-beautiful-dnd&#xff09;array-move npm install --save r…

59 多线程环境普通变量作为标记循环不结束

前言 最近看到这篇例子的时候, [讨论] 内存可见性问题, 吧其中的 demo 拿到本地来跑 居然 和楼主一样, testBasicType 这里面的这个子线程 居然 不结束了, 卧槽 我还以为 只是可能 用的时间稍微长一点 哪知道 直接 无限期执行下去了, 然后 另外还有一个情况就是 加上了 -…