End-to-End Object Detection with Transformers(论文解析)

news2024/11/25 20:30:59

End-to-End Object Detection with Transformers

    • 摘要
    • 介绍
    • 相关工作
      • 2.1 集合预测
      • 2.2 transformer和并行解码
      • 2.3 目标检测
    • 3 DETR模型
      • 3.1 目标检测集设置预测损失
      • 3.2 DETR架构

摘要

我们提出了一种将目标检测视为直接集合预测问题的新方法。我们的方法简化了检测流程,有效地消除了许多手工设计的组件的需求,如显式编码我们关于任务的先验知识的非极大值抑制过程或锚点生成。新框架的主要要素,称为DEtection TRansformer或DETR,包括一个基于集合的全局损失,通过二分图匹配强制执行唯一的预测,以及一个Transformer编码器-解码器架构。给定一组固定的学习目标查询,DETR通过推理对象之间的关系和全局图像上下文,直接并行输出最终的预测。这个新模型在概念上很简单,不需要专门的库,与许多其他现代检测器不同。DETR在具有挑战性的COCO目标检测数据集上表现出与经过充分优化的Faster R-CNN基线相当的准确性和运行时性能。此外,DETR可以轻松推广为以统一的方式生成全景分割。我们展示它明显优于竞争基线。训练代码和预训练模型可在https://github.com/facebookresearch/detr获得。

介绍

目标检测的目标是为每个感兴趣的对象预测一组边界框和类别标签。现代检测器通过在大量的建议区域[37,5]、锚点[23]或窗口中心[53,46]上定义代理回归和分类问题,以间接的方式解决这个集合预测任务。它们的性能受到后处理步骤的显著影响,以折叠近似重复的预测,受到锚点集设计和将目标框分配给锚点的启发式方法的影响。为了简化这些流程,我们提出了一种直接的集合预测方法,绕过了代理任务。这种端到端的哲学已经在复杂的结构化预测任务中取得了重大进展,例如机器翻译或语音识别,但在目标检测领域尚未实现:以前的尝试[43,16,4,39]要么增加了其他形式的先验知识,要么在具有挑战性的基准测试上未能与强基线竞争。本文旨在弥合这一差距。
在这里插入图片描述
图1:DETR通过将常见的CNN与Transformer架构结合在一起,直接(并行地)预测最终的一组检测结果。在训练期间,二分图匹配将预测值唯一地分配给与真值框相匹配的情况。没有匹配的预测应该产生一个“无对象”(∅)的类别预测。

我们通过将目标检测视为直接的集合预测问题,简化了训练流程。我们采用了基于变换器(transformers)[47]的编码器-解码器架构,这是一种用于序列预测的流行架构。变换器的自注意机制明确地模拟了序列中所有元素之间的两两交互作用,使这些架构特别适用于集合预测的特定约束,如去除重复的预测。

我们的DEtection TRansformer(DETR,见图1)一次性预测所有对象,并通过一个集合损失函数进行端到端训练,该函数在预测对象和真值对象之间执行二分图匹配。DETR通过删除多个手工设计的组件,如空间锚点或非极大值抑制,来简化检测流程,这些组件用于编码先验知识。与大多数现有的检测方法不同,DETR不需要任何定制的层,因此可以在包含标准CNN和transformer类的任何框架中轻松复现。

与大多数以前的直接集合预测工作相比,DETR的主要特点是二分图匹配损失和具有(非自回归)并行解码的transformer结合[29,12,10,8]。相比之下,以前的工作侧重于使用RNN进行自回归解码[43,41,30,36,42]。我们的匹配损失函数将预测与真值对象唯一地分配给一个对象,并且不受预测对象排列的影响,因此我们可以并行地发出它们。

DETR的训练设置在多个方面与标准目标检测器不同。新模型需要更长的训练周期,并受益于变换器中辅助解码损失。我们深入探讨了哪些组件对所展示的性能至关重要,
DETR的设计理念很容易扩展到更复杂的任务。在我们的实验中,我们展示了在预训练的DETR之上训练的简单分割头在全景分割(Panoptic Segmentation)[19]上优于竞争基线的结果。全景分割是一项具有挑战性的像素级识别任务,最近变得越来越受欢迎。

相关工作

2.1 集合预测

目前没有一个经典的深度学习模型可以直接预测集合。基本的集合预测任务是多标签分类(请参见例如[40,33]中的参考文献,涉及计算机视觉领域),对于这种任务,基线方法——一对多(one-vs-rest)方法在检测等存在元素之间存在一定结构的问题中不适用(即,存在近似相同的边界框)。在这些任务中的第一个困难是避免近似重复。大多数当前的检测器使用后处理方法,如非极大值抑制,来解决这个问题,但直接集合预测是无需后处理的。它们需要全局推理方案,以模拟所有预测元素之间的交互,以避免冗余。对于固定大小的集合预测,密集全连接网络[9]足够,但成本较高。一种通用的方法是使用自回归序列模型,如循环神经网络[48]。在所有情况下,损失函数应该对预测的排列具有不变性。通常的解决方案是设计基于匈牙利算法[20]的损失函数,以找到真值和预测之间的二分图匹配。这强制执行排列不变性,并保证每个目标元素都有一个唯一的匹配。我们采用了二分图匹配损失方法。然而,与大多数以前的工作不同,我们放弃了自回归模型,而是使用了具有并行解码的变换器,我们将在下面进行描述。

2.2 transformer和并行解码

变换器(Transformers)是由Vaswani等人[47]引入的,作为一种新的基于注意力的机器翻译构建块。注意力机制[2]是神经网络层,可以从整个输入序列中汇总信息。变换器引入了自注意层,类似于非局部神经网络[49],它们会扫描序列中的每个元素,并通过汇总整个序列的信息来更新它。注意力模型的主要优势之一是其全局计算和完美记忆,这使它们在处理长序列时比循环神经网络更适用。在自然语言处理、语音处理和计算机视觉等领域,变换器现在正在取代循环神经网络,应用广泛[8,27,45,34,31]。

transformer首先用于自回归模型,遵循了早期的序列到序列模型[44],逐个生成输出标记。然而,由于推断成本过高(与输出长度成正比,难以批量处理),这导致了并行序列生成的发展,在音频[29]、机器翻译[12,10]、单词表示学习[8]以及更近期的语音识别[6]等领域进行了研究。我们还结合了transformer和并行解码,以在计算成本和执行集合预测所需的全局计算之间找到适当的折衷方案。

2.3 目标检测

大多数现代目标检测方法都相对于一些初始猜测进行预测。两阶段检测器[37,5]根据建议(proposals)预测边界框,而单阶段方法则根据锚点[23]或可能的物体中心网格[53,46]进行预测。最近的研究[52]表明,这些系统的最终性能在初始猜测的确切设置方式上具有很大的依赖性。在我们的模型中,我们能够通过直接预测与输入图像而不是锚点相关的一组检测结果,消除了这个手工制作的过程,并简化了检测过程。

基于集合的损失。一些目标检测器[9,25,35]使用了二分图匹配损失。然而,在这些早期的深度学习模型中,不同预测之间的关系仅使用卷积或全连接层来建模,而手动设计的非极大值抑制后处理可以提高它们的性能。更近期的检测器[37,23,53]在真值和预测之间使用了非唯一的分配规则,同时使用了非极大值抑制。

可学习的非极大值抑制方法[16,4]和关系网络[17]使用注意力明确建模了不同预测之间的关系。使用直接的集合损失,它们不需要任何后处理步骤。然而,这些方法使用额外的手工设计的上下文特征,如建议框坐标,以有效地建模检测之间的关系,而我们寻找减少模型中编码的先验知识的解决方案。

递归检测器。与我们的方法最接近的是用于目标检测[43]和实例分割[41,30,36,42]的端到端集合预测。与我们类似,它们使用基于CNN激活的编码器-解码器架构,使用二分图匹配损失直接生成一组边界框。然而,这些方法仅在小型数据集上进行了评估,而没有与现代基线模型进行比较。特别地,它们基于自回归模型(更精确地说是RNN),因此它们没有利用最近的具有并行解码的变换器模型。

3 DETR模型

在检测中进行直接集合预测需要两个关键因素:(1) 一种集合预测损失,它强制预测的边界框与真值边界框之间具有唯一匹配;(2) 一种体系结构,可以在单次传递中预测一组对象并建模它们之间的关系。我们在图2中详细描述了我们的体系结构。

3.1 目标检测集设置预测损失

DETR通过解码器单次推断出一个固定大小的N个预测,其中N被设置为明显大于图像中典型对象的数量。训练的主要困难之一是如何根据真值对预测的对象(类别、位置、大小)进行评分。我们的损失函数产生了预测对象和真值对象之间的最优二分图匹配,然后优化特定于对象的(边界框)损失。

让我们用y表示真值对象的集合,而ˆy = {ˆyi}N i=1表示N个预测的集合。假设N大于图像中的对象数量,我们也将y视为大小为N的集合,其中包括∅(表示没有对象的占位符)。为了在这两个集合之间找到一个二分图匹配,我们搜索一个具有最低成本的N个元素的排列σ ∈ SN:
在这里插入图片描述
其中Lmatch(yi, ˆyσ(i))是真值yi和索引σ(i)的预测之间的成本。这个最优分配是通过匈牙利算法高效计算的,这是根据之前的工作(例如[43])完成的。

匹配成本考虑了类别预测和预测框与真值框的相似性。真值集合的每个元素i可以看作是yi = (ci, bi),其中ci是目标类别标签(可能为∅),bi ∈ [0, 1]4是一个向量,定义了真值框的中心坐标以及相对于图像大小的高度和宽度。对于具有索引σ(i)的预测,我们将类别ci的概率定义为ˆpσ(i)(ci),并将预测框定义为ˆbσ(i)。使用这些符号,我们将Lmatch(yi, ˆyσ(i))定义为-1{ci=∅}ˆpσ(i)(ci) + 1{ci=∅}Lbox(bi, ˆbσ(i))。其中,1{ci=∅}是指示函数,如果ci不等于∅则为1,否则为0。这个成本函数综合考虑了类别匹配和框匹配。

这种找到匹配的过程在直接集合预测中起到了与现代检测器中用于将提议[37]或锚[22]与真值对象匹配的启发式分配规则相同的作用。主要区别在于,我们需要为直接集合预测找到不包含重复的一对一匹配。
第二步是计算损失函数,即在前一步中匹配的所有成对的匈牙利损失。我们将损失定义为类似于常见目标检测器的损失,即类别预测的负对数似然和稍后定义的框损失的线性组合:
在这里插入图片描述
其中ˆσ是第一步中计算的最优分配(1)。在实践中,当ci = ∅时,我们通过10倍的因子减小对数概率项的权重,以考虑类别不平衡。这类似于Faster R-CNN训练过程通过子采样平衡正样本/负样本提议[37]的方法。请注意,对象与∅之间的匹配成本不依赖于预测,这意味着在这种情况下成本是常数。在匹配成本中,我们使用概率ˆpˆσ(i)(ci)而不是对数概率。这使得类别预测项与Lbox(·, ·)(下文描述)具有可比性,并且我们观察到了更好的经验性能。

边界框损失。匹配成本和匈牙利损失的第二部分是Lbox(·),用于评分边界框。与许多检测器不同,它们根据与一些初始猜测的∆进行边界框预测,我们直接进行边界框预测。尽管这种方法简化了实现,但它在损失的相对缩放方面存在问题。最常用的1损失即使相对误差相似,对小框和大框也有不同的尺度。为了减轻这个问题,我们使用了1损失和广义IoU损失[38]的线性组合Liou(·, ·),这是尺度不变的。总的来说,我们的框损失是Lbox(bi, ˆbσ(i)),定义如下:
λiouLiou(bi, ˆbσ(i)) + λL1||bi − ˆbσ(i)||1,
其中λiou、λL1 ∈ R是超参数。这两个损失都被批次中的对象数量归一化。

3.2 DETR架构

DETR的总体架构出奇地简单,如图2所示。它包含三个主要组件,我们将在下面描述:一个CNN骨干网络用于提取紧凑的特征表示,一个编码器-解码器Transformer,以及一个简单的前馈网络(FFN)用于进行最终的检测预测。

在这里插入图片描述

与许多现代检测器不同,DETR可以在任何提供通用CNN骨干网络和Transformer架构实现的深度学习框架中实现,只需几百行代码。在PyTorch [32]中,可以使用不到50行代码实现DETR的推理代码。我们希望我们的方法的简单性能够吸引新的研究人员加入检测领域。

骨干网络。从初始图像ximg ∈ R3×H0×W0(具有3个颜色通道)开始,传统的CNN骨干网络会生成一个低分辨率的激活图f ∈ RC×H×W。我们通常使用的典型值为C = 2048和H,W = H0 32,W0 32。

Transformer编码器首先,通过1x1卷积将高级别激活图f的通道维度从C减小到较小的维度d,创建一个新的特征图z0 ∈ Rd×H×W。编码器期望以序列作为输入,因此我们将z0的空间维度折叠成一个维度,得到一个d×HW的特征图。每个编码器层都具有标准的体系结构,包括多头自注意力模块和前馈网络(FFN)。由于Transformer架构是排列不变的,我们通过固定的位置编码[31,3]来补充它,这些编码被添加到每个注意层的输入中。我们将详细的架构定义放在了补充材料中,它遵循了[47]中描述的架构。
Transformer解码器。解码器遵循Transformer的标准架构,使用多头自注意力机制和编码器-解码器注意力机制来转换大小为d的N个嵌入。与原始的Transformer不同的是,我们的模型在每个解码器层上并行解码N个对象,而Vaswani等人[47]使用自回归模型,逐个元素地预测输出序列。不熟悉这些概念的读者可以参考补充材料。由于解码器也是排列不变的,因此N个输入嵌入必须不同以产生不同的结果。这些输入嵌入是学习的位置编码,我们称之为对象查询,与编码器类似,我们将它们添加到每个注意力层的输入中。N个对象查询通过解码器转换为输出嵌入。然后,它们通过前馈网络(在下一小节中描述)独立解码为边界框坐标和类标签,生成N个最终的预测。使用这些嵌入上的自注意力和编码器-解码器注意力,模型通过它们之间的成对关系全局推理所有对象,同时能够使用整个图像作为上下文。

预测前馈网络(FFNs)。最终的预测由一个包含ReLU激活函数和隐藏维度d的3层感知器以及一个线性投影层计算。FFN预测了相对于输入图像的标准化中心坐标、高度和宽度,并且线性层使用softmax函数来预测类别标签。由于我们预测一个固定大小的N个边界框,其中N通常远大于图像中感兴趣的实际对象数量,因此额外的特殊类别标签∅ 用于表示某个槽内没有检测到对象。这个类别在标准目标检测方法中起着类似于“背景”类别的作用。
辅助解码损失。我们发现,在训练过程中使用辅助损失[1]特别有帮助,尤其是帮助模型输出每个类别的正确数量的对象。我们在每个解码器层之后添加了预测的前馈网络(FFNs)和匈牙利损失。所有预测的FFNs共享它们的参数。我们使用一个额外的共享层归一化来规范来自不同解码器层的预测FFNs的输入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/988961.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第三方软件信息安全测评服务范围

安全测试 第三方软件信息安全cnas资质测评服务范围: 1、信息安全风险评估 依据《GB/T 20984-2007 信息安全技术信息安全风险评估规范》,通过风险评估项目的实施,对信息系统的重要资产、资产所面临的威胁、资产存在的脆弱性、已采取的防护措…

企业架构LNMP学习笔记24

学习目标和内容: 1、能够描述高可用HA的作用 2、能够理解VIP的切换:虚拟IP。 3、能够描述keepalived作用:保持活跃。主备的服务器的关系。 4、能够理解主master和备backup服务器关系 5、能够实现主备服务器高可用配置:主服务…

合宙Air724UG LuatOS-Air LVGL API控件-开关 (Switch)

开关 (Switch) 示例代码 function event_handler(obj, event)if event lvgl.EVENT_VALUE_CHANGED thenprint("State", lvgl.switch_get_state(obj))end endsw1 lvgl.switch_create(lvgl.scr_act(), nil) lvgl.obj_align(sw1, nil, lvgl.ALIGN_CENTER, 0, -50) lvg…

超强大JS表格:DataViewsJS 1.8.16.1407 Crack

DataViewsJS完整的 JavaScript 数据呈现和数据网格平台 通过从各种不同的演示视图中进行选择,包括树、卡片、砖石、网格、时间线、甘特图、日历和网格,超越传统的表格显示。 快速地!纯 JavaScript,针对速度进行了优化 从多种视图中…

数据结构与算法_树和二叉树

目录 一、树的概念 二、树的衍生概念 三、二叉树 顺序结构 链式存储 二叉树连式结构的遍历 一、树的概念 树是一种非线性的数据结构,它由n(n>0)个有限结点组成一个具有层次关系的集合。 在树中,有一个特殊节点成为根节…

【GO语言基础】前言

系列文章目录 【Go语言学习】ide安装与配置 【GO语言基础】前言 【GO语言基础】变量常量 【GO语言基础】数据类型 【GO语言基础】运算符 文章目录 系列文章目录一、基础知识包和函数函数声明语法简洁性 括号成对出现GO常用DOS命令命名规则项目目录结构注释 总结 一、基础知识 …

The Sandbox 中的建设者:走进 Sandja Studio

随着 The Sandbox 出版新功能的到来,来自世界各地的熟练建设者正在完成必要的步骤,推出他们的可玩体验并与社区分享。 在本期的《The Sandbox 建设者》中,我们将与Sandja Studio一起讨论他们的《星际考古学家》体验。 Sandja Studio 是一家…

ARM接口编程—UART(exynos 4412平台)

UART简介 UART Universal Asynchronous Receiver Transmitter即通用异步收发器,是一种通用的串行、异步通信总线该总线有两条数据线,可以实现全双工的发送和接收在嵌入式系统中常用于主机与辅助设备之间的通信 波特率 波特率用于描述UART通信时的通信…

企业架构LNMP学习笔记25

高可用服务搭建: HA高可用:是一个解决方案。 高可用HA(High Availability)是分布式系统架构中必须考虑的因素之一。它通常是指通过设计,减少系统服务不可用的时间,假设系统一直能够提供服务,我…

Navicat连接openGauss数据库报错

错误信息:fe_sendauth:invalid authentication request from server:AUTH_REQ_SASL_CONT without AUTH_REQ_SASL 解决步骤: 1)关闭防火墙: 切换root用户执行:su - root 输入密码 systemctl status firewalld 查…

用于独立系统应用的光伏MPPT铅酸电池充电控制器建模(Simulink实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

【C++】静态库lib和动态库dll的优缺点、使用方法

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; &#x1f525;c系列专栏&#xff1a;C/C零基础到精通 &#x1f525; 给大…

uniapp 悬浮窗-任意界面 Ba-FloatWinWeb

简介&#xff08;下载地址&#xff09; Ba-FloatWinWeb 是一款支持自定义任意界面的悬浮窗插件。采用webview方式&#xff0c;同时支持本地、网络地址&#xff1b;自带几种界面&#xff0c;可直接使用。 支持显示、更新、隐藏支持记录显示位置支持拖动支持监听点击事件支持自…

查看linux开发板的CPU频率

1&#xff09;查看CPU可设置的频率列表 cat /sys/devices/system/cpu/cpu0/cpufreq/scaling_available_frequencies 2&#xff09;查看CPU当前所使用的频率&#xff1a; cat /sys/devices/system/cpu/cpu0/cpufreq/scaling_cur_freq 3&#xff09;设置CPU频率&#xff08;最高…

linux常用命令及解释大全(三)

目录 前言 一、字符设置和文件格式转换 二、文件系统分析 三、初始化一个文件系统 四、备份 五、光盘 六、网络 总结 前言 本篇文章继续介绍了一部分linux常用命令&#xff0c;包括字符设置和文件格式转换&#xff0c;文件系统分析&#xff0c;初始化一个文件系统&…

17 mysql global_variables session_variables

前言 这是一个关于 mysql 中的一些配置的探索 起因是需要 看一下 mysql 自增长的实现, 这里面涉及到两个变量 auto_increment_increment, auto_increment_offset, 然后 需要探索一下 这两个变量的来历 然后 就有了这里的相关介绍 global_variables 的初始化 global_va…

云数据库知识学习——云数据库产品、云数据库系统架构

一、云数据库产品 1.1、云数据库厂商概述 云数据库供应商主要分为三类。 ① 传统的数据库厂商&#xff0c;如 Teradata、Oracle、IBM DB2 和 Microsoft SQL Server 等。 ② 涉足数据库市场的云供应商&#xff0c;如 Amazon、Google、Yahoo!、阿里、百度、腾讯…

使用Python将网页数据保存到NoSQL数据库的方法和示例

随着大数据和人工智能技术的快速发展&#xff0c;对于大规模数据的处理需求日益增多。NoSQL数据库作为一种新兴的数据存储解决方案&#xff0c;具有高可扩展性、高性能和灵活性数据模型等优势&#xff0c;已经在许多行业得到广泛应用。传统的关系型数据库在处理海量数据时可能会…

我们来看看Kubernetes、Docker、Dockershim、Containerd、runc、CRI、CRI-O、OCI的到底有什么关系?

Kubernetes v1.20版本 的 release note 里说 deprecated docker。并且在后续版本 v1.24 正式删除了 dockershim 组件&#xff0c;这对我们有什么影响呢&#xff1f; 为了搞明白这件事情&#xff0c;以及理解一系列容器名词 docker, dockershim, containerd, containerd-shim, …

第10章 注册字符设备实验(iTOP-RK3568开发板驱动开发指南 )

在上一小节中已经对设备号的相关知识进行了讲解&#xff0c;并成功申请到了设备号&#xff0c;那在Linux系统中&#xff0c;设备号是怎样与字符设备进行关联的呢&#xff1f;字符设备又是怎样注册的呢&#xff1f;带着疑问&#xff0c;让我们开始本章节的学习吧。 10.1 注册字…