【youcans动手学模型】Xception 模型-CIFAR10图像分类

news2024/12/21 16:34:18

欢迎关注『youcans动手学模型』系列
本专栏内容和资源同步到 GitHub/youcans


【youcans动手学模型】Xception 模型-CIFAR10图像分类

    • 1. Xception 神经网络模型
      • 1.1 模型简介
      • 1.2 论文介绍
      • 1.3 分析与讨论
    • 2. 在 PyTorch 中定义 Xception 模型类
      • 2.1 深度可分离卷积
      • 2.2 带残差连接的深度可分离卷积模块
      • 2.3 简化的 Xception 模型类
      • 2.4 完整的 Xception 模型类
    • 3. 基于 Xception 模型的 CIFAR10 图像分类
      • 3.1 PyTorch 建立神经网络模型的基本步骤
      • 3.2 加载 CIFAR10 数据集
      • 3.3 建立 Xception 网络模型
      • 3.4 Xception 模型训练
      • 3.5 Xception 模型的保存与加载
      • 3.6 模型检验
      • 3.7 模型推理


本文用 PyTorch 实现 Xception 网络模型,使用 CIFAR10 数据集训练模型,进行图像分类。


1. Xception 神经网络模型

Francois Chollet 在 2017 年发表的论文“ Xception: Deep Learning with Depthwise Separable Convolutions ”,提出了 Xception 网络模型。本文作者 Francois Chollet 来自 Google,也是 Keras 的作者 。

【论文下载地址】
Xception: Deep Learning with Depthwise Separable Convolutions

【GitHub地址】:作者例程

【PyTorch实现】:参考例程

在这里插入图片描述


1.1 模型简介

传统的卷积操作同时对输入特征图的空间交互性(spatial correlations)和跨通道交互性(cross-channel correlations)进行映射。

在这里插入图片描述

Inception 系列结构致力于对该过程进行分解,在一定程度上实现了跨通道相关性和空间相关性的解耦。Xception 与深度可分离卷积类似,使用 “Extreme Inception” 实现了跨通道相关性和空间相关性的完全解耦。

“深度可分离卷积(Depthwise Separable Convolution,DSC)由深度卷积(depthwise convolution)和逐点卷积(pointwise convolution)连接组成,实现了跨通道相关性和空间相关性的完全解耦。

  • 深度卷积,每个卷积核只作用于单一通道的分组卷积,分组数等于输入通道数,实现空间相关性的映射。
  • 逐点卷积,在级联通道上进行 1*1 卷积,实现跨通道相关性的映射。

在这里插入图片描述

以 16 个输入通道和 32 个输出通道上的 3x3卷积层为例:

  • 常规的卷积操作有 16*32*3*3=4608 个参数。

  • 在深度可分离卷积中,第一步空间卷积有 16*3*3= 144 个参数,第二步深度方向卷积有 16*32*1*1= 512 个参数,共 656 个参数。

因此,深度可分离卷积大大减少了参数计数,具有更高效的复杂性,而且还保持了跨通道功能。


1.2 论文介绍

【论文摘要】

我们将卷积神经网络中的 Inception modules 解释为常规卷积运算和深度可分离卷积运算的中间步骤。从这个角度来看,深度上可分离的卷积可以理解为具有最大分支数量的 Inception modules 。

由此,我们提出一种新型深度卷积神经网络架构,用深度可分离卷积取代了 Inception模块,称为 Xception 架构。Xception 体系结构具有与 Inception V3 相同数量的参数,在 ImageNet 数据集上的性能略微优于Inception V3,在包括 3.5亿张图像和 17000个类的更大图像分类数据集上显著优于 InceptionV3。


【论文背景】

传统的卷积操作同时对输入特征图的空间交互性(spatial correlations)和跨通道交互性(cross-channel correlations)进行映射。例如,卷积层的输入尺寸为 h*w*d_in,卷积核尺寸为 s*s*d_in,卷积操作既在 s*s 的空间范围上对特征图进行信息融合,又对通道数为 d_in 的输入特征图进行跨通道的信息融合。

多分支的 Inception 结构在一定程度上对跨通道相关性和空间相关性进行解耦。例如,1*1 卷积分支,相当于只进行跨通道融合,不进行空间卷积(类似于 RGB 通道融合为灰度图像);先做 1*1 卷积进行跨通道融合,再做 3*3 卷积相当于进行空间信息融合,也可以在一定程度上进行跨通道相关性和空间相关性的解耦。

在这里插入图片描述

在这里插入图片描述

考虑图4的 Inception 的极端情况,首先使用 1*1 卷积来映射输入的跨通道相关性,然后将每个输出通道作为一组(而不是如 Inception 分为 3~4 组),将 h*h*d_in 的输入分为 d_in 组,使用 3*3 卷积来映射空间相关性。

这就是 Extreme 版本的 “Inception”,意思是 Inception 体系结构的更强版本。Extreme 版本的 “Inception” 与基于深度可分离卷积(DSC)的 MobileNet非常相似,二者的区别在于:

(1)操作顺序:DSC 先做单通道空间卷积,再做跨通道 1x1 卷积,Xception 先做 1*1 卷积再做 3*3 卷积 。
(2)激活函数:DSC 在深度卷积与逐点卷积之间没有 ReLU 层,Xception 在两次卷积之后都有 ReLU 层。


【模型结构】

Xception 提出了一种完全基于深度可分离卷积层的卷积神经网络架构,该网络架构基于以下的假设:卷积神经网络特征图中的跨通道相关性和空间相关性的映射可以完全解耦。这个假设是 Inception 架构假设的更强版本,所以命名为 Xception ,代表 “Extreme Inception”。

在这里插入图片描述

完整的 Xception 模型是具有残差连接的深度可分离卷积层的线性堆叠,具有36个卷积层组成的特征提取结构。

  • 36个卷积层被构造成14个模块,除第一个和最后一个模块外,其它所有模块周围都有线性残差连接。。
  • 这些模块分为三个连续的虚拟流:Entry/Middle/Exit 三个flow,每个flow内部使用不同的重复模块。
  • Entry flow 主要是用来不断下采样,减小空间尺寸;Middle 没有下采样,用来学习关联关系,优化特征;Exit flow 用于汇总、整理特征。

在这里插入图片描述

【模型性能】

Xception 作为Inception v3的改进,主要是在Inception v3的基础上引入了depthwise separable convolution,在基本不增加网络复杂度的前提下提高了模型的效果。

根据论文的报道,Xception 在精度、参数量、运算时间几个方面都略优于 Inception V3,但优势都不太大。

在这里插入图片描述

【论文结论】

本文提出了一种新的卷积网络架构 Xception,通过使用深度可分离的卷积代替 Inception 模块来改进 Inception 系列体系结构,构建深度可分离卷积堆栈模型。Xception 的参数数量与 Inception V3 相似,在精度、参数量、运算时间上略优于 Inception V3。


1.3 分析与讨论

(1)需要特别注意的是,本文虽然分析了 Inception 模块的 Extreme 版本本质上也是一种深度可分离的卷积,并讨论了与深度可分离卷积(DSC)的区别,但是,在正文中使用的 Xception 模型架构中所使用的,并不是 Extreme 版本的 Inception 模块,而就是直接使用深度可分离卷积(DSC)。

所以,文中所说的 “Extreme 版本的 Inception 模块” 真的是 Inception,而 Xception(Extreme Inception)模型中真正用的是深度可分离卷积(DSC),与 Inception 并没有关系。

(2)Google 公司的另一些研究者在 2016 年还提出了 Inception-ResNet,精度不仅优于 Inception V3,也优于 Xception。所以,Xception 在精度上的提高,到底是由于深度可分离卷积,还是由于引入了残差连接,从论文中并不能得到明确的结论。


2. 在 PyTorch 中定义 Xception 模型类

总的来说,Xception 模型是一种网络架构,针对不同的任务可以进行不同的网络结构设计和超参数配置。

本节先面向 CIFAR10 数据集图像分类问题,详细介绍一个简化版 Xception 模型类的构造过程。最后也将给出复现论文的完整版 Xception 模型类的例程。

2.1 深度可分离卷积

深度可分离卷积(DSC)是 Xception 网络架构的核心,由深度卷积(depthwise convolution)和逐点卷积(pointwise convolution)连接组成,实现了跨通道相关性和空间相关性的完全解耦。

深度可分离卷积模块(DSC)的例程如下。

# 定义深度可分离卷积
class SeparableConv2d(nn.Module):        
    def __init__(self, in_ch, out_ch, kernel_size, padding, stride=1):
        super(SeparableConv2d, self).__init__()
        # 深度卷积 depthwise, 逐个通道操作, groups=in_channels=out_channels
        self.depth_conv = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch)
        # 逐点卷积 pointwise, 1x1 卷积
        self.point_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0, groups=1)

    def forward(self, x):
        out = self.depth_conv(x)
        out = self.point_conv(out)
        return out

2.2 带残差连接的深度可分离卷积模块

深度可分离卷积在深度卷积和逐点卷积之间不使用 ReLU,在逐点卷积之后加入 ReLu 和 BN,这两层可以在 SeparableConv2d 类中定义,也可以在 ResDSC 类中定义。

Xception 网络架构中使用带残差连接的深度可分离卷积模块。简化的带残差连接的深度可分离卷积模块的例程如下。

# 定义 带残差连接的深度可分离卷积模块
class ResDSC(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ResDSC, self).__init__()
        self.residual = nn.Sequential(SeparableConv2d(in_ch, out_ch, kernel_size=(3, 3), padding=1),
                                      nn.BatchNorm2d(out_ch),
                                      nn.ReLU(),
                                      SeparableConv2d(out_ch, out_ch, kernel_size=(3, 3), padding=1),
                                      nn.BatchNorm2d(out_ch),
                                      nn.MaxPool2d((3, 3), stride=(2, 2), padding=1))
        self.shortcut = nn.Sequential(nn.Conv2d(in_ch, out_ch, (1, 1), stride=(2, 2)),
                                      nn.BatchNorm2d(out_ch))

    def forward(self, x):
        residual = self.residual(x)
        shortcut = self.shortcut(x)
        output = shortcut + residual
        return output

2.3 简化的 Xception 模型类

Xception 模型是一种网络架构,针对不同的任务可以进行不同的网络结构设计和超参数配置。

相对于 ImageNet 数据集来说,CIFAR10 数据集的规模较小、图片尺寸较小,使用论文中的 Xception 网络架构过于庞大和复杂。因此,面向 CIFAR10 数据集图像分类问题,我们构建一个简化的 Xception 模型类,该模型类以带残差连接的深度可分离卷积模块为核心,但没有使用复杂的 Entry/Middle/Exit flow。该简化模型的速度很快,性能也还不错。

# 简化的 Xception 模型类
class mini_Xception(nn.Module):
    def __init__(self, num_classes=10):
        super(mini_Xception, self).__init__()
        self.base = nn.Sequential(nn.Conv2d(3, 16, (3, 3), stride=(1, 1)),
                                  nn.BatchNorm2d(16),
                                  nn.ReLU(),
                                  nn.Conv2d(16, 32, (3, 3), stride=(1, 1)),
                                  nn.BatchNorm2d(32),
                                  nn.ReLU())
        self.module1 = ResDSC(in_ch=32, out_ch=32)
        self.module2 = ResDSC(in_ch=32, out_ch=64)
        self.module3 = ResDSC(in_ch=64, out_ch=64)
        self.module4 = ResDSC(in_ch=64, out_ch=128)
        # output
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.base(x)
        x = self.module1(x)
        x = self.module2(x)
        x = self.module3(x)
        x = self.module4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

2.4 完整的 Xception 模型类

为了复现原始论文中的 Xception 模型,定义一个完整的 Xception 模型类如下。默认类别数量 num_class=100,可以在实例化模型时根据任务需求来设置。

注意,完整的 Xception 模型类需要的 GPU 内存很高,在训练时要减小批大小 batchsize。

# 定义 深度可分离卷积
class SeparableConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, padding, bias=False):
        super(SeparableConv2d, self).__init__()
        self.depth_conv = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch, bias=bias)
        self.point_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depth_conv(x)
        out = self.point_conv(out)
        return out

class Xception(nn.Module):
    def __init__(self, input_channel, num_classes=10):
        super(Xception, self).__init__()

        # Entry Flow
        self.entry_flow1 = nn.Sequential(
            nn.Conv2d(input_channel, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )

        self.entry_flow2 = nn.Sequential(
            SeparableConv2d(64, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            SeparableConv2d(128, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.entry_flow2_residual = nn.Conv2d(64, 128, kernel_size=1, stride=2)

        self.entry_flow3 = nn.Sequential(
            nn.ReLU(True),
            SeparableConv2d(128, 256, 3, 1),
            nn.BatchNorm2d(256),

            nn.ReLU(True),
            SeparableConv2d(256, 256, 3, 1),
            nn.BatchNorm2d(256),

            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.entry_flow3_residual = nn.Conv2d(128, 256, kernel_size=1, stride=2)

        self.entry_flow4 = nn.Sequential(
            nn.ReLU(True),
            SeparableConv2d(256, 728, 3, 1),
            nn.BatchNorm2d(728),

            nn.ReLU(True),
            SeparableConv2d(728, 728, 3, 1),
            nn.BatchNorm2d(728),

            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.entry_flow4_residual = nn.Conv2d(256, 728, kernel_size=1, stride=2)

        # Middle Flow
        self.middle_flow = nn.Sequential(
            nn.ReLU(True),
            SeparableConv2d(728, 728, 3, 1),
            nn.BatchNorm2d(728),

            nn.ReLU(True),
            SeparableConv2d(728, 728, 3, 1),
            nn.BatchNorm2d(728),

            nn.ReLU(True),
            SeparableConv2d(728, 728, 3, 1),
            nn.BatchNorm2d(728)
        )

        # Exit Flow
        self.exit_flow1 = nn.Sequential(
            nn.ReLU(True),
            SeparableConv2d(728, 728, 3, 1),
            nn.BatchNorm2d(728),

            nn.ReLU(True),
            SeparableConv2d(728, 1024, 3, 1),
            nn.BatchNorm2d(1024),

            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.exit_flow1_residual = nn.Conv2d(728, 1024, kernel_size=1, stride=2)
        self.exit_flow2 = nn.Sequential(
            SeparableConv2d(1024, 1536, 3, 1),
            nn.BatchNorm2d(1536),
            nn.ReLU(True),

            SeparableConv2d(1536, 2048, 3, 1),
            nn.BatchNorm2d(2048),
            nn.ReLU(True)
        )
        
        self.linear = nn.Linear(2048, num_classes)

    def forward(self, x):
        entry_out1 = self.entry_flow1(x)
        entry_out2 = self.entry_flow2(entry_out1) + self.entry_flow2_residual(entry_out1)
        entry_out3 = self.entry_flow3(entry_out2) + self.entry_flow3_residual(entry_out2)
        entry_out = self.entry_flow4(entry_out3) + self.entry_flow4_residual(entry_out3)

        middle_out = self.middle_flow(entry_out) + entry_out
        for i in range(7):
            middle_out = self.middle_flow(middle_out) + middle_out
        exit_out1 = self.exit_flow1(middle_out) + self.exit_flow1_residual(middle_out)
        exit_out2 = self.exit_flow2(exit_out1)
        exit_avg_pool = F.adaptive_avg_pool2d(exit_out2, (1, 1))
        exit_avg_pool_flat = exit_avg_pool.view(exit_avg_pool.size(0), -1)
        output = self.linear(exit_avg_pool_flat)
        return output

3. 基于 Xception 模型的 CIFAR10 图像分类

3.1 PyTorch 建立神经网络模型的基本步骤

使用 PyTorch 建立、训练和使用神经网络模型的基本步骤如下。

  1. 准备数据集(Prepare dataset):加载数据集,对数据进行预处理。
  2. 建立模型(Design the model):实例化模型类,定义损失函数和优化器,确定模型结构和训练方法。
  3. 模型训练(Model trainning):使用训练数据集对模型进行训练,确定模型参数。
  4. 模型推理(Model inferring):使用训练好的模型进行推理,对输入数据预测输出结果。
  5. 模型保存与加载(Model saving/loading):保存训练好的模型,以便以后使用或部署。

以下按此步骤讲解 Xception 模型的例程。


3.2 加载 CIFAR10 数据集

通用数据集的样本结构均衡、信息高效,而且组织规范、易于处理。使用通用的数据集训练神经网络,不仅可以提高工作效率,而且便于评估模型性能。

PyTorch 提供了一些常用的图像数据集,预加载在 torchvision.datasets 类中。torchvision 模块实现神经网络所需的核心类和方法, torchvision.datasets 包含流行的数据集、模型架构和常用的图像转换方法。

CIFAR 数据集是一个经典的图像分类小型数据集,有 CIFAR10 和 CIFAR100 两个版本。CIFAR10 有 10 个类别,CIFAR100 有 100 个类别。CIFAR10 每张图像大小为 32*32,包括飞机、小汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车 10 个类别。CIFAR10 共有 60000张图像,其中训练集 50000张,测试集 10000张。每个类别有 6000张图片,数据集平衡。

加载和使用 CIFAR 数据集的方法为:

torchvision.datasets.CIFAR10()
torchvision.datasets.CIFAR100()

CIFAR 数据集可以从官网下载:http://www.cs.toronto.edu/~kriz/cifar.html 后使用,也可以使用 datasets 类自动加载(如果本地路径没有该文件则自动下载)。

下载数据集时,使用预定义的 transform 方法进行数据预处理,包括调整图像尺寸、标准化处理,将数据格式转换为张量。标准化处理所使用 CIFAR10 数据集的均值和方差为 (0.49, 0.48, 0.45), (0.25, 0.24, 0.26)。

transform_train在训练过程中,增加随机性,提高泛化能力

大型训练数据集不能一次性加载全部样本来训练,可以使用 Dataloader 类自动加载数据。Dataloader 是一个迭代器,基本功能是传入一个 Dataset 对象,根据参数 batch_size 生成一个 batch 的数据。

使用 DataLoader 类加载 CIFAR-10 数据集的例程如下。

    # (1) 将[0,1]的PILImage 转换为[-1,1]的Tensor
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.RandomRotation(10),  # 随机旋转
        transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.Resize((32, 32)),  # 图像大小调整为 (w,h)=(32,32)
        transforms.ToTensor(),  # 将图像转换为张量 Tensor
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    # 测试集不需要进行数据增强
    transform = transforms.Compose([
        transforms.Resize((32, 32)),  # 图像大小调整为 (w,h)=(32,32)
        transforms.ToTensor(),  # 将图像转换为张量 Tensor
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])

    # (2) 加载 CIFAR10 数据集
    batchsize = 128
    # 加载 CIFAR10 数据集, 如果 root 路径加载失败, 则自动在线下载
    # 加载 CIFAR10 训练数据集, 50000张训练图片
    train_set = torchvision.datasets.CIFAR10(root='../dataset', train=True,
                                            download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batchsize)
    # 加载 CIFAR10 验证数据集, 10000张验证图片
    test_set = torchvision.datasets.CIFAR10(root='../dataset', train=False,
                                           download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)
    # 创建生成器,用 next 获取一个批次的数据
    valid_data_iter = iter(test_loader)  # _SingleProcessDataLoaderIter 对象
    valid_images, valid_labels = next(valid_data_iter)  # images: [batch,3,32,32], labels: [batch]
    valid_size = valid_labels.size(0)  # 验证数据集大小,batch
    print(valid_images.shape, valid_labels.shape)

    # 定义类别名称,CIFAR10 数据集的 10个类别
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')


3.3 建立 Xception 网络模型

建立一个 Xception 网络模型进行训练,包括三个步骤:

  • 实例化 Xception 模型对象;
  • 设置训练的损失函数;
  • 设置训练的优化器。

为了使用 GPU 设备进行模型训练和模型推理,使用 model.to(device) 将网络分配到指定的设备中。

torch.nn.functional 模块提供了各种内置损失函数,本例使用交叉熵损失函数 CrossEntropyLoss。

torch.optim 模块提供了各种优化方法。本例使用 Adam 优化器,注意要将 model 的参数 model.parameters() 传给优化器对象,以便扫描需要优化的参数。

    # (3) 构造 Xception 网络模型
    model = mini_Xception(num_classes=10)  # 实例化 Xception 网络模型
    model.to(device)  # 将网络分配到指定的device中
    print(model)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()  # 定义损失函数 CrossEntropy
    optimizer = torch.optim.Adam(lr=0.001, params=model.parameters())  # Adam 优化器

3.4 Xception 模型训练

PyTorch 模型训练的基本步骤是:

  1. 前馈计算模型的输出值;
  2. 计算损失函数值;
  3. 计算权重 weight 和偏差 bias 的梯度;
  4. 根据梯度值调整模型参数;
  5. 将梯度重置为 0(用于下一循环)。

在模型训练过程中,可以使用验证集数据评价训练过程中的模型精度,以便控制训练过程。模型验证就是用验证数据进行模型推理,前向计算得到模型输出,但不反向计算模型误差,因此需要设置 torch.no_grad()。

使用 PyTorch 进行模型训练的例程如下。

    # (4) 训练 Xception 模型
    epoch_list = []  # 记录训练轮次
    loss_list = []  # 记录训练集的损失值
    accu_list = []  # 记录验证集的准确率
    num_epochs = 100  # 训练轮次
    for epoch in range(num_epochs):  # 训练轮次 epoch
        running_loss = 0.0  # 每个轮次的累加损失值清零
        for step, data in enumerate(train_loader, start=0):  # 迭代器加载数据
            optimizer.zero_grad()  # 损失梯度清零

            inputs, labels = data  # inputs: [batch,3,32,32] labels: [batch]
            outputs = model(inputs.to(device))  # 正向传播
            loss = criterion(outputs, labels.to(device))  # 计算损失函数
            loss.backward()  # 反向传播
            optimizer.step()  # 参数更新

            # 累加训练损失值
            running_loss += loss.item()
            # if step%100==99:  # 每 100 个 step 打印一次训练信息
            #     print("\t epoch {}, step {}: loss = {:.4f}".format(epoch, step, loss.item()))

        # 计算每个轮次的验证集准确率
        with torch.no_grad():  # 验证过程, 不计算损失函数梯度
            outputs_valid = model(valid_images.to(device))  # 模型对验证集进行推理, [batch, 10]
        pred_labels = torch.max(outputs_valid, dim=1)[1]  # 预测类别, [batch]
        accuracy = torch.eq(pred_labels, valid_labels.to(device)).sum().item() / valid_size * 100  # 计算准确率
        print("Epoch {}: train loss={:.4f}, accuracy={:.2f}%".format(epoch, running_loss, accuracy))

        # 记录训练过程的统计数据
        epoch_list.append(epoch)  # 记录迭代次数
        loss_list.append(running_loss)  # 记录训练集的损失函数
        accu_list.append(accuracy)  # 记录验证集的准确率    

程序运行结果如下:

Epoch 0: train loss=585.6872, accuracy=56.70%
Epoch 1: train loss=454.1163, accuracy=65.40%
Epoch 2: train loss=400.7813, accuracy=68.80%
Epoch 3: train loss=366.8728, accuracy=71.30%

Epoch 98: train loss=143.8827, accuracy=86.30%
Epoch 99: train loss=142.7452, accuracy=85.50%

经过 20 轮左右的训练,使用验证集中的 1000 张图片进行验证,模型准确率可以达到 80%。继续训练可以进一步降低训练损失函数值,经过 100轮左右的训练验证集的准确率保持在 85%左右。

在这里插入图片描述


3.5 Xception 模型的保存与加载

模型训练好以后,将模型保存起来,以便下次使用。PyTorch 中模型保存主要有两种方式,一是保存模型权值,二是保存整个模型。本例使用 model.state_dict() 方法以字典形式返回模型权值,torch.save() 方法将权值字典序列化到磁盘,将模型保存为 .pth 文件。

由于本例程中模型存储在 CUDA 设备上,在保存模型时要将模型移动到 CPU。

    # (5) 保存 Xception 网络模型
    save_path = "../models/Xception_Cifar1"
    model_cpu = model.cpu()  # 将模型移动到 CPU
    model_path = save_path + ".pth"  # 模型文件路径
    torch.save(model.state_dict(), model_path)  # 保存模型权值

使用训练好的模型,首先要实例化模型类,然后调用 load_state_dict() 方法加载模型的权值参数。

    # 以下模型加载和模型推理,可以是另一个独立的程序
    # (6) 加载 Xception 网络模型进行推理
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检测并指定设备
    # 加载 Xception 预训练模型
    model = mini_Xception(num_classes=10)  # 实例化 Xception 网络模型
    model.to(device)  # 将网络分配到指定的device中
    model_path = "../models/Xception_Cifar1.pth"
    model.load_state_dict(torch.load(model_path))
    model.eval()  # 模型推理模式

需要特别注意的是:

(1)PyTorch 中的 .pth 文件只保存了模型的权值参数,而没有模型的结构信息,因此必须先实例化模型对象,再加载模型参数。

(2)模型对象必须与模型参数严格对应,才能正常使用。注意即使都是 LeNet5 模型,模型类的具体定义也可能有细微的区别。如果从一个来源获取模型类的定义,从另一个来源获取模型参数文件,就很容易造成模型结构与参数不能匹配。

(3)无论从 PyTorch 模型仓库加载的模型和参数,或从其它来源获取的预训练模型,或自己训练得到的模型,模型加载的方法都是相同的,也都要注意模型结构与参数的匹配问题。


3.6 模型检验

使用加载的 Xception模型,输入新的图片进行模型推理,可以由模型输出结果确定输入图片所属的类别。

使用测试集数据进行模型推理,根据模型预测结果与图片标签进行比较,可以检验模型的准确率。模型验证集与模型检验集不能交叉使用,但为了简化例程在本程序中未做区分。

    # (7) 模型检测
    correct = 0
    total = 0
    for data in test_loader:  # 迭代器加载测试数据集
        imgs, labels = data  # torch.Size([batch,3,32,32) torch.Size([batch])
        # print(imgs.shape, labels.shape)
        outputs = model(imgs.to(device))  # 正向传播, 模型推理, [batch, 10]
        labels_pred = torch.max(outputs, dim=1)[1]  # 模型预测的类别 [batch]
        # _, labels_pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += torch.eq(labels_pred, labels.to(device)).sum().item()
    accuracy = 100. * correct / total
    print("Test samples: {}".format(total))
    print("Test accuracy={:.2f}%".format(accuracy))

使用测试集进行模型推理,测试模型准确率为 85.01%。

Test samples: 10000
Test accuracy=85.01%


3.7 模型推理

使用加载的 Xception模型,输入新的图片进行模型推理,可以由模型输出结果确定输入图片所属的类别。

从测试集中提取几张图片,或者读取图像文件,进行模型推理,获得图片的分类类别。在提取图片或读取文件时,要注意对图片格式和图片大小进行适当的转换。

    # (8) 提取测试集图片进行模型推理
    batch = 8  # 批次大小
    data_set = torchvision.datasets.CIFAR10(root='../dataset', train=False,
                                           download=False, transform=None)
    plt.figure(figsize=(9, 6))
    for i in range(batch):
        imgPIL = data_set[i][0]  # 提取 PIL 图片
        label = data_set[i][1]  # 提取 图片标签
        # 预处理/模型推理/后处理
        imgTrans = transform(imgPIL)  # 预处理变换, torch.Size([3,32,32])
        imgBatch = torch.unsqueeze(imgTrans, 0)  # 转为批处理,torch.Size([batch=1,3,32,32])
        outputs = model(imgBatch.to(device))  # 模型推理, 返回 [batch=1, 10]
        indexes = torch.max(outputs, dim=1)[1]  # 注意 [batch=1], device = 'device
        index = indexes[0].item()  # 预测类别,整数
        # 绘制第 i 张图片
        imgNP = np.array(imgPIL)  # PIL -> Numpy
        out_text = "label:{}/model:{}".format(classes[label], classes[index])
        plt.subplot(2, 4, i+1)
        plt.imshow(imgNP)
        plt.title(out_text)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

结果如下。

在这里插入图片描述

    # (9) 读取图像文件进行模型推理
    from PIL import Image
    filePath = "../images/img_plane_01.jpg"  # 数据文件的地址和文件名
    imgPIL = Image.open(filePath)  # PIL 读取图像文件, <class 'PIL.Image.Image'>

    # 预处理/模型推理/后处理
    imgTrans = transform["test"](imgPIL)  # 预处理变换, torch.Size([3, 32, 32])
    imgBatch = torch.unsqueeze(imgTrans, 0)  # 转为批处理,torch.Size([batch=1, 3, 32, 32])
    outputs = model(imgBatch.to(device))  # 模型推理, 返回 [batch=1, 10]
    indexes = torch.max(outputs, dim=1)[1]  # 注意 [batch=1], device = 'device
    percentages = nn.functional.softmax(outputs, dim=1)[0] * 100
    index = indexes[0].item()  # 预测类别,整数
    percent = percentages[index].item()  # 预测类别的概率,浮点数

    # 绘制第 i 张图片
    imgNP = np.array(imgPIL)  # PIL -> Numpy
    out_text = "Prediction:{}, {}, {:.2f}%".format(index, classes[index], percent)
    print(out_text)
    plt.imshow(imgNP)
    plt.title(out_text)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

结果如下。

在这里插入图片描述


【参考文献】

Francois Chollet, Xception: Deep Learning with Depthwise Separable Convolutions, 2017

【本节完】


版权声明:
欢迎关注『youcans动手学模型』系列
转发请注明原文链接:
【youcans动手学模型】Xception 模型-CIFAR10图像分类
Copyright 2023 youcans, XUPT
Crated:2023-06-16


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/653381.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【九章斩题录】从尾到头打印链表(JZ6)

精品题解 &#x1f525; 《九章斩题录》 &#x1f448; 猛戳订阅 目录 JZ6 - 从尾到头打印链表 「 法一 」链表元素存入数组后再反转 「 法二 」递归大法 「 法三 」栈 JZ6 - 从尾到头打印链表 &#x1f4da; 题目&#xff1a;输入一个链表的头节点&#xff0c;按链表从…

苹果iOS 17新功能:重置密码 72 小时内可使用旧密码再次重置

一些用户可能会遇到&#xff0c;在修改了 iPhone 密码之后&#xff0c;突然忘记新密码的情况。现在苹果在 iOS 17 中加入了新的解决方案&#xff1a;在重置密码的 72 小时之内&#xff0c;如果用户不小心忘记了新设置的密码&#xff0c;仍然可以使用旧密码进行再次重置。 在重…

TS系列之工具类型Partial、Required、Pick、Record详解,示例

TS系列之工具类型Partial、Required、Pick、Record详解&#xff0c;示例 文章目录 前言一、Partial<Type>二、Required<Type>三、Pick<Type, Keys>四、Record<Keys, Type>总结 前言 本片文章主要利用上一篇所学的keyof做一个延申与扩展的学习。在Type…

OpenGaussDB2.0.1

目录 1. GaussDB版本2. OpenGaussDB介绍3. 单节点安装3.1 环境配置3.2 安装 4. 远程连接设置 1. GaussDB版本 GaussDB的版本&#xff1a; GaussDB 100&#xff1a;目前暂不发布&#xff0c;公司合作伙伴需向华为提交申请&#xff08;GaussDB 100 将在 2020 年被正式命名为 Gaus…

Redis入门 - 3种特殊数据类型

原文首更地址&#xff0c;阅读效果更佳&#xff01; Redis入门 - 3种特殊数据类型 | CoderMast编程桅杆https://www.codermast.com/database/redis/three-special-datatype.html 在我们平常的业务中基本只会使用到Redis的基本数据类型&#xff08;String、List、Hash、Set、S…

浅谈.NET语言开发应用领域

.NET语言是一种跨平台的开发框架&#xff0c;适用于各种应用程序的开发。以下是一些常见的.NET语言开发应用领域&#xff1a; 桌面应用程序开发&#xff1a;使用.NET框架开发的桌面应用程序可以在Windows操作系统上运行&#xff0c;包括Windows Forms和WPF。这些应用程序可以用…

CentOS开机报错““error can‘t find command ‘:‘“处理方法

CentOS开机报错"error cant find command :"处理方法 本文为故障描述和问题记录。converterP2V迁移CentOS7到虚拟机&#xff0c;开机报错"error can’t find command ‘:’"的处理方法。 本文为CentOS7的操作记录&#xff0c;其他版本可以参考&#xff0c;…

CASAIM光学彩色三维扫描仪助力文物艺术品三维数字化3D打印

文物艺术品数字化&#xff0c;实际上是一种文物艺术品信息的记录方式&#xff0c;除了运用视频、照片、录音等多媒体形式将某些文物进行记录外&#xff0c;文物艺术品数字化主要是指针对有形文化遗产采用非接触式扫描得到的三维数字化记录&#xff0c;通过实景三维建模&#xf…

GitHub 上传自己的项目

文章目录 前言一、步骤1.GitHub 创建项目2.Git 上传本地项目到 GitHub3.Git 命令整理 总结 前言 不附 Demo 连接的博客不是好博客&#xff0c;所以我们要做个乐于助人&#xff0c;有责任心的人&#xff0c;这篇文章手把手教你如何在 GitHub 上传自己的项目&#xff0c; 一、步…

【OpenMMLab AI实战营二期笔记】第五天 MMPretrain代码课

1.环境安装 conda activate mmpre # 激活创建好的环境,确保安装好pytorch,可以使用gpu git clone https://github.com/open-mmlab/mmpretrain.git # 下载mmpre源码 cd mmpretrain # 进入mmpretrian目录 pip install openmim # 安装管理工具 mim install -e ".[multimodal…

JAVA开发运维(系统上到生产环境准备工作)

一、前言 java项目在开发环境开发完成&#xff0c;在测试环境测试没有问题后&#xff0c;就需要发布到生产环境&#xff0c;如果系统是对公众的&#xff0c;那就需要很多工作了。比如服务器申请&#xff0c;域名申请&#xff0c;渗透测试&#xff0c;漏洞扫描&#xff0c;公网…

第二章(第二节):导数与微分

1.导数与微分 1.导数概念 设曲线 L 的方程 y=f(x),a ≤ x ≤ b,x0 ∈ (a, b),在曲线 L 上的点 M0(x0, y0) 附近任取一点 M(x0 + Δx, y0 + Δy),过 M0 与 M 作曲线的割线M~0~M,的斜率为:当 x→x0 时,点 M 沿着曲线 L 趋向 M0,与此同时,割线 M0M 趋向一个极限位置 M0T…

想要转行的一定要看软件测试发展简史+学习路线

迄今为止&#xff0c;软件测试的发展一共经历了五个重要时期&#xff1a; 调试为主 20世纪50年代&#xff0c;计算机刚诞生不久&#xff0c;只有科学家级别的人才会去编程&#xff0c;需求和程序本身也远远没有现在这么复杂多变&#xff0c;相当于开发人员一人承担需求分析&am…

idea设置注释模板

目录 设置注释文件模板设置模板 设置注释文件模板 Ctrl Alt S 打开设置&#xff0c;Editor - File and Code Templates 选择class、interface、enum根据自己需要选择需要添加注释的文件&#xff0c;依次添加如下配置内容 /**1. ClassName ${NAME}2. Description TODO3. Aut…

BUUCTF Unencode 1

题目描述&#xff1a; 密文&#xff1a; 89FQA9WMD<V1A<V1S83DY.#<W3$Q,2TM]解题思路&#xff1a; 1、观察密文&#xff0c;尝试Base85、Base91等编码&#xff0c;均失败。 2、结合题目&#xff0c;联想到UUencode编码&#xff0c;尝试后成功&#xff0c;得到flag。 …

驱动LSM6DS3TR-C实现高效运动检测与数据采集(5)----上报匿名上位机实现可视化

概述 lsm6ds3trc包含三轴陀螺仪与三轴加速度计。 姿态有多种数学表示方式&#xff0c;常见的是四元数&#xff0c;欧拉角&#xff0c;矩阵和轴角。他们各自有其自身的优点&#xff0c;在不同的领域使用不同的表示方式。在四轴飞行器中使用到了四元数和欧拉角。 姿态解算选用的…

SpringBoot配置多数据源

SpringBoot配置多数据源 最近在做一个SpringBoot项目时需要关联两个数据库,于是乎我就研究了下关于springboot的多数据源配置,记录配置过程,分享一下 一、基础配置 (这里只展示主要配置) JDK1.8springBoot2.3.4.RELEASEmybatis2.1.0mysql-connector-java 8.0.21maven仓…

知乎家居产品种草营销怎么做?

近年来&#xff0c;家居产品种草营销已经成为了一种新型营销方式。知乎作为全球最大的中文问答社区&#xff0c;拥有着海量的用户和优质内容&#xff0c;逐渐成为了家居产品种草营销中不可忽视的平台。那么&#xff0c;在这个平台上如何进行家居产品种草营销呢&#xff1f;接下…

Python之函数【三】(高阶函数和闭包)

文章目录 前言一、高阶函数二、闭包&#xff08;也称之为&#xff1a;闭包函数&#xff09; 1、浅谈闭包函数 1.1、划重点1.2、注意点2、怎么判断是不是闭包函数呢&#xff1f; 2.1、那接下来&#xff0c;我们就细细的拆开解释2.2、对于这个作用域&#xff0c;在JavaSc…

【MySQL数据库基础】

MySQL数据库基础 1. 数据库的操作1.1 显示当前的数据库1.2 创建数据库1.3 使用数据库1.4 删除数据库 2. 常用数据类型2.1整数&#xff08;xxxint&#xff09;2.2日期时间类型2.3字符串型 3. 表的操作3.1 查看表结构3.2 创建表3.3 删除表 1. 数据库的操作 1.1 显示当前的数据库…