欢迎关注『youcans动手学模型』系列
本专栏内容和资源同步到 GitHub/youcans
【youcans动手学模型】Xception 模型-CIFAR10图像分类
- 1. Xception 神经网络模型
- 1.1 模型简介
- 1.2 论文介绍
- 1.3 分析与讨论
- 2. 在 PyTorch 中定义 Xception 模型类
- 2.1 深度可分离卷积
- 2.2 带残差连接的深度可分离卷积模块
- 2.3 简化的 Xception 模型类
- 2.4 完整的 Xception 模型类
- 3. 基于 Xception 模型的 CIFAR10 图像分类
- 3.1 PyTorch 建立神经网络模型的基本步骤
- 3.2 加载 CIFAR10 数据集
- 3.3 建立 Xception 网络模型
- 3.4 Xception 模型训练
- 3.5 Xception 模型的保存与加载
- 3.6 模型检验
- 3.7 模型推理
本文用 PyTorch 实现 Xception 网络模型,使用 CIFAR10 数据集训练模型,进行图像分类。
1. Xception 神经网络模型
Francois Chollet 在 2017 年发表的论文“ Xception: Deep Learning with Depthwise Separable Convolutions ”,提出了 Xception 网络模型。本文作者 Francois Chollet 来自 Google,也是 Keras 的作者 。
【论文下载地址】
Xception: Deep Learning with Depthwise Separable Convolutions
【GitHub地址】:作者例程
【PyTorch实现】:参考例程
1.1 模型简介
传统的卷积操作同时对输入特征图的空间交互性(spatial correlations)和跨通道交互性(cross-channel correlations)进行映射。
Inception 系列结构致力于对该过程进行分解,在一定程度上实现了跨通道相关性和空间相关性的解耦。Xception 与深度可分离卷积类似,使用 “Extreme Inception” 实现了跨通道相关性和空间相关性的完全解耦。
“深度可分离卷积(Depthwise Separable Convolution,DSC)由深度卷积(depthwise convolution)和逐点卷积(pointwise convolution)连接组成,实现了跨通道相关性和空间相关性的完全解耦。
- 深度卷积,每个卷积核只作用于单一通道的分组卷积,分组数等于输入通道数,实现空间相关性的映射。
- 逐点卷积,在级联通道上进行 1*1 卷积,实现跨通道相关性的映射。
以 16 个输入通道和 32 个输出通道上的 3x3卷积层为例:
-
常规的卷积操作有 16*32*3*3=4608 个参数。
-
在深度可分离卷积中,第一步空间卷积有 16*3*3= 144 个参数,第二步深度方向卷积有 16*32*1*1= 512 个参数,共 656 个参数。
因此,深度可分离卷积大大减少了参数计数,具有更高效的复杂性,而且还保持了跨通道功能。
1.2 论文介绍
【论文摘要】
我们将卷积神经网络中的 Inception modules 解释为常规卷积运算和深度可分离卷积运算的中间步骤。从这个角度来看,深度上可分离的卷积可以理解为具有最大分支数量的 Inception modules 。
由此,我们提出一种新型深度卷积神经网络架构,用深度可分离卷积取代了 Inception模块,称为 Xception 架构。Xception 体系结构具有与 Inception V3 相同数量的参数,在 ImageNet 数据集上的性能略微优于Inception V3,在包括 3.5亿张图像和 17000个类的更大图像分类数据集上显著优于 InceptionV3。
【论文背景】
传统的卷积操作同时对输入特征图的空间交互性(spatial correlations)和跨通道交互性(cross-channel correlations)进行映射。例如,卷积层的输入尺寸为 h*w*d_in,卷积核尺寸为 s*s*d_in,卷积操作既在 s*s 的空间范围上对特征图进行信息融合,又对通道数为 d_in 的输入特征图进行跨通道的信息融合。
多分支的 Inception 结构在一定程度上对跨通道相关性和空间相关性进行解耦。例如,1*1 卷积分支,相当于只进行跨通道融合,不进行空间卷积(类似于 RGB 通道融合为灰度图像);先做 1*1 卷积进行跨通道融合,再做 3*3 卷积相当于进行空间信息融合,也可以在一定程度上进行跨通道相关性和空间相关性的解耦。
考虑图4的 Inception 的极端情况,首先使用 1*1 卷积来映射输入的跨通道相关性,然后将每个输出通道作为一组(而不是如 Inception 分为 3~4 组),将 h*h*d_in 的输入分为 d_in 组,使用 3*3 卷积来映射空间相关性。
这就是 Extreme 版本的 “Inception”,意思是 Inception 体系结构的更强版本。Extreme 版本的 “Inception” 与基于深度可分离卷积(DSC)的 MobileNet非常相似,二者的区别在于:
(1)操作顺序:DSC 先做单通道空间卷积,再做跨通道 1x1 卷积,Xception 先做 1*1 卷积再做 3*3 卷积 。
(2)激活函数:DSC 在深度卷积与逐点卷积之间没有 ReLU 层,Xception 在两次卷积之后都有 ReLU 层。
【模型结构】
Xception 提出了一种完全基于深度可分离卷积层的卷积神经网络架构,该网络架构基于以下的假设:卷积神经网络特征图中的跨通道相关性和空间相关性的映射可以完全解耦。这个假设是 Inception 架构假设的更强版本,所以命名为 Xception ,代表 “Extreme Inception”。
完整的 Xception 模型是具有残差连接的深度可分离卷积层的线性堆叠,具有36个卷积层组成的特征提取结构。
- 36个卷积层被构造成14个模块,除第一个和最后一个模块外,其它所有模块周围都有线性残差连接。。
- 这些模块分为三个连续的虚拟流:Entry/Middle/Exit 三个flow,每个flow内部使用不同的重复模块。
- Entry flow 主要是用来不断下采样,减小空间尺寸;Middle 没有下采样,用来学习关联关系,优化特征;Exit flow 用于汇总、整理特征。
【模型性能】
Xception 作为Inception v3的改进,主要是在Inception v3的基础上引入了depthwise separable convolution,在基本不增加网络复杂度的前提下提高了模型的效果。
根据论文的报道,Xception 在精度、参数量、运算时间几个方面都略优于 Inception V3,但优势都不太大。
【论文结论】
本文提出了一种新的卷积网络架构 Xception,通过使用深度可分离的卷积代替 Inception 模块来改进 Inception 系列体系结构,构建深度可分离卷积堆栈模型。Xception 的参数数量与 Inception V3 相似,在精度、参数量、运算时间上略优于 Inception V3。
1.3 分析与讨论
(1)需要特别注意的是,本文虽然分析了 Inception 模块的 Extreme 版本本质上也是一种深度可分离的卷积,并讨论了与深度可分离卷积(DSC)的区别,但是,在正文中使用的 Xception 模型架构中所使用的,并不是 Extreme 版本的 Inception 模块,而就是直接使用深度可分离卷积(DSC)。
所以,文中所说的 “Extreme 版本的 Inception 模块” 真的是 Inception,而 Xception(Extreme Inception)模型中真正用的是深度可分离卷积(DSC),与 Inception 并没有关系。
(2)Google 公司的另一些研究者在 2016 年还提出了 Inception-ResNet,精度不仅优于 Inception V3,也优于 Xception。所以,Xception 在精度上的提高,到底是由于深度可分离卷积,还是由于引入了残差连接,从论文中并不能得到明确的结论。
2. 在 PyTorch 中定义 Xception 模型类
总的来说,Xception 模型是一种网络架构,针对不同的任务可以进行不同的网络结构设计和超参数配置。
本节先面向 CIFAR10 数据集图像分类问题,详细介绍一个简化版 Xception 模型类的构造过程。最后也将给出复现论文的完整版 Xception 模型类的例程。
2.1 深度可分离卷积
深度可分离卷积(DSC)是 Xception 网络架构的核心,由深度卷积(depthwise convolution)和逐点卷积(pointwise convolution)连接组成,实现了跨通道相关性和空间相关性的完全解耦。
深度可分离卷积模块(DSC)的例程如下。
# 定义深度可分离卷积
class SeparableConv2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, padding, stride=1):
super(SeparableConv2d, self).__init__()
# 深度卷积 depthwise, 逐个通道操作, groups=in_channels=out_channels
self.depth_conv = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch)
# 逐点卷积 pointwise, 1x1 卷积
self.point_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0, groups=1)
def forward(self, x):
out = self.depth_conv(x)
out = self.point_conv(out)
return out
2.2 带残差连接的深度可分离卷积模块
深度可分离卷积在深度卷积和逐点卷积之间不使用 ReLU,在逐点卷积之后加入 ReLu 和 BN,这两层可以在 SeparableConv2d 类中定义,也可以在 ResDSC 类中定义。
Xception 网络架构中使用带残差连接的深度可分离卷积模块。简化的带残差连接的深度可分离卷积模块的例程如下。
# 定义 带残差连接的深度可分离卷积模块
class ResDSC(nn.Module):
def __init__(self, in_ch, out_ch):
super(ResDSC, self).__init__()
self.residual = nn.Sequential(SeparableConv2d(in_ch, out_ch, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
SeparableConv2d(out_ch, out_ch, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_ch),
nn.MaxPool2d((3, 3), stride=(2, 2), padding=1))
self.shortcut = nn.Sequential(nn.Conv2d(in_ch, out_ch, (1, 1), stride=(2, 2)),
nn.BatchNorm2d(out_ch))
def forward(self, x):
residual = self.residual(x)
shortcut = self.shortcut(x)
output = shortcut + residual
return output
2.3 简化的 Xception 模型类
Xception 模型是一种网络架构,针对不同的任务可以进行不同的网络结构设计和超参数配置。
相对于 ImageNet 数据集来说,CIFAR10 数据集的规模较小、图片尺寸较小,使用论文中的 Xception 网络架构过于庞大和复杂。因此,面向 CIFAR10 数据集图像分类问题,我们构建一个简化的 Xception 模型类,该模型类以带残差连接的深度可分离卷积模块为核心,但没有使用复杂的 Entry/Middle/Exit flow。该简化模型的速度很快,性能也还不错。
# 简化的 Xception 模型类
class mini_Xception(nn.Module):
def __init__(self, num_classes=10):
super(mini_Xception, self).__init__()
self.base = nn.Sequential(nn.Conv2d(3, 16, (3, 3), stride=(1, 1)),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, (3, 3), stride=(1, 1)),
nn.BatchNorm2d(32),
nn.ReLU())
self.module1 = ResDSC(in_ch=32, out_ch=32)
self.module2 = ResDSC(in_ch=32, out_ch=64)
self.module3 = ResDSC(in_ch=64, out_ch=64)
self.module4 = ResDSC(in_ch=64, out_ch=128)
# output
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.base(x)
x = self.module1(x)
x = self.module2(x)
x = self.module3(x)
x = self.module4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
2.4 完整的 Xception 模型类
为了复现原始论文中的 Xception 模型,定义一个完整的 Xception 模型类如下。默认类别数量 num_class=100,可以在实例化模型时根据任务需求来设置。
注意,完整的 Xception 模型类需要的 GPU 内存很高,在训练时要减小批大小 batchsize。
# 定义 深度可分离卷积
class SeparableConv2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, padding, bias=False):
super(SeparableConv2d, self).__init__()
self.depth_conv = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch, bias=bias)
self.point_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=bias)
def forward(self, x):
out = self.depth_conv(x)
out = self.point_conv(out)
return out
class Xception(nn.Module):
def __init__(self, input_channel, num_classes=10):
super(Xception, self).__init__()
# Entry Flow
self.entry_flow1 = nn.Sequential(
nn.Conv2d(input_channel, 32, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.entry_flow2 = nn.Sequential(
SeparableConv2d(64, 128, 3, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
SeparableConv2d(128, 128, 3, 1),
nn.BatchNorm2d(128),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.entry_flow2_residual = nn.Conv2d(64, 128, kernel_size=1, stride=2)
self.entry_flow3 = nn.Sequential(
nn.ReLU(True),
SeparableConv2d(128, 256, 3, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
SeparableConv2d(256, 256, 3, 1),
nn.BatchNorm2d(256),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.entry_flow3_residual = nn.Conv2d(128, 256, kernel_size=1, stride=2)
self.entry_flow4 = nn.Sequential(
nn.ReLU(True),
SeparableConv2d(256, 728, 3, 1),
nn.BatchNorm2d(728),
nn.ReLU(True),
SeparableConv2d(728, 728, 3, 1),
nn.BatchNorm2d(728),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.entry_flow4_residual = nn.Conv2d(256, 728, kernel_size=1, stride=2)
# Middle Flow
self.middle_flow = nn.Sequential(
nn.ReLU(True),
SeparableConv2d(728, 728, 3, 1),
nn.BatchNorm2d(728),
nn.ReLU(True),
SeparableConv2d(728, 728, 3, 1),
nn.BatchNorm2d(728),
nn.ReLU(True),
SeparableConv2d(728, 728, 3, 1),
nn.BatchNorm2d(728)
)
# Exit Flow
self.exit_flow1 = nn.Sequential(
nn.ReLU(True),
SeparableConv2d(728, 728, 3, 1),
nn.BatchNorm2d(728),
nn.ReLU(True),
SeparableConv2d(728, 1024, 3, 1),
nn.BatchNorm2d(1024),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.exit_flow1_residual = nn.Conv2d(728, 1024, kernel_size=1, stride=2)
self.exit_flow2 = nn.Sequential(
SeparableConv2d(1024, 1536, 3, 1),
nn.BatchNorm2d(1536),
nn.ReLU(True),
SeparableConv2d(1536, 2048, 3, 1),
nn.BatchNorm2d(2048),
nn.ReLU(True)
)
self.linear = nn.Linear(2048, num_classes)
def forward(self, x):
entry_out1 = self.entry_flow1(x)
entry_out2 = self.entry_flow2(entry_out1) + self.entry_flow2_residual(entry_out1)
entry_out3 = self.entry_flow3(entry_out2) + self.entry_flow3_residual(entry_out2)
entry_out = self.entry_flow4(entry_out3) + self.entry_flow4_residual(entry_out3)
middle_out = self.middle_flow(entry_out) + entry_out
for i in range(7):
middle_out = self.middle_flow(middle_out) + middle_out
exit_out1 = self.exit_flow1(middle_out) + self.exit_flow1_residual(middle_out)
exit_out2 = self.exit_flow2(exit_out1)
exit_avg_pool = F.adaptive_avg_pool2d(exit_out2, (1, 1))
exit_avg_pool_flat = exit_avg_pool.view(exit_avg_pool.size(0), -1)
output = self.linear(exit_avg_pool_flat)
return output
3. 基于 Xception 模型的 CIFAR10 图像分类
3.1 PyTorch 建立神经网络模型的基本步骤
使用 PyTorch 建立、训练和使用神经网络模型的基本步骤如下。
- 准备数据集(Prepare dataset):加载数据集,对数据进行预处理。
- 建立模型(Design the model):实例化模型类,定义损失函数和优化器,确定模型结构和训练方法。
- 模型训练(Model trainning):使用训练数据集对模型进行训练,确定模型参数。
- 模型推理(Model inferring):使用训练好的模型进行推理,对输入数据预测输出结果。
- 模型保存与加载(Model saving/loading):保存训练好的模型,以便以后使用或部署。
以下按此步骤讲解 Xception 模型的例程。
3.2 加载 CIFAR10 数据集
通用数据集的样本结构均衡、信息高效,而且组织规范、易于处理。使用通用的数据集训练神经网络,不仅可以提高工作效率,而且便于评估模型性能。
PyTorch 提供了一些常用的图像数据集,预加载在 torchvision.datasets
类中。torchvision
模块实现神经网络所需的核心类和方法, torchvision.datasets
包含流行的数据集、模型架构和常用的图像转换方法。
CIFAR 数据集是一个经典的图像分类小型数据集,有 CIFAR10 和 CIFAR100 两个版本。CIFAR10 有 10 个类别,CIFAR100 有 100 个类别。CIFAR10 每张图像大小为 32*32,包括飞机、小汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车 10 个类别。CIFAR10 共有 60000张图像,其中训练集 50000张,测试集 10000张。每个类别有 6000张图片,数据集平衡。
加载和使用 CIFAR 数据集的方法为:
torchvision.datasets.CIFAR10()
torchvision.datasets.CIFAR100()
CIFAR 数据集可以从官网下载:http://www.cs.toronto.edu/~kriz/cifar.html 后使用,也可以使用 datasets 类自动加载(如果本地路径没有该文件则自动下载)。
下载数据集时,使用预定义的 transform 方法进行数据预处理,包括调整图像尺寸、标准化处理,将数据格式转换为张量。标准化处理所使用 CIFAR10 数据集的均值和方差为 (0.49, 0.48, 0.45), (0.25, 0.24, 0.26)。
transform_train在训练过程中,增加随机性,提高泛化能力
大型训练数据集不能一次性加载全部样本来训练,可以使用 Dataloader 类自动加载数据。Dataloader 是一个迭代器,基本功能是传入一个 Dataset 对象,根据参数 batch_size 生成一个 batch 的数据。
使用 DataLoader 类加载 CIFAR-10 数据集的例程如下。
# (1) 将[0,1]的PILImage 转换为[-1,1]的Tensor
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(10), # 随机旋转
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.Resize((32, 32)), # 图像大小调整为 (w,h)=(32,32)
transforms.ToTensor(), # 将图像转换为张量 Tensor
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
# 测试集不需要进行数据增强
transform = transforms.Compose([
transforms.Resize((32, 32)), # 图像大小调整为 (w,h)=(32,32)
transforms.ToTensor(), # 将图像转换为张量 Tensor
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
# (2) 加载 CIFAR10 数据集
batchsize = 128
# 加载 CIFAR10 数据集, 如果 root 路径加载失败, 则自动在线下载
# 加载 CIFAR10 训练数据集, 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='../dataset', train=True,
download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batchsize)
# 加载 CIFAR10 验证数据集, 10000张验证图片
test_set = torchvision.datasets.CIFAR10(root='../dataset', train=False,
download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)
# 创建生成器,用 next 获取一个批次的数据
valid_data_iter = iter(test_loader) # _SingleProcessDataLoaderIter 对象
valid_images, valid_labels = next(valid_data_iter) # images: [batch,3,32,32], labels: [batch]
valid_size = valid_labels.size(0) # 验证数据集大小,batch
print(valid_images.shape, valid_labels.shape)
# 定义类别名称,CIFAR10 数据集的 10个类别
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
3.3 建立 Xception 网络模型
建立一个 Xception 网络模型进行训练,包括三个步骤:
- 实例化 Xception 模型对象;
- 设置训练的损失函数;
- 设置训练的优化器。
为了使用 GPU 设备进行模型训练和模型推理,使用 model.to(device) 将网络分配到指定的设备中。
torch.nn.functional 模块提供了各种内置损失函数,本例使用交叉熵损失函数 CrossEntropyLoss。
torch.optim 模块提供了各种优化方法。本例使用 Adam 优化器,注意要将 model 的参数 model.parameters() 传给优化器对象,以便扫描需要优化的参数。
# (3) 构造 Xception 网络模型
model = mini_Xception(num_classes=10) # 实例化 Xception 网络模型
model.to(device) # 将网络分配到指定的device中
print(model)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 定义损失函数 CrossEntropy
optimizer = torch.optim.Adam(lr=0.001, params=model.parameters()) # Adam 优化器
3.4 Xception 模型训练
PyTorch 模型训练的基本步骤是:
- 前馈计算模型的输出值;
- 计算损失函数值;
- 计算权重 weight 和偏差 bias 的梯度;
- 根据梯度值调整模型参数;
- 将梯度重置为 0(用于下一循环)。
在模型训练过程中,可以使用验证集数据评价训练过程中的模型精度,以便控制训练过程。模型验证就是用验证数据进行模型推理,前向计算得到模型输出,但不反向计算模型误差,因此需要设置 torch.no_grad()。
使用 PyTorch 进行模型训练的例程如下。
# (4) 训练 Xception 模型
epoch_list = [] # 记录训练轮次
loss_list = [] # 记录训练集的损失值
accu_list = [] # 记录验证集的准确率
num_epochs = 100 # 训练轮次
for epoch in range(num_epochs): # 训练轮次 epoch
running_loss = 0.0 # 每个轮次的累加损失值清零
for step, data in enumerate(train_loader, start=0): # 迭代器加载数据
optimizer.zero_grad() # 损失梯度清零
inputs, labels = data # inputs: [batch,3,32,32] labels: [batch]
outputs = model(inputs.to(device)) # 正向传播
loss = criterion(outputs, labels.to(device)) # 计算损失函数
loss.backward() # 反向传播
optimizer.step() # 参数更新
# 累加训练损失值
running_loss += loss.item()
# if step%100==99: # 每 100 个 step 打印一次训练信息
# print("\t epoch {}, step {}: loss = {:.4f}".format(epoch, step, loss.item()))
# 计算每个轮次的验证集准确率
with torch.no_grad(): # 验证过程, 不计算损失函数梯度
outputs_valid = model(valid_images.to(device)) # 模型对验证集进行推理, [batch, 10]
pred_labels = torch.max(outputs_valid, dim=1)[1] # 预测类别, [batch]
accuracy = torch.eq(pred_labels, valid_labels.to(device)).sum().item() / valid_size * 100 # 计算准确率
print("Epoch {}: train loss={:.4f}, accuracy={:.2f}%".format(epoch, running_loss, accuracy))
# 记录训练过程的统计数据
epoch_list.append(epoch) # 记录迭代次数
loss_list.append(running_loss) # 记录训练集的损失函数
accu_list.append(accuracy) # 记录验证集的准确率
程序运行结果如下:
Epoch 0: train loss=585.6872, accuracy=56.70%
Epoch 1: train loss=454.1163, accuracy=65.40%
Epoch 2: train loss=400.7813, accuracy=68.80%
Epoch 3: train loss=366.8728, accuracy=71.30%
…
Epoch 98: train loss=143.8827, accuracy=86.30%
Epoch 99: train loss=142.7452, accuracy=85.50%
经过 20 轮左右的训练,使用验证集中的 1000 张图片进行验证,模型准确率可以达到 80%。继续训练可以进一步降低训练损失函数值,经过 100轮左右的训练验证集的准确率保持在 85%左右。
3.5 Xception 模型的保存与加载
模型训练好以后,将模型保存起来,以便下次使用。PyTorch 中模型保存主要有两种方式,一是保存模型权值,二是保存整个模型。本例使用 model.state_dict() 方法以字典形式返回模型权值,torch.save() 方法将权值字典序列化到磁盘,将模型保存为 .pth 文件。
由于本例程中模型存储在 CUDA 设备上,在保存模型时要将模型移动到 CPU。
# (5) 保存 Xception 网络模型
save_path = "../models/Xception_Cifar1"
model_cpu = model.cpu() # 将模型移动到 CPU
model_path = save_path + ".pth" # 模型文件路径
torch.save(model.state_dict(), model_path) # 保存模型权值
使用训练好的模型,首先要实例化模型类,然后调用 load_state_dict() 方法加载模型的权值参数。
# 以下模型加载和模型推理,可以是另一个独立的程序
# (6) 加载 Xception 网络模型进行推理
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测并指定设备
# 加载 Xception 预训练模型
model = mini_Xception(num_classes=10) # 实例化 Xception 网络模型
model.to(device) # 将网络分配到指定的device中
model_path = "../models/Xception_Cifar1.pth"
model.load_state_dict(torch.load(model_path))
model.eval() # 模型推理模式
需要特别注意的是:
(1)PyTorch 中的 .pth 文件只保存了模型的权值参数,而没有模型的结构信息,因此必须先实例化模型对象,再加载模型参数。
(2)模型对象必须与模型参数严格对应,才能正常使用。注意即使都是 LeNet5 模型,模型类的具体定义也可能有细微的区别。如果从一个来源获取模型类的定义,从另一个来源获取模型参数文件,就很容易造成模型结构与参数不能匹配。
(3)无论从 PyTorch 模型仓库加载的模型和参数,或从其它来源获取的预训练模型,或自己训练得到的模型,模型加载的方法都是相同的,也都要注意模型结构与参数的匹配问题。
3.6 模型检验
使用加载的 Xception模型,输入新的图片进行模型推理,可以由模型输出结果确定输入图片所属的类别。
使用测试集数据进行模型推理,根据模型预测结果与图片标签进行比较,可以检验模型的准确率。模型验证集与模型检验集不能交叉使用,但为了简化例程在本程序中未做区分。
# (7) 模型检测
correct = 0
total = 0
for data in test_loader: # 迭代器加载测试数据集
imgs, labels = data # torch.Size([batch,3,32,32) torch.Size([batch])
# print(imgs.shape, labels.shape)
outputs = model(imgs.to(device)) # 正向传播, 模型推理, [batch, 10]
labels_pred = torch.max(outputs, dim=1)[1] # 模型预测的类别 [batch]
# _, labels_pred = torch.max(outputs.data, 1)
total += labels.size(0)
correct += torch.eq(labels_pred, labels.to(device)).sum().item()
accuracy = 100. * correct / total
print("Test samples: {}".format(total))
print("Test accuracy={:.2f}%".format(accuracy))
使用测试集进行模型推理,测试模型准确率为 85.01%。
Test samples: 10000
Test accuracy=85.01%
3.7 模型推理
使用加载的 Xception模型,输入新的图片进行模型推理,可以由模型输出结果确定输入图片所属的类别。
从测试集中提取几张图片,或者读取图像文件,进行模型推理,获得图片的分类类别。在提取图片或读取文件时,要注意对图片格式和图片大小进行适当的转换。
# (8) 提取测试集图片进行模型推理
batch = 8 # 批次大小
data_set = torchvision.datasets.CIFAR10(root='../dataset', train=False,
download=False, transform=None)
plt.figure(figsize=(9, 6))
for i in range(batch):
imgPIL = data_set[i][0] # 提取 PIL 图片
label = data_set[i][1] # 提取 图片标签
# 预处理/模型推理/后处理
imgTrans = transform(imgPIL) # 预处理变换, torch.Size([3,32,32])
imgBatch = torch.unsqueeze(imgTrans, 0) # 转为批处理,torch.Size([batch=1,3,32,32])
outputs = model(imgBatch.to(device)) # 模型推理, 返回 [batch=1, 10]
indexes = torch.max(outputs, dim=1)[1] # 注意 [batch=1], device = 'device
index = indexes[0].item() # 预测类别,整数
# 绘制第 i 张图片
imgNP = np.array(imgPIL) # PIL -> Numpy
out_text = "label:{}/model:{}".format(classes[label], classes[index])
plt.subplot(2, 4, i+1)
plt.imshow(imgNP)
plt.title(out_text)
plt.axis('off')
plt.tight_layout()
plt.show()
结果如下。
# (9) 读取图像文件进行模型推理
from PIL import Image
filePath = "../images/img_plane_01.jpg" # 数据文件的地址和文件名
imgPIL = Image.open(filePath) # PIL 读取图像文件, <class 'PIL.Image.Image'>
# 预处理/模型推理/后处理
imgTrans = transform["test"](imgPIL) # 预处理变换, torch.Size([3, 32, 32])
imgBatch = torch.unsqueeze(imgTrans, 0) # 转为批处理,torch.Size([batch=1, 3, 32, 32])
outputs = model(imgBatch.to(device)) # 模型推理, 返回 [batch=1, 10]
indexes = torch.max(outputs, dim=1)[1] # 注意 [batch=1], device = 'device
percentages = nn.functional.softmax(outputs, dim=1)[0] * 100
index = indexes[0].item() # 预测类别,整数
percent = percentages[index].item() # 预测类别的概率,浮点数
# 绘制第 i 张图片
imgNP = np.array(imgPIL) # PIL -> Numpy
out_text = "Prediction:{}, {}, {:.2f}%".format(index, classes[index], percent)
print(out_text)
plt.imshow(imgNP)
plt.title(out_text)
plt.axis('off')
plt.tight_layout()
plt.show()
结果如下。
【参考文献】
Francois Chollet, Xception: Deep Learning with Depthwise Separable Convolutions, 2017
【本节完】
版权声明:
欢迎关注『youcans动手学模型』系列
转发请注明原文链接:
【youcans动手学模型】Xception 模型-CIFAR10图像分类
Copyright 2023 youcans, XUPT
Crated:2023-06-16