YoloV9改进策略:主干网络改进|DeBiFormer,可变形双级路由注意力|全网首发

news2024/10/19 8:29:27

摘要

在目标检测领域,YoloV9以其高效和准确的性能而闻名。然而,为了进一步提升其检测能力,我们引入了DeBiFormer作为YoloV9的主干网络。这个主干网络的计算量比较大,不过,上篇双级路由注意力的论文受到很大的关注,所以我也将这篇论文中的主干网络用来改进YoloV9,卡多的同学可以试试。

DeBiFormer是一种新型的视觉转换器,它结合了可变形注意力和双级路由注意力的优点。通过引入可变形双级路由注意力(DBRA)机制,DeBiFormer能够灵活且语义化地获取数据依赖的注意力模式。这种注意力中注意力的架构使得模型能够更高效地定位关键特征,从而提高检测的准确性。

将DeBiFormer应用于YoloV9的主干网络,我们实现了以下显著的改进:

  1. 更强的特征表示能力:DeBiFormer的DBRA机制能够捕获更多信息性特征,并将其回传给查询,从而增强了模型的特征表示能力。这使得YoloV9在检测目标时能够更准确地识别其形状、纹理等关键特征。
  2. 更高的检测精度:由于DeBiFormer具有更强的特征表示能力,YoloV9在检测目标时能够实现更高的精度。实验结果表明,在相同的数据集和训练策略下,改进后的YoloV9在各类目标上的检测精度均有显著提升。
  3. 更好的泛化性能:DeBiFormer的注意力机制使得模型能够更好地适应不同的场景和数据分布。因此,改进后的YoloV9在面临新的、未见过的目标时,能够表现出更好的泛化性能。

论文:《DeBiFormer: 带可变形代理双级路由注意力的视觉Transformer

带有各种注意力模块的视觉Transformer在视觉任务上已表现出卓越的性能。虽然使用稀疏自适应注意力(如在DAT中)在图像分类任务中取得了显著成果,但在对语义分割任务进行微调时,由可变形点选择的关键值对缺乏语义相关性。BiFormer中的查询感知稀疏注意力旨在使每个查询关注前 k k k个路由区域。然而,在注意力计算过程中,所选的关键值对受到过多不相关查询的影响,从而降低了对更重要查询的关注度。为了解决这些问题,我们提出了可变形双级路由注意力(DBRA)模块,该模块使用代理查询优化关键值对的选择,并增强了注意力图中查询的可解释性。在此基础上,我们引入了带有DBRA模块的新型通用视觉Transformer——可变形双级路由注意力Transformer(DeBiFormer)。DeBiFormer已在各种计算机视觉任务上得到验证,包括图像分类、目标检测和语义分割,有力地证明了其有效性。代码可访问:https://github.com/maclong01/DeBiFormer
关键词:视觉Transformer,自注意力机制,图像识别

1 引言

视觉Transformer在计算机视觉领域近期展现出了巨大的潜力[15,29,44]。它能够捕获数据中的长距离依赖关系[29,41],并几乎引领了一种更灵活、更适合拟合大量数据的无卷积模型[44]。此外,它还具有高并行性,这有利于大型模型的训练和推理[11,41]。计算机视觉领域观察到,视觉Transformer的采用和发展呈现出爆炸式增长[1,14,15,29,44,45]。
在这里插入图片描述

为了提高注意力,大量研究精心设计了高效的注意力模式,其中每个查询都通过较小部分的关键值对进行选择性聚焦。如图1所示,在各种表示方法中,一些方法包括局部窗口[50]和空洞窗口[45, 40, 24]。此外,一些研究在方法论上通过数据稀疏性适应采取了不同的路径,如[5,47]中的工作所示。然而,尽管在合并或选择关键和值令牌时采用了不同的策略,但这些令牌对于查询来说并不具有语义性。采用这种方法时,当应用于预训练的ViT[41]和DETR[1]的其他下游任务时,查询并非源自语义区域的关键值对。因此,强制所有查询关注于不足的一组令牌可能无法产生最优结果。最近,随着动态查询感知稀疏注意力机制的出现,查询由最具动态语义性的关键值对进行聚焦,这被称为双级路由注意力[56]。然而,在这种方法中,查询由语义关键值对处理而非源自详细区域,这可能并非在所有情况下都能产生最优结果。此外,在计算注意力时,为所有查询选择的这些关键值和值受到太多不太相关查询的影响,导致对重要查询的注意力降低,这对执行分割任务时具有重大影响[13,25]。

为了使查询的注意力更加高效,我们提出了可变形双级路由注意力(DBRA),这是一种用于视觉识别的注意力中注意力架构。在DBRA的过程中,第一个问题是如何定位可变形点。我们使用了[47]中的观察结果,即注意力具有一个偏移网络,该网络以查询特征为输入,并为所有参考点生成相应的偏移量。因此,候选可变形点以高灵活性和效率向重要区域移动,以捕获更多信息性特征。第二个问题是如何从语义相关的关键值对中聚合信息,然后将信息回传给查询。因此,我们提出了一种注意力中注意力架构,其中如上文所示向可变形点移动的部分作为查询的代理。由于关键值对是为可变形点选择的,我们使用[56]中的观察结果来选择一小部分最具语义相关性的关键值对,即一个区域仅通过关注前 k k k个路由区域来所需的部分。然后,在选择了语义相关的关键值对后,我们首先使用带有可变形点查询的令牌到令牌注意力。接着,我们应用第二个令牌到令牌注意力将信息回传给查询,其中作为关键值对的可变形点被设计为表示语义区域中最重要的点。

综上所述,我们的贡献如下:

  1. 我们提出了可变形双级路由注意力(DBRA),这是一种用于视觉识别的注意力中注意力架构,能够灵活且语义化地获取数据依赖的注意力模式。
  2. 通过利用DBRA模块,我们提出了一种新的主干网络,称为DeBiFormer。根据注意力热图的可视化结果,该网络具有更强的识别能力。
  3. 在ImageNet[35]、ADE20K[55]和COCO[17]上的大量实验表明,我们的模型始终优于其他竞争基线。

2 相关工作

2.1 视觉Transformer

基于Transformer的主干网络结合了通道级MLP[38]块,通过通道混合嵌入每个位置的特征。此外,还使用注意力[41]块进行跨位置关系建模并促进空间混合。Transformer最初是为自然语言处理[41,11]而设计的,随后通过DETR[1]和ViT[41]等工作被引入计算机视觉领域。与卷积神经网络(CNN)相比,Transformer的主要区别在于它使用注意力替代卷积,从而促进了全局上下文建模。然而,传统的注意力机制计算所有空间位置之间的成对特征亲和力,这带来了巨大的计算负担和内存占用,特别是在处理高分辨率输入时。因此,一个关键的研究重点是设计更高效的注意力机制,这对于减轻计算需求至关重要,尤其是处理高分辨率输入时。

2.2 注意力机制

大量研究旨在减轻传统注意力机制带来的计算和内存复杂性。方法包括稀疏连接模式[6]、低秩近似[42]和循环操作[10]。在视觉Transformer的上下文中,稀疏注意力变得流行起来,特别是在Swin Transformer[29]取得显著成功后。在Swin Transformer框架中,注意力被限制在非重叠的局部窗口中,并引入了一种创新的移位窗口操作。该操作促进了相邻窗口之间的通信,为其处理注意力机制提供了独特的方法。为了在不超过计算限制的情况下实现更大或近似全局的感受野,最近的研究结合了多种手动设计的稀疏模式。这些模式包括空洞窗口[45,40,24]和十字形窗口[14]的集成。此外,一些研究致力于使稀疏模式适应数据,如DAT[47]、TCFormer[53]和DPT[5]等工作所示。尽管它们通过使用不同的合并或选择策略来减少关键值令牌的数量,但重要的是要认识到这些令牌缺乏语义特异性。相反,我们加强了查询感知的关键值令牌选择。

我们的工作受到一个观察结果的启发:对于重要查询,语义上关注的区域可能表现出显著差异,如ViT[41]和DETR[1]等预训练模型的可视化所示。在实现通过粗细粒度方法实现的查询自适应稀疏性时,我们提出了一种注意力中注意力架构,该架构结合了可变形注意力[47]和双级路由注意力[56]。与可变形注意力[47]和双级路由注意力[56]不同,我们的可变形双级路由注意力旨在加强最具语义性和灵活性的关键值对。相比之下,双级路由注意力仅关注定位少数高度相关的关键值对,而可变形注意力则优先识别少数最具灵活性的关键值对。

3 我们的方法:DeBiFormer

3.1 预备知识

首先,我们回顾了最近视觉Transformer中使用的注意力机制。以扁平化的特征图 x ∈ R N × C x \in \mathrm{R}^{N \times C} xRN×C 作为输入,具有 M M M个头的多头自注意力(MHSA)块表示为

\begin{array}{c}
q=xW_{q}, k=xW_{k}, v=xW_{v} \
z{(m)}=\sigma\left(q{(m)}k^{(m) \top} / \sqrt{d}\right)v^{(m)}, m=1, \ldots, M \
z=\operatorname{Concat}\left(z^{(1)}, \ldots, z^{(M)}\right)W_{o}
\end{array}

其中, σ ( ⋅ ) \sigma(\cdot) σ() 表示softmax函数, d = C / M d=C / M d=C/M 是每个头的维度。 z ( m ) z^{(m)} z(m) 表示第 m m m个注意力头的嵌入输出,而 q ( m ) , k ( m ) , v ( m ) ∈ R N × d q^{(m)}, k^{(m)}, v^{(m)} \in \mathrm{R}^{N \times d} q(m),k(m),v(m)RN×d 分别表示查询、键和值嵌入。 W q , W k , W v , W o ∈ R C × C W_{q}, W_{k}, W_{v}, W_{o} \in \mathrm{R}^{C \times C} Wq,Wk,Wv,WoRC×C 是投影矩阵。带有归一化层和恒等捷径的第 l l l个Transformer块(其中LN表示层归一化)表示为
z l ′ = MHSA ⁡ ( L N ( z l − 1 ) ) + z l − 1 z l = MLP ⁡ ( L N ( z l ′ ) ) + z l ′ \begin{array}{c} z_{l}^{\prime}=\operatorname{MHSA}\left(L N\left(z_{l-1}\right)\right)+z_{l-1} \\ z_{l}=\operatorname{MLP}\left(L N\left(z_{l}^{\prime}\right)\right)+z_{l}^{\prime} \end{array} zl=MHSA(LN(zl1))+zl1zl=MLP(LN(zl))+zl

3.2 可变形双层路由注意力(DBRA)

所提出的可变形双层路由注意力(DBRA)的架构如图2所示。我们首先采用一个可变形注意力模块,该模块包含一个偏移网络,该网络基于查询特征为参考点生成偏移量,从而创建可变形点。然而,这些点往往会在重要区域聚集,导致某些区域过度集中。
在这里插入图片描述

为解决此问题,我们引入了可变形点感知区域划分,确保每个可变形点仅与键值对的一个小子集进行交互。然而,仅依赖区域划分可能会导致重要区域和不太重要区域之间的不平衡。为解决此问题,DBRA模块被设计为更有效地分配注意力。在DBRA中,每个可变形点充当代理查询,与语义区域键值对计算注意力。这种方法确保每个重要区域仅分配少数可变形点,从而使注意力分散到图像的所有关键区域,而不是聚集在一个点上。

通过使用DBRA模块,不太重要区域的注意力减少,更重要区域的注意力增加,确保整个图像中注意力的平衡分布。

可变形注意力模块和输入投影。如图2所示,给定输入特征图 x ∈ R H × W × C x \in \mathrm{R}^{H \times W \times C} xRH×W×C,通过以因子 r r r 对输入特征图进行下采样,生成一个均匀的点网格 p ∈ R H G × W G × 2 p \in \mathrm{R}^{H_{G} \times W_{G} \times 2} pRHG×WG×2,其中 H G = H / r , W G = W / r H_{G}=H / r, W_{G}=W / r HG=H/r,WG=W/r,作为参考。为了获得每个参考点的偏移量,将特征进行线性投影以生成查询令牌 q = x W q q=x W_{q} q=xWq,然后将其输入到 θ offset  ( ⋅ ) \theta_{\text {offset }}(\cdot) θoffset () 子网络中,以产生偏移量 Δ p = θ offset  ( q ) \Delta p = \theta_{\text {offset }}(q) Δp=θoffset (q)。随后,在变形点的位置对特征进行采样作为键和值,并通过投影矩阵进行进一步处理:

q = x W q , Δ p = θ offset  ( q ) , x ˉ = φ ( x ; p + Δ p ) q=x W_{q}, \Delta p=\theta_{\text {offset }}(q), \bar{x}=\varphi(x ; p+\Delta p) q=xWq,Δp=θoffset (q),xˉ=φ(x;p+Δp)

其中, x ˉ \bar{x} xˉ 分别表示变形后的键 k ˉ \bar{k} kˉ 和值 v ˉ \bar{v} vˉ 嵌入。具体来说,我们将采样函数 φ ( ⋅ ; ⋅ ) \varphi(\cdot ; \cdot) φ(;) 设置为双线性插值,使其可微:

φ ( z ; ( p x , p y ) ) = ∑ r x , r y g ( p x , r x ) g ( p y , r y ) z [ r y , r x , : ] \varphi\left(z ;\left(p_{x}, p_{y}\right)\right)=\sum_{r_{x}, r_{y}} g\left(p_{x}, r_{x}\right) g\left(p_{y}, r_{y}\right) z\left[r_{y}, r_{x},:\right] φ(z;(px,py))=rx,ryg(px,rx)g(py,ry)z[ry,rx,:]

其中,函数 g ( a , b ) = max ⁡ ( 0 , 1 − ∣ a − b ∣ ) g(a, b)=\max (0,1-|a-b|) g(a,b)=max(0,1ab),且 ( r x , r y ) \left(r_{x}, r_{y}\right) (rx,ry) 表示 z ∈ R H × W × C z \in \mathrm{R}^{H \times W \times C} zRH×W×C 上所有位置的索引。在类似于可变形注意力的设置中,当 g g g 在最接近 ( p x , p y ) (p_x, p_y) (px,py) 的四个整数点上不为零时,方程7简化为这四个位置上的加权平均。

区域划分与区域间路由。给定可变形注意力特征图输入 x ˉ ∈ R H G × W G × C \bar{x} \in \mathrm{R}^{H_{G} \times W_{G} \times C} xˉRHG×WG×C 和特征图 x ∈ R H × W × C x \in \mathrm{R}^{H \times W \times C} xRH×W×C,过程首先将其划分为大小为 S × S S \times S S×S 的非重叠区域,使得每个区域包含 H G W G S 2 \frac{H_{G} W_{G}}{S^{2}} S2HGWG 个特征向量,并将重塑后的 x ˉ \bar{x} xˉ 记为 x r ‾ ∈ R S 2 × H G W G S 2 × C \overline{x^{r}} \in \mathrm{R}^{S^{2} \times \frac{H_{G} W_{G}}{S^{2}}} \times C xrRS2×S2HGWG×C,将 x x x 记为 x r ∈ R S 2 × H W S 2 × C x^{r} \in \mathrm{R}^{S^{2} \times \frac{H W}{S^{2}}} \times C xrRS2×S2HW×C。然后,我们通过线性投影得到查询、键和值:

q ^ = x r ‾ W q , k ^ = x r W k , v ^ = x r W v \hat{q}=\overline{x^{r}} W_{q}, \hat{k}=x^{r} W_{k}, \hat{v}=x^{r} W_{v} q^=xrWq,k^=xrWk,v^=xrWv

接下来,我们使用BiFormer[56]中介绍的区域间方法,通过构建有向图来建立注意关系。首先,通过每个区域的平均值得到区域查询和键 q ^ r , k ^ r ∈   S S 2 × C \hat{q}^{r}, \hat{k}^{r} \in \mathrm{~S}^{S^{2} \times C} q^r,k^r SS2×C。然后,通过 Q r Q^{r} Qr K r ⊤ K^{r^{\top}} Kr 矩阵乘法得到区域间亲和图的邻接矩阵 A r ∈   S 2 × S 2 A^{r} \in \mathrm{~S}^{2} \times \mathrm{S}^{2} Ar S2×S2

A r = q ^ r ( k ^ r ) ⊤ A^{r}=\hat{q}^{r}\left(\hat{k}^{r}\right)^{\top} Ar=q^r(k^r)

其中,邻接矩阵 A r A^{r} Ar 量化了两个区域之间的语义关系。该方法的关键步骤是通过使用topk操作符和路由索引矩阵 I r ∈   N S 2 × k I^{r} \in \mathrm{~N}^{S^{2} \times k} Ir NS2×k 保留每个区域的topk连接来修剪亲和图:

I r = topk ⁡ ( A r ) I^{r}=\operatorname{topk}\left(A^{r}\right) Ir=topk(Ar)

双层标记到可变形层标记注意力。利用区域路由矩阵 I r I^{r} Ir,我们可以应用标记注意力。对于区域 i i i 内的每个可变形查询标记,其注意力跨越位于topk路由区域中的所有键值对,即由 I i , 1 r , I i , 2 r , … , I i , k r I_{i, 1}^{r}, I_{i, 2}^{r}, \ldots, I_{i, k}^{r} Ii,1r,Ii,2r,,Ii,kr 索引的那些。因此,我们继续收集键和值的过程:

k ^ g = gather ⁡ ( k ^ , I r ) , v ^ g = gather ⁡ ( v ^ , I r ) \hat{k}^{g}=\operatorname{gather}\left(\hat{k}, I^{r}\right), \hat{v}^{g}=\operatorname{gather}\left(\hat{v}, I^{r}\right) k^g=gather(k^,Ir),v^g=gather(v^,Ir)

其中, k ^ g , v ^ g ∈ R S 2 × k H W S 2 × C \hat{k}^{g}, \hat{v}^{g} \in \mathrm{R}^{S^{2} \times \frac{k H W}{S^{2}} \times C} k^g,v^gRS2×S2kHW×C 是收集的键和值。然后,我们对 k ^ g , v ^ g \hat{k}^{g}, \hat{v}^{g} k^g,v^g 应用注意力:

O ^ = x ^ + W o ′ (  Attention  ( q ^ , k ^ g , v ^ g ) + L C E ( v ^ ) ) O = MLP ⁡ ( L N ( O ^ ) ) + O ^ \begin{array}{c}\hat{O}=\hat{x}+W_{o^{\prime}}\left(\text { Attention }\left(\hat{q}, \hat{k}^{g}, \hat{v}^{g}\right)+L C E(\hat{v})\right) \\O=\operatorname{MLP}(L N(\hat{O}))+\hat{O}\end{array} O^=x^+Wo( Attention (q^,k^g,v^g)+LCE(v^))O=MLP(LN(O^))+O^

其中, W o ′ W_{o^{\prime}} Wo 是输出特征的投影权重, L C E ( ⋅ ) L C E(\cdot) LCE() 使用核大小为5的深度卷积。

可变形层标记到标记注意力。之后,通过[56]语义关注的可变形特征被重塑为 O r ∈ R H G × W G × C O^{r} \in \mathbb{R}^{H_{G} \times W_{G} \times C} OrRHG×WG×C,并在键和值的位置进行参数化:

k = O r W k , v = O r W v k=O^{r} W_{k}, v=O^{r} W_{v} k=OrWk,v=OrWv

k k k v v v 分别表示语义变形键和值的嵌入。使用现有方法,我们对 q , k , v q, k, v q,k,v 和相对位置偏移 R R R 执行自注意力。注意力的输出公式如下:

z m = W o ˉ ( σ ( q m k ( m ) ⊤ / d + ϕ ( B ^ ; R ) ) v m ) z^{m}=W_{\bar{o}}\left(\sigma\left(q^{m} k^{(m) \top} / \sqrt{d}+\phi(\hat{B} ; R)\right) v^{m}\right) zm=Woˉ(σ(qmk(m)/d +ϕ(B^;R))vm)

这里, ϕ ( B ^ ; R ) ∈ R H W × H G W G \phi(\hat{B} ; R) \in \mathbb{R}^{H W \times H_{G} W_{G}} ϕ(B^;R)RHW×HGWG 对应位置嵌入,遵循先前工作[29]的方法。然后,将 z m z^{m} zm 通过 W o W_{o} Wo 投影得到最终输出 z z z,如方程3所示。

3.3 模型架构

利用DBRA作为基本构建块,我们引入了一种新的视觉转换器,称为DeBiFormer。如图3所示,我们遵循最新的最先进的视觉转换器[14,29,56,47],使用四阶段金字塔结构。在第 i i i 阶段,我们在第一阶段使用重叠补丁嵌入,在第二到第四阶段使用补丁合并模块[26,34]。这是为了降低输入空间分辨率,同时增加通道数。随后,使用 N i N_{i} Ni 个连续的DeBiFormer块来转换特征。在每个DeBiFormer块内,我们遵循最近的方法论[26,40,56],在开始时使用 3 × 3 3 \times 3 3×3 深度卷积。这是为了隐式编码相对位置信息。之后,我们依次使用一个DBRA模块和一个具有扩展比 e e e 的2-ConvFFN模块,分别用于跨位置关系建模和每个位置的嵌入。DeBiFormer以三种不同的模型尺寸实例化,通过按表1中概述的网络宽度和深度进行缩放来实现。每个注意力头包含32个通道,我们使用具有MLP扩展比 e = 3 e=3 e=3 的双层ConvFFN和可变形层ConvFFN。对于BRA,我们在四个阶段使用topk =1,4,16, S 2 S^{2} S2,对于DBRA,我们使用topk =4,8,16, S 2 S^{2} S2。此外,我们将区域划分因子 S S S 设置为特定值:分类任务中 S = 7 S=7 S=7,语义分割任务中 S = 8 S=8 S=8,目标检测任务中 S = 20 S=20 S=20
在这里插入图片描述

4 实验

我们通过实验评估了所提出的DeBiFormer在各种主流计算机视觉任务上的有效性,包括图像分类(第4.1节)、语义分割(第4.2节)和目标检测以及实例分割(第4.3节)。在我们的方法中,我们从ImageNet-1K [35]数据集开始从头训练图像分类模型。随后,我们在ADE20K [55]数据集上对预训练的主干网络进行微调,以进行语义分割,并在COCO [17]数据集上进行微调,以进行目标检测和实例分割。此外,我们进行了消融研究,以验证所提出的可变形双级路由注意力(Deformable Bi-level Routing Attention)和DeBiFormer的top-k选择的有效性(第4.4节)。最后,为了验证我们DeBiFormer的识别能力和可解释性,我们对注意力图进行了可视化(第5节)。
在这里插入图片描述

4.1 在ImageNet-1K上的图像分类

设置。我们在ImageNet-1K [35]数据集上进行了图像分类实验,遵循DeiT [39]的实验设置以进行公平比较。具体来说,每个模型在8个V100 GPU上以224×224的输入大小训练300个epoch。我们使用AdamW作为优化器,权重衰减为0.05,并采用余弦衰减学习率调度策略,初始学习率为0.001,同时前五个epoch用于线性预热。批量大小设置为1024。为避免过拟合,我们使用了正则化技术,包括RandAugment [9](rand-m9-mstd0.5-inc1)、MixUp [54](prob=0.8)、CutMix [52](prob=1.0)、随机擦除(prob=0.25)以及增加随机深度[23](对于DeBiFormer-T/S/B,prob分别为0.1/0.2/0.4)。结果。我们在表2中报告了结果,展示了具有相似计算复杂度的top-1准确率。我们的DeBiFormer在所有三个尺度上都优于Swin Transformer [29]、PVT [44]、DeiT [39]、DAT[47]和Biformer [56]。在不将卷积插入Transformer块或使用重叠卷积进行块嵌入的情况下,DeBiFormer相对于BiFormer [56]对应版本分别实现了0.5pt、0.1pt和0.1pt的增益。
在这里插入图片描述

4.2 在ADE20K上的语义分割

设置。与现有工作相同,我们在SemanticFPN [46]和UperNet [48]上使用了我们的DeBiFormer。在这两种情况下,主干网络都使用ImageNet-1K预训练权重进行初始化。优化器是AdamW [31],批量大小为32。为进行公平比较,我们遵循PVT [44]的相同设置,用80k步训练模型,并遵循Swin Transformer [29]的相同设置,用160k步训练模型。
在这里插入图片描述

结果。表8展示了两个不同框架的结果。结果表明,在使用Semantic FPN框架的情况下,我们的DeBiFormer-S/B分别实现了49.2/50.6 mIoU,比BiFormer提高了0.3pt/0.7pt。对于UperNet框架,也观察到了类似的性能增益。通过使用DBRA模块,我们的DeBiFormer能够捕获最多的语义键值对,这使得注意力选择更加合理,并在下游语义任务上实现了更高的性能。

4.3 目标检测和实例分割

设置。我们使用DeBiFormer作为Mask RCNN [19]和RetinaNet [16]框架中的主干网络,以评估模型在COCO 2017 [17]数据集上对于目标检测和实例分割的有效性。实验使用MMDetection [3]工具箱进行。在COCO上进行训练之前,我们使用ImageNet-1K预训练权重对主干网络进行初始化,并遵循与BiFormer [56]相同的训练策略以公平比较我们的方法。请注意,由于设备限制,我们在这些实验中设置小批量大小为4,而在BiFormer中此值为16。有关实验具体设置的详细信息,请参阅补充论文。
在这里插入图片描述

结果。我们在表4.2中列出了结果。对于使用RetinaNet进行的目标检测,我们报告了不同IoU阈值(50%,75%)下三个目标尺寸(即小、中、大(S/M/L))的平均精度(mAP)和平均精度(AP)。从结果中可以看出,尽管DeBiFormer的整体性能仅与一些最具竞争力的现有方法相当,但在大目标(AP_L)上的性能却优于这些方法,尽管我们使用的资源有限。这可能是因为DBRA更合理地分配了可变形点。这些点不仅关注小事物,还关注图像中的重要事物。因此,注意力不仅局限于小区域,从而提高了大目标的检测准确性。对于使用Mask R-CNN进行的实例分割,我们报告了不同IoU阈值(50%,75%)下的边界框和掩码的平均精度(AP_b和AP_m)。请注意,尽管受到设备限制(小批量大小),我们的DeBiFormer仍然取得了出色的性能。我们认为,如果小批量大小可以与其他方法相同,我们将能够取得更好的结果,这在语义分割任务中已经得到了证明。

4.4 消融研究

DBRA的有效性。我们将DBRA与几种现有的稀疏注意力机制进行了比较。遵循CSWIN [14],我们为公平比较将宏观架构设计与Swin-T [29]对齐。具体来说,我们在四个阶段分别使用了2、2、6、2个块和非重叠的补丁嵌入,并将初始补丁嵌入维度设置为 C = 96 C=96 C=96,MLP扩展比率设置为 e = 4 e=4 e=4。结果如表5所示。在图像分类和语义分割方面,我们的可变形双级路由注意力(Deformable Bi-level Routing Attention)性能明显优于现有的稀疏注意力机制。

分区因子 S S S。与BiFormer类似,我们选择使用 S S S作为训练尺寸的除数,以避免填充。我们使用分辨率为 224 = 7 × 32 224=7 \times 32 224=7×32的图像分类,并设置 S = 7 S=7 S=7,以确保每个阶段的特征图尺寸都能被整除。这一选择与Swin Transformer [29]中使用的策略一致,其中窗口大小为7。
在这里插入图片描述

Top-k选择。我们系统地调整了 k k k,以确保在后续阶段区域尺寸减小时,有合理数量的令牌被关注到可变形查询上。探索 k k k的各种组合是一个可行的选择。在表9中,我们按照DeBiFormer-STL(“STL”表示Swin-T布局)报告了在IN-1K上的消融结果。从这些实验中得出的一个关键观察结果是,增加关注到可变形查询的令牌数量对准确性和延迟有不利影响,而在第1和第2阶段增加关注到的令牌数量对准确性有影响。

不同阶段的可变形双级路由多头注意力(DBRMHA)。为了评估设计选择的影响,我们系统地用DBRMHA块替换了不同阶段中的双级路由注意力块,如表7所示。最初,所有阶段都使用双级路由注意力,类似于BiFormer-T [56],在图像分类中实现了 81.3 % 81.3 \% 81.3%的准确率。仅将第4阶段的一个块替换为DBRMHA,准确率立即提高了 + 0.21 +0.21 +0.21。将第4阶段的所有块都替换为DBRMHA,又增加了 + 0.05 +0.05 +0.05。在第3阶段进一步替换DBRMHA块继续提高了各项任务的性能。尽管早期阶段的替换带来的增益逐渐减少,但我们最终确定了一个版本——DeBiFormer,其中所有阶段都使用可变形双级路由注意力,以保持简洁性。
在这里插入图片描述

5 Grad-CAM可视化

为了进一步说明所提出的DeBiFormer识别重要区域注意力的能力,我们使用Grad-CAM [36]可视化了BiFormer-Base和DeBiFormer-Base最关注的区域。如图4所示,通过使用DBRA模块,我们的DeBiFormer-Base模型在定位目标对象方面表现更好,其中更多的区域被关注到。此外,我们的模型降低了在不必要区域的注意力,并更加关注必要区域。根据对更多必要区域的注意力,我们的DeBiFormer模型更加连续和完整地关注语义区域,这表明我们的模型具有更强的识别能力。这种能力相比BiFormer-Base带来了更好的分类和语义分割性能。
在这里插入图片描述

6 结论

本文介绍了可变形双级路由注意力Transformer(Deformable Bi-level Routing Attention Transformer),这是一种专为图像分类和密集预测任务设计的新型分层视觉Transformer。通过可变形双级路由注意力,我们的模型优化了查询-键-值交互,同时自适应地选择语义相关区域。这实现了更高效和有意义的注意力。大量实验表明,与强大的基线相比,我们的模型具有有效性。我们希望这项工作能为设计灵活且语义感知的注意力机制提供见解。

7 补充材料

7.1 偏移组

与[47]类似,为了促进变形点之间的多样性,我们遵循与MHSA中相似的范式,其中通道被分成多个头来计算各种注意力。因此,我们将通道分成 G G G组以生成不同的偏移量。偏移生成网络对来自不同组的特征共享权重。

7.2 可变形相对位置偏置

当然,将位置信息融入注意力机制已被证明对模型性能有益。诸如APE[15]、RPE[29]、CPE[8]、LogCPB[28]等方法以及其他方法已证明能够改善结果。Swin Transformer中引入的相对位置嵌入(RPE)特别编码了每对查询和键之间的相对位置,从而通过空间归纳偏置增强了普通注意力[29]。相对位置的显式建模特别适合可变形级别的注意力头。在这种情况下,变形键可以假设任意连续位置,而不是局限于固定的离散网格。

根据[47],相对坐标位移在空间维度上被限制在 [ − H , + H ] [-H,+H] [H,+H] [ − W , + W ] [-W,+W] [W,+W]范围内,并带有一个相对位置偏置(RPB),表示为 B ^ \hat{B} B^,其维度为 ( 2 H − 1 ) × ( 2 W − 1 ) (2 H-1) \times(2 W-1) (2H1)×(2W1)

然后,使用带参数偏置的双线性插值 φ ( B ^ ; R ) \varphi(\hat{B} ; R) φ(B^;R),在 [ − 1 , + 1 ] [-1,+1] [1,+1]范围内对相对位置进行采样。这是通过考虑连续相对位移来完成的,以确保覆盖所有可能的偏移值。

7.3 计算复杂度

可变形双层路由注意力(DBRA)的计算成本与Swin Transformer中的对应机制相当。DBRA的计算包括两部分:令牌到令牌的注意力和偏移量&采样。因此,这部分的计算是:

F L O P s def  = F L O P s attn  + F L O P s offset&sampling  = 2 H W N s C + 2 H W C 2 + 2 N s C 2 + ( k 2 + 6 ) N s C \begin{array}{l}FLOPs_{\text {def }}=FLOPs_{\text {attn }}+FLOPs_{\text {offset\&sampling }} \\=2 H W N_{s} C+2 H W C^{2}+2 N_{s} C^{2}+\left(k^{2}+6\right) N_{s} C\end{array} FLOPsdef =FLOPsattn +FLOPsoffset&sampling =2HWNsC+2HWC2+2NsC2+(k2+6)NsC

其中, N s = H W / r 2 N_{s}=H W / r^{2} Ns=HW/r2是采样点的数量, C C C是令牌嵌入维度。双层路由多头注意力的计算包括三部分:线性投影、区域到区域的路由和令牌到令牌的注意力。因此,这部分的计算是:

F L O P s b i = F L O P s proj  + F L O P s routing  + F L O P s attn  = 2 H W C 2 + 2 N s C 2 + 2 ( S 2 ) 2 C + 2 H W k N s S 2 C = 2 H W C 2 + 2 N s C 2 + C { 2 S 4 + 2 H W k N s S 2 } = 2 H W C 2 + 2 N s C 2 + C { 2 S 4 + k H W N s S 2 + k H W N s S 2 } ≥ 2 H W C 2 + 2 N s C 2 + 3 C { 2 S 4 ⋅ k H W N s S 2 ⋅ k H W N s S 2 } 1 3 = 2 H W C 2 + 2 N s C 2 + 3 C k 2 3 { 2 H W N s } 2 3 \begin{array}{l}FLOPs_{b i}=FLOPs_{\text {proj }}+FLOPs_{\text {routing }}+FLOPs_{\text {attn }} \\=2 H W C^{2}+2 N_{s} C^{2}+2\left(S^{2}\right)^{2} C+2 H W k \frac{N_{s}}{S^{2}} C \\=2 H W C^{2}+2 N_{s} C^{2}+C\left\{2 S^{4}+2 H W k \frac{N_{s}}{S^{2}}\right\} \\=2 H W C^{2}+2 N_{s} C^{2}+C\left\{2 S^{4}+\frac{k H W N_{s}}{S^{2}}+\frac{k H W N_{s}}{S^{2}}\right\} \\ \geq 2 H W C^{2}+2 N_{s} C^{2}+3 C\left\{2 S^{4} \cdot \frac{k H W N_{s}}{S^{2}} \cdot \frac{k H W N_{s}}{S^{2}}\right\}^{\frac{1}{3}} \\=2 H W C^{2}+2 N_{s} C^{2}+3 C k^{\frac{2}{3}}\left\{2 H W N_{s}\right\}^{\frac{2}{3}}\end{array} FLOPsbi=FLOPsproj +FLOPsrouting +FLOPsattn =2HWC2+2NsC2+2(S2)2C+2HWkS2NsC=2HWC2+2NsC2+C{2S4+2HWkS2Ns}=2HWC2+2NsC2+C{2S4+S2kHWNs+S2kHWNs}2HWC2+2NsC2+3C{2S4S2kHWNsS2kHWNs}31=2HWC2+2NsC2+3Ck32{2HWNs}32

其中, k k k是要注意的区域数量, S S S是区域划分因子。最后,DBRA的总计算包括两部分:可变形级别的多头注意力和双层路由多头注意力。因此,总计算量是:

F L O P s = F L O P s b i + F L O P s d e f = 2 H W N s C + 2 H W C 2 + 2 N s C 2 + ( k 2 + 6 ) N s C + 2 H W C 2 + 2 N s C 2 + 3 C k 2 3 { 2 H W N s } 2 3 = 2 H W N s C + 4 H W C 2 + 4 N s C 2 + ( k 2 + 6 ) N s C + 3 C k 2 3 { 2 H W N s } 2 3 \begin{array}{l}FLOPs=FLOPs_{b i}+FLOPs_{d e f} \\=2 H W N_{s} C+2 H W C^{2}+2 N_{s} C^{2}+\left(k^{2}+6\right) N_{s} C \\+2 H W C^{2}+2 N_{s} C^{2}+3 C k^{\frac{2}{3}}\left\{2 H W N_{s}\right\}^{\frac{2}{3}} \\=2 H W N_{s} C+4 H W C^{2}+4 N_{s} C^{2} \\+\left(k^{2}+6\right) N_{s} C+3 C k^{\frac{2}{3}}\left\{2 H W N_{s}\right\}^{\frac{2}{3}}\end{array} FLOPs=FLOPsbi+FLOPsdef=2HWNsC+2HWC2+2NsC2+(k2+6)NsC+2HWC2+2NsC2+3Ck32{2HWNs}32=2HWNsC+4HWC2+4NsC2+(k2+6)NsC+3Ck32{2HWNs}32

换句话说,DBRA实现了 O ( ( H W N s ) 2 3 ) O\left(\left(H W N_{s}\right)^{\frac{2}{3}}\right) O((HWNs)32)的复杂度。例如,对于图像分类的层次模型,其第三阶段具有 224 × 224 224 \times 224 224×224输入,通常具有 H = W = 14 H=W=14 H=W=14 S 2 = 49 S^{2}=49 S2=49 N s = 1 N_{s}=1 Ns=1 C = 384 C=384 C=384的计算规模,因此具有多头自注意力的计算复杂度。此外,通过增大下采样因子 r r r并根据区域划分因子 S S S进行缩放,可以进一步降低复杂度,使其适用于具有更高分辨率输入的任务,如目标检测和实例分割。

7.4 要关注的令牌

在表9中,我们展示了要关注查询的令牌和要关注可变形点的令牌。与其他方法相比,DeBiFormer每个查询要关注的令牌最少,但在Imagenet1K、ADE20K(S-FPN头)和COCO(Retina头)上表现出高性能。
在这里插入图片描述

7.5 更多可视化结果

有效感受野分析 为了评估不同模型中,输入尺寸为224x224时中心像素的有效感受野(ERF)[32],我们在图5中展示了比较分析。为了证明我们DeBiFormer的强大表示能力,我们还比较了具有相似计算成本的几种SOTA(state-of-the-art,当前最优)方法的有效感受野。如图5所示,我们的DeBiFormer在这些方法中拥有最大且最一致的有效感受野,同时保持了强大的局部敏感性,这是很难实现的。
在这里插入图片描述

Grad-CAM分析 为了进一步展示DBRA(动态双分支注意力,Dynamic Bi-branch Attention)的工作原理,我们在图6中展示了更多的可视化结果。得益于灵活的键值对选择,在大多数情况下,我们的DeBiFormer在早期阶段就关注于重要对象。同时,由于变形点的合理分配,它在多对象场景中也能更早地关注于不同的重要区域。凭借强大的DBRA模块,我们的DeBiFormer在最后两个阶段具有更大的热图区域,这代表了更强的识别能力。
在这里插入图片描述

7.6 详细实验设置

ImageNet-1K上的图像分类 如主文所述,每个模型在8个V100 GPU上以224x224的输入尺寸训练300个epoch。实验设置严格遵循DeiT[39]以进行公平比较。更多详细信息,请参阅提供的表10。

目标检测和实例分割 当将我们的DeBiFormer微调至COCO[17]上的目标检测和实例分割时,我们考虑了两种常见框架:Mask R-CNN[19]和RetinaNet[16]。对于优化,我们采用AdamW优化器,初始学习率为0.0002,由于设备限制,小批量大小为4。当训练不同大小的模型时,我们根据图像分类中使用的设置调整训练设置。训练模型时使用的详细超参数见表11。
在这里插入图片描述

语义分割 对于ADE20K,我们为所有训练了160K迭代的模型使用AdamW优化器,初始学习率为0.00006,权重衰减为0.01,小批量大小为16。在测试方面,我们在主要比较中报告了使用单尺度(SS)和多尺度(MS)测试的结果。对于多尺度测试,我们尝试了从0.5倍到1.75倍训练分辨率的分辨率范围。为了设置不同模型中的路径丢弃率,我们使用了与目标检测和实例分割相同的超参数。表8显示了Upernet框架在单尺度和多尺度IoU下的结果。

7.7 限制和未来工作

与具有简单静态模式的稀疏注意力相比,我们提出了一种新的注意力方法,该方法由两个组件组成。首先,我们修剪区域级图,并为重要区域收集键值对,这些区域由高度灵活的键值对所关注。然后,我们应用令牌到令牌的注意力。虽然这种方法由于在顶级k路由的语义相关区域级别和可变形的重要区域上操作而不会引起太多计算,但它不可避免地在线性投影期间涉及额外的参数容量交易。在未来的工作中,我们计划研究高效的稀疏注意力机制,并增强具有参数容量意识的视觉Transformer。

代码

import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from timm.models.registry import register_model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision

from torch import Tensor
from typing import Tuple
import numbers
from timm.models.layers import to_2tuple, trunc_normal_
from einops import rearrange
import gc
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from fairscale.nn.checkpoint import checkpoint_wrapper
from timm.models import register_model
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.vision_transformer import _cfg


class LayerNorm2d(nn.Module):

  def __init__(self, 
               channels
               ):
    super().__init__()
    self.ln = nn.LayerNorm(channels)

  def forward(self, x):
    x = rearrange(x, "N C H W -> N H W C")
    x = self.ln(x)
    x = rearrange(x, "N H W C -> N C H W")
    return x



def init_linear(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None: nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

#def to_4d(x,s,h,w):
#    return rearrange(x, 'b (s h w) c -> b c s h w',s=s,h=h,w=w)

def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

#def to_3d(x):
#    return rearrange(x, 'b c s h w -> b (s h w) c')

class Partial:
    def __init__(self, module, *args, **kwargs):
        self.module = module
        self.args = args
        self.kwargs = kwargs

    def __call__(self, *args_c, **kwargs_c):
        return self.module(*args_c, *self.args, **kwargs_c, **self.kwargs)



class LayerNormChannels(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.LayerNorm(channels)

    def forward(self, x):
        x = x.transpose(1, -1)
        x = self.norm(x)
        x = x.transpose(-1, 1)
        return x

class LayerNormProxy(nn.Module):
    
    def __init__(self, dim):
        
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):

        x = rearrange(x, 'b c h w -> b h w c')
        x = self.norm(x)
        return rearrange(x, 'b h w c -> b c h w')
    
    
class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma+1e-5) * self.weight

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias


class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)
    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)
    
    
#class LayerNorm(nn.Module):
#    def __init__(self, dim, LayerNorm_type):
#        super(LayerNorm, self).__init__()
#        if LayerNorm_type =='BiasFree':
#            self.body = BiasFree_LayerNorm(dim)
#        else:
#            self.body = WithBias_LayerNorm(dim)
#    def forward(self, x):
#        s, h, w = x.shape[-3:]
#        return to_4d(self.body(to_3d(x)),s, h, w)
    

class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        """
        x: NHWC tensor
        """
        x = x.permute(0, 3, 1, 2) #NCHW
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) #NHWC

        return x
    
class ConvFFN(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 1, 1, 0)

    def forward(self, x):
        """
        x: NHWC tensor
        """
        x = x.permute(0, 3, 1, 2) #NCHW
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) #NHWC

        return x
    


class Attention(nn.Module):
    """
    vanilla attention
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        """
        args:
            x: NHWC tensor
        return:
            NHWC tensor
        """
        _, H, W, _ = x.size()
        x = rearrange(x, 'n h w c -> n (h w) c')

        #######################################
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        #######################################

        x = rearrange(x, 'n (h w) c -> n h w c', h=H, w=W)
        return x

class AttentionLePE(nn.Module):
    """
    vanilla attention
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)

    def forward(self, x):
        """
        args:
            x: NHWC tensor
        return:
            NHWC tensor
        """
        _, H, W, _ = x.size()
        x = rearrange(x, 'n h w c -> n (h w) c')

        #######################################
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W))
        lepe = rearrange(lepe, 'n c h w -> n (h w) c')

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = x + lepe

        x = self.proj(x)
        x = self.proj_drop(x)
        #######################################

        x = rearrange(x, 'n (h w) c -> n h w c', h=H, w=W)
        return x



class nchwAttentionLePE(nn.Module):
    """
    Attention with LePE, takes nchw input
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim ** -0.5

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(dim, dim, kernel_size=1)
        self.proj_drop = nn.Dropout(proj_drop)
        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)

    def forward(self, x:torch.Tensor):
        """
        args:
            x: NCHW tensor
        return:
            NCHW tensor
        """
        B, C, H, W = x.size()
        q, k, v = self.qkv.forward(x).chunk(3, dim=1) # B, C, H, W

        attn = q.view(B, self.num_heads, self.head_dim, H*W).transpose(-1, -2) @ \
               k.view(B, self.num_heads, self.head_dim, H*W)
        attn = torch.softmax(attn*self.scale, dim=-1)
        attn = self.attn_drop(attn)

        # (B, nhead, HW, HW) @ (B, nhead, HW, head_dim) -> (B, nhead, HW, head_dim)
        output:torch.Tensor = attn @ v.view(B, self.num_heads, self.head_dim, H*W).transpose(-1, -2)
        output = output.permute(0, 1, 3, 2).reshape(B, C, H, W)
        output = output + self.lepe(v)

        output = self.proj_drop(self.proj(output))

        return output


class TopkRouting(nn.Module):
    """
    differentiable topk routing with scaling
    Args:
        qk_dim: int, feature dimension of query and key
        topk: int, the 'topk'
        qk_scale: int or None, temperature (multiply) of softmax activation
        with_param: bool, wether inorporate learnable params in routing unit
        diff_routing: bool, wether make routing differentiable
        soft_routing: bool, wether make output value multiplied by routing weights
    """
    def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
        super().__init__()
        self.topk = topk
        self.qk_dim = qk_dim
        self.scale = qk_scale or qk_dim ** -0.5
        self.diff_routing = diff_routing
        # TODO: norm layer before/after linear?
        self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
        # routing activation
        self.routing_act = nn.Softmax(dim=-1)

    def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:
        """
        Args:
            q, k: (n, p^2, c) tensor
        Return:
            r_weight, topk_index: (n, p^2, topk) tensor
        """
        if not self.diff_routing:
            query, key = query.detach(), key.detach()
        query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c)
        attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
        topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
        r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)

        return r_weight, topk_index


class KVGather(nn.Module):
    def __init__(self, mul_weight='none'):
        super().__init__()
        assert mul_weight in ['none', 'soft', 'hard']
        self.mul_weight = mul_weight

    def forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):
        """
        r_idx: (n, p^2, topk) tensor
        r_weight: (n, p^2, topk) tensor
        kv: (n, p^2, w^2, c_kq+c_v)
        Return:
            (n, p^2, topk, w^2, c_kq+c_v) tensor
        """
        # select kv according to routing index
        n, p2, w2, c_kv = kv.size()
        topk = r_idx.size(-1)
        # print(r_idx.size(), r_weight.size())
        # FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?
        topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy
                                dim=2,
                                index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv)
                               )

        if self.mul_weight == 'soft':
            topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
        elif self.mul_weight == 'hard':
            raise NotImplementedError('differentiable hard routing TBA')
        # else: #'none'
        #     topk_kv = topk_kv # do nothing

        return topk_kv

class QKVLinear(nn.Module):
    def __init__(self, dim, qk_dim, bias=True):
        super().__init__()
        self.dim = dim
        self.qk_dim = qk_dim
        self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)

    def forward(self, x):
        q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)
        return q, kv
        # q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
        # return q, k, v


class QKVConv(nn.Module):
    def __init__(self, dim, qk_dim, bias=True):
        super().__init__()
        self.dim = dim
        self.qk_dim = qk_dim
        self.qkv = nn.Conv2d(dim,  qk_dim + qk_dim + dim, 1, 1, 0)

    def forward(self, x):
        q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=1)
        return q, kv
    

        
class BiLevelRoutingAttention(nn.Module):
    """
    n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
    kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
    topk: topk for window filtering
    param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
    param_routing: extra linear for routing
    diff_routing: wether to set routing differentiable
    soft_routing: wether to multiply soft routing weights 
    """
    def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
                 kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
                 topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
                 auto_pad=False):
        super().__init__()
        # local attention setting
        self.dim = dim
        self.n_win = n_win  # Wh, Ww
        self.num_heads = num_heads
        self.qk_dim = qk_dim or dim
        assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
        self.scale = qk_scale or self.qk_dim ** -0.5


        ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)
        
        ################ global routing setting #################
        self.topk = topk
        self.param_routing = param_routing
        self.diff_routing = diff_routing
        self.soft_routing = soft_routing
        # router
        assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
        self.router = TopkRouting(qk_dim=self.qk_dim,
                                  qk_scale=self.scale,
                                  topk=self.topk,
                                  diff_routing=self.diff_routing,
                                  param_routing=self.param_routing)
        if self.soft_routing: # soft routing, always diffrentiable (if no detach)
            mul_weight = 'soft'
        elif self.diff_routing: # hard differentiable routing
            mul_weight = 'hard'
        else:  # hard non-differentiable routing
            mul_weight = 'none'
        self.kv_gather = KVGather(mul_weight=mul_weight)

        # qkv mapping (shared by both global routing and local attention)
        self.param_attention = param_attention
        if self.param_attention == 'qkvo':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Linear(dim, dim)
        elif self.param_attention == 'qkv':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Identity()
        else:
            raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
        
        self.kv_downsample_mode = kv_downsample_mode
        self.kv_per_win = kv_per_win
        self.kv_downsample_ratio = kv_downsample_ratio
        self.kv_downsample_kenel = kv_downsample_kernel
        if self.kv_downsample_mode == 'ada_avgpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'ada_maxpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'maxpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'avgpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'identity': # no kv downsampling
            self.kv_down = nn.Identity()
        elif self.kv_downsample_mode == 'fracpool':
            # assert self.kv_downsample_ratio is not None
            # assert self.kv_downsample_kenel is not None
            # TODO: fracpool
            # 1. kernel size should be input size dependent
            # 2. there is a random factor, need to avoid independent sampling for k and v 
            raise NotImplementedError('fracpool policy is not implemented yet!')
        elif kv_downsample_mode == 'conv':
            # TODO: need to consider the case where k != v so that need two downsample modules
            raise NotImplementedError('conv policy is not implemented yet!')
        else:
            raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')

        # softmax for local attention
        self.attn_act = nn.Softmax(dim=-1)

        self.auto_pad=auto_pad

    def forward(self, x, ret_attn_mask=False):
        """
        x: NHWC tensor

        Return:
            NHWC tensor
        """
         # NOTE: use padding for semantic segmentation
        
        ###################################################
        if self.auto_pad:
            N, H_in, W_in, C = x.size()

            pad_l = pad_t = 0
            pad_r = (self.n_win - W_in % self.n_win) % self.n_win
            pad_b = (self.n_win - H_in % self.n_win) % self.n_win
            x = F.pad(x, (0, 0, # dim=-1
                          pad_l, pad_r, # dim=-2
                          pad_t, pad_b)) # dim=-3
            _, H, W, _ = x.size() # padded size
        else:
            N, H, W, C = x.size()
            
            #assert H%self.n_win == 0 and W%self.n_win == 0 #
        ###################################################


        # patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
        x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)

        #################qkv projection###################
        # q: (n, p^2, w, w, c_qk)
        # kv: (n, p^2, w, w, c_qk+c_v)
        # NOTE: separte kv if there were memory leak issue caused by gather
        q, kv = self.qkv(x) 

        # pixel-wise qkv
        # q_pix: (n, p^2, w^2, c_qk)
        # kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
        q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
        kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
        kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)

        q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)

        ##################side_dwconv(lepe)##################
        # NOTE: call contiguous to avoid gradient warning when using ddp
        lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())
        lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)

        ############ gather q dependent k/v #################

        r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors

        kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)
        k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
        # kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)
        # v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)
        
        ######### do attention as normal ####################
        k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
        v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
        q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)

        # param-free multihead attention
        attn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
        attn_weight = self.attn_act(attn_weight)
        out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
        out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
                        h=H//self.n_win, w=W//self.n_win)

        out = out + lepe
        # output linear
        out = self.wo(out)

        # NOTE: use padding for semantic segmentation
        # crop padded region
        if self.auto_pad and (pad_r > 0 or pad_b > 0):
            out = out[:, :H_in, :W_in, :].contiguous()

        if ret_attn_mask:
            return out, r_weight, r_idx, attn_weight
        else:
            return out
        
        
class TransformerMLPWithConv(nn.Module):

    def __init__(self, channels, expansion, drop):
        
        super().__init__()
        
        self.dim1 = channels
        self.dim2 = channels * expansion
        self.linear1 = nn.Sequential(
            nn.Conv2d(self.dim1, self.dim2, 1, 1, 0),
            # nn.GELU(),
            # nn.BatchNorm2d(self.dim2, eps=1e-5)
        )
        self.drop1 = nn.Dropout(drop, inplace=True)
        self.act = nn.GELU()
        # self.bn = nn.BatchNorm2d(self.dim2, eps=1e-5)
        self.linear2 = nn.Sequential(
            nn.Conv2d(self.dim2, self.dim1, 1, 1, 0),
            # nn.BatchNorm2d(self.dim1, eps=1e-5)
        )
        self.drop2 = nn.Dropout(drop, inplace=True)
        self.dwc = nn.Conv2d(self.dim2, self.dim2, 3, 1, 1, groups=self.dim2)
    
    def forward(self, x):
        
        x = self.linear1(x)
        x = self.drop1(x)
        x = x + self.dwc(x)
        x = self.act(x)
        # x = self.bn(x)
        x = self.linear2(x)
        x = self.drop2(x)
        
        return x
    
        
class DeBiLevelRoutingAttention(nn.Module):
    """
    n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
    kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
    topk: topk for window filtering
    param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
    param_routing: extra linear for routing
    diff_routing: wether to set routing differentiable
    soft_routing: wether to multiply soft routing weights
    """
    def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
                 kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
                 topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
                 auto_pad=False, param_size='small'):
        super().__init__()
        # local attention setting
        self.dim = dim
        self.n_win = n_win  # Wh, Ww
        self.num_heads = num_heads
        self.qk_dim = qk_dim or dim

#############################################################
        if param_size=='tiny':
            if self.dim == 64 :
                self.n_groups = 1
                self.top_k_def = 16   # 2    128
                self.kk = 9
                self.stride_def = 8
                self.expain_ratio = 3
                self.q_size=to_2tuple(56)

            if self.dim == 128 :
                self.n_groups = 2
                self.top_k_def = 16   # 4    256
                self.kk = 7
                self.stride_def = 4
                self.expain_ratio = 3
                self.q_size=to_2tuple(28)

            if self.dim == 256 :
                self.n_groups = 4
                self.top_k_def = 4   # 8    512
                self.kk = 5
                self.stride_def = 2
                self.expain_ratio = 3
                self.q_size=to_2tuple(14)

            if self.dim == 512 :
                self.n_groups = 8
                self.top_k_def = 49   # 8    512
                self.kk = 3
                self.stride_def = 1
                self.expain_ratio = 3
                self.q_size=to_2tuple(7)
#############################################################
        if param_size=='small':
            if self.dim == 64 :
                self.n_groups = 1
                self.top_k_def = 16   # 2    128
                self.kk = 9
                self.stride_def = 8
                self.expain_ratio = 3
                self.q_size=to_2tuple(56)

            if self.dim == 128 :
                self.n_groups = 2
                self.top_k_def = 16   # 4    256
                self.kk = 7
                self.stride_def = 4
                self.expain_ratio = 3
                self.q_size=to_2tuple(28)

            if self.dim == 256 :
                self.n_groups = 4
                self.top_k_def = 4   # 8    512
                self.kk = 5
                self.stride_def = 2
                self.expain_ratio = 3
                self.q_size=to_2tuple(14)

            if self.dim == 512 :
                self.n_groups = 8
                self.top_k_def = 49   # 8    512
                self.kk = 3
                self.stride_def = 1
                self.expain_ratio = 1
                self.q_size=to_2tuple(7)
#############################################################
        if param_size=='base':
            if self.dim == 96 :
                self.n_groups = 1
                self.top_k_def = 16   # 2    128
                self.kk = 9
                self.stride_def = 8
                self.expain_ratio = 3
                self.q_size=to_2tuple(56)

            if self.dim == 192 :
                self.n_groups = 2
                self.top_k_def = 16   # 4    256
                self.kk = 7
                self.stride_def = 4
                self.expain_ratio = 3
                self.q_size=to_2tuple(28)

            if self.dim == 384 :
                self.n_groups = 3
                self.top_k_def = 4   # 8    512
                self.kk = 5
                self.stride_def = 2
                self.expain_ratio = 3
                self.q_size=to_2tuple(14)
                
            if self.dim == 768 :
                self.n_groups = 6
                self.top_k_def = 49   # 8    512
                self.kk = 3
                self.stride_def = 1
                self.expain_ratio = 3
                self.q_size=to_2tuple(7)
            

        self.q_h, self.q_w = self.q_size
        
        self.kv_h, self.kv_w = self.q_h // self.stride_def, self.q_w // self.stride_def
        self.n_group_channels = self.dim // self.n_groups
        self.n_group_heads = self.num_heads // self.n_groups
        self.n_group_channels = self.dim // self.n_groups
        
        self.offset_range_factor = -1
        self.head_channels = dim // num_heads

        self.n_group_heads = self.num_heads // self.n_groups

        #assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
        self.scale = qk_scale or self.qk_dim ** -0.5
        
        
        
        
        self.rpe_table = nn.Parameter(
                    torch.zeros(self.num_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
                )
        trunc_normal_(self.rpe_table, std=0.01)




        ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
        self.lepe1 = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=self.stride_def, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)


        ################ global routing setting #################
        self.topk = topk
        self.param_routing = param_routing
        self.diff_routing = diff_routing
        self.soft_routing = soft_routing


        # router
        #assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
        self.router = TopkRouting(qk_dim=self.qk_dim,
                                  qk_scale=self.scale,
                                  topk=self.topk,
                                  diff_routing=self.diff_routing,
                                  param_routing=self.param_routing)

        if self.soft_routing: # soft routing, always diffrentiable (if no detach)
            mul_weight = 'soft'
        elif self.diff_routing: # hard differentiable routing
            mul_weight = 'hard'
        else:  # hard non-differentiable routing
            mul_weight = 'none'
        self.kv_gather = KVGather(mul_weight=mul_weight)




        # qkv mapping (shared by both global routing and local attention)
        self.param_attention = param_attention
        if self.param_attention == 'qkvo':
            #self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.qkv_conv = QKVConv(self.dim, self.qk_dim)
            #self.wo = nn.Linear(dim, dim)
        elif self.param_attention == 'qkv':
            #self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.qkv_conv = QKVConv(self.dim, self.qk_dim)
            #self.wo = nn.Identity()
        else:
            raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')




        self.kv_downsample_mode = kv_downsample_mode
        self.kv_per_win = kv_per_win
        self.kv_downsample_ratio = kv_downsample_ratio
        self.kv_downsample_kenel = kv_downsample_kernel
        if self.kv_downsample_mode == 'ada_avgpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'ada_maxpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'maxpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'avgpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'identity': # no kv downsampling
            self.kv_down = nn.Identity()
        elif self.kv_downsample_mode == 'fracpool':
            raise NotImplementedError('fracpool policy is not implemented yet!')
        elif kv_downsample_mode == 'conv':
            raise NotImplementedError('conv policy is not implemented yet!')
        else:
            raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')

        self.attn_act = nn.Softmax(dim=-1)

        self.auto_pad=auto_pad




##########################################################################################

        self.proj_q = nn.Conv2d(
            dim, dim,
            kernel_size=1, stride=1, padding=0
        )
        self.proj_k = nn.Conv2d(
            dim, dim,
            kernel_size=1, stride=1, padding=0
        )
        self.proj_v = nn.Conv2d(
            dim, dim,
            kernel_size=1, stride=1, padding=0
        )
        self.proj_out = nn.Conv2d(
            dim, dim,
            kernel_size=1, stride=1, padding=0
        )
        
        self.unifyheads1 = nn.Conv2d(
            dim, dim,
            kernel_size=1, stride=1, padding=0
        )

        self.conv_offset_q = nn.Sequential(
                        nn.Conv2d(self.n_group_channels, self.n_group_channels, (self.kk,self.kk), (self.stride_def,self.stride_def), (self.kk//2,self.kk//2), groups=self.n_group_channels, bias=False),
                        LayerNormProxy(self.n_group_channels),
                        nn.GELU(),
                        nn.Conv2d(self.n_group_channels, 1, 1, 1, 0, bias=False),
                )


### FFN

        self.norm = nn.LayerNorm(dim, eps=1e-6)

        self.norm2 = nn.LayerNorm(dim, eps=1e-6)

        self.mlp =TransformerMLPWithConv(dim, self.expain_ratio, 0.)


    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device)
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key).mul_(2).sub_(1)
        ref[..., 0].div_(H_key).mul_(2).sub_(1)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2

        return ref

    @torch.no_grad()
    def _get_q_grid(self, H, W, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.arange(0, H, dtype=dtype, device=device),
            torch.arange(0, W, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2

        return ref

    def forward(self, x, ret_attn_mask=False):
        dtype, device = x.dtype, x.device
        """
        x: NHWC tensor
        Return:
            NHWC tensor
        """
# NOTE: use padding for semantic segmentation
###################################################
        if self.auto_pad:
            N, H_in, W_in, C = x.size()

            pad_l = pad_t = 0
            pad_r = (self.n_win - W_in % self.n_win) % self.n_win
            pad_b = (self.n_win - H_in % self.n_win) % self.n_win
            x = F.pad(x, (0, 0, # dim=-1
                          pad_l, pad_r, # dim=-2
                          pad_t, pad_b)) # dim=-3
            _, H, W, _ = x.size() # padded size
        else:
            N, H, W, C = x.size()
            assert H%self.n_win == 0 and W%self.n_win == 0 #
            
        #print("X_in")
        #print(x.shape)
        
###################################################
        #q=self.proj_q_def(x)
        x_res = rearrange(x, "n h w c -> n c h w")
#################qkv projection###################
        
        q,kv = self.qkv_conv(x.permute(0, 3, 1, 2))
        q_bi = rearrange(q, "n c (j h) (i w) -> n (j i) h w c", j=self.n_win, i=self.n_win)
        kv = rearrange(kv, "n c (j h) (i w) -> n (j i) h w c", j=self.n_win, i=self.n_win)
        
        
        q_pix = rearrange(q_bi, 'n p2 h w c -> n p2 (h w) c')
        kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
        kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)

##################side_dwconv(lepe)##################
        # NOTE: call contiguous to avoid gradient warning when using ddp
        lepe1 = self.lepe1(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())


#################################################################   Offset Q
        
        q_off = rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
        offset_q = self.conv_offset_q(q_off).contiguous() # B * g 2 Sg HWg
        Hk, Wk = offset_q.size(2), offset_q.size(3)
        n_sample = Hk * Wk

        if self.offset_range_factor > 0:
            offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)
            offset_q = offset_q.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset_q = rearrange(offset_q, 'b p h w -> b h w p') # B * g 2 Hg Wg -> B*g Hg Wg 2
        reference = self._get_ref_points(Hk, Wk, N, dtype, device)

        if self.offset_range_factor >= 0:
            pos_k = offset_q + reference
        else:
            pos_k = (offset_q + reference).clamp(-1., +1.)

        x_sampled_q = F.grid_sample(
            input=x_res.reshape(N * self.n_groups, self.n_group_channels, H, W),
            grid=pos_k[..., (1, 0)], # y, x -> x, y
            mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg

        q_sampled = x_sampled_q.reshape(N, C, Hk, Wk)


########  Bi-LEVEL Gathering

        if self.auto_pad:
            q_sampled=q_sampled.permute(0, 2, 3, 1)
            Ng, Hg, Wg, Cg = q_sampled.size()
        
            pad_l = pad_t = 0
            pad_rg = (self.n_win - Wg % self.n_win) % self.n_win
            pad_bg = (self.n_win - Hg % self.n_win) % self.n_win
            q_sampled = F.pad(q_sampled, (0, 0, # dim=-1
                          pad_l, pad_rg, # dim=-2
                          pad_t, pad_bg)) # dim=-3
            _, Hg, Wg, _ = q_sampled.size() # padded size
            
            q_sampled=q_sampled.permute(0, 3, 1, 2)
            
            lepe1 = F.pad(lepe1.permute(0, 2, 3, 1), (0, 0, # dim=-1
                          pad_l, pad_rg, # dim=-2
                          pad_t, pad_bg)) # dim=-3
            lepe1=lepe1.permute(0, 3, 1, 2)
            
            pos_k = F.pad(pos_k, (0, 0, # dim=-1
                          pad_l, pad_rg, # dim=-2
                          pad_t, pad_bg)) # dim=-3

            
        queries_def = self.proj_q(q_sampled)  #Linnear projection

        queries_def = rearrange(queries_def, "n c (j h) (i w) -> n (j i) h w c", j=self.n_win, i=self.n_win).contiguous()

        q_win, k_win = queries_def.mean([2, 3]), kv[..., 0:(self.qk_dim)].mean([2, 3])
        r_weight, r_idx = self.router(q_win, k_win)
        kv_gather = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix)  # (n, p^2, topk, h_kv*w_kv, c )

        k_gather, v_gather = kv_gather.split([self.qk_dim, self.dim], dim=-1)

        ###     Bi-level Routing MHA
        k = rearrange(k_gather, 'n p2 k hw (m c) -> (n p2) m c (k hw)', m=self.num_heads)
        v = rearrange(v_gather, 'n p2 k hw (m c) -> (n p2) m (k hw) c', m=self.num_heads)
        q_def = rearrange(queries_def,  'n p2 h w (m c)-> (n p2) m (h w) c',m=self.num_heads)

        attn_weight = (q_def * self.scale) @ k
        attn_weight = self.attn_act(attn_weight)
        out = attn_weight @ v

        out_def = rearrange(out, '(n j i) m (h w) c -> n (m c) (j h) (i w)', j=self.n_win, i=self.n_win, h=Hg//self.n_win, w=Wg//self.n_win).contiguous()

        out_def = out_def + lepe1

        out_def = self.unifyheads1(out_def)
        
        out_def = q_sampled + out_def
        
        out_def = out_def + self.mlp(self.norm2(out_def.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)) # (N, C, H, W)


#############################################################################################




########   Deformable Gathering
#############################################################################################  
 
        out_def = self.norm(out_def.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        
        k = self.proj_k(out_def)
        v = self.proj_v(out_def)
             
        k_pix_sel = rearrange(k, 'n (m c) h w -> (n m) c (h w)', m=self.num_heads)
        v_pix_sel = rearrange(v, 'n (m c) h w -> (n m) c (h w)', m=self.num_heads)
        q_pix = rearrange(q, 'n (m c) h w -> (n m) c (h w)', m=self.num_heads)
        
        attn = torch.einsum('b c m, b c n -> b m n', q_pix, k_pix_sel) # B * h, HW, Ns
        attn = attn.mul(self.scale)
        
        ### Bias
        rpe_table = self.rpe_table
        rpe_bias = rpe_table[None, ...].expand(N, -1, -1, -1)
        q_grid = self._get_q_grid(H, W, N, dtype, device)
        displacement = (q_grid.reshape(N * self.n_groups, H * W, 2).unsqueeze(2) - pos_k.reshape(N * self.n_groups, Hg*Wg, 2).unsqueeze(1)).mul(0.5)
        attn_bias = F.grid_sample(
                    input=rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True) # B * g, h_g, HW, Ns
        attn_bias = attn_bias.reshape(N * self.num_heads, H * W, Hg*Wg)
        attn = attn + attn_bias
        ### 
        attn = F.softmax(attn, dim=2)
        out = torch.einsum('b m n, b c n -> b c m', attn, v_pix_sel)
        out = out.reshape(N,C,H,W).contiguous()
        out = self.proj_out(out).permute(0,2,3,1)

#############################################################################################
        
        # NOTE: use padding for semantic segmentation
        # crop padded region
        if self.auto_pad and (pad_r > 0 or pad_b > 0):
            out = out[:, :H_in, :W_in, :].contiguous()

        if ret_attn_mask:
            return out, r_weight, r_idx, attn_weight
        else:
            return out




def get_pe_layer(emb_dim, pe_dim=None, name='none'):
    if name == 'none':
        return nn.Identity()
    else:
        raise ValueError(f'PE name {name} is not surpported!')


class Block(nn.Module):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=-1,
                       num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
                       kv_per_win=4, kv_downsample_ratio=4, 
                       kv_downsample_kernel=None, kv_downsample_mode='ada_avgpool',
                       topk=4, param_attention="qkvo", param_routing=False, 
                       diff_routing=False, soft_routing=False, mlp_ratio=4, param_size='small',mlp_dwconv=False,
                       side_dwconv=5, before_attn_dwconv=3, pre_norm=True, auto_pad=False):
        super().__init__()
        qk_dim = qk_dim or dim

        # modules
        if before_attn_dwconv > 0:
            self.pos_embed1 = nn.Conv2d(dim, dim,  kernel_size=before_attn_dwconv, padding=1, groups=dim)
            self.pos_embed2 = nn.Conv2d(dim, dim,  kernel_size=before_attn_dwconv, padding=1, groups=dim)
        else:
            self.pos_embed = lambda x: 0
            
        self.norm1 = nn.LayerNorm(dim, eps=1e-6) # important to avoid attention collapsing
        #if topk > 0:
        if topk == 4:
            self.attn1 = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
                                        kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
                                        topk=1, param_attention=param_attention, param_routing=param_routing,
                                        diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
                                        auto_pad=auto_pad)
            
            self.attn2 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
                                        kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
                                        topk=topk, param_attention=param_attention, param_routing=param_routing,
                                        diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
                                        auto_pad=auto_pad,param_size=param_size)
            
            
        elif topk == 8:
            self.attn1 = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
                                        kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
                                        topk=4, param_attention=param_attention, param_routing=param_routing,
                                        diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
                                        auto_pad=auto_pad)
            
            self.attn2 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
                                        kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
                                        topk=topk, param_attention=param_attention, param_routing=param_routing,
                                        diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
                                        auto_pad=auto_pad,param_size=param_size)
            
        elif topk == 16:
            self.attn1 = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
                                        kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
                                        topk=16, param_attention=param_attention, param_routing=param_routing,
                                        diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
                                        auto_pad=auto_pad)
            
            self.attn2 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
                                        kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
                                        topk=topk, param_attention=param_attention, param_routing=param_routing,
                                        diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
                                        auto_pad=auto_pad,param_size=param_size)
            
            
            
            
        elif topk == -1:
            self.attn = Attention(dim=dim)
        elif topk == -2:
            self.attn1 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
                                        kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
                                        topk=49, param_attention=param_attention, param_routing=param_routing,
                                        diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
                                        auto_pad=auto_pad,param_size=param_size)
            
            self.attn2 = DeBiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
                                        kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
                                        topk=49, param_attention=param_attention, param_routing=param_routing,
                                        diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
                                        auto_pad=auto_pad,param_size=param_size)
            
        elif topk == 0:
            self.attn = nn.Sequential(Rearrange('n h w c -> n c h w'), # compatiability
                                      nn.Conv2d(dim, dim, 1), # pseudo qkv linear
                                      nn.Conv2d(dim, dim, 5, padding=2, groups=dim), # pseudo attention
                                      nn.Conv2d(dim, dim, 1), # pseudo out linear
                                      Rearrange('n c h w -> n h w c')
                                     )
            
        
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        
        self.mlp1 = TransformerMLPWithConv(dim, mlp_ratio, 0.)
        
        
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        
        self.norm3 = nn.LayerNorm(dim, eps=1e-6)
        self.norm4 = nn.LayerNorm(dim, eps=1e-6)
        
        self.mlp2 =TransformerMLPWithConv(dim, mlp_ratio, 0.)

        
        # tricks: layer scale & pre_norm/post_norm
        if layer_scale_init_value > 0:
            self.use_layer_scale = True
            self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.gamma4 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        else:
            self.use_layer_scale = False
        self.pre_norm = pre_norm


    def forward(self, x):
        """
        x: NCHW tensor
        """
        # conv pos embedding
        x = x + self.pos_embed1(x)
        # permute to NHWC tensor for attention & mlp
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                x = x + self.drop_path1(self.gamma1 * self.attn1(self.norm1(x))) # (N, H, W, C)
                x = x + self.drop_path1(self.gamma2 * self.mlp1(self.norm2(x))) # (N, H, W, C)
                
                # conv pos embedding
                x = x + self.pos_embed2(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        
                x = x + self.drop_path2(self.gamma3 * self.attn2(self.norm3(x))) # (N, H, W, C)
                x = x + self.drop_path2(self.gamma4 * self.mlp2(self.norm4(x))) # (N, H, W, C)
                
            else:
                x = x + self.drop_path1(self.attn1(self.norm1(x))) # (N, H, W, C)
                x = x + self.drop_path1(self.mlp1(self.norm2(x).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) # (N, H, W, C)
                
                # conv pos embedding
                x = x + self.pos_embed2(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
                
                x = x + self.drop_path2(self.attn2(self.norm3(x))) # (N, H, W, C)
                x = x + self.drop_path2(self.mlp2(self.norm4(x).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) # (N, H, W, C)
                
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                x = self.norm1(x + self.drop_path1(self.gamma1 * self.attn1(x))) # (N, H, W, C)
                x = self.norm2(x + self.drop_path1(self.gamma2 * self.mlp1(x))) # (N, H, W, C)
                
                # conv pos embedding
                x = x + self.pos_embed2(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
                
                x = self.norm3(x + self.drop_path2(self.gamma3 * self.attn2(x))) # (N, H, W, C)
                x = self.norm4(x + self.drop_path2(self.gamma4 * self.mlp2(x))) # (N, H, W, C)
                
                
            else:
                x = self.norm1(x + self.drop_path1(self.attn1(x))) # (N, H, W, C)
                x = x + self.drop_path1(self.mlp1(self.norm2(x).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) # (N, H, W, C)
                
                # conv pos embedding
                x = x + self.pos_embed2(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
                
                x = self.norm3(x + self.drop_path2(self.attn2(x))) # (N, H, W, C)
                x = x + self.drop_path2(self.mlp2(self.norm4(x).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) # (N, H, W, C)

        # permute back
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x


class DeBiFormer(nn.Module):
    def __init__(self, depth=[3, 4, 8, 3], in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512],
                 head_dim=64, qk_scale=None, representation_size=None,
                 drop_path_rate=0., drop_rate=0.,
                 use_checkpoint_stages=[],
                 ########
                 n_win=7,
                 kv_downsample_mode='ada_avgpool',
                 kv_per_wins=[2, 2, -1, -1],
                 topks=[8, 8, -1, -1],
                 side_dwconv=5,
                 layer_scale_init_value=-1,
                 qk_dims=[None, None, None, None],
                 param_routing=False, diff_routing=False, soft_routing=False,
                 pre_norm=True,
                 pe=None,
                 pe_stages=[0],
                 before_attn_dwconv=3,
                 auto_pad=True,
                 #-----------------------
                 kv_downsample_kernels=[4, 2, 1, 1],
                 kv_downsample_ratios=[4, 2, 1, 1], # -> kv_per_win = [2, 2, 2, 1]
                 mlp_ratios=[4, 4, 4, 4],
                 param_attention='qkvo',
                 param_size='small',
                 mlp_dwconv=False):
        """
        Args:
            depth (list): depth of each stage
            img_size (int, tuple): input image size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (list): embedding dimension of each stage
            head_dim (int): head dimension
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            norm_layer (nn.Module): normalization layer
            conv_stem (bool): whether use overlapped patch stem
        """
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        ############ downsample layers (patch embeddings) ######################
        self.downsample_layers = nn.ModuleList()
        # NOTE: uniformer uses two 3*3 conv, while in many other transformers this is one 7*7 conv
        stem = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim[0] // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(embed_dim[0] // 2),
            nn.GELU(),
            nn.Conv2d(embed_dim[0] // 2, embed_dim[0], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(embed_dim[0]),
        )
        if (pe is not None) and 0 in pe_stages:
            stem.append(get_pe_layer(emb_dim=embed_dim[0], name=pe))
        if use_checkpoint_stages:
            stem = checkpoint_wrapper(stem)
        self.downsample_layers.append(stem)

        for i in range(3):
            downsample_layer = nn.Sequential(
                nn.Conv2d(embed_dim[i], embed_dim[i+1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
                nn.BatchNorm2d(embed_dim[i+1])
            )
            if (pe is not None) and i+1 in pe_stages:
                downsample_layer.append(get_pe_layer(emb_dim=embed_dim[i+1], name=pe))
            if use_checkpoint_stages:
                downsample_layer = checkpoint_wrapper(downsample_layer)
            self.downsample_layers.append(downsample_layer)
        ##########################################################################

        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        nheads= [dim // head_dim for dim in qk_dims]
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=embed_dim[i], drop_path=dp_rates[cur + j],
                        layer_scale_init_value=layer_scale_init_value,
                        topk=topks[i],
                        num_heads=nheads[i],
                        n_win=n_win,
                        qk_dim=qk_dims[i],
                        qk_scale=qk_scale,
                        kv_per_win=kv_per_wins[i],
                        kv_downsample_ratio=kv_downsample_ratios[i],
                        kv_downsample_kernel=kv_downsample_kernels[i],
                        kv_downsample_mode=kv_downsample_mode,
                        param_attention=param_attention,
                        param_size=param_size,
                        param_routing=param_routing,
                        diff_routing=diff_routing,
                        soft_routing=soft_routing,
                        mlp_ratio=mlp_ratios[i],
                        mlp_dwconv=mlp_dwconv,
                        side_dwconv=side_dwconv,
                        before_attn_dwconv=before_attn_dwconv,
                        pre_norm=pre_norm,
                        auto_pad=auto_pad) for j in range(depth[i])],
            )
            if i in use_checkpoint_stages:
                stage = checkpoint_wrapper(stage)
            self.stages.append(stage)
            cur += depth[i]

        ##########################################################################
        self.norm = nn.BatchNorm2d(embed_dim[-1])
        # Representation layer
        if representation_size:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()
        self.channels = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]

            
        self.reset_parameters()
    def reset_parameters(self):
        for m in self.parameters():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head


    def forward(self, x):
        out=[]
        for i in range(4):
            x = self.downsample_layers[i](x) # res = (56, 28, 14, 7), wins = (64, 16, 4, 1)
            x = self.stages[i](x)
            out.append(x)
        return out
    


@register_model
def debi_tiny(pretrained=False, pretrained_cfg=None, **kwargs):
    model = DeBiFormer(
            depth=[1, 1, 4, 1],
            embed_dim=[64, 128, 256, 512], mlp_ratios=[2, 2, 2, 2],
            param_size='tiny',
            drop_path_rate=0.,  #Drop rate
            #------------------------------
            n_win=7,
            kv_downsample_mode='identity',
            kv_per_wins=[-1, -1, -1, -1],
            topks=[4, 8, 16, -2],
            side_dwconv=5,
            before_attn_dwconv=3,
            layer_scale_init_value=-1,
            qk_dims=[64, 128, 256, 512],
            head_dim=32,
            param_routing=False, diff_routing=False, soft_routing=False,
            pre_norm=True,
            pe=None)
    return model




@register_model
def debi_small(pretrained=False, pretrained_cfg=None, **kwargs):
    model = DeBiFormer(
            depth=[2, 2, 9, 3],
            embed_dim=[64, 128, 256, 512], mlp_ratios=[3, 3, 3, 2],
            param_size='small',
            drop_path_rate=0.3,  #Drop rate
            #------------------------------
            n_win=7,
            kv_downsample_mode='identity',
            kv_per_wins=[-1, -1, -1, -1],
            topks=[4, 8, 16, -2],
            side_dwconv=5,
            before_attn_dwconv=3,
            layer_scale_init_value=-1,
            qk_dims=[64, 128, 256, 512],
            head_dim=32,
            param_routing=False, diff_routing=False, soft_routing=False,
            pre_norm=True,
            pe=None)
    return model



@register_model
def debi_base(pretrained=False, pretrained_cfg=None, **kwargs):
    model = DeBiFormer(
            depth=[2, 2, 9, 2],
            embed_dim=[96, 192, 384, 768], mlp_ratios=[3, 3, 3, 3],
            param_size='base',
            drop_path_rate=0.4,  #Drop rate
            #------------------------------
            n_win=7,
            kv_downsample_mode='identity',
            kv_per_wins=[-1, -1, -1, -1],
            topks=[4, 8, 16, -2],
            side_dwconv=5,
            before_attn_dwconv=3,
            layer_scale_init_value=-1,
            qk_dims=[96, 192, 384, 768],
            head_dim=32,
            param_routing=False, diff_routing=False, soft_routing=False,
            pre_norm=True,
            pe=None)
    return model


if __name__ == '__main__':
    model = debi_tiny(pretrained=True)
    inputs = torch.randn((1, 3, 640, 640))
    for i in model(inputs):
        print(i.size())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2218344.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Centos7搭建minio对象存储服务器

Centos7搭建minio对象存储服务器 安装二进制程序配置服务文件 安装二进制程序 参考:https://segmentfault.com/q/1010000042181876 minio中国版:https://www.minio.org.cn/download.shtml#/linux # 下载二进制程序 wget https://dl.min.io/server/min…

matlab相位图

% 清空工作空间和命令窗口 clear; clc; % 模拟生成时间t,位移y(t)和角位移theta(t) t linspace(0, 100, 1000); % 时间从0到100,包含1000个点 y 1e-5 * sin(2 * pi * 0.1 * t) .* exp(-0.01 * t); % 位移y(t) 振荡衰减 theta 1e-6 * cos(2 * pi * …

第8篇:网络安全基础

目录 引言 8.1 网络安全的基本概念 8.2 网络威胁与攻击类型 8.3 密码学的基本思想与加密算法 8.4 消息认证与数字签名 8.5 网络安全技术与协议 8.6 总结 第8篇:网络安全基础 引言 在现代信息社会中,计算机网络无处不在,从互联网到局…

如何将 Docker 镜像的 tar 文件迁移到另一台服务器并运行容器

在 Docker 容器化的世界里,我们经常需要将容器从一个环境迁移到另一个环境。这可能是因为开发、测试或生产环境的需求。本文将详细介绍如何将 Docker 镜像的 tar 文件从一台服务器迁移到另一台服务器(IP 地址为 192.168.100.10),并…

深度学习每周学习总结J3(DenseNet-121算法实战与解析 - 鸟类识别)

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 目录 DenseNet 与 ResNet1. 设计理念2. 网络结构3. 与CNN和ResNet的对比补充:一些网络结构对比的网站(重要&#x…

【C++基础篇】——逐步了解C++

【C基础篇】——逐步了解C 文章目录 【C基础篇】——逐步了解C前言一、C的第一个程序二、命名空间1.namespace的价值2.namespace的定义3.命名空间的使用 三、C的输入&输出四、缺省参数五、函数重载六、引用1.引用的概念和定义:2.引用的特性3.引用的使用4.const引…

标准/开源版本,长连接无法启动

在配置长链接的时候,有时候会出现无法正常启动的问题,下面介绍几种情况,并给出解决办法 1、启动入下图所示,是因为你的php的禁用函数没有解禁,按照配置文档中的解禁所有禁用函数 2、检查你的反向代理是否配置正确&…

华山论剑之Rust的Trait

华山论剑,群雄荟萃,各显神通。武林中人,各有所长,或剑法飘逸,或掌法刚猛,或轻功绝顶。这就好比Rust中的trait,它定义了一种武功套路,而不同的门派、不同的人,可以将这套武…

shell脚本宝藏仓库(基础命令、正则表达式、shell基础、变量、逻辑判断、函数、数组)

一、shell概述 1.1 shell是什么 Shell是一种脚本语言 脚本:本质是一个文件,文件里面存放的是特定格式的指令,系统可以使用脚本解析器、翻译或解析指令并执行(shell不需要编译) Shell既是应用程序又是一种脚本语言&…

5.12 向内核传递信息(2)

首先是 设置 loard_16.c 中的 boot_info 这个变量, 这里最初保存的是 读取到的内存的信息。 然后是 增加跳转的函数。 然后就是 去改动 内核代码。 由于内核的代码最终 调用的是 kernel_init. 最关键的就是 这里了, call kernel_init 相当于 在调用一个…

阿里巴巴达摩院|Chain of Ideas: 利用大型语言模型代理革新新颖创意开发的研究

阿里巴巴集团达摩院|Chain of Ideas: 利用大型语言模型代理革新新颖创意开发的研究 🎯 推荐指数:🌟🌟🌟 📖 title:Chain of Ideas: Revolutionizing Research in Novel Idea Develop…

C++ | Leetcode C++题解之第492题构造矩形

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<int> constructRectangle(int area) {int w sqrt(1.0 * area);while (area % w) {--w;}return {area / w, w};} };

【十】Hyperf最简单的使用rabbitMQ

配置.env文件 AMQP_HOST192.168.86.102 AMQP_PORT5672 AMQP_USERrabbitmq AMQP_PASSWORDrabbitmq AMQP_VHOSTmy-test配置文件 生产者 消费者(注意&#xff1a;里面的num在实际使用的时候&#xff0c;至少为1&#xff0c;不然没有消费者队列来执行里面的操作) 调用&#xff0c;…

【在Linux世界中追寻伟大的One Piece】应用层自定义协议|序列化

目录 1 -> 应用层 2 -> 网络版计算器 3 -> 序列化与反序列化 4 -> 重新理解read、write、recv、send和tcp为什么支持全双工 5 -> 开始实现 5.1 -> 定制协议 5.2 -> 关于流式数据的处理 1 -> 应用层 应用层是OSI模型或TCP/IP模型中的最高层&…

【C++贪心】2712. 使所有字符相等的最小成本|1791

本文涉及知识点 C贪心 LeetCode2712. 使所有字符相等的最小成本 给你一个下标从 0 开始、长度为 n 的二进制字符串 s &#xff0c;你可以对其执行两种操作&#xff1a; 选中一个下标 i 并且反转从下标 0 到下标 i&#xff08;包括下标 0 和下标 i &#xff09;的所有字符&am…

软件设计模式------简单工厂模式

简单工厂模式&#xff08;Simple factory Pattern&#xff09;&#xff0c;又称静态工厂方法(Static Factory Method),属于创新型模式&#xff0c;但它不属于GoF23个设计模式其一。 一、模式动机&#xff1a; 有时需要创建一些来自相同父类的类的实例。 二、定义&#xff1a…

Java基于SpringBoot微信小程序的跳蚤市场系统设计与实现(lw+数据库+讲解等)

项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念&#xff0c;提供了一套默认的配置&#xff0c;让开发者可以更专注于业务逻辑而不是配置文件。Spring Boot 通过自动化配置和约…

【数据分享】中国历史学年鉴(1979-2001)

数据介绍 目录如下&#xff1a; 特稿 2000年国际历史科学大会 史学研究 史学理论 西周春秋战国史 秦汉史 魏晋南北朝史 隋唐五代史 宋史 辽西夏金史 蒙元史 明史 清史 晚清政治史 近代文化史 中外关系史 近代经济史 近代社会史 近代思想史 民国政治史 世…

Navigation2 算法流程

转自 https://zhuanlan.zhihu.com/p/405670882 此文仅作学习笔记 启动流程 在仿真环境中启动导航包的示例程序&#xff0c;执行nav2_bringup/bringup/launch/tb3_simulation_launch.py文件。ROS2的launch文件支持采用python语言来编写以支持更加复杂的功能&#xff0c;本文件…

React高级Hook

useReducer useReducer 是 React 提供的一个 Hook&#xff0c;用于在函数组件中使用 reducer 函数来管理组件的 state。它类似于 Redux 中的 reducer&#xff0c;但仅用于组件内部的状态管理。useReducer 可以使复杂的状态逻辑更加清晰和可维护。 基本用法 useReducer 接收…