8.1.tensorRT高级(3)封装系列-模型编译过程封装,简化模型编译代码

news2024/11/25 18:45:28

目录

    • 前言
    • 1. 模型编译过程封装
    • 2. 问答环节
    • 总结

前言

杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。

本次课程学习 tensorRT 高级-模型编译过程封装,简化模型编译代码

课程大纲可看下面的思维导图

在这里插入图片描述

1. 模型编译过程封装

我们来开始学习 tensorRT 的封装

1. 对 tensorRT 的封装,更像是对推理引擎的封装

2. 封装的意义在于对技术的标准化、工具化,能够使得使用时更加便利,效率更高,定制更多的默认行为

3. 封装推理引擎的思想,还可以应用到更多其它地方,嵌入式等等。由于大多推理引擎提供的默认方式不够友好,对其进行包装,能够很好的使得自己的代码具有复用性,一套代码多处用

4. 还可以实现同样的封装,通过简单的配置,切换不同的推理后端,这都取决于需求

5. 我们的唯一目的就是让工作更简单,让代码复用性更强,让技术可以沉淀

这节课我们主要学习对 builder 进行封装,使得编译的接口足够简单

我们来看代码,首先是关于 cuda 小工具的封装,代码如下:

cuda-tools.hpp

#ifndef CUDA_TOOLS_HPP
#define CUDA_TOOLS_HPP

#include <cuda_runtime.h>
#include <string>

#define checkRuntime(call) CUDATools::check_runtime(call, #call, __LINE__, __FILE__)

#define checkKernel(...)                                                                             \
    __VA_ARGS__;                                                                                     \
    do{cudaError_t cudaStatus = cudaPeekAtLastError();                                               \
    if (cudaStatus != cudaSuccess){                                                                  \
        INFOE("launch failed: %s", cudaGetErrorString(cudaStatus));                                  \
    }} while(0);

namespace CUDATools{
    
    bool check_runtime(cudaError_t e, const char* call, int iLine, const char *szFile);
    bool check_device_id(int device_id);
    int current_device_id();
    std::string device_description();

    // 自动切换当前的deviceid,并在析构的时候切换回去
    class AutoDevice{
    public:
        AutoDevice(int device_id = 0);
        virtual ~AutoDevice();
    
    private:
        int old_ = -1;
    };
}


#endif // CUDA_TOOLS_HPP

cuda-tools.cpp


/*
 *  系统关于CUDA的功能函数
 */

#include "cuda-tools.hpp"
#include <stdio.h>
#include <stdarg.h>
#include <string>
#include <simple-logger.hpp>

using namespace std;

namespace CUDATools{

    bool check_runtime(cudaError_t e, const char* call, int line, const char *file){
        if (e != cudaSuccess) {
            INFOE("CUDA Runtime error %s # %s, code = %s [ %d ] in file %s:%d", 
                call, 
                cudaGetErrorString(e), 
                cudaGetErrorName(e), 
                e, file, line
            );
            return false;
        }
        return true;
    }

    bool check_device_id(int device_id){
        int device_count = -1;
        checkRuntime(cudaGetDeviceCount(&device_count));
        if(device_id < 0 || device_id >= device_count){
            INFOE("Invalid device id: %d, count = %d", device_id, device_count);
            return false;
        }
        return true;
    }

    static std::string format(const char* fmt, ...) {
        va_list vl;
        va_start(vl, fmt);
        char buffer[2048];
        vsnprintf(buffer, sizeof(buffer), fmt, vl);
        return buffer;
    }

    string device_description(){

        cudaDeviceProp prop;
        size_t free_mem, total_mem;
        int device_id = 0;

        checkRuntime(cudaGetDevice(&device_id));
        checkRuntime(cudaGetDeviceProperties(&prop, device_id));
        checkRuntime(cudaMemGetInfo(&free_mem, &total_mem));

        return format(
            "[ID %d]<%s>[arch %d.%d][GMEM %.2f GB/%.2f GB]",
            device_id, prop.name, prop.major, prop.minor, 
            free_mem / 1024.0f / 1024.0f / 1024.0f,
            total_mem / 1024.0f / 1024.0f / 1024.0f
        );
    }

    int current_device_id(){
        int device_id = 0;
        checkRuntime(cudaGetDevice(&device_id));
        return device_id;
    }

    AutoDevice::AutoDevice(int device_id){
        
        cudaGetDevice(&old_);
        checkRuntime(cudaSetDevice(device_id));
    }

    AutoDevice::~AutoDevice(){
        checkRuntime(cudaSetDevice(old_));
    }
}

CUDA 工具集的封装,首先是两个宏定义:

checkRuntime(call): 这是一个用于检查 CUDA 运行时函数调用的宏。它会记录函数的名称、文件名和行号,并在发生错误时输出错误信息。

checkKernel(…): 这个宏用于检查 CUDA 核函数的执行。它首先执行核函数,然后检查其执行是否有任何错误,如果有,它将输出错误信息。

命名空间 CUDATools 中实现了一些函数和类:

  • check_runtime:这个函数检查 CUDA 运行时错误,并打印详细的错误信息,包括文件名、行号、错误代码和错误字符串
  • check_device_id: 这个函数检查给定的设备 ID 是否有效,即是否在有效的范围内
  • current_device_id:这个函数返回当前 CUDA 设备的 ID
  • AutoDevice 类:这是一个 RAII 风格的类,用于自动设置和恢复 CUDA 设备。当你创建这个类的对象时,它会设置指定的设备 ID(默认为0),并在其析构函数中恢复原始的设备 ID

我们再来看下 logger 日志的封装,代码如下:

simple-logger.hpp

#ifndef SIMPLE_LOGGER_HPP
#define SIMPLE_LOGGER_HPP

#include <stdio.h>

#define INFOD(...)			SimpleLogger::__log_func(__FILE__, __LINE__, SimpleLogger::LogLevel::Debug, __VA_ARGS__)
#define INFOV(...)			SimpleLogger::__log_func(__FILE__, __LINE__, SimpleLogger::LogLevel::Verbose, __VA_ARGS__)
#define INFO(...)			SimpleLogger::__log_func(__FILE__, __LINE__, SimpleLogger::LogLevel::Info, __VA_ARGS__)
#define INFOW(...)			SimpleLogger::__log_func(__FILE__, __LINE__, SimpleLogger::LogLevel::Warning, __VA_ARGS__)
#define INFOE(...)			SimpleLogger::__log_func(__FILE__, __LINE__, SimpleLogger::LogLevel::Error, __VA_ARGS__)
#define INFOF(...)			SimpleLogger::__log_func(__FILE__, __LINE__, SimpleLogger::LogLevel::Fatal, __VA_ARGS__)


namespace SimpleLogger{

    enum class LogLevel : int{
        Debug   = 5,
        Verbose = 4,
        Info    = 3,
        Warning = 2,
        Error   = 1,
        Fatal   = 0
    };

    void set_log_level(LogLevel level);
    LogLevel get_log_level();
    void __log_func(const char* file, int line, LogLevel level, const char* fmt, ...);

};  // SimpleLogger

#endif // SIMPLE_LOGGER_HPP

simple-logger.cpp


#include "simple-logger.hpp"
#include <string>
#include <stdarg.h>

using namespace std;

namespace SimpleLogger{

    static LogLevel g_level = LogLevel::Info;

    const char* level_string(LogLevel level){
        switch (level){
            case LogLevel::Debug: return "debug";
            case LogLevel::Verbose: return "verbo";
            case LogLevel::Info: return "info";
            case LogLevel::Warning: return "warn";
            case LogLevel::Error: return "error";
            case LogLevel::Fatal: return "fatal";
            default: return "unknow";
        }
    }

    void set_log_level(LogLevel level){
        g_level = level;
    }

    LogLevel get_log_level(){
        return g_level;
    }

    string file_name(const string& path, bool include_suffix){

        if (path.empty()) return "";

        int p = path.rfind('/');
        p += 1;

        //include suffix
        if (include_suffix)
            return path.substr(p);

        int u = path.rfind('.');
        if (u == -1)
            return path.substr(p);

        if (u <= p) u = path.size();
        return path.substr(p, u - p);
    }

    string time_now(){
        char time_string[20];
        time_t timep;							
        time(&timep);							
        tm& t = *(tm*)localtime(&timep);

        sprintf(time_string, "%04d-%02d-%02d %02d:%02d:%02d", t.tm_year + 1900, t.tm_mon + 1, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec);
        return time_string;
    }

    void __log_func(const char* file, int line, LogLevel level, const char* fmt, ...){
        if(level > g_level) return;

        va_list vl;
        va_start(vl, fmt);
        
        char buffer[2048];
        auto now = time_now();
        string filename = file_name(file, true);
        int n = snprintf(buffer, sizeof(buffer), "[%s]", now.c_str());

        if (level == LogLevel::Fatal or level == LogLevel::Error) {
            n += snprintf(buffer + n, sizeof(buffer) - n, "[\033[31m%s\033[0m]", level_string(level));
        }
        else if (level == LogLevel::Warning) {
            n += snprintf(buffer + n, sizeof(buffer) - n, "[\033[33m%s\033[0m]", level_string(level));
        }
        else if (level == LogLevel::Info) {
            n += snprintf(buffer + n, sizeof(buffer) - n, "[\033[35m%s\033[0m]", level_string(level));
        }
        else if (level == LogLevel::Verbose) {
            n += snprintf(buffer + n, sizeof(buffer) - n, "[\033[34m%s\033[0m]", level_string(level));
        }
        else {
            n += snprintf(buffer + n, sizeof(buffer) - n, "[%s]", level_string(level));
        }

        n += snprintf(buffer + n, sizeof(buffer) - n, "[%s:%d]:", filename.c_str(), line);
        vsnprintf(buffer + n, sizeof(buffer) - n, fmt, vl);
        fprintf(stdout, "%s\n", buffer);

        if(level == LogLevel::Fatal || level == LogLevel::Error){
            fflush(stdout);
            abort();
        }
    }
};

上面是一个简单的日志工具封装。通过定义不同的宏,它为我们提供了方便的日志记录功能

  • 我们是通过使用一些宏 (INFOD, INFOV, INFO, INFOW, INFOE, INFOF),开发者可以轻松地在代码中添加日志消息。
  • 每个宏都将当前的文件名 (__FILE__)、行号 (__LINE__)、日志级别以及日志消息传递给 __log_func 函数。
  • __log_func 是执行实际日志记录的函数。它使用变参列表 (va_list) 和 vsnprintf 函数格式化日志消息,根据日志级别,它为日志消息添加颜色

这个日志工具提供了一个简单而有效的方法来在应用程序中添加日志记录功能。其设计允许开发者轻松地添加、修改和控制日志消息。对于使用者来说找 bug 也好找,对于大型项目来讲 logger 这种基本组件肯定是要有的,不然到时候出了问题调试都操心得很

我们来看下核心的 tensorRT 模型编译的封装部分,代码如下:

trt_builder.hpp



#ifndef TRT_BUILDER_HPP
#define TRT_BUILDER_HPP

#include <string>
#include <vector>
#include <functional>

namespace TRT {

	enum class Mode : int {
		FP32,
		FP16
	};

	const char* mode_string(Mode type);

	bool compile(
		Mode mode,
		unsigned int maxBatchSize,
		const std::string& source,
		const std::string& saveto,
		const size_t maxWorkspaceSize = 1ul << 30                // 1ul << 30 = 1GB
	);
};

#endif //TRT_BUILDER_HPP

trt_builder.cpp


#include "trt_builder.hpp"

#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <NvInfer.h>
#include <NvInferPlugin.h>
//#include <NvCaffeParser.h>
#include <onnx-tensorrt/NvOnnxParser.h>
#include <string>
#include <vector>
#include <iostream>
#include <memory>
#include <sstream>
#include <assert.h>
#include <stdarg.h>
#include "cuda-tools.hpp"
#include "simple-logger.hpp"
#include <chrono>

using namespace nvinfer1;
using namespace std;   
//using namespace nvcaffeparser1  ;

class Logger : public ILogger {
public:
	virtual void log(Severity severity, const char* msg) noexcept override {

		if (severity == Severity::kINTERNAL_ERROR) {
			INFOE("NVInfer INTERNAL_ERROR: %s", msg);
			abort();
		}else if (severity == Severity::kERROR) {
			INFOE("NVInfer: %s", msg);
		}
		else  if (severity == Severity::kWARNING) {
			INFOW("NVInfer: %s", msg);
		}
		else  if (severity == Severity::kINFO) {
			INFOD("NVInfer: %s", msg);
		}
		else {
			INFOD("%s", msg);
		}
	}
};

static Logger gLogger;

namespace TRT {

	static string join_dims(const vector<int>& dims){
		stringstream output;
		char buf[64];
		const char* fmts[] = {"%d", " x %d"};
		for(int i = 0; i < dims.size(); ++i){
			snprintf(buf, sizeof(buf), fmts[i != 0], dims[i]);
			output << buf;
		}
		return output.str();
	}

	bool save_file(const string& file, const void* data, size_t length){

        FILE* f = fopen(file.c_str(), "wb");
        if (!f) return false;

        if (data && length > 0){
            if (fwrite(data, 1, length, f) not_eq length){
                fclose(f);
                return false;
            }
        }
        fclose(f);
        return true;
    }

	static string format(const char* fmt, ...) {
		va_list vl;
		va_start(vl, fmt);
		char buffer[10000];
		vsprintf(buffer, fmt, vl);
		return buffer;
	}

	static string dims_str(const nvinfer1::Dims& dims){
		return join_dims(vector<int>(dims.d, dims.d + dims.nbDims));
	}

	static const char* padding_mode_name(nvinfer1::PaddingMode mode){
		switch(mode){
			case nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN: return "explicit round down";
			case nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP: return "explicit round up";
			case nvinfer1::PaddingMode::kSAME_UPPER: return "same supper";
			case nvinfer1::PaddingMode::kSAME_LOWER: return "same lower";
			case nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN: return "caffe round down";
			case nvinfer1::PaddingMode::kCAFFE_ROUND_UP: return "caffe round up";
		}
		return "Unknow padding mode";
	}

	static const char* pooling_type_name(nvinfer1::PoolingType type){
		switch(type){
			case nvinfer1::PoolingType::kMAX: return "MaxPooling";
			case nvinfer1::PoolingType::kAVERAGE: return "AveragePooling";
			case nvinfer1::PoolingType::kMAX_AVERAGE_BLEND: return "MaxAverageBlendPooling";
		}
		return "Unknow pooling type";
	}

	static const char* activation_type_name(nvinfer1::ActivationType activation_type){
		switch(activation_type){
			case nvinfer1::ActivationType::kRELU: return "ReLU";
			case nvinfer1::ActivationType::kSIGMOID: return "Sigmoid";
			case nvinfer1::ActivationType::kTANH: return "TanH";
			case nvinfer1::ActivationType::kLEAKY_RELU: return "LeakyRelu";
			case nvinfer1::ActivationType::kELU: return "Elu";
			case nvinfer1::ActivationType::kSELU: return "Selu";
			case nvinfer1::ActivationType::kSOFTSIGN: return "Softsign";
			case nvinfer1::ActivationType::kSOFTPLUS: return "Parametric softplus";
			case nvinfer1::ActivationType::kCLIP: return "Clip";
			case nvinfer1::ActivationType::kHARD_SIGMOID: return "Hard sigmoid";
			case nvinfer1::ActivationType::kSCALED_TANH: return "Scaled tanh";
			case nvinfer1::ActivationType::kTHRESHOLDED_RELU: return "Thresholded ReLU";
		}
		return "Unknow activation type";
	}

	static string layer_type_name(nvinfer1::ILayer* layer){
		switch(layer->getType()){
			case nvinfer1::LayerType::kCONVOLUTION: return "Convolution";
			case nvinfer1::LayerType::kFULLY_CONNECTED: return "Fully connected";
			case nvinfer1::LayerType::kACTIVATION: {
				nvinfer1::IActivationLayer* act = (nvinfer1::IActivationLayer*)layer;
				auto type = act->getActivationType();
				return activation_type_name(type);
			}
			case nvinfer1::LayerType::kPOOLING: {
				nvinfer1::IPoolingLayer* pool = (nvinfer1::IPoolingLayer*)layer;
				return pooling_type_name(pool->getPoolingType());
			}
			case nvinfer1::LayerType::kLRN: return "LRN";
			case nvinfer1::LayerType::kSCALE: return "Scale";
			case nvinfer1::LayerType::kSOFTMAX: return "SoftMax";
			case nvinfer1::LayerType::kDECONVOLUTION: return "Deconvolution";
			case nvinfer1::LayerType::kCONCATENATION: return "Concatenation";
			case nvinfer1::LayerType::kELEMENTWISE: return "Elementwise";
			case nvinfer1::LayerType::kPLUGIN: return "Plugin";
			case nvinfer1::LayerType::kUNARY: return "UnaryOp operation";
			case nvinfer1::LayerType::kPADDING: return "Padding";
			case nvinfer1::LayerType::kSHUFFLE: return "Shuffle";
			case nvinfer1::LayerType::kREDUCE: return "Reduce";
			case nvinfer1::LayerType::kTOPK: return "TopK";
			case nvinfer1::LayerType::kGATHER: return "Gather";
			case nvinfer1::LayerType::kMATRIX_MULTIPLY: return "Matrix multiply";
			case nvinfer1::LayerType::kRAGGED_SOFTMAX: return "Ragged softmax";
			case nvinfer1::LayerType::kCONSTANT: return "Constant";
			case nvinfer1::LayerType::kRNN_V2: return "RNNv2";
			case nvinfer1::LayerType::kIDENTITY: return "Identity";
			case nvinfer1::LayerType::kPLUGIN_V2: return "PluginV2";
			case nvinfer1::LayerType::kSLICE: return "Slice";
			case nvinfer1::LayerType::kSHAPE: return "Shape";
			case nvinfer1::LayerType::kPARAMETRIC_RELU: return "Parametric ReLU";
			case nvinfer1::LayerType::kRESIZE: return "Resize";
		}
		return "Unknow layer type";
	}

	static string layer_descript(nvinfer1::ILayer* layer){
		switch(layer->getType()){
			case nvinfer1::LayerType::kCONVOLUTION: {
				nvinfer1::IConvolutionLayer* conv = (nvinfer1::IConvolutionLayer*)layer;
				return format("channel: %d, kernel: %s, padding: %s, stride: %s, dilation: %s, group: %d", 
					conv->getNbOutputMaps(),
					dims_str(conv->getKernelSizeNd()).c_str(),
					dims_str(conv->getPaddingNd()).c_str(),
					dims_str(conv->getStrideNd()).c_str(),
					dims_str(conv->getDilationNd()).c_str(),
					conv->getNbGroups()
				);
			}
			case nvinfer1::LayerType::kFULLY_CONNECTED:{
				nvinfer1::IFullyConnectedLayer* fully = (nvinfer1::IFullyConnectedLayer*)layer;
				return format("output channels: %d", fully->getNbOutputChannels());
			}
			case nvinfer1::LayerType::kPOOLING: {
				nvinfer1::IPoolingLayer* pool = (nvinfer1::IPoolingLayer*)layer;
				return format(
					"window: %s, padding: %s",
					dims_str(pool->getWindowSizeNd()).c_str(),
					dims_str(pool->getPaddingNd()).c_str()
				);   
			}
			case nvinfer1::LayerType::kDECONVOLUTION:{
				nvinfer1::IDeconvolutionLayer* conv = (nvinfer1::IDeconvolutionLayer*)layer;
				return format("channel: %d, kernel: %s, padding: %s, stride: %s, group: %d", 
					conv->getNbOutputMaps(),
					dims_str(conv->getKernelSizeNd()).c_str(),
					dims_str(conv->getPaddingNd()).c_str(),
					dims_str(conv->getStrideNd()).c_str(),
					conv->getNbGroups()
				);
			}
			case nvinfer1::LayerType::kACTIVATION:
			case nvinfer1::LayerType::kPLUGIN:
			case nvinfer1::LayerType::kLRN:
			case nvinfer1::LayerType::kSCALE:
			case nvinfer1::LayerType::kSOFTMAX:
			case nvinfer1::LayerType::kCONCATENATION:
			case nvinfer1::LayerType::kELEMENTWISE:
			case nvinfer1::LayerType::kUNARY:
			case nvinfer1::LayerType::kPADDING:
			case nvinfer1::LayerType::kSHUFFLE:
			case nvinfer1::LayerType::kREDUCE:
			case nvinfer1::LayerType::kTOPK:
			case nvinfer1::LayerType::kGATHER:
			case nvinfer1::LayerType::kMATRIX_MULTIPLY:
			case nvinfer1::LayerType::kRAGGED_SOFTMAX:
			case nvinfer1::LayerType::kCONSTANT:
			case nvinfer1::LayerType::kRNN_V2:
			case nvinfer1::LayerType::kIDENTITY:
			case nvinfer1::LayerType::kPLUGIN_V2:
			case nvinfer1::LayerType::kSLICE:
			case nvinfer1::LayerType::kSHAPE:
			case nvinfer1::LayerType::kPARAMETRIC_RELU:
			case nvinfer1::LayerType::kRESIZE:
				return "";
		}
		return "Unknow layer type";
	}

	static bool layer_has_input_tensor(nvinfer1::ILayer* layer){
		int num_input = layer->getNbInputs();
		for(int i = 0; i < num_input; ++i){
			auto input = layer->getInput(i);
			if(input == nullptr)
				continue;

			if(input->isNetworkInput())
				return true;
		}
		return false;
	}

	static bool layer_has_output_tensor(nvinfer1::ILayer* layer){
		int num_output = layer->getNbOutputs();
		for(int i = 0; i < num_output; ++i){

			auto output = layer->getOutput(i);
			if(output == nullptr)
				continue;

			if(output->isNetworkOutput())
				return true;
		}
		return false;
	}  

	template<typename _T>
	shared_ptr<_T> make_nvshared(_T* ptr){
		return shared_ptr<_T>(ptr, [](_T* p){p->destroy();});
	}

	const char* mode_string(Mode type) {
		switch (type) {
		case Mode::FP32:
			return "FP32";
		case Mode::FP16:
			return "FP16";
		default:
			return "UnknowTRTMode";
		}
	}

	static nvinfer1::Dims convert_to_trt_dims(const std::vector<int>& dims){

		nvinfer1::Dims output{0};
		if(dims.size() > nvinfer1::Dims::MAX_DIMS){
			INFOE("convert failed, dims.size[%d] > MAX_DIMS[%d]", dims.size(), nvinfer1::Dims::MAX_DIMS);
			return output;
		}

		if(!dims.empty()){
			output.nbDims = dims.size();
			memcpy(output.d, dims.data(), dims.size() * sizeof(int));
		}
		return output;
	}

	static string align_blank(const string& input, int align_size, char blank = ' '){
        if(input.size() >= align_size) return input;
        string output = input;
        for(int i = 0; i < align_size - input.size(); ++i)
            output.push_back(blank);
        return output;
    }

	static long long timestamp_now() {
        return chrono::duration_cast<chrono::milliseconds>(chrono::system_clock::now().time_since_epoch()).count();
    }

    static double timestamp_now_float() {
        return chrono::duration_cast<chrono::microseconds>(chrono::system_clock::now().time_since_epoch()).count() / 1000.0;
    }


	bool compile(
		Mode mode,
		unsigned int maxBatchSize,
		const string& source,
		const string& saveto,
		const size_t maxWorkspaceSize) {

		INFO("Compile %s %s.", mode_string(mode), source.c_str());
		auto builder = make_nvshared(createInferBuilder(gLogger));
		if (builder == nullptr) {
			INFOE("Can not create builder.");
			return false;
		}

		auto config = make_nvshared(builder->createBuilderConfig());
		if (mode == Mode::FP16) {
			if (!builder->platformHasFastFp16()) {
				INFOW("Platform not have fast fp16 support");
			}
			config->setFlag(BuilderFlag::kFP16);
		}

		shared_ptr<INetworkDefinition> network;
		//shared_ptr<ICaffeParser> caffeParser;
		const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
		network = make_nvshared(builder->createNetworkV2(explicitBatch));
		
		shared_ptr<nvonnxparser::IParser> onnxParser = make_nvshared(nvonnxparser::createParser(*network, gLogger));
		if (onnxParser == nullptr) {
			INFOE("Can not create parser.");
			return false;
		}

		if (!onnxParser->parseFromFile(source.c_str(), 1)) {
			INFOE("Can not parse OnnX file: %s", source.c_str());
			return false;
		}

		auto inputTensor = network->getInput(0);
		auto inputDims = inputTensor->getDimensions();

		INFO("Input shape is %s", join_dims(vector<int>(inputDims.d, inputDims.d + inputDims.nbDims)).c_str());
		INFO("Set max batch size = %d", maxBatchSize);
		INFO("Set max workspace size = %.2f MB", maxWorkspaceSize / 1024.0f / 1024.0f);
		INFO("Base device: %s", CUDATools::device_description().c_str());

		int net_num_input = network->getNbInputs();
		INFO("Network has %d inputs:", net_num_input);
		vector<string> input_names(net_num_input);
		for(int i = 0; i < net_num_input; ++i){
			auto tensor = network->getInput(i);
			auto dims = tensor->getDimensions();
			auto dims_str = join_dims(vector<int>(dims.d, dims.d+dims.nbDims));
			INFO("      %d.[%s] shape is %s", i, tensor->getName(), dims_str.c_str());

			input_names[i] = tensor->getName();
		}

		int net_num_output = network->getNbOutputs();
		INFO("Network has %d outputs:", net_num_output);
		for(int i = 0; i < net_num_output; ++i){
			auto tensor = network->getOutput(i);
			auto dims = tensor->getDimensions();
			auto dims_str = join_dims(vector<int>(dims.d, dims.d+dims.nbDims));
			INFO("      %d.[%s] shape is %s", i, tensor->getName(), dims_str.c_str());
		}

		int net_num_layers = network->getNbLayers();
		INFO("Network has %d layers:", net_num_layers);
		for(int i = 0; i < net_num_layers; ++i){
			auto layer = network->getLayer(i);
			auto name = layer->getName();
			auto type_str = layer_type_name(layer);
			auto input0 = layer->getInput(0);
			if(input0 == nullptr) continue;
			
			auto output0 = layer->getOutput(0);
			auto input_dims = input0->getDimensions();
			auto output_dims = output0->getDimensions();
			bool has_input = layer_has_input_tensor(layer);
			bool has_output = layer_has_output_tensor(layer);
			auto descript = layer_descript(layer);
			type_str = align_blank(type_str, 18);
			auto input_dims_str = align_blank(dims_str(input_dims), 18);
			auto output_dims_str = align_blank(dims_str(output_dims), 18);
			auto number_str = align_blank(format("%d.", i), 4);

			const char* token = "      ";
			if(has_input)
				token = "  >>> ";
			else if(has_output)
				token = "  *** ";

			INFOV("%s%s%s %s-> %s%s", token, 
				number_str.c_str(), 
				type_str.c_str(),
				input_dims_str.c_str(),
				output_dims_str.c_str(),
				descript.c_str()
			);
		}
		
		builder->setMaxBatchSize(maxBatchSize);
		config->setMaxWorkspaceSize(maxWorkspaceSize);

		auto profile = builder->createOptimizationProfile();
		for(int i = 0; i < net_num_input; ++i){
			auto input = network->getInput(i);
			auto input_dims = input->getDimensions();
			input_dims.d[0] = 1;
			profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, input_dims);
			profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, input_dims);
			input_dims.d[0] = maxBatchSize;
			profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, input_dims);
		}

		// not need
		// for(int i = 0; i < net_num_output; ++i){
		// 	auto output = network->getOutput(i);
		// 	auto output_dims = output->getDimensions();
		// 	output_dims.d[0] = 1;
		// 	profile->setDimensions(output->getName(), nvinfer1::OptProfileSelector::kMIN, output_dims);
		// 	profile->setDimensions(output->getName(), nvinfer1::OptProfileSelector::kOPT, output_dims);
		// 	output_dims.d[0] = maxBatchSize;
		// 	profile->setDimensions(output->getName(), nvinfer1::OptProfileSelector::kMAX, output_dims);
		// }
		config->addOptimizationProfile(profile);

		// error on jetson
		// auto timing_cache = shared_ptr<nvinfer1::ITimingCache>(config->createTimingCache(nullptr, 0), [](nvinfer1::ITimingCache* ptr){ptr->reset();});
		// config->setTimingCache(*timing_cache, false);
		// config->setFlag(BuilderFlag::kGPU_FALLBACK);
		// config->setDefaultDeviceType(DeviceType::kDLA);
		// config->setDLACore(0);

		INFO("Building engine...");
		auto time_start = timestamp_now();
		auto engine = make_nvshared(builder->buildEngineWithConfig(*network, *config));
		if (engine == nullptr) {
			INFOE("engine is nullptr");
			return false;
		}
		INFO("Build done %lld ms !", timestamp_now() - time_start);
		
		// serialize the engine, then close everything down
		auto seridata = make_nvshared(engine->serialize());
		return save_file(saveto, seridata->data(), seridata->size());
	}
}; //namespace TRTBuilder

模型编译封装内容主要可分为以下几部分:(from chatGPT)

1. 日志处理

  • 一个自定义的 Logger 类,该类继承自 ILogger,用于处理 TensorRT 的日志消息。
  • 根据消息的严重性,它将打印不同类型的日志消息。

2. 实用函数

  • join_dims:格式化张量的维度为字符串。
  • save_file:将给定数据保存到文件。
  • format:一个简单的字符串格式化函数。
  • dims_str:将 nvinfer1::Dims 对象转换为字符串表示。

3. compile 编译函数

4. 其他功能:文件中还包含一些其他辅助功能和工具,如时间戳获取、CUDA错误处理等。

5. 错误处理:在整个编译过程中,都有对可能出现的错误的检查,如解析错误、文件写入错误等,并在出错时返回相应的错误信息。

上述封装提供了一个从 ONNX 模型到优化后的 TensorRT 推理引擎的完整流程。其主要目的是简化 TensorRT 的使用,让用户只需调用一个函数即可完成模型的导入、优化和序列化。

我们来重点关注 compile 函数中的内容,主要有以下几部分:

1. TensorRT 构建器初始化

函数首先创建一个 TensorRT 构建器实例,并为其设置日志。这个构建器是 TensorRT 框架中用来创建推理引擎的主要组件。

auto builder = make_nvshared(nvinfer1::createInferBuilder(gLogger));

2. 配置设置

设置 TensorRT 构建器的最大批处理大小和最大工作空间大小。这些参数对于控制推理引擎的资源使用和性能至关重要。

builder->setMaxBatchSize(maxBatchSize);
builder->setMaxWorkspaceSize(maxWorkspaceSize);

3. 精度模式

根据传入的模式(FP16或FP32)来设置构建器的精度模式。如果选择 FP16 模式,构建器会使用半精度浮点数进行计算,这可能会提高性能但牺牲一些精度。

if (mode == Mode::FP16) {
    builder->setFp16Mode(true);
}

4. 导入ONNX模型

使用 TensorRT 的 ONNX 解析器导入模型。这里从文件中读取 ONNX 模型并将其解析成 TensorRT 可以理解的格式。

auto parser = make_nvshared(nvonnxparser::createParser(*network, gLogger));
if (!parser->parseFromFile(source.c_str(), static_cast<int>(ILogger::Severity::kWARNING))) {
    return false;
}

5. 优化配置文件

为网络的输入创建优化配置文件。这些配置文件定义了输入的不同大小,从而使 TensorRT 能够为不同大小的输入优化引擎。

auto profile = builder->createOptimizationProfile();
for (int i = 0; i < net_num_input; ++i) {
    auto input = network->getInput(i);
    auto input_dims = input->getDimensions();
    input_dims.d[0] = 1;
    profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, input_dims);
    profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, input_dims);
    input_dims.d[0] = maxBatchSize;
    profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, input_dims);
}
config->addOptimizationProfile(profile);

6. 构建引擎

使用前面的配置和网络信息构建 TensorRT 引擎。这是一个耗时的过程,因为 TensorRT 会尝试多种优化技术来提高模型的推理速度。

auto engine = make_nvshared(builder->buildEngineWithConfig(*network, *config));

7. 序列化引擎

最后,构建好的引擎被序列化,这样就可以在没有原始模型和配置的情况下重新加载它。序列化后的引擎被保存到文件中。

auto seridata = make_nvshared(engine->serialize());
return save_file(saveto, seridata->data(), seridata->size());

这个模型编译的封装代码涵盖了从原始 ONNX 模型到优化后的 TensorRT 推理引擎的整个过程。其简化了 TensorRT 模型编译过程,让用户只需要调用一个函数即可完成模型的编译工作。

二话不说先去执行一下,执行效果如下所示:

在这里插入图片描述

图1-1 builder封装案例

可以看到它会打印一个时间,然后是日志级别,然后是文件名和行号,这是 logger 起到的作用

在编译过程我们也可以获取到它的详细信息,带网络结构的详细信息,如下所示:

在这里插入图片描述

图1-2 详细网络结构信息查看

以上就是关于 builder 的封装,虽然有很多废代码对于性能的提升没有半毛钱关心,但是它有利于我们调试,

2. 问答环节

以下内容是来自于博主与 chatGPT 之间的对话。

问题1:我比较好奇宏定义 checkRuntime 中像 call #call __LINE__ __FILE__ 这种变量没有定义,代码是如何识别的呢?

回答:这些是 C/C++ 预处理器的特性。在宏定义中,您可以使用一些特殊的宏运算符来实现额外的功能。让我们详细解释 checkRuntime 宏中使用的这些元素:

1. call:

  • 在宏定义中,call 是一个宏参数。当您使用这个宏并传入一个参数时,call 会被替换为该参数的值。例如,如果您写 checkRuntime(cudaMalloc(…));,那么在宏扩展中,所有的 call 都会被替换为 cudaMalloc(…)

2. #call:

  • # 是一个宏运算符,称为 “字符串化” 运算符。当它出现在宏定义中并放在一个宏参数前面时,它会将该参数转换为一个字符串字面量。因此,#call 会将 call 参数转换为一个字符串。例如,如果 callcudaMalloc(…),那么 #call 就是 “cudaMalloc(…)”

3. __LINE__ 和 __FILE__:

  • 这些是预定义的宏。它们不需要手动定义。
  • __LINE__ 会被替换为当前的源代码行号。
  • __FILE__ 会被替换为当前的源文件名。

所以,当在代码中写 checkRuntime(cudaMalloc(…)); 时,预处理器会将其扩展为:

CUDATools::check_runtime(cudaMalloc(...), "cudaMalloc(...)", 当前行号, "当前文件名");

这样就可以在 check_runtime 函数中获取到调用的具体函数、行号和文件名,从而使错误报告更加详细和有用。

问题2:在checkKernel中 … 和 __VA_ARGS__ 这两个变量代表又代表啥意思呢?

回答:在 C++ 宏中,__VA_ARGS__ 用于处理可变数量的参数。这是 C++11 引入的变参宏(Variadic Macros)的一部分。

1.

  • 在宏定义的参数列表中, 用于表示该宏可以接受可变数量的参数。
  • 使用 允许您为宏提供任意数量的参数

2. __VA_ARGS__:

  • 在宏的替换部分,您可以使用 __VA_ARGS__ 来引用传递给宏的可变参数。
  • __VA_ARGS__ 将被替换为传递给宏的所有参数(除了已经明确命名的参数之外)。

允许 checkKernel 接受任何数量的参数。然后,宏的替换部分使用 __VA_ARGS__ 来执行这些参数。这允许您直接在 checkKernel 宏前放置 CUDA 核函数调用。

因此当你写下面的代码时:

checkKernel(myKernel<<<blocks, threads>>>(arg1, arg2));

宏会被扩展为:

myKernel<<<blocks, threads>>>(arg1, arg2);
do {
   cudaError_t cudaStatus = cudaPeekAtLastError();
   if (cudaStatus != cudaSuccess){
       INFOE("launch failed: %s", cudaGetErrorString(cudaStatus));
   }
} while(0);

因此,__VA_ARGS__ 允许您为宏提供可变数量的参数,并在宏的定义中使用这些参数。这为宏提供了更大的灵活性,并允许它们处理各种不同的情况。

总结

本次课程学习了 builder 的封装,tensorRT 的模型编译过程其实很多地方都是通用的,我们完全可以将其封装起来。首先我们封装了一个 cuda-tools 的小工具,里面包含检查 cuda runtime 函数以及 cuda kernel 核函数是否执行成功的定义,另外包含了一些 device 信息的小工具;然后我们封装了一个 logger 日志文件,这个日志文件可以输出对应的时间、日志级别、文件名、行号等,非常有利于我们的调试工作;最后是我们的核心,关于 builder 的封装,我们重点关注了 compile 函数的实现,也就是把我们平时的 tensorRT 编译模型的过程给封装了起来,其中我们通过 logger 日志工具输出了许多信息,甚至可以打印出网络结构的详细信息,非常实用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/872993.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

详细讲解如何在github上编辑个人主页?

在 GitHub 上编辑个人主页可以让您展示您的项目、技能和个人信息&#xff0c;以及与其他开发者互动。以下是详细的步骤来在 GitHub 上编辑个人主页&#xff1a; 创建 GitHub 账户 如果您还没有 GitHub 账户&#xff0c;首先需要注册一个。 登录到 GitHub 使用您的用户名和密…

StringBuilder的基本操作

1、为什么要学习StringBuilder? 1.1、String拼接100万次 String对象做字符串拼接&#xff0c;字符串直接拼接100万次&#xff0c;运行速度非常非常的慢&#xff0c;当数据量比较大的时候&#xff0c;一般不用字符串直接拼接 package stringdemo;public class StringTest {publ…

数据结构:选择排序

简单选择排序 选择排序是一种简单直观的排序算法。首先在未排序序列中找到最大&#xff08;最小&#xff09;的元素&#xff0c;存放到排序学列的其实位置&#xff0c;然后在剩余的未排序的元素中寻找最小&#xff08;最大&#xff09;元素&#xff0c;存放在已排序序列的后面…

9.利用matlab完成 泰勒级数展开 和 符号表达式傅里叶变换和反变换 (matlab程序)

1.简述 matlab之傅里叶变换和逆变换 首先生成一个方波&#xff08;或者其他组合波形&#xff09;&#xff0c;然后对这个信号做傅里叶变换&#xff0c;拆解到频域&#xff0c;可以看到这个信号是由哪些频率的信号叠加而来。 然后把频域信号&#xff0c;用傅里叶逆变换恢复到时…

Grafana Prometheus 通过JMX监控kafka 【2023最新方式】

第三方kafka exporter方案 目前网上关于使用Prometheus 监控kafka的大部分资料都是使用一个第三方的 kafka exporter&#xff0c;他的原理大概就是启动一个kafka客户端&#xff0c;获取kafka服务器的信息&#xff0c;然后提供一些metric接口供Prometheus使用&#xff0c;随意它…

编写守护进程

守护进程是一个后台进程&#xff0c;当操作系统启动时就可以运行的进程&#xff0c;当操作系统结束时结束的进程&#xff0c;与终端无关。 结果 不想要了就杀死

rt_hw_stack_init 线程栈初始化参数分析

rt_hw_stack_init rt-thread线程栈初始化参数分析 文章目录 rt_hw_stack_init rt-thread线程栈初始化参数分析Q:A:1. rt_hw_stack_init调用分析2. rt_hw_stack_init 实现分析2.1 向下增长型栈 rt_hw_stack_init 实现2.1 向上增长型栈 rt_hw_stack_init 实现 扩展知识&#xff1…

MybatisPlus中安装MybatisX插件及代码生成

创建一个新的模块 导入依赖 设置application.yml 连接数据库 点击下面红色框里面的&#xff0c;然后点apply&#xff0c;最后点ok 选好之后点next base path 根据自己的路径写 选完之后点finish 这样就自己生成了代码。 快速生成CRUD 选择方法后altenter 还可以多条件 方法…

NetSuite 固定资产租赁101

目录 前言 1.新租赁准则的相关内容 1.1 主要变化 1.2 IFRS 16/ASC 842/CAS 21的区别与联系 1.3 新租赁准则实行的意义 2.NetSuite中的租赁功能 2.1 概述 2.2 设置 2.2.1 相关科目设置 2.2.2 资产类型设置 2.3 功能详细说明 2.3.1 案例一 2.3.2 案例二 3.新租赁准则…

kafka 02——三个重要的kafka客户端

kafka 02——三个重要的kafka客户端 1. 前言1.1 关于 Kafka 的安装1.2 常用客户端简介1.3 依赖 2. AdminClient2.1 Admin Configs2.2 AdminClient API2.2.1 设置 AdminClient 对象2.2.2 创建 topic 获取 topic 列表2.2.3 删除topic2.2.4 查看 topic 的描述信息2.2.5 查看 topi…

小白带你部署LNMP分布式部署

目录 前言 一、概述 二、LNMP环境部署 三、配置nginx 1、yum安装 2、编译安装 四、安装 1、编译安装nginx 2、网络源 3、稍作优化 4、修改配置文件vim /usr/local/nginx/conf/nginx.conf 5、书写测试页面 五、部署应用 前言 LNMP平台指的是将Linux、Nginx、MySQL和…

Spring Boot + Vue3前后端分离实战wiki知识库系统<十二>--用户管理单点登录开发一

目标&#xff1a; 在上一次Spring Boot Vue3前后端分离实战wiki知识库系统&#xff1c;十一&#xff1e;--文档管理功能开发三我们已经完成了文档管理的功能模块开发&#xff0c;接下来则开启新模块的学习---用户登录&#xff0c;这块还是有不少知识点值得学习的&#xff0c;…

谈谈语音助手

目录 1.什么是语音助手 2.语音助手的发展过程 3.现在有哪些成熟的语音助手 4.语音助手对人类发展的影响 1.什么是语音助手 语音助手是一种能够通过语音交互与用户进行沟通和执行任务的虚拟助手。它基于人工智能和自然语言处理技术&#xff0c;能够理解用户的语音指令&#x…

LeetCode 36题:有效的数独

题目 请你判断一个 9 x 9 的数独是否有效。只需要 根据以下规则 &#xff0c;验证已经填入的数字是否有效即可。 数字 1-9 在每一行只能出现一次。数字 1-9 在每一列只能出现一次。数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。&#xff08;请参考示例图&#xff…

MEC | 条款4 非必要不提供 default constructor

条款4 非必要不提供 default constructor 文章目录 条款4 非必要不提供 default constructorclasses 必须有默认构造函数&#xff1f;Example1. 产生数组解决方法 2.不适用于 template-based container clases3.虚基函数 探讨>>>>> 欢迎关注公众号【三戒纪元】 …

CANoe自动化工程的搭建

基于XMLCAPL建立自动化工程 1、导入ini文件2、新建 Test Environment3、报告类型4、代码编写 1、导入ini文件 工程的配置的文件&#xff0c;配置DUT相关信息&#xff0c;具体视工程而编写内容。 2、新建 Test Environment 1、新建XML测试用例环境 2、导入XML测试用例文件 …

vulnhub靶机Deathnote

难度&#xff1a;easy 下载地址&#xff1a;https://download.vulnhub.com/deathnote/Deathnote.ova 主机发现 arp-scan -l 端口扫描 nmap --min-rate 10000 -p- 192.168.21.140 进一步查看目标的端口的服务和版本 nmap -sV -sT -O -p22,80 192.168.21.140 扫描端口的漏洞…

微服务实战项目-学成在线-项目优化(redis缓存优化)

微服务实战项目-学成在线-项目优化(redis缓存优化) 1 优化需求 视频播放页面用户未登录也可以访问&#xff0c;当用户观看试学课程时需要请求服务端查询数据&#xff0c;接口如下&#xff1a; 1、根据课程id查询课程信息。 2、根据文件id查询视频信息。 这些接口在用户未认…

Java面试——一分钟搞懂限流算法

为什么限流 运营网站&#xff0c;经常会遇到各种挑战&#xff1a;某黑客发起DoS攻击、网络爬虫网页抓取、商品秒杀活动、双十一与618等场景&#xff0c;会使流量突然激增&#xff0c;如果不限制流量的访问就会使系统宕机。 常见的限流算法 1.漏桶算法&#xff08; LEAKY BUC…

Eclipse-配置彩色输出打印

文章目录 前言配置下载查看是否安装 前言 这是一篇古老的文章&#xff0c;那个时候还在用Eclipse &#xff0c;现在已经换 IDEA 了… 这是一篇 2018 年的文章&#xff0c;我只是将文章从个人比较挪到了CSDN 中 配置 配置完然后下载下面插件即可生成彩色代码。 下载 ANSI …