【youcans动手学模型】SqueezeNet 模型-CIFAR10图像分类

news2024/11/17 10:42:32

欢迎关注『youcans动手学模型』系列
本专栏内容和资源同步到 GitHub/youcans


【youcans动手学模型】SqueezeNet 模型-CIFAR10图像分类

    • 1. SqueezeNet 卷积神经网络模型
      • 1.1 模型简介
      • 1.2 论文介绍
      • 1.3 分析与讨论
    • 2. 在 PyTorch 中定义 SqueezeNet 模型类
      • 2.1 定义 Fire Module
      • 2.2 简化的 SqueezeNet 模型类
      • 2.3 按特征提取和分类器模块封装
      • 2.4 加载官方的 SqueezeNet 模型类
    • 3. 基于 Squeeze 模型的 CIFAR10 图像分类
      • 3.1 PyTorch 建立神经网络模型的基本步骤
      • 3.2 加载 CIFAR10 数据集
      • 3.3 建立 SqueezeNet 网络模型
      • 3.4 SqueezeNet 模型训练
      • 3.5 SqueezeNet 模型的保存与加载
      • 3.6 模型检验
      • 3.7 模型推理
    • 4. 使用 SqueezeNet 预训练模型进行图像分类


本文用 PyTorch 实现 SqueezeNet 网络模型,使用 CIFAR10 数据集训练模型,进行图像分类。


1. SqueezeNet 卷积神经网络模型

Forresti, Moskewcz 等在 2016年发表的论文 “SqueezeNet: AlexNet Level Accuracy with 50x Fewer Parameters and <0.5 MB Model Size”,提出一种轻量级深度学习神经网络模型,称为 SqueezeNet。

SqueezeNet 与 MobileNet、ShuffleNet 和 Xception 都是 2016年在 arXiv 上公开的,被称为四大轻量级模型。

【论文下载地址】
SqueezeNet: AlexNet Level Accuracy with 50x Fewer Parameters and <0.5 MB Model Size

【GitHub地址】:
[https://github.com/forresti/SqueezeNet]
[https://github.com/DeepScale/SqueezeNet]


1.1 模型简介

SqueezeNet 的创新在于提出了 fire module,包括 squeeze 和 expand 两个部分,以降低参数规模和计算量。

  • squeeze 层采用 1*1 卷积核对上一层的特征图进行卷积,以降低特征图的维数;
  • expand 层使用 Inception 结构,分为 1*1 卷积和 3*3卷积 2个分支进行拼接。

SqueezeNet 预训练模型大小约 4.8 MB,在 ImageNet 数据集上 Top-5 准确率 80.3%,非常轻量高效。

在这里插入图片描述


1.2 论文介绍

【论文摘要】

近年来对深度卷积神经网络(CNNs)的研究主要集中在提高准确性。对于给定的精度水平,通常有多个 CNN 架构可以实现。在同等精度下,较小的 CNN 架构具有更大的优势:(1)较小的 CNN 在分布式训练期间需要的服务器通信较少。(2)对于自动驾驶任务,较小的 CNN 需要的带宽较少。(3)较小的 CNN 更适合部署在FPGA和内存有限的硬件设备。

我们提出了一种称为 SqueezeNet 的小型 CNN 架构。SqueezeNet 在 ImageNet 数据集上实现了 AlexNet 级别的精度,而参数减少了50倍。通过模型压缩技术,SqueezeNet 可以压缩到 0.5MB 以下。


【论文背景】

自从 LeNet 模型开创卷积神经网络以来,2012年 AlexNet 引发深度学习的研究热潮。此后从 ZF-Net、VGGNet、GoogleNet、ResNet 到 DenseNet,都以追求提升精确率为主要目的,采用的主要方向包括加深网络结构和增强卷积模块功能,但是这也导致了网络模型越来越复杂,需要的内存和计算量也大大增加。SqueezeNet 开辟了另一个方向,在保证模型精度不降低的前提下,尽可能减小模型参数、提高运算速度。

简化模型的一条路径是对现有的 CNN 模型进行压缩,例如使用奇异值分解(SVD)方法,网络剪枝方法,深度压缩方法, EIE 硬件加速方法。

轻量化设计是在模型设计时就采用轻量化的思想,例如轻量卷积方式(深度可分离卷积、分组卷积),平均池化代替全连接层,1×1卷积进行通道降维。


【主要创新】

为了减少网络模型的大小和计算量 ,SqueezeNet 遵循了以下三个策略:

  • 使用 1*1 卷积替代大部分的 3*3 卷积,显著地减小了参数量和计算量。
  • 使用 Squeeze 层(PointConv)降低 3*3 卷积核的输入通道数(深度),减小参数量和计算量。
  • 推迟下采样,在较晚的阶段进行下采样,使卷积层具有较大的特征图,以获得更好的分类结果。

SqueezeNet 的核心是压缩-扩展(Squeeze-Expand)结构,称为 fire module,包括只有 1*1 卷积的 Squeeze 层,由 1*1 卷积、3*3 卷积 2 个分支拼接而成的 Expand 层。

  • Squeeze 层采用 1*1 卷积核对上一层的特征图进行卷积,以降低特征图的维数。

  • Expand 层使用 Inception 结构,分为 1*1 卷积和 3*3卷积 2个分支进行拼接。

这两个分支的卷积运算的 stride=1,padding=same,输出的特征图大小相同,可以进行拼接。拼接后输出的深度是 1*1 卷积深度 e1 与 3*3 卷积深度 e3 之和 ( e 1 + e 3 ) (e_1+e_3) (e1+e3)

在这里插入图片描述

如果对照 Xception 和 DSC 的架构,SqueezeNet 才是真正的 “Extreme Inception”,而 Xception 模型实际上使用的是 DSC。


【模型结构】

SqueezeNet 卷积神经网络的架构如下。

  • SqueezeNet 从一个标准卷积层(conv1)开始,然后设有 8 个 Fire 模块(fire2-fire9),最后是一个标准卷积层(conv10)。
  • 从网络的输入段到输出段,逐渐增加每个Fire 模块的深度(特征图数量)。
  • 在 conf1、fire4、fire8和 conf10 层之后设有步幅 stride=2 的最大池化层。

SqueezeNet 的设计细节如下。

  • 为了使 1*1 卷积和 3*3 卷积输出的特征图大小相同以进行拼接,对于 3*3 卷积使用 padding=1。

  • 使用 ReLU激活函数,用于 squeeze 和 expand 层。

  • 在 fire9 模块之后使用比例为 50% 的 Dropout,以减小规模。

  • 受 NiN 的启发,没有采用全连接层。

  • 初始的学习率设为 0.04,在训练过程中线性降低。

Caffe 框架本身不支持包含多分辨率的卷积层。使用 1*1 和 3*3 的两个独立的卷积层,在通道维度中将两个卷积层的输出连接在一起,来实现 expand 层。

在这里插入图片描述


【模型配置】

Fire Module 有 1*1、3*3 两种尺寸的卷积核,论文将这两种卷积核的数量作为超参数,需要人为设定。论文中给出的 SqueezeNet 模型的具体结构配置参数如下。

在这里插入图片描述


1.3 分析与讨论

SqueezeNet 模型的目的,是在达到一定精度的条件下,最大程度地简化模型、提高运算速度。

SqueezeNet 模型的方向,是降低模型的参数数量和计算量。

SqueezeNet 模型的设计策略,是减少 3*3 卷积的通道数,并用 1*1 卷积替换部分 3*3 卷积。

SqueezeNet 模型的不足是:

(1)嵌入式应用环境的主要问题是实时性,SqueezeNet 通过更深的深度达到更少的参数数量,降低了网络的并行能力,推理时间反而更长。

(2)虽然 SqueezeNet 比 AlexNet 的参数减少了 50倍,但这主要是由于 AlexNet 的全连接层过于庞大,而 SqueezeNet 使用了平均池化层,与 SqueezeNet 的网络结构关系不大。

(3)SqueezeNet 得到的模型大小 5MB左右(Top5/80.3%),Deep Compression 以后的模型大小 0.5 MB 也与 SqueezeNet 的网络结构关系不大。

不过,不论是 5MB 还是 0.5 Mb,SqueezeNet 预训练模型在 ImageNet 数据集上 Top-5 准确率 80.3%,都已经非常轻量高效了。


2. 在 PyTorch 中定义 SqueezeNet 模型类

SqueezeNet 模型是一种网络框架,针对不同的任务可以进行不同的网络结构设计和超参数配置。

本节先面向 CIFAR10 数据集图像分类问题,详细介绍 SqueezeNet 模型类的构造过程。最后也将给出从 torchvision.model 加载预定义模型类的方法。

2.1 定义 Fire Module

Fire module 是 SqueezeNet 网络架构的核心,包括只有 1*1 卷积的 Squeeze 层,由 1*1 卷积、3*3 卷积 2 个分支拼接而成的 Expand 层。

定义 Fire Module 的例程如下。

# 定义 Fire 模块 (Squeeze + Expand)
class Fire(nn.Module):
    def __init__(self, in_ch, squeeze_ch, e1_ch, e3_ch):  # 声明 Fire 模块的超参数
        super(Fire, self).__init__()
        # Squeeze, 1x1 卷积
        self.squeeze = nn.Conv2d(in_ch, squeeze_ch, kernel_size=1)
        # # Expand, 1x1 卷积
        self.expand1 = nn.Conv2d(squeeze_ch, e1_ch, kernel_size=1)
        # Expand, 3x3 卷积
        self.expand3 = nn.Conv2d(squeeze_ch, e3_ch, kernel_size=3, padding=1)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.activation(self.squeeze(x))
        x = torch.cat([self.activation(self.expand1(x)),
                       self.activation(self.expand3(x))], dim=1)
        return x    

2.2 简化的 SqueezeNet 模型类

简化的 SqueezeNet 模型类定义如下。该模型与 SqueezeNet 论文原文模型相比进行了简化,而且使用全连接层作为分类器。

对于不同的数据集,可能需要进行一些适应性的调整。例如 CIFAR10 数据集图像分类问题数据集规模较小,图片尺寸为 32*32,因此对 SqueezeNet 模型进行了简化。

# 定义简化的 SqueezeNet 模型类 1
class SqueezeNet1(nn.Module):
    def __init__(self, num_classes=100):
        super(SqueezeNet1, self).__init__()
        self.conv1 = nn.Conv2d(3, 96, kernel_size=3, stride=1, padding=1)  # 3x32x32 -> 96x32x32
        self.relu = nn.ReLU(inplace=True)
        self.fire2 = Fire(96, 48, 32, 32)  # 96x32x32 -> 64x32x32
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 64x32x32 -> 64x16x16
        self.fire3 = Fire(64, 32, 64, 64)  # 64x16x16 -> 128x16x16
        self.fire4 = Fire(128, 64, 128, 128)  # 128x16x16 -> 256x16x16
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 256x16x16 -> 256x8x8
        self.fire5 = Fire(256, 64, 192, 192)  # 256x8x8 -> 384x8x8
        self.fire6 = Fire(384, 128, 256, 256)  # 384x8x8 -> 512x8x8
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 512x8x8 -> 512x4x4
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))  # 512x4x4 -> 512x1x1
        self.linear = nn.Linear(512, num_classes)  # 512 -> num_classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.fire2(x)
        x = self.maxpool1(x)  # torch.Size([1, 64, 16, 16])
        x = self.fire3(x)
        x = self.fire4(x)
        x = self.maxpool2(x)  # torch.Size([1, 256, 8, 8])
        x = self.fire5(x)
        x = self.fire6(x)
        x = self.maxpool3(x)    # torch.Size([1, 512, 4, 4])
        x = self.avg_pool(x)  # torch.Size([1, 512, 1, 1])
        x = x.view(x.size(0), -1)  # torch.Size([1, 512])
        x = self.linear(x)  # torch.Size([1, 10])
        return x

2.3 按特征提取和分类器模块封装

PyTorch 通过 torch.nn 模块提供了高阶的 API,可以从头开始构建网络。

通过 Sequential 可以构建序列化的模块,使得网络模块的层次更加清晰,便于构造大型和复杂的网络模型。将初始卷积层和 Fire 模块封装为特征提取模块 self.features,按照论文采用 avgpool 层构造分类器模块并进行封装,定义 SqueezeNet 模型类。

# 定义简化的 SqueezeNet 模型类 2
class SqueezeNet2(nn.Module):
    def __init__(self, num_classes=100):
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # 3x32x32 -> 64x32x32
            nn.ReLU(inplace=True),
            Fire(64, 16, 64, 64),  # 64x32x32 -> 128x32x32
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),  # 128x32x32 -> 128x16x16
            Fire(128, 32, 64, 64),  # 128x16x16 -> 128x16x16
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),  # 128x16x16 -> 128x8x8
            Fire(128, 64, 128, 128),  # 128x8x8 -> 256x8x8
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),  # 256x8x8 -> 256x4x4
            Fire(256, 64, 256, 256)  # 256x4x4 -> 512x4x4
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Conv2d(512, self.num_classes, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),  # 512x4x4 -> 10x1x1
        )

    def forward(self, x):
        x = self.features(x)  # torch.Size([1, 512, 4, 4])
        x = self.classifier(x)  # torch.Size([1, 10, 1, 1])
        x = x.view(x.size(0), -1)  # torch.Size([1, 10])
        return x

2.4 加载官方的 SqueezeNet 模型类

torchvision.models 包和 Torch Hub 中都提供了 SqueezeNet 模型,该模型与 SqueezeNet 论文原文的结构基本一致。

torchvision.models 提供了 SqueezeNet模型类和预训练模型 SqueezeNet | PyTorch 可以直接使用,原始代码可以参考:SOURCE CODE。

torchvision.models 包中 SqueezeNet 模型类的定义如下:

torchvision.models.squeezenet1_0(*, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) → SqueezeNet

参数说明:

  • weights (SqueezeNet1_0_Weights, optional) ,预训练模型的权值参数,默认只加载模型结构不加载模型参数。

  • progress (bool, optional) ,显示下载进度条,默认为 True。

  • **kwargs,模型参数。

程序说明:

SqueezeNet 模型类基于 Fire 模块堆叠构建 squeezenet 网络模型,具体代码如下。

网络模型分为特征提取器 features 与分类器 classifier 两部分。在特征提取部分实现了 v1.0 和 v1.1 两个版本的代码。

  • v1.0 版本:初始卷积层 conv1 选择 nn.Conv2d(3,96, kernel_size=7, stride=2),nn.MaxPool2d 放置在 conv1, Fire4, Fire8 之后。
  • v1.1 版本:初始卷积层 conv1 选择 nn.Conv2d(3,64, kernel_size=3, stride=2),nn.MaxPool2d 放置在 conv1, Fire3, Fire5 之后。

注意模型的输入是形状为 (1,3,H,W) 的 RGB 图像的 batch,其中 H 和 W 至少为 224。

class Fire(nn.Module):
    def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
        super().__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
        self.expand3x3_activation = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.squeeze_activation(self.squeeze(x))
        return torch.cat(
            [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
        )

class SqueezeNet(nn.Module):
    def __init__(
        self,
        version: str = '1_0',
        num_classes: int = 1000
    ) -> None:
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
        if version == '1_0':
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=7, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(96, 16, 64, 64),
                Fire(128, 16, 64, 64),
                Fire(128, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 32, 128, 128),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(512, 64, 256, 256),
            )
        elif version == '1_1':
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(64, 16, 64, 64),
                Fire(128, 16, 64, 64),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(128, 32, 128, 128),
                Fire(256, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                Fire(512, 64, 256, 256),
            )
        else:
            # FIXME: Is this needed? SqueezeNet should only be called from the
            # FIXME: squeezenet1_x() functions
            # FIXME: This checking is not done for the other models
            raise ValueError("Unsupported SqueezeNet version {version}:"
                             "1_0 or 1_1 expected".format(version=version))

        # Final convolution is initialized differently from the rest
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
                    init.normal_(m.weight, mean=0.0, std=0.01)
                else:
                    init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.classifier(x)
        return torch.flatten(x, 1)

所有预训练模型都要以相同方式归一化的输入图像,即形状为(3* H* W)的 3通道 RGB 图像的 batch,其中 H 和 W 至少为224。图像必须加载到 [0,1] 的范围内,然后使用均值 [0.485, 0.456, 0.406] 和标准差[0.229, 0.224, 0.225] 进行归一化。使用例程如下。

import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'squeezenet1_0', pretrained=True)
# or
# model = torch.hub.load('pytorch/vision:v0.10.0', 'squeezenet1_1', pretrained=True)
model.eval()

3. 基于 Squeeze 模型的 CIFAR10 图像分类

3.1 PyTorch 建立神经网络模型的基本步骤

使用 PyTorch 建立、训练和使用神经网络模型的基本步骤如下。

  1. 准备数据集(Prepare dataset):加载数据集,对数据进行预处理。
  2. 建立模型(Design the model):实例化模型类,定义损失函数和优化器,确定模型结构和训练方法。
  3. 模型训练(Model trainning):使用训练数据集对模型进行训练,确定模型参数。
  4. 模型推理(Model inferring):使用训练好的模型进行推理,对输入数据预测输出结果。
  5. 模型保存与加载(Model saving/loading):保存训练好的模型,以便以后使用或部署。

以下按此步骤讲解 Squeeze 模型的例程。


3.2 加载 CIFAR10 数据集

通用数据集的样本结构均衡、信息高效,而且组织规范、易于处理。使用通用的数据集训练神经网络,不仅可以提高工作效率,而且便于评估模型性能。

PyTorch 提供了一些常用的图像数据集,预加载在 torchvision.datasets 类中。torchvision 模块实现神经网络所需的核心类和方法, torchvision.datasets 包含流行的数据集、模型架构和常用的图像转换方法。

CIFAR 数据集是一个经典的图像分类小型数据集,有 CIFAR10 和 CIFAR100 两个版本。CIFAR10 有 10 个类别,CIFAR100 有 100 个类别。CIFAR10 每张图像大小为 32*32,包括飞机、小汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车 10 个类别。CIFAR10 共有 60000 张图像,其中训练集 50000张,测试集 10000张。每个类别有 6000张图片,数据集平衡。

加载和使用 CIFAR 数据集的方法为:

torchvision.datasets.CIFAR10()
torchvision.datasets.CIFAR100()

CIFAR 数据集可以从官网下载:http://www.cs.toronto.edu/~kriz/cifar.html 后使用,也可以使用 datasets 类自动加载(如果本地路径没有该文件则自动下载)。

下载数据集时,使用预定义的 transform 方法进行数据预处理,包括调整图像尺寸、标准化处理,将数据格式转换为张量。标准化处理所使用 CIFAR10 数据集的均值和方差为 (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)。transform_train在训练过程中,增加随机性,提高泛化能力。

大型训练数据集不能一次性加载全部样本来训练,可以使用 Dataloader 类自动加载数据。Dataloader 是一个迭代器,基本功能是传入一个 Dataset 对象,根据参数 batch_size 生成一个 batch 的数据。

使用 DataLoader 类加载 CIFAR-10 数据集的例程如下。

    # (1) 将[0,1]的PILImage 转换为[-1,1]的Tensor
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.RandomRotation(10),  # 随机旋转
        transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.Resize(32),  # 图像大小调整为 (w,h)=(32,32)
        transforms.ToTensor(),  # 将图像转换为张量 Tensor
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
    # 测试集不需要进行数据增强
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

    # (2) 加载 CIFAR10 数据集
    batchsize = 128
    # 加载 CIFAR10 数据集, 如果 root 路径加载失败, 则自动在线下载
    # 加载 CIFAR10 训练数据集, 50000张训练图片
    train_set = torchvision.datasets.CIFAR10(root='../dataset', train=True,
                                            download=True, transform=transform_train)
    # train_loader = torch.utils.data.DataLoader(train_set, batch_size=batchsize)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batchsize,
                                              shuffle=True, num_workers=8)
    # 加载 CIFAR10 验证数据集, 10000张验证图片
    test_set = torchvision.datasets.CIFAR10(root='../dataset', train=False,
                                           download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000,
                                              shuffle=True, num_workers=8)
    # 创建生成器,用 next 获取一个批次的数据
    valid_data_iter = iter(test_loader)  # _SingleProcessDataLoaderIter 对象
    valid_images, valid_labels = next(valid_data_iter)  # images: [batch,3,32,32], labels: [batch]
    valid_size = valid_labels.size(0)  # 验证数据集大小,batch
    print(valid_images.shape, valid_labels.shape)

    # 定义类别名称,CIFAR10 数据集的 10个类别
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')


3.3 建立 SqueezeNet 网络模型

建立一个 SqueezeNet 网络模型进行训练,包括三个步骤:

  • 实例化 SqueezeNet 模型对象;
  • 设置训练的损失函数;
  • 设置训练的优化器。

torch.nn.functional 模块提供了各种内置损失函数,本例使用交叉熵损失函数 CrossEntropyLoss。

torch.optim 模块提供了各种优化方法,本例使用 Adam 优化器。注意要将 model 的参数 model.parameters() 传给优化器对象,以便优化器扫描需要优化的参数。

    # (3) 构造 Squeeze 网络模型
    model = SqueezeNet1(num_classes=10)  # 实例化 Squeeze 网络模型
    model.to(device)  # 将网络分配到指定的device中
    # print(model)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()  # 定义损失函数 CrossEntropy
    optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.01)  # 定义优化器 SGD

3.4 SqueezeNet 模型训练

PyTorch 模型训练的基本步骤是:

  1. 前馈计算模型的输出值;
  2. 计算损失函数值;
  3. 计算权重 weight 和偏差 bias 的梯度;
  4. 根据梯度值调整模型参数;
  5. 将梯度重置为 0(用于下一循环)。

在模型训练过程中,可以使用验证集数据评价训练过程中的模型精度,以便控制训练过程。模型验证就是用验证数据进行模型推理,前向计算得到模型输出,但不反向计算模型误差,因此需要设置 torch.no_grad()。

使用 PyTorch 进行模型训练的例程如下。

    # (4) 训练 Squeeze 模型
    epoch_list = []  # 记录训练轮次
    loss_list = []  # 记录训练集的损失值
    accu_list = []  # 记录验证集的准确率
    num_epochs = 100  # 训练轮次
    for epoch in range(num_epochs):  # 训练轮次 epoch
        running_loss = 0.0  # 每个轮次的累加损失值清零
        for step, data in enumerate(train_loader, start=0):  # 迭代器加载数据
            optimizer.zero_grad()  # 损失梯度清零
            inputs, labels = data  # inputs: [batch,3,32,32] labels: [batch]
            outputs = model(inputs.to(device))  # 正向传播
            loss = criterion(outputs, labels.to(device))  # 计算损失函数
            loss.backward()  # 反向传播
            optimizer.step()  # 参数更新

            # 累加训练损失值
            running_loss += loss.item()
            # if step%100==99:  # 每 100 个 step 打印一次训练信息
            #     print("\t epoch {}, step {}: loss = {:.4f}".format(epoch, step, loss.item()))

        # 计算每个轮次的验证集准确率
        with torch.no_grad():  # 验证过程, 不计算损失函数梯度
            outputs_valid = model(valid_images.to(device))  # 模型对验证集进行推理, [batch, 10]
        pred_labels = torch.max(outputs_valid, dim=1)[1]  # 预测类别, [batch]
        accuracy = torch.eq(pred_labels, valid_labels.to(device)).sum().item() / valid_size * 100  # 计算准确率
        print("Epoch {}: train loss={:.4f}, accuracy={:.2f}%".format(epoch, running_loss, accuracy))

        # 记录训练过程的统计数据
        epoch_list.append(epoch)  # 记录迭代次数
        loss_list.append(running_loss)  # 记录训练集的损失函数
        accu_list.append(accuracy)  # 记录验证集的准确率

程序运行结果如下:

Epoch 0: train loss=900.4685, accuracy=8.59%
Epoch 1: train loss=900.4323, accuracy=10.94%
Epoch 2: train loss=900.4668, accuracy=10.55%
Epoch 3: train loss=900.4570, accuracy=10.94%

Epoch 98: train loss=193.8689, accuracy=80.86%
Epoch 99: train loss=192.4832, accuracy=80.86%

比较特殊地,经过 20 轮左右的训练,模型的训练损失几乎没有降低,使用验证集中的图片进行验证,模型的准确率极低,表明模型没有有效地学习特征,这可能是梯度消失导致的。经过 30 多轮的训练,训练损失开始降低,检验准确率也逐渐增大,表明此时模型训练进入正常状态。经过 100 轮的训练,验证集的准确率达到 80%左右。

在这里插入图片描述


3.5 SqueezeNet 模型的保存与加载

模型训练好以后,将模型保存起来,以便下次使用。PyTorch 中模型保存主要有两种方式,一是保存模型权值,二是保存整个模型。本例使用 model.state_dict() 方法以字典形式返回模型权值,torch.save() 方法将权值字典序列化到磁盘,将模型保存为 .pth 文件。

    # (5) 保存 Squeeze 网络模型
    save_path = "../models/Squeeze_Cifar1"
    model_cpu = model.cpu()  # 将模型移动到 CPU
    model_path = save_path + ".pth"  # 模型文件路径
    torch.save(model.state_dict(), model_path)  # 保存模型权值
    # 优化结果写入数据文件
    result_path = save_path + ".csv"  # 优化结果文件路径
    WriteDataFile(epoch_list, loss_list, accu_list, result_path)

使用训练好的模型,首先要实例化模型类,然后调用 load_state_dict() 方法加载模型的权值参数。

    # 以下模型加载和模型推理,可以是另一个独立的程序
    # (6) 加载 Squeeze 网络模型进行推理
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检测并指定设备
    # 加载 Squeeze 预训练模型
    model = SqueezeNet1(num_classes=10)  # 实例化 Squeeze 网络模型
    model.to(device)  # 将网络分配到指定的device中
    model_path = "../models/Squeeze_Cifar1.pth"
    model.load_state_dict(torch.load(model_path))
    model.eval()  # 模型推理模式

需要特别注意的是:

(1)PyTorch 中的 .pth 文件只保存了模型的权值参数,而没有模型的结构信息,因此必须先实例化模型对象,再加载模型参数。

(2)模型对象必须与模型参数严格对应,才能正常使用。注意即使都是 Squeeze 模型,模型类的具体定义也可能有细微的区别。如果从一个来源获取模型类的定义,从另一个来源获取模型参数文件,就很容易造成模型结构与参数不能匹配。

(3)无论从 PyTorch 模型仓库加载的模型和参数,或从其它来源获取的预训练模型,或自己训练得到的模型,模型加载的方法都是相同的,也都要注意模型结构与参数的匹配问题。


3.6 模型检验

使用加载的 SqueezeNet 模型,输入新的图片进行模型推理,可以由模型输出结果确定输入图片所属的类别。

使用测试集数据进行模型推理,根据模型预测结果与图片标签进行比较,可以检验模型的准确率。模型验证集与模型检验集不能交叉使用,但为了简化例程在本程序中未做区分。

    # (7) 模型检测
    correct = 0
    total = 0
    for data in test_loader:  # 迭代器加载测试数据集
        imgs, labels = data  # torch.Size([batch,3,32,32) torch.Size([batch])
        # print(imgs.shape, labels.shape)
        outputs = model(imgs.to(device))  # 正向传播, 模型推理, [batch, 10]
        labels_pred = torch.max(outputs, dim=1)[1]  # 模型预测的类别 [batch]
        # _, labels_pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += torch.eq(labels_pred, labels.to(device)).sum().item()
    accuracy = 100. * correct / total
    print("Test samples: {}".format(total))
    print("Test accuracy={:.2f}%".format(accuracy))

使用测试集进行模型推理,测试模型准确率为 81.86%。

Test samples: 10000
Test accuracy=81.86%


3.7 模型推理

使用加载的 SqueezeNet 模型,输入新的图片进行模型推理,可以由模型输出结果确定输入图片所属的类别。

从测试集中提取几张图片,或者读取图像文件,进行模型推理,获得图片的分类类别。在提取图片或读取文件时,要注意对图片格式和图片大小进行适当的转换。

    # (8) 提取测试集图片进行模型推理
    batch = 8  # 批次大小
    data_set = torchvision.datasets.CIFAR10(root='../dataset', train=False,
                                           download=False, transform=None)
    plt.figure(figsize=(9, 6))
    for i in range(batch):
        imgPIL = data_set[i][0]  # 提取 PIL 图片
        label = data_set[i][1]  # 提取 图片标签
        # 预处理/模型推理/后处理
        imgTrans = transform(imgPIL)  # 预处理变换, torch.Size([3,32,32])
        imgBatch = torch.unsqueeze(imgTrans, 0)  # 转为批处理,torch.Size([batch=1,3,32,32])
        outputs = model(imgBatch.to(device))  # 模型推理, 返回 [batch=1, 10]
        indexes = torch.max(outputs, dim=1)[1]  # 注意 [batch=1], device = 'device
        index = indexes[0].item()  # 预测类别,整数
        # 绘制第 i 张图片
        imgNP = np.array(imgPIL)  # PIL -> Numpy
        out_text = "label:{}/model:{}".format(classes[label], classes[index])
        plt.subplot(2, 4, i+1)
        plt.imshow(imgNP)
        plt.title(out_text)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

结果如下。

在这里插入图片描述

    # (9) 读取图像文件进行模型推理
    from PIL import Image
    filePath = "../images/img_car_01.jpg"  # 数据文件的地址和文件名
    imgPIL = Image.open(filePath)  # PIL 读取图像文件, <class 'PIL.Image.Image'>

    # 预处理/模型推理/后处理
    imgTrans = transform(imgPIL)  # 预处理变换, torch.Size([3, 32, 32])
    imgBatch = torch.unsqueeze(imgTrans, 0)  # 转为批处理,torch.Size([batch=1, 3, 32, 32])
    outputs = model(imgBatch.to(device))  # 模型推理, 返回 [batch=1, 10]
    indexes = torch.max(outputs, dim=1)[1]  # 注意 [batch=1], device = 'device
    percentages = nn.functional.softmax(outputs, dim=1)[0] * 100
    index = indexes[0].item()  # 预测类别,整数
    percent = percentages[index].item()  # 预测类别的概率,浮点数

    # 绘制第 i 张图片
    imgNP = np.array(imgPIL)  # PIL -> Numpy
    out_text = "Prediction:{}, {}, {:.2f}%".format(index, classes[index], percent)
    print(out_text)
    plt.imshow(imgNP)
    plt.title(out_text)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

结果如下。

在这里插入图片描述


4. 使用 SqueezeNet 预训练模型进行图像分类

Torchvision.models 包和 Torch Hub 中不仅提供了 SqueezeNet 模型类,也提供了在 ImageNet 数据集上训练好的预训练模型,可以直接用来进行图像分类或进行迁移学习。

注意问题:

  1. SqueezeNet 预训练模型大小约 4.8 MB,在 ImageNet 数据集上 Top-5 准确率 80.3%,非常轻量高效。

  2. SqueezeNet 模型有 v1.0 和 v1.1 两个版本,二者的大小、性能相同,v1.1 版本的计算量更小。

  3. 加载的 SqueezeNet 预训练模型是在ImageNet 数据集上训练,模型的输入是形状为 (1,3,H,W) 的 RGB 图像的 batch,H, W 至少为 224。图像要加载到 [0,1] 范围,使用均值 [0.485, 0.456, 0.406] 和标准差[0.229, 0.224, 0.225] 进行归一化。

使用 SqueezeNet 预训练模型进行图像分类的完整例程如下。

# Begin_Squeeze_3.py
# SqueezeNet model for beginner with PyTorch
# 加载 SqueezeNet 预训练模型和参数,对图像进行分类
# Copyright: youcans@qq.com
# Crated: Huang Shan, 2023/06/04

# _*_coding:utf-8_*_
import torch
from torchvision import models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
import numpy as np

if __name__ == '__main__':

    # (1) 加载 Squeeze/PyTorch 预训练模型
    model = models.squeezenet1_0(pretrained=True)  # torchvision.models 方式加载
    # model = torch.hub.load('pytorch/vision:v0.10.0', 'squeezenet1_0', pretrained=True)  # torch.hub 方式加载
    # model = torch.hub.load('pytorch/vision:v0.10.0', 'squeezenet1_1', pretrained=True)
    model.eval()

    # (2) 定义输入图像的预处理变换,将 [0,1] 的 PILImage 转换为 [-1,1] 的Tensor
    transform = transforms.Compose([  # 定义图像变换组合
        transforms.Resize([256,256]),  # 图像大小调整为 (w,h)=(256,256)
        transforms.CenterCrop([224,224]),  # 图像中心裁剪为 (w,h)=(224,224)
        transforms.ToTensor(),  # 将图像转换为张量 Tensor
        transforms.Normalize(  # 对图像进行归一化
            mean=[0.485, 0.456, 0.406],  # 均值
            std=[0.229, 0.224, 0.225]  # 标准差
        )])

    # (3) 加载输入图像并进行预处理
    from PIL import Image
    filePath = "../images/img_car_01.jpg"  # 数据文件的地址和文件名
    imgPIL = Image.open(filePath)  # PIL 读取图像文件, <class 'PIL.Image.Image'>
    # 预处理/模型推理/后处理
    imgTrans = transform(imgPIL)  # 预处理变换, torch.Size([3,224,224])
    input_batch = torch.unsqueeze(imgTrans, 0)  # 转为批处理,torch.Size([batch=1,3,224,224])

    # (4) 模型推理
    with torch.no_grad():
        outputs = model(input_batch)  # 返回所有类别的置信度score,torch.Size([batch, 1000])
    # _, index = torch.max(outputs, 1)  # Top-1 类别的索引,tensor([208])
    # print("index: ", index.item())  # 208 : sports car, sport car

    # (5) 模型输出后处理
    # 读取 ImageNet 文本格式类别名称文件
    with open("../dataset/imagenet_classes.txt") as f:  # 类别名称保存为 txt 文件
        categories = [line.strip() for line in f.readlines()]
    print(type(categories), len(categories))  # <class 'list'> 1000

    # 计算所有类别的概率
    probabilities = torch.nn.functional.softmax(outputs[0], dim=0) * 100  # 所有类别的概率,torch.Size([batch, 1000])
    # 查找 Top-5 类别的索引
    top5_prob, top5_idx = torch.topk(probabilities, 5)  # Top-5 类别的概率和索引, torch.Size([5])
    print("Top-5 possible categories:")
    for i in range(top5_prob.size(0)):
        print(top5_idx[i], categories[top5_idx[i]], top5_prob[i].item())

    # (6) 图像分类结果的可视化
    import cv2
    imgCV = cv2.cvtColor(np.asarray(imgPIL), cv2.COLOR_RGB2BGR)  # PIL 转换为 CV 格式
    out_text = f"{categories[top5_idx[0]]}, {top5_prob[0].item():.3f}"  # 类别标签 + 概率
    cv2.putText(imgCV, out_text, (25, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)  # 在图像上添加类别标签
    cv2.imshow('Image classification', imgCV)
    key = cv2.waitKey(0)  # delay=0, 不自动关闭
    cv2.destroyAllWindows()

    # # 绘制图片
    # imgNP = np.array(imgPIL)  # PIL -> Numpy
    # out_text = f"{categories[top5_idx[0]]}, {top5_prob[0].item():.3f}"  # 类别标签 + 概率
    # print(out_text)
    # plt.title(out_text)
    # plt.imshow(imgNP)
    # plt.axis('off')
    # plt.tight_layout()
    # plt.show()

图像分类结果如下。

Top-5 possible categories:
tensor(817) 817: 'sports car, sport car', 62.149803161621094
tensor(436) 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', 11.559002876281738
tensor(511) 511: 'convertible', 5.815169811248779
tensor(656) 656: 'minivan', 4.306943893432617
tensor(627) 627: 'limousine, limo', 4.196065902709961

在这里插入图片描述


参考文献:

  1. Forrest Iandols, Song Han, Matthew Moskewicz, et al. SqueezeNet: AlexNet Level Accuracy with 50x Fewer Parameters and <0.5 MB Model Size, 2017

【本节完】


版权声明:
欢迎关注『youcans动手学模型』系列
转发请注明原文链接:
【youcans动手学模型】SqueezeNet 模型-CIFAR10图像分类
Copyright 2023 youcans, XUPT
Crated:2023-06-27


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/692458.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

少有人告诉你!工科硕士对应届生的肺腑之言!

自己是电子信息工科硕士狗一枚&#xff0c;自认为毕业于985院校还可以。刚毕业时&#xff0c;去了一家世界500强公司&#xff0c;校招签约时只说是“技术岗”&#xff0c;没有指定具体岗位。等待毕业的时间里&#xff0c;对新公司、新岗位充满了无限的期待&#xff0c;认为自己…

5、R语言所需安装包安装教程

一、R包安装 1.鼠标右键&#xff0c;以管理员的身份运行&#xff0c;然后点击确定。 2.点击安装导向中的下一步。 3.修改安装路径&#xff0c;安装到自己所需的路径&#xff0c;然后点击下一步。 4.勾选自己所需组件&#xff0c;然后点击下一步。 5.启动选项&#xff0c;选…

百万奖金悬赏大模型不擅长的任务!这 11 个任务模型越大,效果越差!

夕小瑶科技说 原创 作者 | 智商掉了一地、Python 去年咱们在介绍百万悬赏时提到&#xff0c;“海量资源砸出的大模型真的会一直那么香吗&#xff1f;”&#xff0c;目前来看&#xff0c;自打 ChatGPT 横空出世引领一众大模型开辟新的生活和工作方式以来&#xff0c;还是挺香的…

vue — 高德地图实现来回切换卫星图

默认初始化地图展示标准3d地图&#xff08;这里添加蒙层&#xff09; initMap () {this.mapObj new AMap.Map(mapContainer, {features: [bg, road, point, building],showLabel: true,rotateEnable: false,pitchEnable: false,zoom: 17,pitch: 65,rotation: 45,viewMode: 3D,…

一起学SF框架系列4.8-模块context-事件机制(Event)

ApplicationContext中的事件处理是通过ApplicationEvent类和ApplicationListener接口提供的。如果将实现ApplicationListener接口的bean部署到上下文中&#xff0c;则每次将ApplicationEvent发布到ApplicationContext时&#xff0c;都会通知该bean。从本质上讲&#xff0c;这是…

Allegro如何使用打印预览功能操作指导

Allegro如何使用打印预览功能操作指导 Allegro时常需要使用打印功能,将某个视图打印成pdf文件,如下图 在打印成pdf文件之前,可以使用打印预览的功能,具体操作如下 点击shape Add Rect命令Options出现如下选项<

阿里云地域和可用区分布表

阿里云服务器地域和可用区有哪些&#xff1f;阿里云服务器地域节点遍布全球29个地域、88个可用区&#xff0c;包括中国大陆、中国香港、日本、美国、新加坡、孟买、泰国、首尔、迪拜等地域&#xff0c;同一个地域下有多个可用区可以选择&#xff0c;阿里云服务器网分享2023新版…

Nginx的Location和Rewrite

目录 Rewrite简介 1.0 Rewrite实际场景 1.1 Rewrite跳转场景 1.2 Rewrite跳转实现 1.3 Nginx正则表达式 1.4 Rewrite命令&&语法格式 1.5 flag标记说明 2 Location分类 2.1 Location优先级 3 Rewrite&&Location比较 4 场景跳转实验 4.1 基于域名的跳转 …

vcruntime140.dll无法继续执行代码怎么办

今天打开photoshop软件的时候&#xff0c;突然间就打不开&#xff0c;电脑报错由于找不到vcruntime140.dll&#xff0c;无法继续执行此代码&#xff0c;然后我就把photoshop卸载了&#xff0c;再重新安装&#xff0c;依然还是报错。这个可怎么办&#xff1f;vcruntime140.dll如…

【笔记】肥胖代码:减肥的秘密

直接原因与根本原因 直接原因与根本原因的区别是什么&#xff1f;直接原因是直接造成体重增加的原因&#xff0c;根本原因是导致事物发生变化的根源。 以酗酒为例。酗酒的原因是什么&#xff1f; 直接原因是饮酒过量。这是不可否认的事&#xff0c;但显然不能解决问题。直接…

佑友防火墙默认口令及RCE漏洞

先用fofa脚本爬取所有碧海威相关资产&#xff08;fofa脚本下载地址&#xff1a;&#xff09; python3 fofa-cwillchris.py -k title"佑友防火墙" 将上面爬取到的文件&#xff08;一般是final****.txt&#xff09;移动到脚本目录下&#xff0c;保存为1.txt ./佑友防…

浅谈智能安全用电系统在轨道交通中的应用

安科瑞 华楠 摘要&#xff1a; 随着轨道交通电气设备的增加和用电负荷的变大&#xff0c;用电安全问题愈发突出&#xff0c;而对电力状况在线监测和故障预警是实现安全用电的关键。本文研究了轨道交通安全用电智能监测系统。该系统通过电力载波技术可利用原电缆进行数据传输&am…

适用ddddocr自动化测试验证码识别

原打算使用tesseract进行验证码识别的但后面发现实在太辣鸡了 不知道tesseract以及没安装的可以看这篇文章&#xff1a; tesseract安装以及联调python 使用tesseract的代码&#xff1a; import pytesseract from PIL import Image, ImageEnhance """ 步骤①&…

DOTA大环配体化合物:DOTA PEG5 amine/azide/DBCO,特点分享说明

一、DOTA-PEG5-amine&#xff0c;DOTA PEG5 NH2&#xff0c;DOTA-PEG5-amine HCl salt&#xff0c;DOTA五聚乙二醇氨基Product structure&#xff1a; 1.CAS No&#xff1a;N/A 2.Molecular formula&#xff1a;C28H54N6O12 3.Molecular weight&#xff1a;666.8 5.Appearance …

【图像处理】去雾代码收(附halcon、python、C#、VB、matlab)

【图像处理】去雾代码收&#xff08;附halcon、python、C#、VB、matlab&#xff09; 一、halcon算法1.1 halcon算法源码1.2 halcon算法效果图![在这里插入图片描述](https://img-blog.csdnimg.cn/8ad5217a59be4de29b5a7b6eee997b85.png#pic_center) 二、opencv算法2.1 python源…

了解架构是什么

前言&#xff1a; \textcolor{Green}{前言&#xff1a;} 前言&#xff1a; &#x1f49e;这个专栏就专门来记录一下寒假参加的第五期字节跳动训练营 &#x1f49e;从这个专栏里面可以迅速获得Go的知识 了解架构是什么 01. 什么是架构1.1 定义1.1 问题1.2 什么是架构 - 单机1.3 …

Linux服务器同步Windows目录同步-rsync

前言 最近需要&#xff0c;Linux的服务器同步Windows的一个目录。查了下&#xff0c;大概有三种方法&#xff1a;网盘同步&#xff1b;rsync同步&#xff1b;挂载目录。 网盘同步&#xff0c;可以选择搭建一个Nextcloud 。但是问题在于&#xff0c;我需要的是&#xff0c;客户…

react context上下文与vue中 provide inject的用法区别

一、react中&#xff1a; 数据传递 1、引入createContext import { createContext } from "react"; 2、创建并导出 export const FspThemeContext createContext({}); 3、传递数据&#xff08;value项不能缺少&#xff01;&#xff01;&#xff09; ①不解构…

微流控芯片压力和流量的超高精度串级控制解决方案

摘要&#xff1a;针对微流控芯片压力驱动进样系统中压力和流量的高精度控制&#xff0c;本文提出了国产化替代解决方案。解决方案采用了积木式结构&#xff0c;便于快速搭建起气压驱动进样系统。解决方案的核心是采用了串级控制模式&#xff0c;结合高精度的传感器、电气比例阀…

JMeter如何进行多服务器远程测试

JMeter是Apache软件基金会的开源项目&#xff0c;主要来做功能和性能测试&#xff0c;用Java编写。 我们一般都会用JMeter在本地进行测试&#xff0c;但是受到单个电脑的性能影响&#xff0c;往往达不到性能测试的要求&#xff0c;无法有效的模拟高并发的场景&#xff0c;那么…