1. 总体结构
上图为BEVFormer在t时刻的网络结构。图(a) 表示的是BEVFormer的encoder层。BEVFormer有6个encoder层,每一个encoder除了本文自定义的三个组件外都和传统的transformers结果一致。自定义的三个组件分别是网格状的BEV queries,TSA和SCA。其中BEV queries的参数是可学习的,它通过注意力机制查询多相机视角下的BEV空间特征;图(b)是SCA空间交叉注意力机制,每个BEV query只会提取图像感兴趣区域的特征;©是TSA时间自注意力机制,每个BEV query包括当前时刻的BEV queries和历史时刻的BEV 特征。
在推理时,在t时刻,首先喂入多相机视角图片进入backbone,获得多视角特征图,同时会保留t-1时刻的BEV 特征。在每一个encoder层,首先会使用BEV queries通过TSA查询t-1时刻的BEV 特征,然后BEV queries通过SCA查询多视角特征图的空间特征。在经过FFN(feed-forward network)后的输出作为下一层encoder的输入。就这样经过6个encoder层后,产生当前t时刻的BEV特征。将该特征作为输入,经过3D检测头和地图分割头输出最终预测的3D目标框和语义地图。
2. deformable attention
参考
Deformable DETR: DEFORMABLE TRANSFORMERSFOR END-TO-END OBJECT DETECTION
https://blog.csdn.net/weixin_40671425/article/details/121453942
之后有时间再细看。
3. BEV queries
预先定义一个形状为[H,W,C]的可学习的queries,代码中设置H=200,W=100,C=256。对于BEV平面上的每个网格大小都对应着真实世界的s米。BEV queries的中心对应当前自车位置。依据常见做法,在输入BEVFormer之前,BEV queries中添加可学习的位置编码。
# bevformer_head.py 151行 [20000,256]
bev_queries = self.bev_embedding.weight.to(dtype)
在transformer.py第158-162行中,对queries做如下处理:
# add can bus signals
can_bus = bev_queries.new_tensor(
[each['can_bus'] for each in kwargs['img_metas']]) # [:, :]
can_bus = self.can_bus_mlp(can_bus)[None, :, :]
bev_queries = bev_queries + can_bus * self.use_can_bus
其中,can_bus为车辆自身在当前时刻的一些数据信息,包括线速度、角速度等信息。
4. temporal self-attention
为了更好的检测移动中物体的速度和解决严重遮挡的物体,bevformer作者提出temporal self-attention(TSA)用来将历史信息融合进当前特征中。
5. 历史BEV queries如何结合当前时刻的queries?
根据当前车辆位置信息,对前一时刻的prev_bev做微调,确保在同样的网格grid区域对应同样的真实世界的实际位置。
if prev_bev is not None:
if prev_bev.shape[1] == bev_h * bev_w:
prev_bev = prev_bev.permute(1, 0, 2)
if self.rotate_prev_bev:
for i in range(bs):
# num_prev_bev = prev_bev.size(1)
rotation_angle = kwargs['img_metas'][i]['can_bus'][-1]
tmp_prev_bev = prev_bev[:, i].reshape(
bev_h, bev_w, -1).permute(2, 0, 1)
tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
center=self.rotate_center)
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
bev_h * bev_w, 1, -1)
prev_bev[:, i] = tmp_prev_bev[:, 0]
在encoder.py BEVFormerEncoder类中,对prev_bev做如下处理:
当prev_bev存在时,将上面调整后的prev_bev和当前的bev_query合并。
if prev_bev is not None:
prev_bev = prev_bev.permute(1, 0, 2)
prev_bev = torch.stack(
[prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1)
在encoder.py BEVFormerLayer类中,在TSA的forward函数里,传入的value和key都是prev_bev,如果prev_bev为None时,用当前query代替:
if value is None:
assert self.batch_first
bs, len_bev, c = query.shape
value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c)
从t-1到t时刻,不同移动的目标在真实世界中有着不同的偏移,故在不同时刻对于相同的目标的BEV特征建立精确的关联性是非常有挑战的事情。本文作者通过TSA来建立这种关联性。具体公式:
对应代码(temporal_self_attention.py):
output = self.multi_scale_deformable_attn(
value, spatial_shapes, reference_points, sampling_offsets, attention_weights
).flatten(2)
output = torch.mean(output,keepdim=True,dim=0)
6. Spatial Cross-Attention
由于在自动驾驶的3D感知任务中,输入是基于多视角图像的,其多头注意力计算消耗相当高。在bevformer中,作者基于2D的deformable attention方法,提出了Spatial Cross-Attention(SCA)。每一个BEV query都会根据六路相机视角提取感兴趣区域的空间特征。如下图所示:
上图为SCA示意图,中间区域的BEV query,周围为六路图像特征。从图中可以看到BEV query平面在z轴延伸四个点,用于采样图像中不同高度的空间特征。
首先,初始化采样点ref_3d,W为bev平面的宽,H为bev平面的高,Z为柱子的高,num_points_in_pillar表示每个柱子上采样的点的数目,在Z轴方向上也初始化了采样点,包含了路面上不同高度的物体。(代码在encoder.py的get_reference_points函数)
zs=torch.arange(0.5, Z-0.5+((Z-1)/num_points_in_pillar), (Z-1)/(num_points_in_pillar-1),dtype=dtype,
device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z
xs=torch.arange(0.5, W-0.5+((W-1)/W), (W-1)/(W-1),dtype=dtype,
device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W
ys=torch.arange(0.5, H-0.5+((H-1)/H), (H-1)/(H-1),dtype=dtype,
device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
ref_3d = torch.stack((xs, ys, zs), -1).view(1,4,-1,3)
初始化后,将ref_3d点集范围缩放至pc_range指定的范围内,其中pc_range=[-30,-15,-2,30,15,2],表示以当前自车作为中心,取前后30米范围,左右15米范围,上下2米范围内的物体坐标。
# reference_points等于上面的ref_3d
reference_points = reference_points * torch.tensor(
[
pc_range[3] - pc_range[0],
pc_range[4] - pc_range[1],
pc_range[5] - pc_range[2],
],
dtype=reference_points.dtype,
device=reference_points.device,
).view(1, 1, 1, 3) + torch.tensor(
pc_range[:3], dtype=reference_points.dtype, device=reference_points.device
)
再利用传进来的lidar2img,将lidar坐标系下的采样点投至像素坐标系下。
reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
reference_points.to(torch.float32)).squeeze(-1)
因为某个物体不一定在六路视角下同时都存在,所以这里对reference_points_cam做了条件判断,滤除在某些视角下不存在的坐标点,这样大大节约了计算资源。
上文获得的reference_points_cam即作为query的参考点,根据这些参考点采样对应视角下对应的特征。最后将采样的特征加权求和作为SCA的输出。具体公式如下:
对应代码(spatial_cross_attention.py):
queries = self.deformable_attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims), key=key, value=value,
reference_points=reference_points_rebatch.view(bs*self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes,
level_start_index=level_start_index).view(bs, self.num_cams, max_len, self.embed_dims)
for j in range(bs):
for i, index_query_per_img in enumerate(indexes):
slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)]
count = bev_mask.sum(-1) > 0
count = count.permute(1, 2, 0).sum(-1)
count = torch.clamp(count, min=1.0)
slots = slots / count[..., None]
7. 结果
具体结果可以查看原文。