pytorch为自己的extension backend添加profiler功能

news2025/1/6 20:25:21

pytorch为自己的extension backend添加profiler功能

  • 1.参考文档
  • 2.your-extension-for-pytorch需要增加的代码
  • 3.pytorch demo及如何调整chrome trace json文件
  • 4.[可视化](https://ui.perfetto.dev/)

本文演示了pytorch如何为自己的extension backend添加profiler功能
背景介绍

  • 1.没有CNLight、Profiling AscendCL API、ROC Trace之类Profing功能,无法trace runtime,drive,kernel,也无法获取设备的metrics
  • 2.只有event功能,可以统计kernel耗时
  • 3.本文只是一种尝试,并不合理.
  • 4.torch原生的profiler框架,依赖kineto,kineto目前支持CUPTI和ROC Tracer,如果不修改torch源码,第三方设备不方便使用
  • 5.华为、寒武纪、habana都是采用torch.profile的接口形式及at::addThreadLocalCallback功能,但不依赖torch.profiler框架
    profing原始数据都是私有格式,并且修改TensorBoard的插件,可于可视化

实施步骤

  • 1.调用torch::profiler::impl::registerPrivateUse1Methods注册
  • 2.因为没有correlation ID去关联host api与kernel,因此export_chrome_trace出来的数据没有kernel信息
  • 3.获取prof.profiler.function_events里的数据,通过{ev.name}{ev.id}{ev.thread}拼成uuid与上面chrome trace中的events关联
  • 4.因为只有一个stream。可以根据Host lanuch时间、kernel耗时、launch latency(先验),推断出kernel的开始、结束时间,并用flow event进行关联(虽然并不准确)
  • 5.最后把kernel event以及flow event追加到chrome trace中

1.参考文档

  • ROC Tracer
  • CUPTI
  • 华为profiler_npu
  • Profiling AscendCL API
  • 寒武纪profile_mlu
  • 寒武纪CNLight
  • habana torch
  • intel_extension_for_pytorch
  • Make the kineto extendable for other runtime than CUD
  • pytorch_open_registration_example
  • rename_privateuse1_backend
  • Trace Event Format

2.your-extension-for-pytorch需要增加的代码

#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
#include <c10/util/irange.h>
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
 
using torch::profiler::impl::ProfilerStubs;
using torch::profiler::impl::ProfilerVoidEventStub;
  
namespace torch {
namespace profiler {
namespace impl {
 
struct NPUMethods : public ProfilerStubs {
   void record(
        int* device,
        ProfilerVoidEventStub* event,
        int64_t* cpu_ns) const override
    {
      if (device) {
          TORCH_CHECK(xpurtGetDevice((uint32_t*)device));
      }
      xpurtEvent_t xpurt_event;
      TORCH_CHECK(xpurtEventCreate(&xpurt_event));
      *event = std::shared_ptr<void>(xpurt_event, [](xpurtEvent_t ptr) {
          TORCH_CHECK(xpurtEventDestroy(ptr));
      });
      auto xpurt_stream = c10::xpu::getCurrentxpuStream(vastai::get_device());
      if (cpu_ns) {
          *cpu_ns = getTime();
      }
      TORCH_CHECK(xpurtEventRecord(xpurt_event, xpurt_stream)); 
    } 
    float elapsed(
        const ProfilerVoidEventStub* event1_,
        const ProfilerVoidEventStub* event2_) const override
    {
 
        auto event1 = static_cast<xpurtEvent_t>(event1_->get());
        TORCH_CHECK(xpurtEventSynchronize(event1));
        auto event2 = static_cast<xpurtEvent_t>(event2_->get());
        TORCH_CHECK(xpurtEventSynchronize(event2));
        int64_t time_ms = 0;
        TORCH_CHECK(xpurtEventElapsedTime(&time_ms, event1, event2));
        return time_ms*1.0;
    } 
    void onEachDevice(std::function<void(int)> op) const override
    {
        uint32_t device = 0;
        TORCH_CHECK(xpurtGetDevice(&device));
        op(device);
    } 
    void synchronize() const override { } 
    bool enabled() const override {return true;} 
    void mark(const char*name) const override { } 
    void rangePush(const char*name) const override { } 
    void rangePop() const override {}
};
 
struct RegisterNPUMethods {
    RegisterNPUMethods()
    {
        static NPUMethods methods;
        torch::profiler::impl::registerPrivateUse1Methods(&methods);
    }
};
RegisterNPUMethods reg;
}}}

3.pytorch demo及如何调整chrome trace json文件

import time
import torchvision.models as models
from torch import nn
import torch.nn.functional as F
import copy
import math
import torch
from torch.profiler import profile
import json
import tqdm

def is_valid_kernel(name,duration,valid_kernel_threshold=100):
    '''通过算子的名字和耗时判断是否是Device Kernel'''
    invalid_kernels=["aten::view","aten::reshape",
                    "aten::t","aten::empty",
                    "aten::transpose",
                    "aten::as_strided",
                    "aten::item",
                    "aten::_local_scalar_dense",
                    "aten::result_type",
                    "aten::_unsafe_view",
                    "aten::expand"]
    for k in invalid_kernels:
        if name.find(k)>=0:
            return False
    if duration<valid_kernel_threshold:
        return False    
    return True

def filter_ev(ev):
    '''过滤Kernel'''
    if 'args' in ev and "External id" in ev['args']:
        return True
    return False

def get_uuid(ev,tid_map):
    return f"{ev['name']}_{ev['args']['External id']}_{tid_map[ev['tid']]}"

def get_valid_kernels(traceEvents,kernel_event,tid_map):
    valid_kernels=[]
    device_memory_usage=0
    for ev in traceEvents:
        if filter_ev(ev):
            uuid=get_uuid(ev,tid_map)
            if uuid not in kernel_event:
                continue
            duration=kernel_event[uuid]['kernel_time']
            kernel_name=ev['name']
            if kernel_event[uuid]['device_memory_usage']>0:
                device_memory_usage=kernel_event[uuid]['device_memory_usage']
            if is_valid_kernel(kernel_name,duration):
                launch_beg=ev['ts']
                launch_end=ev['ts']+ev['dur']            
                valid_kernels.append({"name":kernel_name,
                                      "launch_beg":launch_beg,
                                      "launch_end":launch_end,
                                      "kernel_duration":duration,
                                      "host_pid":ev['pid'],
                                      "host_tid":ev['tid'],
                                      "device_memory_usage":device_memory_usage,
                                      "is_leaf_kernel":False})
                                      
    return sorted(valid_kernels,key=lambda x:x['launch_beg'])
    
def is_leaf_kernel(kernel,valid_kernels):
    '''判断是否是叶子Kernel'''
    ret=True
    for k in valid_kernels:
        if k['is_leaf_kernel']:
            continue
        #自己的时间跨度内还有别的Kernel
        if k['launch_beg']>kernel['launch_beg'] and k['launch_end']<kernel['launch_end']:
            ret=False
            break
    return ret

def create_tid_map(traceEvents):
    tids=set()
    for ev in traceEvents:
        if filter_ev(ev):
            tid=ev['tid']
            tids.add(tid)
    tid_map={}
    tids=sorted(tids,reverse=False)
    for i,v in enumerate(tids):
        tid_map[v]=i+1
    return tid_map
                                      
def merge_prof_timeline(prof_json,kernel_event_json,output_json):
    
    kernel_lanuch_latency=0
    with open(prof_json,'r',encoding='utf-8') as f:
        prof = json.load(f)

    with open(kernel_event_json,'r',encoding='utf-8') as f:
        kernel_event = json.load(f)   
    
    traceEvents=prof['traceEvents']
    tid_map=create_tid_map(traceEvents)
    print(tid_map)
    #获取所有kernel
    valid_kernels=get_valid_kernels(traceEvents,kernel_event,tid_map)
    print(len(valid_kernels))
    #筛出所有会在device上执行的kernel
    on_device_kernels=[]
    for kernel in tqdm.tqdm(valid_kernels):
        if is_leaf_kernel(kernel,valid_kernels):
            on_device_kernels.append(kernel)
    
    kernel_start_offset=0
    kernel_index=0

    for kernel in on_device_kernels:
        name=kernel['name']
        kernel_duration=kernel["kernel_duration"]
        lanuch_time=kernel["launch_beg"]
        host_pid=kernel['host_pid']
        host_tid=kernel['host_tid']
        device_memory_usage=kernel['device_memory_usage']
        
        if kernel_start_offset==0:
            kernel_start_offset=lanuch_time+kernel_start_offset
            
        if lanuch_time>kernel_start_offset: #kernel 队列空闲
            kernel_start_offset=lanuch_time
        
        #增加kernel事件
        traceEvents.append({"ph": "X", "cat": "device_kernel", "name":name, "pid": 10, "tid": 10,"ts": kernel_start_offset, "dur": kernel_duration})
        
        #增加内存事件
        traceEvents.append({"ph": "C", "cat": "memory", "name":"memory", "pid": 11, "tid": 11,"ts": lanuch_time, "args": {"value":device_memory_usage}})
        
        #增加flow event
        traceEvents.append({"ph": "s", "id": kernel_index, "pid": host_pid, "tid": host_tid, "ts": lanuch_time,"cat": "ac2g", "name": "ac2g"})
        traceEvents.append({"ph": "f", "id": kernel_index, "pid": 10,  "tid": 10,"ts": kernel_start_offset,"cat": "ac2g", "name": "ac2g", "bp": "e"})
        
        kernel_index+=1
        kernel_start_offset+=(kernel_duration+kernel_lanuch_latency)
    
    #保存最终的结果
    with open(output_json,'w',encoding='utf-8') as f:
        json.dump(prof, f,ensure_ascii=False,indent=4)
		
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
 
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
 
    def forward(self,query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = query@key.transpose(-2,-1) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e20)
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return p_attn@value, p_attn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        self.attention = ScaledDotProductAttention()
 
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        query=self.linears[0](query).view(nbatches, -1, self.h, self.d_k)
        query=query.transpose(1, 2)
        key=self.linears[1](key).view(nbatches, -1, self.h, self.d_k)
        key=key.transpose(1, 2)
        value=self.linears[2](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        x, self.attn = self.attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)
 
use_cuda=True
try:
    import torch_xpu
    import torch_xpu.contrib.transfer_to_xpu
    torch.xpu.set_device(0)
    torch.profiler.ProfilerActivity.PrivateUse1="xpu"
    use_cuda=False
except:
    pass
 
import os
os.environ['LOCAL_RANK']="0"
os.environ['RANK']="0"
os.environ['WORLD_SIZE']="1"
os.environ['MASTER_ADDR']="localhost"
os.environ['MASTER_PORT']="6006"

import torch.distributed as dist
dist.init_process_group(backend='vccl')
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
if not dist.is_available() or not dist.is_initialized():
    print("dist init error")
 
cross_attn = MultiHeadAttention(h=8, d_model=64).half().cuda()
cross_attn.eval()
q1 = torch.ones((1, 50, 64),dtype=torch.float32).half().cuda()
k1 = q1.clone()
v1 = q1.clone()
out = cross_attn.forward(q1,k1,v1).sum()
torch.cuda.synchronize()
 
activities=[torch.profiler.ProfilerActivity.CPU]
if use_cuda:
    activities.append(torch.profiler.ProfilerActivity.CUDA)
 
with profile(
    activities=activities,
    schedule=torch.profiler.schedule(
                wait=1,
                warmup=1,
                active=3,
                repeat=1),
    record_shapes=True,
    with_stack=True,
    with_modules=True,
    with_flops=True,
    profile_memory=True,
   ) as prof:
        for i in range(10):
            out = cross_attn.forward(q1,k1,v1).sum()
            prof.step()
        torch.cuda.synchronize()
 
if not use_cuda:
    kernel_event={}
    for ev in prof.profiler.function_events:
        if ev.privateuse1_time>0:
            uuid=f"{ev.name}_{ev.id}_{ev.thread}"
            #print(uuid,ev.id,ev.name,ev.privateuse1_time,ev.time_range.start,ev.time_range.end-ev.time_range.start,ev.privateuse1_memory_usage)
            kernel_event[uuid]={"kernel_time":ev.privateuse1_time,
								"device_memory_usage":ev.privateuse1_memory_usage,
								"start_us":ev.time_range.start,
								"host_dur":ev.time_range.end-ev.time_range.start,
								"thread":ev.thread} 
    import json
    with open(f"kernel_event_{rank}.json",'w',encoding='utf-8') as f:
        json.dump(kernel_event, f,ensure_ascii=False,indent=4)

    prof.export_chrome_trace(f"prof_{rank}.json")
    merge_prof_timeline(f"prof_{rank}.json",f"kernel_event_{rank}.json",f"prof_{rank}.json")
else:
    #print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    prof.export_chrome_trace(f"prof_{q1.device.type}.json")

4.可视化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1870102.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HarmonyOS Next开发学习手册——代码混淆

代码混淆简介 针对工程源码的混淆可以降低工程被破解攻击的风险&#xff0c;缩短代码的类与成员的名称&#xff0c;减小应用的大小。 DevEco Studio提供代码混淆的能力并默认开启&#xff0c;API 10及以上版本的Stage模型、 编译模式为release 时自动进行代码混淆。 使用约束…

React_创建一个项目

目录 一、React&#xff08;js 版&#xff09; 二、React&#xff08;ts 版&#xff09; 使用react创建一个项目,前提是确保你已经安装了Node.js和npm。 如果没有安装Node.js和npm&#xff0c;查看这个文件&#xff1a; 安装node.js和npmhttps://blog.csdn.net/zxy1993106…

Arcgis地统计分析工具灰色不可用 解决方法

使用Arcmap&#xff0c;调用地统计分析工具&#xff08;Geostatistical Analyst&#xff09;下的探索数据&#xff08;Explore Data&#xff09;&#xff0c;发现工具呈灰色不可用。这是由于扩展模块中没有将该模块做勾选设置导致的。下面介绍一下如何解决地统计分析工具不可用…

学生成绩管理系统带8000字文档学生选课管理系统java项目javaweb项目ssm项目jsp项目java课程设计java毕业设计

文章目录 学生选课成绩管理系统一、项目演示二、项目介绍三、8500字项目文档四、部分功能截图五、部分代码展示六、底部获取项目源码带8500字文档&#xff08;9.9&#xffe5;带走&#xff09; 学生选课成绩管理系统 一、项目演示 选课成绩管理系统 二、项目介绍 语言: Java …

【从0实现React18】 (五) 初探react mount流程 完成核心递归流程

更新流程的目的&#xff1a; 生成wip fiberNode树标记副作用flags 更新流程的步骤&#xff1a; 递&#xff1a;beginWork归&#xff1a;completeWork 在 上一节 &#xff0c;我们探讨了 React 应用在首次渲染或后续更新时的整体更新流程。在 Reconciler 工作流程中&#xff…

硬盘空间告急?监控服务器容量,钉钉及时提醒!

在日常的服务器维护中&#xff0c;硬盘容量的监控是非常重要的。如果硬盘容量超过某个阈值&#xff0c;可能会导致服务器无法正常运行&#xff0c;影响业务的正常运作。为了避免这种情况&#xff0c;我们可以编写一个Shell脚本&#xff0c;定期检查硬盘容量&#xff0c;当超过设…

springboot + Vue前后端项目(第二十记)

项目实战第二十记 写在前面1. 高德地图官网2. 开发文档3. 集成高德地图3.1 在public文件夹下创建config.js3.2 index.html&#xff08;在项目启动文件中引入外部的js&#xff09;3.3 点标记&#xff08;用点标记当前位置&#xff09;3.4 信息窗体&#xff08;点击当前位置&…

如何预防和处理他人盗用IP地址?

IP地址的定义及作用 解释 IP 地址在互联网中的作用。它是唯一标识网络设备的数字地址&#xff0c;类似于物理世界中的邮政地址。 1、IP地址盗窃的定义 解释一下什么是IP地址盗用&#xff0c;即非法使用他人的IP地址或者伪造IP地址的行为&#xff0c;这种行为可能引发法律和安…

DLMS/COSEM协议—(Green-Book)Gateway protocol

DLMS/COSEM协议 — Gateway protocol 10.10 Gateway protocol &#xff08;网关协议&#xff09;10.10.1 概述10.10.2 网关协议 &#xff08;The gateway protocol&#xff09;10.10.3 HES在WAN/NN中作为发起者&#xff08;拉取操作&#xff09;10.10.4 LAN中的终端设备作为发起…

数据库物理结构设计-定义数据库模式结构(概念模式、用户外模式、内模式)、定义数据库、物理结构设计策略

一、引言 如何基于具体的DBMS产品&#xff0c;为数据库逻辑结构设计的结果&#xff0c;即关系数据库模式&#xff0c;制定适合应用要求的物理结构 1、在设计数据库物理结构前&#xff0c;数据库设计人员首先 要充分了解所用的DBMS产品的功能、性能和特点&#xff0c;包括提供…

【最新综述】基于伪标签的半监督语义分割

Semi-Supervised Semantic Segmentation Based on Pseudo-Labels: A Survey 摘要&#xff1a; 语义分割是计算机视觉领域的一个重要而热门的研究领域&#xff0c;其重点是根据图像中像素的语义对其进行分类。然而&#xff0c;有监督的深度学习需要大量数据来训练模型&#xff…

Asm动态生成类和get and set方法

asm在解析文件的时候是按照特定顺序进行分析的&#xff0c;首先是visit方法&#xff0c;做类相关的解析&#xff0c;然后是注解&#xff0c;然后是属性&#xff0c;最后才是方法&#xff0c;属性是在所有方法分析前面进行&#xff0c;也就是只有当class文件中的所有属性都遍历完…

【Android11】开机启动日志捕捉服务

一、前言 制作这个功能的原因是客户想要自动的记录日志中的报错和警告到设备的内存卡里面。虽然开发者模式中有一个“bug report” 会在/data/user_de/0/com.android.shell/files/bugreports/目录下生成一个zip包记录了日志。但是客户觉得这个日志很难获取到他们需要的信息&am…

无需劳师动众,让石油化工DCS集散控制系统轻松实现无线传输!

石油化工中,为了保证较高的可靠性和安全性,大量使用的是DCS集散控制系统。与FCS现场总线的“现场采集,转换为数字信号来集中传输”不同,DCS系统为了避免由于线缆断裂或者节点问题导致整个控制系统失灵,采用“分散传输,集中采集”的方式,即每个传感器通过4-20mA的模拟量通…

el-upload 上传图片及回显照片和预览图片,文件流和http线上链接格式操作

<div v-for"(info, index) in zsjzqwhxqList.helicopterTourInfoList" :key"info.id" >编辑上传图片// oss返回线上地址http链接格式&#xff1a;<el-form-itemlabel"巡视结果照片":label-width"formLabelWidth"><el…

【MySQL】 -- 用户管理

1. 权限 如果我们只能使用root用户&#xff0c;这样存在安全隐患。这时&#xff0c;就需要使用MySQL的用户管理。创建出非root用户&#xff0c;限制其权限。 权限这个概念拿出来就是用来限制非root用户的。这样从技术手段上保证了数据的安全性和完整性&#xff0c;防止有人删库…

LeetCode11. 盛最多水的容器题解

LeetCode11. 盛最多水的容器题解 题目链接&#xff1a; https://leetcode.cn/problems/container-with-most-water 示例 思路 暴力解法 定住一个柱子不动&#xff0c;然后用其他柱子与其围住面积&#xff0c;取最大值。 代码如下&#xff1a; public int maxArea1(int[]…

麒麟系统安装Redis

一、背景 如前文&#xff08;《麒麟系统安装MySQL》&#xff09;所述。 二、下载Redis源码 官方未提供麒麟系统的Redis软件&#xff0c;须下载源码编译。 下载地址&#xff1a;https://redis.io/downloads 6.2.14版本源码下载地址&#xff1a;https://download.redis.io/re…

ADI-DSP|在指定内存写入数据

一、LDF文件设置内存空间 user_data_test { TYPE(BW RAM) START(0x00380010) END(0x0039bfff) WIDTH(8) }//usr data dxe_user_data_bw BW{INPUT_SECTION_ALIGN(4)INPUT_SECTIONS( $OBJS_LIBS(user_data) )} > user_data_test 二、在C文件中设置数据 /************…

2.x86游戏实战-跨进程读取血量

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 本次游戏没法给 内容参考于&#xff1a;微尘网络安全 接下来会写C/C代码&#xff0c;C/C代码不是很难&#xff0c;然后为了快速掌握逆向这个技能&#xff0c;我…