paper:https://arxiv.org/abs/2108.10257
code:https://github.com/JingyunLiang/SwinIR
目录
- 1. Swin Transformer layers
- 1.1 局部注意力
- 1.2 移动窗口机制
- 1.3 关键代码理解
- 2. 整体网络结构
- 2.1 浅层特征提取
- 2.2 深层特征提取
- 2.3 图像重建
- 3.总结
SwinIR将Swin transformer1应用到low level领域的图像增强任务,结合卷积设计了网络结构,在以下三个任务上取得了很好的效果:图像超分辨率(包括classical、lightweight和real-world SR)、图像去噪(包括灰度图和彩色图像去噪)和 JPEG压缩失真去除。本文将结合代码对SwinIR进行详解。
SwinIR的网络结构并不复杂,关键部件就是Swin Transformer layers(STL)、卷积层和残差连接。卷积和残差连接大家都比较熟悉了,因此我首先结合代码介绍一下swin transformer层,然后自底向上的介绍SwinIR的全貌
1. Swin Transformer layers
SwinIR使用的Swin Transformer layers(STL)是在swin transformer中提出的,并未有改动。STL基于原始的多头注意力transformer层进行优化,主要的不同点在于:1. 局部注意力(local attention);2. 移动窗口机制(shifted window mechanism);
1.1 局部注意力
原始的全局注意力会将图像分成若干个patch,所有的patch之间做自注意力计算;所谓的局部注意力就是首先将图像划分成若干个window,每个window内在进行patch的划分,然后在window内部进行自注意力的计算,而不在一个window内的patch是没有交互的。也就是说,只考虑一个window内的patch,他们之间的计算和全局注意力操作是一样的。
理解局部注意力具体是怎么做的,很好的一个办法是看代码和分析tensor在不同层之间的shape整理出来。下面是我整理的tensor shape变化:
其中,b: batchsize, h: 输入高, w:输入宽, ws: 窗口大小, C: channel数, num_heads:attention的head数
1.2 移动窗口机制
由于基于窗口的多头注意力(W-MSA)没有考虑跨窗口的连接,模型建模长距离关联的能力受损。因此swin transformer提出了移动窗口多头注意力机制(SW-MSA),可在保证计算高效性的前提下,扩大感受野。
如下图所示,W-MSA的窗口大小为M*M(图中M=4),那么SW-MSA的窗口划分将向右下移动 ⌊ M / 2 ⌋ ∗ ⌊ M / 2 ⌋ \lfloor M/2 \rfloor *\lfloor M/2 \rfloor ⌊M/2⌋∗⌊M/2⌋。
但是经过位移之后,窗口数量会变多,由原来的
⌊
h
/
M
⌋
∗
⌊
w
/
M
⌋
\lfloor h/M \rfloor *\lfloor w/M \rfloor
⌊h/M⌋∗⌊w/M⌋变成
(
⌊
h
/
M
⌋
+
1
)
∗
(
⌊
w
/
M
⌋
+
1
)
(\lfloor h/M \rfloor + 1) *(\lfloor w/M \rfloor +1)
(⌊h/M⌋+1)∗(⌊w/M⌋+1),而且窗口大小不一致。因此swin transformer提出了循环位移,减少窗口数量,同时可以获得相同大小的窗口进行并行计算。循环位移如下图所示。
在代码中,循环位移通过torch.roll实现,shifts为负,代表从下往上移动,从右往左移动,最上和最左循环移动到最下和最右。
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
关于torch.roll可参考:https://blog.csdn.net/weixin_42899627/article/details/116095067
如上图所示,经过循环移位后,有三个窗口中有一些patch是本不相邻的,它们不应该做自注意力,所以swin transformer建立了mask机制来完成最终的注意力计算。
关于mask的理解可参考https://github.com/microsoft/Swin-Transformer/issues/38
1.3 关键代码理解
下面来看一下关键代码及注释,首先是WindowAttention的forward函数:
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape # 此处的输入是经过window partition的
# self.qkv(x): num_windows*B, window_size*window_size, 3*C
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # 通过一个全连接层获取所有头的qkv,(3, num_windows*B, num_heads, window_size*window_size, C // num_heads)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
# 可学习的相对位置bias
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
if mask is not None:
nW = mask.shape[0]
# 将mask和attn相加,mask只有两种取值0和-100,因此为0时对attn无影响,为-100时,self.softmax(attn)将变为接近于0
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
attn = self.softmax(attn) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# v:num_windows*B, num_heads, window_size*window_size, C // num_heads
# attn:num_windows*B, num_heads, window_size*window_size, window_size*window_size
# attn @ v: num_windows*B, num_heads, window_size*window_size, C // num_heads
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # num_windows*B, window_size*window_size, C
x = self.proj(x) # 全连接层
x = self.proj_drop(x)
return x
接下来是SwinTransformerBlock的forward函数
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # (num_windows*B, window_size, window_size, C)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_windows*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
可以看到每个SwinTransformerBlock内部完成的是:
X
=
M
S
A
(
L
N
(
X
)
)
+
X
X = MSA(LN(X)) + X
X=MSA(LN(X))+X
X
=
M
L
P
(
L
N
(
X
)
)
+
X
X = MLP(LN(X)) + X
X=MLP(LN(X))+X
其中MSA为W-MSA和SW-MSA交替。
2. 整体网络结构
如上图所示,SwinIR包括三个modules,浅层特征提取、深层特征提取和图像重建。其中特征提取模块对所有任务都是一样的,但是图像重建对于不同的任务是不同的。
2.1 浅层特征提取
一个3×3卷积层将特征图通道转成embed_dim:(b, embed_dim, h, w)
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
2.2 深层特征提取
深层特征提取的基本模块则是第一节中讲解的STL和卷积层和残差连接。STL和卷积组成RSTB,RSTB和卷积组成了深层特征提取。
2.3 图像重建
以下代码可以看到对于不同的任务,图像重建模块是不同的,有的采用最邻近插值+卷积,有的采用pixelshuffle+卷积,有的直接采用卷积。
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x) # (b, embed_dim, h, w)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
SwinIR可以很灵活配置网络的复杂度。影响W-MSA计算复杂度:
4
h
w
C
2
+
2
M
2
h
w
C
4hwC^2 + 2M^2hwC
4hwC2+2M2hwC
3.总结
- 结构简单,性能全面超过cnn-based的方法,适用于多种任务,可做为Low-level的基线模型;
- 作者发现与以往基于transformer的方法不同,Swinir不需要比cnn更多的训练数据,收敛速度也更快;
- 结构模块化,可以方便调整出不同复杂度的模型;
Liu Z, Lin Y, Cao Y, et al. Swin transformer: Hierarchical vision transformer using shifted windows[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 10012-10022. ↩︎