4.2.tensorRT基础(1)-第一个trt程序,实现模型编译的过程

news2025/1/14 18:33:39

目录

    • 前言
    • 1. hello案例
    • 2. 补充知识
    • 总结

前言

杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。

本次课程学习 tensorRT 基础-第一个 trt 程序,实现模型编译的过程

课程大纲可看下面的思维导图

在这里插入图片描述

1. hello案例

学习使用 TensorRT-CPP 的 API 构建模型,并进行编译的流程

案例代码如下:


// tensorRT include
#include <NvInfer.h>
#include <NvInferRuntime.h>

// cuda include
#include <cuda_runtime.h>

// system include
#include <stdio.h>

class TRTLogger : public nvinfer1::ILogger{
public:
    virtual void log(Severity severity, nvinfer1::AsciiChar const* msg) noexcept override{
        if(severity <= Severity::kVERBOSE){
            printf("%d: %s\n", severity, msg);
        }
    }
};

nvinfer1::Weights make_weights(float* ptr, int n){
    nvinfer1::Weights w;
    w.count = n;
    w.type = nvinfer1::DataType::kFLOAT;
    w.values = ptr;
    return w;
}

int main(){
    // 本代码主要实现一个最简单的神经网络 figure/simple_fully_connected_net.png 
     
    TRTLogger logger; // logger是必要的,用来捕捉warning和info等

    // ----------------------------- 1. 定义 builder, config 和network -----------------------------
    // 这是基本需要的组件
    //形象的理解是你需要一个builder去build这个网络,网络自身有结构,这个结构可以有不同的配置
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
    // 创建一个构建配置,指定TensorRT应该如何优化模型,tensorRT生成的模型只能在特定配置下运行
    nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
    // 创建网络定义,其中createNetworkV2(1)表示采用显性batch size,新版tensorRT(>=7.0)时,不建议采用0非显性batch size
    // 因此贯穿以后,请都采用createNetworkV2(1)而非createNetworkV2(0)或者createNetwork
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);

    // 构建一个模型
    /*
        Network definition:

        image
          |
        linear (fully connected)  input = 3, output = 2, bias = True     w=[[1.0, 2.0, 0.5], [0.1, 0.2, 0.5]], b=[0.3, 0.8]
          |
        sigmoid
          |
        prob
    */

    // ----------------------------- 2. 输入,模型结构和输出的基本信息 -----------------------------
    const int num_input = 3;   // in_channel
    const int num_output = 2;  // out_channel
    float layer1_weight_values[] = {1.0, 2.0, 0.5, 0.1, 0.2, 0.5}; // 前3个给w1的rgb,后3个给w2的rgb 
    float layer1_bias_values[]   = {0.3, 0.8};

    //输入指定数据的名称、数据类型和完整维度,将输入层添加到网络
    nvinfer1::ITensor* input = network->addInput("image", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4(1, num_input, 1, 1));
    nvinfer1::Weights layer1_weight = make_weights(layer1_weight_values, 6);
    nvinfer1::Weights layer1_bias   = make_weights(layer1_bias_values, 2);
    //添加全连接层
    auto layer1 = network->addFullyConnected(*input, num_output, layer1_weight, layer1_bias);      // 注意对input进行了解引用
    //添加激活层 
    auto prob = network->addActivation(*layer1->getOutput(0), nvinfer1::ActivationType::kSIGMOID); // 注意更严谨的写法是*(layer1->getOutput(0)) 即对getOutput返回的指针进行解引用
    
    // 将我们需要的prob标记为输出
    network->markOutput(*prob->getOutput(0));

    printf("Workspace Size = %.2f MB\n", (1 << 28) / 1024.0f / 1024.0f); // 256Mib
    config->setMaxWorkspaceSize(1 << 28);
    builder->setMaxBatchSize(1); // 推理时 batchSize = 1 

    // ----------------------------- 3. 生成engine模型文件 -----------------------------
    //TensorRT 7.1.0版本已弃用buildCudaEngine方法,统一使用buildEngineWithConfig方法
    nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
    if(engine == nullptr){
        printf("Build engine failed.\n");
        return -1;
    }

    // ----------------------------- 4. 序列化模型文件并存储 -----------------------------
    // 将模型序列化,并储存为文件
    nvinfer1::IHostMemory* model_data = engine->serialize();
    FILE* f = fopen("engine.trtmodel", "wb");
    fwrite(model_data->data(), 1, model_data->size(), f);
    fclose(f);

    // 卸载顺序按照构建顺序倒序
    model_data->destroy();
    engine->destroy();
    network->destroy();
    config->destroy();
    builder->destroy();
    printf("Done.\n");
    return 0;
}

运行效果如下:

在这里插入图片描述

图1-1 hello案例运行效果

上述示例代码演示了使用 tensorRT 构建一个简单神经网络模型的过程。通过定义模型结构、设置输入和输出信息,并使用 TensorRT 的 Builder 对象和网络定义对象构建 Engine 模型。然后,将生成的模型序列化并保存为文件。示例代码中还展示了如何配置 TensorRT 的优化参数和限制条件。通过这个示例代码,可以了解 TensorRT 构建模型的基本工作流程和相关对象的使用方法。

首先引入了必要的头文件和库,代码中引入了 TensorRT 和 CUDA 的相关头文件和库。然后定义了一个日志类,继承自 nvinfer1::ILogger,用于捕获 TensorRT 的警告和信息。

定义输入、模型结构和输出的基本信息:指定输入张量的名称、数据类型和维度,添加全连接层和激活层,并将输出标记为网络的输出。接下来配置 TensorRT,通过设置最大工作空间大小和最大批处理大小来配置,随后构建 Engine 模型,使用 buildEngineWithConfig 方法构建 Engine 模型。序列化并保存模型,将模型序列化为二进制数据,并将其保存为文件。最后释放资源,按照创建顺序的倒序释放 TensorRT 的相关资源。

TensorRT 构建模型的大致工作流程可以分为四个部分:

  • 1.定义 builder,config 和 network
  • 2.构建网络所需输入,模型结构和输出的基本信息
  • 3.生成 engine 模型文件
  • 4.序列化模型文件并存储

关于该示例代码的重点提炼

  1. 必须使用 createNetworkV2,并指定为 1(表示显性 batch)。createNetwork 已经废弃,非显性 batch 官方不推荐。这个方式直接影响推理时 enqueue 还是 enqueueV2
  2. builder、config 等指针记得释放,否则会有内存泄露,使用 ptr->destroy() 释放
  3. markOutput 表示该模型的输出节点,mark 几次就有几个输出,addInput 几次就有几个输入,这与推理时相呼应
  4. workspaceSize 是工作空间大小,某些 layer 需要使用额外存储时,不会自己分配空间,而是为了内存复用,直接找 tensorRT 要 workspace 空间。指的这个意思
  5. 一定要记住,保存的模型只能适配编译时的 trt 版本、编译时指定的设备。也只能保证在这种配置下是最优的。如果用 trt 跨不同设备执行,有时可以运行,但不是最优的,也不推荐

2. 补充知识

关于第一个 trt 程序的相关知识点:(from 杜老师)

  • main.cpp 构建了一个最简单全连接网络
  • tensorrt 的工作流程如下图:
    • 首先定义网络
    • 优化 builder 参数
    • 通过 builder 生成 engine 用于模型保存、推理等
    • engine 可以通过序列化和反序列化转化模型数据类型(转化为二进制 byte 文件,加快传输速率),再一步推动模型由输入张量到输出张量的推理

在这里插入图片描述

  • code struct
    • 1.定义 builder, config 和network,其中 builder 表示所创建的构建器,config 表示创建的构建配置(指定 TensorRT 应该如何优化模型),network 为创建的网络定义。
    • 2.输入,模型结构和输出的基本信息(如下图所示)
    • 3.生成 engine 模型文件
    • 4.序列化模型文件并存储

在这里插入图片描述

  • 官方文档参考部分 C++ API

总结

本次课程学习了使用 tensorRT 的 C++ 接口来搭建一个简单的神经网络结构,整体流程可分为:builder、config、network 定义;输入、模型结构和输出信息;engine 模型文件生成;序列化模型文件并存储四个部分。

一些细节需要大家自行看代码进行分析,比如 builder、config 等指针记得释放,在 tensorRT 构建网络前需要定义日志 Logger 类用于捕获 tensorRT 的信息等等

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/759660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SlickGrid学习

options&#xff1a; 选项 设置 enableCellNavigation 启用单元格导航&#xff0c;可以点单元格 enableColumnReorder 启动拖拽列 example-colspan.html 跨列实例 AutoTooltips plugin 隐藏列文字时自动显现列标题全文 Checkbox row select column 增加选择列来选择行…

STM32入门学习之core_cm3问题

1.安装了keil之后&#xff0c;新建工程出现几百个关于core_cm3的问题&#xff0c;百思不得其解。后在网上查阅资料后&#xff0c;了解到可能是keil版本的问题&#xff0c;是因为我下载的keill版本太高了&#xff0c;内部不支持ARM5.06的编译器。出现很多关于core_cm3的问题是因…

使用java语言制作一个窗体(弹窗),用来收集用户输入的内容

前言 最近做的一个需求&#xff0c;有个逻辑环节是&#xff1a;需要从本地保存的xml文件中取出一个值&#xff0c;这个值是会变化的。然后项目经理就给我说&#xff0c;你能不能做一个小工具&#xff0c;让用户可以直接通过界面化操作将这个变化的值写入文件&#xff0c;不用再…

rv1126交叉编译

目录 一、解压sdk二、交叉编译出动态库sqlite3交叉编译opencv交叉编译一、解压sdk tar xzvf rv1126_rv1109_linux_sdk_v1.8.0_PureVersion.tar.gz 查看交叉编译工具链 pwd查看绝对路径/home

正斜杠“/” 和反斜杠 “\” 傻傻分不清?

Note, on your keyboard, the location of two different punctuation marks—/, or forward slash, and \, or backward slash (also known as backslash). As you read from left to right, the forward slash falls forward while the backslash falls backward; 引用自 《…

js逆向思路-区分各个瑞数版本vmp/3/4/5/6代

目录 一、如何区分是最新瑞数vmp反爬二、3/4/5/6代/vmp版本的瑞数网站特征举例三、瑞数反爬的解决思路四、推荐相关瑞数文章五、一些心得一、如何区分是最新瑞数vmp反爬 前言:本篇文章不会介绍详细的解决反爬的算法扣代码过程,只是一些经验闲谈,文章的末尾有相关的好的质量的…

油猴脚本-Bilibili剧场模式仿Youtube

对比某个不存在的视频网站&#xff08;YouTube&#xff09;&#xff0c;以及B站的播放模式&#xff0c;普通模式以及网页全屏之间都有一个“中间档”&#xff0c;油管的叫 剧场模式&#xff0c;B站的叫 宽屏模式。 剧场模式 宽屏模式 相比之下&#xff0c;还是更喜欢油管的剧…

10.带你入门matlab频率表、盒图(matlab程序)

1.简述 相关概念介绍 以信号为例&#xff0c;信号在时域下的图形可以显示信号如何随着时间变化&#xff0c;而信号在频域下的图形&#xff08;一般称为频谱&#xff09;可以显示信号分布在哪些频率及其比例。频域的表示法除了有各个频率下的大小外&#xff0c;也会有各个频率的…

《遗留系统现代化》读书笔记(基础篇)

你现在所写的每一行代码&#xff0c;都是未来的遗留系统 为什么要对遗留系统进行现代化&#xff1f; 什么是遗留系统&#xff1f; 判断遗留系统的几个维度&#xff1a;代码、架构、测试、DevOps 以及技术和工具。时间长短并不是衡量遗留系统的标准。代码质量差、架构混乱、没…

【C++】多线程编程系列总纲

本系列文章主要介绍C11并发编程&#xff0c;主要包括如下篇目&#xff1a; 1、多线程编程一&#xff08;初识并发和多线程&#xff09; https://blog.csdn.net/Jacky_Feng/article/details/131751373?csdn_share_tail{"type"%3A"blog"%2C"rType&qu…

idea快捷方式,(主要是自己用的,持续收集)

编辑器 代码光标扩大选中 快捷键名字 使用方式 此时光标在这这里,按一下发现选中面积扩大, 如果再按快捷键,选中面积还扩大

带你【玩转Linux命令】➾ find cut 每天2个day06

带你【玩转Linux命令】➾ find & cut 每天2个day06 &#x1f53b; 一、文件管理命令1.1 find-查找文件或目录1.2 cut-指定欲显示的文件内容&#xff0c;输出到标准输出设备 &#x1f53b; 总结—温故知新 &#x1f53b; 一、文件管理命令 1.1 find-查找文件或目录 &#x…

【Nacos】基于k8s容器化部署Nacos集群

近期&#xff0c;在机器上部署了三个节点的nacos集群服务用于几个小型微服务的注册配置中心&#xff0c;并使用了Nginx简单代理了一下&#xff0c;随即简单研究了下集群部署分布式部署稍微提高可用性。部署完后能够正常使用&#xff0c;但是发现一个问题&#xff0c;刷新Nacos集…

制作文件间链接

制作文件间链接 管理文件间链接 硬/软链接 创建指向同一个文件的多个名称。 创建硬链接 从初始名称到文件系统的数据&#xff0c;每个文件都以一个硬链接开始。当创建指向文件的新硬链接时,也会创建另一个指向同一数据的名称。新硬链接与原始文件名作用相同。一经创建&…

Maya适合哪个工作站?

Autodesk Maya 提供多种功能&#xff0c;可以适应电影、游戏和建筑等不同行业的需求。定制的 Autodesk Maya 工作站可以帮助您提高行业领先的 3D 计算机动画、建模、模拟和渲染软件的工作效率和用户体验。 根据您的特定需求定制的快速、强大的工作站可以帮助您充分利用 Maya 工…

高薪Offer收割机之Redis分布式锁

锁在应用开发中使用非常广泛,哪些场景需要使用锁呢? 我们先来看抢购优惠卷的场景,代码如下: public void rushToPurchase() throws InterruptedException {//获取优惠券数量Integer num = (Integer) redisTemplate.opsForValue().get(“num”);//判断是否抢完if (null == n…

[Ipsc2009]Let there be rainbows!

Description HY Star是一个处处充满和谐&#xff0c;人民安居乐业的星球&#xff0c;但是HY Star却没有被评上宇宙文明星球&#xff0c;很大程度上是因为 星球的形象问题。HY Star由N个国家组成,并且在一些国家之间修建了道路以方便交流。由于HY Star是一个和谐的 星球&#x…

【运维】第04课:入口网关服务注册发现-Openrety 动态 uptream

本课时,我将带你一起了解入口网关服务的注册发现,并使用 OpenResty 实现一套动态 Upstream。 课前学习提示 基于本课时我们将要学习的内容,我建议你课前先了解一下 Nginx 的基础,同时熟悉基础的 Lua 语言语法,另外再回顾一下 HTTP 的请求过程,对于 Nginx 的负载均衡基本…

按键控制流水灯方向——FPGA

文章目录 前言一、按键二、系统设计1、模块框图2、RTL视图 三、源码四、效果五、总结六、参考资料 前言 环境&#xff1a; 1、Quartus18.0 2、vscode 3、板子型号&#xff1a;EP4CE6F17C8 要求&#xff1a; 按键1按下&#xff0c;流水灯从右开始向左开始流动&#xff0c;按键2按…

习题-Java网络编程

目录 1.TCP-对象 2.UDP​​​​​​​​​​​​​​ 1.TCP-对象 利用TCP传输对象信息&#xff0c;需要对对象进行实例化 User类&#xff1a; package dh09.demo02;import java.io.Serializable;public class User implements Serializable {private String name;private St…