Deformable DETR 论文学习

news2024/11/25 1:33:17

1. 解决了什么问题?

DETR 去除了目标检测算法中的人为设计,取得了不错的表现。但是其收敛速度很慢,对低分辨率特征识别效果差:

  • 模型初始化时,注意力模块给特征图上所有的像素点分配的权重是均匀的,就需要较长的训练 epochs 使注意力权重学习关注到稀疏的、有意义的像素位置。
  • Transformer encoder 中的注意力权重的计算量关于像素的个数是 quadratic,因此计算量和内存占用就非常高。
  • 缺少多尺度特征来检测小目标。

2. 提出了什么方法?

Deformable DETR 将 deformable conv 的稀疏空间采样和 transformer 的关系建模能力结合,缓解了收敛慢和高复杂度的问题,提出了 deformable attention 模块,只关注于特征图上每个像素少量的采样位置。
在这里插入图片描述

2.1 Multi-head Attention

给定一个 query element 以及一个 key elements 的集合,multi-head attention 模块根据注意力权重聚合 key 对应的 values,该注意力权重表示 query-key pairs 之间的匹配程度。为了让模型关注于不同表征子空间和位置的内容,我们将不同的 attention heads 的输出用可学习的权重线性聚合起来。 q ∈ Ω q q\in\Omega_q qΩq是 query,其对应的表征是 z q ∈ R C z_q\in\mathbb{R}^C zqRC k ∈ Ω k k\in \Omega_k kΩk是 key,对应表征是 x k ∈ R C x_k\in\mathbb{R}^C xkRC C C C是特征维度。Multi-head attention 特征计算过程如下:

MultiHeadAttn ( z q , x ) = ∑ m = 1 M W m [ ∑ k ∈ Ω k A m q k ⋅ W m ′ x k ] \text{MultiHeadAttn}(z_q,x)=\sum_{m=1}^M W_m \left[ \sum_{k\in\Omega_k} A_{mqk}\cdot W'_m x_k \right] MultiHeadAttn(zq,x)=m=1MWm[kΩkAmqkWmxk]

其中 m m m是 attention head 索引, W m ′ ∈ R C v × C , W m ∈ R C × C v W'_m\in \mathbb{R}^{C_v\times C}, W_m\in \mathbb{R}^{C\times C_v} WmRCv×C,WmRC×Cv是可学习权重( C v = C / M C_v=C/M Cv=C/M)。注意力权重 A m q k ∝ exp ⁡ { z q T U m T V m x k C v } A_{mqk}\propto \exp\lbrace \frac{z_q^T U_m^T V_m x_k}{\sqrt{C_v}} \rbrace Amqkexp{Cv zqTUmTVmxk},并且 ∑ k ∈ Ω k A m q k = 1 \sum_{k\in \Omega_k}A_{mqk}=1 kΩkAmqk=1,其中 U m , V m ∈ R C v × C U_m, V_m\in \mathbb{R}^{C_v\times C} Um,VmRCv×C是 query 和 key 的映射矩阵,也是可学习权重。为了区分不同空间位置, z q z_q zq x k x_k xk会加上 positional encodings。在图像里面,query 和 key 元素都是像素点,因此 MultiHeadAttn 的复杂度是 O ( N q C 2 + N k C 2 + N q N k C ) , N q = N k ≫ C \mathcal{O}(N_q C^2 + N_k C^2 + N_qN_kC), N_q=N_k \gg C O(NqC2+NkC2+NqNkC),Nq=NkC,复杂度由第三项决定,因此 multi-head attn 的复杂度与特征图大小呈 quadratic。

2.2 Deformable Attention Module

在图像特征上使用 transformer attention,它会查看所有可能的空间位置。于是,作者提出了 deformable attention,它只会关注于 reference point 附近的 key sampling points,不管特征图有多大。这样,就缓解了收敛速度慢和空间分辨率的问题。
给定特征图 x ∈ R C × H × W x\in\mathbb{R}^{C\times H\times W} xRC×H×W q q q是 query 的索引,对应的特征是 z q z_q zq p q p_q pq是 2D reference point,deformable attention 特征计算过程如下:
DeformAttn ( z q , p q , x ) = ∑ m = 1 M W m [ ∑ k = 1 K A m q k ⋅ W m ′ ⋅ x ( p q + Δ p m q k ) ] \text{DeformAttn}(z_q,p_q,x)=\sum_{m=1}^M W_m \left[ \sum_{k=1}^K A_{mqk}\cdot W'_m \cdot x(p_q+\Delta p_{mqk}) \right] DeformAttn(zq,pq,x)=m=1MWm[k=1KAmqkWmx(pq+Δpmqk)]

m m m是注意力 head 的索引, k k k是采样 keys 的索引, K K K是所有采样点的个数( K ≪ H W K\ll HW KHW)。 Δ p m q k \Delta p_{mqk} Δpmqk A m q k A_{mqk} Amqk分别表示第 m m m个注意力 head 里面第 k k k个采样点的偏移量和注意力权重。 A m q k ∈ [ 0 , 1 ] , ∑ k = 1 K A m q k = 1 , Δ p m q k ∈ R 2 A_{mqk}\in[0,1],\sum_{k=1}^K A_{mqk}=1,\Delta p_{mqk}\in\mathbb{R}^2 Amqk[0,1],k=1KAmqk=1,ΔpmqkR2。通过双线性插值计算出 x ( p q + Δ p m q k ) x(p_q + \Delta p_{mqk}) x(pq+Δpmqk),对 z q z_q zq做 linear projection 得到 Δ p m q k , A m q k \Delta p_{mqk}, A_{mqk} Δpmqk,Amqk。将 z q z_q zq输入进一个 3 M K 3MK 3MK通道的 linear projection ,前 2 M K 2MK 2MK个通道编码了采样点偏移量 Δ p m q k \Delta p_{mqk} Δpmqk,后 M K MK MK个通道输入进 softmax 得到注意力权重 A m q k A_{mqk} Amqk

2.2.1 计算复杂度

N q N_q Nq是 query 个数, M K MK MK相对较小,deformable attention 模块的复杂度就是 O ( 2 N q C 2 + min ⁡ ( H W C 2 , N q K C 2 ) ) \mathcal{O}(2N_qC^2+\min(HWC^2, N_qKC^2)) O(2NqC2+min(HWC2,NqKC2)),过程如下:

  • 计算采样坐标偏移量 Δ p m q k \Delta p_{mqk} Δpmqk和注意力权重 A m q k A_{mqk} Amqk的复杂度是 O ( N q C ⋅ 3 M K ) \mathcal{O}(N_qC\cdot 3MK) O(NqC3MK)
  • 在 decoder 中,有了坐标偏移量和注意力权重,DeformAttn 的复杂度就是 O ( N q C 2 + N q K C 2 + 5 N q K C ) \mathcal{O}(N_qC^2 + N_qKC^2+5N_qKC) O(NqC2+NqKC2+5NqKC),其中 5 N q K C 5N_qKC 5NqKC是因为注意力的双线性插值和加权和操作。
  • 在 encoder 中, N q = H W N_q=HW Nq=HW,我们可以在采样前计算 W m ′ x W'_mx Wmx,DeformAttn 的复杂度就是 O ( N q C 2 + H W C 2 + 5 N q K C ) \mathcal{O}(N_qC^2 + HWC^2+5N_qKC) O(NqC2+HWC2+5NqKC)
  • 总的复杂度就是 O ( N q C 2 + min ⁡ ( N q K C 2 , H W C 2 ) + 5 N q K C + 3 N q C M K ) \mathcal{O}(N_qC^2 + \min(N_qKC^2,HWC^2)+5N_qKC+3N_qCMK) O(NqC2+min(NqKC2,HWC2)+5NqKC+3NqCMK),在实验中 M = 8 , K ≤ 4 , C = 256 M=8,K\leq 4,C=256 M=8,K4,C=256,因此 5 K + 3 M K < C 5K+3MK < C 5K+3MK<C,因此复杂度为 O ( 2 N q C 2 + min ⁡ ( H W C 2 , N q K C 2 ) ) \mathcal{O}(2N_qC^2 + \min(HWC^2,N_qKC^2)) O(2NqC2+min(HWC2,NqKC2))

2.3 多尺度 Deformable Attention 模块

{ x l } l = 1 L \lbrace x^l \rbrace^L_{l=1} {xl}l=1L是多尺度特征图,其中 x l ∈ R C × H l × W l x^l\in \mathbb{R}^{C\times H_l\times W_l} xlRC×Hl×Wl。$\hat{p}_q\in \left[ 0,1\right]^2 $是每个 query q q q的 reference point 归一化后的坐标,多尺度 deformable attention 模块计算过程为:

MSDeformAttn ( z q , p ^ q , { x l } l = 1 L ) = ∑ m = 1 M W m [ ∑ l = 1 L ∑ k = 1 K A m l q k ⋅ W m ′ ⋅ x l ( ϕ l ( p ^ q ) + Δ p m l q k ) ] \text{MSDeformAttn}(z_q,\hat{p}_q,\lbrace x^l \rbrace_{l=1}^L)=\sum_{m=1}^M W_m \left[ \sum_{l=1}^L \sum_{k=1}^K A_{mlqk} \cdot W'_m \cdot x^l (\phi_l(\hat{p}_q) + \Delta p_{mlqk}) \right] MSDeformAttn(zq,p^q,{xl}l=1L)=m=1MWm[l=1Lk=1KAmlqkWmxl(ϕl(p^q)+Δpmlqk)]

其中 m m m是 attention head 的索引, l l l是输入特征层级的索引, k k k是采样点的索引。 Δ p m l q k \Delta p_{mlqk} Δpmlqk A m l q k A_{mlqk} Amlqk表示第 l l l个特征层级、第 m m m个注意力 head 上第 k k k个采样点的采样位置偏移量和注意力权重, ∑ l = 1 L ∑ k = 1 K A m l q k = 1 \sum_{l=1}^L \sum_{k=1}^K A_{mlqk}=1 l=1Lk=1KAmlqk=1。函数 ϕ l ( p ^ q ) \phi_l(\hat{p}_q) ϕl(p^q)将归一化的坐标 p ^ q \hat{p}_q p^q重新缩放到第 l l l层级的特征图上。

2.4 Deformable Transformer Encoder

Encoder 的输入和输出是分辨率相同的多尺度特征图。从 ResNet 中提取 C 3 C_3 C3 C 5 C_5 C5阶段的输出特征图 { x l } l = 1 L − 1 , ( L = 4 ) \lbrace x^l \rbrace_{l=1}^{L-1},(L=4) {xl}l=1L1(L=4)。其中 C l C_l Cl的分辨率要比输入图片小 2 l 2^l 2l。最低的分辨率是 C 6 C_6 C6,通过对 C 5 C_5 C5输出的特征图使用 3 × 3 3\times 3 3×3且步长为 2 的卷积操作得到。所有的特征图通道数都是 256。
多尺度特征图的输出的分辨率与输入相同。Key 和 query 元素都是多尺度特征图上的像素点。对于 query 像素点,reference point 就是其本身。为了鉴别 query 像素所在的特征层级,作者在特征表征中加入了一个尺度层级 embedding,即 e l e_l el。尺度层级 embedding { e l } l = 1 L \lbrace e_l \rbrace_{l=1}^L {el}l=1L随机初始化,与网络协同训练。

2.5 Deformable Transformer Decoder

Decoder 包括 cross attention 和 self attention。这两种 attention 模块的 query 元素都是 object queries。在 cross attention 中,object queries 从特征图上提取特征,key 元素来自于 encoder 输出的特征图。在 self attention 中,object queries 互相交流,key 元素就是 object queries。作者只将 cross attention 替换为了 multi-scale deformable attention,而保留了 self attention。对于每个 object query,使用 linear projection 以及一个 sigmoid 函数,预测一个reference point p ^ q \hat{p}_q p^q的 2D 归一化坐标。然后,检测 head 关于这个reference point 预测边框的坐标,降低优化难度。这个reference point 就可以看作为边框中心点的 initial guess。

2.6 实验

在 MS COCO 2017 数据集进行的实验。在主干网络中使用了 ImageNet 上预训练的 ResNet-50 权重,没有用 FPN 提取多尺度特征图。在 deformable attention 中, M = 8 , K = 4 M=8,K=4 M=8,K=4。Deformable transformer encoder 中不同特征层级的参数共享权重。Focal Loss 中边框分类损失的权重设为 2。Object queries 个数设为了 300。所有的模型训练了 50 个 epochs,学习率在第 40 个 epoch 时乘以 0.1 衰减。模型训练采用了 Adam 优化器,初始学习率为 2 × 1 0 − 4 , β 1 = 0.9 , β 2 = 0.999 2\times 10^{-4},\beta_1=0.9,\beta_2=0.999 2×104,β1=0.9,β2=0.999,weight decay 为 1 0 − 4 10^{-4} 104。用于预测 object query reference point 和采样偏移量的 linear projection,它的学习率会乘以系数 0.1。
在这里插入图片描述

从上图可以看出,Deformable DETR 的收敛速度和表现明显优于 DETR。

3. 有什么优点?

  • 将 deformable conv 和 transformer attention 结合,只关注reference point 附近少量的采样点,大幅提高了收敛速度,并保证了准确性。
  • 通过 multi-scale 方式,提高了对小目标的检测效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/552110.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

chatgpt赋能Python-python5的阶乘

Python5的阶乘介绍 Python是一门广泛应用于编写脚本、自动化、爬虫、数据分析等方面的编程语言&#xff0c;也是很多科研领域使用的首选。Python的功能和灵活性可以帮助用户解决各种问题&#xff0c;而本文要介绍的是Python中计算阶乘的方法。 阶乘是数学中的一个概念&#x…

fullter 学习记录_01_插件整理

flutter学习记录第一节--搭建项目及路由的设置 1.轮播图: flutter_swiper1.1 用处1.2 导入flutter_swiper库1.3 导入库&#xff0c;运行后可能遇到的问题1.4 属性说明1.5 代码案例 2. flutter_screenutil2.1 用处2.2 引用2.3 使用说明2.4 代码实现按理2.5 ScreenUtl 的封装 1.轮…

UNIX环境高级编程——守护进程

13.1 引言 守护进程&#xff08;daemon&#xff09;是生存期长的一种进程。它们常常在系统引导装入时启动&#xff0c;仅在系统关闭时才终止。因为它们没有控制终端&#xff0c;所以说它们是在后台运行的。 13.2 守护进程的特征 系统进程依赖于操作系统实现。父进程ID为0的各…

xxs跨站之原理分类及攻击手法

xss跨站达到原理&#xff0c;危害和特点 他和语言没有太大关系&#xff0c;它大部分都是属于一个前端的漏洞&#xff0c;搭建一个简易的php网站存在xss跨站漏洞 访问这个网站&#xff0c;x1&#xff0c;就输出1&#xff0c; 如果我们把x<script>alert(1)</script&g…

系统分析师考试之论文框架

系统分析师考试之论文框架 系统分析师考试之论文框架

SpringBoot—常用注解

目录 一、注解(annotations)列表 二、注解(annotations)详解 三、JPA注解 四、springMVC相关注解 五、全局异常处理 一、注解(annotations)列表 SpringBootApplication&#xff1a; 包含了ComponentScan、Configuration和EnableAutoConfiguration注解。其中ComponentScan…

linux 读写锁 pthread_rwlock

专栏内容&#xff1a;linux下并发编程个人主页&#xff1a;我的主页座右铭&#xff1a;天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物&#xff0e; 目录 前言 概念介绍 应用场景 接口说明 头文件 rwlock定义 初始化/销毁 两种初始化方…

算法设计与智能计算 || 专题九: 基于拉普拉斯算子的谱聚类算法

谱聚类 文章目录 谱聚类1. 信息增益的度量2. 谱聚类: 寻找最优的函数向量 f \boldsymbol{f} f2.1 : 寻找一个最优的函数向量 f \boldsymbol{f} f2.2 寻找鲁棒性更强的多个函数向量2.3 谱聚类(spectral clustering)算法 小结 1. 信息增益的度量 由于数据集 X [ x 1 , x 2 ,…

自动化仓储管理系统(WMS)

仓储是现代物流的一个重要组成部分&#xff0c;在物流系统中起着至关重要的作用&#xff0c;是厂商研究和规划的重点。高效合理的仓储可以帮助厂商加快物资流动的速度&#xff0c;降低成本&#xff0c;保障生产的顺利进行&#xff0c;并可以实现对资源有效控制和管理。 随着我…

CVE-2023-27363 FOXIT PDF READER与EDITOR任意代码执行漏洞复现

目录 0x01 声明&#xff1a; 0x02 简介&#xff1a; 0x03 漏洞概述&#xff1a; 0x04 影响版本&#xff1a; 0x05 环境搭建&#xff1a; 文件下载&#xff1a; 0x06 漏洞复现&#xff1a; POC下载&#xff1a; 利用POC&#xff1a; RCE&#xff1a; 0x07 修复建议&a…

2023.5.21 第五十四次周报

目录 前言 文献阅读:跨多个时空尺度进行预测的时空 LSTM 模型 背景 本文思路 本文解决的问题 方法论 SPATIAL 自动机器学习模型 数据处理 模型性能 代码 用Python编写的LSTM多变量预测模型 总结 前言 This week, I studied an article that uses LSTM to solve p…

Kuberadm安装部署

目录 1、所需配置 2、准备环境 3、安装Docker 4、安装Kubeadm、Kubelet和Kuberctl 5、部署K8s集群 6、初始化kubeadm 7、设置kubecrl 8、部署网络插件flannel 测试 9、部署Dashboard 10、安装Harbor私有仓库 11、内核参数优化方案 1、所需配置 主机名系统要求IP地…

一款C#开发的窗口文本提取开源软件

窗体上的文字&#xff0c;我们是无法直接复制的&#xff0c;如果要获取&#xff0c;大家第一反应肯定是利用OCR&#xff0c;但是今天给大家推荐一个通过Hook原理来抓取文本的开源项目。 项目简介 这是一个基于.Net Framework开发的&#xff0c;功能强大的文本提取工具&#x…

【消息中间件】为什么要用消息中间件?解决了什么问题?

文章目录 前言一、什么是消息中间件&#xff1f;二、为什么使用&#xff1f;二、使用消息中间件前后对比三、当前几种中间件&#xff1a;四、使用消息中间件场景&#xff1a; 前言 一般认为&#xff0c;采用消息传送机制/消息队列 的中间件技术&#xff0c;进行数据交流&#…

[CISCN 2019华东南]Web4 day5

考察&#xff1a;任意文件读取 获取网卡地址 伪随机 打开界面&#xff0c;点击read somethings直接进行了跳转 直接修改url&#xff0c;发现没显示&#xff0c;但是访问错误的路由就会有no response 读取flag也无果&#xff0c;那就读一下/app/app.py&#xff0c;为什么读这个&…

postgresql|数据库|插件学习(二)---postgresql-12的外置插件pg_profile的启用和使用

前言&#xff1a; postgresql数据库有非常多的插件&#xff0c;那么&#xff0c;pg_profile算是监控类的插件&#xff0c;该插件会通过内置的pg_stat_statements插件和dblink插件这两个插件监控查询postgresql的状态&#xff0c;并可以通过打快照的方式得到awr报告。 ###注&a…

攻防世界 wife_wife

查看提示&#xff1a;不需要爆破 进入到靶场中&#xff0c;发现需要注册用户 到达注册页面&#xff0c;is admin需要打勾&#xff0c;并输入同样 burpsuite抓包 原来payload&#xff1a;{"username":"1","password":"1","isAdmin…

MySQL数据库---笔记2

MySQL数据库---笔记2 一、函数1.1、字符串函数1.2、数值函数1.3、日期函数1.4、流程函数 二、约束2.1、概述2.2、演示2.3、外键约束 一、函数 函数 是指一段可以直接被另一段程序调用的程序或代码。 1.1、字符串函数 函数功能CONCAT(S1,S2,…Sn)字符串拼接&#xff0c;将S1&…

分布式医疗云平台【基础功能搭建、数据表定义、ERP业务表、系统表、其它配置表 】(三)-全面详解(学习总结---从入门到深化)

基础功能搭建 创建数据库 数据表定义 ERP业务表 系统表 其它配置表 基础功能搭建 创建数据库 执行脚本&#xff0c;创建基础数据 数据表定义 业务表 his_care_order: 药用处方表&#xff0c;涉及的业务模块&#xff1a;处方收费、处方发药、新开检查 his_care_history…

jetson填坑-单独安装cuda,cudnn,tensorrt任意适用版本

前言 jetson无法单独安装cuda&#xff0c;cudnn&#xff0c;tensorrt的解决方法&#xff0c;比下载SDK manager刷机安装简单好多倍 这个方法是直接下载deb包安装&#xff0c;deb包安装网站 https://repo.download.nvidia.com/jetson/ 单独安装cuda 1 sudo apt-get install …