8.5.tensorRT高级(3)封装系列-基于生产者消费者实现的yolov5封装

news2024/11/25 15:48:46

目录

    • 前言
    • 1. yolov5封装
    • 总结

前言

杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。

本次课程学习 tensorRT 高级-基于生产者消费者实现的yolov5封装

课程大纲可看下面的思维导图

在这里插入图片描述

1. yolov5封装

这节我们学习使用封装好的组件,配合生产者消费者,实现一个完整的 yolov5 推理

对于 yolov5 的封装,主要考虑以下:

1. 希望调用者是线程安全的,可以随意进行 commit,而不用考虑是否冲突

2. 希望结果是懒加载的,也就是需要的时候才等待,不需要的时候可以不等待

  • 由 promise 与 future 配合实现
  • 这样的灵活度和效率性能都是最好的

3. 希望最大化利用 GPU,如何利用呢?需要尽可能的使得计算密集

  • 实际体现就是抓取一个批次,一次性进行推理
  • 这在内部的消费者模型里面,一次抓一批

我们来看代码:

yolov5.hpp

#ifndef YOLOV5_HPP
#define YOLOV5_HPP

#include <string>
#include <future>
#include <memory>
#include <opencv2/opencv.hpp>

/
// 封装接口类
namespace YoloV5{

    struct Box{
        float left, top, right, bottom, confidence;
        int class_label;

        Box() = default;

        Box(float left, float top, float right, float bottom, float confidence, int class_label)
        :left(left), top(top), right(right), bottom(bottom), confidence(confidence), class_label(class_label){}
    };
    typedef std::vector<Box> BoxArray;
    
    class Infer{
    public:
        virtual std::shared_future<BoxArray> commit(const cv::Mat& input) = 0;
    };

    std::shared_ptr<Infer> create_infer(
        const std::string& file,
        int gpuid=0, float confidence_threshold=0.25, float nms_threshold=0.45
    );
};

#endif // YOLOV5_HPP

yolov5.cpp


#include "yolov5.hpp"
#include <thread>
#include <vector>
#include <condition_variable>
#include <mutex>
#include <string>
#include <future>
#include <queue>
#include <functional>
#include "trt-infer.hpp"
#include "cuda-tools.hpp"
#include "simple-logger.hpp"

/
// 封装接口类
using namespace std;

namespace YoloV5{

    struct Job{
        shared_ptr<promise<BoxArray>> pro;
        cv::Mat input;
        float d2i[6];
    };

    class InferImpl : public Infer{
    public:
        virtual ~InferImpl(){
            stop();
        }

        void stop(){
            if(running_){
                running_ = false;
                cv_.notify_one();
            }

            if(worker_thread_.joinable())
                worker_thread_.join();
        }

        bool startup(const string& file, int gpuid, float confidence_threshold, float nms_threshold){

            file_    = file;
            running_ = true;
            gpuid_   = gpuid;
            confidence_threshold_ = confidence_threshold;
            nms_threshold_        = nms_threshold;

            promise<bool> pro;
            worker_thread_ = thread(&InferImpl::worker, this, std::ref(pro));
            return pro.get_future().get();
        }

        virtual shared_future<BoxArray> commit(const cv::Mat& image) override{

            if(image.empty()){
                INFOE("Image is empty");
                return shared_future<BoxArray>();
            }
            
            Job job;
            job.pro.reset(new promise<BoxArray>());

            float scale_x = input_width_ / (float)image.cols;
            float scale_y = input_height_ / (float)image.rows;
            float scale   = std::min(scale_x, scale_y);
            float i2d[6];
            i2d[0] = scale;  i2d[1] = 0;  i2d[2] = (-scale * image.cols + input_width_ + scale  - 1) * 0.5;
            i2d[3] = 0;  i2d[4] = scale;  i2d[5] = (-scale * image.rows + input_height_ + scale - 1) * 0.5;

            cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);
            cv::Mat m2x3_d2i(2, 3, CV_32F, job.d2i);
            cv::invertAffineTransform(m2x3_i2d, m2x3_d2i);

            job.input.create(input_height_, input_width_, CV_8UC3);
            cv::warpAffine(image, job.input, m2x3_i2d, job.input.size(), cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar::all(114));
            job.input.convertTo(job.input, CV_32F, 1 / 255.0f);

            shared_future<BoxArray> fut = job.pro->get_future();
            {
                lock_guard<mutex> l(lock_);
                jobs_.emplace(std::move(job));
            }
            cv_.notify_one();
            return fut;
        }

        vector<Box> cpu_decode(
            float* predict, int rows, int cols, float* d2i,
            float confidence_threshold = 0.25f, float nms_threshold = 0.45f
        ){
            vector<Box> boxes;
            int num_classes = cols - 5;
            for(int i = 0; i < rows; ++i){
                float* pitem = predict + i * cols;
                float objness = pitem[4];
                if(objness < confidence_threshold)
                    continue;

                float* pclass = pitem + 5;
                int label     = std::max_element(pclass, pclass + num_classes) - pclass;
                float prob    = pclass[label];
                float confidence = prob * objness;
                if(confidence < confidence_threshold)
                    continue;

                float cx     = pitem[0];
                float cy     = pitem[1];
                float width  = pitem[2];
                float height = pitem[3];

                // 通过反变换恢复到图像尺度
                float left   = (cx - width * 0.5) * d2i[0] + d2i[2];
                float top    = (cy - height * 0.5) * d2i[0] + d2i[5];
                float right  = (cx + width * 0.5) * d2i[0] + d2i[2];
                float bottom = (cy + height * 0.5) * d2i[0] + d2i[5];
                boxes.emplace_back(left, top, right, bottom, confidence, (float)label);
            }

            std::sort(boxes.begin(), boxes.end(), [](Box& a, Box& b){return a.confidence > b.confidence;});
            std::vector<bool> remove_flags(boxes.size());
            std::vector<Box> box_result;
            box_result.reserve(boxes.size());

            auto iou = [](const Box& a, const Box& b){
                float cross_left   = std::max(a.left, b.left);
                float cross_top    = std::max(a.top, b.top);
                float cross_right  = std::min(a.right, b.right);
                float cross_bottom = std::min(a.bottom, b.bottom);

                float cross_area = std::max(0.0f, cross_right - cross_left) * std::max(0.0f, cross_bottom - cross_top);
                float union_area = std::max(0.0f, a.right - a.left) * std::max(0.0f, a.bottom - a.top) 
                                + std::max(0.0f, b.right - b.left) * std::max(0.0f, b.bottom - b.top) - cross_area;
                if(cross_area == 0 || union_area == 0) return 0.0f;
                return cross_area / union_area;
            };

            for(int i = 0; i < boxes.size(); ++i){
                if(remove_flags[i]) continue;

                auto& ibox = boxes[i];
                box_result.emplace_back(ibox);
                for(int j = i + 1; j < boxes.size(); ++j){
                    if(remove_flags[j]) continue;

                    auto& jbox = boxes[j];
                    if(ibox.class_label == jbox.class_label){
                        // class matched
                        if(iou(ibox, jbox) >= nms_threshold)
                            remove_flags[j] = true;
                    }
                }
            }
            return box_result;
        }

        void worker(promise<bool>& pro){

            // load model
            checkRuntime(cudaSetDevice(gpuid_));
            auto model = TRT::load_infer(file_);

            if(model == nullptr){

                // failed
                pro.set_value(false);
                INFOE("Load model failed: %s", file_.c_str());
                return;
            }

            auto input    = model->input();
            auto output   = model->output();
            input_width_  = input->size(3);
            input_height_ = input->size(2);

            // load success
            pro.set_value(true);

            int max_batch_size = model->get_max_batch_size();
            vector<Job> fetched_jobs;
            while(running_){
                
                {
                    unique_lock<mutex> l(lock_);
                    cv_.wait(l, [&](){
                        return !running_ || !jobs_.empty();
                    });

                    if(!running_) break;
                    
                    for(int i = 0; i < max_batch_size && !jobs_.empty(); ++i){
                        fetched_jobs.emplace_back(std::move(jobs_.front()));
                        jobs_.pop();
                    }
                }

                for(int ibatch = 0; ibatch < fetched_jobs.size(); ++ibatch){
                    
                    auto& job = fetched_jobs[ibatch];
                    auto& image = job.input;
                    cv::Mat channel_based[3];
                    for(int i = 0; i < 3; ++i){
                        // 这里实现bgr -> rgb
                        // 做的是内存引用,效率最高
                        channel_based[i] = cv::Mat(input_height_, input_width_, CV_32F, input->cpu<float>(ibatch, 2-i));
                    }
                    cv::split(image, channel_based);
                }

                // 一次加载一批,并进行批处理
                // forward(fetched_jobs)
                model->forward();

                for(int ibatch = 0; ibatch < fetched_jobs.size(); ++ibatch){
                    auto& job = fetched_jobs[ibatch];
                    float* predict_batch = output->cpu<float>(ibatch);
                    auto boxes = cpu_decode(
                        predict_batch, output->size(1), output->size(2), job.d2i, confidence_threshold_, nms_threshold_
                    );
                    job.pro->set_value(boxes);
                }
                fetched_jobs.clear();
            }

            // 避免外面等待
            unique_lock<mutex> l(lock_);
            while(!jobs_.empty()){
                jobs_.back().pro->set_value({});
                jobs_.pop();
            }
            INFO("Infer worker done.");
        }

    private:
        atomic<bool> running_{false};
        int gpuid_;
        float confidence_threshold_;
        float nms_threshold_;
        int input_width_;
        int input_height_;
        string file_;
        thread worker_thread_;
        queue<Job> jobs_;
        mutex lock_;
        condition_variable cv_;
    };

    shared_ptr<Infer> create_infer(const string& file, int gpuid, float confidence_threshold, float nms_threshold){
        shared_ptr<InferImpl> instance(new InferImpl());
        if(!instance->startup(file, gpuid, confidence_threshold, nms_threshold)){
            instance.reset();
        }
        return instance;
    }
};

头文件中定义了一个 Infer 类,该类只有一个 commit 纯虚函数,接收图像数据,然后 shared_future 对象,create_infer 函数用于 RAII,它创建并返回一个 Infer 接口类的实现,通过 startup 方法初始化实例

在 startup 函数中我们创建了一个 bool 类型的 promise 变量,用于判断资源是否获取成功,通过引用的方式传递到了消费者线程 worker 中,在 worker 线程里面会处理模型加载的过程,加载成功或者失败都会把结果反馈到对应的 promise 变量上,而通过 future 对象的 get 方法我们可以获取到是否成功,而成功之后消费者线程也启动了,会继续往下走

在 worker 消费者线程中,有个条件变量在等待,如果条件为 true 则退出等待,也就是当队列不为空时会退出等待,然后将队列中的数据移动到 vector 容器中,循环 vector 容器中的每个图像进行 brg2rgb 同时 split,然后将这一批数据送到网络进行推理拿到结果,然后对拿到的 box 进行decode,最好将通过 promise 的 set_value 方法将 boxes 返回回去

我们再来看下 commit 生产者线程,它会将接收的图像数据进行预处理后放入到队列中,然后利用条件变量通知消费者线程可以处理了,其中预处理使用的是 warpAffine 仿射变换,

总的来说,上述代码提供了 YOLOv5 的推理功能。它实现了一个生产者-消费者模式,其中生产者可以异步地提交推理任务,并使用 future-promise 获取结果。worker 线程作为消费者执行这些任务。以下是 yolov5 封装的关键点:

1. 封装的接口:定义了 Infer 接口,提供异步推理的方法

2. 生产者-消费者模式:该实现采用生产者-消费者模式,其中生产者提交推理任务,worker 线程作为消费者去执行

3. 异步处理:使用 future-promise 机制,生产者可以异步地提交任务并等待结果

4. 预处理:预处理采用 warpAffine 仿射变换

5. decode:采用 IM 逆矩阵进行解码恢复成框

6. 多线程安全:使用互斥锁和条件变量确保多线程安全

7. 资源管理:使用 RAII 原则确保资源的正确管理

最好我们来看下 main.cpp 中的内容变化:


// tensorRT include
// 编译用的头文件
#include <NvInfer.h>

// onnx解析器的头文件
#include <onnx-tensorrt/NvOnnxParser.h>

// 推理用的运行时头文件
#include <NvInferRuntime.h>

// cuda include
#include <cuda_runtime.h>

// system include
#include <stdio.h>
#include <math.h>

#include <iostream>
#include <fstream>
#include <vector>
#include <memory>
#include <functional>
#include <unistd.h>
#include <opencv2/opencv.hpp>

#include "trt-builder.hpp"
#include "simple-logger.hpp"
#include "yolov5.hpp"

using namespace std;

static const char* cocolabels[] = {
    "person", "bicycle", "car", "motorcycle", "airplane",
    "bus", "train", "truck", "boat", "traffic light", "fire hydrant",
    "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
    "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
    "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis",
    "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
    "skateboard", "surfboard", "tennis racket", "bottle", "wine glass",
    "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
    "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
    "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
    "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
    "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
    "scissors", "teddy bear", "hair drier", "toothbrush"
};

static bool exists(const string& path){

#ifdef _WIN32
    return ::PathFileExistsA(path.c_str());
#else
    return access(path.c_str(), R_OK) == 0;
#endif
}

// 上一节的代码
static bool build_model(){

    if(exists("yolov5s.trtmodel")){
        printf("yolov5s.trtmodel has exists.\n");
        return true;
    }

    //SimpleLogger::set_log_level(SimpleLogger::LogLevel::Verbose);
    TRT::compile(
        TRT::Mode::FP32,
        10,
        "yolov5s.onnx",
        "yolov5s.trtmodel",
        1 << 28
    );
    INFO("Done.");
    return true;
}

static void inference(){

    auto image = cv::imread("rq.jpg");
    auto yolov5 = YoloV5::create_infer("yolov5s.trtmodel");
    auto boxes = yolov5->commit(image).get();
    for(auto& box : boxes){
        cv::Scalar color(0, 255, 0);
        cv::rectangle(image, cv::Point(box.left, box.top), cv::Point(box.right, box.bottom), color, 3);

        auto name      = cocolabels[box.class_label];
        auto caption   = cv::format("%s %.2f", name, box.confidence);
        int text_width = cv::getTextSize(caption, 0, 1, 2, nullptr).width + 10;
        cv::rectangle(image, cv::Point(box.left-3, box.top-33), cv::Point(box.left + text_width, box.top), color, -1);
        cv::putText(image, caption, cv::Point(box.left, box.top-5), 0, 1, cv::Scalar::all(0), 2, 16);
    }
    cv::imwrite("image-draw.jpg", image);
}

int main_old();

int main(){

    // 旧的实现,请参照main-old.cpp
    main_old();

    // 新的实现
    if(!build_model()){
        return -1;
    }
    inference();
    return 0;
}

模型构建部分非常简单,一行代码解决模型编译问题,非常方便。推理部分通过 create_infer 拿到一个 shared_ptr 对象,通过调用 commit 方法把图像加入到生成者队列中,然后通过 get 方法等待消费者线程拿到推理结果,然后直接绘制目标框即可,相比于之前的纯裸的,无封装的 yolov5 来说非常简便了

总结

本次课程学习了基于生产者和消费者实现的 yolov5 封装,commit 生成者线程不断往队列中抛数据,通过条件变量 cv_ 通知消费者进行消费,worker 消费者线程会一直等待,直到队列不为空,它会把一批次数据全部扔到模型中进行推理,然后解码,通过 promise 将结果传递回来。create_infer 是 RAII 和接口模式的体现,通过 startup 获取资源并初始话,如果资源获取失败则直接退出,同时实例化的是实现类 InferImpl,而返回的是接口类 Infer,只对使用者暴露 commit 接口。这都是我们之前在生产者消费者课程中讲到过的知识,这边是直接拿过来用了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/904044.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

日撸java_day63-65

文章目录 Booster代码运行截图 Booster 代码 package machineLearning.adaboosting;import weka.core.Instances;import java.io.FileReader; import java.util.Arrays;/*** ClassName: WeightedInstances* Package: machineLearning.adaboosting* Description:Weighted inst…

计算机提示mfc120u.dll缺失(找不到)怎么解决

在计算机领域&#xff0c;mfc120u.dll是一个重要的动态链接库文件。它包含了Microsoft Foundation Class (MFC) 库的特定版本&#xff0c;用于支持Windows操作系统中的应用程序开发。修复mfc120u.dll可能涉及到解决与该库相关的问题或错误。这可能包括程序崩溃、运行时错误或其…

DAY23

题目一 给定一个全是小写字母的字符串str.删除多余字符&#xff0c;使得每种字符只保留一个&#xff0c;并让最终结果字符串的字典序最小 str "acbc"&#xff0c; 删掉第一个c&#xff0c; 得到"abc", 是所有结果字符串中字典序最小的。str "dbcacbc…

Python入门教程 | Python简介和环境搭建

Python 简介 Python是一种高级编程语言&#xff0c;由荷兰人Guido van Rossum于1991年创建。它以其简单易学、可读性强和丰富的生态系统而受到广泛喜爱。它被广泛应用于各个领域&#xff0c;包括Web开发、科学计算、数据分析、人工智能等。 Python的特点 简洁易读&#xff1a…

idea新建web项目

步骤一 步骤二 步骤三 新建两个目录lib、classes 步骤四 设置两个目录的功能lib、classes 步骤五 发布到tomcat

docker项目实战

1、使用mysql:5.6和 owncloud 镜像&#xff0c;构建一个个人网盘 1&#xff09;拉取mysql:5.6和owncloud镜像 [rootmaster ~]# docker pull mysql:5.6 5.6: Pulling from library/mysql 35b2232c987e: Pull complete fc55c00e48f2: Pull complete 0030405130e3: Pull compl…

MetaMask Mobile +Chrome DevTools 调试Web3应用教程

注&#xff1a;本教程来源网络&#xff0c;根据项目做的整理 写好了WEB3应用&#xff0c;在本地调试用得好好的&#xff0c;但是用钱包软件访问就报莫名的错&#xff0c;但是又不知道是什么原因&#xff0c;排查的过程非常浪费时间 。 因此在本地同一局域网进行调试就非常有必要…

河北人事档案管理系统

河北人事档案管理系统是一个集数字化管理、高效服务、安全可靠于一体的人事档案管理平台&#xff0c;可以集中管理机关事业单位人事档案、农村党员档案、参保职工档案、流动人才档案等&#xff0c;并实现高效、便捷的查阅和调阅服务。 河北人事档案管理系统的建设主要是为了更好…

【C++】模拟实现哈希(闭散列和开散列两种方式)

哈希 前言正式开始map、set 与 unordered_map、unordered_set 的不同遍历结果不同查找速度不同 哈希闭散列概念介绍模拟实现字符串等自定义类型找位置字符串哈希算法二次探测 开散列概念介绍模拟实现存储自定义类型哈希表大小设置为素数 前言 在C98中&#xff0c;STL提供了底层…

论文学习——FOLEY SOUND SYNTHESIS AT THE DCASE 2023 CHALLENGE(声音生成介绍)

文章目录 引言正文AbstractIntroduction问题 2 Problem And Task Definition3. Official Dataset And Baseline第一部分问题 4. Evaluation问题 4.1 Step 1&#xff1a;Objective Evaluation问题 4.2 Step 2: Subjective Evaluation问题 4.3 Execution&#xff08;非重点&#…

实验一 ubuntu 网络环境配置

ubuntu 网络环境配置 【实验目的】 掌握 ubuntu 下网络配置的基本方法&#xff0c;能够通过有线网络连通 ubuntu 和开发板 【实验环境】 ubuntu 14.04 发行版FS4412 实验平台 【注意事项】 实验步骤中以“$”开头的命令表示在 ubuntu 环境下执行&#xff0c;以“#”开头的…

华为OD机试 - ABR 车路协同场景 - (Java 2023 B卷 100分)

目录 专栏导读一、题目描述1、问题2、条件3、原型 二、输入描述三、输出描述四、Java算法源码五、效果展示1、输入2、输出 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷&am…

七夕特辑(一)浪漫表白方式 用神经网络生成一首情诗

目录 一、准备工作二、用神经网络生成一首诗&#xff0c;代码说明 牛郎织女相会&#xff0c;七夕祝福要送来。祝福天下有情人&#xff0c;终成眷属永相伴。 七夕是中国传统的情人节&#xff0c;也是恋人们表达爱意的好时机。在这个特别的日子里&#xff0c;送上温馨的祝福&…

idea创建javaweb项目,jboss下没有web application

看看下图这个地方有没有web application

mybatis入门环境搭建及CRUD

一、MyBatis介绍 二、MyBatis环境搭建 创建一个maven项目&#xff0c;名为mybatis01&#xff0c;如下&#xff1a; 2.1 pom.xml修改 代码如下&#xff1a; <?xml version"1.0" encoding"UTF-8"?><project xmlns"http://maven.apache.o…

Java-抽象类和接口(下)

接口使用实例 给对象数组排序 两个学生对象的大小关系怎么确定? 需要我们额外指定. 这里需要用到Comparable 接口 在Comparable 接口内部有一个compareTo 的方法&#xff0c;我们需要实现它 在下图中&#xff0c;我们需要将o强制转换为Student 之后调用Arrays.sort(array)即…

电商项目part04 微服务拆分

微服务架构拆分 微服务介绍 英文:https://martinfowler.com/articles/microservices.html 中文:http://blog.cuicc.com/blog/2015/07/22/microservices 微服务拆分时机 如下场景是否需要进行微服务拆分&#xff1f; 代码维护困难&#xff0c;几百人同时开发一个模块&…

01 背包算法

描述 王强决定把年终奖用于购物&#xff0c;他把想买的物品分为两类&#xff1a;主件与附件&#xff0c;附件是从属于某个主件的&#xff0c;下表就是一些主件与附件的例子&#xff1a; 主件附件电脑打印机&#xff0c;扫描仪书柜图书书桌台灯&#xff0c;文具工作椅无 如果…

漏洞指北-VulFocus靶场专栏-中级02

漏洞指北-VulFocus靶场专栏-中级02 中级005 &#x1f338;thinkphp lang 命令执行&#xff08;thinkphp:6.0.12&#xff09;&#x1f338;step1&#xff1a;burp suite 抓包 修改请求头step2 修改成功&#xff0c;访问shell.php 中级006 &#x1f338;Metabase geojson任意文件…