Cswin提出了上图中使用交叉形状局部attention,为了解决VIT模型中局部自注意力感受野进一步增长受限的问题,同时提出了局部增强位置编码模块,超越了Swin等模型,在多个任务上效果SOTA(当时的SOTA,已经被SG Former超越,感兴趣的可以看看SG Former)。
论文地址:https://arxiv.org/abs/2107.00652
代码地址:https://github.com/microsoft/CSWin-Transformer
模型整体结构如上所示,由token embeeding layer和4个stageblock所堆叠而成,每个stage block后面都会接入一个conv层,用来对featuremap进行下采样。和典型的R50设计类似,每次下采样后,会增加dim的数量,一是为了提升感受野,二是为了增加特征性。
研究动机:
- 基于global attention的transformer效果虽然好但是计算复杂度与特征图大小平方(H==W的情况)成正比。
- 基于local attention的transformer的会限制每个token的感受野的交互,减缓感受野的增长,需要堆叠大量的block来实现全局自注意力。
解决办法:
- 提出了Cross-Shaped Window self-attention机制,对注意力头进行分组,并行计算水平和竖直方向的self-attention,可以在更小的计算量条件下获得更好的效果。
- 提出了Locally-enhanced Positional Encoding(LePE), 可以更好的处理局部位置信息,并且支持任意形状的输入。
1.1 Convolutional Token Embedding
用convolution来做embedding,为了减少计算量,本文直接采用了7x7的卷积核,stride为4的卷积来直接对输入进行embedding,之后再对最后一维进行layernorm。
self.stage1_conv_embed = nn.Sequential(
nn.Conv2d(in_chans, embed_dim, 7, 4, 2),
Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4),
nn.LayerNorm(embed_dim)
)
1.2 Cross-Shaped Window Self-Attention
具体来讲,假设原始的Feature Map为,为了计算它在横向上的自注意力,它首先被拆分成个横条的数据(实际代码先进行竖列处理),其中是横条的宽度。在这4个不同的Stage中取不同的值,实验结果表明[1,2,7,7]这组值在速度和精度上取得了比较好的均衡。
对于每个条状特征,使用Transformer可以得到它的特征,最后将这个特征拼接到一起便得到了这个head的输入。假设它属于第个head,那么横向自注意力的计算方式为:
纵向自注意力V-Attention 和H-Attention的计算方式类似,不同的是它是取的宽度为的竖条。
最终,这个block的输出表示为:
CSWin self-attention计算复杂度分析:
对于高分辨率输入,H,W早期大于C,后期小于C,因此早期sw小,后期大。即,调整sw可以有效地扩大后期每个token的attention区域。为了使224×224输入的中间特征图大小可被sw整除,默认将4个阶段的sw设置为1、2、7、7。
def img2windows(img, H_sp, W_sp):
"""
img: B C H W
"""
B, C, H, W = img.shape
img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) # [N*56*1 56 32] [N*56*1 56 32] / [N*14*1 56 64] [N*14*1 56 64] / [N*2*1 98 128] [N*2*1 98 128] / [N*1*1 49 512]
return img_perm
def windows2img(img_splits_hw, H_sp, W_sp, H, W):
"""
img_splits_hw: B' H W C
"""
B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) # [N*56*1 56 32]->[N 1 56 56 1 32] [N*56*1 56 32]->[N 56 1 1 56 32] / [N*14*1 56 64]->[N 1 14 28 2 64] [N*14*1 56 64]->[N 14 1 2 28 64] / [N*2*1 98 128]->[N 1 2 14 7 128] [N*2*1 98 128]->[N 2 1 7 14 128] / [N*1*1 49 512]->[N 1 1 7 7 512]
img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # [N 56 56 32] [N 28 28 64] [N 14 14 128] [N 7 7 512]
return img
class LePEAttention(nn.Module):
def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0.,
qk_scale=None):
super().__init__()
self.dim = dim
self.dim_out = dim_out or dim
self.resolution = resolution
self.split_size = split_size
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
if idx == -1:
H_sp, W_sp = self.resolution, self.resolution
elif idx == 0:
H_sp, W_sp = self.resolution, self.split_size
elif idx == 1:
W_sp, H_sp = self.resolution, self.split_size
else:
print("ERROR MODE", idx)
exit(0)
self.H_sp = H_sp
self.W_sp = W_sp
stride = 1
self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
self.attn_drop = nn.Dropout(attn_drop)
def im2cswin(self, x):
B, N, C = x.shape
H = W = int(np.sqrt(N))
x = x.transpose(-2, -1).contiguous().view(B, C, H, W) # [B, N, C] -> [B, C, N] -> [B, C, H, W]
x = img2windows(x, self.H_sp, self.W_sp) # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]
x = x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1,
3).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
return x
def get_lepe(self, x, func):
B, N, C = x.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
H = W = int(np.sqrt(N))
x = x.transpose(-2, -1).contiguous().view(B, C, H, W) # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]
H_sp, W_sp = self.H_sp, self.W_sp
x = x.view(B, C, H // H_sp, H_sp, W // W_sp,
W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,
W_sp) ### B', C, H', W' # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]
lepe = func(
x) ### B', C, H', W' # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]
lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,
2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,
2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
return x, lepe
def forward(self, qkv):
"""
x: B L C
"""
q, k, v = qkv[0], qkv[1], qkv[2] # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
### Img2Window
H = W = self.resolution # 56 28 14 7
B, L, C = q.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
assert L == H * W, "flatten img_tokens has wrong size"
q = self.im2cswin(q) # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
k = self.im2cswin(k) # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
v, lepe = self.get_lepe(v, self.get_v)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N
attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
attn = self.attn_drop(attn)
x = (attn @ v) + lepe
x = x.transpose(1, 2).reshape(-1, self.H_sp * self.W_sp,
C) # B head N N @ B head N C # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]
### Window2Img
x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C) # B H' W' C
return x # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
代码部分其实和Swin类似,如果理解了swin的分窗机制,再加上head分组,基本上就能很快理解论文中思想。
1.3 Locally-Enhanced Positional Encoding(LePE)
因为Transformer是输入顺序无关的,因此需要向其中加入位置编码。上图左边为ViT模型的PE,使用的绝对位置编码或者是条件位置编码,只在embedding的时候与token一起进入transformer,中间的是Swin,CrossFormer等模型的PE,使用相对位置编码偏差,通过引入token图的权重来和attention一起计算,灵活度更好,相对APE效果更好。
本文所提出的LePE,相比于RPE更加直接,将位置信息施加到线性投影中,同时注意到RPE以head方式引入偏差,而LepE是per-channel bias,这可能显示出更强大的潜力来充当位置嵌入。也就是直接将位置编码添加加到了Value向量上,假设位置编码为,它的添加方式是通过将位置编码和相乘完成的。然后通过一个short-cut将添加了位置编码的和通过自注意力加权的单位加到一起,公式如下:
这里作者基于一个假设:对于一个输入元素,他附近的元素提供最重要的位置信息。所以对V做一个深度卷积,加到softmax之后的结果上。公式为:
这样,LePE可以友好地应用于将任意输入分辨率作为输入的下游任务。
def get_lepe(self, x, func):
# func -> self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim)
B, N, C = x.shape # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
H = W = int(np.sqrt(N))
x = x.transpose(-2, -1).contiguous().view(B, C, H, W) # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]
H_sp, W_sp = self.H_sp, self.W_sp
x = x.view(B, C, H // H_sp, H_sp, W // W_sp,
W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,
W_sp) ### B', C, H', W' # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]
lepe = func(
x) ### B', C, H', W' # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]
lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,
2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,
2).contiguous() # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
return x, lepe
1.4 CSWin Transformer Block
CSWin Transformer Block的结构如图所示,它最显著的特点是添加了两个shortcut,并使用LN对特征做归一化.
网络结构配置:
其中为第 个Transformer block的输出或各stage的卷积层。
CSwin的block有两个部分,一个是做LayerNorm和Cross-shaped window self-attention并接一个shortcut,另一个则是做LayerNorm和MLP,相比于Swin和Twins来说,block的计算量大大的降低了(swin,twins则是有两个attention+两个MLP堆叠一个block)。
class CSWinBlock(nn.Module):
def __init__(self, dim, reso, num_heads,
split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
last_stage=False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.patches_resolution = reso
self.split_size = split_size
self.mlp_ratio = mlp_ratio
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm1 = norm_layer(dim)
if self.patches_resolution == split_size:
last_stage = True
if last_stage:
self.branch_num = 1
else:
self.branch_num = 2
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(drop)
if last_stage:
self.attns = nn.ModuleList([
LePEAttention(
dim, resolution=self.patches_resolution, idx = -1,
split_size=split_size, num_heads=num_heads, dim_out=dim,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
for i in range(self.branch_num)])
else:
self.attns = nn.ModuleList([
LePEAttention(
dim//2, resolution=self.patches_resolution, idx = i,
split_size=split_size, num_heads=num_heads//2, dim_out=dim//2,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
for i in range(self.branch_num)])
mlp_hidden_dim = int(dim * mlp_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop)
self.norm2 = norm_layer(dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H = W = self.patches_resolution # 56
B, L, C = x.shape # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]
assert L == H * W, "flatten img_tokens has wrong size"
img = self.norm1(x)
qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # [3 N 3136 64] [3 N 784 128] [3 N 196 256] [3 N 49 512]
if self.branch_num == 2:
x1 = self.attns[0](qkv[:,:,:,:C//2]) # qkv[3 N 3136 32]->x1[N 3136 32] qkv[3 N 784 128]->x1[N 784 64] qkv[3 N 196 256]->x1[N 196 128]
x2 = self.attns[1](qkv[:,:,:,C//2:]) # qkv[3 N 3136 32]->x2[N 3136 32] qkv[3 N 784 128]->x1[N 784 64] qkv[3 N 196 256]->x1[N 196 128]
attened_x = torch.cat([x1,x2], dim=2)
else:
attened_x = self.attns[0](qkv) # [3 N 49 512]->[N 49 512]
attened_x = self.proj(attened_x)
x = x + self.drop_path(attened_x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]
在相似网络参数和计算量的模型中,cswin在分类任务和各类下游任务中都做到了SOTA
检测:
分割: