【ECCV2022】DaViT: Dual Attention Vision Transformers

news2024/11/15 12:57:24

DaViT: Dual Attention Vision Transformers, ECCV2022

解读:【ECCV2022】DaViT: Dual Attention Vision Transformers - 高峰OUC - 博客园 (cnblogs.com)

DaViT:双注意力Vision Transformer - 知乎 (zhihu.com) 

DaViT: Dual Attention Vision Transformers - 知乎 (zhihu.com) 

论文:https://arxiv.org/abs/2204.03645

代码:https://github.com/dingmyu/davit

动机

以往的工作一般是,在分辨率、全局上下文和计算复杂度之间权衡:像素级和patch级的self-attention要么是有二次计算成本,要么损失全局上下文信息。除了像素级和patch级的self-attention的变化之外,是否可以设计一个图像级的self-attention机制来捕获全局信息?

作者提出了Dual Attention Vision Transformers (DaViT),能够在保持计算效率的同时捕获全局上下文。提出的方法具有层次结构和细粒度局部注意的优点,同时采用 group channel attention,有效地建模全局环境。

创新点:

  • 提出 Dual Attention Vision Transformers(DaViT),它交替地应用spatial window attentionchannel group attention来捕获长短依赖关系。
  • 提出 channel group attention,将特征通道划分为几个组,并在每个组内进行图像级别的交互。通过group attention,作者将空间和通道维度的复杂性降低到线性。

方法

dual attention

双attention机制是从两个正交的角度来进行self-attention:

一是对spatial tokens进行self-attention,此时空间维度(HW)定义了tokens的数量,而channel维度(C)定义了tokens的特征大小,这其实也是ViT最常采用的方式;

二是对channel tokens进行self-attention,这和前面的处理完全相反,此时channel维度(C)定义了tokens的数量,而空间维度(HW)定义了tokens的特征大小。

可以看出两种self-attention完全是相反的思路。为了减少计算量,两种self-attention均采用分组的attention:对于spatial token而言,就是在空间维度上划分成不同的windows,这就是Swin中所提出的window attention,论文称之为spatial window attention;而对于channel tokens,同样地可以在channel维度上划分成不同的groups,论文称之为channel group attention

 (a)空间窗口多头自注意将空间维度分割为局部窗口,其中每个窗口包含多个空间token。每个token也被分成多个头。(b)通道组单自注意组将token分成多组。在每个通道组中使用整个图像级通道作为token进行Attention。在(a)中也突出显示了捕获全局信息的通道级token。交替地使用这两种类型的注意力机制来获得局部的细粒度,以及全局特征。

两种attention能够实现互补:spatial window attention能够提取windows内的局部特征,而channel group attention能学习到全局特征,这是因为每个channel token在图像空间上都是全局的。

dual attention block

dual attention block的模型架构,它包含两个transformer block:空间window self-attention block和通道group self-attention block。通过交替使用这两种类型的attention机制,作者的模型能实现局部细粒度和全局图像级交互。图3(a)展示了作者的dual attention block的体系结构,包括一个空间window attention block和一个通道group attention block。

Spatial Window Attention

将patchs按照空间结构划分为Nw个window,每个window 里的patchs(Pw)单独计算self-attention:(P=Nw*Pw)

 Channel Group Attention

将channels分为Ng个group,每个group的channel数量为Cg,有C=Ng*Cg,计算如下: 

关键代码

class SpatialBlock(nn.Module):
    r""" Windows Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7,
                 mlp_ratio=4., qkv_bias=True, drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 ffn=True, cpe_act=False):
        super().__init__()
        self.dim = dim
        self.ffn = ffn
        self.num_heads = num_heads
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        # conv位置编码
        self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
                                  ConvPosEnc(dim=dim, k=3, act=cpe_act)])

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim,
            window_size=to_2tuple(self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        if self.ffn:
            self.norm2 = norm_layer(dim)
            mlp_hidden_dim = int(dim * mlp_ratio)
            self.mlp = Mlp(
                in_features=dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer)

    def forward(self, x, size):
        H, W = size
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = self.cpe[0](x, size) # depth-wise conv
        x = self.norm1(shortcut)
        x = x.view(B, H, W, C)

        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows)

        # merge windows
        attn_windows = attn_windows.view(-1,
                                         self.window_size,
                                         self.window_size,
                                         C)
        x = window_reverse(attn_windows, self.window_size, Hp, Wp)

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)

        x = self.cpe[1](x, size) # 第2个depth-wise conv
        if self.ffn:
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x, size


class ChannelAttention(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads # 这里的num_heads实际上是num_groups
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        # 得到query,key和value,是在channel维度上进行线性投射
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        k = k * self.scale
        attention = k.transpose(-1, -2) @ v # 对维度进行反转
        attention = attention.softmax(dim=-1)
        x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

DaViT采用金字塔结构,共包含4个stages,每个stage的开始时都插入一个 patch embedding 层。作者在每个stage叠加dual attention block,这个block就是将两种attention(还包含FFN)交替地堆叠在一起,其分辨率和特征维度保持不变。

采用stride=4的7x7 conv,然后是4个stage,各stage通过stride=2的2x2 conv来进行降采样。其中DaViT-Tiny,DaViT-Small和DaViT-Base三个模型的配置如下所示:

 

实验

 ​​​​​​

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/612558.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在线教育APP小程序系统开发 教培行业一站式解决方案

移动互联网如今已经深入到我们生活的方方面面,教育行业也不例外。如今市面上的在线教育APP小程序系统开发大受欢迎,很多学校、培训机构等都争相开发应用软件,以求通过全新的模式来满足不断扩大的市场需求,为用户提供更高质量的服务…

约瑟夫环(递归+迭代)

剑指 Offer 62. 圆圈中最后剩下的数字 leetcode 这题让我对递归和迭代又有了新的一层认识,首先一定要把图画对,就是模拟约瑟夫的这个过程 红色是被淘汰的位置,绿色的3是最后会活下来的人的位置 0 ~ n 正好是数组中的下标 重点在于计算 不同…

Java009——Java数据类型简单认识

围绕以下3点学习: 1、什么是Java数据类型? 2、Java数据类型的作用? 3、Java有哪些数据类型? 4、熟悉Java8大基本数据类型 一、什么是Java数据类型? 当我们写Java代码时,需要把数据保存在变量(…

【备战秋招】每日一题:4月8日美团春招(三批)第一题:题面+题目思路 + C++/python/js/Go/java带注释

2023大厂笔试模拟练习网站(含题解) www.codefun2000.com 最近我们一直在将收集到的各种大厂笔试的解题思路还原成题目并制作数据,挂载到我们的OJ上,供大家学习交流,体会笔试难度。现已录入200道互联网大厂模拟练习题&…

Unity之反向动力学IK

1、如何使用 (1)给物体的父对象加上IK Manager的脚本 (2)在人物四肢骨骼末端和权杖末端创建空对象 (3)添加IK节点 选择Player 添加后会发现出现了一个Player的子对象IK节点 将权杖末端的GameObject拖入…

deepin搭建go开发环境(git、go、neovim、NvChad、Nerd Font)

安装deepin虚拟机 官网下载地址 vmware中记得版本选择是debian 10.x 64位 然后就是一些确认操作,然后就可以了 安装git apt install gedit apt install git git config --global user.name "hello" git config --global user.email hello126.com git c…

Nginx之location与rewrite

Nginx之location与rewrite 一.location location 对访问的路径做访问控制或者代理转发1.匹配分类 精准匹配: location / {...} 前缀匹配: location ^~ / {...} 正则匹配: location ~ / {...} location ~* / {...} 部分…

直播带货APP小程序系统开发功能有哪些?

直播带货APP小程序系统开发功能有哪些? 1、直播带货:主播一边带货一边直播间活跃气氛,直观地了解产品,下单的概率会更高。还集有观看、打赏、购物、分享于一体。 2、短视频带货:短视频种草,利…

图数据库(一):Neo4j入门

什么是Neo4j 我们可以看一下百度百科对其的定义:Neo4j是一个高性能的,NOSQL图形数据库,它将结构化数据存储在网络上而不是表中。它是一个嵌入式的、基于磁盘的、具备完全的事务特性的Java持久化引擎,但是它将结构化数据存储在网络(从数学角度…

北京云服务器安装anaconda、cuda、cudnn、torch流程

安装顺序:Anaconda->cuda->cudnn ->torch(根据cuda安装torch) 1读取anaconda module load anaconda/2020.11 2读取cuda cudnn module load cuda/11.1 cudnn/8.2.1_cuda11.x (查看cuda版本: nvcc -V) 3运行脚本 sbatch train.sh 脚本写法 #!/bin/bash #SB…

Nextcloud私有云 - 零基础搭建私有云盘

文章目录 摘要视频教程1. 环境搭建2. 测试局域网访问3. 内网穿透3.1 ubuntu本地安装cpolar3.2 创建隧道3.3 测试公网访问 4 配置固定http公网地址4.1 保留一个二级子域名4.1 配置固定二级子域名4.3 测试访问公网固定二级子域名 转载自cpolar极点云的文章:使用Nextcl…

hash传递攻击

简介 Pass the hash也就是Hash传递攻击,简称为PTH。模拟用户登录不需要用户明文密码只需要hash值就可以直接来登录目标系统。 利用前提条件是: 开启445端口开启ipc$共享 Metasploit pesexec模块 windows/smb/psexec 这里主要设置smbuser、smbPass …

Docker镜像与Docker容器详解

文章目录 Docker帮助信息Docker镜像拉取镜像镜像仓库官方与非官方镜像仓库 拉取指定标签的镜像为镜像打多个标签搜索镜像镜像和分层共享镜像层根据摘要拉取镜像删除镜像总结 Docker容器启动一个简单的容器容器进程容器生命周期优雅的退出容器利用重启策略进行容器的自我修复Web…

7. JVM调优实战及常量池详解

JVM性能调优 1. 阿里巴巴Arthas详解1.1 Arthas使用 本文是按照自己的理解进行笔记总结,如有不正确的地方,还望大佬多多指点纠正,勿喷。 课程内容: 1、阿里巴巴Arthas调优工具详解 2、GC日志详解与调优分析 3、Class常量池与运行…

氟化物选择吸附树脂Tulsimer ®CH-87 ,锂电行业废水行业矿井水除氟专用树脂

氟化物选择吸附树脂 Tulsimer CH-87 是一款去除水溶液中氟离子的专用的凝胶型选择性离子交换树脂。它是具有氟化物选择性官能团的交联聚苯乙烯共聚物架构的树脂。 去除氟离子的能力可以达到 1ppm 以下的水平。中性至碱性的PH范围内有较好的工作效率,并且很容易再生…

Ingress详解

Ingress Service对集群外暴露端口两种方式,这两种方式都有一定的缺点: NodePort :会占用集群集群端口,当集群服务变多时,缺点明显LoadBalancer:每个Service都需要一个LB,并且需要k8s之外设备支…

[洛谷]P2960 [USACO09OCT]Invasion of the Milkweed G(BFS,坑点多多)

、1:坐标是反的,(1,1)是在左下角,正常在右上角,所用建图的时候要小心 2: 加node(),搭配 中的构造 3: 不用判断位置是否越界,数组从1,1开始&…

网络安全通信HTTPS原理

背景 HTTPS并不是一个全新的协议,而是为了保证网络通信的数据安全,在HTTP的基础上加上了安全套件SSL的一个数据通信协议。 HTTP协议全称为HyperText Transfer Protocol即超文本传输协议,是客户端浏览器与Web服务器之间的应用层通信协议。HT…

秘钥失效问题

一、发现问题 问题如下图代码: $ ssh root108.61.163.242 WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! Someone could be eavesdropping on you right now (man-in-the-middle attack)! It is …

前端入门学习

封装axios axios的基础使用 axios基础使用方法: axios.create({config}) //创建axios实例 axios.get(url,{config}) //get请求 axios.post(url, data,{config}) //post请求 axios.interceptors.request.use() // 请求拦截器 axios.interce…