Swin Transformer模型详解(附pytorch实现)

news2025/1/10 21:22:27

写在前面

Swin Transformer(Shifted Window Transformer)是一种新颖的视觉Transformer模型,在2021年由微软亚洲研究院提出。这一模型提出了一种基于局部窗口的自注意力机制,显著改善了Vision Transformer(ViT)在处理高分辨率图像时的性能,尤其是在图像分类、物体检测等计算机视觉任务中表现出色。

Swin Transformer的最大创新之一是其引入了“平移窗口”机制,克服了传统自注意力方法在大图像处理时计算资源消耗过大的问题。这一机制使得模型能够在不同层次上以局部的方式计算自注意力,同时保持全局信息的处理能力。

在本文中,我们将通过详细的分析,介绍Swin Transformer的模型结构、核心思想及其实现,最后提供一个基于PyTorch的简单实现。

论文地址:https://arxiv.org/pdf/2103.14030

官方代码实现:https://github.com/microsoft/Swin-Transformer

Swin网络结构

如下图所示,Swin Transformer的Encoder采用分层的方式,通过多个阶段(Stage)逐渐减少特征图的分辨率,同时增加特征维度。每个Stage包含若干个Transformer Block。

每个Block通常由以下几个部分组成:

  • Window-based Self-Attention:每个Block使用窗口自注意力机制,在每个窗口内计算自注意力。这种方式减少了计算量,因为自注意力只在局部窗口内进行计算,而不是整个图像。
  • Shifted Window:为了增强不同窗口之间的联系,Swin Transformer在每一层的Block中采用了“窗口位移”策略。每一层中的窗口会偏移一定的步长,使得窗口之间的重叠区域增加,从而促进信息交流。

Patch Partition

Patch Partition 是将输入图像分割成固定大小的块(patch)并将其映射到高维空间的操作。就相当于是VIT模型当中的 Patch Embedding。

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPath

LayerNorm = partial(nn.LayerNorm, eps=1e-6)


class PatchPartition(nn.Module):
    def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_channels, self.embed_dim,
                              kernel_size=self.patch_size, stride=self.patch_size)
        self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape
        if H % self.patch_size[0] != 0:
            pad_h = self.patch_size[0] - H % self.patch_size[0]
            x = F.pad(x, (0, 0, 0, pad_h))
        if W % self.patch_size[1] != 0:
            pad_w = self.patch_size[1] - W % self.patch_size[1]
            x = F.pad(x, (0, pad_w, 0, 0))
        x = self.proj(x)     # [B, embed_dim, H/patch_size, W/patch_size]
        Wh, Ww = x.shape[2:]
        x = x.flatten(2).transpose(1, 2)    # [B, num_patches, embed_dim]
        # Linear Embedding
        x = self.norm(x)
        # x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
        return x, Wh, Ww


if __name__=="__main__":
    batch_size = 1
    in_channels = 3
    height, width = 30, 32
    patch_size = 4
    embed_dim = 96

    x = torch.randn(batch_size, in_channels, height, width)
    patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
    output,_ ,_ = patch_partition(x)

    print(f"Output shape: {output.shape}")

Patch Merging

PatchMerging 这一层用于将输入的特征图进行下采样,类似于卷积神经网络中的池化层。

如果图像的高度或宽度是奇数,PatchMerging 会进行填充,使得其变为偶数。这是因为下采样操作需要将图像分割为以2为步长的区域。如果图像的高度或宽度是奇数,直接进行切片会导致不均匀的分割,因此需要填充以保证每个块的大小一致。

这里我们在吧如上图的相同颜色块提取并进行拼接,沿着通道维度合并成一个更大的特征,将合并后的张量重新调整形状,新的空间分辨率是原来的一半(H/2 和 W/2)。

class PatchMerging(nn.Module):
    def __init__(self, dim, norm_layer=LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        if H % 2 == 1 or W % 2 == 1:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

if __name__=="__main__":
    batch_size = 1
    in_channels = 3
    height, width = 30, 32
    patch_size = 4
    embed_dim = 96

    x = torch.randn(batch_size, in_channels, height, width)
    patch_partition = PatchPartition(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
    output, Wh, Ww = patch_partition(x)

    patch_merging = PatchMerging(dim=embed_dim)
    output = patch_merging(output, Wh, Ww)
    print(output.shape)

在代码中呢就是在高和宽的维度通过切片的形式获得,x0表示的是左上角,x1表示的是右上角,x2表示的是左下角,x3表示的是右下角。经过一系列操作后,最后通过线性层实现通道数翻倍。

W-MSA

W-MSA(Window-based Multi-Head Self-Attention)是Swin Transformer中的一个核心创新,它是为了优化传统自注意力机制在高分辨率输入图像处理中的效率问题而提出的。

\Omega (MSA) = 4hwC^{2}+2(hw)^{2}C

\Omega (W$-$MSA) = 4hwC^{2} + 2(M)^{2}hwC

这是原论文当中给出的计算公式,h,w和C分别表示特征的高度,宽度和深度,M表示窗口的大小。在标准的 Transformer 模型中,自注意力机制需要对整个输入进行计算,这使得计算和内存的消耗随着输入的增大而急剧增长。而在图像任务中,输入图像往往具有非常高的分辨率,因此直接应用标准的全局自注意力在计算上不可行。

W-MSA 通过在局部窗口内进行自注意力计算来解决这一问题,极大地减少了计算和内存开销,同时保持了模型的表示能力。

SW-MSA

SW-MSA (Shifted Window-based Multi-Head Self-Attention)结合了局部窗口化自注意力和窗口偏移(shifted)策略,既提升了计算效率,又能在捕捉局部信息的基础上,保持对全局信息的建模能力。

左侧就是刚刚说到的W-MSA,经过窗口的偏移变成了右边的SW-MSA,偏移的策略能够让模型在每一层的计算中捕捉到不同窗口之间的依赖关系,避免了 W-MSA 只能在单一窗口内计算的局限。这样,相邻窗口之间的信息就能够通过偏移和交错的方式进行交流,增强了模型的全局感知能力。

但是,现在的窗口从原来的四个变成了九个,如果对每一个窗口再进行W-MSA那就太麻烦了。为了应对这种情况,作者提出了一种 高效批处理计算方法,旨在优化窗口偏移后的大规模窗口计算。其核心思想是:通过批处理计算的方式来有效地处理这些偏移后的窗口,而不是每个窗口单独计算。

意思就是说将图中的A,B,C的位置通过偏移和交错方式变化后,可以将这些窗口的计算统一进行批处理,而不是一个一个地处理。这样可以显著减少计算时间和内存占用。

这个过程我个人感觉比较像是卡诺图,具体的过程可以看我下面画的图:

然后这里的4还和原来的一样,5和3组合成一个窗口,1和7组合成一个窗口,8、2、6、0又组合成一个窗口,这样就和原来一样是4个4x4的窗口了,保证了计算量的不变。但是如果这样做了就会将不相邻的信息混合在一起了。作者这里采用掩蔽机制将自注意力计算限制在每个子窗口内,其实就是创建一个蒙板来屏蔽信息。

Relative Position Bias

关于这一部分,作者没有怎么提,只是经过了相对位置偏移,指标有明显的提示。

关于这一部分,我是参考的官方代码以及b站的讲解视频理解的。首先需要创建一个相对位置偏置的参数表,它的范围是从[-Wh+1, Wh-1],这里的 +1 和 -1 是因为偏移量是相对于当前元素的位置而言的,当前元素自身的偏移量为0,但我们不包括0在偏移量的计算中(因为0表示没有偏移,通常会在自注意力机制中以其他方式处理)。因此,对于垂直方向(或水平方向),总的偏移量数量是 win_h(或 win_w)的正偏移量数量加上 win_h(或 win_w)的负偏移量数量,再减去一个(因为我们不计算0偏移量)。因此,相对位置偏置表的尺寸为:

[(2 * Wh - 1) * (2 * Ww - 1), num_heads]

每个元素的查询(Query)和键(Key)之间的内积会得到一个相似度分数,在这些分数的基础上,会加入相对位置偏置,调整相似度:

Attention = softmax((QK^T + Relative_Position_Bias) / sqrt(d_k))

其中,Q 是查询向量,K 是键向量,Relative_Position_Bias 是根据相对位置计算得到的偏置。加入相对位置偏置后,模型可以更好地捕捉到局部结构的依赖关系。

网络实现

"""
Copyright (c) 2025, Auorui.
All rights reserved.

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/pdf/2103.14030>
use for reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/swin_transformer.py
                   https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/swin_transformer/model.py
"""
from functools import partial

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from pyzjr.utils.FormatConver import to_2tuple
from pyzjr.nn.models.bricks.drop import DropPath
from pyzjr.nn.models.bricks.initer import trunc_normal_

LayerNorm = partial(nn.LayerNorm, eps=1e-6)

class PatchPartition(nn.Module):
    def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_channels, self.embed_dim,
                              kernel_size=self.patch_size, stride=self.patch_size)
        self.norm = norm_layer(self.embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape
        if H % self.patch_size[0] != 0:
            pad_h = self.patch_size[0] - H % self.patch_size[0]
            x = F.pad(x, (0, 0, 0, pad_h))
        if W % self.patch_size[1] != 0:
            pad_w = self.patch_size[1] - W % self.patch_size[1]
            x = F.pad(x, (0, pad_w, 0, 0))
        x = self.proj(x)     # [B, embed_dim, H/patch_size, W/patch_size]
        Wh, Ww = x.shape[2:]
        x = x.flatten(2).transpose(1, 2)    # [B, num_patches, embed_dim]
        # Linear Embedding
        x = self.norm(x)
        # x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
        return x, Wh, Ww

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_ratio=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop_ratio)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class PatchMerging(nn.Module):
    def __init__(self, dim, norm_layer=LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        if H % 2 == 1 or W % 2 == 1:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

class WindowAttention(nn.Module):
    """
    Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports shifted and non-shifted windows.
    """
    def __init__(
            self,
            dim,
            window_size,
            num_heads,
            qkv_bias=True,
            proj_bias=True,
            attention_dropout_ratio=0.,
            proj_drop=0.,
    ):
        super().__init__()
        self.dim = dim
        self.window_size = to_2tuple(window_size)
        win_h, win_w = self.window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)
        )   # [2*Wh-1 * 2*Ww-1, nHeads]   Offset Range: -Wh+1, Wh-1

        self.register_buffer("relative_position_index",
                             self.get_relative_position_index(win_h, win_w), persistent=False)
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attention_dropout_ratio)
        self.proj = nn.Linear(dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

    def get_relative_position_index(self, win_h: int, win_w: int):
        # get pair-wise relative position index for each token inside the window
        coords = torch.stack(torch.meshgrid(torch.arange(win_h), torch.arange(win_w), indexing='ij'))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += win_h - 1  # shift to start from 0
        relative_coords[:, :, 1] += win_w - 1
        relative_coords[:, :, 0] *= 2 * win_w - 1
        return relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[:3]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


def window_partition(x, window_size: int):
    """
    将feature map按照window_size划分成一个个没有重叠的window
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size: int, H: int, W: int):
    """
    将一个个window还原成一个feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block."""
    mlp_ratio = 4
    def __init__(
        self,
        dim,
        num_heads,
        window_size=7,
        shift_size=0,
        qkv_bias=True,
        proj_bias=True,
        attention_dropout_ratio=0.,
        proj_drop=0.,
        drop_path_ratio=0.,
        norm_layer=LayerNorm,
        act_layer=nn.GELU,
    ):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        assert 0 <= self.shift_size < window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim,
            window_size=self.window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            proj_bias=proj_bias,
            attention_dropout_ratio=attention_dropout_ratio,
            proj_drop=proj_drop,
        )

        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * self.mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_ratio=proj_bias)

        self.H = None
        self.W = None

    def forward(self, x, mask_matrix):
        """
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
            mask_matrix: Attention mask for cyclic shift.
        """
        B, L, C = x.shape
        H, W = self.H, self.W
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage."""
    def __init__(self,
                 dim,
                 num_layers,
                 num_heads,
                 drop_path,
                 window_size=7,
                 qkv_bias=True,
                 proj_bias=True,
                 attention_dropout_ratio=0.,
                 proj_drop=0.,
                 norm_layer=LayerNorm,
                 act_layer=nn.GELU,
                 downsample=None):
        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.num_layers = num_layers

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                attention_dropout_ratio=attention_dropout_ratio,
                proj_drop=proj_drop,
                drop_path_ratio=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                act_layer=act_layer)
            for i in range(num_layers)])
        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x, H, W):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """

        # calculate attention mask for SW-MSA
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        for blk in self.blocks:
            blk.H, blk.W = H, W
            x = blk(x, attn_mask)
        if self.downsample is not None:
            x = self.downsample(x, H, W)
            H, W = (H + 1) // 2, (W + 1) // 2

        return x, H, W




class SwinTransformer(nn.Module):
    """ Swin Transformer backbone."""
    def __init__(self,
                 patch_size=4,
                 in_channels=3,
                 num_classes=1000,
                 embed_dim=96,
                 depths=(2, 2, 6, 2),
                 num_heads=(3, 6, 12, 24),
                 window_size=7,
                 qkv_bias=True,
                 proj_bias=True,
                 attention_dropout_ratio=0.,
                 proj_drop=0.,
                 drop_path_rate=0.2,
                 norm_layer=LayerNorm,
                 patch_norm=True,
                 ):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        # stage4输出特征矩阵的channels
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

        # split image into non-overlapping patches
        self.patch_embed = PatchPartition(
            patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        self.pos_drop = nn.Dropout(p=proj_drop)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        layers = []
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2 ** i_layer),
                num_layers=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                attention_dropout_ratio=attention_dropout_ratio,
                proj_drop=proj_drop,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                )
            layers.append(layer)

        self.layers = nn.Sequential(*layers)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # x: [B, L, C]
        x, H, W = self.patch_embed(x)
        x = self.pos_drop(x)

        for layer in self.layers:
            x, H, W = layer(x, H, W)

        x = self.norm(x)  # [B, L, C]
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x

def swin_t(num_classes) -> SwinTransformer:
    model = SwinTransformer(in_channels=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 6, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes)
    return model

def swin_s(num_classes) -> SwinTransformer:
    model = SwinTransformer(in_channels=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 18, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes)
    return model


def swin_b(num_classes) -> SwinTransformer:
    model = SwinTransformer(in_channels=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes)
    return model

def swin_l(num_classes) -> SwinTransformer:
    model = SwinTransformer(in_channels=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes)
    return model

if __name__=="__main__":
    import pyzjr
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    input = torch.ones(2, 3, 224, 224).to(device)
    net = swin_l(num_classes=4)
    net = net.to(device)
    out = net(input)
    print(out)
    print(out.shape)
    pyzjr.summary_1(net, input_size=(3, 224, 224))
    # swin_t Total params: 27,499,108
    # swin_s Total params: 48,792,676
    # swin_b Total params: 86,683,780
    # swin_l Total params: 194,906,308

参考文章

Swin-Transformer网络结构详解_swin transformer-CSDN博客

Swin-transformer详解_swin transformer-CSDN博客 

【深度学习】详解 Swin Transformer (SwinT)-CSDN博客 

推荐的视频:12.1 Swin-Transformer网络结构详解_哔哩哔哩_bilibili 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2274555.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

穷举vs暴搜vs深搜vs回溯vs剪枝系列一>字母大小写全排列

题目&#xff1a; 解析&#xff1a; 代码&#xff1a; private List<String> ret;private StringBuffer path;public List<String> letterCasePermutation(String s) {ret new ArrayList<>();path new StringBuffer();dfs(s,0);return ret;}private voi…

LabVIEW软件侵权分析与应对

问&#xff1a;如果涉及到LabVIEW软件的仿制或模仿&#xff0c;特别是在功能、界面等方面&#xff0c;如何判断是否构成侵权&#xff1f;该如何应对&#xff1f; 答&#xff1a;LabVIEW软件的侵权问题&#xff0c;尤其是在涉及到仿制或模仿其功能、界面、设计等方面&#xff0…

玩转 JMeter:Random Order Controller让测试“乱”出花样

嘿&#xff0c;各位性能测试的小伙伴们&#xff01;今天咱要来唠唠 JMeter 里超级有趣又超实用的 Random Order Controller&#xff08;随机顺序控制器&#xff09;&#xff0c;它就像是性能测试这场大戏里的“魔术棒”&#xff0c;轻轻一挥&#xff0c;就能让测试场景变得千变…

探秘MetaGPT:革新软件开发的多智能体框架(22/30)

一、MetaGPT 引发的 AI 变革浪潮 近年来&#xff0c;人工智能大模型领域取得了令人瞩目的进展&#xff0c;GPT-3、GPT-4、PaLM 等模型展现出了惊人的自然语言处理能力&#xff0c;仿佛为 AI 世界打开了一扇通往无限可能的大门。它们能够生成流畅的文本、回答复杂的问题、进行创…

01、Redis初认识

一、简介 Redis&#xff0c;Remote Dictionary Server &#xff0c;远程字典服务。它是由一个意大利人使用C语言开发的&#xff0c;支持网络、可基于内存也可以持久化的日志型、NoSQL内存数据库&#xff0c;其提供了多种语言的API。 为什么把Reids称为字典服务&#xff1f; …

【2025 Rust学习 --- 10 运算符重载】

重载操作符 算术运算符与按位运算符 Rust 中&#xff0c;表达式 a b 实际上是 a.add(b) 的简写形式&#xff0c;也就是对标准库 中 std::ops::Add 特型的 add 方法的调用。Rust 的标准数值类型都实现了 std::ops::Add。 trait Add<Rhs Self> {type Output;fn add(se…

node-sass@4.14.1报错的最终解决方案分享

输入npm i全安装文件所需的依赖的时候&#xff0c;博主是使用sass去书写的&#xff0c;使用的是node-sass4.14.1和sass-loader7.3.1的版本的&#xff0c;安装的时候老是出现错误&#xff0c; node-sass4.14.1版本不再被支持的原因 node-sass 是一个基于 LibSass 的 Node.js 绑…

LabVIEW大数据有什么应用场景?

LabVIEW在处理大数据时主要依赖于其强大的数据采集、信号处理、控制、以及实时系统的功能。以下是一些典型的应用场景&#xff1a; ​ 1. 工业自动化与制造 数据采集与监控&#xff1a;在生产线上&#xff0c;LabVIEW可以用来收集大量的传感器数据&#xff08;如温度、压力、湿…

深入理解Mybatis原理》MyBatis的sqlSessi

sqlSessionFactory 与 SqlSession 正如其名&#xff0c;Sqlsession对应着一次数据库会话。由于数据库会话不是永久的&#xff0c;因此Sqlsession的生命周期也不应该是永久的&#xff0c;相反&#xff0c;在你每次访问数据库时都需要创建它&#xff08;当然并不是说在Sqlsession…

【OAuth2系列】如何使用OAuth 2.0实现安全授权?详解四种授权方式

作者&#xff1a;后端小肥肠 &#x1f347; 我写过的文章中的相关代码放到了gitee&#xff0c;地址&#xff1a;xfc-fdw-cloud: 公共解决方案 &#x1f34a; 有疑问可私信或评论区联系我。 &#x1f951; 创作不易未经允许严禁转载。 姊妹篇&#xff1a; 【OAuth2系列】集成微…

Open WebUI 与 AnythingLLM 安装部署

在前文 Ollama私有化部署大语言模型LLM&#xff08;上&#xff09;-CSDN博客 中通过Ollama来搭建运行私有化大语言模型&#xff0c;但缺少用户交互的界面&#xff0c;特别是Web可视化界面。 对此&#xff0c;本文以Open WebUI和AnythingLLM为例分别作为Ollama的前端Web可视化界…

如何稳定使用 O1 / O1 Pro,让“降智”现象不再困扰?

近期&#xff0c;不少朋友在使用 O1 或 O1 Pro 模型时&#xff0c;都会碰到“降智”或“忽高忽低”的智力波动&#xff0c;比如无法识图、无法生成图片、甚至回答准确度也不稳定。面对这些问题&#xff0c;你是不是也感到头疼呢&#xff1f; 为了找到更可靠的解决办法&#xf…

RK3562编译Android13 ROOT固件教程,触觉智能开发板演示

本文介绍编译Android13 ROOT权限固件的方法&#xff0c;触觉智能RK3562开发板演示&#xff0c;搭载4核A53处理器&#xff0c;主频高达2.0GHz&#xff1b;内置独立1Tops算力NPU&#xff0c;可应用于物联网网关、平板电脑、智能家居、教育电子、工业显示与控制等行业。 关闭seli…

58. Three.js案例-创建一个带有红蓝配置的半球光源的场景

58. Three.js案例-创建一个带有红蓝配置的半球光源的场景 实现效果 本案例展示了如何使用Three.js创建一个带有红蓝配置的半球光源的场景&#xff0c;并在其中添加一个旋转的球体。通过设置不同的光照参数&#xff0c;可以观察到球体表面材质的变化。 知识点 WebGLRenderer …

React+redux项目搭建流程

1.创建项目 create-react-app my-project --template typescript // 创建项目并使用typescript2.去除掉没用的文件夹&#xff0c;只保留部分有用的文件 3.项目配置&#xff1a; 配置项目的icon 配置项目的标题 配置项目的别名等&#xff08;craco.config.ts&…

解决GitHub上的README.md文件的图片内容不能正常显示问题

一、问题描述 我们将项目推送到GitHub上后&#xff0c;原本在本地编写配置好可展现的相对路径图片内容&#xff0c;到了GitHub上却不能够正常显示图片内容&#xff0c;我们希望能够在GitHub上正常显示图片&#xff0c;如下图所示&#xff1a; 二、问题分析 现状&#xff1a;REA…

双模充电桩发展前景:解锁新能源汽车未来的金钥匙,市场潜力无限

随着全球能源转型的浪潮席卷而来&#xff0c;新能源汽车行业正以前所未有的速度蓬勃发展&#xff0c;而作为其坚实后盾的充电基础设施&#xff0c;特别是双模充电桩&#xff0c;正逐渐成为推动这一变革的关键力量。本文将从多维度深入剖析双模充电桩的市场现状、显著优势、驱动…

开关不一定是开关灯用 - 命令模式(Command Pattern)

命令模式&#xff08;Command Pattern&#xff09; 命令模式&#xff08;Command Pattern&#xff09;命令设计模式命令设计模式结构图命令设计模式涉及的角色 talk is cheap&#xff0c; show you my code总结 命令模式&#xff08;Command Pattern&#xff09; 命令模式&…

Qt 5.14.2 学习记录 —— 칠 QWidget 常用控件(2)

文章目录 1、Window Frame2、windowTitle3、windowIcon4、qrc机制5、windowOpacity 1、Window Frame 在运行Qt程序后&#xff0c;除了用户做的界面&#xff0c;最上面还有一个框&#xff0c;这就是window frame框。对于界面的元素&#xff0c;它们的原点是Qt界面的左上角或win…

LabVIEW水轮发电机组振动摆度故障诊断

本文介绍了基于LabVIEW的水轮发电机组振动摆度故障诊断系统的设计与实施过程。系统在通过高效的故障诊断功能&#xff0c;实现水轮发电机组的振动、温度等关键指标的实时监控与智能分析&#xff0c;从而提高电力设备的可靠性和安全性。 ​ 项目背景 随着电力行业对设备稳定性…