Swin Transfromer在2021年首次发表于论文《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》,目前已用于图像分割、分类等计算机视觉领域的各项任务中。该模型借鉴了Vision Transformer模型的思想,将二维图像加工成transformer可处理的一维数据,试图将transformer这一自然语言处理领域的利器,迁移至计算机视觉领域,以获得较优的模型性能。


  • 随着图像的分辨率增加,展平后的序列长度呈平方级别增加,是模型不可接受的,将严重影响transformer中自注意力的计算效率;
  • 不同于NLP模型的传统输入特征,同一物体的图像因拍摄角度等原因,尺度和内容特征方差较大。

Swin Transformer创新地引入了滑动窗口机制,在窗口内进行自注意力计算,使计算复杂度随图片分辨率平方级别增长降低为线性增长,并使模型可以学习到跨窗口的图像信息;参考了传统的CNN结构,进行类似的级联式局部特征提取工作,在每一层次进行下采样以扩大感受野,并可提取到多尺度的图像信息。


Swin Transfromer的基本结构如图1(d)所示,由4个层级式模块组成,每一模块内包含一个Swin Transformer Block。原始输入图像尺寸为H×W×3(3为RGB通道数),经过Patch Partition层分为大小为4×4的patch,尺寸转换为(H/4)×(W/4)×48,其功能等价于卷积核大小为4×4,步长为4,卷积核个数为48的二维卷积操作。

随后,每个Stage内部由Patch Merging模块(Stage1为Linear Embedding)、Swin Transformer模块组成。以Stage1、Stage2为例:

  • Linear Embedding

    线性嵌入模块(Linear Embedding)将图像的通道数调整为嵌入长度C,调整后的图像尺寸为(H/4)×(W/4)×C。

  • Swin Transformer

    Linear Embedding模块的输出序列长度往往较大,无法直接输入transformer。在本模型中,将输入张量划分为m×m大小的窗口,其中m为每个窗口内的patch数量,原文模型中默认为7。自注意力计算将在每个窗口内展开。为提取窗口间的信息,对窗口进行滑动,并再次计算自注意力。有关Swin Transformer模块的自注意力计算等实现细节见下文。

  • Patch Merging

    下采样模块(Patch Merging)的作用是降低图片分辨率,扩大感受野,捕获多尺寸的图片信息,同时降低计算量,如图1(a)所示。类比于下采样在CNN模型中的实现,Swin Transformer模块的输出经过一次下采样,由(H/4)×(W/4)×C转换为(H/8)×(W/8)×2C,将临近的2个小patch合并成一个大patch,相当于进行了步长为2的二维卷积操作,后经过线性层(或1×1卷积层)调整通道数至2C,功能等价于Patch Partition模块与Linear Embedding模块的先后组合。此后再经过多个Swin Transformer模块、Patch Merging模块,进行多尺度特征的提取。


图1 Swin Transformer模型结构图

Swin Transformer模块原理

Swin Transformer模块的细节实现如图1(c)所示,核心部分为多头自注意力的计算。如上所述,基于全尺寸图像的自注意力在面对密集计算型任务时具有较高的计算复杂度,因此本模型采用基于窗口的自注意力计算(W-MSA),输入张量先被分割成窗口,在窗口内的patch之间计算多头自注意力。相比于传统的自注意力计算,此处在计算q、k时额外增加了相对位置编码,以提供窗口内的位置信息。

Swin Transformer的亮点在于采用滑动窗口机制,提取窗口之间的信息,达到全局自注意力的效果。以4个4×4大小的窗口为例,在计算滑动窗口多头自注意力(SW-MSA)时,将窗口向右下移动2个patch,得到9个新窗口,其中上、下、左、右4个新窗口包含原窗口划分中2个窗口的图像信息,中间的新窗口包含原窗口划分中4个窗口的图像信息,从而实现窗口之间的通信。但随之而来的是窗口数量的成倍增加和不同的窗口大小,反而增大了计算难度。

为降低计算量,本模型对新窗口进行循环移位处理,组成新的窗口布局,将4×4大小的窗口数降低为4个,如图1(b)所示。此时除左上角的窗口外,其它3个窗口包含的部分patch原本并不属于同一区域,不应计算之间的自注意力,因此本模型创新性地提出了masked MSA机制,原理在自注意力计算结果上加上mask矩阵,目的是只取相同原窗口内的patch间的q、k计算结果,而不同原窗口内的patch间q、k计算结果添加原码后得到一个较大的负数,在随后的softmax层计算中,掩码部分输出将会是0,达到忽略其值的目的。


  • 基于滑动窗口的自注意力计算,交替使用W-MSA、SW-MSA,解决transformer应用于计算机视觉领域所产生的计算复杂度高的问题;
  • 借鉴CNN架构,对图像进行降采样,捕获多尺寸层次的全局特征。



└─ swin_transformer
    ├─ src
        ├─ configs                // SwinTransformer的配置文件
            ├─ args.py
            └─ swin_tiny_patch4_window7_224.yaml
        └─ data
            ├─ augment            // 数据增强函数文件
                ├─ auto_augment.py
                ├─ mixup.py
                └─ random_erasing.py
            └─ imagenet           // miniImageNet数据集
                ├─ train
                ├─ val
                └─ test
    ├─ swin_transformer.ipynb     // 端到端可执行的Notebook文件
    └─ README.md


In [1]:

import numpy as np

import mindspore.common.initializer as weight_init
import mindspore.ops.operations as P
from mindspore import Parameter
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore import nn
from mindspore import numpy
from mindspore import ops

import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms as C
import mindspore.dataset.vision as vision
from mindspore.dataset.vision.utils import Inter

import collections.abc
from itertools import repeat

import os

import mindspore
from mindspore import Model
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Callback
from mindspore.nn.loss.loss import LossBase
from mindspore.common import RowTensor
from mindspore.ops import composite as Cps
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.optim.momentum import Momentum



    ├─ PatchEmbed
    └─ BasicLayer
        ├─ PatchMerging
        └─ SwinTransformerBlock
            ├─ WindowAttention
            ├─ RelativeBias
            ├─ DropPath1D
            ├─ Mlp
            ├─ Roll
            ├─ WindowPartitionConstruct
            └─ WindowReverseConstruct

In [2]:

from src.configs.args import args

def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse

class Identity(nn.Cell):
    def construct(self, x):
        return x

class DropPath(nn.Cell):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).

    def __init__(self, drop_prob, ndim):
        super(DropPath, self).__init__()
        self.drop = nn.Dropout(keep_prob=1 - drop_prob)
        shape = (1,) + (1,) * (ndim + 1)
        self.ndim = ndim
        self.mask = Tensor(np.ones(shape), dtype=mstype.float32)

    def construct(self, x):
        if not self.training:
            return x
        mask = ops.Tile()(self.mask, (x.shape[0],) + (1,) * (self.ndim + 1))
        out = self.drop(mask)
        out = out * x
        return out

class DropPath1D(DropPath):
    def __init__(self, drop_prob):
        super(DropPath1D, self).__init__(drop_prob=drop_prob, ndim=1)

to_2tuple = _ntuple(2)

act_layers = {
    "GELU": nn.GELU,
    "gelu": nn.GELU,
=> Reading YAML config from ./src/configs/swin_tiny_patch4_window7_224.yaml
Namespace(arch='swin_tiny_patch4_window7_224', accumulation_step=1, amp_level='O1', ape=False, batch_size=128, beta=[0.9, 0.999], clip_global_norm_value=5.0, crop=True, data_url='./src/data/imagenet', device_id=1, device_num=1, device_target='GPU', epochs=100, eps=1e-08, file_format='MINDIR', in_channel=3, is_dynamic_loss_scale=True, keep_checkpoint_max=20, optimizer='adamw', set='ImageNet', graph_mode=0, mix_up=0.8, mlp_ratio=4.0, num_parallel_workers=16, start_epoch=0, warmup_length=20, warmup_lr=7e-08, weight_decay=0.05, loss_scale=1024, lr=0.0005, lr_scheduler='cosine_lr', lr_adjust=30, lr_gamma=0.97, momentum=0.9, num_classes=100, patch_size=4, patch_norm=True, swin_config='./src/configs/swin_tiny_patch4_window7_224.yaml', seed=0, save_every=5, label_smoothing=0.1, image_size=224, train_url='./', cutmix=1.0, auto_augment='rand-m9-mstd0.5-inc1', interpolation='bicubic', re_prob=0.25, re_mode='pixel', re_count=1, mixup_prob=1.0, switch_prob=0.5, mixup_mode='batch', base_lr=0.0005, min_lr=6e-06, nonlinearity='GELU', keep_bn_fp32=True, drop_path_rate=0.2, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7)

In [3]:

class Mlp(nn.Cell):
    """MLP Cell"""

    def __init__(self, in_features, hidden_features=None,
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Dense(in_channels=in_features, out_channels=hidden_features, has_bias=True)
        self.act = act_layer()
        self.fc2 = nn.Dense(in_channels=hidden_features, out_channels=out_features, has_bias=True)
        self.drop = nn.Dropout(keep_prob=1.0 - drop)

    def construct(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [4]:

def window_partition(x, window_size):
        x: (B, H, W, C)
        window_size (int): window size

        windows: (num_windows*B, window_size, window_size, C)
    B, H, W, C = x.shape
    x = np.reshape(x, (B, H // window_size, window_size, W // window_size, window_size, C))
    windows = x.transpose(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
    return windows

In [5]:

class WindowPartitionConstruct(nn.Cell):
    """WindowPartitionConstruct Cell"""

    def __init__(self, window_size):
        super(WindowPartitionConstruct, self).__init__()

        self.window_size = window_size

    def construct(self, x):
            x: (B, H, W, C)
            window_size (int): window size

            windows: (num_windows*B, window_size, window_size, C)
        B, H, W, C = x.shape
        x = P.Reshape()(x, (B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C))
        x = P.Transpose()(x, (0, 1, 3, 2, 4, 5))
        x = P.Reshape()(x, (B * H * W // (self.window_size ** 2), self.window_size, self.window_size, C))

        return x

In [6]:

class WindowReverseConstruct(nn.Cell):
    """WindowReverseConstruct Cell"""

    def construct(self, windows, window_size, H, W):
            windows: (num_windows*B, window_size, window_size, C)
            window_size (int): Window size
            H (int): Height of image
            W (int): Width of image

            x: (B, H, W, C)
        B = windows.shape[0] // (H * W // window_size // window_size)
        x = ops.Reshape()(windows, (B, H // window_size, W // window_size, window_size, window_size, -1))
        x = ops.Transpose()(x, (0, 1, 3, 2, 4, 5))
        x = ops.Reshape()(x, (B, H, W, -1))
        return x

In [7]:

class RelativeBias(nn.Cell):
    """RelativeBias Cell"""

    def __init__(self, window_size, num_heads):
        super(RelativeBias, self).__init__()
        self.window_size = window_size
        # define a parameter table of relative position bias
        coords_h = np.arange(self.window_size[0]).reshape(self.window_size[0], 1).repeat(self.window_size[0],
                                                                                         1).reshape(1, -1)
        coords_w = np.arange(self.window_size[1]).reshape(1, self.window_size[1]).repeat(self.window_size[1],
                                                                                         0).reshape(1, -1)
        coords_flatten = np.concatenate([coords_h, coords_w], axis=0)  # 2, Wh, Ww
        relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.transpose(1, 2, 0)  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        self.relative_position_index = Tensor(relative_coords.sum(-1).reshape(-1))  # Wh*Ww, Wh*Ww
        self.relative_position_bias_table = Parameter(
            Tensor(np.random.randn((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads),
                   dtype=mstype.float32))  # 2*Wh-1 * 2*Ww-1, nH
        self.one_hot = nn.OneHot(axis=-1, depth=(2 * window_size[0] - 1) * (2 * window_size[1] - 1),
        self.index = Parameter(self.one_hot(self.relative_position_index), requires_grad=False)

    def construct(self, axis=0):
        out = ops.MatMul()(self.index, self.relative_position_bias_table)
        out = P.Reshape()(out, (self.window_size[0] * self.window_size[1],
                                self.window_size[0] * self.window_size[1], -1))
        out = P.Transpose()(out, (2, 0, 1))
        out = ops.ExpandDims()(out, 0)
        return out

In [8]:

class WindowAttention(nn.Cell):
    r""" Window based multi-head self attention (W-MSA) Cell with relative position bias.
    It supports both of shifted and non-shifted window.

        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qZk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        if isinstance(dim, tuple) and len(dim) == 1:
            dim = dim[0]
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = Tensor(qk_scale or head_dim ** -0.5, mstype.float32)
        self.relative_bias = RelativeBias(self.window_size, num_heads)

        # get pair-wise relative position index for each token inside the window
        self.q = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
        self.k = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
        self.v = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)

        self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_drop)
        self.proj = nn.Dense(in_channels=dim, out_channels=dim, has_bias=True)
        self.proj_drop = nn.Dropout(keep_prob=1.0 - proj_drop)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x, mask=None):
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        B_, N, C = x.shape
        q = ops.Reshape()(self.q(x), (B_, N, self.num_heads, C // self.num_heads)) * self.scale
        q = ops.Transpose()(q, (0, 2, 1, 3))
        k = ops.Reshape()(self.k(x), (B_, N, self.num_heads, C // self.num_heads))
        k = ops.Transpose()(k, (0, 2, 3, 1))
        v = ops.Reshape()(self.v(x), (B_, N, self.num_heads, C // self.num_heads))
        v = ops.Transpose()(v, (0, 2, 1, 3))

        attn = ops.BatchMatMul()(q, k)
        attn = attn + self.relative_bias()

        if mask is not None:
            nW = mask.shape[1]
            attn = P.Reshape()(attn, (B_ // nW, nW, self.num_heads, N, N,)) + mask
            attn = P.Reshape()(attn, (-1, self.num_heads, N, N,))
            attn = self.softmax(attn)
            attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = ops.Reshape()(ops.Transpose()(ops.BatchMatMul()(attn, v), (0, 2, 1, 3)), (B_, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

In [9]:

class Roll(nn.Cell):
    """Roll Cell"""

    def __init__(self, shift_size, shift_axis=(1, 2)):
        super(Roll, self).__init__()
        self.shift_size = to_2tuple(shift_size)
        self.shift_axis = shift_axis

    def construct(self, x):
        x = numpy.roll(x, self.shift_size, self.shift_axis)
        return x

In [10]:

class SwinTransformerBlock(nn.Cell):
    """ Swin Transformer Block.

        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Cell, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Cell, optional): Normalization layer.  Default: nn.LayerNorm

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=act_layers[args.nonlinearity], norm_layer=nn.LayerNorm):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)

        if isinstance(dim, int):
            dim = (dim,)

        self.norm1 = norm_layer(dim, epsilon=1e-5)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath1D(drop_path) if drop_path > 0. else Identity()
        self.norm2 = norm_layer(dim, epsilon=1e-5)
        mlp_hidden_dim = int((dim[0] if isinstance(dim, tuple) else dim) * mlp_ratio)
        self.mlp = Mlp(in_features=dim[0] if isinstance(dim, tuple) else dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop)
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = np.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            # img_mask: [1, 56, 56, 1] window_size: 7
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.reshape(-1, self.window_size * self.window_size)
            attn_mask = mask_windows[:, np.newaxis] - mask_windows[:, :, np.newaxis]
            # [64, 49, 49] ==> [1, 64, 1, 49, 49]
            attn_mask = np.expand_dims(attn_mask, axis=1)
            attn_mask = np.expand_dims(attn_mask, axis=0)
            attn_mask = Tensor(np.where(attn_mask == 0, 0., -100.), dtype=mstype.float32)
            self.attn_mask = Parameter(attn_mask, requires_grad=False)
            self.roll_pos = Roll(self.shift_size)
            self.roll_neg = Roll(-self.shift_size)
            self.attn_mask = None

        self.window_partition = WindowPartitionConstruct(self.window_size)
        self.window_reverse = WindowReverseConstruct()

    def construct(self, x):
        """construct function"""
        H, W = self.input_resolution
        B, _, C = x.shape

        shortcut = x
        x = self.norm1(x)
        x = P.Reshape()(x, (B, H, W, C,))

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = self.roll_neg(x)
            # shifted_x = numpy.roll(x, (-self.shift_size, -self.shift_size), (1, 2))
            shifted_x = x

        # partition windows
        x_windows = self.window_partition(shifted_x)  # nW*B, window_size, window_size, C
        x_windows = ops.Reshape()(x_windows,
                                  (-1, self.window_size * self.window_size, C,))  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = P.Reshape()(attn_windows, (-1, self.window_size, self.window_size, C,))
        shifted_x = self.window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = self.roll_pos(shifted_x)
            # x = numpy.roll(shifted_x, (self.shift_size, self.shift_size), (1, 2))  # TODO:Don't stupid
            x = shifted_x

        x = P.Reshape()(x, (B, H * W, C,))

        # FFN
        x = shortcut + self.drop_path(x)

        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

In [11]:

class PatchMerging(nn.Cell):
    """ Patch Merging Layer.

        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        self.input_resolution = input_resolution
        self.dim = dim[0] if isinstance(dim, tuple) and len(dim) == 1 else dim
        # Default False
        self.reduction = nn.Dense(in_channels=4 * dim, out_channels=2 * dim, has_bias=False)
        self.norm = norm_layer([dim * 4,])
        self.H, self.W = self.input_resolution
        self.H_2, self.W_2 = self.H // 2, self.W // 2
        self.H2W2 = int(self.H * self.W // 4)
        self.dim_mul_4 = int(dim * 4)
        self.H2W2 = int(self.H * self.W // 4)

    def construct(self, x):
        x: B, H*W, C
        B = x.shape[0]
        x = P.Reshape()(x, (B, self.H_2, 2, self.W_2, 2, self.dim))
        x = P.Transpose()(x, (0, 1, 3, 4, 2, 5))
        x = P.Reshape()(x, (B, self.H2W2, self.dim_mul_4))
        x = self.norm(x)
        x = self.reduction(x)

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

In [12]:

class BasicLayer(nn.Cell):
    """ A basic Swin Transformer layer for one stage.

        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Cell, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Cell | None, optional): Downsample layer at the end of the layer. Default: None

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None):

        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth

        # build blocks
        self.blocks = nn.CellList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,  # TODO: 这里window_size//2的时候特别慢
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
            self.downsample = None

    def construct(self, x):
        for blk in self.blocks:
            x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

In [13]:

class PatchEmbed(nn.Cell):
    """ Image to Patch Embedding

        image_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Cell, optional): Normalization layer. Default: None

    def __init__(self, image_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        image_size = to_2tuple(image_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
        self.image_size = image_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size,
                              pad_mode='pad', has_bias=True, weight_init="TruncatedNormal")

        if norm_layer is not None:
            if isinstance(embed_dim, int):
                embed_dim = (embed_dim,)
            self.norm = norm_layer(embed_dim, epsilon=1e-5)
            self.norm = None

    def construct(self, x):
        B = x.shape[0]
        # FIXME look at relaxing size constraints
        x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1))  # B Ph*Pw C
        x = ops.Transpose()(x, (0, 2, 1))

        if self.norm is not None:
            x = self.norm(x)
        return x

In [14]:

class SwinTransformer(nn.Cell):
    """ Swin Transformer
        A Pynp impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -

        image_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 100
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Cell): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True

    def __init__(self, image_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=None, num_heads=None, window_size=7,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True):

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            image_size=image_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = Parameter(Tensor(np.zeros(1, num_patches, embed_dim), dtype=mstype.float32))

        self.pos_drop = nn.Dropout(keep_prob=1.0 - drop_rate)

        # stochastic depth
        dpr = [x for x in np.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.CellList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)

        self.norm = norm_layer([self.num_features,], epsilon=1e-5)
        self.avgpool = P.ReduceMean(keep_dims=False)
        self.head = nn.Dense(in_channels=self.num_features,
                             out_channels=num_classes, has_bias=True) if num_classes > 0 else Identity()

    def init_weights(self):
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Dense):
                if isinstance(cell, nn.Dense) and cell.bias is not None:
            elif isinstance(cell, nn.LayerNorm):

    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)  # B L C
        x = self.avgpool(ops.Transpose()(x, (0, 2, 1)), 2)  # B C 1
        return x

    def construct(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x



  • 数据集大小:共100类,60000张图像,每类600张图像
  • 数据格式:JPGE格式,84*84彩色图像
  • 对数据集结构的处理要求:类别分布均衡,训练集 : 验证集 : 测试集 = 7 : 1 : 2


└─ dataset
    ├─ images
        ├─ n0153282900000005.jpg
        ├─ n0153282900000006.jpg
        ├─ ...
    ├─ train.csv
    ├─ val.csv
    └─ test.csv


└─ dataset
    ├─ train
        ├─ 第1类
            └─ 600张图像 
        ├─ ...
        └─ 第64类
            └─ 600张图像 
    ├─ val
        ├─ 第65类
            └─ 600张图像 
        ├─ ...
        └─ 第80类
            └─ 600张图像 
    └─ test
        ├─ 第81类
            └─ 600张图像 
        ├─ ...
        └─ 第100类
            └─ 600张图像


└─ dataset
    ├─ train
        ├─ 第1类
            └─ 420张图像 
        ├─ ...
        └─ 第100类
            └─ 420张图像 
    ├─ val
        ├─ 第1类
            └─ 60张图像 
        ├─ ...
        └─ 第100类
            └─ 60张图像 
    └─ test
        ├─ 第1类
            └─ 120张图像 
        ├─ ...
        └─ 第100类
            └─ 120张图像


  • 链接:百度网盘 请输入提取码 提取码:xqnu


In [15]:

Data operations, will be used in train.py and eval.py

from src.data.augment.auto_augment import _pil_interp, rand_augment_transform
from src.data.augment.mixup import Mixup
from src.data.augment.random_erasing import RandomErasing

In [16]:

def _get_rank_info():
    get rank size and rank id
    rank_size = int(os.environ.get("RANK_SIZE", 1))

    if rank_size > 1:
        from mindspore.communication.management import get_rank, get_group_size
        rank_size = get_group_size()
        rank_id = get_rank()
        rank_size = rank_id = None

    return rank_size, rank_id

In [17]:

def create_dataset_imagenet(dataset_dir, args, repeat_num=1, training=True):
    create a train or val or test mini-imagenet dataset for SwinTransformer

        dataset_dir(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1


    device_num, rank_id = _get_rank_info()
    shuffle = bool(training)
    if device_num == 1 or not training:
        data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers,
        data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers, shuffle=shuffle,
                                         num_shards=device_num, shard_id=rank_id)

    image_size = args.image_size

    # define map operations
    # BICUBIC: 3
    # data augment
    if training:
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        aa_params = dict(
            translate_const=int(image_size * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        interpolation = args.interpolation
        auto_augment = args.auto_augment
        assert auto_augment.startswith('rand')
        aa_params['interpolation'] = _pil_interp(interpolation)

        transform_img = [
            vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3),
        transform_img += [rand_augment_transform(auto_augment, aa_params)]
        transform_img += [
            vision.Normalize(mean=mean, std=std, is_hwc=False),
            RandomErasing(args.re_prob, mode=args.re_mode, max_count=args.re_count)
        mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
        std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
        # test transform complete
        if args.crop:
            transform_img = [
                vision.Resize(int(256 / 224 * image_size), interpolation=Inter.BICUBIC),
                vision.Normalize(mean=mean, std=std, is_hwc=True),
            transform_img = [
                vision.Resize(int(image_size), interpolation=Inter.BICUBIC),
                vision.Normalize(mean=mean, std=std, is_hwc=True),

    transform_label = C.TypeCast(mstype.int32)

    data_set = data_set.map(input_columns="image", num_parallel_workers=args.num_parallel_workers,
    data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
    if (args.mix_up > 0. or args.cutmix > 0.)  and not training:
        # if use mixup and not training(False), one hot val data label
        one_hot = C.OneHot(num_classes=args.num_classes)
        data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
    # apply batch operations
    data_set = data_set.batch(args.batch_size, drop_remainder=True,

    if (args.mix_up > 0. or args.cutmix > 0.) and training:
        mixup_fn = Mixup(
            mixup_alpha=args.mix_up, cutmix_alpha=args.cutmix, cutmix_minmax=None,
            prob=args.mixup_prob, switch_prob=args.switch_prob, mode=args.mixup_mode,
            label_smoothing=args.label_smoothing, num_classes=args.num_classes)

        data_set = data_set.map(operations=mixup_fn, input_columns=["image", "label"],

    # apply dataset repeat operation
    data_set = data_set.repeat(repeat_num)
    return data_set

In [18]:

class ImageNet:
    """ImageNet Define"""

    def __init__(self, args, training=True):
        train_dir = os.path.join(args.data_url, "train")
        val_ir = os.path.join(args.data_url, "val")
        test_ir = os.path.join(args.data_url, "test")
        if training:
            self.train_dataset = create_dataset_imagenet(train_dir, training=True, args=args)
            self.val_dataset = create_dataset_imagenet(val_ir, training=False, args=args)
        self.test_dataset = create_dataset_imagenet(test_ir, training=False, args=args)


本案例利用MindSpore框架,基于Callback基于自定义了回调类 EvaluateCallBack,为了满足在模型训练时,能够在每一次epoch结束后计算3个评价指标:Loss、Top1-Acc、Top5-Acc,并根据评价指标的变化保存当前最优的模型的特定需求。



  • 验证集:Top1-Acc:55.15%,Top5-Acc:81.00%
  • 测试集:Top1-Acc:55.26% ,Top5-Acc:81.59%



In [19]:

def do_keep_fp32(network, cell_types):
    """Cast cell to fp32 if cell in cell_types"""
    for _, cell in network.cells_and_names():
        if isinstance(cell, cell_types):

def cast_amp(net):
    """cast network amp_level"""
    if args.amp_level == "O1":
        print(f"=> using amp_level {args.amp_level}\n"
              f"=> change {args.arch} to fp16")
        cell_types = (nn.GELU, nn.Softmax, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d, nn.LayerNorm)
        print(f"=> cast {cell_types} to fp32 back")
        do_keep_fp32(net, cell_types)
    elif args.amp_level == "O2":
        print(f"=> using amp_level {args.amp_level}\n"
              f"=> change {args.arch} to fp16")
        cell_types = (nn.BatchNorm2d, nn.LayerNorm)
        print(f"=> cast {cell_types} to fp32 back")
        do_keep_fp32(net, cell_types)
    elif args.amp_level == "O3":
        print(f"=> using amp_level {args.amp_level}\n"
              f"=> change {args.arch} to fp16")
        print(f"=> using amp_level {args.amp_level}")
        args.loss_scale = 1.
        args.is_dynamic_loss_scale = 0
        print(f"=> When amp_level is O0, using fixed loss_scale with {args.loss_scale}")


In [20]:

class SoftTargetCrossEntropy(LossBase):
    """SoftTargetCrossEntropy for MixUp Augment"""

    def __init__(self):
        super(SoftTargetCrossEntropy, self).__init__()
        self.mean_ops = P.ReduceMean(keep_dims=False)
        self.sum_ops = P.ReduceSum(keep_dims=False)
        self.log_softmax = P.LogSoftmax()

    def construct(self, logit, label):
        logit = P.Cast()(logit, mstype.float32)
        label = P.Cast()(label, mstype.float32)
        loss = self.sum_ops(-label * self.log_softmax(logit), -1)
        return self.mean_ops(loss)

class CrossEntropySmooth(LossBase):

    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = P.OneHot()
        self.sparse = sparse
        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
        self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
        self.cast = ops.Cast()

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
        loss2 = self.ce(logit, label)
        return loss2

def get_criterion(args):
    """Get loss function from args.label_smooth and args.mix_up"""
    assert args.label_smoothing >= 0. and args.label_smoothing <= 1.

    if args.mix_up > 0. or args.cutmix > 0.:
        print(25 * "=" + "Using MixBatch" + 25 * "=")
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.label_smoothing > 0.:
        print(25 * "=" + "Using label smoothing" + 25 * "=")
        criterion = CrossEntropySmooth(sparse=True, reduction="mean",
        print(25 * "=" + "Using Simple CE" + 25 * "=")
        criterion = CrossEntropySmooth(sparse=True, reduction="mean", num_classes=args.num_classes)

    return criterion

class NetWithLoss(nn.Cell):
       NetWithLoss: Only support Network with Classfication

    def __init__(self, model, criterion):
        super(NetWithLoss, self).__init__()
        self.model = model
        self.criterion = criterion

    def construct(self, data, label):
        predict = self.model(data)
        loss = self.criterion(predict, label)
        return loss


In [21]:

_grad_scale = Cps.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()

@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
    return grad * F.cast(reciprocal(scale), F.dtype(grad))

@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
    return RowTensor(grad.indices,
                     grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),

class TrainClipGrad(nn.TrainOneStepWithLossScaleCell):
    Encapsulation class of SSD network training.

    Append an optimizer to the training network after that the construct
    function can be called to create the backward graph.

        network (Cell): The training network. Note that loss function should have been added.
        optimizer (Optimizer): Optimizer for updating the weights.
        sens (Number): The adjust parameter. Default: 1.0.
        use_global_nrom(bool): Whether apply global norm before optimizer. Default: False

    def __init__(self, network, optimizer,
                 scale_sense=1.0, use_global_norm=True,
        super(TrainClipGrad, self).__init__(network, optimizer, scale_sense)
        self.use_global_norm = use_global_norm
        self.clip_global_norm_value = clip_global_norm_value
        self.print = P.Print()

    def construct(self, *inputs):
        weights = self.weights
        loss = self.network(*inputs)
        scaling_sens = self.scale_sense

        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

        scaling_sens_filled = Cps.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
        grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
        # apply grad reducer on grads
        grads = self.grad_reducer(grads)
        # get the overflow buffer
        cond = self.get_overflow_status(status, grads)
        overflow = self.process_loss_scale(cond)
        # if there is no overflow, do optimize
        if not overflow:
            if self.use_global_norm:
                grads = Cps.clip_by_global_norm(grads, clip_norm=self.clip_global_norm_value)
            loss = F.depend(loss, self.optimizer(grads))
            self.print("=============Over Flow, skipping=============")
        return loss

def get_train_one_step(args, net_with_loss, optimizer):
    """get_train_one_step cell"""
    if args.is_dynamic_loss_scale:
        print(f"=> Using DynamicLossScaleUpdateCell")
        scale_sense = nn.wrap.loss_scale.DynamicLossScaleUpdateCell(loss_scale_value=2 ** 24, scale_factor=2,
        print(f"=> Using FixedLossScaleUpdateCell, loss_scale_value:{args.loss_scale}")
        scale_sense = nn.wrap.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale)
    net_with_loss = TrainClipGrad(net_with_loss, optimizer, scale_sense=scale_sense,
    return net_with_loss


In [22]:

__all__ = ["multistep_lr", "cosine_lr", "constant_lr", "get_policy", "exp_lr"]

def get_policy(name):
    """get lr policy from name"""
    if name is None:
        return constant_lr

    out_dict = {
        "constant_lr": constant_lr,
        "cosine_lr": cosine_lr,
        "multistep_lr": multistep_lr,
        "exp_lr": exp_lr,

    return out_dict[name]

def constant_lr(args, batch_num):
    """Get constant lr"""
    learning_rate = []

    def _lr_adjuster(epoch):
        if epoch < args.warmup_length:
            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
            lr = args.base_lr

        return lr

    for epoch in range(args.epochs):
        for batch in range(batch_num):
            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
    return learning_rate

def exp_lr(args, batch_num):
    """Get exp lr """
    learning_rate = []

    def _lr_adjuster(epoch):
        if epoch < args.warmup_length:
            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
            lr = args.base_lr * args.lr_gamma ** epoch

        return lr

    for epoch in range(args.epochs):
        for batch in range(batch_num):
            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
    return learning_rate

def cosine_lr(args, batch_num):
    """Get cosine lr"""
    learning_rate = []

    def _lr_adjuster(epoch):
        if epoch < args.warmup_length:
            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
            e = epoch - args.warmup_length
            es = args.epochs - args.warmup_length
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * args.base_lr

        return lr

    for epoch in range(args.epochs):
        for batch in range(batch_num):
            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
    return learning_rate

def multistep_lr(args, batch_num):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    learning_rate = []

    def _lr_adjuster(epoch):
        lr = args.base_lr * (args.lr_gamma ** (epoch / args.lr_adjust))
        return lr

    for epoch in range(args.epochs):
        for batch in range(batch_num):
            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
    return learning_rate

def _warmup_lr(warmup_lr, base_lr, warmup_length, epoch):
    """Linear warmup"""
    return epoch / warmup_length * (base_lr - warmup_lr) + warmup_lr

def get_learning_rate(args, batch_num):
    """Get learning rate"""
    return get_policy(args.lr_scheduler)(args, batch_num)

In [23]:

def get_optimizer(args, model, batch_num):
    """Get optimizer for training"""
    print(f"=> When using train_wrapper, using optimizer {args.optimizer}")
    args.start_epoch = int(args.start_epoch)
    optim_type = args.optimizer.lower()
    params = get_param_groups(model)
    learning_rate = get_learning_rate(args, batch_num)
    step = int(args.start_epoch * batch_num)
    accumulation_step = int(args.accumulation_step)
    learning_rate = learning_rate[step::accumulation_step]
    train_step = len(learning_rate)
    print(f"=> Get LR from epoch: {args.start_epoch}\n"
          f"=> Start step: {step}\n"
          f"=> Total step: {train_step}\n"
          f"=> Accumulation step:{accumulation_step}")
    learning_rate = learning_rate * args.batch_size * int(os.getenv("DEVICE_NUM", args.device_num)) / 512.
    if accumulation_step > 1:
        learning_rate = learning_rate * accumulation_step

    if optim_type == "momentum":
        optim = Momentum(
    elif optim_type == "adamw":
        optim = AdamWeightDecay(
        raise ValueError(f"optimizer {optim_type} is not supported")

    return optim

def get_param_groups(network):
    """ get param groups """
    decay_params = []
    no_decay_params = []
    for x in network.trainable_params():
        parameter_name = x.name
        if parameter_name.endswith(".weight"):
            # Dense or Conv's weight using weight decay
            # all bias not using weight decay
            # bn weight bias not using weight decay, be carefully for now x not include LN

    return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

In [24]:

def get_swintransformer(args):
    """get swintransformer according to args"""
    # override args
    image_size = args.image_size
    patch_size = args.patch_size
    in_chans = args.in_channel
    embed_dim = args.embed_dim
    depths = args.depths
    num_heads = args.num_heads
    window_size = args.window_size
    drop_path_rate = args.drop_path_rate
    mlp_ratio = args.mlp_ratio
    qkv_bias = True
    qk_scale = None
    ape = args.ape
    patch_norm = args.patch_norm
    print(25 * "=" + "MODEL CONFIG" + 25 * "=")
    print(f"==> IMAGE_SIZE:         {image_size}")
    print(f"==> PATCH_SIZE:         {patch_size}")
    print(f"==> NUM_CLASSES:        {args.num_classes}")
    print(f"==> EMBED_DIM:          {embed_dim}")
    print(f"==> NUM_HEADS:          {num_heads}")
    print(f"==> DEPTHS:             {depths}")
    print(f"==> WINDOW_SIZE:        {window_size}")
    print(f"==> MLP_RATIO:          {mlp_ratio}")
    print(f"==> QKV_BIAS:           {qkv_bias}")
    print(f"==> QK_SCALE:           {qk_scale}")
    print(f"==> DROP_PATH_RATE:     {drop_path_rate}")
    print(f"==> APE:                {ape}")
    print(f"==> PATCH_NORM:         {patch_norm}")
    print(25 * "=" + "FINISHED" + 25 * "=")
    model = SwinTransformer(image_size=image_size,
    # print(model)
    return model


In [25]:

class EvaluateCallBack(Callback):
    def __init__(self, model, eval_dataset):
        super(EvaluateCallBack, self).__init__()
        self.model = model
        self.eval_dataset = eval_dataset
        self.best_acc = 0.

    def epoch_end(self, run_context):
            Test when epoch end, save best model with best.ckpt.
        cb_params = run_context.original_args()
        cur_epoch_num = cb_params.cur_epoch_num
        result = self.model.eval(self.eval_dataset)
        if result["Acc"] > self.best_acc:
            self.best_acc = result["Acc"]
            mindspore.save_checkpoint(cb_params.train_network, 'ckpt/best.ckpt')
        print("epoch: %s acc: %s top5-acc: %s, best acc is %s" %
              (cb_params.cur_epoch_num, result["Acc"], result["Top5-Acc"], self.best_acc), flush=True)


In [26]:

def get_dataset(args, training=True):
    """"Get model according to args.set"""
    print(f"=> Getting {args.set} dataset")
    dataset = ImageNet(args, training)

    return dataset

def set_device(args):
    """Set device and ParallelMode(if device_num > 1)"""
    rank = 0
    # set context and device
    device_target = args.device_target
    device_num = int(os.environ.get("DEVICE_NUM", 1))

    if device_target == "Ascend":
        if device_num > 1:
            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
            # context.set_auto_parallel_context(pipeline_stages=2, full_batch=True)

            rank = get_rank()
    elif device_target == "GPU":
        if device_num > 1:
            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
            rank = get_rank()
        raise ValueError("Unsupported platform.")

    return rank

In [27]:

assert args.crop, f"{args.arch} is only for evaluation"
mode = {
    0: context.GRAPH_MODE,
    1: context.PYNATIVE_MODE
context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
if args.device_target == "Ascend":
rank = set_device(args)

# get model and cast amp_level
net = get_swintransformer(args)
criterion = get_criterion(args)
net_with_loss = NetWithLoss(net, criterion)

data = get_dataset(args)
batch_num = data.train_dataset.get_dataset_size()
optimizer = get_optimizer(args, net, batch_num)

net_with_loss = get_train_one_step(args, net_with_loss, optimizer)

eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
eval_indexes = [0, 1, 2]

eval_metrics = {'Loss': nn.Loss(),
                'Acc': nn.Accuracy(),
                'Top5-Acc': nn.Top5CategoricalAccuracy()}

model = Model(net_with_loss, metrics=eval_metrics,

# checkpoint and callback settings
# keep_checkpoint_max (int) - 最多保存多少个checkpoint文件。默认值:1。
config_ck = CheckpointConfig(save_checkpoint_steps=data.train_dataset.get_dataset_size(),
ckpt_save_dir = "./ckpt"

time_cb = TimeMonitor(data_size=data.train_dataset.get_dataset_size())
# prefix (str) - checkpoint文件的前缀名称。默认值:’CKP’。
# directory (str) - 保存checkpoint文件的文件夹路径。默认情况下,文件保存在当前目录下。默认值:None。
# config (CheckpointConfig) - checkpoint策略配置。默认值:None。
ckpoint_cb = ModelCheckpoint(prefix=args.arch, directory=ckpt_save_dir, config=config_ck)
loss_cb = LossMonitor()
eval_cb = EvaluateCallBack(model=model, eval_dataset=data.val_dataset)

print("begin train")
model.train(int(args.epochs - args.start_epoch), data.train_dataset,
            callbacks=[time_cb, ckpoint_cb, loss_cb, eval_cb],
print("train success")
=========================MODEL CONFIG=========================
==> IMAGE_SIZE:         224
==> PATCH_SIZE:         4
==> NUM_CLASSES:        100
==> EMBED_DIM:          96
==> NUM_HEADS:          [3, 6, 12, 24]
==> DEPTHS:             [2, 2, 6, 2]
==> WINDOW_SIZE:        7
==> MLP_RATIO:          4.0
==> QKV_BIAS:           True
==> QK_SCALE:           None
==> DROP_PATH_RATE:     0.2
==> APE:                False
==> PATCH_NORM:         True
[WARNING] ME(2932479:139625079657088,MainProcess):2022-10-22-17:23:09.484.580 [mindspore/train/model.py:1075] For EvaluateCallBack callback, {'epoch_end'} methods may not be supported in later version, Use methods prefixed with 'on_train' or 'on_eval' instead when using customized callbacks.
=> using amp_level O1
=> change swin_tiny_patch4_window7_224 to fp16
=> cast (<class 'mindspore.nn.layer.activation.GELU'>, <class 'mindspore.nn.layer.activation.Softmax'>, <class 'mindspore.nn.layer.conv.Conv2d'>, <class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>) to fp32 back
=========================Using MixBatch=========================
=> Getting ImageNet dataset
=> When using train_wrapper, using optimizer adamw
=> Get LR from epoch: 0
=> Start step: 0
=> Total step: 32800
=> Accumulation step:1
=> Using DynamicLossScaleUpdateCell
begin train
=============Over Flow, skipping=============
=============Over Flow, skipping=============
=============Over Flow, skipping=============
=============Over Flow, skipping=============
=============Over Flow, skipping=============
epoch: 1 step: 328, loss is 4.606683254241943
Train epoch time: 98225.719 ms, per step time: 299.469 ms
epoch: 1 acc: 0.04364809782608696 top5-acc: 0.1559103260869565, best acc is 0.04364809782608696
epoch: 2 step: 328, loss is 4.516416072845459
Train epoch time: 77570.307 ms, per step time: 236.495 ms
epoch: 2 acc: 0.05672554347826087 top5-acc: 0.1875, best acc is 0.05672554347826087
epoch: 3 step: 328, loss is 4.447005748748779
Train epoch time: 77713.935 ms, per step time: 236.933 ms
epoch: 3 acc: 0.06725543478260869 top5-acc: 0.2027853260869565, best acc is 0.06725543478260869
epoch: 4 step: 328, loss is 4.451318740844727
Train epoch time: 77429.944 ms, per step time: 236.067 ms
epoch: 4 acc: 0.07540760869565218 top5-acc: 0.2301290760869565, best acc is 0.07540760869565218
epoch: 5 step: 328, loss is 4.380975246429443
Train epoch time: 77789.145 ms, per step time: 237.162 ms
epoch: 5 acc: 0.08967391304347826 top5-acc: 0.2686820652173913, best acc is 0.08967391304347826
epoch: 6 step: 328, loss is 4.443111896514893
Train epoch time: 78030.254 ms, per step time: 237.897 ms
epoch: 6 acc: 0.11090353260869565 top5-acc: 0.3021399456521739, best acc is 0.11090353260869565
=============Over Flow, skipping=============
epoch: 7 step: 328, loss is 4.38911247253418
Train epoch time: 77914.136 ms, per step time: 237.543 ms
epoch: 7 acc: 0.13671875 top5-acc: 0.33474864130434784, best acc is 0.13671875
epoch: 8 step: 328, loss is 4.416530609130859
Train epoch time: 78155.158 ms, per step time: 238.278 ms
epoch: 8 acc: 0.12516983695652173 top5-acc: 0.335258152173913, best acc is 0.13671875
epoch: 9 step: 328, loss is 4.346778392791748
Train epoch time: 77911.331 ms, per step time: 237.535 ms
epoch: 9 acc: 0.14673913043478262 top5-acc: 0.36905570652173914, best acc is 0.14673913043478262
epoch: 10 step: 328, loss is 4.5370354652404785
Train epoch time: 77869.734 ms, per step time: 237.408 ms
epoch: 10 acc: 0.16559103260869565 top5-acc: 0.41389266304347827, best acc is 0.16559103260869565
epoch: 11 step: 328, loss is 4.40866756439209
Train epoch time: 77792.362 ms, per step time: 237.172 ms
epoch: 11 acc: 0.16932744565217392 top5-acc: 0.42595108695652173, best acc is 0.16932744565217392
epoch: 12 step: 328, loss is 4.258575439453125
Train epoch time: 78045.822 ms, per step time: 237.945 ms
epoch: 12 acc: 0.1883491847826087 top5-acc: 0.4429347826086957, best acc is 0.1883491847826087
=============Over Flow, skipping=============
epoch: 13 step: 328, loss is 4.406404972076416
Train epoch time: 78015.169 ms, per step time: 237.851 ms
epoch: 13 acc: 0.19208559782608695 top5-acc: 0.45023777173913043, best acc is 0.19208559782608695
epoch: 14 step: 328, loss is 4.288840293884277
Train epoch time: 77976.810 ms, per step time: 237.734 ms
epoch: 14 acc: 0.20074728260869565 top5-acc: 0.47452445652173914, best acc is 0.20074728260869565
epoch: 15 step: 328, loss is 4.286988735198975
Train epoch time: 77837.343 ms, per step time: 237.309 ms
epoch: 15 acc: 0.20838994565217392 top5-acc: 0.491508152173913, best acc is 0.20838994565217392
epoch: 16 step: 328, loss is 4.023094654083252
Train epoch time: 78042.042 ms, per step time: 237.933 ms
epoch: 16 acc: 0.21993885869565216 top5-acc: 0.5079823369565217, best acc is 0.21993885869565216
epoch: 17 step: 328, loss is 4.169596195220947
Train epoch time: 77697.700 ms, per step time: 236.883 ms
epoch: 17 acc: 0.22554347826086957 top5-acc: 0.5061141304347826, best acc is 0.22554347826086957
epoch: 18 step: 328, loss is 4.3070878982543945
Train epoch time: 78022.776 ms, per step time: 237.874 ms
epoch: 18 acc: 0.23539402173913043 top5-acc: 0.5283627717391305, best acc is 0.23539402173913043
=============Over Flow, skipping=============
epoch: 19 step: 328, loss is 4.109980583190918
Train epoch time: 78016.262 ms, per step time: 237.854 ms
epoch: 19 acc: 0.25577445652173914 top5-acc: 0.5222486413043478, best acc is 0.25577445652173914
epoch: 20 step: 328, loss is 4.031306743621826
Train epoch time: 77715.961 ms, per step time: 236.939 ms
epoch: 20 acc: 0.25900135869565216 top5-acc: 0.5385529891304348, best acc is 0.25900135869565216
epoch: 21 step: 328, loss is 4.215546607971191
Train epoch time: 77684.027 ms, per step time: 236.842 ms
epoch: 21 acc: 0.25968070652173914 top5-acc: 0.5531589673913043, best acc is 0.25968070652173914
epoch: 22 step: 328, loss is 4.093695640563965
Train epoch time: 78090.945 ms, per step time: 238.082 ms
epoch: 22 acc: 0.2600203804347826 top5-acc: 0.5543478260869565, best acc is 0.2600203804347826
epoch: 23 step: 328, loss is 4.0735273361206055
Train epoch time: 77968.515 ms, per step time: 237.709 ms
epoch: 23 acc: 0.2787024456521739 top5-acc: 0.5686141304347826, best acc is 0.2787024456521739
epoch: 24 step: 328, loss is 4.377508163452148
Train epoch time: 77772.071 ms, per step time: 237.110 ms
epoch: 24 acc: 0.2798913043478261 top5-acc: 0.5782948369565217, best acc is 0.2798913043478261
=============Over Flow, skipping=============
epoch: 25 step: 328, loss is 3.937976360321045
Train epoch time: 77997.789 ms, per step time: 237.798 ms
epoch: 25 acc: 0.2916100543478261 top5-acc: 0.5825407608695652, best acc is 0.2916100543478261
=============Over Flow, skipping=============
epoch: 26 step: 328, loss is 4.130610942840576
Train epoch time: 77919.894 ms, per step time: 237.561 ms
epoch: 26 acc: 0.29483695652173914 top5-acc: 0.5840692934782609, best acc is 0.29483695652173914
epoch: 27 step: 328, loss is 3.963009834289551
Train epoch time: 78296.125 ms, per step time: 238.708 ms
epoch: 27 acc: 0.30570652173913043 top5-acc: 0.6103940217391305, best acc is 0.30570652173913043
epoch: 28 step: 328, loss is 3.94366717338562
Train epoch time: 78107.002 ms, per step time: 238.131 ms
epoch: 28 acc: 0.31266983695652173 top5-acc: 0.6154891304347826, best acc is 0.31266983695652173
epoch: 29 step: 328, loss is 3.7120132446289062
Train epoch time: 77957.414 ms, per step time: 237.675 ms
epoch: 29 acc: 0.32319972826086957 top5-acc: 0.627547554347826, best acc is 0.32319972826086957
epoch: 30 step: 328, loss is 4.277674198150635
Train epoch time: 78170.149 ms, per step time: 238.324 ms
epoch: 30 acc: 0.311820652173913 top5-acc: 0.6129415760869565, best acc is 0.32319972826086957
epoch: 31 step: 328, loss is 4.324661731719971
Train epoch time: 78529.550 ms, per step time: 239.419 ms
epoch: 31 acc: 0.3255774456521739 top5-acc: 0.6299252717391305, best acc is 0.3255774456521739
=============Over Flow, skipping=============
epoch: 32 step: 328, loss is 4.174750804901123
Train epoch time: 77924.515 ms, per step time: 237.575 ms
epoch: 32 acc: 0.33695652173913043 top5-acc: 0.6441915760869565, best acc is 0.33695652173913043
epoch: 33 step: 328, loss is 3.8301479816436768
Train epoch time: 77852.323 ms, per step time: 237.355 ms
epoch: 33 acc: 0.3422214673913043 top5-acc: 0.641983695652174, best acc is 0.3422214673913043
epoch: 34 step: 328, loss is 4.2253923416137695
Train epoch time: 78160.324 ms, per step time: 238.294 ms
epoch: 34 acc: 0.32778532608695654 top5-acc: 0.6323029891304348, best acc is 0.3422214673913043
epoch: 35 step: 328, loss is 4.316101551055908
Train epoch time: 77949.068 ms, per step time: 237.650 ms
epoch: 35 acc: 0.358695652173913 top5-acc: 0.6569293478260869, best acc is 0.358695652173913
epoch: 36 step: 328, loss is 4.160644054412842
Train epoch time: 78140.680 ms, per step time: 238.234 ms
epoch: 36 acc: 0.32999320652173914 top5-acc: 0.6413043478260869, best acc is 0.358695652173913
epoch: 37 step: 328, loss is 3.802969217300415
Train epoch time: 78095.721 ms, per step time: 238.097 ms
epoch: 37 acc: 0.36379076086956524 top5-acc: 0.6569293478260869, best acc is 0.36379076086956524
epoch: 38 step: 328, loss is 4.089756011962891
Train epoch time: 78088.570 ms, per step time: 238.075 ms
epoch: 38 acc: 0.37313179347826086 top5-acc: 0.6798573369565217, best acc is 0.37313179347826086
=============Over Flow, skipping=============
epoch: 39 step: 328, loss is 3.53753662109375
Train epoch time: 77897.645 ms, per step time: 237.493 ms
epoch: 39 acc: 0.3828125 top5-acc: 0.6759510869565217, best acc is 0.3828125
epoch: 40 step: 328, loss is 3.609902858734131
Train epoch time: 78314.654 ms, per step time: 238.764 ms
epoch: 40 acc: 0.3804347826086957 top5-acc: 0.6703464673913043, best acc is 0.3828125
epoch: 41 step: 328, loss is 4.303561210632324
Train epoch time: 78110.555 ms, per step time: 238.142 ms
epoch: 41 acc: 0.39809782608695654 top5-acc: 0.6931046195652174, best acc is 0.39809782608695654
epoch: 42 step: 328, loss is 3.4438838958740234
Train epoch time: 77831.506 ms, per step time: 237.291 ms
epoch: 42 acc: 0.3931725543478261 top5-acc: 0.6961616847826086, best acc is 0.39809782608695654
epoch: 43 step: 328, loss is 3.5862011909484863
Train epoch time: 78138.125 ms, per step time: 238.226 ms
epoch: 43 acc: 0.3967391304347826 top5-acc: 0.6983695652173914, best acc is 0.39809782608695654
epoch: 44 step: 328, loss is 4.128260612487793
Train epoch time: 77968.969 ms, per step time: 237.710 ms
epoch: 44 acc: 0.39249320652173914 top5-acc: 0.6944633152173914, best acc is 0.39809782608695654
=============Over Flow, skipping=============
epoch: 45 step: 328, loss is 4.155826091766357
Train epoch time: 77874.741 ms, per step time: 237.423 ms
epoch: 45 acc: 0.40591032608695654 top5-acc: 0.7055027173913043, best acc is 0.40591032608695654
epoch: 46 step: 328, loss is 3.9320266246795654
Train epoch time: 78020.188 ms, per step time: 237.866 ms
epoch: 46 acc: 0.4154211956521739 top5-acc: 0.7155230978260869, best acc is 0.4154211956521739
epoch: 47 step: 328, loss is 4.123641014099121
Train epoch time: 78344.166 ms, per step time: 238.854 ms
epoch: 47 acc: 0.4155910326086957 top5-acc: 0.7116168478260869, best acc is 0.4155910326086957
epoch: 48 step: 328, loss is 4.011114120483398
Train epoch time: 78071.050 ms, per step time: 238.021 ms
epoch: 48 acc: 0.4232336956521739 top5-acc: 0.7280910326086957, best acc is 0.4232336956521739
epoch: 49 step: 328, loss is 4.114437580108643
Train epoch time: 78166.502 ms, per step time: 238.313 ms
epoch: 49 acc: 0.4327445652173913 top5-acc: 0.7359035326086957, best acc is 0.4327445652173913
epoch: 50 step: 328, loss is 4.075896739959717
Train epoch time: 77756.753 ms, per step time: 237.063 ms
epoch: 50 acc: 0.4201766304347826 top5-acc: 0.7141644021739131, best acc is 0.4327445652173913
=============Over Flow, skipping=============
epoch: 51 step: 328, loss is 4.036683082580566
Train epoch time: 77974.045 ms, per step time: 237.726 ms
epoch: 51 acc: 0.43783967391304346 top5-acc: 0.7260529891304348, best acc is 0.43783967391304346
epoch: 52 step: 328, loss is 4.108072280883789
Train epoch time: 78001.574 ms, per step time: 237.810 ms
epoch: 52 acc: 0.4320652173913043 top5-acc: 0.7331861413043478, best acc is 0.43783967391304346
epoch: 53 step: 328, loss is 3.69376802444458
Train epoch time: 77847.631 ms, per step time: 237.340 ms
epoch: 53 acc: 0.4505774456521739 top5-acc: 0.7437160326086957, best acc is 0.4505774456521739
epoch: 54 step: 328, loss is 4.165959358215332
Train epoch time: 77814.628 ms, per step time: 237.240 ms
epoch: 54 acc: 0.44089673913043476 top5-acc: 0.7328464673913043, best acc is 0.4505774456521739
epoch: 55 step: 328, loss is 4.163305282592773
Train epoch time: 78025.681 ms, per step time: 237.883 ms
epoch: 55 acc: 0.45499320652173914 top5-acc: 0.7367527173913043, best acc is 0.45499320652173914
epoch: 56 step: 328, loss is 4.002499580383301
Train epoch time: 78114.233 ms, per step time: 238.153 ms
epoch: 56 acc: 0.45822010869565216 top5-acc: 0.7454144021739131, best acc is 0.45822010869565216
=============Over Flow, skipping=============
epoch: 57 step: 328, loss is 3.964570999145508
Train epoch time: 77909.596 ms, per step time: 237.529 ms
epoch: 57 acc: 0.47673233695652173 top5-acc: 0.7567934782608695, best acc is 0.47673233695652173
epoch: 58 step: 328, loss is 3.8768999576568604
Train epoch time: 77917.612 ms, per step time: 237.554 ms
epoch: 58 acc: 0.47843070652173914 top5-acc: 0.7554347826086957, best acc is 0.47843070652173914
epoch: 59 step: 328, loss is 3.3266918659210205
Train epoch time: 77987.581 ms, per step time: 237.767 ms
epoch: 59 acc: 0.4782608695652174 top5-acc: 0.7637567934782609, best acc is 0.47843070652173914
epoch: 60 step: 328, loss is 3.9986932277679443
Train epoch time: 78162.117 ms, per step time: 238.299 ms
epoch: 60 acc: 0.4748641304347826 top5-acc: 0.7651154891304348, best acc is 0.47843070652173914
epoch: 61 step: 328, loss is 3.554734230041504
Train epoch time: 79327.994 ms, per step time: 241.854 ms
epoch: 61 acc: 0.47690217391304346 top5-acc: 0.7627377717391305, best acc is 0.47843070652173914
epoch: 62 step: 328, loss is 4.098214149475098
Train epoch time: 79343.368 ms, per step time: 241.901 ms
epoch: 62 acc: 0.48012907608695654 top5-acc: 0.765625, best acc is 0.48012907608695654
=============Over Flow, skipping=============
epoch: 63 step: 328, loss is 3.7967209815979004
Train epoch time: 78342.405 ms, per step time: 238.849 ms
epoch: 63 acc: 0.48556385869565216 top5-acc: 0.7676630434782609, best acc is 0.48556385869565216
epoch: 64 step: 328, loss is 3.398728609085083
Train epoch time: 78146.445 ms, per step time: 238.251 ms
epoch: 64 acc: 0.4835258152173913 top5-acc: 0.7725883152173914, best acc is 0.48556385869565216
epoch: 65 step: 328, loss is 3.307548999786377
Train epoch time: 77922.219 ms, per step time: 237.568 ms
epoch: 65 acc: 0.4984714673913043 top5-acc: 0.778702445652174, best acc is 0.4984714673913043
epoch: 66 step: 328, loss is 4.043028354644775
Train epoch time: 78039.694 ms, per step time: 237.926 ms
epoch: 66 acc: 0.5010190217391305 top5-acc: 0.7793817934782609, best acc is 0.5010190217391305
epoch: 67 step: 328, loss is 3.7040460109710693
Train epoch time: 77856.072 ms, per step time: 237.366 ms
epoch: 67 acc: 0.49966032608695654 top5-acc: 0.7797214673913043, best acc is 0.5010190217391305
epoch: 68 step: 328, loss is 3.765737771987915
Train epoch time: 78541.913 ms, per step time: 239.457 ms
epoch: 68 acc: 0.5047554347826086 top5-acc: 0.783797554347826, best acc is 0.5047554347826086
epoch: 69 step: 328, loss is 3.8921711444854736
Train epoch time: 78037.338 ms, per step time: 237.919 ms
epoch: 69 acc: 0.5137567934782609 top5-acc: 0.7843070652173914, best acc is 0.5137567934782609
epoch: 70 step: 328, loss is 3.996326446533203
Train epoch time: 77872.417 ms, per step time: 237.416 ms
epoch: 70 acc: 0.5125679347826086 top5-acc: 0.7875339673913043, best acc is 0.5137567934782609
epoch: 71 step: 328, loss is 4.144741535186768
Train epoch time: 77957.219 ms, per step time: 237.674 ms
epoch: 71 acc: 0.5137567934782609 top5-acc: 0.790421195652174, best acc is 0.5137567934782609
epoch: 72 step: 328, loss is 3.6382193565368652
Train epoch time: 78040.411 ms, per step time: 237.928 ms
epoch: 72 acc: 0.5127377717391305 top5-acc: 0.7860054347826086, best acc is 0.5137567934782609
epoch: 73 step: 328, loss is 3.5755395889282227
Train epoch time: 78034.322 ms, per step time: 237.910 ms
epoch: 73 acc: 0.5159646739130435 top5-acc: 0.7878736413043478, best acc is 0.5159646739130435
epoch: 74 step: 328, loss is 3.6679434776306152
Train epoch time: 77881.833 ms, per step time: 237.445 ms
epoch: 74 acc: 0.5193614130434783 top5-acc: 0.7931385869565217, best acc is 0.5193614130434783
=============Over Flow, skipping=============
epoch: 75 step: 328, loss is 3.8212318420410156
Train epoch time: 78212.435 ms, per step time: 238.453 ms
epoch: 75 acc: 0.5236073369565217 top5-acc: 0.7936480978260869, best acc is 0.5236073369565217
epoch: 76 step: 328, loss is 3.811330556869507
Train epoch time: 77924.677 ms, per step time: 237.575 ms
epoch: 76 acc: 0.5263247282608695 top5-acc: 0.7961956521739131, best acc is 0.5263247282608695
epoch: 77 step: 328, loss is 3.0529000759124756
Train epoch time: 77880.984 ms, per step time: 237.442 ms
epoch: 77 acc: 0.5264945652173914 top5-acc: 0.7990828804347826, best acc is 0.5264945652173914
epoch: 78 step: 328, loss is 3.051379680633545
Train epoch time: 77631.681 ms, per step time: 236.682 ms
epoch: 78 acc: 0.5320991847826086 top5-acc: 0.8023097826086957, best acc is 0.5320991847826086
epoch: 79 step: 328, loss is 4.120721340179443
Train epoch time: 77824.862 ms, per step time: 237.271 ms
epoch: 79 acc: 0.53125 top5-acc: 0.7989130434782609, best acc is 0.5320991847826086
epoch: 80 step: 328, loss is 3.954958915710449
Train epoch time: 77628.485 ms, per step time: 236.672 ms
epoch: 80 acc: 0.5353260869565217 top5-acc: 0.8002717391304348, best acc is 0.5353260869565217
=============Over Flow, skipping=============
epoch: 81 step: 328, loss is 3.724217176437378
Train epoch time: 78065.812 ms, per step time: 238.006 ms
epoch: 81 acc: 0.5363451086956522 top5-acc: 0.8019701086956522, best acc is 0.5363451086956522
epoch: 82 step: 328, loss is 3.155369997024536
Train epoch time: 77924.482 ms, per step time: 237.575 ms
epoch: 82 acc: 0.537703804347826 top5-acc: 0.8067255434782609, best acc is 0.537703804347826
epoch: 83 step: 328, loss is 3.3825459480285645
Train epoch time: 77953.530 ms, per step time: 237.663 ms
epoch: 83 acc: 0.5354959239130435 top5-acc: 0.803328804347826, best acc is 0.537703804347826
=============Over Flow, skipping=============
epoch: 84 step: 328, loss is 3.9255869388580322
Train epoch time: 77817.596 ms, per step time: 237.249 ms
epoch: 84 acc: 0.5433084239130435 top5-acc: 0.8092730978260869, best acc is 0.5433084239130435
epoch: 85 step: 328, loss is 3.62528133392334
Train epoch time: 78312.543 ms, per step time: 238.758 ms
epoch: 85 acc: 0.5436480978260869 top5-acc: 0.8045176630434783, best acc is 0.5436480978260869
epoch: 86 step: 328, loss is 3.8779077529907227
Train epoch time: 79035.696 ms, per step time: 240.962 ms
epoch: 86 acc: 0.5456861413043478 top5-acc: 0.80859375, best acc is 0.5456861413043478
epoch: 87 step: 328, loss is 4.07357931137085
Train epoch time: 77741.437 ms, per step time: 237.017 ms
epoch: 87 acc: 0.544327445652174 top5-acc: 0.8034986413043478, best acc is 0.5456861413043478
epoch: 88 step: 328, loss is 3.640528917312622
Train epoch time: 77927.202 ms, per step time: 237.583 ms
epoch: 88 acc: 0.5438179347826086 top5-acc: 0.8074048913043478, best acc is 0.5456861413043478
epoch: 89 step: 328, loss is 3.3842782974243164
Train epoch time: 78080.250 ms, per step time: 238.050 ms
epoch: 89 acc: 0.5448369565217391 top5-acc: 0.8108016304347826, best acc is 0.5456861413043478
epoch: 90 step: 328, loss is 3.9701128005981445
Train epoch time: 78034.974 ms, per step time: 237.912 ms
epoch: 90 acc: 0.5473845108695652 top5-acc: 0.8106317934782609, best acc is 0.5473845108695652
epoch: 91 step: 328, loss is 3.93862247467041
Train epoch time: 77777.080 ms, per step time: 237.125 ms
epoch: 91 acc: 0.5487432065217391 top5-acc: 0.8091032608695652, best acc is 0.5487432065217391
epoch: 92 step: 328, loss is 3.3578507900238037
Train epoch time: 77840.785 ms, per step time: 237.319 ms
epoch: 92 acc: 0.5499320652173914 top5-acc: 0.8087635869565217, best acc is 0.5499320652173914
epoch: 93 step: 328, loss is 3.402940034866333
Train epoch time: 77841.572 ms, per step time: 237.322 ms
epoch: 93 acc: 0.5516304347826086 top5-acc: 0.8077445652173914, best acc is 0.5516304347826086
epoch: 94 step: 328, loss is 3.50596022605896
Train epoch time: 78055.788 ms, per step time: 237.975 ms
epoch: 94 acc: 0.5514605978260869 top5-acc: 0.8106317934782609, best acc is 0.5516304347826086
epoch: 95 step: 328, loss is 3.9884400367736816
Train epoch time: 77863.441 ms, per step time: 237.389 ms
epoch: 95 acc: 0.5516304347826086 top5-acc: 0.8092730978260869, best acc is 0.5516304347826086
=============Over Flow, skipping=============
epoch: 96 step: 328, loss is 3.4499902725219727
Train epoch time: 77863.299 ms, per step time: 237.388 ms
epoch: 96 acc: 0.5526494565217391 top5-acc: 0.8079144021739131, best acc is 0.5526494565217391
epoch: 97 step: 328, loss is 3.518087387084961
Train epoch time: 77672.538 ms, per step time: 236.807 ms
epoch: 97 acc: 0.5526494565217391 top5-acc: 0.8102921195652174, best acc is 0.5526494565217391
epoch: 98 step: 328, loss is 2.900960922241211
Train epoch time: 77649.499 ms, per step time: 236.736 ms
epoch: 98 acc: 0.5512907608695652 top5-acc: 0.8084239130434783, best acc is 0.5526494565217391
epoch: 99 step: 328, loss is 3.933858633041382
Train epoch time: 77925.955 ms, per step time: 237.579 ms
epoch: 99 acc: 0.5531589673913043 top5-acc: 0.8104619565217391, best acc is 0.5531589673913043
epoch: 100 step: 328, loss is 2.9689321517944336
Train epoch time: 77380.916 ms, per step time: 235.917 ms
epoch: 100 acc: 0.5514605978260869 top5-acc: 0.809952445652174, best acc is 0.5531589673913043
train success


In [28]:

def test(net, model):
    # load the saved model for evaluation
    param_dict = load_checkpoint("ckpt/best.ckpt")
    # load parameter to the network
    load_param_into_net(net, param_dict)
    # load testing dataset
    data = get_dataset(args, training=False)
    print(f"=> begin eval")
    results = model.eval(data.test_dataset)
    print(f"=> eval results:{results}")
    print(f"=> eval success")

In [29]:

test(net, model)
=> Getting ImageNet dataset
=> begin eval
=> eval results:{'Loss': 1.8583465449271663, 'Acc': 0.5525873655913979, 'Top5-Acc': 0.8158602150537635}
=> eval success



资源Gefore RTX 3090 * 1
验证集miniImageNet Val,共6,000张图像
验证集分类准确率Top1-Acc: 55.32%, Top5-Acc: 81.05%
测试集miniImageNet Test,共12,000张图像
测试集分类准确率Top1-Acc: 55.26%, Top5-Acc: 81.59%





