TensorRT学习笔记--基本概念和推理流程

news2025/1/12 20:50:55

目录

前言

1--Tensor RT基本概念

2--推理流程

3--实例代码


前言

        以下 Tensor RT 的基本概念和推理流程均为博主自我的理解,可能部分内存会存在错误或偏差,仅供参考!

1--Tensor RT基本概念

① Logger:日志记录器,可用于记录模型编译的过程;

② Builder:可用于创建 Network,对模型进行序列化生成engine;

③ Network:由 Builder 的创建,最初只是一个空的容器;

④ Parser:用于解析 Onnx 等模型;

⑥ context:上接 engine,下接 inference,因此解释为上下文;

2--推理流程

        结合之前博主的相关笔记,根据博主的个人理解,将Tensor RT的推理流程表达为下图:

         ① 基于 Tensor RT 进行模型的推理,从宏观上可分为三个阶段,即输入数据的前处理(preprocess)、模型推理(inference)和推理结果的后处理(postprocess);

        ② 在模型推理中,一般会将 Onnx 等格式的模型通过编译(build)生成推理引擎(engine),engine 可以通过序列化(serialize)的方式进行永久存储,永久存储的 engine 则可以通过反序列化(deserialize)的方式进行加载;

        ③ 在正式推理(inference)前,需要手动申请 Cuda 的内存(memory),以存储输入和输出数据流,数据将以流(stream)的形式进行传递;推理前,输入数据(input_data)需要从主机(host)转移到 Cuda(device)中,推理结束后的推理结果则需要从 Cuda(device)转移到主机(host)的内存当中;

        ④ 编译模型时,必须需要定义 logger,并使用定义的 parser 进行解析模型,config用于配置模型,如通过 profile 设定模型的动态输入尺寸,通过 builder 可以创建序列化的 engine;

        ⑤ 序列化的 engine 需要经过反序列化,才能用于创建 context,最后进行推理;

3--实例代码

        以下提供一个使用 Tensor RT 进行完整推理的代码,使用的 Onnx 模型可参考博主之前的博客(导出动态的Onnx模型)。

import pycuda
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
import torch
import numpy as np

# 前处理
def preprocess(data):
    data = np.asarray(data)
    return data
    
# 后处理
def postprocess(data):
    data = np.reshape(data, (B, 256, H, W))
    return data
    
# 创建build_engine类
class build_engine():
    def __init__(self, onnx_path):
        super(build_engine, self).__init__()
        self.onnx = onnx_path
        self.engine = self.onnx2engine() # 调用 onnx2engine 函数生成 engine
        
    def onnx2engine(self):
        # 创建日志记录器
        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
        
        # 显式batch_size,batch_size有显式和隐式之分
        EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        
        # 创建builder,用于创建network
        builder = trt.Builder(TRT_LOGGER)
        network = builder.create_network(EXPLICIT_BATCH) # 创建network(初始为空)
        
        # 创建config
        config = builder.create_builder_config()
        profile = builder.create_optimization_profile() # 创建profile
        profile.set_shape("input", (1,3,128,128), (3,3,256,256), (5,3,512,512))  # 设置动态输入,分别对应:最小尺寸、最佳尺寸、最大尺寸
        config.add_optimization_profile(profile)
        config.max_workspace_size = 1<<30 # 允许TensorRT使用1GB的GPU内存,<<表示左移,左移30位即扩大2^30倍,使用2^30 bytes即 1 GB
        
        # 创建parser用于解析模型
        parser = trt.OnnxParser(network, TRT_LOGGER)
        
        # 读取并解析模型
        onnx_model_file = self.onnx # Onnx模型的地址
        model = open(onnx_model_file, 'rb')
        if not parser.parse(model.read()): # 解析模型
            for error in range(parser.num_errors):
                print(parser.get_error(error)) # 打印错误(如果解析失败,根据打印的错误进行Debug)

        # 创建序列化engine
        engine = builder.build_serialized_network(network, config)
        return engine
        
    def get_engine(self):
        return self.engine # 返回 engine
    
# 分配内存缓冲区
def Allocate_memory(engine, context):
    bindings = []
    for binding in engine:
        binding_idx = engine.get_binding_index(binding) # 遍历获取对应的索引
        
        size = trt.volume(context.get_binding_shape(binding_idx))
        # context.get_binding_shape(binding_idx): 获取对应索引的Shape,例如input的Shape为(1, 3, H, W)
        # trt.volume(shape): 根据shape计算分配内存 
        
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # engine.get_binding_dtype(binding): 获取对应index或name的类型
        # trt.nptype(): 映射到numpy类型
        
        if engine.binding_is_input(binding): # 当前index为网络的输入input
            input_buffer = np.ascontiguousarray(input_data) # 将内存不连续存储的数组转换为内存连续存储的数组,运行速度更快
            input_memory = cuda.mem_alloc(input_data.nbytes) # cuda.mem_alloc()申请内存
            bindings.append(int(input_memory))
        else:
            output_buffer = cuda.pagelocked_empty(size, dtype)
            output_memory = cuda.mem_alloc(output_buffer.nbytes)
            bindings.append(int(output_memory))
            
    return input_buffer, input_memory, output_buffer, output_memory, bindings

        
if __name__ == "__main__":
    
    # 设置输入参数,生成输入数据
    Batch_size = 3
    Channel = 3
    Height = 256
    Width = 256
    input_data = torch.rand((Batch_size, Channel, Height, Width))
    
    # 前处理
    input_data = preprocess(input_data)
    
    # 生成engine
    onnx_model_file = "./Dynamics_InputNet.onnx"
    engine_build = build_engine(onnx_model_file)
    engine = engine_build.get_engine()
    
    # 生成context
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    runtime = trt.Runtime(TRT_LOGGER)
    engine = runtime.deserialize_cuda_engine(engine)
    context = engine.create_execution_context()
    
    # 绑定上下文
    B, C, H, W = input_data.shape
    context.set_binding_shape(engine.get_binding_index("input"), (B, 3, H, W))
    
    # 分配内存缓冲区
    input_buffer, input_memory, output_buffer, output_memory, bindings = Allocate_memory(engine, context)
    
    # 创建Cuda流
    stream = cuda.Stream()
    # 拷贝数据到GPU (host -> device)
    cuda.memcpy_htod_async(input_memory, input_buffer, stream) # 异步拷贝数据
    
    # 推理
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
    
    # 将GPU得到的推理结果 拷贝到主机(device -> host)
    cuda.memcpy_dtoh_async(output_buffer, output_memory, stream)
    
    # 同步Cuda流
    stream.synchronize()
    
    # 后处理
    output_data = postprocess(output_buffer)
    print("output.shape is : ", output_data.shape)

运行结果如下:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/156592.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ssm:spring定时任务Task和CronExpression表达式

开发一个定时任务&#xff1a;每天晚上23点执行数据归集任务 首先Spring配置文件&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.or…

Java中的常用的代理模式

本文介绍在Java种常用的3种动态代理。 代理模式是23种模式中的一种&#xff0c;属于结构型设计模式。这种模式的作用就是要创建一个中间对象&#xff08;相当于中介或者代理对象&#xff09;&#xff0c;通过操作中间对象来间接调用目的对象的方法&#xff0c;字段等&#xff0…

Everything搜索知识总结

1.只知道那个文件以 .txt结尾 .*\.txt$ ($表示以什么结尾) 2.搜索某个路径下的文件 D:\ configure.bat (搜索D盘下的该文件,注意要用这种类型的"\",和被搜索的文件之间有空格;要先打出路径,再打出搜索文件.) 3.搜索指定路径下的多个文件 路径\ 文件1 | …

Halcon亚像素边缘缺陷检测案例

一、下面的案例是总结的Halcon边缘缺陷检测的一种情况。本案例是利用阈值分割获取金属区域&#xff0c;并利用boundary和edges_sub_pix获取到亚像素边缘。然后综合利用fit_rectangle2_contour_xld拟合出金属对应的放射矩形&#xff0c;最后利用dist_rectangle2_contour_points_…

【小白课程】openKylin用户手册原理解析,一招教你学会自定义!

openKylin用户手册是详细描述openKylin操作系统的功能和用户界面&#xff0c;让用户了解如何使用该软件的说明书。通过阅读openKylin用户手册&#xff0c;能够更快更好的上手和使用openKylin操作系统。今天就带大家简单了解下openKylin用户手册的实现原理以及如何自定义用户手册…

用EditPlus编译Fortran

一、EditPlus配置 语法点亮 安装好EditPlus后&#xff0c;点击Tool->Prefenrences&#xff0c;在File->Setting&syntex下&#xff0c;点击Add按钮&#xff0c;填Frotran。 到EditPlus官网上 EditPlus - User Files (other files) 下载Fortran语法文件 ​ 二、配置…

设计模式学习(四):Strategy策略模式

一、什么是Strategy模式 Strategy的意思是“策略”&#xff0c;指的是与敌军对垒时行军作战的方法。在编程中&#xff0c;我们可以将它理解为“算法”。无论什么程序&#xff0c;其目的都是解决问题。而为了解决问题&#xff0c;我们又需要编写特定的算法。使用Strategy模式可以…

Redis- 主从复制原理

1、概述 Master节点在平时提供服务&#xff0c;另外一个或多个Slave节点在平时不提供服务&#xff08;或只提供数据读取服务&#xff09;。当Master节点由于某些原因停止服务后&#xff0c;再人工/自动完成Slave节点到Master节点的切换工作&#xff0c;以便整个Redis集群继续向…

Spring依赖注入源码分析

1. 前言 Spring的核心之一就是依赖注入&#xff0c;Spring提供了Autowired注解来给bean注入依赖。除了注入最基本的bean之外&#xff0c;Spring还做了一些扩展&#xff0c;例如你可以注入Optional&#xff0c;以此来判断依赖的bean是否存在&#xff1b;你还可以注入Map来获得所…

Leetcode:617. 合并二叉树(C++)

目录 问题描述&#xff1a; 实现代码与解析&#xff1a; 递归&#xff1a; 原理思路&#xff1a; 迭代&#xff1a; 原理思路&#xff1a; 问题描述&#xff1a; 给你两棵二叉树&#xff1a; root1 和 root2 。 想象一下&#xff0c;当你将其中一棵覆盖到另一棵之上时&am…

leetcode 399. 除法求值-java题解

题目所属分类 flod最短路算法 原题链接 给你一个变量对数组 equations 和一个实数值数组 values 作为已知条件&#xff0c;其中 equations[i] [Ai, Bi] 和 values[i] 共同表示等式 Ai / Bi values[i] 。每个 Ai 或 Bi 是一个表示单个变量的字符串。 另有一些以数组 queri…

编译metabase

Linux Centos7 配置Metabase编译打包环境 安装Oracle JDK1.8&#xff08;如果已经安装&#xff0c;则可以省略此步骤&#xff0c;必须是Oracle JDK&#xff09; 在线下载Oracle JDK 1.8 将下载好的tar包放入linux目录下 2、解压tar进行安装 tar -zxvf jdk-8u212-linux-x64.t…

SSL/TLS协议信息泄露漏洞(CVE-2016-2183)

最近服务器扫描出SSL/TLS协议信息泄露漏洞(CVE-2016-2183) TLS是安全传输层协议&#xff0c;用于在两个通信应用程序之间提供保密性和数据完整性。 TLS, SSH, IPSec协商及其他产品中使用的DES及Triple DES密码存在大约四十亿块的生日界&#xff0c;这可使远程攻击者通过Sweet…

总结几个常用的Git命令的使用方法

目录 1、Git的使用越来越广泛 2、设置Git的用户名和密码并查看 3、建立自己的 Git 仓库 4、将自己的代码提交到远程 (origin) 仓库 5、同步远程仓库的更新到本地仓库 6、分支管理 7、获取远程仓库的内容 1、Git的使用越来越广泛 现在很多的公司或者机构都在使用Git进行项目和代…

Elasticsearch基础1——搜索引擎发展史和工作流程、es\es-head\kibana的基础安装

文章目录一、搜索引擎1.1 搜索引擎的发展背景1.2 Lucene和Elasticsearch1.3 Solr和Elasticsearch对比1.4 数据搜索方式1.5 搜索引擎1.5.1 搜索引擎工作流程1.5.2 网络爬虫原理流程1.5.3 网页分析1.5.4 正排索引和倒排索引二、Elasticsearch基础安装1.2 概述简介2.2 安装2.2.1 W…

tensorflow算子注册以及op详解

在自定义的算子时&#xff0c;经常遇到一些函数和宏&#xff0c;这里介绍一下常见的函数和宏 REGISTER_OP 首先我们来思考REGISTER_OP的作用是什么&#xff1f;当我们定义一个tensorflow的算子&#xff0c;首先我们需要tensorflow知道这个算子&#xff0c;也就是说我们要把这…

WeLink的使用

我这里是注册的企业端 流程>手机号验证码 注册成功后登陆 进入首页面 按操作逐步完成信息需求 因个体使用情况不同 在角色分类和组织架构中可根据自己部门或单位的分工分类 【拉人】&#xff1a; 三种方式 主要就是网址超链接和企业码 前提需要用户先注册 【加入审核】是根…

Nginx——反向代理解决跨域问题(Windows)

这个破玩意是真麻烦&#xff0c;必须写一篇文章避避坑了。一、先看看大佬的解释&#xff0c;了解反向代理和跨域问题吧&#xff1a;Nginx反向代理什么是跨域问题二、OK&#xff0c;直接开工&#xff0c;装Nginx下载地址: http://nginx.org/en/download.html如图所示, 选择相应的…

Flink多流转换(Flink Stream Unoin、Flink Stream Connect、Flink Stream Window Join)

文章目录多流转换1、分流操作1.1、在flink 1.13版本中已弃用.split()进行分流1.2、使用&#xff08;process function&#xff09;的侧输出流&#xff08;side output&#xff09;进行分流2、基本合流操作2.1、联合&#xff08;Flink Stream Union&#xff09;2.2、连接&#x…

【Go】实操使用go连接clickhouse

前言 近段时间业务在一个局点测试clickhouse&#xff0c;用java写的代码在环境上一直连接不上clickhouse服务&#xff0c;报错信息也比较奇怪&#xff0c;No client available&#xff0c;研发查了一段时间没查出来&#xff0c;让运维这边继续查&#xff1a; 运维同学查了各种…