Pytorch之MobileViT图像分类

news2025/1/15 17:21:51

文章目录

  • 前言
  • 一、Transformer存在的问题
  • 二、MobileViT
    • 1.MobileViT网络结构
      • 🍓 Vision Transformer结构
      • 🍉MobileViT结构
    • 2.MV2(MobileNet v2 block)
    • 3.MobileViT block
      • 🥇Local representations
      • 🥈Transformers as Convolutions (global representations)
      • 🥉Fusion
    • 4.模型配置
    • 5.MobileViT优势
  • 三、MobileViT网络实现
    • 1.构建网络模型
    • 2.训练和测试模型
  • 四、图像分类
  • 结束语


  • 💂 个人主页:风间琉璃
  • 🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主
  • 💬 如果文章对你有帮助欢迎关注点赞收藏(一键三连)订阅专栏

前言

MobileViT是一种基于ViT(Vision Transformer)架构的轻量级视觉模型,旨在适用于移动设备和嵌入式系统。ViT是一种非常成功的深度学习模型,用于图像分类和其他计算机视觉任务,但通常需要大量的计算资源和参数。MobileViT的目标是在保持高性能的同时,减少模型的大小和计算需求,以便在移动设备上运行,据作者介绍,这是第一次基于轻量级CNN网络性能的轻量级ViT工作,性能SOTA。性能优于MobileNetV3、CrossviT等网络。


一、Transformer存在的问题

MobileVitV1是苹果公司2021年发表的一篇轻量型主干网络,它是CNNTransfomrer混合架构模型(CNN的轻量和高效+Transformer的自注意力机制和全局视野),这样的架构模型也是现在很多研究者们青睐的架构之一。

自Vision Transformer出现之后,人们发现Transfomrer也可以应用在计算机视觉领域,并且效果还是非常不错的。但是基于Transformer的网络模型存在着以下问题:

参数多,算力要求高
Transformer模型通常具有数十亿或数百亿个参数,这使得它们的模型文件非常大,不仅占用大量存储空间,而且在训练和部署过程中也需要更多的计算资源。

缺少空间归纳偏置
即纯Transformer对空间位置信息不敏感,但是,我们在进行视觉应用的时位置信息又比较重要,为了解决这个问题就引入了位置编码。

归纳 (Induction) 是自然科学中常用的两大方法之一 (归纳与演绎,Induction & Deduction),指从一些例子中寻找共性、泛化,形成一个较通用的规则的过程。偏置 (Bias) 则是指对模型的偏好,以下展示了 4 种解释:

∙ \bullet 通俗理解:归纳偏置可以理解为,从现实生活中观察到的现象中归纳出一定的 规则 (heuristics),然后对模型做一定的约束,从而可以起到 “模型选择” 的作用,类似贝叶斯学习中的 “先验”。
∙ \bullet 西瓜书解释:机器学习算法在学习过程中对某种类型假设的偏好,称为归纳偏好。归纳偏好可以看作学习算法自身在一个庞大的假设空间中对假设进行选择的启发式或 “价值观”。
∙ \bullet 维基百科解释:如果学习器需要去预测 “其未遇到过的输入” 的结果时,则需要一些假设来帮助它做出选择。
∙ \bullet 广义解释:归纳偏置会促使学习算法优先考虑具有某些属性的解。

深度神经网络偏好性地认为,层次化处理信息有更好效果;卷积神经网络认为信息具有空间局部性,可用滑动卷积共享权重的方式降低参数空间;循环神经网络则将时序信息纳入考虑,强调顺序重要性;图网络则认为中心节点与邻居节点的相似性会更好地引导信息流动。通常,模型容量 (capacity) 很大但 Inductive Bias 匮乏则容易过拟合 (overfitting),如 Transformer

CNN的空间归纳偏差内容如下:

CNN 的 归纳偏置(Inductive Bias)局部性 (Locality) 空间不变性 (Spatial Invariance) / 平移等效性 (Translation Equivariance),即空间位置上的元素 (Grid Elements) 的联系/相关性近大远小,以及空间平移的不变性 (Kernel 权重共享)。

⋆ \star locality:CNN是以滑动窗口的形式一点一点地在图片上进行卷积的,所以假设图片上相邻的区域会有相邻的特征,靠得越近的东西相关性越强;

⋆ \star translation equivariance(平移等变性或平移同变性):用公式表示为f(g(x))=g(f(x)),不论是先经过g映射,还是先经过f映射,其结果是不变的;其中f代表卷积操作,g代表平移操作。因为在卷积神经网络中,卷积核相当于是一个模板,不论图片中同样的物体移动到哪里,只要是相同的输入,经过相同的卷积核,其输出是不变的。

一旦网络(CNN)模型有了这两个归纳偏置,它就拥有很多的先验信息所以只需要相对较少的数据就可以学习一个相对比较好的模型。但是对于transformer来说,它没有这些先验信息,所以它对视觉的感知全部需要从这些数据中自己学习。

因此transformer结构的网络模型需要大量的数据才能得到不错的效果,如果使用少量数据进行训练,那么会掉点很明显。这是因为Transformer缺少空间归纳偏置,空间归纳偏置允许CNN在不同的视觉任务中学习较少参数的表示

虽然Transformer缺少空间归纳偏置必须要大量数据来进行学习数据中的某种特性,从而导致无法很好的应用在这样的边缘设备。但是CNN也有缺点CNN在空间上获取的信息是局部的,因此一定程度上会制约着CNN网络结构的性能,而Transformer的自注意力机制能够获取全局信息。

模型迁移困难

这个问题核心是引入的位置编码导致的。 Transformer 网络需要先对原始的图像进行切片处理,一般来说训练好的 ViT 网络原始的输入图像大小 224×224,patch 大小为 16×16,那么得到的 patch个数也就固定了。由于 Transformer 网络缺少空间归纳偏置在计算某一个 token 时其他 token 位置顺序发生变化并不会影响到最终的实验结果,也即输出与位置信息无关。而我们知道对于图像来说,空间信息是非常重要且具有实际意义的,因此,Transformer 通过加上位置偏置ViT使用绝对位置偏置,Swin T引入相对位置偏置来解决位置信息的丢失问题

但是,当输入图像的尺寸或者 patch 大小发生变化时,训练好的模型就会因为位置信息不准确而失效。目前常见的处理方法是将位置偏置信息进行插值,插值到所需要的序列长度从而匹配到图像的尺寸。这种方式需要对训练好的模型进行微调才能保证性能不出现大幅损失,每次改变输入图像的尺寸或者 patch 的尺寸均需要对位置编码进行插值和对网络进行微调,这提高了网络迁移的难度

Swin T网络使用了相对位置偏置,理论上来说序列的长度只与窗 windows 的大小有关而与输入图像的尺寸无关。但是,windows的大小一般被设定与输入尺寸匹配,当输入尺寸变大时,window 的大小也应该相应的增大,那么所使用的相对位置偏置序列也应该增大,这也会导致上述问题。这些问题将导致 Transformer 网络迁移时比 CNN 网络迁移得更加困难和繁琐。

模型训练困难

根据现有的一些经验,Transformer相比CNN要更难训练。Transformer需要更多的训练数据需要迭代更多的epoch需要更大的正则项(L2正则)需要更多的数据增强(且对数据增强很敏感)。

针对以上问题,采用CNN与Transformer的混合架构CNN能够提供空间归纳偏置所以可以解决位置偏置,而且加入CNN后能够加速网络的收敛,使网络训练过程更加的稳定

二、MobileViT

1.MobileViT网络结构

🍓 Vision Transformer结构

下图是MobileViT论文中绘制的Standard visual Transformer。首先将输入的图片划分成N个Patch,然后通过线性变化将每个Patch映射到一维向量中(Token),接着加上位置偏置信息(可学习参数),再通过一系列Transformer Block,最后通过一个全连接层得到最终预测输出。
在这里插入图片描述
首先将C,H,W的图片进行Patch处理成N个向量,然后经过线性层进行降低向量维度,再经过位置编码,然后再经过N个Transformer块,在通过class token来进行分类。

这个Standard visual Transformer和前面文章中ViT有一点不同,这里没有class token,class token只是针对分类才加上去的,上面这个网络才是最标准的视觉ViT网络。

由于VIT忽略了空间归纳偏差,所以它们需要更多的参数来学习视觉表征。此外,与CNN相比,VIT及其多种变体的优化性能不佳,这些模型对L2正则化很敏感,需要大量的数据增强以防止过拟合

🍉MobileViT结构

上面展示是标准视觉ViT模型,下面来看下本次介绍的重点:Mobile-ViT网路结构,如下图所示:
在这里插入图片描述通过上图可以看到MobileViT主要由普通卷积MV2(MobiletNetV2中的Inverted Residual block),MobileViT block全局池化以及全连接层共同组成。

其中,MobileViT块中的Convn × n表示一个标准的n × n卷积MV2指的是MobileNetv2块执行下采样的块用↓2标记

2.MV2(MobileNet v2 block)

MV2 块指MobileNet v2 block,是一个Inverted Residual Block(倒残差结构)。 在倒残差结构中,即特征图的维度是先升后降,据相关论文中描述说,更高的维度经过激活函数后,它损失的信息就会少一些。(注意倒残差结构中基本使用的都是ReLU6激活函数,但是最后一个1x1的卷积层使用的是线性激活函数)。具体网络结构如下图所示。
在这里插入图片描述
MobileViT结构图中标有向下箭头的MV2结构代表stride等于2的情况,即需要进行下采样

🌼Residual Block(残差结构):
①1x1卷积降维
②3x3卷积
③1x1卷积升维
🌻Inverted Residual Block(倒残差结构)
①1x1卷积升维
②3x3卷积DW
③1x1卷积降维

3.MobileViT block

MV2来源于mobilenetv2,所以Mobile-ViT的核心是MobileViT block模块。MobileViT block的结构如下图所示:
在这里插入图片描述
MobileViT Block旨在用更少的参数对输入张量中的局部全局信息进行建模。由上图可知MobileViT Block 整体由三部分组成分别为:Local representationsTransformers as Convolutions (global representations)Fusion

大致流程:首先将特征图通过一个卷积核大小为nxn(代码中是3x3)的卷积层进行局部的特征建模,然后通过一个卷积核大小为1x1的卷积层调整通道数。接着通过Unfold -> Transformer -> Fold结构进行全局的特征建模,然后再通过一个卷积核大小为1x1的卷积层将通道数调整回原始大小。接着通过shortcut分支(在V2版本中将该捷径分支取消了)与原始输入特征图进行Concat拼接(沿通道channel方向拼接),最后再通过一个卷积核大小为nxn(代码中是3x3)的卷积层做特征融合得到输出

Global representations它的具体计算过程如下图所示,
在这里插入图片描述
首先对特征图划分Patch(忽略了通道channels),图中的Patch大小为2x2,即每个Patch由4个Pixel组成。

在进行Self-Attention计算的时候,每个Token(图中的每个Pixel或者说每个小颜色块)只和颜色相同的Token进行Attention,可以减少参数计算量。对于原始的Self-Attention计算每个Token是需要和所有的Token进行Self-Attention。

假设特征图的高宽和通道数分别为H, W, C,在输入到Transformer中,在Self-Attention的时候,每个图中的每个像素和其他的像素进行计算,这样计算量就是:
P 1 = W ∗ H ∗ C P_1 = W*H*C P1=WHC

MobileViT中的是先对输入的特征图划分成多个的patch,但是在计算Self-Attention的时候只对相同位置的像素计算,即图中展示的颜色相同的位置,这样就可以相对的减少计算量,这个时候的计算量为:
P 2 = W ∗ H ∗ C 4 P_2 = \frac{W*H*C}{4} P2=4WHC即理论上的计算成本仅为原始的 1 4 \frac{1}{4} 41

在本次的自注意力机制中,只选择了位置相同的像素点进行点积操作。这样做的原因大概就是因为和所有的像素点都进行自注意力操作会带来信息冗余,毕竟不是所有的像素含有有用的信息对于图像数据本身就存在大量的数据冗余,一张图像的每个像素点的周围的像素值都差不多,并且分辨率越高相差越小,所以这样做并不会损失太多的信息。而且MobileViT在做全局表征之前已经做了一次局部表征(Local representations),进行全局建模时可以忽略一些信息。

Global representations中的​UnfoldFold只是为了将数据给reshape成计算Self-Attention时所需的数据格式。unfold就是将颜色相同的部分拼成一个序列输入到Transformer进行建模,最后再通过fold是调整为原始大小,如下图所示:
在这里插入图片描述
下面来简单的看下patch size对模型性能的影响,patch如果划分的比较大的话是可以减少计算量的,但是划分的太大的话又会忽略更多的语义信息,影响模型的性能。

下图从左到右对语义信息的要求逐渐递增。其中配置A的patch大小为{2, 2, 2},配置B的patch大小为{8, 4, 2},这三个数字分别对应下采样倍率为8,16,32的特征图所采用的patch大小。通过对比可以发现,在图像分类目标检测任务中(对语义细节要求不高的场景),配置A和配置B在Acc和mAP上没太大区别,但配置B要更快。但在语义分割任务中(对语义细节要求较高的场景)配置A的效果要更好。
在这里插入图片描述

🥇Local representations

Local representations 表示输入信息的局部表达。在这个部分,输入MobileViT Block 的数据会经过一个 n × n n \times n n×n的卷积块和一个 1 × 1 1 \times 1 1×1的卷积块。

从上文所述的CNN的空间归纳偏差就可以得知:经过 n × n n \times n n×n(n=3)的卷积块的输出获取到了输入模型的局部信息表达(因为卷积块是对一个整体块进行操作,但是这个卷积核的n是远远小于数据规模的,所以是局部信息表达,而不是全局信息表达)。另外, 1 × 1 1 \times 1 1×1的卷积块是为了线性投影将数据投影至高维空间。例如:对于 9 × 9 9\times 9 9×9的数据,使用 3 × 3 3\times 3 3×3的卷积层,获取到的每个数据都是对 9 × 9 9\times 9 9×9 数据的局部表达

🥈Transformers as Convolutions (global representations)

Transformers as Convolutions (global representations) 表示输入信息的全局表示。在Transformers as Convolutions 中首先通过Unfold 对数据进行转换,转化为 Transformer 可以接受的 1D 数据。然后将数据输入到Transformer 块中。最后通过Fold再将数据变换成原有的样子。

🥉Fusion

Fusion中,经过Transformers as Convolutions得到的信息原始输入信息 ( A ∈ R H × W × C ) (\mathrm{A} \in \mathrm{R^{\mathrm{H \times W \times C}}}) (ARH×W×C)进行合并,然后使用另一个 n × n n\times n n×n卷积层来融合这些连接的特征。这里,得到的信息指:全局表征 X F ∈ R H × W × d \mathrm{X_F} \in \mathrm{R^{\mathrm{H \times W \times d}}} XFRH×W×d经过逐点卷积( 1 × 1 1\times 1 1×1卷积)得到的输出 X F u ∈ R H × W × d \mathrm{X_{Fu}} \in \mathrm{R^{\mathrm{H \times W \times d}}} XFuRH×W×d ,并通过Concat操作与 X \mathrm{X} X组合。

4.模型配置

论文中总共给出了三组模型配置,即MobileViT-S(small)、MobileViT-XS(extra small)、MobileViT-XXS(extra extra small),三种配置是越来越轻量化,三者的主要区别在于特征图的通道数不同

下图为MobileViT的整体框架,主要看下图中的标出的Layer1~5,这里是根据源码中的配置信息划分的:
在这里插入图片描述
对于MobileViT-XXS,Layer1~5的详细配置信息如下:
在这里插入图片描述
对于MobileViT-XS,Layer1~5的详细配置信息如下:
在这里插入图片描述
对于MobileViT-S,Layer1~5的详细配置信息如下:
在这里插入图片描述
参数说明:
⋆ \star out_channels表示该模块输出的通道数
⋆ \star mv2_exp表示Inverted Residual Block中的expansion ratio
⋆ \star transformer_channels表示Transformer模块输入Token的序列长度(特征图通道数)
⋆ \star num_heads表示多头自注意力机制中的head数
⋆ \star ffn_dim表示FFN中间层Token的序列长度
⋆ \star patch_h表示每个patch的高度
⋆ \star patch_w表示每个patch的宽度

5.MobileViT优势

🍄更好的性能: 对于给定的参数预算,MobileViT 在不同的移动视觉任务(图像分类、物体检测、语义分割)中取得了比现有的轻量级 CNN 更好的性能

🍄更好的泛化能力泛化能力是指训练和评价指标之间的差距。对于具有相似训练指标的2个模型,具有更好评价指标的模型更具有通用性,因为它可以更好地预测未知数据集。与CNN相比,即使有广泛的数据增强,其泛化能力也很差,MobileViT显示出更好的泛化能力。

🍄更好的鲁棒性:一个好的模型应该对超参数具有鲁棒性,因为调优这些超参数会消耗时间和资源。与大多数基于ViT的模型不同,MobileViT模型使用基本增强训练,对L2正则化不太敏感

总之,MobileViT使用CNNTransformer相融合的方案,在减少模型复杂度的同时,提高了模型的精度和鲁棒性

⋆ \star 对于一个模型,如果全都使用 CNN 结构。模型只能获取到数据的局部信息而获取不到全局信息
⋆ \star 对于一个模型,如果全部使用 Transformer 结构。模型可以获取到全局信息。但是,Transformer 结构会带来较大的复杂度,存在训练时间上升,模型容易过拟合等等问题。

因此,基于上述问题。作者先使用CNN获取局部信息,然后使用 Transformer 结构获取全局信息。通过上述的理解可以发现:在MobileViT 中的Transformer 结构中,复杂度相比于 ViT 结构 中复杂度降低了很多,因为输入数据复杂度的降低。最终实验结果同时表明:MobileViT 精度更高且鲁棒性更好

三、MobileViT网络实现

1.构建网络模型

首先要构建MobileViT block,其结构图如下所示:
在这里插入图片描述
Transformer实现:
在这里插入图片描述


class MultiHeadAttention(nn.Module):
    """
    This layer applies a multi-head self- or cross-attention as described in
    `Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper

    Args:
        embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
        num_heads (int): Number of heads in multi-head attention
        attn_dropout (float): Attention dropout. Default: 0.0
        bias (bool): Use bias or not. Default: ``True``

    Shape:
        - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
        and :math:`C_{in}` is input embedding dim
        - Output: same shape as the input

    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        attn_dropout: float = 0.0,
        bias: bool = True,
        *args,
        **kwargs
    ) -> None:
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(
                "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
                    self.__class__.__name__, embed_dim, num_heads
                )
            )

        self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)

        self.attn_dropout = nn.Dropout(p=attn_dropout)
        self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)

        self.head_dim = embed_dim // num_heads
        self.scaling = self.head_dim ** -0.5
        self.softmax = nn.Softmax(dim=-1)
        self.num_heads = num_heads
        self.embed_dim = embed_dim

    def forward(self, x_q: Tensor) -> Tensor:
        # [N, P, C]
        b_sz, n_patches, in_channels = x_q.shape

        # self-attention
        # [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
        qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)

        # [N, P, 3, h, c] -> [N, h, 3, P, C]
        qkv = qkv.transpose(1, 3).contiguous()

        # [N, h, 3, P, C] -> [N, h, P, C] x 3
        query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        query = query * self.scaling

        # [N h, P, c] -> [N, h, c, P]
        key = key.transpose(-1, -2)

        # QK^T
        # [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
        attn = torch.matmul(query, key)
        attn = self.softmax(attn)
        attn = self.attn_dropout(attn)

        # weighted sum
        # [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
        out = torch.matmul(attn, value)

        # [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
        out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
        out = self.out_proj(out)

        return out


class TransformerEncoder(nn.Module):
    """
    This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
    Args:
        embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
        ffn_latent_dim (int): Inner dimension of the FFN
        num_heads (int) : Number of heads in multi-head attention. Default: 8
        attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
        dropout (float): Dropout rate. Default: 0.0
        ffn_dropout (float): Dropout between FFN layers. Default: 0.0

    Shape:
        - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
        and :math:`C_{in}` is input embedding dim
        - Output: same shape as the input
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_latent_dim: int,
        num_heads: Optional[int] = 8,
        attn_dropout: Optional[float] = 0.0,
        dropout: Optional[float] = 0.0,
        ffn_dropout: Optional[float] = 0.0,
        *args,
        **kwargs
    ) -> None:

        super().__init__()

        attn_unit = MultiHeadAttention(
            embed_dim,
            num_heads,
            attn_dropout=attn_dropout,
            bias=True
        )

        self.pre_norm_mha = nn.Sequential(
            nn.LayerNorm(embed_dim),
            attn_unit,
            nn.Dropout(p=dropout)
        )

        self.pre_norm_ffn = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
            nn.SiLU(),
            nn.Dropout(p=ffn_dropout),
            nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
            nn.Dropout(p=dropout)
        )
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_latent_dim
        self.ffn_dropout = ffn_dropout
        self.std_dropout = dropout

    def forward(self, x: Tensor) -> Tensor:
        # multi-head attention
        res = x
        x = self.pre_norm_mha(x)
        x = x + res

        # feed forward network
        x = x + self.pre_norm_ffn(x)
        return x

MobileViT的整体框架,主要看下图中的标出的Layer1~5,这里是根据源码中的配置信息划分的:
在这里插入图片描述

def get_config(mode: str = "xxs") -> dict:
    if mode == "xx_small":
        mv2_exp_mult = 2
        config = {
            "layer1": {
                "out_channels": 16,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 1,
                "stride": 1,
                "block_type": "mv2",
            },
            "layer2": {
                "out_channels": 24,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 3,
                "stride": 2,
                "block_type": "mv2",
            },
            "layer3": {  # 28x28
                "out_channels": 48,
                "transformer_channels": 64,
                "ffn_dim": 128,
                "transformer_blocks": 2,
                "patch_h": 2,  # 8,
                "patch_w": 2,  # 8,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "layer4": {  # 14x14
                "out_channels": 64,
                "transformer_channels": 80,
                "ffn_dim": 160,
                "transformer_blocks": 4,
                "patch_h": 2,  # 4,
                "patch_w": 2,  # 4,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "layer5": {  # 7x7
                "out_channels": 80,
                "transformer_channels": 96,
                "ffn_dim": 192,
                "transformer_blocks": 3,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "last_layer_exp_factor": 4,
            "cls_dropout": 0.1
        }
    elif mode == "x_small":
        mv2_exp_mult = 4
        config = {
            "layer1": {
                "out_channels": 32,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 1,
                "stride": 1,
                "block_type": "mv2",
            },
            "layer2": {
                "out_channels": 48,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 3,
                "stride": 2,
                "block_type": "mv2",
            },
            "layer3": {  # 28x28
                "out_channels": 64,
                "transformer_channels": 96,
                "ffn_dim": 192,
                "transformer_blocks": 2,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "layer4": {  # 14x14
                "out_channels": 80,
                "transformer_channels": 120,
                "ffn_dim": 240,
                "transformer_blocks": 4,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "layer5": {  # 7x7
                "out_channels": 96,
                "transformer_channels": 144,
                "ffn_dim": 288,
                "transformer_blocks": 3,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "last_layer_exp_factor": 4,
            "cls_dropout": 0.1
        }
    elif mode == "small":
        mv2_exp_mult = 4
        config = {
            "layer1": {
                "out_channels": 32,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 1,
                "stride": 1,
                "block_type": "mv2",
            },
            "layer2": {
                "out_channels": 64,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 3,
                "stride": 2,
                "block_type": "mv2",
            },
            "layer3": {  # 28x28
                "out_channels": 96,
                "transformer_channels": 144,
                "ffn_dim": 288,
                "transformer_blocks": 2,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "layer4": {  # 14x14
                "out_channels": 128,
                "transformer_channels": 192,
                "ffn_dim": 384,
                "transformer_blocks": 4,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "layer5": {  # 7x7
                "out_channels": 160,
                "transformer_channels": 240,
                "ffn_dim": 480,
                "transformer_blocks": 3,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "num_heads": 4,
                "block_type": "mobilevit",
            },
            "last_layer_exp_factor": 4,
            "cls_dropout": 0.1
        }
    else:
        raise NotImplementedError

    for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
        config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})

    return config

现在开始构建MobileVit网络模型

def make_divisible(
    v: Union[float, int],
    divisor: Optional[int] = 8,
    min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

# 卷积层
class ConvLayer(nn.Module):
    """
    Applies a 2D convolution over an input

    Args:
        in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
        out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
        kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
        stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
        groups (Optional[int]): Number of groups in convolution. Default: 1
        bias (Optional[bool]): Use bias. Default: ``False``
        use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
        use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
                                Default: ``True``

    Shape:
        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C_{out}, H_{out}, W_{out})`

    .. note::
        For depth-wise convolution, `groups=C_{in}=C_{out}`.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Optional[Union[int, Tuple[int, int]]] = 1,
        groups: Optional[int] = 1,
        bias: Optional[bool] = False,
        use_norm: Optional[bool] = True,
        use_act: Optional[bool] = True,
    ) -> None:
        super().__init__()

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)

        if isinstance(stride, int):
            stride = (stride, stride)

        assert isinstance(kernel_size, Tuple)
        assert isinstance(stride, Tuple)

        padding = (
            int((kernel_size[0] - 1) / 2),
            int((kernel_size[1] - 1) / 2),
        )

        block = nn.Sequential()

        conv_layer = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            groups=groups,
            padding=padding,
            bias=bias
        )

        block.add_module(name="conv", module=conv_layer)

        if use_norm:
            norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
            block.add_module(name="norm", module=norm_layer)

        if use_act:
            act_layer = nn.SiLU()
            block.add_module(name="act", module=act_layer)

        self.block = block

    def forward(self, x: Tensor) -> Tensor:
        return self.block(x)

# MV2
class InvertedResidual(nn.Module):
    """
    This class implements the inverted residual block, as described in `MobileNetv2 <https://arxiv.org/abs/1801.04381>`_ paper

    Args:
        in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
        out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
        stride (int): Use convolutions with a stride. Default: 1
        expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
        skip_connection (Optional[bool]): Use skip-connection. Default: True

    Shape:
        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C_{out}, H_{out}, W_{out})`

    .. note::
        If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        expand_ratio: Union[int, float],
        skip_connection: Optional[bool] = True,
    ) -> None:
        assert stride in [1, 2]
        hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)

        super().__init__()

        block = nn.Sequential()
        if expand_ratio != 1:
            block.add_module(
                name="exp_1x1",
                module=ConvLayer(
                    in_channels=in_channels,
                    out_channels=hidden_dim,
                    kernel_size=1
                ),
            )

        block.add_module(
            name="conv_3x3",
            module=ConvLayer(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                stride=stride,
                kernel_size=3,
                groups=hidden_dim
            ),
        )

        block.add_module(
            name="red_1x1",
            module=ConvLayer(
                in_channels=hidden_dim,
                out_channels=out_channels,
                kernel_size=1,
                use_act=False,
                use_norm=True,
            ),
        )

        self.block = block
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.exp = expand_ratio
        self.stride = stride
        self.use_res_connect = (
            self.stride == 1 and in_channels == out_channels and skip_connection
        )

    def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
        if self.use_res_connect:
            return x + self.block(x)
        else:
            return self.block(x)


class MobileViTBlock(nn.Module):
    """
    This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_

    Args:
        opts: command line arguments
        in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
        transformer_dim (int): Input dimension to the transformer unit
        ffn_dim (int): Dimension of the FFN block
        n_transformer_blocks (int): Number of transformer blocks. Default: 2
        head_dim (int): Head dimension in the multi-head attention. Default: 32
        attn_dropout (float): Dropout in multi-head attention. Default: 0.0
        dropout (float): Dropout rate. Default: 0.0
        ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
        patch_h (int): Patch height for unfolding operation. Default: 8
        patch_w (int): Patch width for unfolding operation. Default: 8
        transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
        conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
        no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
    """

    def __init__(
        self,
        in_channels: int,
        transformer_dim: int,
        ffn_dim: int,
        n_transformer_blocks: int = 2,
        head_dim: int = 32,
        attn_dropout: float = 0.0,
        dropout: float = 0.0,
        ffn_dropout: float = 0.0,
        patch_h: int = 8,
        patch_w: int = 8,
        conv_ksize: Optional[int] = 3,
        *args,
        **kwargs
    ) -> None:
        super().__init__()

        # 下面两个卷积层:Local representations
        conv_3x3_in = ConvLayer(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=conv_ksize,
            stride=1
        )
        conv_1x1_in = ConvLayer(
            in_channels=in_channels,
            out_channels=transformer_dim,
            kernel_size=1,
            stride=1,
            use_norm=False,
            use_act=False
        )

        # 下面两个卷积层:Fusion
        conv_1x1_out = ConvLayer(
            in_channels=transformer_dim,
            out_channels=in_channels,
            kernel_size=1,
            stride=1
        )
        conv_3x3_out = ConvLayer(
            in_channels=2 * in_channels,
            out_channels=in_channels,
            kernel_size=conv_ksize,
            stride=1
        )

        # Local representations
        self.local_rep = nn.Sequential()
        self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
        self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)

        assert transformer_dim % head_dim == 0
        num_heads = transformer_dim // head_dim

        # global representations
        global_rep = [
            TransformerEncoder(
                embed_dim=transformer_dim,
                ffn_latent_dim=ffn_dim,
                num_heads=num_heads,
                attn_dropout=attn_dropout,
                dropout=dropout,
                ffn_dropout=ffn_dropout
            )
            for _ in range(n_transformer_blocks)
        ]
        global_rep.append(nn.LayerNorm(transformer_dim))
        self.global_rep = nn.Sequential(*global_rep)

        # Fusion
        self.conv_proj = conv_1x1_out
        self.fusion = conv_3x3_out

        self.patch_h = patch_h
        self.patch_w = patch_w
        self.patch_area = self.patch_w * self.patch_h

        self.cnn_in_dim = in_channels
        self.cnn_out_dim = transformer_dim
        self.n_heads = num_heads
        self.ffn_dim = ffn_dim
        self.dropout = dropout
        self.attn_dropout = attn_dropout
        self.ffn_dropout = ffn_dropout
        self.n_blocks = n_transformer_blocks
        self.conv_ksize = conv_ksize

    def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
        patch_w, patch_h = self.patch_w, self.patch_h
        patch_area = patch_w * patch_h
        batch_size, in_channels, orig_h, orig_w = x.shape

        new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
        new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)

        interpolate = False
        if new_w != orig_w or new_h != orig_h:
            # Note: Padding can be done, but then it needs to be handled in attention function.
            x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
            interpolate = True

        # number of patches along width and height
        num_patch_w = new_w // patch_w  # n_w
        num_patch_h = new_h // patch_h  # n_h
        num_patches = num_patch_h * num_patch_w  # N

        # [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
        x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
        # [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
        x = x.transpose(1, 2)
        # [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
        x = x.reshape(batch_size, in_channels, num_patches, patch_area)
        # [B, C, N, P] -> [B, P, N, C]
        x = x.transpose(1, 3)
        # [B, P, N, C] -> [BP, N, C]
        x = x.reshape(batch_size * patch_area, num_patches, -1)

        info_dict = {
            "orig_size": (orig_h, orig_w),
            "batch_size": batch_size,
            "interpolate": interpolate,
            "total_patches": num_patches,
            "num_patches_w": num_patch_w,
            "num_patches_h": num_patch_h,
        }

        return x, info_dict

    def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
        n_dim = x.dim()
        assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
            x.shape
        )
        # [BP, N, C] --> [B, P, N, C]
        x = x.contiguous().view(
            info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
        )

        batch_size, pixels, num_patches, channels = x.size()
        num_patch_h = info_dict["num_patches_h"]
        num_patch_w = info_dict["num_patches_w"]

        # [B, P, N, C] -> [B, C, N, P]
        x = x.transpose(1, 3)
        # [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
        x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
        # [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
        x = x.transpose(1, 2)
        # [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
        x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
        if info_dict["interpolate"]:
            x = F.interpolate(
                x,
                size=info_dict["orig_size"],
                mode="bilinear",
                align_corners=False,
            )
        return x

    def forward(self, x: Tensor) -> Tensor:
        res = x

        fm = self.local_rep(x)

        # convert feature map to patches
        patches, info_dict = self.unfolding(fm)

        # learn global representations
        for transformer_layer in self.global_rep:
            patches = transformer_layer(patches)

        # [B x Patch x Patches x C] -> [B x C x Patches x Patch]
        fm = self.folding(x=patches, info_dict=info_dict)

        fm = self.conv_proj(fm)

        fm = self.fusion(torch.cat((res, fm), dim=1))
        return fm


class MobileViT(nn.Module):
    """
    This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
    """
    def __init__(self, model_cfg: Dict, num_classes: int = 1000):
        super().__init__()

        image_channels = 3
        out_channels = 16

        self.conv_1 = ConvLayer(
            in_channels=image_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=2
        )

        self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
        self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
        self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
        self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
        self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])

        exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
        self.conv_1x1_exp = ConvLayer(
            in_channels=out_channels,
            out_channels=exp_channels,
            kernel_size=1
        )

        self.classifier = nn.Sequential()
        self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
        self.classifier.add_module(name="flatten", module=nn.Flatten())
        if 0.0 < model_cfg["cls_dropout"] < 1.0:
            self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
        self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))

        # weight init
        self.apply(self.init_parameters)

    def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
        block_type = cfg.get("block_type", "mobilevit")
        if block_type.lower() == "mobilevit":
            return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
        else:
            return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)

    @staticmethod
    def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
        output_channels = cfg.get("out_channels")
        num_blocks = cfg.get("num_blocks", 2)
        expand_ratio = cfg.get("expand_ratio", 4)
        block = []

        for i in range(num_blocks):
            stride = cfg.get("stride", 1) if i == 0 else 1

            layer = InvertedResidual(
                in_channels=input_channel,
                out_channels=output_channels,
                stride=stride,
                expand_ratio=expand_ratio
            )
            block.append(layer)
            input_channel = output_channels

        return nn.Sequential(*block), input_channel

    @staticmethod
    def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
        stride = cfg.get("stride", 1)
        block = []

        if stride == 2:
            layer = InvertedResidual(
                in_channels=input_channel,
                out_channels=cfg.get("out_channels"),
                stride=stride,
                expand_ratio=cfg.get("mv_expand_ratio", 4)
            )

            block.append(layer)
            input_channel = cfg.get("out_channels")

        transformer_dim = cfg["transformer_channels"]
        ffn_dim = cfg.get("ffn_dim")
        num_heads = cfg.get("num_heads", 4)
        head_dim = transformer_dim // num_heads

        if transformer_dim % head_dim != 0:
            raise ValueError("Transformer input dimension should be divisible by head dimension. "
                             "Got {} and {}.".format(transformer_dim, head_dim))

        block.append(MobileViTBlock(
            in_channels=input_channel,
            transformer_dim=transformer_dim,
            ffn_dim=ffn_dim,
            n_transformer_blocks=cfg.get("transformer_blocks", 1),
            patch_h=cfg.get("patch_h", 2),
            patch_w=cfg.get("patch_w", 2),
            dropout=cfg.get("dropout", 0.1),
            ffn_dropout=cfg.get("ffn_dropout", 0.0),
            attn_dropout=cfg.get("attn_dropout", 0.1),
            head_dim=head_dim,
            conv_ksize=3
        ))

        return nn.Sequential(*block), input_channel

    @staticmethod
    def init_parameters(m):
        if isinstance(m, nn.Conv2d):
            if m.weight is not None:
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            if m.weight is not None:
                nn.init.ones_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.Linear,)):
            if m.weight is not None:
                nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        else:
            pass

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv_1(x)
        x = self.layer_1(x)
        x = self.layer_2(x)

        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.layer_5(x)
        x = self.conv_1x1_exp(x)
        x = self.classifier(x)
        return x


def mobile_vit_xx_small(num_classes: int = 1000):
    # pretrain weight link
    # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
    config = get_config("xx_small")
    m = MobileViT(config, num_classes=num_classes)
    return m


def mobile_vit_x_small(num_classes: int = 1000):
    # pretrain weight link
    # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
    config = get_config("x_small")
    m = MobileViT(config, num_classes=num_classes)
    return m


def mobile_vit_small(num_classes: int = 1000):
    # pretrain weight link
    # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
    config = get_config("small")
    m = MobileViT(config, num_classes=num_classes)
    return m

2.训练和测试模型

def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    tb_writer = SummaryWriter()

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    img_size = 224
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    model = create_model(num_classes=args.num_classes).to(device)

    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)
        weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
        # 删除有关分类类别的权重
        for k in list(weights_dict.keys()):
            if "classifier" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他权重全部冻结
            if "classifier" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)

    best_acc = 0.
    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)

        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "./weights/best_model.pth")

        torch.save(model.state_dict(), "./weights/latest_model.pth")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=0.0002)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default="F:/NN/Learn_Pytorch/flower_photos")

    # 预训练权重路径,如果不想载入就设置为空字符
    parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
                        help='initial weights path')
    # 是否冻结权重
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

这里使用了预训练权重,在其基础上训练自己的数据集。训练50epoch的准确率能到达94%左右。
在这里插入图片描述

四、图像分类

这里使用花朵数据集,下载连接:https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    img_size = 224
    data_transform = transforms.Compose(
        [transforms.Resize(int(img_size * 1.14)),
         transforms.CenterCrop(img_size),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # 加载图片
    img_path = 'daisy2.jpg'
    assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)
    image = Image.open(img_path)

    # image.show()
    # [N, C, H, W]
    img = data_transform(image)
    # 扩展维度
    img = torch.unsqueeze(img, dim=0)

    # 获取标签
    json_path = 'class_indices.json'
    assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
    with open(json_path, 'r') as f:
        # 使用json.load()函数加载JSON文件的内容并将其存储在一个Python字典中
        class_indict = json.load(f)

    # create model
    model = create_model(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./weights/best_model.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))

    model.eval()
    with torch.no_grad():
        # 对输入图像进行预测
        output = torch.squeeze(model(img.to(device))).cpu()
        # 对模型的输出进行 softmax 操作,将输出转换为类别概率
        predict = torch.softmax(output, dim=0)
        # 得到高概率的类别的索引
        predict_cla = torch.argmax(predict).numpy()

    res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
    draw = ImageDraw.Draw(image)
    # 文本的左上角位置
    position = (10, 10)
    # fill 指定文本颜色
    draw.text(position, res, fill='green')
    image.show()
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))


预测结果:
在这里插入图片描述

结束语

感谢阅读吾之文章,今已至此次旅程之终站 🛬。

吾望斯文献能供尔以宝贵之信息与知识也 🎉。

学习者之途,若藏于天际之星辰🍥,吾等皆当努力熠熠生辉,持续前行。

然而,如若斯文献有益于尔,何不以三连为礼?点赞、留言、收藏 - 此等皆以证尔对作者之支持与鼓励也 💞。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1092000.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

三十四、【进阶】MySQL索引的操作

1、创建索引 &#xff08;1&#xff09;基础语法 &#xff08;2&#xff09;唯一索引 唯一索引与普通索引不同的是&#xff0c;索引列的数值必须唯一&#xff0c;但允许有空值null&#xff1b; 唯一索引与主键索引不同的是&#xff0c;主键索引不允许出现空值null&#xff0…

Webmin 远程命令执行漏洞_CVE-2019-15107

Webmin 远程命令执行漏洞_CVE-2019-15107 文章目录 Webmin 远程命令执行漏洞_CVE-2019-15107在线漏洞详情&#xff1a;漏洞描述&#xff1a;版本影响&#xff1a;环境搭建漏洞复现使用BP进行抓包POC发送以下POST请求来执行命令id:复现成功 漏洞利用-shell反弹执行反弹指令bp的i…

【Flutter Widget】AppBar 和 PopupMenu

App Bar 可以视为页面的标题栏&#xff0c;在 Flutter 中用AppBar组件实现&#xff1b;Popup Menu 是弹出菜单&#xff0c;用PopupMenuButton实现。下面结合这两个组件说明其用法。 1. 代码实现 一个简单的AppBar实现代码如下&#xff1a; import package:flutter/material.…

同源策略和跨域问题

1.跨域问题产生的原因 浏览器的同源策略影响&#xff0c;同源策略是一种安全机制&#xff0c;它限制了一个网页中的脚本只能访问同源的资源。 跨源网络访问的三种方式&#xff1a;跨域写操作&#xff0c;跨域资源嵌入&#xff0c;跨域读操作 2.跨域问题案例 ip和域名不一致…

使用 Certbot 为 Nginx 自动配置 SSL 证书

发布于 2023-07-13 on https://chenhaotian.top/linux/certbot-nginx/ 使用 Certbot 为 Nginx 自动配置 SSL 证书 配置步骤 以 Debian 11 为例 1. 安装Certbot和Nginx插件 sudo apt-get update sudo apt-get install certbot python3-certbot-nginx2. 获取和安装证书 运行…

nginx的location的优先级和匹配方式和nginx的重定向

在http模块有server,在server模块才有location,location匹配的是uri /test /image 在一个server当中有多个location,如何来确定匹配哪个location。 nginx的正则表达式&#xff1a; ^&#xff1a;字符串的起始位置 $&#xff1a;字符串的结束位置 *&#xff1a;匹配所有 &am…

【宏offsetof详解和模拟实现】

文章目录 一. 什么是offsetof&#xff1f;二. 用宏模拟实现offsetof三. 代码重点讲解四. 结束语 一. 什么是offsetof&#xff1f; 1.** 注意&#xff1a;offsetof不是函数而是一个宏** 2. 返回值是size_t无符号整形&#xff0c;头文件是<stddef.h>&#xff0c;参数是&…

智慧工地解决方案,智慧工地平台源码

一、现状描述 建筑工程建设具有明显的生产规模大宗性与生产场所固定性的特点。建筑企业70%左右的工作都发生在施工现场&#xff0c;施工阶段的现场管理对工程成本、进度、质量及安全等至关重要。同时随着工程建设规模不断扩大&#xff0c;工艺流程纷繁复杂&#xff0c;如何搞好…

腾讯云我的世界mc服务器多少钱一年?

腾讯云我的世界mc服务器多少钱&#xff1f;95元一年2核2G3M轻量应用服务器、2核4G5M带宽优惠价218元一年、4核8G12M带宽轻量服务器446元一年&#xff0c;云服务器CVM标准型S5实例2核2G优惠价280元一年、2核4G配置服务器748元一年&#xff0c;腾讯云百科txybk.com分享腾讯云我的…

Spring MVC 和Spring JDBC

目录 Spring MVC MVC模式 核心组件 工作流程 Spring JDBC Spring JDBC功能和优势 Spring JDBC的关键组件 Spring MVC Spring MVC&#xff08;Model-View-Controller&#xff09;是Spring框架的一个模块&#xff0c;用于构建Web应用程序。它的主要目标是将Web应用程序的不…

【Python】Python语言基础(上)

第一章 前言 1. Python简介 Python语言并不是新的语言&#xff0c;它早于HTTP 1.0协议5年&#xff0c;早于Java语言 4年。 ​ Python是由荷兰人Guido van Rossum&#xff08;吉多范罗苏姆&#xff09;于1989年圣诞节期间在阿姆斯特丹休假时为了打发无聊的假期而编写的一个脚本…

利用 Amazon CodeWhisperer 激发孩子的编程兴趣

我是一个程序员&#xff0c;也是一个父亲。工作之余我会经常和儿子聊他们小学信息技术课学习的 Scratch 和 Kitten 这两款图形化的少儿编程工具。 我儿子有一次指着书房里显示器上显示的 Visual Studio Code 问我&#xff0c;“为什么我们上课用的开发界面&#xff0c;和爸爸你…

巴以冲突中暴露的摄像头正对安全构成威胁

巴以冲突爆发后&#xff0c;许多配置不当的安全摄像头正暴露给黑客活动分子&#xff0c;使其周遭人员面临巨大安全风险。 Cyber​​news 研究人员发现&#xff0c;在以色列至少有165 个暴露的联网 RTSP 摄像头&#xff0c;在巴勒斯坦有 29 个暴露的 RTSP 摄像头。在巴勒斯坦&am…

CCPlotR | 轻松拿捏单细胞分析之细胞交互!~

1写在前面 周末了各位&#xff0c;昨天去看了奥本海默&#xff0c;不得不说&#xff0c;大神就是大神。&#x1f618; 比起我们的电影&#xff0c;似乎诺兰更好地还原了奥本海默的真实。&#x1f9d0; 言归正传&#xff0c;今天分享的是CCPlotR包&#xff0c;用于基于scRNAseq数…

SwiftUI Swift CoreData 计算某实体某属性总和

有一个名为 Item 的实体&#xff0c;它有一个名为 amount 的 Double 属性&#xff0c;向你的 View 添加一个计算属性&#xff1a; Code: struct ContentView: View {Environment(\.managedObjectContext) private var viewContextFetchRequest(sortDescriptors: [NSSortDescri…

Codeforces Round 903 (Div. 3) C(矩形旋转之后对应的坐标)

题目链接&#xff1a;Codeforces Round 903 (Div. 3) C 题目&#xff1a; 思想&#xff1a; 旋转之后对应的坐标&#xff1a; &#xff08;i&#xff0c;j&#xff09;&#xff08;n1-j&#xff0c;i&#xff09;&#xff08;n1-i&#xff0c;n1-j&#xff09;&#xff08;j…

CSS的布局 Day03

一、显示模式&#xff1a; 网页中HTML的标签多种多样&#xff0c;具有不同的特征。而我们学习盒子模型、使用定位和弹性布局把内容分块&#xff0c;利用CSS布局使内容脱离文本流&#xff0c;使用定位或弹性布局让每块内容摆放在想摆放的位置&#xff0c;让网站页面布局更合理、…

VCAP-DCV VMware vSphere: 运维、扩展和安全防护 [V8.0]

VMware官方授权合作机构&#xff0c;全国招生&#xff01; VCP-DCV VMware vSphere&#xff1a;安装、配置和管理[V8.x]-CSDN博客本课程重点讲授如何安装、配置和理 VMware vSphere 8.0&#xff08;包括 VMware ESXi™ 8.0 和 VMware vCenter Server™ 8.0&#xff09;。使用 …

代码随想录算法训练营第二十天丨 二叉树part07

530.二叉搜索树的最小绝对差 思路 题目中要求在二叉搜索树上任意两节点的差的绝对值的最小值。 注意是二叉搜索树&#xff0c;二叉搜索树是有序的。 遇到在二叉搜索树上求什么最值啊&#xff0c;差值之类的&#xff0c;就把它想成在一个有序数组上求最值&#xff0c;求差值…

出游热潮再起,IPIDEA代理IP帮你应对旅游数据采集的挑战

随着互联网的快速发展&#xff0c;旅游行业也随之迅速发展。在线旅游预订已经成为人们出行前的必要步骤&#xff0c;然而&#xff0c;旅游信息的采集却是一项具有挑战性的任务。为了从酒店和航空公司网站、在线旅行社和其他类似来源收集数据&#xff0c;企业需要克服许多障碍。…