DETR存在的问题
1.收敛速度慢
2.对小目标物体检测效果不好,因为transformer计算量大,受限于计算规模,CNN提取特征时只采取了最后一层特征,没有用FPN等结构。所以对于小目标检测效果不好。
论文主要观点
-
通过对DETRdecoder中的attentionmap进行可视化,发现query查询到的区域都是物体的extremity末端区域。所以论文认为attention尝试找到物体的边界区域。
-
论文中认为DETRtransofmer结构中的信息主要可以分为两部分,一部分是与图像的特征(颜色纹理等)相关的信息,称为content,比如encoder或decoder的输出信息。另一部分是代表空间上的信息,称为spatial,比如position embedding等。
-
detr中的CNN与encoder只涉及图像特征向量提取;decoder中的self-attn只涉及query之间的交互去重;所以收敛慢的最可能原因发生在cross attn
-
Cross attention中的K包含encoder输出信息(content key Ck)与position embedding(spatial Key Pk),Q包含self attention的输出(content query Cq)和object query(spatial query Pq)信息。论文中发现去掉cross attention中的object基本不掉点,所以收敛慢很可能是content query难学习导致的。
-
提出了reference point的概念,为每个query设定一个检测范围,使得匹配更加稳定,加快了收敛
-
原始detr混合两者学习,使得content query难学习。所以将content与spatial进行解耦
变为
网络结构
对于object query生成了一个2D坐标embedding(上图中的s),用于限定当前query的预测范围。最终decoder的输出的是相对与s的偏移量
bbox回归输出:
其中f是decoer的输出,S表示x,y的坐标。最终b是[x,y,w,h]的向量。
classifier分类输出:
f是decoder的输出,输出每个候选框的类别
decoder Pq生成:
提出了reference point的概念,即图中的s,是一个2d的坐标(q_num,B,2),由object queries经过一个线性层生成,代表了每个query的预测范围。
s经过sigmoid和position embedding后(图中的Ps),跟FFN(decoder embedding)(即图中的T)做内积。得到空间特征Pq
代码spatial query这一部分的实现:
# query_pos [num_query,batch,d_model]
# reference_points_before_sigmoid [num_query,batch,2] 从query预测一个坐标,代表了这个query预测的大概范围
reference_points_before_sigmoid = self.ref_point_head(query_pos) # [num_queries, batch_size, 2]
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
for layer_id, layer in enumerate(self.layers):
# 图里的s,代表了query的预测大概范围
obj_center = reference_points[..., :2].transpose(0, 1) # [num_queries, batch_size, 2]
# For the first decoder layer, we do not apply transformation over p_s
## pos_transformation代表图里的T,表示decoder embedding的特征经过ffn后其实得到的是相对于s的偏移量
if layer_id == 0:
pos_transformation = 1
else:
pos_transformation = self.query_scale(output)
# get sine embedding for the query vector
query_sine_embed = gen_sineembed_for_position(obj_center)
# apply transformation
# 最终的Pq,代表空间特征信息
query_sine_embed = query_sine_embed * pos_transformation
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
is_first=(layer_id == 0))
decoder中cross attention的实现
# ========== Begin of Cross-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.ca_qcontent_proj(tgt)
k_content = self.ca_kcontent_proj(memory)
v = self.ca_v_proj(memory)
num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape
# k的位置编码
k_pos = self.ca_kpos_proj(pos)
# For the first decoder layer, we concatenate the positional embedding predicted from
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first:
q_pos = self.ca_qpos_proj(query_pos)
q = q_content + q_pos
k = k_content + k_pos
else:
q = q_content
k = k_content
q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)
# decoder embedding cat spatial query
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
k = k.view(hw, bs, self.nhead, n_model//self.nhead)
# encoder embdeding cat position embedding
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)
tgt2 = self.cross_attn(query=q,
key=k,
value=v, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
# ========== End of Cross-Attention =============
head的实现
# hs代表decoder embedding,reference代表s(reference point)
hs, reference = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])
reference_before_sigmoid = inverse_sigmoid(reference)
outputs_coords = []
for lvl in range(hs.shape[0]):
# 回归head hs输出相对于 reference的偏移量,得到检测框
tmp = self.bbox_embed(hs[lvl])
tmp[..., :2] += reference_before_sigmoid
outputs_coord = tmp.sigmoid()
outputs_coords.append(outputs_coord)
outputs_coord = torch.stack(outputs_coords)
#分类head,hs输出分类结果
outputs_class = self.class_embed(hs)
总结思考
实际上conditional DETR有点像transfoermer版本的faster-RCNN。将特征信息与空间信息进行了解耦。reference point像anchor的概念,让网络自己为每个query设定一个anchor范围,从而使得二分匹配更加问题,所以加快了网络的收敛
作者论文解读:https://zhuanlan.zhihu.com/p/401916664
公式解释得更加详细