目标检测算法之RT-DETR

news2024/10/6 16:19:43

RT-DETR算法理解

  • Background
  • Model Architecture
    • Efficient Hybrid Encoder
    • Uncertainty-minimal Query Selection
  • 总结

Background

Real-time Detection Transformer(RT-DETR)是一个基于tranformer的实时推理目标检测模型。RT-DETR是2023年百度发布的一个新目标检测模型,它兼顾了速度和精度俩个特性,在速度上超越yolo,同时仍保持不低于yolo模型的精度。其分别从encoder部分、query选择俩个方面进行改进,保持了模型的精度,同时提高了模型的推理速度。
在这里插入图片描述
论文地址:https://arxiv.org/pdf/2304.08069
代码地址:https://github.com/lyuwenyu/RT-DETR

Model Architecture

在这里插入图片描述
模型的结构如上图所示,输出图片经过Backbone进行特征提取,获取三个特征图 S 3 、 S 4 、 S 5 S_3、S_4、S_5 S3S4S5。然后将它们输入Efficient Hybrid Encoder层。Efficient Hybrid Encoder层对特征图 S 5 S_5 S5做AIFI获得特征图 F 5 F_5 F5,然后通过CCFF结合 S 3 、 S 4 、 F 5 S_3、S_4、F_5 S3S4F5输出。然后用Uncertainty-minimal Query Selection选取query,再和Encoder的输出一起输入decoder中,最后输出检测结果。

Efficient Hybrid Encoder

作者分析了特征图自交互的情况,认为低级特征具备丰富的图像语义,交互的需求不大。同时通过实验验证了这一观点。这里的出发点是从缩短输入的AIFI的长度出发,由于计算复杂度与长度的平方成正比,由于高级特征的长度较小,所以计算量较少,同时能够验证低级特征交互是不必要,那么就可以较少这一部分的计算。
整个Efficient Hybrid Encoder模块可以用公式表达出来,即 Q = K = V = F l a t t e n ( C 5 ) F 5 = R e s h a p e ( A I F I ( Q , K , V ) ) O = C C F F ( { S 3 , S 4 , F 5 } ) \begin{align*}Q =& K=V = Flatten(C_5)\\F_5 = &Reshape(AIFI(Q,K,V))\\O=&CCFF(\{S_3,S_4,F_5\})\end{align*} Q=F5=O=K=V=Flatten(C5)Reshape(AIFI(Q,K,V))CCFF({S3,S4,F5})这里就是将 C 5 C_5 C5打平,然后输入AIFI中,AIFI是一个普通的transformer encoder模块,然后复原获得特征图 F 5 F_5 F5。然后将三个特征图输入CCFF模块中。官方的CCFF图看起来有些许不明显,所以这里重新画了一下这块,可能让读者更好地了解CCFF,具体见下图。
在这里插入图片描述
CCFF模块其实就是类似于yolo neck中的FPN+PAN,用于融合不同尺度的特征图。这里主要了解一下Fusion的结构,论文中给出了fusion的结构图,具体如下
在这里插入图片描述
Fusion的结构采用了CSP的方法,将输入的特征concat后用1x1的卷积分成了俩份,然后一边经过RepBlock,另一边直接与RepBlock输出直接concat,然后经过flatten层输出。
接下来结合一下源码分析一下CCFF的结构,下面的代码来自hybrid_encoder.py

        inner_outs = [proj_feats[-1]] #获取特征图F5
        for idx in range(len(self.in_channels) - 1, 0, -1): #总共俩层,即idx为2,1
            feat_high = inner_outs[0] #第一次遍历为F5
            feat_low = proj_feats[idx - 1] #第一次遍历为S4
            feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high)#这一部分就是图中的黄色模块,由1x1的卷积+BN层+SiLU组成,第一次遍历时处理F5
            inner_outs[0] = feat_high
            upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest') #第一次遍历对经过lateral_conv的F5做上采样
            inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1)) #这里就是论文中的fusion模块
            inner_outs.insert(0, inner_out)   #相信集合图形可以很好地理解,第二次的遍历对着图就可以了

        outs = [inner_outs[0]]
        for idx in range(len(self.in_channels) - 1): #这里也是遍历俩次
            feat_low = outs[-1] #获得FPN的最后一层输出
            feat_high = inner_outs[idx + 1] #第二次lateral_conv的输出 
            downsample_feat = self.downsample_convs[idx](feat_low) #上采样
            out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1)) #经过fusion模块
            outs.append(out) #这里也是分析了第一次遍历,第二次也是类似的

Uncertainty-minimal Query Selection

作者分析认为,以往选择query时未同时考虑分类和回归的结果,所以导致模型的预测结果中,并不是分类和回归都是最优。所以它为了降低这种不确定性,在query的选择中加入整个因素,即衡量不确定性定义为 U ( x ^ ) U(\hat{x}) U(x^),其中 U ( x ^ ) = ∣ ∣ P ( x ^ ) − C ( x ^ ) ∣ ∣ U(\hat{x}) = ||P(\hat{x})-C(\hat{x})|| U(x^)=∣∣P(x^)C(x^)∣∣其中 x ^ \hat{x} x^为encoder的输出, P P P位置预测, C C C指分类预测。
然后在最后的损失中加上 U U U,即 L ( x ^ , y ^ , y ) = L b o x ( b ^ , b ) + L c l s ( U ( x ^ ) , c ^ , c ) \mathcal{L}(\hat{x},\hat{y},y) = \mathcal{L} _{box}(\hat{b},b)+ \mathcal{L} _{cls}(U(\hat{x}),\hat{c},c) L(x^,y^,y)=Lbox(b^,b)+Lcls(U(x^),c^,c)这里的思想其实就是做了一个分类和回归的对齐,核心上就是分类分数高回归结果也要准。在源码的具体实现中,采用了VFL的方法,VFL公式具体如下 V F L ( p , q ) = { − q ( q log ⁡ ( p ) + ( 1 − q ) log ⁡ ( 1 − p ) ) q > 0 − α p γ log ⁡ ( 1 − p ) q = 0 VFL(p,q)=\left\{\begin{matrix}-q(q\log(p)+(1-q)\log(1-p))&q>0\\ -\alpha p^{\gamma}\log(1-p) &q=0\end{matrix}\right. VFL(p,q)={q(qlog(p)+(1q)log(1p))αpγlog(1p)q>0q=0其中 q q q为预测框的iou, p p p则为分类概率。
源码中的实现如下

    def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True):
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)

        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
        ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))
        ious = torch.diag(ious).detach()

        src_logits = outputs['pred_logits']
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o
        target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]

        target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
        target_score_o[idx] = ious.to(target_score_o.dtype)
        target_score = target_score_o.unsqueeze(-1) * target

        pred_score = F.sigmoid(src_logits).detach()
        weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
        
        loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none')
        loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
        return {'loss_vfl': loss}

总结

对RT-DETR的encoder部分,整体看下来像是yolo的backbone+neck。RT-DETR的核心还是在增速上,所以这里它的优化思想是值得借鉴的,但是yolo结构跟DETR结构之间的界限越来越模糊了。对query的优化上,只是做了对齐,使其选择的query更加精确。整体而言模型的创新不大。虽然DETR提倡的是NMS-Free,但是对于某些对精装度要求较高的任务中,如果阈值设置过低,导致最后得出的框过多,仍然需要借助NMS的方法去改进。设置过高则存在丢框的问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1872119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

七天速通javaSE:第五天 数组进阶

文章目录 前言一、二维数组二、Arrays类1.toString打印数组内各元素1.1 示例1.2 自己实现内部逻辑 2. sort升序排列3. fill数组填充(重新赋值)4.equals比较数组元素是否相等 三、冒泡排序 前言 本文将学习二维数组、arrays类以及冒泡排序 一、二维数组 …

重生奇迹MU新手攻略:如何一步步往大佬发展

装备强化攻略: 提纯装备:通过提纯装备可以提升基础属性,选择合适的装备进行提纯可以获得更好的效果。 镶嵌宝石:使用宝石进行装备镶嵌可以增加装备的属性,根据需要选择适合的宝石进行镶嵌。 洗练装备:通…

基于盲信号处理的声音分离——最大化信噪比的ICA算法

基于最大化信噪比的ICA算法是一种较新模式的ICA算法,在该算法中利用输出信号的信噪比建立信噪比函数作为该算法的代价函数。 在上式中,用S表示原信号,Y表示输出信号。由于原信号S并不知道,因此采用估计信号Y的滑动平均 来代替&…

激励视频广告的eCPM更高,每天的展示频次有限制吗?

在APP发展初期,由于DUA量级有限,所需的广告资源比较少,往往接入1-2家广告平台就能满足APP用户每日需要的广告展示量。而随着APP用户规模的扩大、广告场景的不断丰富,开发者要提升APP整体广告变现收益,一是可以尽可能多…

PLC数据采集案例

--------天津三石峰科技案例分享 项目介绍 项目背景 本项目为天津某钢铁集团下数字化改造项目,主要解决天津大型钢厂加氢站数字化改造过程中遇到的数据采集需求。项目难点PLC已经在运行了,需要采集里面数据,不修改程序,不影响P…

3D立体卡片动效(附源码)

3D立体卡片动效 欢迎关注:xssy5431 小拾岁月参考链接:https://mp.weixin.qq.com/s/9xEjPAA38pRiIampxjXNKQ 效果展示 思路分析 需求含有立体这种关键词,我们第一反应是采用动画中的平移、倾斜等实现。如果是立体,必然产生阴影&…

浅谈制造业EHS管理需要关注的重点

在快速发展的制造业中,EHS(环境、健康、安全)管理体系如同一道坚实的屏障,守护着企业的绿色与安全。那么,这个管理体系到底包含哪些内容呢?接下来,让我们一同探寻其奥秘。 一、EHS管理体系的丰富…

你的钱花得值不值?简谈FMEA培训的投资与回报

在探讨 FMEA(失效模式及影响分析)培训是否值得投资时,需要综合考虑多个方面。 从投资的角度来看,FMEA 培训通常需要一定的费用支出,包括培训课程的费用、培训期间员工的时间成本以及可能涉及的培训材料和设备成本。 然…

利用MMDetection将单阶段检测器作为Faster R-CNN的RPN

将单阶段检测器作为RPN 一、在 Faster R-CNN 中使用 FCOSHead 作为 RPNHead与原始配置的对比结果Neck (FPN)RPN HeadROI Head学习率 使用单阶段检测器作为RPN的优势1. 速度提升2. 准确性3. 简化架构4. 灵活性 二、评估候选区域三、用预先训练的 FCOS 训练定制的 Faster R-CNN 本…

Excel单元格输入逐字动态提示可选输入效果制作

Excel单元格输入逐字动态提示可选输入效果制作。INDEX函数整理动态列表,再配合IF函数干净界面,“数据验证”完成点选。 (笔记模板由python脚本于2024年06月27日 22:26:14创建,本篇笔记适合喜欢用Excel处理数据的coder翻阅) 【学习的细节是欢悦…

【数据集划分——针对于原先图片已经整理好类别】训练集|验证集|测试集

目标:用split-folders进行数据集划分 学习资源:https://www.youtube.com/watch?vC6wbr1jJvVs 努力的小巴掌 记录计算机视觉学习道路上的所思所得。 现在已经有了数据集,并且,注意,是已经划分好类别的! …

基于ARM的通用的Qt移植思路

文章目录 实验环境介绍一、确认Qt版本二、确认交叉编译工具链三、配置Qt3.1、修改qmake.conf3.2、创建autoConfig.sh配置文件 四、编译安装Qt五、移植Qt安装目录六、配置Qt creator6.1、配置qmake6.2、配置GCC编译器6.3、配置G编译器6.4、配置编译器套件6.5、创建应用 七、总结…

MySQL 主从复制集群高可用

在实际的生产环境中,如果对数据库的读和写都在同一个数据库服务器中操作,无论是在安全性、高可用性还是高并发等各个方面都是完全不能满足实际需求的。因此,一般来说 都是通过主从复制(Master-Slave)来同步数据&#x…

微信小程序毕业设计-线上教育商城系统项目开发实战(附源码+论文)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:微信小程序毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计…

基于STM32F103最小系统板和DL-LN33 2.4G通信 ZigBee无线串口自组网采集温湿度

文章目录 前言一、组网概述二、产品特性三、电气特性四、引脚配置五、UART通信协议5.1 UART参数5.2 包分割5.3 端口5.4 举例通信5.4.1 一个节点给另一个节点发送数据5.4.2 一个节点给另一个节点的内部端口发送数据5.4.3 一个节点给自己的内部端口发送数据5.4.4 不推荐的数据传输…

【单片机毕业设计选题24033】-基于STM32的智能饮水机设计

系统功能: 系统上电后显示“欢迎使用智能饮水系统请稍后”两秒后进入正常显示页面。 第一页面第一行显示“系统状态信息”,第二行显示温湿度信息,第三行显示 水温&水位值,第四行显示系统状态(锁定或解锁状态)。…

World of Warcraft [CLASSIC] Level 70 Dire Maul (DM)

[月牙钥匙] [大型爆盐炸弹] World of Warcraft [CLASSIC] Level 70 厄运之槌,完美贡品,Dire Maul (DM) Foror‘s Compendium of Dragon Slaying 佛洛尔的屠龙技术纲要 因为不是兽王宝宝,而且开始位置放的不对&am…

【python011】经纬度点位可视化html生成(有效方案)

1.熟悉、梳理、总结项目研发实战中的Python开发日常使用中的问题、知识点等,如获取省市等边界区域经纬度进行可视化,从而辅助判断、决策。 2.欢迎点赞、关注、批评、指正,互三走起来,小手动起来! 3.欢迎点赞、关注、批…

输出100以内的质数

质数&#xff1a;只能被1和自身整除的数 let count; for(let i2; i<100; i){for(let j1; j<i; j){if(i % j 0){// 只要能被整除&#xff0c;count就加1count;}} if(count 2) {// 从1到自身被整除完之后&#xff0c;如果count只有两次&#xff0c;则说明i为质数co…

应急响应靶机-Linux(1)

前言 本次应急响应靶机采用的是知攻善防实验室的Linux-1应急响应靶机 靶机下载地址为&#xff1a; https://pan.quark.cn/s/4b6dffd0c51a 相关账户密码&#xff1a; defend/defend root/defend 解题 第一题-攻击者的IP地址 先找到的三个flag&#xff0c;最后才找的ip地址 所…