TensorRT推理时,如何比对中间层的误差

news2025/1/6 19:27:02

TensorRT推理时,如何比对中间层的误差

  • 有二种方案
  • 第二种方案的实现
    • 1.运行环境的搭建
    • 2.实现代码(compare_trt_onnxrt.py)
    • 3.运行
    • 4.输出

本文演示了TensorRT推理时,如何比对中间层的误差。
在做TensorRT推理加速时,可能会遇到精度问题,希望定位到是哪一个节点引起的误差,还是累计误差。

有二种方案

  • 1.通过polygraphy run和–onnx-outputs和–trt-outputs参数,标记需要输出的节点。这种方法简单,但不灵活。
  • 2.基于polygraphy api编程实现,该方法可以灵活设置希望标记为输出的节点

第二种方案的实现

1.运行环境的搭建

  • 参考链接

2.实现代码(compare_trt_onnxrt.py)

from polygraphy.logger import G_LOGGER
import onnx
import torch
from torch.nn import functional as F
import tensorrt as trt
from polygraphy.comparator.data_loader import DataLoader
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
from polygraphy.backend.trt import EngineBytesFromNetwork, EngineFromBytes, NetworkFromOnnxPath, TrtRunner
from polygraphy.comparator import Comparator
from polygraphy.exception import PolygraphyException
from polygraphy.comparator.compare import CompareFunc
import pycuda.driver as cuda
import os
import pycuda.autoinit
import glob
from PIL import Image
import imageio as imageio
import numpy as np
import sys
import time

def get_img_data(img_lq):
    '''图片预处理'''
    img_t = torch.from_numpy(np.array(img_lq)).to('cpu').permute(2, 0, 1).unsqueeze(0)
    img_t = (img_t/255.-0.5)/0.5
    img_t = F.interpolate(img_t, (512, 512)).squeeze(0)
    return img_t.numpy()

class MNISTEntropyCalibrator(trt.IInt8MinMaxCalibrator):
    def __init__(self, cache_file):
        trt.IInt8MinMaxCalibrator.__init__(self)
        self.cache_file = cache_file
        self.current_index = 0
        self.batch_size=1
        self.dataset = []

        img_preprocess=get_img_data
        for input_file in glob.glob("images/*"):
            src_img = Image.open(input_file.strip())
            src_img = img_preprocess(src_img)
            src_img = np.expand_dims(src_img, 0)
            src_tensor = torch.tensor(src_img).numpy()
            self.dataset.append(src_tensor)
        if len(self.dataset)==0:
            self.dataset.append(torch.ones((1,3,512,512),dtype=torch.float32))
        self.data=np.concatenate(self.dataset,axis=0)
        self.total_samples=len(self.dataset)
        self.device_input = cuda.mem_alloc(self.data[0].nbytes * self.batch_size)

    def get_batch_size(self):
        return self.batch_size

    def get_batch(self, names):
        if self.current_index + self.batch_size > self.total_samples:
            return None

        batch = self.data[self.current_index : self.current_index + self.batch_size].ravel()
        cuda.memcpy_htod(self.device_input, batch)
        self.current_index += self.batch_size
        return [self.device_input]

    def read_calibration_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

def main():

    onnx_model_path="yolov5n.onnx"
    onnx_model_output_path="yolov5n_tmp.onnx"
    
    # 需要导出的节点类型(也可以修改为按name过滤)
    filter_names=["Pow","Add","Sqrt","Div","Mul","LeakyRelu","Reduce","Conv","Gemm"]
    
    # 标记onnx模型的输出节点
    model = onnx.load(onnx_model_path)
    for node in model.graph.node:
        for name in filter_names:
            if node.name.find(name)>=0:
                for output in node.output[:1]:
                    model.graph.output.extend([onnx.ValueInfoProto(name=output)])

    onnx.save(model,onnx_model_output_path)

    parse_network_from_onnx = NetworkFromOnnxPath(onnx_model_path)
    builder, network, parser=parse_network_from_onnx()
	
    # TensorRT编译参数
    config = builder.create_builder_config()
    config.set_tactic_sources(1 << int(trt.TacticSource.CUBLAS))
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 2 << 30)
    config.set_flag(trt.BuilderFlag.FP16)
    
    if True: #是否为int8模式
        calibration_cache = "calibration.cache"
        if os.path.exists(calibration_cache):
            os.remove(calibration_cache)
        calib = MNISTEntropyCalibrator(cache_file=calibration_cache)
        config.int8_calibrator = calib
        config.set_flag(trt.BuilderFlag.INT8)
    
    # 标记输出节点
    for i in range(network.num_layers):
        layer = network.get_layer(i)
        if layer.type in [trt.LayerType.CONSTANT]:
            continue
        for name in filter_names:
            if layer.name.find(name)>=0:
                out_name=layer.get_output(0).name
                if out_name!="skip.23":
                    layer.set_output_type(0, trt.DataType.FLOAT)
                    network.mark_output(layer.get_output(0))
                break

    build_engine = EngineBytesFromNetwork((builder, network, parser),config)
    deserialize_engine = EngineFromBytes(build_engine)
    build_onnxrt_session = SessionFromOnnx(onnx_model_output_path)

    runners = [
        TrtRunner(deserialize_engine),
        OnnxrtRunner(build_onnxrt_session),
    ]

    data_loader = DataLoader(val_range=(-1.0, 1.0))
    results = Comparator.run(runners,data_loader)
    success = True
    #compare_func = CompareFunc.simple(check_shapes=True,show_heatmaps=False,rtol=1e-2,atol=1e-2)
    
    # 忽略绝对误差,只关心相对误差
    compare_func = CompareFunc.simple(check_shapes=True,show_heatmaps=False,rtol=1e-2,atol=1000)
    success &= bool(Comparator.compare_accuracy(results,compare_func=compare_func))
    if not success:
        raise PolygraphyException('FAILED')
        
if __name__=="__main__":
    main()

3.运行

# 安装pycuda
export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/cuda/lib64
export CPATH=$CPATH:/usr/local/cuda/include
pip install pycuda==2023.1
# 执行
python compare_trt_onnxrt.py

4.输出

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1471944.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

10.vue学习笔记(组件数据传递-props回调函数子传父+透传Attributes+插槽slot)

文章目录 1.组件数据传递2.透传Attributes&#xff08;了解&#xff09;禁用Attributes继承 3.插槽slot3.1.插槽作用域3.2.默认内容3.3.具名插槽3.4.插槽中的数据传递3.5.具名插槽传递数据 1.组件数据传递 我们之前讲解过了组件之间的数据传递&#xff0c;props 和 自定义事件…

【IDEA】java 项目启动偶现Kotlin 版本问题 error:Kotlin:module was

一、问题描述&#xff1a; error:Kotlin:module was compiled with an incompatible version of kotlin the binary version of its metadata is二、问题原因&#xff1a; jar包版本冲突 三、解决方式&#xff1a; 1、Rebuild Project&#xff08;推荐☆&#xff09; 重新构…

小迪安全30WEB 攻防-通用漏洞SQL 注入CTF二次堆叠DNS 带外

#知识点&#xff1a; 1、数据库堆叠注入 根据数据库类型决定是否支持多条语句执行 2、数据库二次注入 应用功能逻辑涉及上导致的先写入后组合的注入 3、数据库 Dnslog 注入 解决不回显(反向连接),SQL 注入,命令执行,SSRF 等 4、黑盒模式分析以上 二次注入&…

2024.2.21 模拟实现 RabbitMQ —— 实现转发规则

目录 需求分析 直接交换机&#xff08;Direct &#xff09; 主题交换机&#xff08;Topic &#xff09; 扇出交换机&#xff08;Fanout &#xff09; Topic 交换机转发规则 routingKey 组成 bindingKey 组成 匹配规则 情况一 情况二 情况三 实现 Router 类 校验 b…

【AIGC大模型】跑通wonder3D (windows)

论文链接&#xff1a;https://arxiv.org/pdf/2310.15008.pdf windows10系统 显卡&#xff1a;NVIDIA rtx 2060 一、安装anaconda 二、安装CUDA 11.7 (CUDA Toolkit 11.7 Downloads | NVIDIA Developer) 和 cudnn 8.9.7(cuDNN Archive | NVIDIA Developer)库 CUDA选择自定…

【Android】坐标系

Android 系统中有两种坐标系&#xff0c;分别为 Android 坐标系和 View 坐标系。了解这两种坐标系能够帮助我们实现 View 的各种操作&#xff0c;比如我们要实现 View 的滑动&#xff0c;你连这个 View 的位置都不知道&#xff0c;那如何去操作呢&#xff1f; 一、Android 坐标…

【Spring Cloud】高并发带来的问题及常见容错方案

文章目录 高并发带来的问题编写代码修改配置压力测试修改配置&#xff0c;并启动软件添加线程组配置线程并发数添加Http取样配置取样&#xff0c;并启动测试访问message方法观察效果 服务雪崩效应常见容错方案常见的容错思路常见的容错组件 总结 欢迎来到阿Q社区 https://bbs.c…

《极简C++学习专栏》之结束语

朋友们&#xff0c;经过这么长的时间&#xff0c;《极简C学习专栏》的文章创作就要结束了&#xff0c;感谢你们一路陪伴&#xff01; 也希望你们能支持我接下来的其他专栏的创作&#xff01; 专栏的初衷 《极简C学习》专栏的初衷源自于我个人的学习笔记&#xff0c;记录下自己…

【刷题】牛客 JZ64 求1+2+3+...+n

刷题 题目描述思路一 &#xff08;暴力递归版&#xff09;思路二 &#xff08;妙用内存版&#xff09;思路三 &#xff08;快速乘法版&#xff09;思路四 &#xff08;构造巧解版&#xff09;Thanks♪(&#xff65;ω&#xff65;)&#xff89;谢谢阅读&#xff01;&#xff01…

Java 面向对象进阶 18 JDK8、9开始新增的方法;接口的应用;适配器设计模式;内部类(黑马)

一、JDK8开始新增的方法 默认方法不是抽象方法&#xff0c;所以不强制被重写&#xff1a; 但是如果被重写&#xff0c;就要去掉default关键字&#xff1a; public可以省略&#xff0c;但是default不可以省略&#xff1a; public是灰色的&#xff0c;代表可以省略 但是default是…

怎么把pdf转换成word?

怎么把pdf转换成word&#xff1f;Pdf和word在电脑上的使用非常广泛&#xff0c;pdf和word分别是由 Adobe和Microsoft 分别开发的电脑文件格式。PDF 文件可以在不同操作系统和设备上保持一致的显示效果&#xff0c;无论是在 Windows、Mac 还是移动设备上查看&#xff0c;都能保持…

使用Docker部署MinIO并结合内网穿透实现远程访问本地数据

文章目录 前言1. Docker 部署MinIO2. 本地访问MinIO3. Linux安装Cpolar4. 配置MinIO公网地址5. 远程访问MinIO管理界面6. 固定MinIO公网地址 前言 MinIO是一个开源的对象存储服务器&#xff0c;可以在各种环境中运行&#xff0c;例如本地、Docker容器、Kubernetes集群等。它兼…

《绝地求生》提示msvcp140.dll丢失如何修复?分享5种靠谱的解决方法

在玩绝地求生&#xff08;PUBG&#xff09;游戏过程中&#xff0c;如果遇到系统弹出“提示请重新安装软件msvcp140.dll”的信息&#xff0c;这究竟是什么原因导致的呢&#xff1f;msvcp140.dll这个文件是Microsoft Visual C Redistributable Package的一部分&#xff0c;是许多…

服务器系统安全,10招教你维护服务器的安全

网络逐渐成为了我们生活中一部分。有人说&#xff0c;断WIFI是最厉害的一种惩罚手段&#xff0c;但是其实不然&#xff0c;最狠的莫过于网站的服务器遭受攻击&#xff0c;直接访问不了网页了&#xff0c;这时候就算有wifi我们也无能为力。服务器系统安全一直是管理者最关注的事…

这个元宵节,被云开发者安排了

元宵节快乐&#xff0c;同学们&#xff01;今天吃的汤圆都是什么馅儿的&#xff1f; 都说过了元宵&#xff0c;这个年才算是正式过完&#xff0c;2024年就算是正式开启。学堂君这里也准备了一份专属于开发者的小礼物&#xff0c;作为一点心意。 欢迎私信&#xff0c;发送暗号…

Ubuntu 某软件导致卡机如何 kill 掉进程

输入 top 查看现在系统的进程&#xff0c;记下该进程第一列的 pid 编号 kill [pid] 可以杀掉此进程 参考&#xff1a; Ubuntu下查看进程pid及结束无响应程序_终止3分钟内无响应的所有pid-CSDN博客

【OneAPI】节假日查询API

OneAPI新接口发布&#xff1a;节假日查询API 可查询指定月份、年份法定节假日及调休情况。 API地址&#xff1a;https://oneapi.coderbox.cn/openapi/public/holiday 请求参数 URL参数 参数名类型必须含义说明datestring否要查询的日期可按年或月查询&#xff0c;支持前缀…

Coursera吴恩达机器学习专项课程02:Advanced Learning Algorithms 笔记 Week01

Advanced Learning Algorithms Week 01 笔者在2022年7月份取得这门课的证书&#xff0c;现在&#xff08;2024年2月25日&#xff09;才想起来将笔记发布到博客上。 Website: https://www.coursera.org/learn/advanced-learning-algorithms?specializationmachine-learning-in…

如何使用Lychee+cpolar搭建本地私人图床并实现远程访问存储图片

文章目录 1.前言2. Lychee网站搭建2.1. Lychee下载和安装2.2 Lychee网页测试2.3 cpolar的安装和注册 3.本地网页发布3.1 Cpolar云端设置3.2 Cpolar本地设置 4.公网访问测试5.结语 1.前言 图床作为图片集中存放的服务网站&#xff0c;可以看做是云存储的一部分&#xff0c;既可…

嵌入式Qt 实现用户界面与业务逻辑分离

一.基本程序框架一般包含 二.框架的基本设计原则 三.用户界面与业务逻辑的交互 四.代码实现计算器用户界面与业务逻辑 ICalculator.h #ifndef _ICALCULATOR_H_ #define _ICALCULATOR_H_#include <QString>class ICalculator { public:virtual bool expression(const QSt…