【CV学习笔记】tensorrtx-yolov5 逐行代码解析

news2025/1/22 22:06:10
1、前言

TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库,作者提到为什么不用ONNX parser的方式来进行trt加速,而用最底层的API来搭建trt加速的方式有如下原因:

  • Flexible 很容易修改模型的任意一层,删除、增加、替换等操作。
  • Debuggable 可以容易获得模型中间某一层的结果
  • Chance to learn 可以对模型结构有进一步的了解

尽管onnx2trt的方式目前已经在绝大部分情况下都不会出现问题,但在trtx下,我们能够掌握更底层的原理和代码,有利于我们对模型的部署以及优化。下文将会以yolov5s在trtx框架下的例子,来逐行解析是trtx是如何工作的。

TensorRTx项目链接:https://github.com/wang-xinyu/tensorrtx。

2、步骤解析

在trtx中,对一个模型加速的过程可以分为两个步骤

  • 提取pytorch模型参数 wts
  • 利用trt底层API搭建网络结构,并将wts中的参数填充到网络中
2.1、get_wts.py

首先需要将pytorch中的模型参数提取出来,pytorch中的模型参数是以caffe中blob的格式存在的,每个操作都有对应的名字、数据长度、数据.

for k, v in model.state_dict().items():
    # k-> blob的名字
    vr = v.reshape(-1).cpu().numpy() # vr -> 数据长度
    f.write('{} {} '.format(k, len(vr)))
    for vv in vr:
        f.write(' ')
        f.write(struct.pack('>f', float(vv)).hex()) # 将数据转化到16进制
        f.write('\n')

通过上get_wts.py,就可以得到包含yolov5s.pth的模型参数,打开yolov5s.wts如下图所示:

在这里插入图片描述

其中第一行的351为总的blob数量,第二行的model.0.conv.weight为第一个blob的名字,3456表示为该blob的数据长度,3a198000 3ca58000…为实际参数。

得到了上述的参数之后,就可以以trtx的方式进行加速了。

2.2、构造engine

在利用wts转engine的之前,需要十分清楚模型的网络结构,不太清楚的同学可以参考太阳花的小绿豆关于yolov5的网络结构图。了解完yolov5的网络结构后,就可以着手利用trt的api来搭建网络模型了。搭建模型的代码在 model.cpp中的build_det_engine函数,本文将其中的代码过程直接画到yolov5的网络结构图中了,可以直接对照代码和图来进行查看。
在这里插入图片描述

//yolov5_det.cpp
viod serialize_engine(...){
	if (is_p6) {
        ...
	} else {
        // 以yolov5s为例
        engine = build_det_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name);
  	}
    // 序列化
    IHostMemory* serialized_engine = engine->serialize();
    std::ofstream p(engine_name, std::ios::binary);
    // 写到文件中
    p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());

}

model.cpp

// 解析get_wts.py
static std::map<std::string, Weights> loadWeights(const std::string file) {
    int32_t count;  // wts文件第一行,共有351个blob
  	input >> count;
    //每一行是一个blob,模型名称 + 数据长度 + 参数
    while (count--) {
        // 一个blob的参数
     	Weights wt{ DataType::kFLOAT, nullptr, 0 };
        uint32_t size;  //blob 数据长度
        std::string name; // blob 数据名字
        for (uint32_t x = 0, y = size; x < y; ++x) {
      		input >> std::hex >> val[x];  // 将数据转化成十进制,并放到val中
    	}
        // 每个blob名字对应一个wt
        weightMap[name] = wt;
    }
}


ICudaEngine* build_det_engine(){
   // 初始化网络结构
   INetworkDefinition* network = builder->createNetworkV2(0U);
   // 定义模型输入
   ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW });
   // 加载pytorch模型中的参数
   std::map<std::string, Weights> weightMap = loadWeights(wts_name);
    
   // 逐步添加网络结构,已将代码与网络结构一一对应 ,具体过程见上图
 
   // 增加yolo后处理decode模块,使用了plugin
   auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});
   network->markOutput(*yolo->getOutput(0));  //将plugin的输出设置为模型的最后输出(decode)
    
   #if defined(USE_FP16)
  	// FP16
	config->setFlag(BuilderFlag::kFP16);
   #elif defined(USE_INT8)
    // INT8 量化
    std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;
    assert(builder->platformHasFastInt8());
    config->setFlag(BuilderFlag::kINT8);
    Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName);
      config->setInt8Calibrator(calibrator);
    #endif
    // 根据网络结构来生成engine
    ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
	return engine;
}
3、plugin

本人对plugin也在学习当中,下面是我在学习trtx-yolo5代码中对plugin浅显的认知。原作者在模型后面增加了一个模型解码的plugin,用于获得每个特征层上的bbox,调用代码在model.cpp中

auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});

static IPluginV2Layer* addYoLoLayer(...){
    // 注册一个名为 "YoloLayer_TRT"的插件,如果找不到插件,就会报错
    auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");
    
    // plugin的数据
    PluginField plugin_fields[2];
    int netinfo[5] = {kNumClass, kInputW, kInputH, kMaxNumOutputBbox, (int)is_segmentation};  //维度数据
  	plugin_fields[0].data = netinfo;  
  	plugin_fields[0].length = 5; 
  	plugin_fields[0].name = "netinfo";
  	plugin_fields[0].type = PluginFieldType::kFLOAT32;
    
    // 所有plugin的参数
    PluginFieldCollection plugin_data;
  	plugin_data.nbFields = 2;
  	plugin_data.fields = plugin_fields;
    // 创建plugin的对象 
    IPluginV2 *plugin_obj = creator->createPlugin("yololayer", &plugin_data);
}

实现代码在yololayer.h/cu中

class API YoloLayerPlugin : public IPluginV2IOExt {
    	
    // 设置插件名称,在注册插件时会寻找对应的插件
      const char* getPluginType() const TRT_NOEXCEPT override{
          return "YoloLayer_TRT";
      }

    
    //插件构造函数
	YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<YoloKernel>& vYoloKernel){
      /*
      	classCount:类别数量
      	netWidth:输入宽
      	netHeight:输入高
      	maxOut:最大检测数量
      	is_segmentation:是否含有实例分割
      	vYoloKernel:anchors参数
      */
    }
    
}

// 插件运行时调用的代码
void YoloLayerPlugin::forwardGpu(...){
    // 输出结果 1+ 是在第一个位置记录解码的数量
    int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);
    
    // 将存放结果的内存置为0
    for (int idx = 0; idx < batchSize; ++idx) {
    	CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));
 
    // 遍历三种不同尺度的anchor
    for (unsigned int i = 0; i < mYoloKernel.size(); ++i) {
        // 调用核函数进行解码
     	CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >(...)
    }
    
}

__global__ void CalDetection(...){
    // input:模型输出结果
    // output:decode存放地址
    // 当前线程的的全局索引ID
    int idx = threadIdx.x + blockDim.x * blockIdx.x;
    // yoloWidth * yoloHeight
    int total_grid = yoloWidth * yoloHeight; // 在当前特征层上要处理的总框数
    int bnIdx = idx / total_grid;    // 第n个batch    
    // x,y,w,h,score + 80
    int info_len_i = 5 + classes;
    // 如果带有实例分割分析,需要再加上32个分割系数
    if (is_segmentation) info_len_i += 32;
    
    // 第n个batch的推理结果开始地址
    const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor);
    // 遍历三种不同尺寸的anchor
    for (int k = 0; k < kNumAnchor; ++k) {
        //每个框的置信度
    	float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);
        if (box_prob < kIgnoreThresh) continue;
        for (int i = 5; i < 5 + classes; ++i) {
            // 每个类别的概率
        	float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);
            // 提取最大概率以及类别ID
            if (p > max_cls_prob) {
        		max_cls_prob = p;
        		class_id = i - 5;
      		}
        }
        // 
        float *res_count = output + bnIdx * outputElem;
        // 统计decode框的数量	
        int count = (int)atomicAdd(res_count, 1);
		// 下面是按照论文的公式将预测的宽和高恢复到原图大小
		...
    }
}
4、总结

通过本次对trtx开源代码的深入学习,知道了如何利用trt的api对模型进行加速,同时还了解到plugin的实现,后续还会继续学习trtx里面的知识点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1043715.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

渗透测试——信息收集思路

文章目录 信息收集域名与 IPOSINTCDNCDN的作用如何检测是否存在CDN CDN 绕过多地Ping邮件服务器子域名真实IP寻找国外地址请求查找老域名查找关联域名信息泄露/配置文件网站漏洞DNS记录&#xff0c;证书域名历史 搜索引擎语法WHOIS端口对外开放情况Nmap 网站的三种部署模式网站…

chrome extensions mv3通过content scripts注入/获取原网站的window数据

开发插件的都知道插件的content scripts和top window只共享Dom不共享window和其他数据&#xff0c;如果想拿挂载在window的数据还有点难度&#xff0c;下面会通过事件的方式传递cs和top window之间的数据写一个例子 代码 manifest.json 这里只搞了2个js&#xff0c;content.…

DataX - 在有总bps限速条件下,单个channel的bps值不能为空,也不能为非正数

更新服务器上的datax版本后&#xff0c;发现执行以前的任务全都失败&#xff0c;查看日志都有报 com.alibaba.datax.common.exception.DataXException: Code:[Framework-03], Description:[DataX引擎配置错误&#xff0c;该问题通常是由于DataX安装错误引起&#xff0c;请联系…

ssl证书 阿里的域名,腾讯云的证书

目录 1.腾讯云申请ssl免费证书 2.去阿里云进行解析 3.回到腾讯云 4.nginx的配置 说明&#xff1a;阿里云的免费证书用完了&#xff08;每年可以申请20个&#xff09;&#xff0c;还有个项目要用证书&#xff0c;第三方的证书免费的都是90天的。看了下腾讯云业可以申请免费的…

史上最全的公司各种体系流程图,直接拿走!

大家好&#xff0c;我是老原。 优秀企业和卓越企业的区别在哪里&#xff1f; 两个字&#xff1a;流程。 流程的水平高低在一定程度上也体现了项目经理做项目的能力&#xff0c;一个企业能否持续成功的过程能力。 拥有稳定高效的流程管理体系&#xff0c;项目经理的管理水平…

ABB机器人如何在示教器上查看输入输出以及强制输出DO信号

ABB机器人如何在示教器上查看输入输出以及强制输出DO信号 如下图所示,点击左上角的菜单—选择“输入输出“, 如下图所示,进入输入输出画面后,点击右下角的视图,选择“数字输出“, 如下图所示,此时可以看到所有的DO信号及其当前值, 如下图所示,这里以 Local_IO_0_DO1 为…

AI大模型服务上线,助力企业AI大模型应用落地

在数字时代的浪潮中&#xff0c;人工智能(AI)技术的发展和应用已经深入到我们生活的方方面面。其中&#xff0c;企业AI大模型作为AI技术的重要形式之一&#xff0c;正在成为推动企业创新、提高效率和优化决策的关键力量。为顺应AI大模型的新趋势需求&#xff0c;近日&#xff0…

游戏技术亮点|Aavegotchi 与 GameSwift 建立合作伙伴关系

构建一个优秀的游戏只是成功发布的一部分&#xff0c;让数百万玩家体验这款游戏才是真正的乐趣所在。 这也是为什么我们很高兴宣布与 GameSwift 建立了新的合作伙伴关系&#xff0c;GameSwift 是一款先进的模块化游戏区块链&#xff0c;采用 zkEVM 技术构建&#xff0c;是全球…

【通意千问】大模型GitHub开源工程学习笔记(1)

9月25日&#xff0c;阿里云开源通义千问140亿参数模型Qwen-14B及其对话模型Qwen-14B-Chat,免费可商用。 立马就到了GitHub去fork。 GitHub&#xff1a; GitHub - QwenLM/Qwen: The official repo of Qwen (通义千问) chat & pretrained large language model proposed b…

解决谷歌Redux DevTools调试React+Typescript项目数据对不上/连接不上问题

上文 ReactTypescript项目环境中搭建并使用redux环境 我们创建了一个redux项目的环境 但是我们用谷歌浏览器插件调试 会发现 要不 匹配的数据有问题 看不到数据 要不 就压根连接不到 而且 我们点击加减号 去改变值 调试工具也没有任何反应 我们终端输入 npm install --save-d…

VSCode安装离线插件

1. 打开 VSCode 插件市场网址 Extensions for the Visual Studio family of product&#xff0c;输入你想要的插件名称&#xff0c;比如这里我想要安装的是 Markdown All in One 插件 2. 点击进入插件主页&#xff0c;点击右侧的 Download Extension 链接&#xff0c;得到下载…

Hugging News #0925: 一览近期的新功能发布

每一周&#xff0c;我们的同事都会向社区的成员们发布一些关于 Hugging Face 相关的更新&#xff0c;包括我们的产品和平台更新、社区活动、学习资源和内容更新、开源库和模型更新等&#xff0c;我们将其称之为「Hugging News」。本期 Hugging News 有哪些有趣的消息&#xff0…

通讯网关软件013——利用CommGate X2ORACLE实现Modbus RTU数据转储ORACLE

本文介绍利用CommGate X2ORACLE实现从Modbus RTU设备读取数据并转储至ORACLE数据库。CommGate X2ORACLE是宁波科安网信开发的网关软件&#xff0c;软件可以登录到网信智汇(wangxinzhihui.com)下载。 【案例】如下图所示&#xff0c;实现从Modbus RTU设备读取数据并转储至ORACL…

【漏洞复现】某友GRP-U8 SQL注入

漏洞描述 某友GRP-U8是某友软件推出的企业级管理软件套件,旨在助力企业实现全面数字化管理及业务优化,某友GRP-U8的bx_historyDataCheck.jsp接口对用户传入的参数未进行有效的过滤,直接拼接至SQL查询的语句中,导致SQL注入漏洞,攻击者可利用该漏洞获取数据库的敏感信息 …

springboot实战(八)之整合redis

目录 序言&#xff1a; 环境&#xff1a; 依赖&#xff1a; 配置&#xff1a; 测试&#xff1a; redis序列化配置&#xff1a; 连接池&#xff1a; 序言&#xff1a; Redis是我们Java开发中&#xff0c;使用频次非常高的一个nosql数据库&#xff0c;数据以key-value键…

Linux下使用yum安装的东西都去哪儿了?(新手友好)

常见的安装路径 使用yum安装的软件包通常都会遵循相似的目录结构 安装路径含义/etc配置文件/var/log日志文件/usr/sbin可执行文件(包括服务管理工具) 面对不同的软件如何看安装位置 上面给出的是一些软件包安装几乎必备的几个安装路径&#xff0c;具体用yum去安装不同的软件…

element-ui form表单,内嵌表单数据校验

在最近开发的功能的过程中,遇到一个很复杂的表单;外层一个大表单;里面有一项是动态添加的,而且内嵌一个表单。每一项还有校验规则;如下图 记录一下调试结果。 无论多少层form, 注意几个事项; form的model/ref; form_item的prop这个关系到,校验作用具体那个框框 数据…

【LeetCode热题100】--48.旋转图像

48.旋转图像 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在原地旋转图像&#xff0c;这意味着你需要直接修改输入的二维矩阵。请不要使用另一个矩阵来旋转图像。 使用辅助数组 class Solution {public void rotate(int[][] matrix)…

2023-9-26 JZ 复杂链表的复制

题目链接&#xff1a;复杂链表的复制 import java.util.*; /* public class RandomListNode {int label;RandomListNode next null;RandomListNode random null;RandomListNode(int label) {this.label label;} } */ public class Solution {public RandomListNode Clone(Ra…

跨域问题的原理及解决方法

一.同源策略 如果没有进行特殊处理&#xff0c;我们在进行前后端联调的时候游览器会发生报错&#xff1a; 这是因为请求被同源策略被阻止&#xff0c;浏览器出于安全的考虑&#xff0c;使用XMLHttpRequest对象发起HTTP请求&#xff08;异步请求&#xff09;时必须遵守同源策略…