【TensorRT部署】pytorch模型(pt/pth)转onnx,onnx转engine(tensorRT)

news2024/9/20 1:13:05

1. 单帧处理

1. pt2onnx

import torch
import numpy as np
from parameters import get_parameters as get_parameters
from models._model_builder import build_model
TORCH_WEIGHT_PATH = './checkpoints/model.pth'
ONNX_MODEL_PATH = './checkpoints/model.onnx'
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_tensor_type('torch.cuda.FloatTensor')
def get_numpy_data():
    batch_size = 1
    img_input = np.ones((batch_size,1,512,512), dtype=np.float32)
    return img_input

def get_torch_model():
    # args = get_args()
    args = get_parameters()
    model = build_model(args.model, args)
    model.load_state_dict(torch.load(TORCH_WEIGHT_PATH))
    model.cuda()
    #pass
    return model
#定义参数
input_name = ['data']
output_name = ['prob']
'''input为输入模型图片的大小'''
input = torch.randn(1,1,512,512).cuda()

# 创建模型并载入权重
model = get_torch_model()
model.load_state_dict(torch.load(TORCH_WEIGHT_PATH))
model.cuda()

#导出onnx
torch.onnx.export(model, input, ONNX_MODEL_PATH, input_names=input_name, output_names=output_name, verbose=False,opset_version=11)

补充:也可以对onnx进行简化

# pip install onnxsim

from onnxsim import simplify
import onnx
onnx_model = onnx.load("./checkpoints/model.onnx")  # load onnx model
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, "./checkpoints/model.onnx")
print('finished exporting onnx')

2. onnx2engine

// OnnxToEngine.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//
#include <iostream>
#include <chrono>
#include <vector>
#include "cuda_runtime_api.h"
#include "logging.h"
#include "common.hpp"
#include "NvOnnxParser.h"
#include"NvCaffeParser.h"
const char* INPUT_BLOB_NAME = "data";
using namespace std;
using namespace nvinfer1;
using namespace nvonnxparser;
using namespace nvcaffeparser1;

unsigned int maxBatchSize = 1;

int main()
{
    //step1:创建logger:日志记录器
    static Logger gLogger;
    //step2:创建builder
    IBuilder* builder = createInferBuilder(gLogger);

    //step3:创建network
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);//0改成1,
    //step4:创建parser
    nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger);

    //step5:使用parser解析模型填充network
    const char* onnx_filename = "..\\onnx\\model.onnx";
    parser->parseFromFile(onnx_filename, static_cast<int>(Logger::Severity::kWARNING));
    for (int i = 0; i < parser->getNbErrors(); ++i)
    {
        std::cout << parser->getError(i)->desc() << std::endl;
    }
    std::cout << "successfully load the onnx model" << std::endl;
    //step6:创建config并设置最大batchsize和最大工作空间
    // Create builder
   // unsigned int maxBatchSize = 1;
    builder->setMaxBatchSize(maxBatchSize);
    IBuilderConfig* config = builder->createBuilderConfig();
    config->setMaxWorkspaceSize( (1 << int(20)));

    //step7:创建engine
    ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
    //assert(engine);

    //step8:序列化保存engine到planfile
    IHostMemory* serializedModel = engine->serialize();
    //assert(serializedModel != nullptr);
    //std::ofstream p("D:\\TensorRT-7.2.2.322\\engine\\unet.engine");
    //p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());
    std::string engine_name = "..\\engine\\model.engine";
    std::ofstream p(engine_name, std::ios_base::out | std::ios_base::binary);
    if (!p) {
        std::cerr << "could not open plan output file" << std::endl;
        return -1;
    }
    p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());
    std::cout << "successfully build an engine model" << std::endl;
    //step9:释放资源
    serializedModel->destroy();
    engine->destroy();
    parser->destroy();
    network->destroy();
    config->destroy();
    builder->destroy();

}

2. 多帧处理(加速)

2.1 pt2onnx

import onnx
import torch
import numpy as np
from parameters import get_parameters as get_parameters
from models._model_builder import build_model
TORCH_WEIGHT_PATH = './checkpoints/model.pth'
ONNX_MODEL_PATH = './checkpoints/model.onnx'
args = get_parameters()
def get_torch_model():
    # args = get_args()
    print(args.model)
    model = build_model(args.model, args)
    model.load_state_dict(torch.load(TORCH_WEIGHT_PATH))
    model.cuda()
    #pass
    return model



if __name__ == "__main__":
    # 设置输入参数
    Batch_size = 1
    Channel = 1
    Height = 384
    Width = 640
    input_data = torch.rand((Batch_size, Channel, Height, Width)).cuda()

    # 实例化模型
    # 创建模型并载入权重
    model = get_torch_model()
    #model.load_state_dict(torch.load(TORCH_WEIGHT_PATH))
    #model.cuda()

    # 导出为静态输入
    input_name = 'data'
    output_name = 'prob'
    torch.onnx.export(model,
                      input_data,
                      ONNX_MODEL_PATH,
                      verbose=True,
                      input_names=[input_name],
                      output_names=[output_name])

    # 导出为动态输入
    torch.onnx.export(model,
                      input_data,
                      ONNX_MODEL_PATH2,
                      opset_version=11,
                      input_names=[input_name],
                      output_names=[output_name],
                      dynamic_axes={
                          #input_name: {0: 'batch_size'},
                          #output_name: {0: 'batch_size'}}
                          input_name: {0: 'batch_size', 1: 'channel', 2: 'input_height', 3: 'input_width'},
                          output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}}
                       )

2.2 onnx2engine

 OnnxToEngine.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
#include <iostream>
#include "NvInfer.h"
#include "NvOnnxParser.h"
#include "logging.h"
#include "opencv2/opencv.hpp"
#include <fstream>
#include <sstream>
#include "cuda_runtime_api.h"
static Logger gLogger;
using namespace nvinfer1;


bool saveEngine(const ICudaEngine& engine, const std::string& fileName)
{
	std::ofstream engineFile(fileName, std::ios::binary);
	if (!engineFile)
	{
		std::cout << "Cannot open engine file: " << fileName << std::endl;
		return false;
	}

	IHostMemory* serializedEngine = engine.serialize();
	if (serializedEngine == nullptr)
	{
		std::cout << "Engine serialization failed" << std::endl;
		return false;
	}

	engineFile.write(static_cast<char*>(serializedEngine->data()), serializedEngine->size());
	return !engineFile.fail();
}
void print_dims(const nvinfer1::Dims& dim)
{
	for (int nIdxShape = 0; nIdxShape < dim.nbDims; ++nIdxShape)
	{

		printf("dim %d=%d\n", nIdxShape, dim.d[nIdxShape]);

	}
}

int main()
{

	//	1、创建一个builder
	IBuilder* pBuilder = createInferBuilder(gLogger);
	// 2、 创建一个 network,要求网络结构里,没有隐藏的批量处理维度
	INetworkDefinition* pNetwork = pBuilder->createNetworkV2(1U << static_cast<int>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));

	// 3、 创建一个配置文件
	nvinfer1::IBuilderConfig* config = pBuilder->createBuilderConfig();
	// 4、 设置profile,这里动态batch专属
	IOptimizationProfile* profile = pBuilder->createOptimizationProfile();
	// 这里有个OptProfileSelector,这个用来设置优化的参数,比如(Tensor的形状或者动态尺寸),

	profile->setDimensions("data", OptProfileSelector::kMIN, Dims4(1, 1, 512, 512));
	profile->setDimensions("data", OptProfileSelector::kOPT, Dims4(2, 1, 512, 512));
	profile->setDimensions("data", OptProfileSelector::kMAX, Dims4(4, 1, 512, 512));

	config->addOptimizationProfile(profile);

	auto parser = nvonnxparser::createParser(*pNetwork, gLogger.getTRTLogger());

	const char* pchModelPth = "..\\onnx\\model.onnx";

	if (!parser->parseFromFile(pchModelPth, static_cast<int>(gLogger.getReportableSeverity())))
	{

		printf("解析onnx模型失败\n");
	}

	int maxBatchSize = 4;
	//IBuilderConfig::setMaxWorkspaceSize

	pBuilder->setMaxWorkspaceSize(1 << 32);  //pBuilderg->setMaxWorkspaceSize(1<<32);改为config->setMaxWorkspaceSize(1<<32);
	pBuilder->setMaxBatchSize(maxBatchSize);
	//设置推理模式
	pBuilder->setFp16Mode(true);
	ICudaEngine* engine = pBuilder->buildEngineWithConfig(*pNetwork, *config);

	std::string strTrtSavedPath = "..\\engine\\model.trt";
	// 序列化保存模型
	saveEngine(*engine, strTrtSavedPath);
	nvinfer1::Dims dim = engine->getBindingDimensions(0);
	// 打印维度
	print_dims(dim);
}

3. c++调用tensorRT模型

整个工程:链接
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1245846.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

毛里塔尼亚市场开发攻略,收藏一篇就够了

毛里塔尼亚是非洲西北部的一个国家&#xff0c;也是中国长期援建的一个国家&#xff0c;也是一带一路上的国家。毛里塔尼亚生产生活资料依赖进口&#xff0c;长期依赖跟我们国家的贸易关系也是比较紧密的&#xff0c;今天就来给大家介绍一下毛里塔尼亚的市场开发公路。文章略长…

“关爱零距离.情暖老人心”主题活动

为提高社区老年人的生活质量&#xff0c;促进邻里间的互动与友谊&#xff0c;以及弘扬尊老爱幼的社区精神&#xff0c;11月21日山东省潍坊市金阳公益服务中心、重庆市潼南区同悦社会工作服务中心在潼南区桂林街道东风社区共同在潼南区桂林街道东风社区举办了“关爱零距离.情暖老…

BMS实战: BMS产品介绍,电池外观分析,电芯种类分析,焊接方式分析,充电方式,电压平台,电芯型号分析。

快速入门的办法就是了解产品,了解现在市面上正在流通的成熟产品方案。光看基础知识是没有效果的。 首先我们找到了一张市面上正在出售的电池pack包。 图片来源网上,侵权删 电池外观分析 外壳: 一般是金属外壳,大部分都是铁壳加喷漆,特殊材质可以定制。 提手 一般是…

22款奔驰S400L升级主动式氛围灯 光影彰显奔驰的完美

新款奔驰S级原车自带64色氛围灯&#xff0c;还可以升级原厂的主动式氛围灯&#xff0c;增加车内的氛围效果。主动式环境氛围灯包含263个LED光源&#xff0c;每隔1.6厘米就有一个LED光源&#xff0c;照明效果较过去明亮10倍&#xff0c;视觉效果更加绚丽&#xff0c;它还可结合智…

怎么申请IP地址证书?

IP地址证书&#xff0c;也称为SSL证书&#xff0c;是一种数字证书&#xff0c;用于在网络传输过程中对IP地址进行加密和解密。它是由受信任的证书颁发机构&#xff08;CA&#xff09;颁发的&#xff0c;用于证明网站所有者身份的真实性和合法性。 一、选择证书颁发机构。首先需…

计算机是如何执行指令的

计算机组成 现在所说的计算机基本上都是冯诺依曼体系的计算机。其核心原理&#xff1a; 冯诺依曼计算的核心思想是将程序指令和数据以二进制形式存储存储在同一存储器中&#xff0c;并使用相同的数据格式和处理方式来处理它们。这种存储程序的设计理念使得计算机能够以可编程…

DataFunSummit:2023年因果推断在线峰会-核心PPT资料下载

一、峰会简介 因果推断是指从数据中推断变量之间的因果关系&#xff0c;而不仅仅是相关关系。因果推断可以帮助业务增长理解数据背后的机制&#xff0c;提高决策的效率和质量&#xff0c;避免被相关性误导&#xff0c;找到真正影响业务的因素和策略。 因果推断在推荐系统中的…

SpringBoot整合SpringSecurity+jwt+knife4生成api接口(从零开始简单易懂)

一、准备工作 ①&#xff1a;创建一个新项目 1.事先创建好一些包 ②&#xff1a;引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency>&…

时间敏感网络TSN的车载设计实践: 802.1Qbv协议

▎概述 IEEE 802.1Qbv[1]是TSN系列协议中备受关注的技术之一&#xff0c;如图1所示&#xff0c;它定义了一种时间感知整形器&#xff08;Time Aware Shaper&#xff0c;TAS&#xff09;&#xff0c;支持Qbv协议的交换机可以按照配置好的门控列表来打开/关闭交换机出口队列&…

某上市证券公司:管控文件交换行为 保护核心数据资产

客户简介 某上市证券公司成立于2001年&#xff0c;经营范围包括&#xff1a;证券经纪、证券投资咨询、证券承销与保荐、证券自营等。经过多年发展&#xff0c;在北京、上海、深圳、重庆、杭州、厦门等国内主要中心城市及甘肃省内各地市设立了15家分公司和80余家证券营业部。20…

字符串函数的模拟实现(strlen,strcpy,strcat,strcmp,strstr)(图文并茂,清晰易懂)

目录 1. strlen函数2. strcpy函数3. strcat函数4. strcmp函数5. strstr函数 个人专栏&#xff1a; 《零基础学C语言》 1. strlen函数 strlen函数&#xff08;Get string length&#xff09;的功能是求字符串长度 使用注意事项&#xff1a; 字符串以 ‘\0’ 作为结束标志&…

如何预防数据泄露?六步策略帮您打造企业信息安全壁垒

大家好&#xff01;我是恒小驰&#xff0c;今天我想和大家聊聊一个非常重要的话题——如何预防数据泄露。在这个数字化的时代&#xff0c;数据已经成为了我们生活中不可或缺的一部分。然而&#xff0c;随着数据的价值日益凸显&#xff0c;数据泄露的风险也随之增加。企业应该如…

windows电脑定时开关机设置

设置流程 右击【此电脑】>【管理】 【任务计划程序】>【创建基本任务】 gina 命令 查看 已经添加的定时任务从哪看&#xff1f;这里&#xff1a; 往下滑啦&#xff0c;看你刚才添加的任务&#xff1a;

Lora学习资料汇总

目录 LoRa联盟 Semtech lora网关供应商: LoRaMAC API文档 论坛 开发板 主流技术对比分析 LoRa网络距离模拟测试方法 LoRa应用 Lora LoraWAN教程 LoRa联盟 LoRa联盟&#xff1a;LoRaWAN规范的制定组织 https://www.lora-alliance.org/ LoRa技术白皮书&#xff1a;htt…

计算机毕业设计项目选题推荐(免费领源码)java+springboot+mysql 城市房屋租赁管理系统01855

摘 要 本论文主要论述了如何使用springboot 城市房屋租赁管理系统 &#xff0c;本系统将严格按照软件开发流程进行各个阶段的工作&#xff0c;采用B/S架构JAVA技术&#xff0c;面向对象编程思想进行项目开发。在引言中&#xff0c;作者将论述城市房屋租赁管理系统的当前背景以及…

SAP指针Field-Symbols:<FS>用法及实例

指针Field-Symbols:用法 内部字段定义 : FIELD-SYMBOLS: [TYPE>] 一、在ABAP编程中使用非常广泛&#xff0c;类似于指针&#xff0c;可以指代任何变量。 当不输入时&#xff0c;继承赋给它的变量的所有属性 当输入时&#xff0c;赋给它的变量必须与同类型。 举个简…

一文带你了解多文件混淆加密

目录 &#x1f512; 一文带你了解 JavaScript 多文件混淆加密 ipaguard加密前 ipaguard加密后 &#x1f512; 一文带你了解 JavaScript 多文件混淆加密 JavaScript 代码多文件混淆加密可以有效保护源代码不被他人轻易盗取。虽然前端的 JS 无法做到纯粹的加密&#xff0c;但通…

Grails 启动

Grails系列 Grails项目启动 文章目录 Grails系列Grails一、项目创建二、可能的问题1.依赖下载2.项目导入到idea失败3.项目导入到idea后运行报错 Grails Grails是一款基于Groovy语言的Web应用程序框架&#xff0c;它使用了许多流行的开源技术&#xff0c;如Spring Framework、…

技术部工作职能规划分析

前言 技术部的职能。以下是一个基本的框架,其中涵盖了技术部在公司中的关键职能和子职能。 主要职能 技术部门的主要职能分为以下几个板块: - 技术规划与战略: 制定技术规划和战略,与业务团队合作确定技术需求。 研究和预测技术趋势,引领公司在技术创新和数字化转型方…

外网讨论疯了的神秘模型Q*(Q-Star)究竟是什么?OpenAI的AGI真的要来了吗 | 详细解读

大家好&#xff0c;我是极智视界&#xff0c;欢迎关注我的公众号&#xff0c;获取我的更多前沿科技分享 邀您加入我的知识星球「极智视界」&#xff0c;星球内有超多好玩的项目实战源码和资源下载&#xff0c;链接&#xff1a;https://t.zsxq.com/0aiNxERDq 这几天&#xff0c;…