DeformableAttention的原理解读和源码实现

news2024/11/15 19:59:33

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

目录

  • 原理
    • 第一步看看输入:
    • 第二步,准备工作:
      • 生成参考点的偏移量
      • 生成参考点的权重
      • 生成参考点
    • 第三步,工作:
  • 源码

原理

目前流行3D转2DBEV方案的都绕不开的transfomer变体-DeformableAttention.
在这里插入图片描述
传统transformer注意力关注全局特征,速度慢.而DeformableAttention注意力模块只关注一个目标周围的一小部分的关键采样点特征.原来的DETR需要很多个 epoch 才能找到特征,在Deformable DTER中可以更快,据说1/10的耗时。
原理:以DETR3D的做法为例.

第一步看看输入:

定义一个shape为(900,256)的query,代表900和目标,每个目标256维查询信息.
定义一个query_pos shape同query.
定义一个shape为(900,3)的reference_points,作为目标参考点.
输入为:pts_feats(1,43054,256),多尺度flatten结果,
多尺度特征图尺寸记录:spatial_shapes:([[180, 180],[ 90, 90],[ 45, 45],[ 23, 23]])
特征图在pts_feats起点记录:level_start_index:([ 0, 32400, 40500, 42525])
可自行验算下.

第二步,准备工作:

pts_feats reshape为(1,43054,8,32)

value = value.view(bs, num_value, self.num_heads, -1)

生成参考点的偏移量

query经过self.sampling_offsets线性映射再reshape输出:
sampling_offsets(torch.Size([1, 900, 8, 4, 4, 2]))
其中8是多头数量,4是特征层数, 4是采样点数, 2是采样点xy两个维度.意思是8次在4层特征图上分别采样4个点,这844个点的xy方向的偏移量.

生成参考点的权重

query经过self.attention_weights线性映射再reshape输出:
attention_weights(torch.Size([1, 900, 8, 4, 4]))
对应上述点的权重.

生成参考点

reference_points加上参考点的偏移量生成,真正的参考点.

sampling_location = reference_poins[:, :, None, None, None, :2] + sampling_offsets

sampling_locations(torch.Size([1, 900, 8, 4, 4, 2]))

说白就是,就是定义一个query_embed,它生成自己即将要去采样的点位置和采样点权重.

第三步,工作:

输入:
value shape(torch.Size([b,43054,8,32]))
sampling_locations(torch.Size([b, 900, 8, 4, 4, 2]))
attention_weights(torch.Size([b, 900, 8, 4, 4]))
spatial_shapes:([[180, 180],[ 90, 90],[ 45, 45],[ 23, 23]])

value 根据spatial_shapes分解出各个level:
[torch.Size([b,180180,8,32],torch.Size([b,9090,8,32])),torch.Size([b,4545,8,32])),torch.Size([b,2323,8,32]))]
reshape为正常图像torch.Size([b*8,32,180,180]

sampling_locations原本为采样点位置,范围为[0,1),为了适应F.grid_sample采样函数的用法,调整为[-1,1)分布,
调用F.grid_sample对每一层特征进行采样,输入value为torch.Size([b8,32,level_h,level_w]),采样点为sampling_grid:torch.Size([b8,900,4,2])
则输出为sampling_value:torch.Size([b8,32,900,4])
意思是,900个query在特征图(32,level_h,level_w)中各采样4个点,采样结果为900个对应的4个通道为32的像素特征.
将4层采样结果sampling_value拍在一起torch.Size([b
8,32,900,4*4])

attention_weights变成相同形式(torch.Size([b8, 1,900, 44])),然后对16个采样特征进行加权求和输出outputtorch.Size([b,32*8,900]).后续交给FFN对多头特征进行全连接融合.

源码

import torch
import torch.nn.functional as F
import torch.nn as nn


def multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_locations, attention_weights):
    batch, _, num_head, embeding_dim_perhead = value.shape
    _, query_size, _, level_num, sample_num, _ = sampling_locations.shape
    split_list = []
    for h, w in spatial_shapes:
        split_list.append(int(h * w))
    value_list = value.split(split_size=tuple(split_list), dim=1)
    # [0,1)分布变成 [-1,1)分布,因为要调用F.grid_sample函数
    sampling_grid = 2 * sampling_locations - 1
    output_list = []
    for level_id, (h, w) in enumerate(spatial_shapes):
        h = int(h)
        w = int(w)
        # batch, value_len, num_head, embeding_dim_perhead
        # batch, num_head, embeding_dim_perhead, value_len
        # batch*num_head, embeding_dim_perhead, h, w
        value_l = value_list[level_id].permute(0, 2, 3, 1).view(batch * num_head, embeding_dim_perhead, h, w)
        # batch,query_size,num_head,level_num,sample_num,2
        # batch,query_size,num_head,sample_num,2
        # batch,num_head,query_size,sample_num,2
        # batch*num_head,query_size,sample_num,2
        sampling_grid_l = sampling_grid[:, :, :, level_id, :, :].permute(0, 2, 1, 3, 4).view(batch * num_head,
                                                                                             query_size, sample_num, 2)
        # batch*num_head embeding_dim,,query_size, sample_num
        output = F.grid_sample(input=value_l,
                               grid=sampling_grid_l,
                               mode='bilinear',
                               padding_mode='zeros',
                               align_corners=False)
        output_list.append(output)
    # batch*num_head, embeding_dim_perhead,query_size, level_num, sample_num
    outputs = torch.stack(output_list, dim=-2)
    # batch,query_size,num_head,level_num,sample_num
    # batch,num_head,query_size,level_num,sample_num
    # batch*num_head,1,query_size,level_num,sample_num
    attention_weights = attention_weights.permute(0, 2, 1, 3, 4).view(batch * num_head, 1, query_size, level_num,
                                                                      sample_num)
    outputs = outputs * attention_weights
    # batch*num_head, embeding_dim_perhead,query_size
    # batch,num_head, embeding_dim_perhead,query_size
    # batch,query_size,num_head, embeding_dim_perhead
    # batch,query_size,num_head*embeding_dim_perhead
    outputs = outputs.sum(-1).sum(-1).view(batch, num_head, embeding_dim_perhead, query_size).permute(0, 3, 1, 2). \
        view(batch, query_size, num_head * embeding_dim_perhead)
    return outputs.contiguous()


if __name__ == '__main__':
    batch = 1
    num_head = 8
    embeding_dim = 256
    query_size = 900
    spatial_shapes = torch.Tensor([[180, 180], [90, 90], [45, 45], [23, 23]])
    value_len = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().int()
    value = torch.rand(size=(batch, value_len, embeding_dim))
    query_embeding = torch.rand(size=(batch, query_size, embeding_dim * 2 + 3))
    query = query_embeding[..., :embeding_dim]
    query_pos = query_embeding[..., embeding_dim:2 * embeding_dim]
    reference_poins = query_embeding[..., 2 * embeding_dim:]
    # 讨论1:在deformale-att中这个query并不会和value交互生成att-weights,att-weights只和query有关,
    # 也就是推理过程att-weights(包括sampling_locations)是固定的.
    # 据作者解释这是因为采用前者的方式计算的attention权重存在退化问题,
    # 即最后得到的attention权重与并没有随key的变化而变化。
    # 因此,这两种计算attention权重的方式最终得到的结果相当,
    # 而后者耗时更短、计算代价更小,所以作者选择直接对query做projection得到attention权重。
    # 讨论2:在query固定情况下,第一个layer的att-weights无法改变,
    # 但是第二个layer的query与value有关,att-weights则会发生变化.so the self-att in frist layer is not nesscerary
    level_num = 4
    sample_num = 4
    sampling_offsets_net = nn.Linear(in_features=embeding_dim, out_features=num_head * level_num * sample_num * 2)
    sampling_offsets = sampling_offsets_net(query).view(batch, query_size, num_head, level_num, sample_num, 2)
    sampling_location = reference_poins[:, :, None, None, None, :2] + sampling_offsets
    attention_weights_net = nn.Linear(in_features=embeding_dim, out_features=num_head * level_num * sample_num)
    attention_weights = attention_weights_net(query).view(batch, query_size, num_head, level_num * sample_num)
    attention_weights = attention_weights.softmax(dim=-1).view(batch, query_size, num_head, level_num,
                                                               sample_num)  # sum of 16 points weight is equal to 1
    embeding_dim_perhead = embeding_dim // num_head
    value = value.view(batch, value_len, num_head, -1)

    output = multi_scale_deformable_attn_pytorch(
        value, spatial_shapes, sampling_location, attention_weights)
    pass

如需获取全套代码请参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1526980.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DataFunSummit 2023因果推断在线峰会:解码数据与因果,引领智能决策新篇章(附大会核心PPT下载)

在数据驱动的时代,因果推断作为数据科学领域的重要分支,正日益受到业界的广泛关注。DataFunSummit 2023年因果推断在线峰会,汇聚了国内外顶尖的因果推断领域专家、学者及业界精英,共同探讨因果推断的最新进展、应用与挑战。本文将…

【小白笔记:JetsonNano学习(一)SDKManager系统烧录】

参考文章:SDKManager系统烧录 小白烧录文件系统可能遇到的问题 担心博主删除文章,可能就找不到比较详细的教程了,特意记录一下。 Jetson Nano采用四核64位ARM CPU和128核集成NVIDIA GPU,可提供472 GFLOPS的计算性能。它还包括4GB…

24计算机考研调剂 | 【官方】山东师范大学(22自命题)

山东师范大学2024年拟接收调剂 考研调剂信息 调剂专业目录如下: 计算机技术(085404)、软件工程(085405) 补充内容 我校2024年硕士研究生调剂工作将于4月8日教育部“中国研究生招生信息网”(https://yz.ch…

海外问卷调查:代理IP使用方法

在进行问卷调查时,为了避免被限制访问或被封禁IP,使用代理IP已经成为了必要的选择。 其中,口子查和渠道查也不例外。 使用代理IP可以隐藏本机IP地址,模拟不同的IP地址,从而规避被封禁的风险。但是,对于很…

登录-前端部分

登录表单和注册表单在同一个页面中,通过注册按钮以及返回按钮来控制要显示哪个表单 一、数据绑定和校验 (1)绑定数据,复用注册表单的数据模型: //控制注册与登录表单的显示, 默认false显示登录 true时显…

linux 安装常用软件

文件传输工具 sudo yum install –y lrzsz vim编辑器 sudo yum install -y vimDNS 查询 sudo yum install bind-utils用法可以参考文章 《掌握 DNS 查询技巧,dig 命令基本用法》 net-tools包 yum install net-tools -y简单用法: # 查看端口占用情况…

3_springboot_shiro_jwt_多端认证鉴权_Redis缓存管理器

1. 什么是Shiro缓存管理器 上一章节分析完了Realm是怎么运作的,自定义的Realm该如何写,需要注意什么。本章来关注Realm中的一个话题,缓存。再看看 AuthorizingRealm 类继承关系 其中抽象类 CachingRealm ,表示这个Realm是带缓存…

stm32-模拟数字转化器ADC

接线图: #include "stm32f10x.h" // Device header//1: 开启RCC时钟,包括ADC和GPIO的时钟//2:配置GPIO将GPIO配置为模拟输入模式//3:配置多路开关将左边的通道接入到规则组中//4:配置ADC转…

在Python中执行分位数回归

线性回归被定义为根据给定的变量集构建因变量和自变量之间关系的统计方法。在执行线性回归时,我们对计算响应变量的平均值感到好奇。相反,我们可以使用称为分位数回归的机制来计算或估计响应值的分位数(百分位数)值。例如&#xf…

Unity UGUI之Toggle基本了解

在Unity中,Toggle一般用于两种状态之间的切换,通常用于开关或复选框等功能。 它的基本属性如图: 其中, Interactable(可交互):指示Toggle是否可以与用户交互。设置为false时,禁用To…

代码随想录|Day23|回溯03|39.组合总和、40.组合总和II、131.分割回文串

39.组合总和 本题和 216.组合总和III 类似,但有几个区别: 没有元素个数限制:树的深度并不固定,因此递归终止条件有所变化每个元素可以使用多次:下层递归的起始位置和上层相同(startIndex不需要改动&#xf…

#每天一道面试题# 什么是MySQL的回表查询

MySQL中的索引按照物理存储的方式分为聚集索引和非聚集索引; 聚集索引索引和数据存储在一起,B树的叶子节点就是表数据,如果通过聚集索引查询数据,直接就可以查询出我们想要的数据;非聚集索引B树的叶子节点存储的是主键…

Hive SQL必刷练习题:连续问题 间断连续(*****)

问题描述: 1) 连续问题:找出连续三天(或者连续几天的啥啥啥)。 2) 间断连续:统计各用户连续登录最长天数,间断一天也算连续,比如1、3、4、6也算登陆了6天 问题分析&am…

Java八股文(XXL-JOB)

Java八股文のXXL-JOB XXL-JOB XXL-JOB xxl-job 是什么?它的主要作用是什么? xxl-job 是一款分布式任务调度平台,用于解决分布式系统中的定时任务和异步任务调度问题。 它提供了任务的注册、调度、执行和监控等功能,能够帮助开发者…

激光打标机的维护与保养:确保设备长期稳定运行的关键

​ 激光打标机的维护与保养是确保设备长期稳定运行的关键,以下是一些关键的维护和保养步骤: 一、定期清洁 1. 清洁镜片:定期清洁激光打标机的镜片是维护保养的重要环节。使用纯净的酒精或专用的激光镜片清洗剂,轻轻擦拭镜片表面&…

WPS制作甘特图

“ 甘特图(Gantt chart)又称为横道图、条状图(Bar chart),通过条状图来显示项目、进度和其他时间相关的系统进展的内在关系随着时间进展的情况。” 设置基础样式 设置行高 设置宽度 准备基础数据 计算持续时间 …

C语言数组—二维数组

二维数组的创建 //数组创建 int arr[3][4]; //三行四列,存放整型变量 double arr[2][4];二维数组的初始化 我们如果这样初始化,效果是什么样的呢 int arr[3][4] { 1,2,3,4,5,6,7,8,9,10,11,12 };那如果我们不写满十二个呢 int arr[3][4] { 1,2,3,4…

超实用!免费软件站大盘点,总有一款适合你

相信用Mac电脑的同学都知道一个网站MacWK,可以白嫖几乎所有常用软件,不用付费,但不好的消息是在2022年10月宣布关站,小编从此走上了开源免费的道路,尽管不太好用,奈何口袋木有钱,经过小编的不断…

一个页面请求从在浏览器中输入网址到页面最终呈现

前言-与正文无关 生活远不止眼前的苦劳与奔波,它还充满了无数值得我们去体验和珍惜的美好事物。在这个快节奏的世界中,我们往往容易陷入工作的漩涡,忘记了停下脚步,感受周围的世界。让我们一起提醒自己,要适时放慢脚步…

通过调整报文偏移解决CAN应用报文丢帧或周期过长问题

偏移原理 报文很多都是周期性发送的,但是如果每条报文都以一开始作为开始计时的时间点,也就是一开始就发送第一条报文,可能会导致CAN堵塞,导致丢帧或者某些报文某一时刻周期过长,就像下图这样,同一时刻CAN…