pytest学习-pytorch单元测试

news2024/11/29 14:54:39

pytorch单元测试

  • 一.公共模块[common.py]
  • 二.普通算子测试[test_clone.py]
  • 三.集合通信测试[test_ccl.py]
  • 四.测试命令
  • 五.测试报告

希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

本文基于torch.testing._internal

一.公共模块[common.py]

import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as dist

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" 

device="cpu"
device_type="cpu"
device_name="cpu"

try:
    if torch.cuda.is_available():     
        device_name=torch.cuda.get_device_name().replace(" ","")
        device="cuda:0"
        device_type="cuda"
        ccl_backend='nccl'
except:
    pass

host_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
device_count=torch.cuda.device_count()

if not os.path.exists(metric_data_root):
    os.makedirs(metric_data_root)

def device_warmup(device):
    '''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''
    left = torch.rand([128,512], dtype = torch.float16).to(device)
    right = torch.rand([512,128], dtype = torch.float16).to(device)
    out=torch.matmul(left,right)
    torch.cuda.synchronize()

torch.manual_seed(1) 
np.random.seed(1)

def loop_decorator(loops,rank=0):
    '''循环装饰器,用于统计函数的执行时间,内存占用等'''
    def decorator(func):
        def wrapper(*args,**kwargs):
            latency=[]
            memory_allocated_t0=torch.cuda.memory_allocated(rank)
            for _ in range(loops):
                input_copy=[x.clone() for x in args]
                beg= datetime.now().timestamp() * 1e6
                pred= func(*input_copy)
                gt=kwargs["golden"]
                torch.cuda.synchronize()
                end=datetime.now().timestamp() * 1e6
                mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()
                latency.append(end-beg)
            memory_allocated_t1=torch.cuda.memory_allocated(rank)
            avg_latency=np.mean(latency[len(latency)//2:]).round(3)
            first_latency=latency[0]
            return { "first_latency":first_latency,"avg_latency":avg_latency,
                      "memory_allocated":memory_allocated_t1-memory_allocated_t0,
                      "mse":mse}
        return wrapper
    return decorator

class TorchUtMetrics:
    '''用于统计测试结果,比较之前的最小值'''
    def __init__(self,ut_name,thresold=0.2,rank=0):
        self.ut_name=f"{ut_name}_{rank}"
        self.thresold=thresold
        self.rank=rank
        self.data={"ut_name":self.ut_name,"metrics":[]}
        self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")
        try:
            with open(self.metrics_path,"r") as f:
                self.data=json.loads(f.read())
        except:
            pass

    def __enter__(self):
        self.beg= datetime.now().timestamp() * 1e6
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):        
        self.report()
        self.save_data()

    def save_data(self):
        with open(self.metrics_path,"w") as f:
            f.write(json.dumps(self.data,indent=4))

    def set_metrics(self,metrics):
        self.end=datetime.now().timestamp() * 1e6
        item=collections.OrderedDict()
        item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
        item["sdk_version"]=sdk_version
        item["device_name"]=device_name
        item["host_name"]=host_name
        item["metrics"]=metrics
        item["metrics"]["e2e_time"]=self.end-self.beg
        self.cur_item=item
        self.data["metrics"].append(self.cur_item)

    def get_metric_names(self):
        return self.data["metrics"][0]["metrics"].keys()

    def get_min_metric(self,metric_name,devicename=None):
        min_value=0
        min_value_index=-1
        for idx,item in enumerate(self.data["metrics"]):
            if devicename and (devicename!=item['device_name']):                
                continue            
            val=float(item["metrics"][metric_name])
            if min_value_index==-1 or val<min_value:
                min_value=val
                min_value_index=idx
        return min_value,min_value_index

    def get_metric_info(self,index):
        metrics=self.data["metrics"][index]
        return f'{metrics["device_name"]}@{metrics["sdk_version"]}'

    def report(self):
        assert len(self.data["metrics"])>0
        for metric_name in self.get_metric_names():
            min_value,min_value_index=self.get_min_metric(metric_name)
            min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)
            cur_value=float(self.cur_item["metrics"][metric_name])
            print(f"-------------------------------{metric_name}-------------------------------")
            print(f"{cur_value}#{device_name}@{sdk_version}")
            if min_value_index_same_dev>=0:
                print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")
            if min_value_index>=0:
                print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子测试[test_clone.py]

from common import *
class TestCaseClone(TestCase):
    #如果不满足条件,则跳过这个测试
    @unittest.skipIf(device_count>1, "Not enough devices") 
    def test_todo(self):
        print(".TODO")

    #框架会自动遍历以下参数组合
    @parametrize("shape", [(10240,20480),(128,256)])
    @parametrize("dtype", [torch.float16,torch.float32])
    def test_clone(self,shape,dtype):
        
        #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
        @loop_decorator(loops=5)
        def run(input_dev):
            output=input_dev.clone()
            return output
        
        #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:
            input_host=torch.ones(shape,dtype=dtype)*np.random.rand()
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=input_host.cpu())
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        
instantiate_parametrized_tests(TestCaseClone)

if __name__ == "__main__":
    run_tests()

三.集合通信测试[test_ccl.py]

from common import *
class TestCCL(MultiProcessTestCase):
    '''CCL测试用例'''
    def _create_process_group_vccl(self, world_size, store):
        dist.init_process_group(
            ccl_backend, world_size=world_size, rank=self.rank, store=store
        )        
        pg = dist.distributed_c10d._get_default_group()
        return pg

    def setUp(self):
        super().setUp()
        self._spawn_processes()

    def tearDown(self):
        super().tearDown()
        try:
            os.remove(self.file_name)
        except OSError:
            pass

    @property
    def world_size(self):
        return 4
      
    #框架会自动遍历以下参数组合
    @unittest.skipIf(device_count<4, "Not enough devices") 
    @parametrize("op",[dist.ReduceOp.SUM])
    @parametrize("shape", [(1024,8192)])
    @parametrize("dtype", [torch.int64])
    def test_allreduce(self,op,shape,dtype):
        if self.rank >= self.world_size:
            return
        
        store = dist.FileStore(self.file_name, self.world_size)
        pg = self._create_process_group_vccl(self.world_size, store)
        if not torch.distributed.is_initialized():
            return
    
        torch.cuda.set_device(self.rank)
        device = torch.device(device_type,self.rank)
        device_warmup(device)
        #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
        @loop_decorator(loops=5,rank=self.rank)
        def run(input_dev):
            dist.all_reduce(input_dev, op=op)
            return input_dev
        
        #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:
            input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)
            gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]
            gt_=gt[0]
            for i in range(1,self.world_size):
                gt_=gt_+gt[i]
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=gt_)
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        dist.destroy_process_group(pg)
    
instantiate_parametrized_tests(TestCCL)

if __name__ == "__main__":
    run_tests()

四.测试命令

# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./

# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.测试报告

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1603174.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python基于Django搜索的目标站点内容监测系统设计,附源码

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

Android Studio XML 预览View 底部移动到右边

以前 XML 的预览都是在右边的&#xff0c;最近不知道为什么突然到下面去了&#xff0c;很不习惯 找半天想把 预览view 移动到右边&#xff0c;一直没找到按钮。 误打误撞移回来了&#xff0c;原来只要再点击一次 split&#xff0c;就可以变动位置了&#xff0c;记录一下。

ChatGPT及GIS、生物、地球、农业、气象、生态、环境科学领域案例

以ChatGPT、LLaMA、Gemini、DALLE、Midjourney、Stable Diffusion、星火大模型、文心一言、千问为代表AI大语言模型带来了新一波人工智能浪潮&#xff0c;可以面向科研选题、思维导图、数据清洗、统计分析、高级编程、代码调试、算法学习、论文检索、写作、翻译、润色、文献辅助…

C++教你如何模拟实现string,如何实现string写时拷贝

文章目录 前言成员变量默认成员函数默认构造函数拷贝构造函数析构函数赋值运算符重载 容量相关函数&#xff08;Capacity&#xff09;reserve函数resize函数size函数capacity 函数clear函数 修改函数&#xff08;Modifiers&#xff09;swap函数insert函数字符插入字符串插入 ap…

未来城市可视化,A3D引擎支持,免费搭建全新一代数字孪生!

AMRT3D数字孪生引擎https://www.amrt3d.com/#/ 什么是未来城市&#xff1f;它是新型数字化理念的载体&#xff0c;以数字孪生与物理世界城市的融合为核心&#xff0c;通过数字孪生技术在数字空间实时构建城市&#xff0c;采用数据整合和分析预测来实时模拟、预测、控制整体城市…

uniapp之消除图片的空白占用空间

我们在使用uniapp开发的过程中一定会遇到一个情况就是我们加载的图片总有一点空白出现在不该出现的地方代码如下 <view style"background:#ff0000;"><image style"width:100%;"src"https://t7.baidu.com/it/u1819248061,230866778&fm19…

[论文笔记]Root Mean Square Layer Normalization

引言 今天带来论文Root Mean Square Layer Normalization的笔记&#xff0c;论文题目是均方根层归一化。 本篇工作提出了RMSNorm&#xff0c;认为可以省略重新居中步骤。 简介 层归一化对Transformer等模型非常重要&#xff0c;它可以帮助稳定训练并提升模型收敛性&#xf…

uniapp-小程序保存图片到相册

小程序保存图片到相册 一. 将图片保存到手机相册涉及的api 有以下几个 1. uni.getSetting (获取用户的当前设置) 2. uni.authorize&#xff08;提前向用户发起授权请求。调用后会立刻弹窗询问用户是否同意授权小程序使用某项功能或获取用户的某些数据&#xff0c;但不会实际调…

GPT国内能用吗

2022年11月&#xff0c;Open AI发布ChatGPT&#xff0c;ChatGPT展现了大型语模型在自然语言处理方面的惊人进步&#xff0c;其生成文本的流畅度和连贯性令人印象深刻&#xff0c;为AI应用打开了新的可能性。 ChatGPT的出现推动了AI技术在各个领域的应用&#xff0c;例如&#x…

『Django』创建app(应用程序)

theme: smartblue 本文简介 点赞 关注 收藏 学会了 在《『Django』环境搭建》中介绍了如何搭建 Django 环境&#xff0c;并且创建了一个 Django 项目。 在刚接触 Django 时有2个非常基础的功能是需要了解的&#xff0c;一个是“app”(应用程序)&#xff0c;另一个是 url(路由…

kafka---topic详解

一、分区与高可用 在Kafka中,事件(events 事件即消息)是以topic的形式进行组织的;同时topic是分区(partitioned)的,这意味着一个topic分布在Kafka broker上的多个“存储桶”(buckets)上。这种数据的分布式放置对于可伸缩性非常重要,因为它允许客户端应用程序同时从多个…

「 网络安全常用术语解读 」漏洞利用交换VEX详解

漏洞利用交换&#xff08;Vulnerability Exploitability eXchange&#xff0c;简称VEX&#xff09;是一个信息安全领域的标准&#xff0c;旨在提供关于软件漏洞及其潜在利用的实时信息。根据美国政府发布的用例(PDF)&#xff0c;由美国政府开发的漏洞利用交换(VEX)使供应商和用…

postman 调试 传base64字符串 原来选xml

上个图 工具类 package org.springblade.common.utils;import com.alibaba.fastjson.JSONObject; import org.springblade.modules.tc.mas.Submit;import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStrea…

tcp三次握手和四次断开以及tcpdump的基本使用

前言 最近工作中会发现有超时的问题&#xff0c;还有就是在面试的时候很多都要求深入理解TCP/IP协议。突然感觉TCP/IP协议是一个既熟悉&#xff0c;又陌生的技术。又想到上大学的时候&#xff0c;老师说过 网络的圣经&#xff1a;“TCP/IP详解” 卷一 卷二 卷三&#xff0c;三…

Spring之CGLIB和JDK动态代理底层实现

目录 CGLIB 使用示例-支持创建代理对象&#xff0c;执行代理逻辑 使用示例-多个方法&#xff0c;走不同的代理逻辑 JDK动态代理 使用示例-支持创建代理对象&#xff0c;执行代理逻辑 Spring会自动在JDK动态代理和CGLIB之间转换: 1、如果目标对象实现了接口&#xff0c;默…

基于51单片机智能鱼缸仿真LCD1602显示( proteus仿真+程序+设计报告+讲解视频)

基于51单片机智能鱼缸仿真LCD显示 1. 主要功能&#xff1a;2. 讲解视频&#xff1a;3. 仿真4. 程序代码5. 设计报告6. 设计资料内容清单&&下载链接资料下载链接&#xff1a; 基于51单片机智能鱼缸仿真LCD显示( proteus仿真程序设计报告讲解视频&#xff09; 仿真图prot…

【Web】NewStarCTF 2022 题解(全)

目录 Week1 HTTP Head?Header! 我真的会谢 NotPHP Word-For-You Week2 Word-For-You(2 Gen) IncludeOne UnserializeOne ezAPI Week3 BabySSTI_One multiSQL IncludeTwo Maybe You Have To think More Week4 So Baby RCE BabySSTI_Two UnserializeT…

IDEA 安装、基本使用、创建项目

文章目录 下载基本使用修改颜色主题Keymap插件 创建项目创建模块新建 Java 类运行新建 Package打包 Jar运行 jar 包 查看文档 下载 官方下载地址&#xff1a;https://www.jetbrains.com/zh-cn/idea/download/?sectionmac 这里我下载 macOS 社区版&#xff0c;IDEA 2024.1 (C…

mPEG-Glutaramide Acid结合了聚乙二醇(PEG)和戊二酸(GAA)的性质

【试剂详情】 英文名称 mPEG-GAA&#xff0c;Methoxy PEG GAA&#xff0c; mPEG-Glutaramide Acid 中文名称 聚乙二醇单甲醚酰胺戊二酸&#xff0c; 甲氧基-聚乙二醇-戊二酰胺酸 外观性状 由分子量决定 分子量 400,600&#xff0c;2k&#xff0c;3.4k&#xff0c;5k&…

代码随想录算法训练营第五十七天 | 647. 回文子串、516. 最长回文子序列

代码随想录算法训练营第五十七天 | 647. 回文子串、516. 最长回文子序列 647. 回文子串题目解法 516. 最长回文子序列题目解法 动态规划总结链接感悟 647. 回文子串 题目 解法 题解链接 动态规划 class Solution { public:int countSubstrings(string s) {// dp[i][j]:表示…