Transformer学习之DETR

news2024/12/22 16:05:44

文章目录

  • 1.算法简介
    • 1.1 算法主要贡献
    • 1.2 算法网络结构
  • 2.损失函数设计
    • 2.1 二分图匹配(匈牙利算法)
    • 2.2 二分图匹配Loss_match
    • 2.3 训练Loss_Hungarian
  • 3.网络核心模块
    • 3.1 BackBone模块
    • 3.2 空间位置编码(spatial positional encoding)
      • 3.2.1 输入与输出
      • 3.2.2 空间位置编码原理
    • 3.3 TransFormer之Encoder模块
      • 3.3.1 输入与输出
    • 3.4 TransFormer之Decoder模块
      • 3.4.1 object queries的理解
      • 3.4.2 多头自注意力机制
      • 3.4.3 多头cross attention机制
    • 3.5 预测头

1.算法简介

Detection Transformer(DETR) 首次将Transformer拓展到目标检测领域中,DETR抛弃了几乎所有的前处理和后处理操作,不需要进行设计锚框来提供参考,或者利用非极大值(NMS)抑制来筛除多余的框,使模型做到了真正的End-to-End检测,即对于输入的任意图像统一输出N个带有类别和置信度的Box结果。
参考链接1:沐神论文精读
参考链接2:搞懂DEtection TRanformer(DETR)
参考链接3:详细解读DETR
参考链接4:Transformer中的position encoding

1.1 算法主要贡献

  1. 提出一种新的目标函数,通过二分图匹配的方式为每个目标输出独一无二的预测,避免更多冗余检测框的出现
  2. 首次将Transformer拓展到目标检测领域
  3. 在Decoder部分设计一种可学习的object queries,将其和全局图像信息结合在一起,通过不断地进行注意力操作,使得模型可以直接并行地输出最后一组预测框
  4. 可以简单的拓展到其他任务(如全景分割),只需要修改预测头(prediction heads)即可

1.2 算法网络结构

在这里插入图片描述
如上图所示,DETR主要包括以下几个部分:

  1. Backbone模块:将输入图像利用卷积神经网络(CNN)映射为特征图
  2. Transformer Encoder模块:输入特征向量和空间位置编码,输出相同维度的全局特征向量,描述了每个patch或者像素与全局图像下的其他patch之间的关系
  3. Transformer Decoder模块
    1. 输入Encoder模块得到的特征向量空间位置编码和固定数量 N = 100 N=100 N=100Object queries,输出 N = 100 N=100 N=100个固定的预测结果
    2. 二分图匹配:对Object queries和真实的标签GroundTruth,利用二分图匹配筛选出和GroundTruth对应的object query用于计算损失(在推理时,会设置一个阈值,将大于阈值的object query作为检测结果)
  4. 利用计算的损失(计算类别和预测框的损失)反向更新卷积神经网络(CNN)和Transformer模型的参数。
  5. 预测头Prediction heads: FFN 是由具有 ReLU 激活函数且具有隐藏层的3层线性层计算的,或者说就是 1 × 1 1\times1 1×1卷积。FFN 预测框标准化中心坐标高度和宽度,然后使用 softmax 函数激活获得预测类标签。

2.损失函数设计

前面提到,在训练过程中Transformer Decoder模块会创建𝑁=100个object queries,但是正常情况下检测目标的数量会小于100这个值,如果想要计算损失函数必须有一一对应的预测值和GroundTruth。

本文使用了一种二分图匹配方法,在object queries集合中寻找和GroundTruth对应的object query,然后再计算损失。

2.1 二分图匹配(匈牙利算法)

想要计算Loss,必须找到一组一一对应的预测值和真值,假设我们现在有两个sets:

  • 左边的sets是模型预测得到的N个object query,每个元素里有一个bbox和对这个bbox预测的类别的概率分布,预测的类别可以是空,用 ϕ \phi ϕ来表示;
  • 右边的sets是我们的ground truth,每个元素里有一个标签的类别和对应的bbox,如果标签的数量不足N 则用 ϕ \phi ϕ来补充, ϕ \phi ϕ可以认为是background。

两边sets的元素数量都是N,所以我们是可以做一个配对的操作,让左边的元素都能找到右边的一个配对元素,每个左边元素找到的右边元素都是不同的,也就是一一对应。这样的组合可以有 N ! N! N!

在这里插入图片描述
对每一组匹配都可以根据class概率和bbox计算得到一个损失,匈牙利算法就是寻找使得总得损失最小时对应的一组object query和ground truth,代码中使用linear_sum_assignment函数来完成这项任务。

2.2 二分图匹配Loss_match

要注意:训练时使用的Loss_match和寻找最优匹配时的Loss_Hungarian不是同一个!!!

在寻找最优匹配关系时,需要提前计算好object query和ground truth之间的损失,这里的损失主要包括两部分:类别损失和bbox损失
在这里插入图片描述

  • 类别损失: 前者表示类别预测部分,和Loss_Hungarian不同之处在于取消了-log操作,目的是为了让计算值区间和bbox损失保持一致
  • bbox损失: 由于文章中所使用的方法是没有预先设计好的anchor的,是直接预测bbox的,所以如果像其他方法那样直接计算 L 1 L_1 L1loss的话,就会导致对于大的框和小的框的惩罚力度不一致,所以文章在使用 L 1 L_1 L1loss的同时,也使用了scale-invariant的IoU loss
    在这里插入图片描述
    最后使用匈牙利算法找到使得上述损失最小时对应的匹配关系。

2.3 训练Loss_Hungarian

要注意:训练时使用的Loss和寻找最优匹配时的Loss不是同一个!!!

训练时使用的Loss同样也是包括类别损失和bbox损失两个部分,区别在于在计算类别损失中保留了−log操作

在这里插入图片描述
也就是说我们在找match的时候,把和ground truth类别一致的,且bbox最接近的预测结果对应上就完事了,其他那些 ϕ \phi ϕ,模型预测出来啥,我match并不关心。但是在算训练模型的Hungarian loss时,就不一样了,我不希望模型会预测出乱七八糟的结果, ϕ \phi ϕ就是 ϕ \phi ϕ,没有就是没有,别整得似有似无的,该 ϕ \phi ϕ的时候预测出东西了,就要惩罚你。因为我预测的时候可是没有ground truth的,我没法知道哪几个是对的了。

3.网络核心模块

在这里插入图片描述

3.1 BackBone模块

BackBone模块主要包括一个CNN卷积层和一次 1 × 1 1 \times 1 1×1卷积。

  1. 将图像(维度为 3 × H × W 3 \times H \times W 3×H×W)输入至卷积神经网络(比如说ResNet50),经过五次尺寸上的缩减(每次降为原来1/2)后,输出维度为 2048 × H 32 × W 32 2048 \times \frac{H}{32} \times \frac{W}{32} 2048×32H×32W的特征图
  2. 利用 1 × 1 1\times1 1×1卷积层将CNN输出的特征图的维度降低至256,记做输入特征图,其维度为 256 × H 32 × W 32 256 \times \frac{H}{32} \times \frac{W}{32} 256×32H×32W
    因为最终需要送到Transformer模块,需要将2048压缩到合适的大小

3.2 空间位置编码(spatial positional encoding)

1. Transformer模块简述
Transformer模块中的核心是注意力机制,如自注意力机制、交叉注意力机制等,已经知道在注意力机制中会为输出的向量分别创建额QKV的值,通过相应的计算可以得到每个输入序列与其他序列之间的关联程度,因此Transformer模块可以学习到图像全局的特征。

2. 什么是位置编码
位置编码同样也描述了不同向量之间的位置关系,可以将其理解为一个权重信息,两个向量在图像中距离近相对位置编码对应权重就大,反之就越小。这种权重可以作为一种额外的信息用于计算输入向量的关联程度,即QKV的值。

3. DETR中的空间位置编码
spatial positional encoding是作者自己提出的二维空间位置编码方法,该位置编码分别被加入到了encoder的self attention和decoder的cross attention,同时下文中的object queries也被加入到了decoder的两个attention中。
在这里插入图片描述

3.2.1 输入与输出

和Transformer的常规操作一样,需要为输入序列的每个patch生成一个位置编码。位置编码的大小为 256 × H 32 × W 32 256\times \frac{H}{32} \times \frac{W}{32} 256×32H×32W,将其与输入特征图按位相加,其相加后维度依旧是 256 × H 32 × W 32 256\times \frac{H}{32} \times \frac{W}{32} 256×32H×32W

3.2.2 空间位置编码原理

在DETR代码中提供了两种位置编码模式:正弦位置编码可学习位置编码,原理和TransformerSwin相对位置偏置类似,只不过在计算偏置时采用了正余弦函数。得到空间位置偏置矩阵后,在计算向量之间的关联程度时就可以通过向量位置索引找到两个向量之间的空间偏置权重,将其加入到关联程度的计算中。

我们假设上文经过backbone模块得到的特征向量为维度时3x10,即长度为3维度为10,基于这一假设为其计算空间位置编码。

1. 生成Mask和反Mask:
假设图像的维度为3x3,Mask的维度设置为4x4。

下图为mask生成的4x4维度的矩阵,根据对应与输入图像大小3*3生成以下的mask编码tensor,下右图为反mask编码tensor,这一步就得到了图像的大小及对应与mask下的位置。

在这里插入图片描述
2. 生成Y_embed和X_embed的tensor
Y_embed对为mask编码True的进行行方向累加1,X_embed对为mask编码True的进行列方向累加1
在这里插入图片描述
3. 分别计算pos_x以及pos_y
这里使用正余弦编码,对奇数位置采用正弦编码,对偶数位置采用余弦编码,公式如下:
在这里插入图片描述
公式中的pose指第2步计算的Y_embed和X_embed,因为此时假设的特征向量的维度为10,i指position所在的维度取值为 [ 0 , 9 ] [0,9] [0,9] d m o d e l d_{model} dmodel指维度大小,则 d m o d e l = 10 d_{model}=10 dmodel=10,这样一来对于不同的维度在计算正余弦编码时的分母都不同,所以在代码中首先计算分母的值,即 1000 0 2 i / d m o d e l 10000^{2i/d_{model}} 100002i/dmodel

在得到每个维度的分母后,根据上述公式分别为Y_embed和X_embed计算pos信息.

在这里插入图片描述
4. 组合pos_x和pos_y
因为上述位置编码的生成是行列方向分开的,这一步需要进行组合。
在这里插入图片描述
5. 计算位置编码
在组合pos_x和pos_y后将其带入正余弦编码公式,可以得到最终关于特征向量的位置编码。 组合以后会发现16个位置的分母已经根据pos的不同,达到了位置编码的不同,因为本文采用的是10维的position,分子i的范围为0-10,每个位置就形成了1*20的tensor数据,包括pos_x的10维和pos_y的10维。

回过头来看,最开始的特征向量维度维 3 × 10 3\times10 3×10,Mask的维度为 4 × 4 × 10 4\times4\times10 4×4×10,最终得到的位置编码的维度是 4 × 4 × 20 4\times4\times20 4×4×20,包括pos_x的10维和pos_y的10维。

上述两个位置的编码就可以理解为1*20的tensor数据,因为比较长,分开写了,不是4*5的,而是1*20的tensor数据,通过上图可以很直观的理解position encoding。

在这里插入图片描述
在得到位置编码信息后,则可以通过索引找到两个特征向量之间的位置编码。

3.3 TransFormer之Encoder模块

Enconder模块和VIT基本一致,这里不再作过多赘述。

在位置编码结束后会进行6次Encoder的串联操作,然后再进入Decoder模块。每个编码模块由:多头自注意力机制+ 残差add & 层归一化LayerNorm + 前馈网络FFN + 残差add & 层归一化LayerNorm组成

  • 多头自注意力机制:核心部分,原理同VIT模块
  • add+LayerNorm:经过多头自注意力机制后再与输入相加,并经过层归一化LayerNorm,即在最后一个维度C上做归一化
  • 前馈网络FFN:是由两个全连接层+ReLu激活函数组成

在这里插入图片描述

3.3.1 输入与输出

输入为经过卷积操作提取的特征向量空间位置编码信息

将输入特征图reshape成 ( H 32 × W 32 ) × 256 (\frac{H}{32} \times \frac{W}{32})×256 (32H×32W)×256(在图中为850x56)大小喂给Transformer编码器,输出同大小的特征图),其维度依旧是 ( H 32 × W 32 ) × 256 (\frac{H}{32} \times \frac{W}{32})×256 (32H×32W)×256,但此时的输出中的每个向量都包含了与其他输入向量之间的权重关系;

3.4 TransFormer之Decoder模块

Decoder模块输入有三个:(1)Encoder模块输出的特征向量 ( H 32 × W 32 ) × 256 (\frac{H}{32} \times \frac{W}{32})×256 (32H×32W)×256;(2)一组Object queries 100 × 256 100\times256 100×256;(3)空间位置编码 ( H 32 × W 32 ) × 256 (\frac{H}{32} \times \frac{W}{32})×256 (32H×32W)×256

TransFormer Decoder由6个解码模块组成,每个解码模块由多头自注意力机制+残差add&层归一化LayerNorm+多头cross attention机制+add&LayerNorm+前馈网络FFN+add&LayerNorm

在这里插入图片描述

3.4.1 object queries的理解

在这里插入图片描述
1. 什么是object queries?
前面讲到DERT的主要贡献之一是丢弃了Anchor设置以及NMS等操作,真正实现了End-to-End的目标检测,而这一创新则是得益于object queries,从Transformer整体上讲,object queries提供了注意力机制QKV中的Q向量,Encoder模块提供K与V,然后使用QKV进行注意力操作。

2. object queries的初始化与更新
在DETR模型中,Object query是一组可学习的向量,维度为 100 × 256 100\times256 100×256。这些向量通常初始化为零向量,然后在训练过程中通过反向传播进行优化(即Object query连同KV一起被写在损失函数中,在训练迭代时进行优化)。

3. Object query与Anchor的联系
对于每一个Object query,我们希望他可以预测一个物体类别的概率并回归出一个bbox的位置。object queries是预定义的目标查询的个数,代码中默认为100。

它的意义是:根据Encoder编码的特征,Decoder将100个查询转化成100个目标,即最终预测这100个目标的类别和bbox位置。

综上可以看到Object query其实是起到了Anchor的作用的,Anchor是一组提前进行设计的好的bbox,然后将预测和提前定义好的bbox进行对比;而Object query则是生成一组可以学习的数量固定的目标框,然后使用匹配机制寻找和GT的匹配关系然后进行对比

4. Object query的可视化
下图为在COCO验证集中每个Object query对应的预测框中心点分布情况,不同点的颜色表示不同大小的框,绿色表示小框,红色表示横向的的大框,蓝色表示竖向点的大框。

借助参考链接1中的解释,以第一个Object query为例,当算法训练得到Object query向量后,当有新的图像进来时,第一个Object query会检测左侧是否存在小框,中间是否存在大的框,每个Object query都有自己的检测方式。

因为COCO数据集本身的特点,目标在图像中都占据了比较大的空间,每个Object query都会检测是否存在横向的较大物体。当然Object queries中预测框中心点的分布和数据集相关,如果换成其他的数据集会得到不同的分布。
在这里插入图片描述

3.4.2 多头自注意力机制

在Object queries输入到Decoder模块后,首先其本身会进行一次自注意力操作。其目的是为了尽可能的移除冗余框,经过自注意力操作每个Object query互相交换信息,可以了解到各自的检测框在什么位置,进而避免多个Object query预测同一个位置的情况。

例如对于图像中的位置1,如果Object query1会针对这个位置进行预测,则其他的Object queries则不再位置1进行预测。

3.4.3 多头cross attention机制

在Decoder模块,其输入为两个不同的数据源,即经过Encoder的输入向量850*256,以及Object queries100*256,因此选择cross attention进行Decoder模块的注意力操作。

3.5 预测头

在经过TransFormer的decoder模块后,会得到一个shape=[N, 100, C]的向量,其中N表示BatchNum,100即为Object queries的数量每张图像都会预测100个目标,C为预测的100个目标的类别数+1(背景类)以及bbox位置(4个值)。

得到预测结果以后,将object predictions和ground truth box之间通过匈牙利算法进行二分匹配:假如有K个目标,那么100个object predictions中就会有K个能够匹配到这K个ground truth,其他的都会和“no object”匹配成功,使其在理论上每个object query都有唯一匹配的目标,不会存在重叠,所以DETR不需要nms进行后处理。

分类loss采用的是交叉熵损失,针对所有predictions;

bbox loss采用了L1 loss和giou loss,针对匹配成功的predictions

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1985146.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【教学类-74-01】袜子配对01(UIBOT图片背景填充白色)

背景需求: 前期用PS修图(灰色背景变成白色背景),200张图片用了6个小时。 【教学类-73-02】20240805广口瓶(宽口瓶)02-CSDN博客文章浏览阅读744次,点赞17次,收藏20次。【教学类-73-…

鸿萌成功案例:Lenovo SystemX 3650M5 MT:5462 数据“起死回生”

鸿萌数据恢复中心,自 2003 年创立伊始,便凭借其出类拔萃的专业数据恢复技术,在形形色色、错综复杂的数据恢复情境中展露了令人叹服的强大实力,铸就了数不胜数的成功范例。涵盖的情形包括但不限于服务器突发故障、硬盘意外损毁、文…

AI 汹涌而至!三波冲击下将淘汰大部分程序员

作者:老余捞鱼 原创不易,转载请标明出处及原作者。 写在前面的话: 最近到处都能听到“人工智能会不会在不久的将来取代程序员”的争论。本人的观点为:人工智能将会取代程序员,本文将对此予以阐述。(注&…

指针小课堂

目录 一.内存和地址 二.指针变量和地址 1.取地址操作符(&) 2.指针变量和解引⽤操作符(*) 2.1指针变量 2.2如何理解指针类型 2.3解引用操作符 2.4 指针的解引用 2.5.不同指针类型的运加减性质 2.5.1指针与整数相加&am…

写一个gradio录音的webui界面并展现波形图

如图下:这是需求 要创建一个 Gradio 录音的 Web UI 界面,你可以使用 Gradio 的 Audio 组件来实现。下面是一个简单的示例,展示了如何创建一个 Gradio 应用程序,其中包含一个录音按钮,用户可以录制音频并提交给服务器处…

JVM详解(个人学习笔记)

前言 本篇文章为我个人在学习JVM时所记录的笔记,内容把部分来自《深入理解java虚拟机》一书,笔记中总结了JVM中一些比较重要的知识点并作出了自己的解释。 java运行时数据区域 程序计数器(线程内私有) 程序计数器(P…

Java每日一练_模拟面试题4(volatile和synchronized)

volatile加原子操作能取代synchronized和锁吗?答案是否定的。它能保证单操作原子性,对任意单个volatile变量的读写具有原子性,但对于复合操作不保证原子性,如x。

智慧公厕系统的重要性与发展

在城市发展的进程中,智慧公厕系统正逐渐成为一项不可或缺的重要设施。智慧公厕系统利用信息技术和物联网等先进手段,将公共厕所的建设、使用、运营和管理进行信息化整合与优化,实现了公厕运行的高效、智能和可持续发展。 智慧公厕系统的重要性…

MySQL —— CRUD

CRUD CRUD 即增加(Create)、查询(Retrieve)、更新(Update)、删除(Delete)四个单词的首字母缩写。 我们常说增删查改,增删改查… 这里我们的增删查改是对表格的数据行进行操作的~~ 新增 1.1.1 单行数据 全列插入 插入一行新数据行,使用 insert into t…

【Bug记录】函数错误匹配,非法的间接寻址

项目场景: 当我写模拟vector的时候,写出下面测试代码准备稍微测试一下新写的构造函数 新写的构造函数,n个value构造 问题描述 当写出上面测试代码的时候,会报错: 这是什么鬼??&#xff1f…

【老张的程序人生】我命由我不由天:我的计算机教师中级岗之旅

在计算机行业的洪流中,作为一名20年计算机专业毕业的博主,我深知这几年就业的坎坷与辉煌。今天,我想与大家分享我的故事,一段关于梦想、挑战与坚持的计算机教师中级岗之旅。希望我的经历能为大家提供一个发展方向,在计…

CCRC-CISAW信息安全保障人员证书含金量

在数字化时代背景下,CISAW认证受到越来越多个人的青睐。 特别是在互联网技术高速发展的今天,随着5G技术的广泛应用,市场对CISAW专业人才的需求急剧增加。 这种职业不仅地位显著,而且职业生涯相对较长。 目前市场上,…

SAP MIGO新增字段 自定义字段

效果 原先是没有的 清单里面找了没有 自定义字段 待新增字段 F1打开200 screen 加字段 zzplusl

非负数(0和正数) 限制最大值且保留两位小数,在elementpuls表单中正则自定义验证传更多参数

一、结构 <el-form-item label="单价:" prop="price"><el-inputv-model.trim="formData.price"placeholder="请输入"><template #append>(元)</template></el-input></el-form-item>二、验证方…

一个为90后设计的Shell,早知道,当年学Shell也不至于那么痛苦了,Star 25K+!

一个现代、用户友好的命令行界面&#xff0c;以其智能特性、语法高亮、实时自动建议、花式标签补全、直观的历史搜索和跨平台支持而著称。它提供了一个美观、易用且功能丰富的Shell环境&#xff0c;旨在简化Shell命令行操作&#xff0c;提高用户的工作效率。号称一个为90后设计…

数据库|SQLServer数据库:企业管理器的使用

哈喽&#xff0c;你好啊&#xff0c;我是雷工&#xff01; 之前学习了通过脚本创建数据库数据表以及增删改查的相关操作。 接下来了解企业管理器的使用。 以下为学习笔记。 01 新建数据库 1.1、登录数据库后&#xff0c;选中【数据库】-->右击【新建数据库】。 1.2、可以…

swift 自定义DatePacker

import Foundationenum AppDatePickerStyle {case KDatePickerDate //年月日case KDatePickerTime //年月日时分case kDatePickerMonth // 年月case KDatePickerSecond //秒}class AppDatePicker: UIView {private let jk_rootView UIApplication.shared.keyWindow!pri…

电池放电的速率对电池寿命有影响吗?

电池放电的速率对电池寿命确实有很大的影响&#xff0c;电池的寿命通常是指电池在正常使用条件下&#xff0c;能够保持其额定容量的时间。电池的容量会随着充放电次数的增加而逐渐减少&#xff0c;这个过程被称为电池的老化。电池的老化速度受到许多因素的影响&#xff0c;其中…

自闭症的孩子有哪些症状

在自闭症这个复杂而广阔的领域中&#xff0c;作为长期从事自闭症教育的工作者&#xff0c;我们深知每一位自闭症孩子都是独一无二的&#xff0c;他们面对的世界充满了挑战与不解。自闭症&#xff0c;也被称为孤独症谱系障碍&#xff0c;其核心症状往往体现在社交互动、沟通以及…

git安装图文

1.下载 通过百度网盘分享的文件&#xff1a;git安装图文 链接&#xff1a;https://pan.baidu.com/s/17ZMiWUIULtrGGba5n-WLeA 提取码&#xff1a;anjm --来自百度网盘超级会员V3的分享 2.安装