RT-DETR+Flask实现目标检测推理案例

news2024/9/23 19:20:24

今天,带大家利用RT-DETR(我们可以换成任意一个模型)+Flask来实现一个目标检测平台小案例,其实现效果如下:

目标检测案例

这个案例很简单,就是让我们上传一张图像,随后选择一下置信度,即可检测出图像中的目标,那么具体该如何实现呢?

RT-DETR模型推理

在先前的学习过程中,博主对RT-DETR进行来了简要的介绍,作为百度提出的实时性目标检测模型,其无论是速度还是精度均取得了较为理想的效果,今天则主要介绍一下RT-DETR的推理过程,与先前使用DETR中使用pth权重与网络结构相结合的推理方式不同,RT-DETR中使用的是onnx这种权重文件,因此,我们需要先对onnx文件进行一个简单了解:
在这里插入图片描述

ONNX模型文件

import onnx
# 加载模型
model = onnx.load('onnx_model.onnx')
# 检查模型格式是否完整及正确
onnx.checker.check_model(model)
# 获取输出层,包含层名称、维度信息
output = self.model.graph.output
print(output)

在原本的DETR类目标检测算法中,推理是采用权重文件与模型结构代码相结合的方式,而在RT-DETR中,则采用onnx模型文件来进行推理,即只需要该模型文件即可。

首先是将pth文件与模型结构进行匹配,从而导出onnx模型文件

"""by lyuwenyu
"""

import os 
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

import argparse
import numpy as np 

from src.core import YAMLConfig

import torch
import torch.nn as nn 


def main(args, ):
    """main
    """
    cfg = YAMLConfig(args.config, resume=args.resume)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu') 
        if 'ema' in checkpoint:
            state = checkpoint['ema']['module']
        else:
            state = checkpoint['model']
    else:
        raise AttributeError('only support resume to load model.state_dict by now.')

    # NOTE load train mode state -> convert to deploy mode
    cfg.model.load_state_dict(state)

    class Model(nn.Module):
        def __init__(self, ) -> None:
            super().__init__()
            self.model = cfg.model.deploy()
            self.postprocessor = cfg.postprocessor.deploy()
            print(self.postprocessor.deploy_mode)
            
        def forward(self, images, orig_target_sizes):
            outputs = self.model(images)
            return self.postprocessor(outputs, orig_target_sizes)
    

    model = Model()

    dynamic_axes = {
        'images': {0: 'N', },
        'orig_target_sizes': {0: 'N'}
    }

    data = torch.rand(1, 3, 640, 640)
    size = torch.tensor([[640, 640]])

    torch.onnx.export(
        model, 
        (data, size), 
        args.file_name,
        input_names=['images', 'orig_target_sizes'],
        output_names=['labels', 'boxes', 'scores'],
        dynamic_axes=dynamic_axes,
        opset_version=16, 
        verbose=False
    )


    if args.check:
        import onnx
        onnx_model = onnx.load(args.file_name)
        onnx.checker.check_model(onnx_model)
        print('Check export onnx model done...')


    if args.simplify:
        import onnxsim
        dynamic = True 
        input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else None
        onnx_model_simplify, check = onnxsim.simplify(args.file_name, input_shapes=input_shapes, dynamic_input_shape=dynamic)
        onnx.save(onnx_model_simplify, args.file_name)
        print(f'Simplify onnx model {check}...')
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c',  default="D:\graduate\programs\RT-DETR-main\RT-DETR-main//rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml",type=str, )
    parser.add_argument('--resume', '-r', default="D:\graduate\programs\RT-DETR-main\RT-DETR-main/rtdetr_pytorch/tools\output/rtdetr_r18vd_6x_coco\checkpoint0024.pth",type=str, )
    parser.add_argument('--file-name', '-f', type=str, default='model.onnx')
    parser.add_argument('--check',  action='store_true', default=False,)
    parser.add_argument('--simplify',  action='store_true', default=False,)

    args = parser.parse_args()

    main(args)

随后,便是利用onnx模型文件进行目标检测推理过程了
onnx也有自己的一套流程:

onnx前向InferenceSession的使用

关于onnx的前向推理,onnx使用了onnxruntime计算引擎。
onnx runtime是一个用于onnx模型的推理引擎。微软联合Facebook等在2017年搞了个深度学习以及机器学习模型的格式标准–ONNX,顺路提供了一个专门用于ONNX模型推理的引擎(onnxruntime)。

import onnxruntime
# 创建一个InferenceSession的实例,并将模型的地址传递给该实例
sess = onnxruntime.InferenceSession('onnxmodel.onnx')
# 调用实例sess的润方法进行推理
outputs = sess.run(output_layers_name, {input_layers_name: x})

推理详细代码

推理代码如下:

import torch
import onnxruntime as ort
from PIL import Image, ImageDraw
from torchvision.transforms import ToTensor

if __name__ == "__main__":
    ##################
    classes = ['car','truck',"bus"]
    ##################
    # print(onnx.helper.printable_graph(mm.graph))
    #############
    img_path = "1.jpg"
    #############
    im = Image.open(img_path).convert('RGB')
    im = im.resize((640, 640))
    im_data = ToTensor()(im)[None]
    print(im_data.shape)

    size = torch.tensor([[640, 640]])
    sess = ort.InferenceSession("model.onnx")
    import time
    start = time.time()
    output = sess.run(
        output_names=['labels', 'boxes', 'scores'],
        #output_names=None,
        input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
    )
    end = time.time()
    fps = 1.0 / (end - start)
    print(fps)
    # print(type(output))
    # print([out.shape for out in output])

    labels, boxes, scores = output

    draw = ImageDraw.Draw(im)
    thrh = 0.6

    for i in range(im_data.shape[0]):

        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        print(i, sum(scr > thrh))
        #print(lab)
        print(f'box:{box}')
        for l, b in zip(lab, box):
            draw.rectangle(list(b), outline='red',)
            print(l.item())

            draw.text((b[0], b[1] - 10), text=str(classes[l.item()]), fill='blue', )
    #############
    im.save('2.jpg')
    #############

前端代码

前端代码包含两部分,一个是上传页面,一个是显示页面

上传页面如下:

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
    <title></title>
    <script src="http://www.jq22.com/jquery/jquery-1.10.2.js"></script>
    <style>
        #addCommodityIndex {
            text-align: center;
            width: 300px;
            height: 340px;
            position: absolute;
            left: 50%;
            top: 50%;
            margin: -200px 0 0 -200px;
            border: solid #ccc 1px;
            padding: 35px;
        }
        
        #imghead {
            cursor: pointer;
        }

        .btn {
            width: 100%;
            height: 40px;
            text-align: center;
        }
    </style>
    <link rel="stylesheet" href="../static/css/bootstrap.min.css"  crossorigin="anonymous">
</head>

<body>

    <div id="addCommodityIndex">
        <h2>目标检测</h2>
        <div class="form-group row">
            <form id="upload"  action="/upload" enctype="multipart/form-data" method="POST">
                <img src="">
                <div class="form-group row">
                    <label>上传图像</label>
                    <input type="file" class="form-control"  name='file'>
                </div>
                <div class="form-group row">
                    <label>选择置信度</label>
                    <select class="form-control" name="score" id="exampleFormControlSelect1">
                        <option value="0.5">0.5</option>
                        <option value="0.6">0.6</option>
                        <option value="0.7">0.7</option>
                        <option value="0.8">0.8</option>
                        <option value="0.9">0.9</option>
                    </select>
                </div>
                <div class="form-group row">
                <div class="btn"><input type="submit" class="btn btn-success" value="提交图像" /></div>
                </div>
            </form>
        </div>
    </div>

</body>
</html>

显示页面:

<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
    <title></title>
    <script src="http://www.jq22.com/jquery/jquery-1.10.2.js"></script>
    <style>
        #addCommodityIndex {
            text-align: center;
            position: absolute;
            left: 40%;
            top: 50%;
            margin: -200px 0 0 -200px;
            border: solid #ccc 1px;
        }
        #imghead {
            cursor: pointer;
        }
        .result {
            width: 100%;
            height: 100%;
            text-align: center;
        }
    </style>
    <link rel="stylesheet" href="../static/css/bootstrap.min.css"  crossorigin="anonymous">
</head>

<body>

<div id="addCommodityIndex">
<div class="card mb-3" style="max-width: 680px;">
    <div class="row no-gutters">
        <div class="col-md-5">
            <img src="../static/img/result.jpg" class="result">
        </div>
        <div class="col-md-5">
            <div class="card-body">
                <h5 class="card-title">检测结果</h5>
                <p class="card-text">目标数量:{{num}}</p>
                <p class="card-text">检测速度:{{fps}}/</p>
                <a  href="/home" class="btn btn-success">继续提交</a>
            </div>
        </div>
    </div>
</div>
</div>
</body>
</html>

Flask框架代码:

# -*- coding: utf-8 -*-
from flask import Flask,request,render_template
import json
import os
import time
app = Flask(__name__)
import infer
@app.route('/home',methods=['GET'])
def home():
    return render_template('upload.html')

@app.route('/upload',methods=['GET','POST'])
def upload():
    if request.method == 'POST':
        f = request.files['file'] #获取数据流
        rootPath = os.path.dirname(os.path.abspath(__file__)) #根目录路径
        #创建存储文件的文件夹,使用时间戳防止重名覆盖
        file_path = 'static/upload/' + str(int(time.time()))
        absolute_path = os.path.join(rootPath,file_path).replace('\\','/') #存储文件的绝对路径,window路径显示\\要转化/
        if not os.path.exists(absolute_path): #不存在改目录则会自动创建
            os.makedirs(absolute_path)
        save_file_name = os.path.join(absolute_path,f.filename).replace('\\','/') #文件存储路径(包含文件名)
        f.save(save_file_name)
        score=request.values.to_dict().get("score")
        num,fps=infer.inference(save_file_name,score)

        #return json.dumps({'code':200,'url':url_path},ensure_ascii=False)
        return render_template("show.html",num=num,fps=fps)

app.run(port='5000',debug=True)

上述项目博主已经上传到github上

git init
git add README.md
git commit -m "first commit"
git branch -M main
git remote add origin https://github.com/pengxiang1998/rt-detr.git
git push -u origin main

项目地址

在使用onnx时,安装了onnxruntime后,出现了下面的错误:

ImportError: cannot import name 'create_and_register_allocator_v2' from 'onnxruntime.capi._pybind_state'

这是由于onnxruntime-gpu版本与CUDA、CuDNN版本不匹配导致的,可以查看下面的网址来查看匹配版本

https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

在这里插入图片描述
随后又出现错误:

> This ORT build has ['TensorrtExecutionProvider',
> 'CUDAExecutionProvider', 'CPUExecutionProvider'] enabled. Since ORT
> 1.9, you are required to explicitly set the providers parameter when instantiating InferenceSession. For example,
> onnxruntime.InferenceSession(...,
> providers=['TensorrtExecutionProvider',

这是由于InferenceSession中没有提供对应的provider,修改代码如下:

if torch.cuda.is_available():
        print("GPU")
        sess = ort.InferenceSession("model.onnx", None, providers=["CUDAExecutionProvider"])
    else:
        print("CPU")
        sess= ort.InferenceSession("model.onnx", None)

随后运行,发现安装了onnxruntime-gpu后的速度竟然满了下来,fps仅为0.2,而原本使用onnxruntime的fps则为7左右,这到底是怎么回事呢?

在这里插入图片描述

YOLO集成推理

而在YOLO集成的RT-DETR项目中,训练得到的权重 文件为.pt,在推理时需要与RT-DETR搭配使用,从而实现推理过程:
需要注意的是,由于YOLO里面集成了多种模型,因此为了具有适配性,其代码都具有通用性

from ultralytics.models import RTDETR
if __name__ == '__main__':
    model=RTDETR("weights/best.pt")
    model.predict(source="images/1.mp4",save=True,conf=0.6)

随后执行predict,代码如下:

def predict(
        self,
        source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
        stream: bool = False,
        predictor=None,
        **kwargs,
    ) -> list:
        if source is None:
            source = ASSETS
            LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")

        is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(
            x in ARGV for x in ("predict", "track", "mode=predict", "mode=track")
        )

        custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"}  # method defaults
        args = {**self.overrides, **custom, **kwargs}  # highest priority args on the right
        prompts = args.pop("prompts", None)  # for SAM-type models

        if not self.predictor:
            self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
            self.predictor.setup_model(model=self.model, verbose=is_cli)
        else:  # only update args if predictor is already setup
            self.predictor.args = get_cfg(self.predictor.args, args)
            if "project" in args or "name" in args:
                self.predictor.save_dir = get_save_dir(self.predictor.args)
        if prompts and hasattr(self.predictor, "set_prompts"):  # for SAM-type models
            self.predictor.set_prompts(prompts)
        return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)

这部分代码在功能上具有复用性,因此在理解上存在一定难度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1933938.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【博士每天一篇文献-算法】连续学习算法之HNet:Continual learning with hypernetworks

阅读时间&#xff1a;2023-12-26 1 介绍 年份&#xff1a;2019 作者&#xff1a;Johannes von Oswald&#xff0c;Google Research&#xff1b;Christian Henning&#xff0c;EthonAI AG&#xff1b;Benjamin F. Grewe&#xff0c;苏黎世联邦理工学院神经信息学研究所 期刊&a…

解决虚拟机与主机ping不通,解决主机没有vmware网络

由于注册表文件缺失导致&#xff0c;使用这个工具 下载cclean 白嫖就行 https://www.ccleaner.com/ 是 点击修复就可以了

《TF2.x》强化学习手册-P47-P59-TD时序差分-Monte_carlo蒙特卡洛预测与控制算法

文章目录 实现时序差分学习前期准备实现步骤工作原理 构建强化学习中的蒙特卡洛预测和控制算法前期准备实现步骤工作原理 实现时序差分学习 时序差分&#xff08;Temporal Difference &#xff0c;TD&#xff09;算法。TD算法是一种预测值或目标值校正的方法&#xff0c;用于强…

JRT实体视图查询

JRT的设计目标就是多数据库支持&#xff0c;对于爬行周边数据提供DolerGet解决爬取多维数据问题。但是对于通过父表字段筛选子表数据就不能通过DolerGet取数据了&#xff0c;因为查询到的父表数据没有子表数据的ID。 比如下面表&#xff1a; 我需要按登记号查询这个登记号的报…

tree组件实现折叠与展开功能(方式2 - visible计算属性)

本示例节选自vue3最新开源组件实战教程大纲&#xff08;持续更新中&#xff09;的tree组件开发部分。考察Vue3 Composition API形式的计算属性的用法&#xff0c;computed可以单独用在ts文件中&#xff0c;实现ts的计算属性类型的定义。 父节点属性 在IFlatTreeNode中定义父节…

Blackbox AI:你的智能编程伙伴

目录 Blackbox AI 产品介绍 Blackbox AI 产品使用教程 Blackbox AI体验 AI问答 代码验证 实时搜索 探索&代理 拓展集成 总结 Blackbox AI 产品介绍 Blackbox是专门为程序员量身定制的语言大模型&#xff0c;它针对20多种编程语言进行了特别训练和深度优化&#xff0c;在AI代…

MySQL JDBC

JDBC&#xff1a;Java的数据库编程 JDBC&#xff0c;即Java Database Connectivity&#xff0c;java数据库连接。是一种用于执行SQL语句的Java API&#xff0c;它是 Java中的数据库连接规范。这个API由 java.sql.*,javax.sql.* 包中的一些类和接口组成&#xff0c;它为Java 开…

MySQL:基础操作(增删查改)

目录 一、库的操作 创建数据库 查看数据库 显示创建语句 修改数据库 删除数据库 备份和恢复 二、表的操作 创建表 查看表结构 修改表 删除表 三、表的增删查改 新增数据 插入否则更新 插入查询的结果 查找数据 为查询结果指定别名 结果去重 where 条件 结…

tree组件实现折叠与展开功能(方式1 - expandedTree计算属性)

本示例节选自vue3最新开源组件实战教程大纲&#xff08;持续更新中&#xff09;的tree组件开发部分。考察响应式对象列表封装和computed计算属性的使用&#xff0c;以及数组reduce方法实现结构化树拍平处理的核心逻辑。 实现思路 第一种方式&#xff1a;每次折叠或展开后触发…

经纬恒润全新第二代行泊一体域控制器成功量产

随着L2自动驾驶功能的普及&#xff0c;整车架构的升级&#xff0c;传统分布式控制器已不能适应市场的发展&#xff0c;如何以低成本高性能实现高阶自动驾驶功能的落地, 成为了众多整车厂的迫切需求&#xff0c;行泊一体域控制器应运而生。据高工数据显示&#xff0c;2023年仅1-…

NVIDIA GPU 监控观测最佳实践

1、DCGM 介绍 DCGM&#xff08;Data Center GPU Manager&#xff09;即数据中心 GPU 管理器&#xff0c;是一套用于在集群环境中管理和监视 Tesla™GPU 的工具。它包括主动健康监控&#xff0c;全面诊断&#xff0c;系统警报以及包括电源和时钟管理在内的治理策略。它可以由系…

TypeScript 基础类型(一)

简介 它是 JavaScript 的超集&#xff0c;具有静态类型检查和面向对象编程的特性。TypeScript 的出现&#xff0c;为开发者提供了一种更加严谨和高效的开发方式。 主要特点&#xff1a; 、静态类型检查。 通过静态类型检查&#xff0c;开发者可以在编译时发现错误&#xff0…

大气热力学(8)——热力学图的应用之一(气象要素求解)

本篇文章源自我在 2021 年暑假自学大气物理相关知识时手写的笔记&#xff0c;现转化为电子版本以作存档。相较于手写笔记&#xff0c;电子版的部分内容有补充和修改。笔记内容大部分为公式的推导过程。 文章目录 8.1 复习斜 T-lnP 图上的几种线8.1.1 等温线和等压线8.1.2 干绝热…

搬运5款我觉得超级好用的软件

​ 今天再来推荐5个超级好用的效率软件&#xff0c;无论是对你的学习还是办公都能有所帮助&#xff0c;每个都堪称神器中的神器&#xff0c;用完后觉得不好用你找我。 1.PDF阅读——Sumatra PDF ​ Sumatra PDF 是一款 PDF 阅读器&#xff0c;它的界面简洁&#xff0c;使用起…

乐尚代驾一项目概述

前言 2024年7月17日&#xff0c;最近终于在低效率的情况下把java及其生态的知识点背的差不多了&#xff0c;投了两个礼拜的简历&#xff0c;就一个面试&#xff0c;总结了几点原因。 市场环境不好 要知道&#xff0c;前两年找工作&#xff0c;都不需要投简历&#xff0c;把简历…

DevExpress WPF中文教程 - 为项目添加GridControl并将其绑定到数据

DevExpress WPF拥有120个控件和库&#xff0c;将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序&#xff0c;这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…

小程序-3(页面导航+页面事件+生命周期+WXS)

目录 1.页面导航 声明式导航 导航到tabBar页面 导航到非tabBar页面 后退导航 编程式导航 后退导航 导航传参 声明式导航传参 编程式导航传参 在onload中接收导航参数 2.页面事件 下拉刷新 停止下拉刷新的效果 ​编辑 上拉触底 配置上拉触底距离 上拉触底的节…

函数返回右值的一点学习研究

https://zhuanlan.zhihu.com/p/511371573?utm_mediumsocial&utm_oi939219201949429760 下面情况下不会调用&#xff1a; DPoint3d fun1() {return DPoint3d{1,2,3}; // 默认构造 }int main() {DPoint3d&& a fun1();a.y 20;int i 0;i; } 下面情况下&#xff0c…

【内网穿透】如何本地搭建Whisper语音识别模型并配置公网地址

个人名片 &#x1f393;作者简介&#xff1a;java领域优质创作者 &#x1f310;个人主页&#xff1a;码农阿豪 &#x1f4de;工作室&#xff1a;新空间代码工作室&#xff08;提供各种软件服务&#xff09; &#x1f48c;个人邮箱&#xff1a;[2435024119qq.com] &#x1f4f1…

数据库管理的艺术(MySQL):DDL、DML、DQL、DCL及TPL的实战应用(下:数据操作与查询)

文章目录 DML数据操作语言1、新增记录2、删除记录3、修改记录 DQL数据查询语言1、查询记录2、条件筛选3、排序4、函数5、分组条件6、嵌套7、模糊查询8、limit分页查询 集合操作union关键字和运算符in关键字any关键字some关键字all关键字 联合查询1、广义笛卡尔积2、等值连接3、…