基于__torch_dispatch__机制的dump方法

news2024/12/28 19:05:27

基于__torch_dispatch__机制的dump方法

  • 1.参考链接
  • 2.原理
  • 3.代码
  • 4.效果

之前拦截torch和torch.Tensor的办法,在处理backward时,不能看到aten算子的细节.以下基于__torch_dispatch__机制的方案更节约代码,且能看到调用栈

1.参考链接

[原理] (https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557)

2.原理

在这里插入图片描述

3.代码

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threading

device="cuda"
from torch.utils._python_dispatch import TorchDispatchMode
import inspect
import traceback
from dataclasses import dataclass
from typing import Any

@dataclass
class _ProfilerState:
    cls: Any
    object: Any = None

lock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):
    if isinstance(args,torch.Tensor):
        print(name,index,args.shape)
        global gindex
        lock.acquire()
        torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))
        gindex+=1
        lock.release()
    if isinstance(args,tuple):
        for idx,x in enumerate(args):
            save_tensor(name,x,index+idx)

class TorchDumpDispatchMode(TorchDispatchMode):
    def __init__(self,parent):
        super().__init__()
        self.parent=parent
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        func_packet = func._overloadpacket        
        if kwargs is None:
            kwargs = {}        
        enable_dump=False
        if func_packet.__name__ not in ["detach"]:
            enable_dump=True
            print(f"Profiling {func_packet.__name__}") 
            for idx,stack in enumerate(inspect.stack()):
                print(f'{"*"*idx}{stack.filename}{stack.lineno}')
        if enable_dump:     
            save_tensor(f"{func_packet.__name__}-input",args)
        ret= func(*args, **kwargs)
        if enable_dump:
            save_tensor(f"{func_packet.__name__}-output",args)
        return ret

class TorchDumper:
    _CURRENT_Dumper = None
    def __init__(self,schedule: Any):
        self.p= _ProfilerState(schedule) 

    def __enter__(self):
        assert TorchDumper._CURRENT_Dumper is None
        TorchDumper._CURRENT_Dumper = self
        if self.p.object is None:
            o = self.p.cls(self)
            o.__enter__()
            self.p.object = o
        else:
            self.p.object.step()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        TorchDumper._CURRENT_Dumper = None
        if self.p.object is not None:
            self.p.object.__exit__(exc_type, exc_val, exc_tb)

class Attention(nn.Module):
    def __init__(self,max_seq_len,head_dim,flash):
        super().__init__()
        self.flash = flash
        self.dropout=0
        self.attn_dropout = nn.Dropout(self.dropout)
        self.head_dim=head_dim
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)
            mask = torch.triu(mask, diagonal=1).half().to(device)
            self.register_buffer("mask", mask)		
    def forward(
            self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):
        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,
                                                                       attn_mask=None, 
                                                                       dropout_p=self.dropout if self.training else 0.0, is_causal=True)
        else:
            _xk=xk.clone()
            t=_xk.transpose(2, 3)
            scores = torch.matmul(xq,t)
            scores = scores/math.sqrt(self.head_dim)
            a=self.mask[:, :, :seqlen, :seqlen]
            scores = scores+a
            scores = F.softmax(scores.float(), dim=-1)
            scores = scores.type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores, xv)  
        return output

def main(flash,bs, n_local_heads, seqlen, head_dim):
    torch.random.manual_seed(1)

    q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)

    q.data.normal_(0, 0.1)
    k.data.normal_(0, 0.1)
    v.data.normal_(0, 0.1)

    q=Variable(q, requires_grad=True).to(device)
    k=Variable(k, requires_grad=True).to(device)
    v=Variable(v, requires_grad=True).to(device)

    gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)
    loss_func=nn.CrossEntropyLoss().to(device)

    model=Attention(seqlen,head_dim,flash).half().to(device)
    optim = torch.optim.SGD([q,k,v], lr=1.1)

    with TorchDumper(TorchDumpDispatchMode):
        for i in range(1):
            output = model(q,k,v)
            loss=loss_func(output.reshape(-1,head_dim),gt)
            loss.backward()  
            optim.step()
            print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))

bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

4.效果

Profiling clone
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py109
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
Profiling transpose
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py110
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 512, 64])
Profiling expand
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
expand-input 0 torch.Size([8, 8, 512, 64])
expand-output 0 torch.Size([8, 8, 512, 64])
Profiling view
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
view-input 0 torch.Size([8, 8, 512, 64])
view-output 0 torch.Size([8, 8, 512, 64])
Profiling expand
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
expand-input 0 torch.Size([8, 8, 64, 512])
expand-output 0 torch.Size([8, 8, 64, 512])
Profiling view
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
view-input 0 torch.Size([8, 8, 64, 512])
view-output 0 torch.Size([8, 8, 64, 512])
Profiling bmm
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
bmm-input 0 torch.Size([64, 512, 64])
bmm-input 1 torch.Size([64, 64, 512])
bmm-output 0 torch.Size([64, 512, 64])
bmm-output 1 torch.Size([64, 64, 512])
Profiling _unsafe_view

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1619781.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【vue2】实现微信截图(复制图片)在项目内可粘贴

需求 后台管理在上传图片地方需要将复制的图片粘贴上传 一、添加事件 在原有上传组件的基础上添加 paste事件 二、方法 onPaste(e) {const items (e.clipboardData || window.clipboardData).items;let blob null;for (let i 0; i < items.length; i) {if (items[i].ty…

related_name和related_query_name属性

在Django模型继承中&#xff0c;假如在外键或多对多字段中使用了related_name属性或related_query_name属性&#xff0c;则必须为该字段提供一个独一无二的反向名字和查询名字。但是&#xff0c;这样在抽象基类中一般会引发问题&#xff0c;因为基类中的字段都被子类继承并且保…

WEB攻防-IIS中间件PUT漏洞

IIS6.0 server在web服务扩展中开启了WebDAV&#xff08;Web-based Distributed Authoring and Versioning&#xff09;。WebDAV是一种HTTP1.1的扩展协议。它扩展了HTTP 1.1&#xff0c;在GET、POST、HEAD等几个HTTP标准方法以外添加了一些新的方法&#xff0c;如PUT&#xff0c…

【js】解决自动生成颜色时相邻颜色视觉相似问题的技术方案

解决自动生成颜色时相邻颜色视觉相似问题的技术方案 在进行大规模颜色生成时&#xff0c;特别是在数据可视化、用户界面设计等应用领域&#xff0c;一个常见的挑战是确保相邻颜色在视觉上具有足够的区分度。本文介绍的方法通过结合黄金分割比与饱和度、亮度的周期性变化&#…

Axure实现tab页面切换功能

1. 实现效果 2. 实现原理 创建两个标签&#xff0c;并实现点击时选中状态点击时&#xff0c;设置面板状态 3. 实现步骤 3.1 实现可切换的标签 在页面上拖拽两个矩形作为两个tab标签&#xff0c;并命名 tab1 和 tab2 设置每个矩形的边框显示&#xff0c;只显示下边框即可 …

Oracle使用内部包自定义创建表空间和用户

如果之前有类似的表空间,可以使用dbms自动生成对应的表空间和数据文件 select dbms_metadata.get_ddl(TABLESPACE,ts.tablespace_name) from dba_tablespaces ts; 可以使用类似的 SQL> set echo off SQL> spool /data/logs/create_tablespace.log SQL> select dbms…

【安卓13】解决带GMS编译报super分区空间不足错误

1、错误信息 2、解决方案 不同供应商修改分区大小的文件路径不一样&#xff0c;但是万变不离其宗&#xff0c;根据报错信息全局搜索关键词BOARD_SUPER_PARTITION_SIZE 这里以RK供应商和AML供应商修改为例&#xff1a; &#xff08;1&#xff09;RK改法&#xff1a; 根目录下…

uniapp配置了pages.json 的 tabbar 国际化,小程序切换语言没有实时切换

如上图&#xff0c;按照uniapp官方文档配置了tabbar的国际化 但是微信小程序实时切换语言没有实时刷新 解决方案&#xff1a; 在App.vue中加入以下代码&#xff1a; 在onLaunch中执行方法即可

Springboot+Vue项目-基于Java+MySQL的图书馆管理系统(附源码+演示视频+LW)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;Java毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计 &…

ctfshow——XSS

文章目录 XSS介绍什么是xss&#xff1f;XSS危害XSS的分类常用XSSpayload web316——反射型XSSweb317——过滤<script> web318——过滤script、imgweb319——不止过滤script、imgweb320——过滤空格web321——不止过滤空格web322——不止过滤空格web323web324web 325web32…

代码质量与可维护性的重要性都有哪些?

目录 一、为了提高代码质量&#xff0c;可以采取以下几种方法&#xff1a; 二、如何制定和执行有效的代码编写规范&#xff1f; 三、设计模式和设计原则在提高代码质量中的具体应用案例有哪些&#xff1f; 四、代码审查的最佳实践和技巧是什么&#xff1f; 五、如何有效地…

面包屑新玩法,ReactRouter+Ant Design实现动态渲染

在Ant Design中,可以通过Breadcrumb组件结合react-router库实现动态生成面包屑导航。具体步骤如下: 定义路由配置数据结构 我们需要在路由配置中添加额外的面包屑相关信息,例如面包屑标题、icon等。例如: const routes [{path: /,breadcrumbName: 首页},{path: /users,brea…

echarts图表柱状图实现左右滑动

柱状图中实现下边或右边的左右滑动效果

【电控笔记5.7】Notch-Filter滤波器

Notch-Filter滤波器 通过阻尼比&#xff0c;限制陡峭程度 阻尼比小&#xff0c;比较陡峭&#xff0c;对周围信号干扰比较小&#xff0c;衰减度小 总结 实现&#xff1a;转换成Z转换进行伯德图验证

UE5、CesiumForUnreal实现建筑白模生成及白模美化功能

1.实现目标 在专栏上篇文章基于GeoJson文件生成城市级白模(本文建筑白模数量12w+)的基础上修改,计算法线和纹理坐标,并基于特定材质进行美化,美化后的白模GIF动图如下所示: 文章目录 1.实现目标2.实现过程2.1 基于Cesium材质美化2.1.1实现原理2.1.2 C++代码2.1.3 蓝图应…

Whatsapp在中国下架了?这招教你解决!

今天有一个紧急的消息要告诉大家&#xff0c;根据最新的电信办要求&#xff0c;苹果手机的中国应用商店已经下架了WhatsApp&#xff01;这意味着&#xff0c;如果你的苹果设备是在中国大陆地区注册的&#xff0c;那么你将无法直接在App Store搜索到WhatsApp。 但是&#xff0c;…

SQL基础(关系模型)

目录 SQL及定义域概念 SQL是什么 定义域 关系简介 关系的定义 关系的封闭性 关系模型简介 关系模型 谓词逻辑 运算基础 SQL的加减乘除 SQL的除法1 SQL的除法2 SQL的除法3 三值逻辑 NULL的危害 消除NULL SQL及定义域概念 SQL是什么 Structured Query Languag…

2024.4.23

const char *p; 指针变量地址可改变&#xff0c;指向的地址的值不可变 const (char *) p; 指针变量地址可改变&#xff0c;指向的地址的值不可变 char *const p; 指针变量地址不可改变&#xff0c;指向的地址的值可变 const char* const p; 地址…

TDengine高可用探讨

提到数据库&#xff0c;不可避免的要考虑高可用HA&#xff08;High Availability&#xff09;。但是很多人对高可用的理解并不是很透彻。 要搞清高可用需要回答以下几个问题&#xff1a; 什么是高可用&#xff1f;为什么需要高可用&#xff1f;高可用需要达到什么样的目标&am…

刷代码随想录有感(45):二叉树的最大深度

题干&#xff1a; 力扣这里给了定义&#xff1a;二叉树的最大深度指的是从根节点开始&#xff0c;到最远叶子所经过的节点数。 代码&#xff1a; class Solution {//递归实现 public:int maxDepth(TreeNode* root) {if(root NULL)return NULL;int leftheight maxDepth(root…