写在前面
Swin Transformer(Shifted Window Transformer)是一种新颖的视觉Transformer模型,在2021年由微软亚洲研究院提出。这一模型提出了一种基于局部窗口的自注意力机制,显著改善了Vision Transformer(ViT)在处理高分辨率图像时的性能,尤其是在图像分类、物体检测等计算机视觉任务中表现出色。
Swin Transformer的最大创新之一是其引入了“平移窗口”机制,克服了传统自注意力方法在大图像处理时计算资源消耗过大的问题。这一机制使得模型能够在不同层次上以局部的方式计算自注意力,同时保持全局信息的处理能力。
在本文中,我们将通过详细的分析,介绍Swin Transformer的模型结构、核心思想及其实现,最后提供一个基于PyTorch的简单实现。
论文地址:https://arxiv.org/pdf/2103.14030
官方代码实现:https://github.com/microsoft/Swin-Transformer
Swin网络结构
如下图所示,Swin Transformer的Encoder采用分层的方式,通过多个阶段(Stage)逐渐减少特征图的分辨率,同时增加特征维度。每个Stage包含若干个Transformer Block。
每个Block通常由以下几个部分组成:
- Window-based Self-Attention:每个Block使用窗口自注意力机制,在每个窗口内计算自注意力。这种方式减少了计算量,因为自注意力只在局部窗口内进行计算,而不是整个图像。
- Shifted Window:为了增强不同窗口之间的联系,Swin Transformer在每一层的Block中采用了“窗口位移”策略。每一层中的窗口会偏移一定的步长,使得窗口之间的重叠区域增加,从而促进信息交流。
Patch Partition
Patch Partition 是将输入图像分割成固定大小的块(patch)并将其映射到高维空间的操作。就相当于是VIT模型当中的 Patch Embedding。
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPath
LayerNorm = partial(nn.LayerNorm, eps=1e-6)
class PatchPartition(nn.Module):
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
super().__init__()
self.patch_size = to_2tuple(patch_size)
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_channels, self.embed_dim,
kernel_size=self.patch_size, stride=self.patch_size)
self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
if H % self.patch_size[0] != 0:
pad_h = self.patch_size[0] - H % self.patch_size[0]
x = F.pad(x, (0, 0, 0, pad_h))
if W % self.patch_size[1] != 0:
pad_w = self.patch_size[1] - W % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, 0))
x = self.proj(x) # [B, embed_dim, H/patch_size, W/patch_size]
Wh, Ww = x.shape[2:]
x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
# Linear Embedding
x = self.norm(x)
# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x, Wh, Ww
if __name__=="__main__":
batch_size = 1
in_channels = 3
height, width = 30, 32
patch_size = 4
embed_dim = 96
x = torch.randn(batch_size, in_channels, height, width)
patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
output,_ ,_ = patch_partition(x)
print(f"Output shape: {output.shape}")
Patch Merging
PatchMerging 这一层用于将输入的特征图进行下采样,类似于卷积神经网络中的池化层。
如果图像的高度或宽度是奇数,PatchMerging 会进行填充,使得其变为偶数。这是因为下采样操作需要将图像分割为以2为步长的区域。如果图像的高度或宽度是奇数,直接进行切片会导致不均匀的分割,因此需要填充以保证每个块的大小一致。
这里我们在吧如上图的相同颜色块提取并进行拼接,沿着通道维度合并成一个更大的特征,将合并后的张量重新调整形状,新的空间分辨率是原来的一半(H/2 和 W/2)。
class PatchMerging(nn.Module):
def __init__(self, dim, norm_layer=LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
if H % 2 == 1 or W % 2 == 1:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
if __name__=="__main__":
batch_size = 1
in_channels = 3
height, width = 30, 32
patch_size = 4
embed_dim = 96
x = torch.randn(batch_size, in_channels, height, width)
patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
output, Wh, Ww = patch_partition(x)
patch_merging = PatchMerging(dim=embed_dim)
output = patch_merging(output, Wh, Ww)
print(output.shape)
在代码中呢就是在高和宽的维度通过切片的形式获得,x0表示的是左上角,x1表示的是右上角,x2表示的是左下角,x3表示的是右下角。经过一系列操作后,最后通过线性层实现通道数翻倍。
W-MSA
W-MSA(Window-based Multi-Head Self-Attention)是Swin Transformer中的一个核心创新,它是为了优化传统自注意力机制在高分辨率输入图像处理中的效率问题而提出的。
这是原论文当中给出的计算公式,h,w和C分别表示特征的高度,宽度和深度,M表示窗口的大小。在标准的 Transformer 模型中,自注意力机制需要对整个输入进行计算,这使得计算和内存的消耗随着输入的增大而急剧增长。而在图像任务中,输入图像往往具有非常高的分辨率,因此直接应用标准的全局自注意力在计算上不可行。
W-MSA 通过在局部窗口内进行自注意力计算来解决这一问题,极大地减少了计算和内存开销,同时保持了模型的表示能力。
SW-MSA
SW-MSA (Shifted Window-based Multi-Head Self-Attention)结合了局部窗口化自注意力和窗口偏移(shifted)策略,既提升了计算效率,又能在捕捉局部信息的基础上,保持对全局信息的建模能力。
左侧就是刚刚说到的W-MSA,经过窗口的偏移变成了右边的SW-MSA,偏移的策略能够让模型在每一层的计算中捕捉到不同窗口之间的依赖关系,避免了 W-MSA 只能在单一窗口内计算的局限。这样,相邻窗口之间的信息就能够通过偏移和交错的方式进行交流,增强了模型的全局感知能力。
但是,现在的窗口从原来的四个变成了九个,如果对每一个窗口再进行W-MSA那就太麻烦了。为了应对这种情况,作者提出了一种 高效批处理计算方法,旨在优化窗口偏移后的大规模窗口计算。其核心思想是:通过批处理计算的方式来有效地处理这些偏移后的窗口,而不是每个窗口单独计算。
意思就是说将图中的A,B,C的位置通过偏移和交错方式变化后,可以将这些窗口的计算统一进行批处理,而不是一个一个地处理。这样可以显著减少计算时间和内存占用。
这个过程我个人感觉比较像是卡诺图,具体的过程可以看我下面画的图:
然后这里的4还和原来的一样,5和3组合成一个窗口,1和7组合成一个窗口,8、2、6、0又组合成一个窗口,这样就和原来一样是4个4x4的窗口了,保证了计算量的不变。但是如果这样做了就会将不相邻的信息混合在一起了。作者这里采用掩蔽机制将自注意力计算限制在每个子窗口内,其实就是创建一个蒙板来屏蔽信息。
Relative Position Bias
关于这一部分,作者没有怎么提,只是经过了相对位置偏移,指标有明显的提示。
关于这一部分,我是参考的官方代码以及b站的讲解视频理解的。首先需要创建一个相对位置偏置的参数表,它的范围是从[-Wh+1, Wh-1],这里的 +1 和 -1 是因为偏移量是相对于当前元素的位置而言的,当前元素自身的偏移量为0,但我们不包括0在偏移量的计算中(因为0表示没有偏移,通常会在自注意力机制中以其他方式处理)。因此,对于垂直方向(或水平方向),总的偏移量数量是 win_h(或 win_w)的正偏移量数量加上 win_h(或 win_w)的负偏移量数量,再减去一个(因为我们不计算0偏移量)。因此,相对位置偏置表的尺寸为:
[(2 * Wh - 1) * (2 * Ww - 1), num_heads]
每个元素的查询(Query)和键(Key)之间的内积会得到一个相似度分数,在这些分数的基础上,会加入相对位置偏置,调整相似度:
Attention = softmax((QK^T + Relative_Position_Bias) / sqrt(d_k))
其中,Q 是查询向量,K 是键向量,Relative_Position_Bias 是根据相对位置计算得到的偏置。加入相对位置偏置后,模型可以更好地捕捉到局部结构的依赖关系。
网络实现
"""
Copyright (c) 2025, Auorui.
All rights reserved.
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/pdf/2103.14030>
use for reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer.py
https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/swin_transformer/model.py
"""
from functools import partial
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPath
from pyzjr.nn.models.bricks.initer import trunc_normal_
LayerNorm = partial(nn.LayerNorm, eps=1e-6)
class PatchPartition(nn.Module):
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
super().__init__()
self.patch_size = to_2tuple(patch_size)
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_channels, self.embed_dim,
kernel_size=self.patch_size, stride=self.patch_size)
self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
if H % self.patch_size[0] != 0:
pad_h = self.patch_size[0] - H % self.patch_size[0]
x = F.pad(x, (0, 0, 0, pad_h))
if W % self.patch_size[1] != 0:
pad_w = self.patch_size[1] - W % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, 0))
x = self.proj(x) # [B, embed_dim, H/patch_size, W/patch_size]
Wh, Ww = x.shape[2:]
x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
# Linear Embedding
x = self.norm(x)
# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x, Wh, Ww
class MLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_ratio=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop_ratio)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class PatchMerging(nn.Module):
def __init__(self, dim, norm_layer=LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
if H % 2 == 1 or W % 2 == 1:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class WindowAttention(nn.Module):
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports shifted and non-shifted windows.
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
proj_bias=True,
attention_dropout_ratio=0.,
proj_drop=0.,
):
super().__init__()
self.dim = dim
self.window_size = to_2tuple(window_size)
win_h, win_w = self.window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)
) # [2*Wh-1 * 2*Ww-1, nHeads] Offset Range: -Wh+1, Wh-1
self.register_buffer("relative_position_index",
self.get_relative_position_index(win_h, win_w), persistent=False)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attention_dropout_ratio)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def get_relative_position_index(self, win_h: int, win_w: int):
# get pair-wise relative position index for each token inside the window
coords = torch.stack(torch.meshgrid(torch.arange(win_h), torch.arange(win_w), indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_h - 1 # shift to start from 0
relative_coords[:, :, 1] += win_w - 1
relative_coords[:, :, 0] *= 2 * win_w - 1
return relative_coords.sum(-1) # Wh*Ww, Wh*Ww
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[:3]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def window_partition(x, window_size: int):
"""
将feature map按照window_size划分成一个个没有重叠的window
Args:
x: (B, H, W, C)
window_size (int): window size(M)
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
# permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
"""
将一个个window还原成一个feature map
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size(M)
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block."""
mlp_ratio = 4
def __init__(
self,
dim,
num_heads,
window_size=7,
shift_size=0,
qkv_bias=True,
proj_bias=True,
attention_dropout_ratio=0.,
proj_drop=0.,
drop_path_ratio=0.,
norm_layer=LayerNorm,
act_layer=nn.GELU,
):
super(SwinTransformerBlock, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
assert 0 <= self.shift_size < window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=self.window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attention_dropout_ratio=attention_dropout_ratio,
proj_drop=proj_drop,
)
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * self.mlp_ratio)
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_ratio=proj_bias)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
"""
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage."""
def __init__(self,
dim,
num_layers,
num_heads,
drop_path,
window_size=7,
qkv_bias=True,
proj_bias=True,
attention_dropout_ratio=0.,
proj_drop=0.,
norm_layer=LayerNorm,
act_layer=nn.GELU,
downsample=None):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.num_layers = num_layers
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attention_dropout_ratio=attention_dropout_ratio,
proj_drop=proj_drop,
drop_path_ratio=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
act_layer=act_layer)
for i in range(num_layers)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x, attn_mask)
if self.downsample is not None:
x = self.downsample(x, H, W)
H, W = (H + 1) // 2, (W + 1) // 2
return x, H, W
class SwinTransformer(nn.Module):
""" Swin Transformer backbone."""
def __init__(self,
patch_size=4,
in_channels=3,
num_classes=1000,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_size=7,
qkv_bias=True,
proj_bias=True,
attention_dropout_ratio=0.,
proj_drop=0.,
drop_path_rate=0.2,
norm_layer=LayerNorm,
patch_norm=True,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
# stage4输出特征矩阵的channels
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
# split image into non-overlapping patches
self.patch_embed = PatchPartition(
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.pos_drop = nn.Dropout(p=proj_drop)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
layers = []
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
num_layers=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attention_dropout_ratio=attention_dropout_ratio,
proj_drop=proj_drop,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
)
layers.append(layer)
self.layers = nn.Sequential(*layers)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
# x: [B, L, C]
x, H, W = self.patch_embed(x)
x = self.pos_drop(x)
for layer in self.layers:
x, H, W = layer(x, H, W)
x = self.norm(x) # [B, L, C]
x = self.avgpool(x.transpose(1, 2))
x = torch.flatten(x, 1)
x = self.head(x)
return x
def swin_t(num_classes) -> SwinTransformer:
model = SwinTransformer(in_channels=3,
patch_size=4,
window_size=7,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
num_classes=num_classes)
return model
def swin_s(num_classes) -> SwinTransformer:
model = SwinTransformer(in_channels=3,
patch_size=4,
window_size=7,
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
num_classes=num_classes)
return model
def swin_b(num_classes) -> SwinTransformer:
model = SwinTransformer(in_channels=3,
patch_size=4,
window_size=7,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes)
return model
def swin_l(num_classes) -> SwinTransformer:
model = SwinTransformer(in_channels=3,
patch_size=4,
window_size=7,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
num_classes=num_classes)
return model
if __name__=="__main__":
import pyzjr
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input = torch.ones(2, 3, 224, 224).to(device)
net = swin_l(num_classes=4)
net = net.to(device)
out = net(input)
print(out)
print(out.shape)
pyzjr.summary_1(net, input_size=(3, 224, 224))
# swin_t Total params: 27,499,108
# swin_s Total params: 48,792,676
# swin_b Total params: 86,683,780
# swin_l Total params: 194,906,308
参考文章
Swin-Transformer网络结构详解_swin transformer-CSDN博客
Swin-transformer详解_swin transformer-CSDN博客
【深度学习】详解 Swin Transformer (SwinT)-CSDN博客
推荐的视频:12.1 Swin-Transformer网络结构详解_哔哩哔哩_bilibili