RT-DETR:端到端的实时Transformer检测模型(目标检测+跟踪)

news2024/11/19 17:49:50

博主一直一来做的都是基于Transformer的目标检测领域,相较于基于卷积的目标检测方法,如YOLO等,其检测速度一直为人诟病。
终于,RT-DETR横空出世,在取得高精度的同时,检测速度也大幅提升。

那么RT-DETR是如何做到的呢?

在研究RT-DETR的改进前,我们先来了解下DETR类目标检测方法的发展历程吧

  • 首先是DETR,该方法作为Transformer在目标检测领域的开山之作,一经推出,便引发了极大的轰动,该方法巧妙的利用Transformer进行特征提取与解码,同时通过匈牙利匹配方法完成预测框与真实框的匹配,避免了NMS等后处理过程。
  • 随后DAB-DETR引入了动态锚框作为查询向量,从而对DETR中的100个查询向量进行了解释。
  • Deformable-DETR针对Transformer中自注意力计算复杂度高的问题,提出可变形注意力计算,即通过可学习的选取少量向量进行注意力计算,大幅的降低了计算量。
  • DN-DETR认为匈牙利匹配的二义性是导致DETR训练收敛慢的原因,因此提出查询降噪机制,即利用先前DAB-DETR中将查询向量解释为锚框的原理,给查询向量添加一些噪声来辅助模型收敛,最终大幅提升了模型的训练速度。
  • DINO则是在DAB-DETR与DN-DETR的基础上进行进一步的融合与改进。
  • H-DETR为使模型获取更多的正样本特征,从而提升检测精度,因此提出混合匹配方法,在训练阶段,包含原始的匈牙利匹配分支与一个一对多的辅助匹配分支,而在推理阶段,则只有一个匈牙利匹配分支。

然而,上述方法尽管已经大幅提升了检测精度,降低了计算复杂度,但其受Transformer本身高计算复杂度的制约,DETR类目标检测方法的实时性始终令人难以满意,尤其是相较于YOLO等单阶段目标检测方法,其检测速度的确差别巨大。

为了解决这个问题,百度提出了RT-DETR,该方法依旧是在DETR的基础上改进生成的,从论文中给出的实验结果来看,该方法无论在检测速度还是检测精度方法都已经超过了YOLOv8,实现了真正的实时性。

在这里插入图片描述

  • 创新点1:高效混合编码器:RT-DETR使用了一种高效的混合编码器,通过解耦尺度内交互和跨尺度融合来处理多尺度特征。这种独特的基于视觉Transformer的设计降低了计算成本,并允许实时物体检测。
  • 创新点2:IoU感知查询选择:RT-DETR通过利用IoU感知的查询选择改进了目标查询初始化。这使得模型能够聚焦于场景中最相关的目标,从而提高了检测精度。
  • 创新点3:自适应推理速度:RT-DETR支持通过使用不同的解码器层来灵活调整推理速度,而无需重新训练。这种适应性便于在各种实时目标检测场景中的实际应用。

RT-DETR的代码有两个,一个是官方提供的代码,但该代码功能比较单一,只有训练与验证,另一个则是集成在YOLOv8中,该代码的设计就比较全面了

环境部署

conda create -n rtdetr python=3.8
conda activate rtdetr
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
cd RT-DETR-main/rtdetr_pytorch  //这个路径根据你自己的改
pip install -r requirement.txt

该算法的环境为pytorch=2.0.1,注意,尽量要用pytorch2以上的版本,否则可能会报错:

AttributeError: module 'torchvision' has no attribute 'disable_beta_transforms_warning'

官方模型训练

参数配置

该算法的配置封装较好,我们只需要修改配置即可:train.py,指定要使用的骨干网络。

parser.add_argument('--config', '-c', default="/rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml",type=str, )

修改数据集配置文件:RT-DETR-main\rtdetr_pytorch\configs\dataset\coco_detection.yml
修改训练集与测试集路径,同时修改类别数。

在这里插入图片描述

随后便可以开启训练:该文件中指定 epochs

RT-DETR-main\rtdetr_pytorch\configs\rtdetr\include\optimizer.yml

首次训练,需要下载骨干网络的预训练模型

在这里插入图片描述

在这里,博主使用ResNet18作为骨干特征提取网络

训练结果

开始运行,查看GPU使用情况,此时的batch-size=8,可以看到显存占用4.5G左右,相较于博主先前提出的方法或者DINO,其显存占用少了许多,DINObatch-size=2时的显存占用将近16G.

在这里插入图片描述

训练了24轮的结果。

在这里插入图片描述

训练的结果会保存在output文件夹内:

在这里插入图片描述

官方模型推理

在进行模型推理前,需要先导出模型,在官方代码的tools文件夹下有个export_onnx.py文件,只需要指定配置文件与训练好的模型文件:

parser.add_argument('--config', '-c',  default="/rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml",type=str, )
parser.add_argument('--resume', '-r', default="rtdetr_pytorch/tools\output/rtdetr_r18vd_6x_coco\checkpoint0024.pth",type=str, )

导出的文件是onnx格式

ONNX(Open Neural Network Exchange)是一种开放式的文件格式,用于存储和交换训练好的机器学习模型。它使得不同的人工智能框架(如PyTorch、TensorFlow)可以共享模型,促进了模型在不同平台之间的迁移和复用。ONNX文件采用Protobuf序列化技术进行存储,具有高效、紧凑的特点。

在这里插入图片描述
随后开始推理,代码如下:

import torch
import onnxruntime as ort
from PIL import Image, ImageDraw
from torchvision.transforms import ToTensor
if __name__ == "__main__":
    ##################
    classes = ['car','truck',"bus"]
    ##################
    # print(onnx.helper.printable_graph(mm.graph))
    #############
    img_path = "1.jpg"
    #############
    im = Image.open(img_path).convert('RGB')
    im = im.resize((640, 640))
    im_data = ToTensor()(im)[None]
    print(im_data.shape)
    size = torch.tensor([[640, 640]])
    sess = ort.InferenceSession("model.onnx")
    import time
    start = time.time()
    output = sess.run(
        # output_names=['labels', 'boxes', 'scores'],
        output_names=None,
        input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
    )
    end = time.time()
    fps = 1.0 / (end - start)
    print(fps)
    
    labels, boxes, scores = output
    draw = ImageDraw.Draw(im)
    thrh = 0.6
    for i in range(im_data.shape[0]):

        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        print(i, sum(scr > thrh))
        #print(lab)
        print(f'box:{box}')
        for l, b in zip(lab, box):
            draw.rectangle(list(b), outline='red',)
            print(l.item())

            draw.text((b[0], b[1] - 10), text=str(classes[l.item()]), fill='blue', )
    #############
    im.save('2.jpg')
    #############

YOLOv8集成RT-DETR训练

在YOLOv8中,给出了YOLO先前的诸多版本,此外还包含RT-DETR
其运行环境与官方的相同,这里就不再赘述了,另外,如果想要了解YOLO及其集成算法的更多功能,可以查看:

https://docs.ultralytics.com/

ultralytics集成了多种算法,已有将YOLO目标检测算法大一统的趋势,涵盖语义分割、目标检测、姿势估计、分类、跟踪等多个任务。

数据集配置

YOLO版本的RT-DETR的数据集支持的数据集格式有多种,这里博主选用的是YOLO格式的

coco
    images
    	train2017
    	val2017
    lables
    	train2017
    	val2017

开始训练

随后在根目录下新建一个run.py文件,文件中写入如下代码:

from ultralytics.models import RTDETR
if __name__ == '__main__':
    model = RTDETR(model='ultralytics/cfg/models/rt-detr/rtdetr-l.yaml')
    #model.load('rtdetr-l.pt') # 不使用预训练权重可注释掉此行
    model.train(pretrained=True, data='ultralytics\cfg\datasets\cocomine.yaml', epochs=200, batch=16, device=0, imgsz=320, workers=2,cache=False,)

运行报错:

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.

解决方法,这是由于Anconda的torch中的某个文件与环境中的某个文件冲突导致的,找到环境中的文件:
环境路径:

D:\softwares\Anconda\envs\detr\Library\bin

将下面的文件给重命名即可。

在这里插入图片描述

随后便开始训练了,如下:

在这里插入图片描述

至此,RT-DETR的训练过程便完成了。博主设置训练200个epoch,但考虑到接下来的任务,因此训练到一半就停止了,生成的文件存放在run文件中,如下:

在这里插入图片描述

YOLOv8集成RT-DETR推理

在YOLOv8集成的RT-DETR中,其设计就非常完备了,我们只需要新建一个predict.py,里面的内容如下:
这里的images即为一个文件夹,里面可以放入多张图像,save代表保存

model=RTDETR("runs\detect/train\weights/best.pt")
model.predict(source="images",save=True)

推理结果、保存路径与推理速度都会显示在下面

在这里插入图片描述

当然我们还可以指定conf参数,即置信度,可以帮我们筛选一下结果:设置置信度为0.6,此时原本的汽车就不再框选了。

在这里插入图片描述

在这里插入图片描述

视频推理

视频推理也很简单,只需要将原来的图像换为视频即可

model=RTDETR("runs\detect/train\weights/best.pt")
model.predict(source="images/1.mp4",save=True,conf=0.6)

在这里插入图片描述

目标跟踪

在先前的目标跟踪中,都是通过先检测,后跟踪的方式,如采用YOLOv5+DeepSort的方式进行目标跟踪,而在YOLOv8中,他将该功能集成到里面,我们可以直接采用执行跟踪任务的方式完成目标跟踪。

from ultralytics.models import RTDETR
model=RTDETR("runs\detect/train\weights/best.pt")
results = model.track(source="images/1.mp4", conf=0.3, iou=0.5,save=True)

RT-DETR目标跟踪视频

轨迹绘制

from collections import defaultdict

import cv2
import numpy as np
from ultralytics import RTDETR

# Load the YOLOv8 model
model=RTDETR("D:\graduate\programs\yolo8/ultralytics-main/runs\detect/train\weights/best.pt")

# Open the video file
video_path = "images/1.mp4"
cap = cv2.VideoCapture(video_path)

# Store the track history
track_history = defaultdict(lambda: [])

# Loop through the video frames
while cap.isOpened():
    # Read a frame from the video
    success, frame = cap.read()

    if success:
        # Run YOLOv8 tracking on the frame, persisting tracks between frames
        results = model.track(frame, persist=True)

        # Get the boxes and track IDs
        boxes = results[0].boxes.xywh.cpu()
        track_ids = results[0].boxes.id.int().cpu().tolist()

        # Visualize the results on the frame
        annotated_frame = results[0].plot()

        # Plot the tracks
        for box, track_id in zip(boxes, track_ids):
            x, y, w, h = box
            track = track_history[track_id]
            track.append((float(x), float(y)))  # x, y center point
            if len(track) > 30:  # retain 90 tracks for 90 frames
                track.pop(0)

            # Draw the tracking lines
            points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
            cv2.polylines(annotated_frame, [points], isClosed=False, color=(230, 230, 230), thickness=10)

        # Display the annotated frame
        cv2.imshow("YOLOv8 Tracking", annotated_frame)

        # Break the loop if 'q' is pressed
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break
    else:
        # Break the loop if the end of the video is reached
        break

# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1720597.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据库(13)——DQL分组查询

语法 SELECT 字段列表 FROM 表名 [WHERE 条件] GROUP BY 分组字段名 [HAVING 分组后过滤条件] 示例 原始表: 根据性别分组并统计人数 select sex,count(*) from information group by sex; 根据性别分组,并求年龄的平均值:

2024抖音流量认知课:掌握流量底层逻辑,明白应该选择什么赛道 (43节课)

课程下载:https://download.csdn.net/download/m0_66047725/89360865 更多资源下载:关注我。 课程目录 01序言:拍前请看.mp4 02抖音建模逻辑1.mp4 03抖音标签逻辑2.mp4 04抖音推流逻辑3.mp4 05抖音起号逻辑4.mp4 06养号的意义.mp4 0…

Java | Leetcode Java题解之第123题买卖股票的最佳时机III

题目&#xff1a; 题解&#xff1a; class Solution {public int maxProfit(int[] prices) {int n prices.length;int buy1 -prices[0], sell1 0;int buy2 -prices[0], sell2 0;for (int i 1; i < n; i) {buy1 Math.max(buy1, -prices[i]);sell1 Math.max(sell1, b…

Bean作用域和生产周期已经Bean的线程安全问题

bean 的作用域 单例(Singletion) : Spring 容器中只有一个 bean &#xff0c;这个 bean 在整个应用程序内共享。 原话(Prototype) : 每次 getBean()&#xff0c; 都是不同的bean&#xff0c;都会创建一个实例。 请求(Request)&#xff1a;每个HTTP请求都会创建一个新的 Bean …

开发者工具-sources(源代码选项)

一、概要说明 源代码面板从视觉效果上分为三个区域&#xff1a;菜单区、内容区、监听区。 菜单区里面有5个子分类&#xff1a; 网页(Page)&#xff1a;指页面源&#xff0c;包含了该页面中所有的文件&#xff0c;即使多个域名下的文件也都会展示出来&#xff0c;包括iframe…

束测后台实操文档2-OpenWrt

束测后台实操文档1-PVE、PBS 上面文&#xff0c;把proxmox装好并添加好PBS上的镜像存储空间后&#xff0c;还原已经做好的镜像基本上就可以在已有的镜像下开展工作了。 调试的PVE环境一般两个网口&#xff0c;一个外网wan&#xff0c;一个子网lan&#xff0c;虚拟机一般在lan…

【redis】宝塔,线上环境报Redis error: ERR unknown command del 错误

两种方式&#xff1a; 1.打开宝塔上的redis&#xff0c;通过配置文件修改权限&#xff0c;注释&#xff1a;#rename-command DEL “” 2.打开服务器&#xff0c;宝塔中默认redis安装位置是&#xff1a;cd /www/server/redis 找到redis.conf,拉到最后&#xff0c;注释#rename-co…

大语言模型技术系列讲解:大模型应用了哪些技术

为了弄懂大语言模型原理和技术细节&#xff0c;笔者计划展开系列学习&#xff0c;并将所学内容从简单到复杂的过程给大家做分享&#xff0c;希望能够体系化的认识大模型技术的内涵。本篇文章作为第一讲&#xff0c;先列出大模型使用到了哪些技术&#xff0c;目的在于对大模型使…

C++设计模式-策略模式

文章目录 27. 策略模式 运行在VS2022&#xff0c;x86&#xff0c;Debug下。 27. 策略模式 策略模式让算法的选择与使用独立开来&#xff0c;使得代码更灵活、可扩展和易维护。应用&#xff1a;如在游戏开发中&#xff0c;AI角色需要根据环境和条件做出不同的行为&#xff0c;如…

基于云服务器使用DreamBooth训练主体

资源整理 参考教程&#xff1a;StableDiffusion/NAI DreamBooth自训练全教程 - 知乎 (zhihu.com) 云服务器平台&#xff1a;AutoDL算力云 | 弹性、好用、省钱。租GPU就上AutoDL 镜像链接&#xff1a;CrazyBoyM/dreambooth-for-diffusion/dreambooth-for-diffusion、 代码仓…

使用Python操作Git

大家好&#xff0c;当谈及版本控制系统时&#xff0c;Git是最为广泛使用的一种&#xff0c;而Python作为一门多用途的编程语言&#xff0c;在处理Git仓库时也展现了其强大的能力。通过Python&#xff0c;我们可以轻松地与Git仓库进行交互&#xff0c;执行各种操作&#xff0c;从…

为参数设置默认值

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 调用函数时&#xff0c;如果没有指定某个参数将抛出异常&#xff0c;为了解决这个问题&#xff0c;我们可以为参数设置默认值&#xff0c;即在定义函…

Blueprints - Collision Presets相关

一些以前的学习笔记归档&#xff1b; 在Static Mesh或SkeletalMesh等的属性中&#xff0c;都有Collision Presets&#xff1a; 其中Oject Type只是一个枚举参数&#xff0c;代表设置该Actor为什么类型&#xff0c;Collision Responses代表该Actor对各种类型的Actor有什么反应&a…

MYSQL四大操作——查!查!查!

目录 简洁版&#xff1a; 详解版&#xff1a; SQL通用语法&#xff1a; 分类&#xff1a; 1. DDL —库 1.1 查询&#xff1a; 1.2 创建&#xff1a; 1.3 删除 1.4 使用库 2. DDL—表 2.1 查询 2.1.1 查询当前库的所有表&#xff1a; 2.1.2 查询表结构 &#xff1a; 2.1.…

408数据结构-图的存储与基本操作 自学知识点整理

前置知识&#xff1a;图的基本概念 图的存储必须完整、准确地反映顶点集和边集的信息。根据不同图的结构和算法&#xff0c;采用不同的存储方式将对程序的效率产生相当大的影响&#xff0c;因此选取的存储结构应适合于待求解的问题。 图的存储 邻接矩阵法 所谓邻接矩阵存储&a…

Perplexity 搜索引擎刚刚推出了新的页面功能——维基百科可以扔了

Perplexity 允许用户根据搜索结果创建自定义页面 人工智能搜索引擎初创公司 Perplexity 推出了一项新功能&#xff0c;使其结果更具粘性&#xff0c;允许用户将研究转变为易于共享的页面。页面建立在 Perplexity 中现有的人工智能驱动的搜索功能之上&#xff0c;该功能使用与 …

javascript DOM 设置样式

No.内容链接1Openlayers 【入门教程】 - 【源代码示例300】 2Leaflet 【入门教程】 - 【源代码图文示例 150】 3Cesium 【入门教程】 - 【源代码图文示例200】 4MapboxGL【入门教程】 - 【源代码图文示例150】 5前端就业宝典 【面试题详细答案 1000】 文章目录 一、直接…

Mac vm虚拟机激活版:VMware Fusion Pro for Mac支持Monterey 1

相信之前使用过Win版系统的朋友们对这款VMware Fusion Pro for Mac应该都不会陌生&#xff0c;这款软件以其强大的功能和适配能力广受用户的好评&#xff0c;在Mac端也同样是一款最受用户欢迎之一的虚拟机软件&#xff0c;VM虚拟机mac版可以让您能够轻松的在Apple的macOS和Mac的…

单片机原理及应用复习

单片机原理及应用 第二章 在AT89S52单片机中&#xff0c;如果采用6MHz晶振&#xff0c;一个机器周期为 2us 。 时钟周期Tocs1focs 机器周期 Tcy12focs 指令周期&#xff1a;一条指令所用的时间&#xff0c;单字和双字节指令一般为单机器周期和双机器周期。 AT89S5…

代码审计(工具Fortify 、Seay审计系统安装及漏洞验证)

源代码审计 代码安全测试简介 代码安全测试是从安全的角度对代码进行的安全测试评估。&#xff08;白盒测试&#xff1b;可看到源代码&#xff09; 结合丰富的安全知识、编程经验、测试技术&#xff0c;利用静态分析和人工审核的方法寻找代码在架构和编码上的安全缺陷&#xf…