DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition 中的空洞自注意力机制

news2025/2/23 1:00:53

空洞自注意力机制

文章目录

  • 摘要
  • 1. 模型解释
    • 1.1. 滑动窗口扩张注意力
    • 1.2. 多尺度扩张注意力
  • 2. 代码
  • 3. 流程图
    • 3.1. MultiDilatelocalAttention
    • 3.2. DilateAttention
    • 3.3. MLP

摘要

    本文针对DilateFormer中的空洞自注意力机制原理和代码进行详细介绍,最后通过流程图梳理其实现原理。

1. 模型解释

1.1. 滑动窗口扩张注意力

    根据在普通视觉变换器(ViTs)中浅层全局注意力中观察到的局部性稀疏性特性,我们提出了一种滑动窗口扩张注意力(SWDA) 操作,其中,keys和values被以query patch为中心的滑动窗口稀疏地选择。然后对这些代表性patches进行自注意力。我们的 SWDA 正式描述如下:

X = S W D A ( Q , K , V , r ) ( 1 ) \begin{aligned} &&&&&&&&&&&&& X = SWDA(Q,K,V,r) &&&&&&&&&&&&&&&& (1) \end{aligned} X=SWDA(Q,K,V,r)(1)

其中, Q , K , V Q,K,V Q,K,V分别代表query、key和value矩阵,三个矩阵的每一行表示一个query/key/value特征向量。对于原始特征图上 ( i , j ) (i,j) (i,j)位置的query,SWDA以尺寸为 w × w w×w w×w大小的滑动窗口,稀疏地选择key和value去指导自注意力。

    而且,我们定义一个扩张率 r ϵ N + r \epsilon N^+ N+去控制稀疏程度。特别地,对于位置 ( i , j ) (i,j) (i,j)SWDA计算的输出 X X X中的相应分量 x i j x_{ij} xij定义如下:

x i j = A t t e n t i o n ( q i j , K r , V r ) , ( 2 ) = S o f t m a x ( q i j K r T d k ) V r , 1 ≤ i ≤ W , 1 ≤ i ≤ H \begin{aligned} &&&&&&&&&&&& x_{ij} &= Attention(q_{ij},K_r,V_r), &&&&&&&&&&&&&&&& (2)\\ &&&&&&&&&&&&&=Softmax(\frac{q_{ij}K^T_r}{\sqrt{d_k}})V_r,& 1≤i≤W, 1≤i≤H \\ \end{aligned} xij=Attention(qij,Kr,Vr),=Softmax(dk qijKrT)Vr,1iW,1iH(2)

其中, H H H W W W 是特征图的高和宽。 K r K_r Kr V r V_r Vr表示从特征图 K K K V V V 中选择的keys和values。

    给定位于 ( i , j ) (i,j) (i,j)的query,位于坐标 ( i ′ , j ′ ) (i', j') (i,j) 下keys和values将被选择去指导自注意力(self-attetion):

{ ( i ′ , j ′ ) ∣ i ′ = i + p × r , j ′ = j + q × r } , − w 2 ≤ p , q ≤ w 2 . ( 3 ) \begin{aligned} &&&&&&&&&&&&& \{(i',j')|i'=i+p×r, j'=j+q×r \}, \frac{-w}{2}≤p, q≤\frac{w}{2}. &&&&&&&&&&&&&&&& (3) \end{aligned} {(i,j)i=i+p×r,j=j+q×r},2wp,q2w.(3)

    我们的 SWDA 以滑动窗口的方式对所有query patches进行自注意力操作。对于特征图边缘的query,我们简单地使用卷积运算中常用的 补零策略 来保持特征图的大小。通过稀疏地选择以queries为中心的keys和values,所提出的 SWDA 明确满足局部性和稀疏性属性,并且可以有效地对远程依赖关系进行建模

1.2. 多尺度扩张注意力

在这里插入图片描述

图4. 多尺度空洞注意力。

    首先,特征图的通道被划分不同的heads。然后,自注意力操作是在红色查询块周围的窗口中的彩色块之间执行的,在不同的头中使用不同的膨胀率。此外,不同heads中的特征被连接在一起,然后输入到线性层中。默认情况下,我们使用 3 × 3 的内核大小,膨胀率 r = 1、2 和 3,不同头中参与感受野的大小为 3 × 3、5 × 5 和 7 × 7。

    为了利用块级自注意力机制在不同尺度上的稀疏性,我们进一步提出了多尺度扩张注意力(MSDA) 块来提取多尺度语义信息。如图4所示,给定特征图 X X X,我们通过 线性投影(linear projection) 获得相应的query、kay和value。之后,我们将特征图的通道划分到 n n n 个不同的 h e a d s heads heads,并在不同的 h e a d s heads heads中以不同的膨胀率(dilation rates)执行多尺度SWDA。具体来说,我们的MSDA计算如下:

h i = S W D A ( Q i , K i , V i , r i ) , 1 ≤ i ≤ n , ( 4 ) X = L i n e a r ( C o n c a t [ h 1 , . . . , h n ] ) , ( 5 ) \begin{aligned} &&&&&&&&&&&&& h_i=SWDA(Q_i,K_i,V_i,r_i), &1≤i≤n, &&&&&&&&&&&&&&&& (4)\\ &\\ &&&&&&&&&&&&& X=Linear(Concat[h_1,...,h_n]), &&&&&&&&&&&&&&&&& (5) \end{aligned} hi=SWDA(Qi,Ki,Vi,ri),X=Linear(Concat[h1,...,hn]),1in,(4)(5)

其中, r i r_i ri是第 i i i h e a d head head的扩张率, Q i , K i Q_i,K_i Qi,Ki V i V_i Vi 代表馈入第 i i i h e a d head head的特征图切片。输出 { h i } i = 1 n \{h_i\}_{i=1}^n {hi}i=1n被concat到一起,然后送到线性层进行特征聚合。

    通过为不同的 h e a d s heads heads 设置不同的扩张率,我们的 MSDA 有效地聚合了参与感受野内不同尺度的语义信息,并有效地减少了自注意力机制的冗余,而无需复杂的操作和额外的计算成本。

2. 代码

import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class DilateAttention(nn.Module):
    "Implementation of Dilate-attention"
    def __init__(self, head_dim, qk_scale=None, attn_drop=0, kernel_size=3, dilation=1):
        super().__init__()
        self.head_dim = head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.kernel_size=kernel_size
        self.unfold = nn.Unfold(kernel_size, dilation, dilation*(kernel_size-1)//2, 1)
        self.attn_drop = nn.Dropout(attn_drop)

    def forward(self,q,k,v):
        #B, C//3, H, W
        q, k, v = q.detach(), k.detach(), v.detach()  # todo:!!!
        B,d,H,W = q.shape
        q = q.reshape([B, d//self.head_dim, self.head_dim, 1 ,H*W]).permute(0, 1, 4, 3, 2)  # B,h,N,1,d
        k = self.unfold(k).reshape([B, d//self.head_dim, self.head_dim, self.kernel_size*self.kernel_size, H*W]).permute(0, 1, 4, 2, 3)  #B,h,N,d,k*k
        attn = (q @ k) * self.scale  # B,h,N,1,k*k
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        v = self.unfold(v).reshape([B, d//self.head_dim, self.head_dim, self.kernel_size*self.kernel_size, H*W]).permute(0, 1, 4, 3, 2)  # B,h,N,k*k,d
        x = (attn @ v).transpose(1, 2).reshape(B, H, W, d)
        return x


class MultiDilatelocalAttention(nn.Module):
    "Implementation of Dilate-attention"

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
                 attn_drop=0.,proj_drop=0., kernel_size=3, dilation=[1, 2, 3]):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.dilation = dilation
        self.kernel_size = kernel_size
        self.scale = qk_scale or head_dim ** -0.5
        self.num_dilation = len(dilation)
        assert num_heads % self.num_dilation == 0, f"num_heads{num_heads} must be the times of num_dilation{self.num_dilation}!!"
        self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
        self.dilate_attention = nn.ModuleList(
            [DilateAttention(head_dim, qk_scale, attn_drop, kernel_size, dilation[i])
             for i in range(self.num_dilation)])
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, H, W, C = x.shape
        x = x.permute(0, 3, 1, 2)# B, C, H, W
        qkv = self.qkv(x).reshape(B, 3, self.num_dilation, C//self.num_dilation, H, W).permute(2, 1, 0, 3, 4, 5)
        #num_dilation,3,B,C//num_dilation,H,W
        x = x.reshape(B, self.num_dilation, C//self.num_dilation, H, W).permute(1, 0, 3, 4, 2 )
        # num_dilation, B, H, W, C//num_dilation
        for i in range(self.num_dilation):
            x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])# B, H, W,C//num_dilation
        x = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class DilateBlock(nn.Module):
    "Implementation of Dilate-attention block"
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3, dilation=[1, 2, 3],
                 cpe_per_block=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.cpe_per_block = cpe_per_block
        if self.cpe_per_block:
            self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.norm1 = norm_layer(dim)
        self.attn = MultiDilatelocalAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                                attn_drop=attn_drop, kernel_size=kernel_size, dilation=dilation)

        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop)

    def forward(self, x):
        if self.cpe_per_block:
            x = x + self.pos_embed(x)
        x = x.permute(0, 2, 3, 1)
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        x = x.permute(0, 3, 1, 2)
        #B, C, H, W
        return x


if __name__ == "__main__":
    x = torch.rand([2,72,56,56])

    B, C, H, W = x.shape
    dim = C
    num_heads = 3   # 必须是dilation的整数倍 且 被dim整除
    head_dim = dim // num_heads
    #######################

    drop_path=0.1
    depths = [2, 2, 6, 2]
    num_layers = len(depths)
    dpr = [x.item() for x in torch.linspace(0, drop_path, sum(depths))]
    for i_layer in range(num_layers):
        drop_paths = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])]
    #######################
    m = DilateBlock(dim=C,
                    num_heads=num_heads,
                    kernel_size=3,
                    dilation=[1,2,3],
                    mlp_ratio=4.,
                    qkv_bias=True,
                    qk_scale=head_dim ** -0.5,
                    drop=0.,
                    attn_drop=0.,
                    drop_path=drop_paths[1] if isinstance(drop_paths, list) else drop_paths,
                    norm_layer=nn.LayerNorm, act_layer=nn.GELU, cpe_per_block=True)

    y = m(x)
    print(y.shape)

3. 流程图

在这里插入图片描述


3.1. MultiDilatelocalAttention

在这里插入图片描述

3.2. DilateAttention

在这里插入图片描述

3.3. MLP

在这里插入图片描述

完整流程图如下:

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2280036.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大模型GUI系列论文阅读 DAY2续:《一个具备规划、长上下文理解和程序合成能力的真实世界Web代理》

摘要 预训练的大语言模型(LLMs)近年来在自主网页自动化方面实现了更好的泛化能力和样本效率。然而,在真实世界的网站上,其性能仍然受到以下问题的影响:(1) 开放领域的复杂性,(2) 有限的上下文长度&#xff…

Qt按钮美化教程

前言 Qt按钮美化主要有三种方式:QSS、属性和自绘 QSS 字体大小 font-size: 18px;文字颜色 color: white;背景颜色 background-color: rgb(10,88,163); 按钮边框 border: 2px solid rgb(114,188,51);文字对齐 text-align: left;左侧内边距 padding-left: 10…

云IDE:开启软件开发的未来篇章

敖行客一直致力于将整个研发协作流程线上化,从而打破物理环境依赖,让研发组织模式更加灵活、自由且高效,今天就来聊聊AT Work(一站式研发协作平台)的重要组成部分-云IDE。 在科技领域,历史常常是未来的风向…

AI agent 在 6G 网络应用,无人机群控场景

AI agent 在 6G 网络应用,无人机群控场景 随着 6G 时代的临近,融合人工智能成为关键趋势。借鉴 IT 行业 AI Agent 应用范式,提出 6G AI Agent 技术框架,包含多模型融合、定制化 Agent 和插件式环境交互理念,构建了涵盖四层结构的框架。通过各层协同实现自主环境感知等能力…

【Linux 重装】Ubuntu 启动盘 U盘无法被识别,如何处理?

背景 U盘烧录了 Ubuntu 系统作为启动盘,再次插入电脑后无法被识别 解决方案(Mac 适用) (1)查找 USB,(2)格式化(1)在 terminal 中通过 diskutil list 查看是…

【优选算法篇】2----复写零

---------------------------------------begin--------------------------------------- 这道算法题相对于移动零,就上了一点点强度咯,不过还是很容易理解的啦~ 题目解析: 这道题如果没理解好题目,是很难的,但理解题…

高效建站指南:通过Portainer快速搭建自己的在线网站

文章目录 前言1. 安装Portainer1.1 访问Portainer Web界面 2. 使用Portainer创建Nginx容器3. 将Web静态站点实现公网访问4. 配置Web站点公网访问地址4.1公网访问Web站点 5. 固定Web静态站点公网地址6. 固定公网地址访问Web静态站点 前言 Portainer是一个开源的Docker轻量级可视…

redis性能优化参考——筑梦之路

基准性能测试 redis响应延迟耗时多长判定为慢? 比如机器硬件配置比较差,响应延迟10毫秒,就认为是慢,机器硬件配置比较高,响应延迟0.5毫秒,就认为是慢。这个没有固定的标准,只有了解了你的 Red…

Python 入门教程(2)搭建环境 | 2.3、VSCode配置Python开发环境

文章目录 一、VSCode配置Python开发环境1、软件安装2、安装Python插件3、配置Python环境4、包管理5、调试程序 前言 Visual Studio Code(简称VSCode)以其强大的功能和灵活的扩展性,成为了许多开发者的首选。本文将详细介绍如何在VSCode中配置…

Trimble三维激光扫描-地下公共设施维护的新途径【沪敖3D】

三维激光扫描技术生成了复杂隧道网络的高度详细的三维模型 项目背景 纽约州北部的地下通道网络已有100年历史,其中包含供暖系统、电线和其他公用设施,现在已经开始显露出老化迹象。由于安全原因,第三方的进入受到限制,在没有现成纸…

【强化学习】策略梯度(Policy Gradient,PG)算法

📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅…

Apache SeaTunnel 2.3.9 正式发布:多项新特性与优化全面提升数据集成能力

近日,Apache SeaTunnel 社区正式发布了最新版本 2.3.9。本次更新新增了Helm 集群部署、Transform 支持多表、Zeta新API、表结构转换、任务提交队列、分库分表合并、列转多行 等多个功能更新! 作为一款开源、分布式的数据集成平台,本次版本通过…

4 AXI USER IP

前言 使用AXI Interface封装IP,并使用AXI Interface实现对IP内部寄存器进行读写实现控制LED的demo,这个demo是非常必要的,因为在前面的笔记中基本都需哟PS端与PL端就行通信互相交互,在PL端可以通过中断的形式来告知PS端一些事情&…

B站评论系统的多级存储架构

以下文章来源于哔哩哔哩技术 ,作者业务 哔哩哔哩技术. 提供B站相关技术的介绍和讲解 1. 背景 评论是 B站生态的重要组成部分,涵盖了 UP 主与用户的互动、平台内容的推荐与优化、社区文化建设以及用户情感满足。B站的评论区不仅是用户互动的核心场所&…

电子科大2024秋《大数据分析与智能计算》真题回忆

考试日期:2025-01-08 课程:成电信软学院-大数据分析与智能计算 形式:开卷 考试回忆版 简答题(4*15) 1. 简述大数据的四个特征。分析每个特征所带来的问题和可能的解决方案 2. HDFS的架构的主要组件有哪些&#xff0…

多选multiple下拉框el-select回显问题(只显示后端返回id)

首先保证v-model的值对应options数据源里面的id <el-form-item prop"subclass" label"分类" ><el-select v-model"formData.subclass" multiple placeholder"请选择" clearable :disabled"!!formData.id"><e…

JavaWeb开发(十五)实战-生鲜后台管理系统(二)注册、登录、记住密码

1. 生鲜后台管理系统-注册功能 1.1. 注册功能 &#xff08;1&#xff09;创建注册RegisterServlet&#xff0c;接收form表单中的参数。   &#xff08;2&#xff09;service创建一个userService处理业务逻辑。   &#xff08;3&#xff09;RegisterServlet将参数传递给ser…

【MySQL系列文章】Linux环境下安装部署MySQL

前言 本次安装部署主要针对Linux环境进行安装部署操作,系统位数64 getconf LONG_BIT 64MySQL版本&#xff1a;v5.7.38 一、下载MySQL MySQL下载地址&#xff1a;MySQL :: Download MySQL Community Server (Archived Versions) 二、上传MySQL压缩包到Linuxx环境&#xff0c…

嵌入式硬件篇---基本组合逻辑电路

文章目录 前言基本逻辑门电路1.与门&#xff08;AND Gate&#xff09;2.或门&#xff08;OR Gate&#xff09;3.非门&#xff08;NOT Gate&#xff09;4.与非门&#xff08;NAND Gate&#xff09;5.或非门&#xff08;NOR Gate&#xff09;6.异或门&#xff08;XOR Gate&#x…

基于微信小程序的手机银行系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…