【论文笔记】BiFormer: Vision Transformer with Bi-Level Routing Attention

news2024/11/19 14:36:14

论文地址:BiFormer: Vision Transformer with Bi-Level Routing Attention

代码地址:https://github.com/rayleizhu/BiFormer

vision transformer中Attention是极其重要的模块,但是它有着非常大的缺点:计算量太大。

BiFormer提出了Bi-Level Routing Attention,在Attention计算时,只关注最重要的token,由此来降低计算量。

一、Bi-Level Routing Attention

下图是多个不同的Attention模块关注的区域,(a)是原始的attention,其他的都是稀疏的Attention结构。Bi-Level Routing Attention如下图(f)所示。

与其他的稀疏Attention结构有所不同,Bi-Level Routing Attention首先将特征图分为不同的区域(区域大小是SxS),每个区域经过线性映射,得到QKV,然后QK在每个SxS的窗口内取平均作为该区域的token(可参考代码),得到Q^{r}K^{r}(r表示region,即SxS的窗口),通过Q^{r}K^{r}的矩阵运算得到A^{r},如下式:

得到了邻接矩阵A^{r}后,取其相关性最高的k个token索引I^{r},这样就知道每个窗口与哪k个窗口相关性更高了。

得到了I^{r}后,用gather运算得到K^{g}V^{g}

最后计算Attention:

Bi-Level Routing Attention的计算过程如下图所示,其中的k就是计算相关性索引时设置的参数。

二、代码

Bi-Level Routing Attention的代码如下:

"""
Core of BiFormer, Bi-Level Routing Attention.

To be refactored.

author: ZHU Lei
github: https://github.com/rayleizhu
email: ray.leizhu@outlook.com

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor


class TopkRouting(nn.Module):
    """
    differentiable topk routing with scaling
    Args:
        qk_dim: int, feature dimension of query and key
        topk: int, the 'topk'
        qk_scale: int or None, temperature (multiply) of softmax activation
        with_param: bool, wether inorporate learnable params in routing unit
        diff_routing: bool, wether make routing differentiable
        soft_routing: bool, wether make output value multiplied by routing weights
    """
    def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
        super().__init__()
        self.topk = topk
        self.qk_dim = qk_dim
        self.scale = qk_scale or qk_dim ** -0.5
        self.diff_routing = diff_routing
        # TODO: norm layer before/after linear?
        self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
        # routing activation
        self.routing_act = nn.Softmax(dim=-1)
    
    def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:
        """
        Args:
            q, k: (n, p^2, c) tensor
        Return:
            r_weight, topk_index: (n, p^2, topk) tensor
        """
        if not self.diff_routing:
            query, key = query.detach(), key.detach()
        query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c) 
        attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
        topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
        r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)
        
        return r_weight, topk_index
        

class KVGather(nn.Module):
    def __init__(self, mul_weight='none'):
        super().__init__()
        assert mul_weight in ['none', 'soft', 'hard']
        self.mul_weight = mul_weight

    def forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):
        """
        r_idx: (n, p^2, topk) tensor
        r_weight: (n, p^2, topk) tensor
        kv: (n, p^2, w^2, c_kq+c_v)

        Return:
            (n, p^2, topk, w^2, c_kq+c_v) tensor
        """
        # select kv according to routing index
        n, p2, w2, c_kv = kv.size()
        topk = r_idx.size(-1)
        # print(r_idx.size(), r_weight.size())
        # FIXME: gather consumes much memory (topk times redundancy), write cuda kernel? 
        topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy
                                dim=2,
                                index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv)
                               )

        if self.mul_weight == 'soft':
            topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
        elif self.mul_weight == 'hard':
            raise NotImplementedError('differentiable hard routing TBA')
        # else: #'none'
        #     topk_kv = topk_kv # do nothing

        return topk_kv

class QKVLinear(nn.Module):
    def __init__(self, dim, qk_dim, bias=True):
        super().__init__()
        self.dim = dim
        self.qk_dim = qk_dim
        self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
    
    def forward(self, x):
        q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)
        return q, kv
        # q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
        # return q, k, v

class BiLevelRoutingAttention(nn.Module):
    """
    n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
    kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
    topk: topk for window filtering
    param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
    param_routing: extra linear for routing
    diff_routing: wether to set routing differentiable
    soft_routing: wether to multiply soft routing weights 
    """
    def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
                 kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
                 topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
                 auto_pad=False):
        super().__init__()
        # local attention setting
        self.dim = dim
        self.n_win = n_win  # Wh, Ww
        self.num_heads = num_heads
        self.qk_dim = qk_dim or dim
        assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
        self.scale = qk_scale or self.qk_dim ** -0.5


        ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)
        
        ################ global routing setting #################
        self.topk = topk
        self.param_routing = param_routing
        self.diff_routing = diff_routing
        self.soft_routing = soft_routing
        # router
        assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
        self.router = TopkRouting(qk_dim=self.qk_dim,
                                  qk_scale=self.scale,
                                  topk=self.topk,
                                  diff_routing=self.diff_routing,
                                  param_routing=self.param_routing)
        if self.soft_routing: # soft routing, always diffrentiable (if no detach)
            mul_weight = 'soft'
        elif self.diff_routing: # hard differentiable routing
            mul_weight = 'hard'
        else:  # hard non-differentiable routing
            mul_weight = 'none'
        self.kv_gather = KVGather(mul_weight=mul_weight)

        # qkv mapping (shared by both global routing and local attention)
        self.param_attention = param_attention
        if self.param_attention == 'qkvo':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Linear(dim, dim)
        elif self.param_attention == 'qkv':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Identity()
        else:
            raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
        
        self.kv_downsample_mode = kv_downsample_mode
        self.kv_per_win = kv_per_win
        self.kv_downsample_ratio = kv_downsample_ratio
        self.kv_downsample_kenel = kv_downsample_kernel
        if self.kv_downsample_mode == 'ada_avgpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'ada_maxpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'maxpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'avgpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'identity': # no kv downsampling
            self.kv_down = nn.Identity()
        elif self.kv_downsample_mode == 'fracpool':
            # assert self.kv_downsample_ratio is not None
            # assert self.kv_downsample_kenel is not None
            # TODO: fracpool
            # 1. kernel size should be input size dependent
            # 2. there is a random factor, need to avoid independent sampling for k and v 
            raise NotImplementedError('fracpool policy is not implemented yet!')
        elif kv_downsample_mode == 'conv':
            # TODO: need to consider the case where k != v so that need two downsample modules
            raise NotImplementedError('conv policy is not implemented yet!')
        else:
            raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')

        # softmax for local attention
        self.attn_act = nn.Softmax(dim=-1)

        self.auto_pad=auto_pad

    def forward(self, x, ret_attn_mask=False):
        """
        x: NHWC tensor

        Return:
            NHWC tensor
        """
         # NOTE: use padding for semantic segmentation
        ###################################################
        if self.auto_pad:
            N, H_in, W_in, C = x.size()

            pad_l = pad_t = 0
            pad_r = (self.n_win - W_in % self.n_win) % self.n_win
            pad_b = (self.n_win - H_in % self.n_win) % self.n_win
            x = F.pad(x, (0, 0, # dim=-1
                          pad_l, pad_r, # dim=-2
                          pad_t, pad_b)) # dim=-3
            _, H, W, _ = x.size() # padded size
        else:
            N, H, W, C = x.size()
            assert H%self.n_win == 0 and W%self.n_win == 0 #
        ###################################################


        # patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
        x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)

        #################qkv projection###################
        # q: (n, p^2, w, w, c_qk)
        # kv: (n, p^2, w, w, c_qk+c_v)
        # NOTE: separte kv if there were memory leak issue caused by gather
        q, kv = self.qkv(x) 

        # pixel-wise qkv
        # q_pix: (n, p^2, w^2, c_qk)
        # kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
        q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
        kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
        kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)

        q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)

        ##################side_dwconv(lepe)##################
        # NOTE: call contiguous to avoid gradient warning when using ddp
        lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())
        lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)

        ############ gather q dependent k/v #################

        r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors

        kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)
        k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
        # kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)
        # v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)
        
        ######### do attention as normal ####################
        k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
        v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
        q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)

        # param-free multihead attention
        attn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
        attn_weight = self.attn_act(attn_weight)
        out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
        out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
                        h=H//self.n_win, w=W//self.n_win)

        out = out + lepe
        # output linear
        out = self.wo(out)

        # NOTE: use padding for semantic segmentation
        # crop padded region
        if self.auto_pad and (pad_r > 0 or pad_b > 0):
            out = out[:, :H_in, :W_in, :].contiguous()

        if ret_attn_mask:
            return out, r_weight, r_idx, attn_weight
        else:
            return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1338058.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI数字人直播:引领直播行业新风潮,塑造新一轮“风口”

直播行业自从出现以来,就一直在不断地迭代和创新。而如今,随着AI技术的迅猛发展,AI数字人直播成为了一个崭新的赛道,引领着直播行业的新风潮,并有望塑造出新一轮的“风口”。 AI数字人直播,顾名思义&#…

LENOVO联想笔记本小新Pro 14 IRH8 2023款(83AL)电脑原装出厂Win11系统恢复预装OEM系统

链接:https://pan.baidu.com/s/1M1iSFahokiIHF3CppNpL4w?pwdzr8y 提取码:zr8y 联想原厂系统自带所有驱动、出厂主题壁纸、Office办公软件、联想电脑管家等自带的预装软件程序 所需要工具:16G或以上的U盘 文件格式:ISO 文件…

python3遇到Can‘t connect to HTTPS URL because the SSL module is not available.

远程服务器centos7系统上有minicoda3,觉得太占空间,就把整个文件夹删了,原先的Python3也没了,都要重装。 我自己的步骤:进入管理员模式 1.下载Python3的源码: wget https://www.python.org/ftp/python/3.1…

平面灯阵中寻找最大正方形边界 - 华为机试真题题解

分值: 300分 题解: Java / Python / C++ 题目描述 现在有一个二维数组来模拟一个平面灯阵,平面灯阵中每个位置灯处于点亮或熄灭,分别对应数组每个元素取值只能为1或0,现在需要找一个正方形边界,其每条边上的灯都是点亮(对应数组中元素的值为1)的,且该正方形面积最大。 …

Linux第一个小程序-进度条(c语言版)

目录 行缓冲区概念: 行缓冲区代码演示: ​编辑进度条代码 1:memset函数: 2:const char* lable"|/-\\"; 3:usleep C语言 usleep 函数的功能和用法: 4:进度条代码的实…

设置视图的对齐方式

对齐方式 在XML文件中通过属性android:layout_gravity可以指定当前视图的对齐方向,当属性值为top时表示视图朝上对齐,为bottom时表示视图朝下对齐,为left时表示视图靠左对齐,为right时表示视图靠右对齐。如果希望视图既朝上又靠左…

全网最详细的TVBOX带会员版二开图文教程:一、tvbox如意前端后台搭建教程;二、tvbox后台配置教程;三、tvbox源码Android Studio配置修改教程;四、tvbox源码as打包教程

一、TVBOX管理后台源码网站搭建; 搭建测试环境:PHP7.0、Nginx、按照好宝塔、配置解析好域名 1、请将下载好的tvbox源码压缩包进行解压,解压后得到的问题件如图所示 2、请将压缩包内的如图所指文件(1)上传到你的网站…

【Java】SpringBoot快速整合WebSocket实现客户端服务端相互推送信息

目录 什么是webSocket? webSocket可以用来做什么? WebSocket操作类 一:测试客户端向服务端推送消息 1.启动SpringBoot项目 2.打开网站 3.进行测试消息推送 4.后端进行查看测试结果 二:测试服务端向客户端推送消息 1.接口代码 2.使…

Unity网格篇Mesh(二)

Unity网格篇Mesh(二) 介绍4.生成额外的顶点数据未计算法线计算法线没有法线vs有法线错误的UV坐标Clamping vs warpping正确的UV纹理,平铺(1,1) vs 平铺(2,1)凹凸不平的表面,产生了金…

SLAM算法与工程实践——相机篇:RealSense D435使用(1)

SLAM算法与工程实践系列文章 下面是SLAM算法与工程实践系列文章的总链接,本人发表这个系列的文章链接均收录于此 SLAM算法与工程实践系列文章链接 下面是专栏地址: SLAM算法与工程实践系列专栏 文章目录 SLAM算法与工程实践系列文章SLAM算法与工程实践…

layui(iconPickerFa)图标选择器插件,主要用于后台菜单图标管理

话不多说直接上代码 在页面中引入如下代码 <link rel"stylesheet" href"/template/admin/layui-v2.5.6/css/layui.css"> <script type"text/javascript" src"/template/admin/layui-v2.5.6/layui.js"></script> &…

LSTM的记忆能力实验

长短期记忆网络&#xff08;Long Short-Term Memory Network&#xff0c;LSTM&#xff09;是一种可以有效缓解长程依赖问题的循环神经网络&#xff0e;LSTM 的特点是引入了一个新的内部状态&#xff08;Internal State) 和门控机制&#xff08;Gating Mechanism&#xff09;&am…

Java 将PDF 转为图片 工具 【Free Spire.PDF for Java】(免费版)

Java 将PDF 转为图片 使用工具&#xff1a;Free Spire.PDF for Java&#xff08;免费版&#xff09; Jar文件获取及导入&#xff1a; 方法1&#xff1a;通过官网下载jar文件包。下载后&#xff0c;解压文件&#xff0c;并将lib文件夹下的Spire.Pdf.jar文件导入Java程序。 方…

C++ 比 C语言 增加的新特性 4 之 内存分配

1. 内存分配 1.1 面试题&#xff1a; C语言里的malloc和free与C里的new和delete有什么区别&#xff1f; C语言&#xff1a; malloc&#xff1a;用于动态内存分配 free&#xff1a;释放动态内存 C&#xff1a; new&#xff1a;用于动态内存的申请 delete&#xff1a;用于释放申请…

python爬虫进阶-每日一学(GIF验证码识别)

目的 学习更多的python反爬虫策略 测试网址 http://credit.customs.gov.cn/ccppserver/verifyCode/creator分析 01 下载gif图片 02 使用ddddocr逐帧识别 03 如指定字符串出现次数大于等于3&#xff0c;则认定为正确的识别结果 经验证&#xff0c;识别成功率95%源码 #!/usr…

如何对Sentinel-2进行 预处理

下载Sentinel-2数据后&#xff0c;可以自动对L1C数据进行预处理&#xff0c;它用于 Sentinel-2 Level 2A 产品生成和格式化的处理&#xff1b;它对大气层 1C 级输入数据进行大气、地形和卷云校正。Sen2Cor 创建大气层底部、可选的地形和卷云校正反射率图像&#xff1b;此外&…

IDEA相关操作

目录 连接MySQL IDEA配置Maven 配置全局Maven 导入Maven项目 方法一 方法二 安装Mybatisx插件 连接MySQL 填写user和Password之后测试连接 如果是第一次连接需要联网下载数据库连接驱动&#xff0c;安装提示下载即可 如果显示如下错误需要更改时区 Server returns …

Es三节点+vip集群搭建部署方案

线上环境Es三节点集群搭建部署方案 1. 目标 Es 集群架构图 ! 2. 搭建步骤 官网教程&#xff1a; https://elasticsearch.bookhub.tech/set_up_elasticsearch/installing_elasticsearch/ 确定 Es 安装目录 机器名内网IPEs 版本重点目录es-node01192.18.233.2407.15.1安装…

在k8s中使用cert-manager部署gitlab集群

写在前面的话&#xff1a;前面有详细的分享过在k8s集群中部署gitlab&#xff0c;不过当时使用gitlab的访问证书是阿里云上免费的ssl证书&#xff0c;今天特意专门介绍下另外一种基于cert-manager发布自签证书的方式实现部署gitlab到k8s集群中。 往期gitlab部署系列如&#xff1…

什么是数据可视化?数据可视化的流程与步骤

前言 数据可视化将大大小小的数据集转化为更容易被人脑理解和处理的视觉效果。可视化在我们的日常生活中非常普遍&#xff0c;但它们通常以众所周知的图表和图形的形式出现。正确的数据可视化以有意义和直观的方式为复杂的数据集提供关键的见解。 数据可视化定义 数据可视化…