本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考
目录
- 原理
- 第一步看看输入:
- 第二步,准备工作:
- 生成参考点的偏移量
- 生成参考点的权重
- 生成参考点
- 第三步,工作:
- 源码
原理
目前流行3D转2DBEV方案的都绕不开的transfomer变体-DeformableAttention.
传统transformer注意力关注全局特征,速度慢.而DeformableAttention注意力模块只关注一个目标周围的一小部分的关键采样点特征.原来的DETR需要很多个 epoch 才能找到特征,在Deformable DTER中可以更快,据说1/10的耗时。
原理:以DETR3D的做法为例.
第一步看看输入:
定义一个shape为(900,256)的query,代表900和目标,每个目标256维查询信息.
定义一个query_pos shape同query.
定义一个shape为(900,3)的reference_points,作为目标参考点.
输入为:pts_feats(1,43054,256),多尺度flatten结果,
多尺度特征图尺寸记录:spatial_shapes:([[180, 180],[ 90, 90],[ 45, 45],[ 23, 23]])
特征图在pts_feats起点记录:level_start_index:([ 0, 32400, 40500, 42525])
可自行验算下.
第二步,准备工作:
pts_feats reshape为(1,43054,8,32)
value = value.view(bs, num_value, self.num_heads, -1)
生成参考点的偏移量
query经过self.sampling_offsets线性映射再reshape输出:
sampling_offsets(torch.Size([1, 900, 8, 4, 4, 2]))
其中8是多头数量,4是特征层数, 4是采样点数, 2是采样点xy两个维度.意思是8次在4层特征图上分别采样4个点,这844个点的xy方向的偏移量.
生成参考点的权重
query经过self.attention_weights线性映射再reshape输出:
attention_weights(torch.Size([1, 900, 8, 4, 4]))
对应上述点的权重.
生成参考点
reference_points加上参考点的偏移量生成,真正的参考点.
sampling_location = reference_poins[:, :, None, None, None, :2] + sampling_offsets
sampling_locations(torch.Size([1, 900, 8, 4, 4, 2]))
说白就是,就是定义一个query_embed,它生成自己即将要去采样的点位置和采样点权重.
第三步,工作:
输入:
value shape(torch.Size([b,43054,8,32]))
sampling_locations(torch.Size([b, 900, 8, 4, 4, 2]))
attention_weights(torch.Size([b, 900, 8, 4, 4]))
spatial_shapes:([[180, 180],[ 90, 90],[ 45, 45],[ 23, 23]])
value 根据spatial_shapes分解出各个level:
[torch.Size([b,180180,8,32],torch.Size([b,9090,8,32])),torch.Size([b,4545,8,32])),torch.Size([b,2323,8,32]))]
reshape为正常图像torch.Size([b*8,32,180,180]
sampling_locations原本为采样点位置,范围为[0,1),为了适应F.grid_sample采样函数的用法,调整为[-1,1)分布,
调用F.grid_sample对每一层特征进行采样,输入value为torch.Size([b8,32,level_h,level_w]),采样点为sampling_grid:torch.Size([b8,900,4,2])
则输出为sampling_value:torch.Size([b8,32,900,4])
意思是,900个query在特征图(32,level_h,level_w)中各采样4个点,采样结果为900个对应的4个通道为32的像素特征.
将4层采样结果sampling_value拍在一起torch.Size([b8,32,900,4*4])
attention_weights变成相同形式(torch.Size([b8, 1,900, 44])),然后对16个采样特征进行加权求和输出outputtorch.Size([b,32*8,900]).后续交给FFN对多头特征进行全连接融合.
源码
import torch
import torch.nn.functional as F
import torch.nn as nn
def multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_locations, attention_weights):
batch, _, num_head, embeding_dim_perhead = value.shape
_, query_size, _, level_num, sample_num, _ = sampling_locations.shape
split_list = []
for h, w in spatial_shapes:
split_list.append(int(h * w))
value_list = value.split(split_size=tuple(split_list), dim=1)
# [0,1)分布变成 [-1,1)分布,因为要调用F.grid_sample函数
sampling_grid = 2 * sampling_locations - 1
output_list = []
for level_id, (h, w) in enumerate(spatial_shapes):
h = int(h)
w = int(w)
# batch, value_len, num_head, embeding_dim_perhead
# batch, num_head, embeding_dim_perhead, value_len
# batch*num_head, embeding_dim_perhead, h, w
value_l = value_list[level_id].permute(0, 2, 3, 1).view(batch * num_head, embeding_dim_perhead, h, w)
# batch,query_size,num_head,level_num,sample_num,2
# batch,query_size,num_head,sample_num,2
# batch,num_head,query_size,sample_num,2
# batch*num_head,query_size,sample_num,2
sampling_grid_l = sampling_grid[:, :, :, level_id, :, :].permute(0, 2, 1, 3, 4).view(batch * num_head,
query_size, sample_num, 2)
# batch*num_head embeding_dim,,query_size, sample_num
output = F.grid_sample(input=value_l,
grid=sampling_grid_l,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
output_list.append(output)
# batch*num_head, embeding_dim_perhead,query_size, level_num, sample_num
outputs = torch.stack(output_list, dim=-2)
# batch,query_size,num_head,level_num,sample_num
# batch,num_head,query_size,level_num,sample_num
# batch*num_head,1,query_size,level_num,sample_num
attention_weights = attention_weights.permute(0, 2, 1, 3, 4).view(batch * num_head, 1, query_size, level_num,
sample_num)
outputs = outputs * attention_weights
# batch*num_head, embeding_dim_perhead,query_size
# batch,num_head, embeding_dim_perhead,query_size
# batch,query_size,num_head, embeding_dim_perhead
# batch,query_size,num_head*embeding_dim_perhead
outputs = outputs.sum(-1).sum(-1).view(batch, num_head, embeding_dim_perhead, query_size).permute(0, 3, 1, 2). \
view(batch, query_size, num_head * embeding_dim_perhead)
return outputs.contiguous()
if __name__ == '__main__':
batch = 1
num_head = 8
embeding_dim = 256
query_size = 900
spatial_shapes = torch.Tensor([[180, 180], [90, 90], [45, 45], [23, 23]])
value_len = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().int()
value = torch.rand(size=(batch, value_len, embeding_dim))
query_embeding = torch.rand(size=(batch, query_size, embeding_dim * 2 + 3))
query = query_embeding[..., :embeding_dim]
query_pos = query_embeding[..., embeding_dim:2 * embeding_dim]
reference_poins = query_embeding[..., 2 * embeding_dim:]
# 讨论1:在deformale-att中这个query并不会和value交互生成att-weights,att-weights只和query有关,
# 也就是推理过程att-weights(包括sampling_locations)是固定的.
# 据作者解释这是因为采用前者的方式计算的attention权重存在退化问题,
# 即最后得到的attention权重与并没有随key的变化而变化。
# 因此,这两种计算attention权重的方式最终得到的结果相当,
# 而后者耗时更短、计算代价更小,所以作者选择直接对query做projection得到attention权重。
# 讨论2:在query固定情况下,第一个layer的att-weights无法改变,
# 但是第二个layer的query与value有关,att-weights则会发生变化.so the self-att in frist layer is not nesscerary
level_num = 4
sample_num = 4
sampling_offsets_net = nn.Linear(in_features=embeding_dim, out_features=num_head * level_num * sample_num * 2)
sampling_offsets = sampling_offsets_net(query).view(batch, query_size, num_head, level_num, sample_num, 2)
sampling_location = reference_poins[:, :, None, None, None, :2] + sampling_offsets
attention_weights_net = nn.Linear(in_features=embeding_dim, out_features=num_head * level_num * sample_num)
attention_weights = attention_weights_net(query).view(batch, query_size, num_head, level_num * sample_num)
attention_weights = attention_weights.softmax(dim=-1).view(batch, query_size, num_head, level_num,
sample_num) # sum of 16 points weight is equal to 1
embeding_dim_perhead = embeding_dim // num_head
value = value.view(batch, value_len, num_head, -1)
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_location, attention_weights)
pass
如需获取全套代码请参考