五. TensorRT API的基本使用-load-model

news2024/9/22 21:19:38

目录

    • 前言
    • 0. 简述
    • 1. 案例运行
    • 2. 代码分析
      • 2.1 main.cpp
      • 2.2 model.hpp
      • 2.3 model.cpp
      • 2.4 其它
    • 总结
    • 下载链接
    • 参考

前言

自动驾驶之心推出的 《CUDA与TensorRT部署实战课程》,链接。记录下个人学习笔记,仅供自己参考

本次课程我们来学习课程第五章—TensorRT API 的基本使用,一起来学习手动实现一个 build

课程大纲可以看下面的思维导图

在这里插入图片描述

0. 简述

本小节目标:手动实现 build 完成模型的序列化

今天我们来讲第五章节第二小节—5.2-load-model 这个案例,上个小节我们主要是通过官方 MNIST 案例让大家熟悉 TensorRT 的一些 API 的使用,这个小节我们主要是模仿官方案例来自己手写一个 build

下面我们开始本次课程的学习🤗

1. 案例运行

在正式开始课程之前,博主先带大家跑通 5.2-load-model 这个小节的案例🤗

源代码获取地址:https://github.com/kalfazed/tensorrt_starter.git

首先大家需要把 tensorrt_starter 这个项目给 clone 下来,指令如下:

git clone https://github.com/kalfazed/tensorrt_starter.git

也可手动点击下载,点击右上角的 Code 按键,将代码下载下来。至此整个项目就已经准备好了。也可以点击 here 下载博主准备好的源代码(注意代码下载于 2024/7/14 日,若有改动请参考最新

整个项目后续需要使用的软件主要有 CUDA、cuDNN、TensorRT、OpenCV,大家可以参考 Ubuntu20.04软件安装大全 进行相应软件的安装,博主这里不再赘述

假设你的项目、环境准备完成,下面我们来一起运行 5.2 小节案例代码

开始之前我们需要创建几个文件夹,在 tensorrt_starter/chapter5-tensorrt-api-basics/5.2-load-model 小节中创建一个 models 文件夹,接着在 models 文件夹下创建 onnx 和 engine 文件夹,总共三个文件夹需要创建

创建完后 5.2 小节整个目录结构如下:

在这里插入图片描述

接着我们需要执行 python 文件创建一个 ONNX 模型,先进入到 5.2 小节中:

cd tensorrt_starter/chapter5-tensorrt-api-basics/5.2-load-model

执行如下指令:

python src/python/generate_onnx.py

Note:大家需要准备一个虚拟环境,安装好 torch、onnx、onnxsim 等第三方库

输出如下:

在这里插入图片描述

生成好的 onnx 模型文件保存在 models/onnx 文件夹下,大家可以查看

接着我们需要利用 ONNX 生成对应的 engine,在此之前我们需要修改下整体的 Makefile.config,指定一些库的路径:

# tensorrt_starter/config/Makefile.config
# CUDA_VER                    :=  11
CUDA_VER                    :=  11.6
    
# opencv和TensorRT的安装目录
OPENCV_INSTALL_DIR          :=  /usr/local/include/opencv4
# TENSORRT_INSTALL_DIR        :=  /mnt/packages/TensorRT-8.4.1.5
TENSORRT_INSTALL_DIR        :=  /home/jarvis/lean/TensorRT-8.6.1.6

Note:大家查看自己的 CUDA 是多少版本,修改为对应版本即可,另外 OpenCV 和 TensorRT 修改为你自己安装的路径即可

接着我们就可以来执行编译,指令如下:

make -j64

输出如下:

在这里插入图片描述

接着执行:

./trt-infer

输出如下:

在这里插入图片描述

可以看到输出了很多日志信息,该案例主要是通过自定义 build 构建一个 engine 并保存到 models/engine/sample.engine 中,最后打印输入和输出的维度信息

如果大家能够看到上述输出结果,那就说明本小节案例已经跑通,下面我们就来看看具体的代码实现

2. 代码分析

2.1 main.cpp

我们先从 main.cpp 看起:

#include <iostream>
#include <memory>

#include "model.hpp"
#include "utils.hpp"

using namespace std;

int main(int argc, char const *argv[])
{
    Model model("models/onnx/sample.onnx");
    if(!model.build()){
        LOGE("ERROR: fail in building model");
        return 0;
    }
    return 0;
}

通过传入 ONNX 模型文件路径创建一个 Model 实例,之后调用 build 函数构建 engine

2.2 model.hpp

我们来看 Model 类的定义:

class Model{
public:
    Model(std::string onnxPath);
    bool build();
private:
    std::string mOnnxPath;
    std::string mEnginePath;
    nvinfer1::Dims mInputDims;
    nvinfer1::Dims mOutputDims;
    std::shared_ptr<nvinfer1::ICudaEngine> mEngine;
    bool constructNetwork();
    bool preprocess();
};

Model 类中公有方法主要是 build,而其它的比如 mEngine,constructNetwork,preprocess 等都是私有方法,没有必要暴露给用户的

2.3 model.cpp

接着我们来看 Model 的构造函数:

Model::Model(string onnxPath){
    if (!fileExists(onnxPath)) {
        LOGE("%s not found. Program terminated", onnxPath.c_str());
        exit(1);
    }
    mOnnxPath   = onnxPath;
    mEnginePath = getEnginePath(mOnnxPath);
}

首先它会去检查传入的 onnxPath 文件是否存在,如果不存在则打印错误信息并退出,接着把 onnxPath 赋值给私有成员变量 mOnnxPath,通过 getEnginePath 函数拿到对应的 mEnginePath。另外这里的 LOGE 是通过宏定义实现的一个打印函数,它可以用来控制不同的输出日志等级

我们再来看 build 的实现,首先是检查 mEnginePath 是否存在:

if (fileExists(mEnginePath)){
    LOG("%s has been generated!", mEnginePath.c_str());
    return true;
} else {
    LOG("%s not found. Building engine...", mEnginePath.c_str());
}

如果存在则不用再重新 build,如果不存在则需要通过下面的流程进行 build

首先我们实例化一个 logger:

class Logger : public nvinfer1::ILogger{
public:
    virtual void log (Severity severity, const char* msg) noexcept override{
        string str;
        switch (severity){
            case Severity::kINTERNAL_ERROR: str = RED    "[fatal]:" CLEAR;
            case Severity::kERROR:          str = RED    "[error]:" CLEAR;
            case Severity::kWARNING:        str = BLUE   "[warn]:"  CLEAR;
            case Severity::kINFO:           str = YELLOW "[info]:"  CLEAR;
            case Severity::kVERBOSE:        str = PURPLE "[verb]:"  CLEAR;
        }
        if (severity <= Severity::kINFO)
            cout << str << string(msg) << endl;
    }
};

Logger logger;

我们上节课讲过在创建一个 builder 的时候需要绑定一个 logger,因此我们这里自己手动实现了一个 Logger 类,它继承自 nvinfer1::ILogger,在 Logger 类中我们必须自己手动来实现 log 虚函数。Severity 是一个枚举类,用于控制日志消息的等级,然后将不同的 str 附加不同颜色,如果 severity 级别小于或等于 kINFO,则会通过 cout 将带有前缀的 str 日志信息打印出来

创建完 logger 之后我们接着往下看:

auto builder       = make_unique<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(logger));
auto network       = make_unique<nvinfer1::INetworkDefinition>(builder->createNetworkV2(1));
auto config        = make_unique<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
auto parser        = make_unique<nvonnxparser::IParser>(nvonnxparser::createParser(*network, logger));

config->setMaxWorkspaceSize(1<<28);

if (!parser->parseFromFile(mOnnxPath.c_str(), 1)){
    LOGE("ERROR: failed to %s", mOnnxPath.c_str());
    return false;
}

其实和上节课讲的流程一样,我们先创建一个 builder,然后通过 builder 创建 network、config,接着把 network 和 logger 丢到 nvonnxparser::createParser 函数中创建一个 parser

接着通过 config 设置了最大的 workspace size,其实 config 可以设置非常多的参数,包括 setCalibrationProfile 设置校准文件,setInt8Calibrator 设置校准器等等,这些都是跟模型创建相关的东西,大家自己可以看下

另外这些 API 的说明在官方文档中描述都比较详细,大家也可以参考:tensorrt/developer-guide

config 设置完成之后,通过 parserFromFile 函数将 onnx parser 到 network 里面去

上面这些都是准备工作,接着我们就可以来创建 engine:

auto engine        = make_unique<nvinfer1::ICudaEngine>(builder->buildEngineWithConfig(*network, *config));
auto plan          = builder->buildSerializedNetwork(*network, *config);
auto runtime       = make_unique<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger));

auto f = fopen(mEnginePath.c_str(), "wb");
fwrite(plan->data(), 1, plan->size(), f);
fclose(f);

通过 builder->buildEngineWithConfig 把 network 和 config 丢进去创建 engine,之后把创建好的 network 做一个序列化保存到 plan 中去,其中 plan 是一个 IHostMemory 的指针,然后我们创建了一个 runtime 方便后续反序列化测试

接着我们把序列化好的 plan 文件通过 fwrite 写入保存到指定路径,方便下次加载使用

下面我们打印了模型的一些基本信息:

mEngine            = shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(plan->data(), plan->size()));
mInputDims         = network->getInput(0)->getDimensions();
mOutputDims        = network->getOutput(0)->getDimensions();
LOG("Input dim is %s", printDims(mInputDims).c_str());
LOG("Output dim is %s", printDims(mOutputDims).c_str());
return true;

通过 runtime->deserializeCudaEngine 来反序列化拿到我们的 engine,其中的 mEngine 是 ICudaEnigne 的指针,是一个推理引擎,然后我们可以通过 network 将输入输出的一些维度信息打印出来

这里有一个小技巧,大家在学习 API 的时候可以通过一些名字大概猜测其主要实现的功能,比如 network 它其中的以 getXXX 为例的 API 一般来说都是去获取网络的一些信息,比如 getLayer、getName 等等,再比如 engine 也有类似于 getXXX 的 API,比如 getDeviceMemorySize、getNbOptimizationProfiles 等等

2.4 其它

在 src/python 文件夹下还有一个 generate_onnx.py 的脚本文件,其内容如下:

import torch
import torch.nn as nn
import torch.onnx
import onnxsim
import onnx
import os

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(in_features=10, out_features=5, bias=False)
    
    def forward(self, x):
        x = self.linear(x)
        return x

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def export_norm_onnx():
    current_path = os.path.dirname(__file__)
    file = current_path + "/../../models/onnx/sample.onnx"

    input   = torch.rand(1, 10)
    model   = Model()
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 15)
    print("Finished normal onnx export")

    # check the exported onnx model
    model_onnx = onnx.load(file)
    onnx.checker.check_model(model_onnx)

    # use onnx-simplifier to simplify the onnx
    print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "assert check failed"
    onnx.save(model_onnx, file)

def infer():
    setup_seed(1)
    model  = Model()
    input  = torch.tensor([[0.0193, 0.2616, 0.7713, 0.3785, 0.9980, 0.9008, 0.4766, 0.1663, 0.8045, 0.6552]])
    output = model(input)
    print(input)
    print(output)

if __name__ == "__main__":
    export_norm_onnx()
    infer()

它就是创建了一个非常简单的 ONNX 模型,其中包含一个 Linear 节点,如下所示:

在这里插入图片描述

总结

本次课程我们主要模仿官方案例自己手动实现了一个 builder,和官方流程类似,先创建一个 logger,然后创建 builder,通过 builder 创建 network、config,然后创建 parser,通过 parseFromFile 将 ONNX parser 到 network 中,接着创建完 engine,通过 buildSerializedNetwork 进行序列化生成 plan,并将 plan 保存下来,最后调用一个 API 来打印一些输入输出维度信息。总的来说,实现还是比较简单的,关于一些 API 的使用大家可以多尝试尝试

OK,以上就是 5.2 小节案例的全部内容了,下节我们来学习 5.3 小节自己构建一个 infer 来推理模型,敬请期待😄

下载链接

  • tensorrt_starter源码
  • 5.2-load-model案例文件

参考

  • Ubuntu20.04软件安装大全
  • https://github.com/kalfazed/tensorrt_starter.git
  • https://docs.nvidia.com/deeplearning/tensorrt/developer-guide

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1939818.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

达梦数据库审计日志采集

目录 1. 审计功能简介2. dm8官方技术参考文档3. dm8审计功能配置3.1 登录审计用户3.2 开启审计开关3.3 查询审计日志3.4 审计设置3.4.1 配置语句级审计3.4.2 取消语句级审计3.5 审计日志查阅4. python获取dm8审计日志1. 审计功能简介 审计机制是 DM 数据库管理系统安全管理的重…

第十一课:综合项目实践

下面是我们搭建的一个综合实践的拓扑图&#xff1a; 我们要完成以下目标&#xff1a; 网络中有3个不同部门&#xff0c;均可自动获取地址各部门可相互访问&#xff0c;也可访问内部服务网172.16.100.1&#xff0c;PC1不允许访问互联网&#xff0c;PC1和PC3可以访问互联网内网服…

nginx配置文件说明

Nginx的配置文件说明 Nginx配置文件的主要配置块可以分为三个部分&#xff1a;全局配置块&#xff08;events和http块&#xff09;&#xff0c;events块和http块。这三个部分共同定义了Nginx服务器的整体行为和处理HTTP请求的方式。 全局配置块&#xff1a; 包含了影响Nginx服…

Vue3组件样式

在 Vue3开发中&#xff0c;我们经常需要对元素的类和样式进行动态控制。本文将详细介绍如何使用 Vue.js 的特性来实现这一目标。 class 绑定 绑定对象&#xff1a; 在 Vue.js 中&#xff0c;我们可以使用对象语法来绑定 class。例如&#xff1a; <div :class"{ act…

卸载顽固的驱动或软件

在Windows系统&#xff0c;有些软件或驱动&#xff0c;为了防止被卸载&#xff0c;特地在C:\Windows\System32\drivers目录里&#xff0c;生成xxx.sys文件。这些xxx.sys文件&#xff0c;无法直接删除&#xff0c;用杀毒软件也很难卸载。     这里介绍一种方法&#xff0c;可以…

基于GTX的64B66B编码的自定义接收模块(高速收发器二十二)

点击进入高速收发器系列文章导航界面 1、自定义PHY顶层模块 前文设计了64B66B自定义PHY的发送模块&#xff0c;本文完成自定义PHY剩余的模块的设计&#xff0c;整体设计框图如下所示。 其中phy_tx是自定义PHY的发送数据模块&#xff0c;scrambler是加扰模块&#xff0c;rx_slip…

阿里云OS Copilot:解锁操作系统运维与编程的智能助手

目录 引言 OS Copilot简介 OS Copilot的环境准备 创建实验资源 安全设置 设置安全组端口 创建阿里云AccessKey 准备系统环境 OS Copilot的实操 场景一、用OS Copilot写脚本和注释代码 场景二、使用OS Copilot进行对话问答 场景三、使用OS Copilot辅助编程学习 清理…

P15-P18-随机梯度下降-自适应学习率-超参数筛选-正则化

文章目录 随机梯度下降和自适应学习率超参数筛选交叉验证 正则化权重衰减Dropout 简介 本文主要讨论了机器学习中随机梯度下降&#xff08;SGD&#xff09;和自适应学习率算法的原理及应用。SGD通过随机选择小批量样本计算损失值&#xff0c;减少了计算量&#xff0c;加快了训练…

国内访问Docker Hub慢问题解决方法

在国内访问Docker Hub时可能会遇到一些困难&#xff0c;但幸运的是&#xff0c;有多种解决方案可以帮助你顺利下载Docker镜像。以下是一些有效的解决方案&#xff1a; 配置Docker镜像源&#xff1a;你可以通过配置Docker的daemon.json文件来使用国内镜像源&#xff0c;比如DaoC…

Spring Web MVC(一篇带你了解并入门,附带常用注解)

一&#xff0c;什么是Spring Web MVC 先看一下官网怎么说&#xff1a; 也就是Spring Web MVC一开始就是包含在Spring框架里面的&#xff0c;但通常叫做Spring MVC。 也可以总结出一个信息&#xff0c;这是一个Web框架。后面我就简称为Spring MVC了。 1.1MVC MVC也就是Mode…

202496读书笔记|《飞花令·菊(中国文化·古典诗词品鉴)》——荷尽已无擎雨盖,菊残犹有傲霜枝

202496读书笔记|《飞花令菊&#xff08;中国文化古典诗词品鉴&#xff09;》——荷尽已无擎雨盖&#xff0c;菊残犹有傲霜枝 《飞花令菊&#xff08;中国文化古典诗词品鉴&#xff09;》素心落雪 编著。飞花令得名于唐代诗人韩翃《寒食》中的名句“春城无处不飞花”&#xff0c…

食南之徒~马伯庸

◆ 第一章 >> 老赵&#xff0c;这你就不懂了。过大于功&#xff0c;要受罚挨打&#xff0c;不合算&#xff1b;功大于过&#xff0c;下回上司有什么脏活累活&#xff0c;第一时间会想到你&#xff0c;也是麻烦多多。只有功过相抵&#xff0c;上司既挑不出你的错&#xf…

Unity 调试死循环程序

如果游戏出现死循环如何调试呢。 测试脚本 我们来做一个测试。 首先写一个死循环代码&#xff1a; using System.Collections; using System.Collections.Generic; using UnityEngine;public class dead : MonoBehaviour {void Start(){while (true){int a 1;}}}Unity对象设…

Flowable-SpringBoot项目集成

在前面的介绍中&#xff0c;虽然实现了绘制流程图&#xff0c;然后将流程图存储到数据库中&#xff0c;然后从数据库中获取流程信息&#xff0c;并部署和启动流程&#xff0c;但是部署的流程绘制器是在tomcat中部署的&#xff0c;可能在部分的项目中&#xff0c;需要我们将流程…

微信小程序数组绑定使用案例(一)

微信小程序数组绑定案例&#xff0c;修改数组中的值 1.Wxml 代码 <view class"list"><view class"item {{item.ischeck?active:}}" wx:for"{{list}}"><view class"title">{{item.name}} <text>({{item.id}…

武忠祥李永乐强化笔记

高等数学 函数 极限 连续 函数 复合函数奇偶性 f[φ(x)]内偶则偶&#xff0c;内奇则同外 奇函数 ln ⁡ ( x 1 x 2 ) \ln(x \sqrt{1 x^{2}}) ln(x1x2 ​) 单调性 一点导数>0不能得出邻域单调增&#xff0c;加上导函数连续则可以得出邻域单调增 极限 等价无穷小…

达梦数据库的系统视图v$utsk_info

达梦数据库的系统视图v$utsk_info 查询守护进程向服务器发送请求的执行情况。 升级到 V3.0 版本后&#xff0c;此视图仅用于查看当前服务器的命令执行情况&#xff0c;在 CMD 字段值不为 0 时&#xff0c;说明是有效的命令信息&#xff1b;此时如果 CODE 字段值是 100&#xf…

202495读书笔记|《红楼梦(插图本)(童年书系·书架上的经典)》——荣辱自古周而复始,岂是人力所能保的?

202495读书笔记|《红楼梦&#xff08;插图本&#xff09;&#xff08;童年书系书架上的经典&#xff09;》——荣辱自古周而复始&#xff0c;岂是人力所能保的&#xff1f; 摘录人物关系&#xff1a; 《红楼梦&#xff08;插图本&#xff09;&#xff08;童年书系书架上的经典&…

02互联网行业的产品方向(2)

数字与策略产品 大数据时代&#xff0c;数据的价值越来越重要。大多数公司开始对内外全部数据进行管理与挖掘&#xff0c;将业务数据化&#xff0c;数据资产化&#xff0c;资产业务化&#xff0c;将数据产品赋能业务&#xff0c;通过数据驱动公司业务发展&#xff0c;支撑公司战…

学习周报:文献阅读+Fluent案例+水动力学方程推导

目录 摘要 Abstract 文献阅读&#xff1a;物理信息神经网络学习自由表面流 文献摘要 讨论|结论 预备知识 浅水方程SWE&#xff08;Shallow Water Equations&#xff09; 质量守恒方程&#xff1a; 动量守恒方程&#xff1a; Godunov通量法&#xff1a; 基本原理&…