补充1:
-
局部窗口内的自注意力(W-MSA):
- 在 Swin Transformer 中,输入特征图被划分为多个小的窗口(例如 7x7 的窗口)。在每个窗口内,计算自注意力机制(W-MSA, Window-based Multi-Head Self-Attention),这意味着每个 token 只和同一窗口内的其他 token 进行交互。
- 由于计算只发生在局部窗口内,所以计算复杂度大大降低,特别是对于高分辨率的输入图像来说,这种方式更加高效。
-
滑动窗口机制(Shifted Window Attention):
- 为了在局部窗口之间传递信息,Swin Transformer 引入了滑动窗口机制。通过在不同的层中移动窗口的位置,使得相邻窗口之间的特征可以进行交流,从而保证全局上下文的信息逐步整合到特征中。
-
计算量的比较:
- 全局自注意力(MSA):像 Vision Transformer (ViT) 这样的方法在整个特征图上计算自注意力,计算复杂度是 O((hw)²)。
- 窗口内自注意力(W-MSA):Swin Transformer 仅在每个窗口内计算,计算复杂度降低为 O(W * M²),其中 W 是窗口的数量,M 是窗口内的 token 数量(例如 7x7 = 49 个 token)。
-
滑动窗口的好处:
- 滑动窗口机制允许信息在不同窗口之间传递,而不仅仅局限在窗口内部。这种设计平衡了计算效率和特征提取的全局性,确保 Swin Transformer 可以在较低的计算复杂度下仍然获得良好的表现。
补充2:
关于特征图大小的解释:
-
输入图像大小(224x224):
- 在大多数计算机视觉任务中,输入图像通常会被调整为 224x224 像素。
-
Patch Embedding 和 Stride=16:
- Vision Transformer (ViT) 通常会将输入图像划分为
16x16
的 non-overlapping patches,然后将每个 patch 展平并映射到一个高维的特征空间。 - 因为每个 patch 的大小是 16x16,且是 non-overlapping 的,这相当于对输入图像应用了一个
Stride=16
的卷积操作,将图像的空间分辨率从 224x224 减少到 14x14。
- Vision Transformer (ViT) 通常会将输入图像划分为
-
特征图大小(14x14):
- 因此,经过
Stride=16
的操作后,原始图像被划分为 14x14 个 patch,每个 patch 被视为一个 token。在 Vision Transformer 中,这 14x14 个 token 会形成一个 196 维的 token 序列。
- 因此,经过
Swin Transformer 的不同点:
Swin Transformer 在一些细节上和 ViT 有所不同:
- 多级特征图:
- Swin Transformer 处理的是逐级降低空间分辨率的特征图(类似于卷积神经网络中的多尺度特征),比如从最开始的较大特征图(例如
h=w=56
)到最后的较小特征图。
- Swin Transformer 处理的是逐级降低空间分辨率的特征图(类似于卷积神经网络中的多尺度特征),比如从最开始的较大特征图(例如
- 滑动窗口与局部自注意力:
- 在 Swin Transformer 中,通过窗口内自注意力(W-MSA)和滑动窗口(Shifted Window)机制来逐步处理这些特征图,计算局部区域内的自注意力。
图中的特征图大小与实际应用:
在你提供的图片中,h=w=56
可能指的是在 Swin Transformer 的某个阶段,特征图被处理时的空间分辨率。例如,在较早的阶段,特征图的空间分辨率较高,经过几次降采样后,可能从 224x224 降到 56x56 甚至更低。
因此,特征图的大小 (14x14 或 56x56) 取决于模型的阶段以及具体的网络结构。在 Swin Transformer 中,早期层的特征图可能较大,而后期层的特征图可能较小,这与 Vision Transformer 中固定的 14x14
特征图有所不同。
补充3:
关于正文中h=w=56, m=7 的补充:
在 Swin Transformer 中,h=w=56
和 m=7
是针对特定阶段的特征图大小和窗口大小。这些参数在 Swin Transformer 中是有具体含义的:
1. h=w=56
的解释
- 初始阶段的特征图大小:
- Swin Transformer 通常会通过多级特征提取器(类似卷积神经网络中的多尺度特征提取),逐步缩小特征图的空间分辨率。
- 例如,在初始阶段,输入图像可能会被划分成大小为
4x4
的 patch(相当于Stride=4
),并将输入图像从原始的 224x224 分辨率降采样到 56x56 的特征图。 - 具体来说,224x224 的输入图像通过
Stride=4
的操作后,特征图的大小变成 224/4 = 56,既h=w=56
。
2. m=7
的解释
- 窗口大小(Window Size):
- Swin Transformer 的一个关键特点是,它引入了基于窗口的多头自注意力(W-MSA),这个窗口是在特定大小的局部区域内进行自注意力计算的。
m=7
表示窗口的大小为7x7
,也就是说,在每一个7x7
的局部区域内计算自注意力,而不是在整个56x56
的全局上计算。- 通过将大的特征图(例如
56x56
)划分为多个7x7
的窗口,Swin Transformer 可以在保持计算量可控的前提下,捕捉局部的相关性。
Swin Transformer 的多级结构
Swin Transformer 的网络结构通常分为多个阶段,每个阶段的特征图大小和窗口大小可能有所不同:
- Stage 1: 假设输入图像为 224x224,通过
Stride=4
的 patch embedding 操作,特征图的大小变为 56x56。 - Stage 2: 在 Stage 1 处理后的 56x56 特征图基础上,应用
7x7
的窗口来进行局部自注意力计算。 - Stage 3: 特征图继续下采样到更小的分辨率(例如 28x28 或 14x14),然后继续应用更小的窗口进行计算。
注1:
注2:
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
x: B, H*W, C
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
# 如果输入feature map的H,W不是2的整数倍,需要进行padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
# to pad the last 3 dimensions, starting from the last dimension and moving forward.
# (C_front, C_back, W_left, W_right, H_top, H_bottom)
# 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x) # [B, H/2*W/2, 2*C]
return x
正文: