paper:Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition
official implementation:https://github.com/houqb/VisionPermutator
出发点
现有的MLP模型在编码空间信息时通常会将空间维度展开并沿着展平的维度进行线性投影,这样会丢失由二维特征表示携带的位置信息。 为了解决这个问题,本文提出了Vision Permutator,一种新的纯MLP结构的网络,它分别沿着高度和宽度维度进行线性投影,从而保留精确的位置信息并捕获长距离依赖关系。 在不依赖于空间卷积或注意力机制的情况下,达到或超过了大多数CNN和视觉Transformer的性能。
创新点
- 新颖的Permute-MLP层:与现有的MLP模型不同,Vision Permutator提出了一种新的层结构,即Permute-MLP层。该层包含三个独立的分支,分别负责沿高度、宽度和通道维度编码特征。
- 位置敏感的输出:通过分别沿高度和宽度维度进行线性投影,Vision Permutator生成的位置敏感输出以互补的方式聚合,从而形成对目标对象的有效表示。
- 高效的性能:在不使用额外大规模训练数据的情况下,Vision Permutator在ImageNet上达到了81.5%的top-1准确率,并且参数量仅为25M。模型扩展到88M参数时,准确率进一步提升到83.2%。
方法介绍
Permutator block的结构如图1左所示,可以看到和Transformer block相似,只是将其中的self-attention换成了Permute-MLP层,Channel-MLP和Transformer block中的FFN类似,都是由两个全连接层和一个GELU激活函数组成。对于空间信息的编码,和最近的Mixer(具体介绍见MLP-Mixer(NeurIPS 2021, Google)论文与源码解读-CSDN博客)不同,它沿着空间维度对所有的token进行线性投影,而本文提出分别沿着宽度和高度维度来处理token。
具体来说,给定一个C维的输入token \(\mathbf{X}\in \mathbb{R}^{H\times W\times C}\),Permutator可以表示如下
其中LN指Layer Norm,输出 \(\mathbf{Z}\) 作为下一个Permutator block的输入。
Permute-MLP
Permute-MLP的过程如图2所示,与vision transformer和Mixer接收一个二维("tokens x channels",即 \(HW\times C\))的输入不同,Permute-MLP接收一个三维的输入。
如图2所示,Permute-MLP包括三个分支,分别负责沿高度、宽度、通道维度编码信息。通道信息编码很简单,只需要一个权重为 \(\mathbf{W}_C\in \mathbb{R}^{C\times C}\) 的全连接层就可以对输入 \(\mathbf{X}\) 进行线性投影得到输出 \(\mathbf{X}_C\)。接下来我们详细介绍下如何通过维度之间的permutation操作来编码空间信息。
假设隐藏维度C为384,输入图像的分辨率为224x224。为了沿高度维度对空间信息进行编码,我们首先进行一个height-channel维度的permutation操作。给定输入 \(\mathbf{X}\in \mathbb{R}^{H\times W\times C}\),我们首先沿通道维度将其均分成 \(S\) 份,得到 \([\mathbf{X}_{H_1},\mathbf{X}_{H_2},...,\mathbf{X}_{H_S}]\),且满足 \(C=N*S\)(本文\(N=H=W\))。如果patch大小设置为14x14,则 \(N=16\) 且 \(\mathbf{X}_{H_i}\in \mathbb{R}^{H\times W\times N}, \ (i\in \{1,...,S\})\)。然后我们对每个 \(\mathbf{X}_{H_i}\) 进行height-channel的permutation操作(就是转换第一个高度维度和第三个通道维度,\((H, W, C)\rightarrow(C,W,H)\),得到输出 \([\mathbf{X}_{H_1}^{\top}, \mathbf{X}_{H_2}^{\top}, \cdots, \mathbf{X}_{H_S}^{\top}]\),然后沿通道维度拼接。接着一个权重为 \(\mathbf{W}_H\in \mathbb{R}^{C\times C}\) 的全连接层用来混合高度信息。为了恢复到原始维度,只需要再执行一次height-channel permutation操作即可,得到最终输出 \(\mathbf{X}_{H}\)。类似的,在第二个分支,我们执行width-channel的permuation操作,然后得到输出 \(\mathbf{X}_{W}\)。然后将三个分支的输出进行element-wise summation,再通过一个全连接层得到Permute-MLP层的输出,如下
其中 \(FC(\cdot)\) 表示一个权重为 \(\mathbf{W}_P\in \mathbb{R}^{C\times C}\) 的全连接层。Permute-MLP的PyTorch代码如下所示
Weighted Permute-MLP
在式(3)中我们只是简单地将三个分支的输出进行相加,作者进一步提出了Weighted Permute-MLP来重新校正三个分支的重要性。具体采用ResNeSt中的split attention(具体介绍见ResNeSt-CSDN博客)来得到加权权重,区别在于ResNeSt中的split attention是在每个cardinal group内进行的,对每个radix group求一个权重。而这里是对 \(\mathbf{X}_{H},\mathbf{X}_{W},\mathbf{X}_{C}\) 进行的,求得三个权重。
实验结果
作者设计了5个不同大小的ViP,具体配置如下
作者比较了ViP和CNN、Transformer以及MLP类模型在ImageNet数据集上性能,首先是MLP类的性能如下表所示,可以看到ViP在MLP类的backbone中取得了最优的性能。
下表是和CNN、Transformer的代表网络的性能对比,可以看到和一些经典的卷积网络例如ResNet、RegNet相比在相似的模型大小下取得了更好的结果。和一些transformer模型例如DeiT、Swin Transformer相比效果也更好。但和一些最新的SOTA模型例如NFNet(86.5%)、CaiT(86.5%)相比,还有较大的差距。
代码解析
具体实现非常简单,官方实现的weighted permute mlp的代码如下,其中self.segment_dim就是文章中的N,即沿通道划分成S份后每份的维度。然后就是将H维度与channel维度调换,执行MLP;将W维度与channel维度调换,执行MLP;直接沿输入channel维度执行MLP。最后通过split attention得到三者的权重,最后加权求和得到最终输出结果。
class WeightedPermuteMLP(nn.Module):
def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.segment_dim = segment_dim
self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)
self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias)
self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias)
self.reweight = Mlp(dim, dim // 4, dim * 3)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, H, W, C = x.shape
S = C // self.segment_dim
h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S) # (B, seg_dim, W, H, S)->(B, seg_dim, W, H*S)
h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)
w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S)
w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)
c = self.mlp_c(x)
a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2) # (B,H,W,C)->(B,C,H,W)->(B,C,H*W)->(B,C)
a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)
# (B,3C)->(B,C,3)->(3,B,C)->(3,B,C)->(3,B,1,C)->(3,B,1,1,C)
x = h * a[0] + w * a[1] + c * a[2]
x = self.proj(x)
x = self.proj_drop(x)
return x