TRT3-trt-basic - 5.1 封装插件

news2024/11/25 23:14:39

之前TRTbasic4是新增插件,这次我们看看不新增插件,仅凭封装可不可以达到一样的功能

首先可以看到这次的g.op不再是MYSELU了,而是plugin,那为什么.cu还能识别出来呢?

 是因为在这里做了一个通用的plugin


DEFINE_BUILTIN_OP_IMPORTER(Plugin)
{
    std::vector<nvinfer1::ITensor*> inputTensors;
    std::vector<onnx2trt::ShapedWeights> weights;
    for(int i = 0; i < inputs.size(); ++i){
        auto& item = inputs.at(i);
        if(item.is_tensor()){
            nvinfer1::ITensor* input = &convertToTensor(item, ctx);
            inputTensors.push_back(input);
        }else{
            weights.push_back(item.weights());
        }
    }

    OnnxAttrs attrs(node, ctx);
    auto name = attrs.get<std::string>("name", "");
    auto info = attrs.get<std::string>("info", "");

    // Create plugin from registry
    auto registry = getPluginRegistry();
    auto creator = registry->getPluginCreator(name.c_str(), "1", "");
    if(creator == nullptr){
        printf("%s plugin was not found in the plugin registry!", name.c_str());
        ASSERT(false, ErrorCode::kUNSUPPORTED_NODE);
    }
    
    nvinfer1::PluginFieldCollection pluginFieldCollection;
    pluginFieldCollection.nbFields = 0;

    ONNXPlugin::TRTPlugin* plugin = (ONNXPlugin::TRTPlugin*)creator->createPlugin(name.c_str(), &pluginFieldCollection);
    if(plugin == nullptr){
        LOG_ERROR(name << " plugin was not found in the plugin registry!");
        ASSERT(false, ErrorCode::kUNSUPPORTED_NODE);
    }

    std::vector<std::shared_ptr<ONNXPlugin::Weight>> weightTensors;
    for(int i = 0; i < weights.size(); ++i){
        auto& weight = weights[i];
        std::vector<int> dims(weight.shape.d, weight.shape.d + weight.shape.nbDims);
        std::shared_ptr<ONNXPlugin::Weight> dweight(new ONNXPlugin::Weight(dims, ONNXPlugin::DataType::Float32));
        
        if(weight.type != ::onnx::TensorProto::FLOAT){
            LOG_ERROR("unsupport weight type: " << weight.type);
        }
        
        memcpy(dweight->pdata_host_, weight.values, dweight->data_bytes_);
        weightTensors.push_back(dweight);
    }
    
    plugin->pluginInit(name, info, weightTensors);
    auto layer = ctx->network()->addPluginV2(inputTensors.data(), inputTensors.size(), *plugin);
    std::vector<TensorOrWeights> outputs;
    for( int i=0; i< layer->getNbOutputs(); ++i )
      outputs.push_back(layer->getOutput(i));
    return outputs;
}
} // namespace

仅仅通过 设置name就可以设置这个模块的名字

class MYSELUImpl(torch.autograd.Function):

    # reference: https://pytorch.org/docs/1.10/onnx.html#torch-autograd-functions
    @staticmethod
    def symbolic(g, x, p):
        print("==================================call symbolic")
        return g.op("Plugin", x, p, 
            g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
            name_s="MYSELU",
            info_s=json.dumps(
                dict(
                    attr1_s="这是字符串属性", 
                    attr2_i=[1, 2, 3], 
                    attr3_f=222
                ), ensure_ascii=False
            )
        )

    @staticmethod
    def forward(ctx, x, p):
        return x * 1 / (1 + torch.exp(-x))

并且对于attribute也不用单独的设置了,可以直接用json把这个dicts存进去

所以这里也是通过info读出来

    OnnxAttrs attrs(node, ctx);
    auto name = attrs.get<std::string>("name", "");
    auto info = attrs.get<std::string>("info", "");

传入cu里的时候也是通过info传入,之后再通过config的读取就可以读取出来各种类型的文件,这样就不用再设置字符串类型还是float32类型。

class MYSELU : public TRTPlugin {
public:
	SetupPlugin(MYSELU);

	virtual void config_finish() override{
		printf("\033[33minit MYSELU config: %s\033[0m\n", config_->info_.c_str());
		printf("weights count is %d\n", config_->weights_.size());
	}

	int enqueue(const std::vector<GTensor>& inputs, std::vector<GTensor>& outputs, const std::vector<GTensor>& weights, void* workspace, cudaStream_t stream) override{
		
		int n = inputs[0].count();
		const int nthreads = 512;
		int block_size = n < nthreads ? n : nthreads;
		int grid_size = (n + block_size - 1) / block_size;

		MYSELU_kernel_fp32 <<<grid_size, block_size, 0, stream>>> (inputs[0].ptr<float>(), outputs[0].ptr<float>(), n);
		return 0;
	}
};

RegisterPlugin(MYSELU);

从导出的onnx文件也可以看出来,类型是plugin,name是MYSELU,剩下的都在info里

而且在这里creator什么的用的都是默认的实现

    auto creator = registry->getPluginCreator(name.c_str(), "1", "");

class class_##PluginCreator__ : public nvinfer1::IPluginCreator{																				\
	public:																																			\
		const char* getPluginName() const noexcept override{return #class_;}																					\
		const char* getPluginVersion() const noexcept override{return "1";}																					\
		const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override{return &mFieldCollection;}													\
																																					\
		nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override{									\
			auto plugin = new class_();																												\
			mFieldCollection = *fc;																													\
			mPluginName = name;																														\
			return plugin;																															\
		}																																			\
																																					\
		nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override{								\
			auto plugin = new class_();																												\
			plugin->pluginInit(name, serialData, serialLength);																						\
			mPluginName = name;																														\
			return plugin;																															\
		}																																			\
																																					\
		void setPluginNamespace(const char* libNamespace) noexcept override{mNamespace = libNamespace;}														\
		const char* getPluginNamespace() const noexcept override{return mNamespace.c_str();}																	\
																																					\
	private:																																		\
		std::string mNamespace;																														\
		std::string mPluginName;																													\
		nvinfer1::PluginFieldCollection mFieldCollection{0, nullptr};																				\
	};																																				\
	REGISTER_TENSORRT_PLUGIN(class_##PluginCreator__);

在这里通过自定义ConfigPlugin可以将权重,输入等新信息全都输入到config中。

经过以上这些等等操作,就可以达成在cu里获取基本信息。        

	virtual void config_finish() override{
		printf("\033[33minit MYSELU config: %s\033[0m\n", config_->info_.c_str());
		printf("weights count is %d\n", config_->weights_.size());
	}

抑或是自定义实现enqueue这些操作:


	int enqueue(const std::vector<GTensor>& inputs, std::vector<GTensor>& outputs, const std::vector<GTensor>& weights, void* workspace, cudaStream_t stream) override{
		
		int n = inputs[0].count();
		const int nthreads = 512;
		int block_size = n < nthreads ? n : nthreads;
		int grid_size = (n + block_size - 1) / block_size;

		MYSELU_kernel_fp32 <<<grid_size, block_size, 0, stream>>> (inputs[0].ptr<float>(), outputs[0].ptr<float>(), n);
		return 0;
	}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/774193.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

常用API学习03(Java)

String char charAt(int index) 返回char指定索引处的值 char[] toCharArray() 将此字符串转换为新的字符数组 int compareTo(String anotherString) 按字典顺序比较两个字符串 boolean contains(CharSequence s) 当且仅当此字符串包含指定的char值序列才返…

osg osgQt模块 基于 QGLWidget 源码分析

osgQt基于QGLWidget实现了在Qt窗口内OSG渲染操作。Qt以其开源、跨平台、方便快捷、现代化的界面风格等优点&#xff0c;已经成为了目前桌面版CAD/CAE/CAM等软件开发的首选组件。因此&#xff0c;非常有必要在OSG的基础之上&#xff0c;研究Qt桌面系统内集成OSG渲染功能的相关技…

Linux查看机器内存空间

执行 fdisk命令查看磁盘空间 fdisk -l更多方法参考&#xff1a; Linux检查磁盘空间

一种工业机器人绝对精度的提升方法

摘要&#xff1a;一种新的校准方法&#xff0c;使用动作捕捉作为测量工具&#xff0c;利用ELM神经网络作为非几何误差源补偿&#xff0c;提升工业机器人的绝对精度。 工业机器人的绝对精度是评估其综合性能的重要指标之一。然而&#xff0c;由于机器人受到多种因素的影响&#…

概率论的学习和整理19:条件概率我知道,但什么是条件期望?---用来解决递归问题(草稿)

目录 1 目标问题&#xff1a; 什么是条件期望&#xff1f; 条件期望有什么用&#xff1f; 2 条件期望&#xff0c;全期望公式 3 条件期望&#xff0c;全期望公式 和 条件概率&#xff0c;全概率公式的区别和联系 3.1 公式如下 3.2 区别和联系 3.3 概率和随机过程 4 有什…

8年软件测试工程师感悟 —— 写给还在迷茫中的朋友

去年还在全网声讨互联网企业996呢&#xff0c;今年突然没声音了&#xff0c;也不用讨论在哪个路灯上吊死互联网资本家了&#xff0c;因为都被裁了。 继教育培训领域大幅度裁员之后&#xff0c;大厂裁员消息也开始陆续传出&#xff0c;百度AIG,MEG多条业务线进行精简&#xff0…

线程系列 3 - 关于 CompletableFuture

线程系列3-关于 CompletableFuture 1、从 Future 接口说起2、CompletableFuture 对 Future 的改进2.1、CompletionStage 接口类2.2、runAsync 和 supplyAsync 创建子任务2.3、 whenComplete 和 exceptionally 异步任务回调钩子2.4、调用 handle() 方法统一处理异常和结果2.5、异…

重温黑盒、白盒与灰盒测试方法

黑盒、白盒和灰盒测试方法是软件测试中常用的测试策略&#xff0c;用于评估系统的功能和质量。 对于黑盒、白盒与灰盒测试方法的理解&#xff0c;几年前我在某乎做过一个概念性的回答&#xff0c;当时提问者询问&#xff1a;如何跟非技术人员解释黑盒、白盒、灰盒测试的区别&a…

UG\NX二次开发 使用exception类,异常处理的方法

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan 简介 异常处理是一种编程中常见的错误处理机制,它允许程序在遇到错误或异常情况时优雅地处理。在C++中,异常类是一种重要的异常处理工具。 当程序发生错误或异常时,我们可以使用excepti…

查看maven包依赖关系,一行命令搞定。

1&#xff1a;以logback-classic包为例&#xff0c;在命令端运行&#xff0c;下面命令 mvn dependency:tree -Dincludesch.qos.logback:logback-classic 2.会出现以下日志&#xff0c;就可以清楚的知道这个jar包&#xff0c;是在谁的下面。

python 批量下载图片(协程)

要下载的图片网站 1、总共多少页&#xff0c;得到每页的 url 列表 2、每页的图片详情的 ulr 列表&#xff08;因为该高清大图在图片详情页&#xff0c;因此需要去图片详情页拿图片的url) ​​​​​​​ 3、进入图片详情页&#xff0c;获取到图片url 然后下载。 完整代码如下&…

玩转代码|详细盘点JavaScript 数据类型

目录 什么是JavaScript JavaScript 拥有动态类型 JavaScript 字符串 JavaScript 数字 JavaScript 布尔 JavaScript 数组 JavaScript 对象 Undefined 和 Null JS 中如何判断 undefined JS 中如何判断 null 声明变量类型 什么是JavaScript JavaScript&#xff08;简称…

“nacos is starting with cluster”之nacos启动报错问题

下载并解压nacos后&#xff0c;通过点击startup.cmd启动nacos&#xff0c;出现nacos is starting with cluster的错误&#xff0c;导致nacos未能启动成功。 这是因为&#xff0c;通过startup.cmd命令启动nacos&#xff0c;默认是以集群的方式进行启动的&#xff0c;我们可以改…

为3.7亿用户提供优质服务的微众银行,如何保障应用安全、及时上线

微众银行成立于2014年&#xff0c;是国内首家数字银行。作为银行业改革创新的产物&#xff0c;开业八年多来&#xff0c;微众银行积极把握数字经济时代发展新机遇&#xff0c;运用科技手段为小微企业及普罗大众提供特色化、差异化的优质金融服务&#xff0c;在以数字普惠金融服…

地下供水管漏水监测-供水管道漏水监测设备

地下供水管道作为城市供水系统的重要组成部分&#xff0c;承载着为居民和企业提供清洁饮用水的重要使命。然而&#xff0c;由于管道老化、施工质量、外力损伤等因素&#xff0c;地下供水管道泄漏问题时有发生&#xff0c;这不仅造成了宝贵的水资源浪费&#xff0c;还会导致供水…

流程编排及可视化

写在前面 这里只介绍liteflow的简单基础使用以及作者对liteflow进行可视化扩展的相关阐述 一、背景及意义 背景&#xff1a;对于拥有复杂业务逻辑的系统承载着核心业务逻辑&#xff0c;这些核心业务逻辑涉及内部逻辑运算&#xff0c;缓存操作&#xff0c;持久化操作&#xf…

LiveQing视频点播RTMP推流直播功能-点播拉转在线资源拉转转推到鉴权直播间云端录像集中录像存储

LiveQing点播拉转在线资源拉转转推到鉴权直播间云端录像集中录像存储 1、基本功能2、拉转直播2.1、点播资源拉转2.2、在线资源拉转2.3、服务器本地文件拉转 3、拉转直播如何录像&#xff1f;4、RTMP推流视频直播和点播流媒体服务 1、基本功能 LiveQing RTMP直播点播流媒体服务…

Electron运行报错: Failed to fetch extension, trying ...

Script: "electron:serve": "vue-cli-service electron:serve", 运行 npm run electron:serve 时报错&#xff1a; 解决方法&#xff1a; 检查你的electron配置文件也就是 vue.config.js 中的 mian 的文件 注释其中关于开发工具安装的部分&#xff1a;…

搭建zyplayer-doc个人WIKI文档管理工具,问题记录及简单使用

目录 项目简介各模块介绍项目部署准备工作修改配置及数据库初始化 编译部署编译后文件前后端在同一个部署包当中&#xff08;无需单独部署前端&#xff09; 环境部署目录规划启动脚本编写登录 部署问题记录错误: 找不到或无法加载主类Failed to instantiate [javax.sql.DataSou…

Linux--标记位:flag

我们知道&#xff0c;标记位赋予的值不同&#xff0c;就会生成不同的选项。那么如何给一个变量的位置赋予多个值呢&#xff1f; int整型有32个比特位&#xff0c;故我们可以通过改变位的方式改变值的大小 示例&#xff1a; #include <stdio.h> #include <unistd.h&…