1,首先,我们需要知道的是,想要调用预训练的Swin Transformer模型,必须要安装pytorch2,因为pytorch1对应的torchvision中不包含Swin Transformer。
2,pytorch2调用预训练模型时,不建议使用pretrained=True,这个用法即将淘汰,会报警告。最好用如下方式:
from torchvision.models.swin_transformer import swin_b, Swin_B_Weights
model = swin_b(weights=Swin_B_Weights.DEFAULT)
这里调用的就是swin_b在imagenet上的预训练模型
3,swin_b的模型结构如下(仅展示到第一个patch merging部分),在绝大部分情况下,我们可能需要的不是整个模型,而是其中的一个模块,比如SwinTransformerBlock。
SwinTransformer(
(features): Sequential(
(0): Sequential(
(0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
(1): Permute()
(2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
(1): Sequential(
(0): SwinTransformerBlock(
(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=128, out_features=384, bias=True)
(proj): Linear(in_features=128, out_features=128, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.0, mode=row)
(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=128, out_features=512, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=512, out_features=128, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(attn): ShiftedWindowAttention(
(qkv): Linear(in_features=128, out_features=384, bias=True)
(proj): Linear(in_features=128, out_features=128, bias=True)
)
(stochastic_depth): StochasticDepth(p=0.021739130434782608, mode=row)
(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(mlp): MLP(
(0): Linear(in_features=128, out_features=512, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=512, out_features=128, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
(2): PatchMerging(
(reduction): Linear(in_features=512, out_features=256, bias=False)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
那么如何调用其中的SwinTransformerBlock呢。
由于该模型是个嵌套结构,而不是类似vgg一样简单的结构,所以不能直接用layer0=model.SwinTransformerBlock调用。
因为SwinTransformerBlock是Sequential下的子模块,故正确的调用代码如下:
swinblock = model.features[1][0]
结果如下,调用成功: