文章目录
- 前言
- 一、DETR理论
- 二、模型架构
- 1. CNN
- 2. Transformer
- 3. FFN
- 三、损失函数
- 四、代码实现
- 总结
前言
DETR是Facebook团队在2020年提出的一篇论文,名字叫做《End-to-End Object Detection with Transformers》端到端的基于Transformers的目标检测,DETR是Detection Transformers的缩写。DETR摆脱了传统目标检测模型中复杂的组件,例如NMS、先验框等,是一种基于Transformer的简单的端到端的目标检测架构,并且取的了与Faster R-CNN差不多的成绩。接下来对DETR这篇论文进行介绍。
一、DETR理论
DETR将目标检测看做一个直接的预测集合问题,集合的大小是固定的,集合中的每个元素可以当作一个检测框,可以理解为DETR根据图片的输入去预测集合中元素的属性(分类,坐标)。
DETR的推理概如下图所示:输入是一张图片,图片首先经过CNN网络进行下采样,然后将CNN输出的特征图拉直成向量进入Transformer的编码器中,Transformer的解码器的输入是object queries,(object queries是可学习的参数,可以理解为object queries学习预测框的先验知识)然后与Transformer的编码器输出做交叉注意力并行得到最终的检测框。在DETR论文中这个集合的大小取100,也就是对于每张图片,都会一口气预测出100个框,对于预测框中的分类预测为’no object’的则不显示,然后我们可以设置一个阈值,把集合中置信度低于阈值的预测框去掉,从而得到最终的输出。
DETR根据Transformer的全局建模能力对图像进行全局的上下文推理,因此对于大物体的检测效果很好,但是对于小物体的检测效果不如Faster R-CNN。但是由于其架构的简单性深受大家喜欢。
二、模型架构
DETR的模型结构主要由三部分组成,分别是CNN,Transformer和FFN。如下图所示:
1. CNN
DETR使用CNN骨干网络(例如ResNet50)将输入的照片由[3,H,W]下采样成[2048,H/32,W/32],然后在使用一个卷积核用来减少通道数,由[2048,H/32,W/32]变为[d,H/32,W/32]。然后将其进行展平与位置编码表进行相加送入到Transformer编码器中。
2. Transformer
在论文中,DETR使用了6层的Transformer。
编码器 DETR中的编码器架构与经典的Transformer编码器相同,由多头自注意力层和FFN组成。下图为encoder部分的自注意力可视化,可以看到encoder主要负责预测物体的主体部分。
解码器 DETR中的解码器架构也与经典的Transformer解码器相同,稍微不同的是,经典的Transformer模型是自回归方式,而DETR在每个解码器层并行解码N个对象,因此N个输入嵌入必须不同,输入嵌入是可以学习的位置编码,称之为’object query’,在经过解码器的解码后并行计算出N个输出。下图为解码器的可视化输出,可以看到解码器主要用于区分物体的边缘或者轮廓。
3. FFN
当得到解码器的N个输出后,使用两个线性层和ReLU激活函数将每个输出分别映射到预测框的输出类别和坐标位置。由于集合中的边界框比实际照片中的物体数量多,所以集合中剩余的那些边界框预测的标签则为’no object’,表示预测为背景。
三、损失函数
DETR在进行训练时,例如我们设置集合的大小为100,那么模型最终会输出100个框,Ground True可能只有2个,那么我们如何算Loss呢?
1.论文里首先是使用匈牙利算法进行最佳二分图匹配。 如何理解最佳二分图匹配呢?
例如:有100个预测框,最终只有2个预测框与真实的Ground True相匹配。最佳二分图匹配就是那么选哪两个预测框与2个Ground True进行匹配得到的最终Loss最小。
那么我们如何衡量一个预测框与Groud True之间的损失呢(匹配损失)?公式如下:
通过这个公式我们可以获得预测框分别于Ground True进行匹配的Loss,最终使用匈牙利算法选定哪两个预测框与Ground True相对应,而其他的98个预测框可以看作与’no object‘相匹配。
2. 接着我们得出100个预测框如何与Ground True进行匹配,使用如下公式算的最终的Loss:
在实际中,为了保持正负样本的均衡,当预测框的种类为’no object’,其对数概率项的权重被缩小了10倍。同时为了使Bounding box loss (
L
b
o
x
L_{box}
Lbox)对于不同大小的预测框的惩罚项尽可能公平,
L
b
o
x
L_{box}
Lbox由
L
o
s
s
I
o
u
Loss_{Iou}
LossIou和
L
1
L_1
L1组成,公式如下:
3 辅助损失函数。 在训练过程中Transformer的decoder中加入了辅助损失函数,也就是将每层decoder的输出都通过参数共享的FFN映射成预测框然后计算Loss。
四、代码实现
DETR的前向推理过程代码如下所示:
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(hidden_dim, nheads,
num_encoder_layers, num_decoder_layers)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1))
return self.linear_class(h), self.linear_bbox(h).sigmoid()
总结
DETR是基于Transformer架构的端到端的目标检测模型,把目标检测看作一个直接的集合预测问题,他简化了传统目标检测模型繁杂的前处理和后处理过程,并且随着Transformer架构的增加模型的性能也有所增加。DETR相当于使用’object queries’替代了’anchor’,使用二分图匹配去掉了之前的NMS。同时DETR还有一些不足,因为他是Transformer架构的所以不好优化,对于小目标的物体检测不足等等。但由于DETR的简单性有效性后续出来了一大批工作对于DETR做出了改进。