机器学习周报(12.9-12.15)

news2025/1/22 15:55:55

文章目录

    • 摘要
    • Abstract
  • 1 Swin Transformer
    • 1.1 输入
    • 1.2 Patch Partition
    • 1.3 Linear Embedding
    • 1.4 Patch Merging
    • 1.5 Swin Transformer Block
    • 1.6 代码
    • 总结

摘要

本篇博客介绍了采用类似于卷积核的移动窗口进行图像特征提取的Swin Transformer网络模型,这是一种基于Transformer的计算机视觉模型,通过引入类似CNN卷积核的Shifted Windows(移动窗口),其通过将 self-attention 计算限制在不重叠的局部窗口上,由于窗口大小固定,每个窗口内计算复杂度固定,计算复杂度跟图像大小呈线性关系,从而带来更高的效率。Swin Transformer Block的核心机制,包括W-MSA(窗口多头自注意力)和SW-MSA(移动窗口多头自注意力),这些机制从全局的自注意力变为了窗口内的自注意力,使得模型在减少计算量的同时捕捉多尺度特征。最后,通过使用预训练模型对花类数据集进行微调;通过预测代码和结果展示了模型对图像分类的高准确率。

Abstract

This blog post introduces the Swin Transformer network model, which adopts shifting windows similar to convolutional kernels for image feature extraction. This computer vision model, based on the Transformer architecture, introduces Shifted Windows akin to CNN convolutional kernels. By limiting self-attention computations to non-overlapping local windows, and with a fixed window size, the computational complexity within each window is constant, and the overall computational complexity is linear with respect to the image size, thus offering greater efficiency. The core mechanisms of the Swin Transformer Block include W-MSA (Window Multi-Head Self-Attention) and SW-MSA (Shifted Window Multi-Head Self-Attention), which transition from global self-attention to window-internal self-attention, allowing the model to capture multi-scale features while reducing computational load. Finally, the post demonstrates the model’s high accuracy in image classification by fine-tuning a pre-trained model on a flower dataset; the prediction code and results showcase the model’s high accuracy in image classification tasks.

1 Swin Transformer

Swin Transformer提出了类似CNN卷积核的Shifted Windows(移动窗口),以减少patch之间做自注意力的次数,达到节省时间的目的。

接下来我将参照Swin Transformer模型整体结构图来讲解该模型,网络结构图如下所示:

在这里插入图片描述

1.1 输入

以输入图像大小为 224x224 为例,将图像tensor 224x224x3 从1处传入模型作为开始。

1.2 Patch Partition

到达第2步Patch Partiton,因为是以Transformer为基础,如同ViT一样需要将图片打成patch传入模型。这里将图像大小 224x224 通过卷积核大小为 4x4,步伐为4的卷积将图像转换为 56x56 的大小,通道数从3变为3x4x4=48维。

1.3 Linear Embedding

接下来数据将传入Swin Transformer最重要的第3大部分,首先在第一个stage会经过全连接层,这里的通道数在论文中是设置好的,即C=96。如上图所示,56x56x48 的tensor经过Linear Embedding维度变化为56x56x96。

1.4 Patch Merging

这里先说一下Patch Merging,因为只要第一个stage先经过Linear Embedding,后3个stage都是Patch Merging。这里的Patch Merging类似于CNN中的池化,为了使窗口获得多尺度的特征,需要在每一个block之前进行下采样,这样在窗口大小不变的情况下图像变小、维度增加,可以获得更多的特征信息。如下图所示:

在这里插入图片描述

论文中提到将图像下采样2倍,在Patch Merging中会隔一个点采样。细心的人这时会发现,我们图像的维度从 4x4x1 变为了 2x2x4,但是在模型中是从变为,即大小缩小一半,维度扩大2倍。但是在上图的操作中维度扩大了4倍,论文中为了使其与CNN中池化保持一致,在Patch Merging之后会在再通过一次卷积使维度缩小一半,以达到大小缩小一半,维度扩大2倍的效果。

到这一步Swin Transformer已经拥有了感知多尺度特征的手段。

1.5 Swin Transformer Block

在这里插入图片描述
Swin Transformer中最重要的模块,这里是用到了Transformer中的编码器,但是做了一些改动,Transformer中是单纯的多头自注意力(MSA),而这里用到了W-MSA和SW-MSA。

  1. W-MSA
    引入Windows Multi-head Self-Attention(窗口多头自注意力)就是为了解决开篇提到的计算量问题。W-MSA会将图像划分为 MxM (在论文中M默认为7)的不重叠窗口,然后多头自注意力只在同一个窗口内做,这样就从全局的自注意力变为了窗口内的自注意力,大大减小了计算量

在这里插入图片描述
如上公式,若图像长宽为 224x224 、M=7,则从 22 4 4 224^4 2244降为 22 4 4 × 49 224^4×49 2244×49

  1. SW-MSA

    如果只进行W-MSA会发现一个问题,窗口之间是没用通信的,也就是窗口与窗口之间是独立计算的,就不能像卷积一样移动去感受不同像素范围。于是,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块让窗口移动起来。

在这里插入图片描述

例如上图所示,将图像分为了4个窗口,在W-MSA中4个窗口分别做MSA;然后在下一层SW-MSA左上角的窗口会向右和向下移动个步伐。如右图所示将图像分割为9个窗口,会发现上一层不属于同一个窗口的patch,经过移动后融合为一个窗口,窗口与窗口之间成功进行了信息交流。

如果9个窗口分别进行SW-MSA,那么从上一层4个相同的窗口变为9个不完全相同的窗口进行计算,不但不利于并行计算,而且窗口数量也增大两倍多,通过W-MSA积攒的计算优势一去不复返。

那么该如何解决呢?有人想到padding补0,这样虽然可以保证窗口大小相同进行并行计算,但是数量是增多的。论文中,作者提出通过掩码的方式进行自注意力的计算

在这里插入图片描述
A、B、C经过顺时针循环位移,将图像窗口补为大小相同的4个窗口。但是同一个窗口中,颜色不同的部分不能直接进行自注意力,需要在计算之后分别乘上对应的掩码。

在这里插入图片描述

掩码是加上一个很大的负值,在经过softmax之后该部分就会变为0。1的部分全属于同一窗口可直接进行自注意力计算。

1.6 代码

Swin Transformer网络模型如下:

""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
    - https://arxiv.org/pdf/2103.14030
Code/weights from https://github.com/microsoft/Swin-Transformer
"""
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional
 
 
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output
 
 
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
 
    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)
 
 
def window_partition(x, window_size: int):
    """
    将feature map按照window_size划分成一个个没有重叠的window
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows
 
 
def window_reverse(windows, window_size: int, H: int, W: int):
    """
    将一个个window还原成一个feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x
 
 
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
 
    def forward(self, x):
        _, _, H, W = x.shape
 
        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))
 
        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W
 
 
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
 
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)
 
    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
 
        x = x.view(B, H, W, C)
 
        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
 
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
 
        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]
 
        return x
 
 
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
 
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)
 
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x
 
 
class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """
 
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
 
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # [Mh, Mw]
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
 
        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]
 
        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        self.register_buffer("relative_position_index", relative_position_index)
 
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
 
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
 
    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        B_, N, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
 
        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
 
        # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
        attn = attn + relative_position_bias.unsqueeze(0)
 
        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
 
        attn = self.attn_drop(attn)
 
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
 
 
class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
 
    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
 
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)
 
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
 
    def forward(self, x, attn_mask):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
 
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)
 
        # pad feature maps to multiples of window size
        # 把feature map给pad到window size的整数倍
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape
 
        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
            attn_mask = None
 
        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]
 
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]
 
        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]
 
        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
 
        if pad_r > 0 or pad_b > 0:
            # 把前面pad的数据移除掉
            x = x[:, :H, :W, :].contiguous()
 
        x = x.view(B, H * W, C)
 
        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
 
        return x
 
 
class BasicLayer(nn.Module):
    """
    A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """
 
    def __init__(self, dim, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint
        self.shift_size = window_size // 2
 
        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else self.shift_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])
 
        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None
 
    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 拥有和feature map一样的通道排列顺序,方便后续window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
 
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask
 
    def forward(self, x, H, W):
        attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]
        for blk in self.blocks:
            blk.H, blk.W = H, W
            if not torch.jit.is_scripting() and self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x = self.downsample(x, H, W)
            H, W = (H + 1) // 2, (W + 1) // 2
 
        return x, H, W
 
 
class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """
 
    def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()
 
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        # stage4输出特征矩阵的channels
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio
 
        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)
 
        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
 
        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            # 注意这里构建的stage和论文图中有些差异
            # 这里的stage不包含该stage的patch_merging层,包含的是下个stage的
            layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                                depth=depths[i_layer],
                                num_heads=num_heads[i_layer],
                                window_size=window_size,
                                mlp_ratio=self.mlp_ratio,
                                qkv_bias=qkv_bias,
                                drop=drop_rate,
                                attn_drop=attn_drop_rate,
                                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                                norm_layer=norm_layer,
                                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                                use_checkpoint=use_checkpoint)
            self.layers.append(layers)
 
        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
 
        self.apply(self._init_weights)
 
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
 
    def forward(self, x):
        # x: [B, L, C]
        x, H, W = self.patch_embed(x)
        x = self.pos_drop(x)
 
        for layer in self.layers:
            x, H, W = layer(x, H, W)
 
        x = self.norm(x)  # [B, L, C]
        x = self.avgpool(x.transpose(1, 2))  # [B, C, 1]
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x
 
 
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 6, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    return model

引入预训练模型之后,进行花类数据集(上篇博客提到)进行微调,代码如下:

import os
import argparse
 
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
 
from my_dataset import MyDataSet
from model import swin_tiny_patch4_window7_224 as create_model
from utils import read_split_data, train_one_epoch, evaluate
 
 
def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
 
    if os.path.exists("../weights") is False:
        os.makedirs("../weights")
 
    tb_writer = SummaryWriter()
 
    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
 
    img_size = 224
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
 
    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])
 
    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])
 
    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)
 
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)
 
    model = create_model(num_classes=args.num_classes).to(device)
 
    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)["model"]
        # 删除有关分类类别的权重
        for k in list(weights_dict.keys()):
            if "head" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))
 
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))
 
    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)
 
    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)
 
        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)
 
        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
 
        torch.save(model.state_dict(), "../weights/model-{}.pth".format(epoch))
 
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=0.0001)
 
    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default="../data/flower_photos")
 
    # 预训练权重路径,如果不想载入就设置为空字符
    parser.add_argument('--weights', type=str, default='../weights/swin_tiny_patch4_window7_224.pth',
                        help='initial weights path')
    # 是否冻结权重
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
 
    opt = parser.parse_args()
 
    main(opt)

模型预测代码:

import os
import json
 
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
 
from model import swin_tiny_patch4_window7_224 as create_model
 
 
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
    img_size = 224
    data_transform = transforms.Compose(
        [transforms.Resize(int(img_size * 1.14)),
         transforms.CenterCrop(img_size),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
 
    # load image
    img_path = "../data/Image/flower.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    plt.show()
    img2 = img
    # [N, C, H, W]
    img = img.convert('RGB')
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)
 
    # read class_indict
    json_path = 'class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
 
    with open(json_path, "r") as f:
        class_indict = json.load(f)
 
    # create model
    model = create_model(num_classes=5).to(device)
    # load model weights
    model_weight_path = "../weights/model-9.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
 
    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
 
    plt.imshow(img2)
    plt.show()
 
 
if __name__ == '__main__':
    main()

代码预测结果如下图所示:

在这里插入图片描述
大约99.8%的可能性为玫瑰类别

总结

Swin Transformer模型通过创新的窗口机制在保持Transformer模型优势的同时,显著降低了计算成本,使其在计算机视觉任务中表现出色。模型的层次化特征提取和多尺度感知能力使其在图像分类、目标检测和语义分割等任务中取得了优异的性能。通过实际代码的展示和微调应用,预测结果的高准确率进一步证明了模型的强大能力,特别是在处理高分辨率图像时的效率和准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2262616.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Spring Boot的校园车辆管理系统

一、系统背景与意义 随着校园规模的不断扩大和车辆数量的增加&#xff0c;传统的车辆管理方式已经难以满足高效、准确管理车辆的需求。因此&#xff0c;开发一个基于Spring Boot的校园车辆管理系统具有重要的现实意义。该系统可以实现对校园车辆的信息化管理&#xff0c;提高车…

SpringBoot3整合FastJSON2如何配置configureMessageConverters

在 Spring Boot 3 中整合 FastJSON 2 主要涉及到以下几个步骤&#xff0c;包括添加依赖、配置 FastJSON 作为 JSON 处理器等。下面是详细的步骤&#xff1a; 1. 添加依赖 首先&#xff0c;你需要在你的 pom.xml 文件中添加 FastJSON 2 的依赖。以下是 Maven 依赖的示例&#…

隐私清理工具Goversoft Privazer

PrivaZer 是一款专为隐私保护而生的 Windows 系统清理工具&#xff0c;支持深度扫描、清除无用文件和隐私痕迹。 PrivaZer - 深度扫描磁盘&#xff0c;自动清理上网痕迹&#xff0c;全面保护 Windows 的网络隐私 释放磁盘空间 硬盘空间告急&#xff0c;想清理却又无从下手&…

UDP网络编程套接

目录 本文核心 预备知识 1.端口号 认识TCP协议 认识UDP协议 网络字节序 socket编程接口 sockaddr结构 UDP套接字编程 服务端 客户端 TCP与UDP传输的区别 可靠性&#xff1a; 传输方式&#xff1a; 用途&#xff1a; 头部开销&#xff1a; 速度&#xff1a; li…

3.zabbix中文设置

1、zabbix中文设置 2、中文乱码的原因 zabbix使用DejaVuSan.ttf字体&#xff0c;不支持中文&#xff0c;导致中文出现乱码。解决方法很简单&#xff0c;把我们电脑里面字体文件传到zabbix服务器上。 3、解决zabbix乱码方法 3.1、从Window服务器找到相应的字休复制到zabbix S…

【Vue3学习】setup语法糖中的ref,reactive,toRef,toRefs

在 Vue 3 的组合式 API&#xff08;Composition API&#xff09;中&#xff0c;ref、reactive、toRef 和 toRefs 是四个非常重要的工具函数&#xff0c;用于创建和管理响应式数据。 一、ref 用ref()包裹数据,返回的响应式引用对象&#xff0c;包含一个 .value 属性&#xff0…

图书馆管理系统(三)基于jquery、ajax

任务3.4 借书还书页面 任务描述 这部分主要是制作借书还书的界面&#xff0c;这里我分别制作了两个网页分别用来借书和还书。此页面&#xff0c;也是通过获取books.txt内容然后添加到表格中&#xff0c;但是借还的操作没有添加到后端中去&#xff0c;只是一个简单的前端操作。…

springboot450房屋租赁管理系统(论文+源码)_kaic

摘 要 如今社会上各行各业&#xff0c;都喜欢用自己行业的专属软件工作&#xff0c;互联网发展到这个时候&#xff0c;人们已经发现离不开了互联网。新技术的产生&#xff0c;往往能解决一些老技术的弊端问题。因为传统房屋租赁管理系统信息管理难度大&#xff0c;容错率低&am…

vue框架的搭建

1什么是Node.js&#xff1b; Node.js 是一个免费、开源、跨平台的 JavaScript 运行时环境, 它让开发人员能够创建服务器 Web 应用、命令行工具和脚本 Node.js下载&#xff1a; 下载Node 16.20.2 安装Node.js 安装Node.js 测试安装 运行命令行 win键R 查看node版本 输入&am…

TCP基础了解

什么是 TCP &#xff1f; TCP 是面向连接的、可靠的、基于字节流的传输层通信协议。 面向连接&#xff1a;一定是「一对一」才能连接&#xff0c;不能像 UDP 协议可以一个主机同时向多个主机发送消息&#xff0c;也就是一对多是无法做到的&#xff1b; 可靠的&#xff1a;无论…

电商数据采集电商,行业数据分析,平台数据获取|稳定的API接口数据

电商数据采集可以通过多种方式完成&#xff0c;其中包括人工采集、使用电商平台提供的API接口、以及利用爬虫技术等自动化工具。以下是一些常用的电商数据采集方法&#xff1a; 人工采集&#xff1a;人工采集主要是通过基本的“复制粘贴”的方式在电商平台上进行数据的收集&am…

C语言-稀疏数组转置

1.题目要求 2.代码实现 #include <stdio.h> #define MAX_TERM 80// 定义稀疏矩阵结构体 typedef struct juzhen {int row;int col;int value; } Juzhen;// 显示稀疏矩阵 void show(Juzhen a[], int count_a) {printf(" i row col val\n");for (int i 1; i &…

在 Spring Boot 3 中实现基于角色的访问控制

基于角色的访问控制 (RBAC) 是一种有价值的访问控制模型,可增强安全性、简化访问管理并提高效率。它在管理资源访问对安全和运营至关重要的复杂环境中尤其有益。 我们将做什么 我们有一个包含公共路由和受限路由的 Web API。受限路由需要数据库中用户的有效 JWT。 现在用户…

计算机毕业设计python+spark+hive动漫推荐系统 漫画推荐系统 漫画分析可视化大屏 漫画爬虫 漫画推荐系统 漫画爬虫 知识图谱 大数据毕设

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…

SAP抓取外部https报错SSL handshake处理方法

一、问题描述 SAP执行报表抓取https第三方数据,数据获取失败。 报错消息: SSL handshake with XXX.COM:449 failed: SSSLERR_SSL_READ (-58)#SAPCRYPTO:SSL_read() failed##SapSSLSessionStartNB()==SSSLERR_SSL_READ# SSL:SSL_read() failed (536875120/0x20001070)# …

OpenCV基本图像处理操作(三)——图像轮廓

轮廓 cv2.findContours(img,mode,method) mode:轮廓检索模式 RETR_EXTERNAL &#xff1a;只检索最外面的轮廓&#xff1b;RETR_LIST&#xff1a;检索所有的轮廓&#xff0c;并将其保存到一条链表当中&#xff1b;RETR_CCOMP&#xff1a;检索所有的轮廓&#xff0c;并将他们组…

告别机器人味:如何让ChatGPT写出有灵魂的内容

目录 ChatGPT的一些AI味道小问题 1.提供编辑指南 2.提供样本 3.思维链大纲 4.融入自己的想法 5.去除重复增加多样性 6.删除废话 ChatGPT的一些AI味道小问题 大多数宝子们再使用ChatGPT进行写作时&#xff0c;发现我们的老朋友ChatGPT在各类写作上还有点“机器人味”太重…

对于给定PI参数的锁相环带宽简单计算方法

锁相环的控制框图一般为&#xff1a; 对于锁相环的闭环传递函数&#xff1a; H ( s ) K P L L p s K P L L i s 2 K P L L p s K P L L i H(s)\frac{K_{PLLp}sK_{PLLi}}{s^2K_{PLLp}sK_{PLLi}} H(s)s2KPLLp​sKPLLi​KPLLp​sKPLLi​​ 我们可以通过分析系统的特征方程&a…

day14-16系统服务管理和ntp和防火墙

一、自有服务概述 服务是一些特定的进程&#xff0c;自有服务就是系统开机后就自动运行的一些进程&#xff0c;一旦客户发出请求&#xff0c;这些进程就自动为他们提供服务&#xff0c;windows系统中&#xff0c;把这些自动运行的进程&#xff0c;称为"服务" window…

【数据集】玻璃门窗缺陷检测数据集3085张5类YIOLO+VOC格式

数据集格式&#xff1a;VOC格式YOLO格式 压缩包内含&#xff1a;3个文件夹&#xff0c;分别存储图片、xml、txt文件 JPEGImages文件夹中jpg图片总计&#xff1a;3085 Annotations文件夹中xml文件总计&#xff1a;3085 labels文件夹中txt文件总计&#xff1a;3085 标签种类数&am…