文章目录
- 一、背景
- 二、方法
- 2.1 DETR
- 2.2 MDETR
- 三、效果
- 3.1 预训练调整后的检测器
- 3.2 下游任务
论文:MDETR - Modulated Detection for End-to-End Multi-Modal Understanding
代码:https://github.com/ashkamath/mdetr
出处:ICCV 2021 Oral | Yann LeCun | NYU | Facebook AI
时间:2021.10
贡献:
- 提出了端到端的 text-modulated 检测系统
- 打破了传统目标检测只能检测特定类别的限制,可以实现对任意形式文本输入中提及的内容进行检测
一、背景
目标检测在很多多模态理解系统中有着很重要的作用,一般是用作一个黑盒,用于检测特定类别的目标,然后后面进行多模态对齐。
但这种 pipeline 会限制模型只能检测特定的目标,而不能实现对整个图像的描述。
此外,目标检测系统的类别是固定的,也会阻碍模型性能的改进,这些系统也不能识别任意类别的 text 输入
在一些新的工作中,将该问题归结为一个 text-conditioned 目标检测问题,这些工作将一些主流的单阶段和双阶段目标检测器进行扩展,来实现对任意输入的检测。但由于很多检测器不是端到端可微的,无法和下游任务一起训练,所以难以对下游任务产生很好的指导作用。
本文提出的 MDETR,是一个端到端可调节的检测器,基于 DETR 检测结构,并且将目标检测和自然语言处理理解联合训练,能够实现完全端到端的多模态推理
phrase grounding 任务:给定一个自由形式的文本,找出文本中提到的目标
二、方法
2.1 DETR
DETR:
- DETR 是一个端到端的目标检测器,使用卷积网络作为 backbone,后面接一个 transformer encoder-decoder 来进行编解码
- encoder:对卷积提取到的特征经过 flattened,使用一些 transformer 层来进行编码
- decoder:输入是一系列 N 个可学习的 embedding(object queries),所有的 object queries 被并行输入 decoder
- 训练策略:使用匈牙利匹配 loss,对 N 个预测的目标和真实的目标进行双边匹配,和 gt 匹配成功的预测结果就会使用该 gt 进行监督,没有匹配成功的就会被监督预测出 ‘no object’ label ϕ \phi ϕ
- 分类头:使用 cross-entropy loss 来监督
- 回归头:使用 L1 loss 和 GIoU 的结合来监督
2.2 MDETR
MDETR 的结构如图 2 所示:
- 图像先经过 CNN 来提取特征,并 flattened
- 给经过 flattened 的向量加上 2-D 位置编码
- 将输入的 text 使用经过预训练的 transformer language model 进行编码
- 使用线性映射,来对图像和文本特征映射到共同的编码空间,然后按序列维度来将生成的编码连接起来形成一个更长的编码,输入 cross encoder 中
loss 函数:
- 作者在 DETR loss 函数之外还使用了两个额外的 loss 函数
- 第一个是 soft token prediction loss:无参数的对齐 loss,因为 MDETR 不会直接输出类别,而是会输出目标和 token 的相似度
- 第二个是 text-query contrastive alignment loss:有参数的 loss,计算 object query 和 token 的相似程度
1、soft token prediction loss
对于 modulated detection,不同于标准的目标检测设置,而是参考每个 matched object 来从 original text 中预测 span of tokens
- 首先,将给定句子的最大 token 长度设为 256,对于每个和 gt box 匹配上的预测 box(根据双边匹配原则),模型都会预测一个和 object 相关的 token 的标准分布,如图 2 所示,cat 的 box 就能够预测前两个 words 的标准分布,图 6 也做了展示
- 没有和目标匹配上的 query 会被训练来预测 ‘no object’ 的 label
- 此外,可以注意到,可能会出现多个 words 对应图像中一个相同的目标,也可能会出现多个目标对应相同的 text,这样的 loss 设计能够让模型学习有共指关系的目标
2、text-query contrastive alignment loss
尽管 soft token prediction 使用 positional 信息来对齐目标和文本,但对比对齐 loss 能够更加增强以下两者的对齐程度:
- object embedded representation:object 经过 decoder 的输出
- text representation:cross encoder 的输出
contrastive alignment loss 的作用:能够确保图像目标的编码和与其对应的 text token 的编码比与其无关的 text token 的编码更加接近
对比对齐 loss 的公式如下:
- token 的最大数量为 L,目标的最大数量为 N
- T i + T_i^+ Ti+ 是一系列需要和给定 object o i o_i oi 进行对齐的 tokens
- O i + O_i^+ Oi+ 是一系列需要和给定 token t i t_i ti 进行对齐的 objects
- τ \tau τ 是温度参数,设置为 0.07
① 对所有 object 的 contrastive alignment loss 如下,归一化的方式是除以每个 object 对应的 positive token 数量:
② 对所有 token 的 constrastive loss 如下,归一化的方式是除以每个 token 的 positive object 数量:
③ 整体的 constrastive alignment loss 是上述两个子 loss 的平均
3、所有 loss 的结合
DETR 中,使用双边匹配的方法来寻找预测和 gt 中的最佳匹配
MDETR 中,最大的不同在于对没有预测的 object 没有特定的类别 label,而是预测 object 对应 text 中所有相关位置 token 的标准分布(soft token prediction),使用 soft cross entropy 监督
matching cost 由 L1 和 GIoU 组成
总体的 loss :box prediction loss(L1、GIoU)、soft-token prediction loss、contrastive alignment loss
三、效果
3.1 预训练调整后的检测器
本文中所谓的 modulated detection,表示的就是将传统的按类别检测的检测器进行调整或修改后的检测器,能够实现对任意文本输入设计到的目标都进行检测,而非只能检测特定的类别。
数据集:混合数据集
- Flickr30k
- MS COCO
- Visual Genome (VG)
如何进行数据集混合:
- 对每个图像,把整个数据集中的所有标注拿出来,将同一图片中目标的所有 text 结合起来
- 如何结合成句子:使用 graph coloring algorithm,只把 GIoU<=0.5 的 box 对应的 phrase 结合起来,整个句子的长度小于 250
- 通过上述方法,作者获得了 1.3M 个 image-text pairs
这种句子结合方式的好处:
- 数据有效性,能够将很多信息打包到一个训练样本中
- 能够为 soft token prediction loss 提供更好的学习信号,因为模型必须学习小区多个相同类别目标同时出现时的歧义,如图 3 所示
模型:
- text encoder:pre-trained RoBERTa-base,有 12 层 transformer encoder,每个有 768 hidden dimension,12 heads,使用 HuggingFace 提供的权重
- visual backbone:尝试了两个,一个是 ResNet-101,一个是 EfficientNet family
3.2 下游任务
1、Phrase grounding:给定一个或多个 phrases,该任务的目的是为每个 phrase 预测一系列的 bbox
2、Referring expression comprehension:给定一个图像和对应的文本描述,该任务是理解文本描述,并返回需要返回的目标,而非返回全部提及的目标
有三个数据集:
- RefCOCO
- RefCOCO+ [70]
- RefCOCOg [36]
因为在预训练中,训练方式包括所有 text 涉及到的目标,和这个子任务有些不同,这个任务的一个示例如下:The woman wearing a blue dress standing next to the rose bush
- 不需要返回所有的 woman、blue dress、rose bush
- 只需要返回 woman box 即可
MDETR 在这个任务上进行了 5 epoch 的 finetuned,在推理阶段,使用 ϕ \phi ϕ 来对 100 个检测到的 box 进行排序, P ( ϕ ) P(\phi) P(ϕ) 表示被分配为 no label 的概率,使用 1 − P ( ϕ ) 1-P(\phi) 1−P(ϕ) 降序排列,结果见表 2。
3、Referring expression segmentation:本文的方法也可以扩展到分割上
4、Visual Question Answering:如图 4 展示了本文模型的结构