Pytorch之GoogLeNet图像分类

news2025/1/22 21:03:49
  • 💂 个人主页:风间琉璃
  • 🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主
  • 💬 如果文章对你有帮助、欢迎关注、点赞、收藏(一键三连)订阅专栏

目录

前言

一、GoogLeNet网络结构

1.Inception 结构

(1)Inception v1 

(2)Inception v2

(3)Inception v3

(4)Inception v4

2.网络模型分析

(1)输入层

(2)第一个模块

(2)第二个模块

(3)第三个模块 Inception 3a

(4)第四个模块 Inception 3b 

(5)输出层 

(6)辅助分类器

3.网络创新点

(1)引入Inception

(2)1x1卷积核

(3)辅助分类器

(4)平均池化层

二、GoogLeNet实现

1.构建GoogLeNet网络

2.加载数据集

3.训练和测试模型

三、实现图像分类


前言

2014 年,GoogLeNet 和 VGG 是当年 ImageNet 挑战赛 (ILSVRC14) 的双雄,GoogLeNet 获得了第一名、VGG 获得了第二名,这两类模型结构的共同特点是层次更深了。VGG 继承了 LeNet 以及 AlexNet 的一些框架结构,而 GoogLeNet 则做了更加大胆的网络结构尝试,虽然深度只有 22 层,但大小却比 AlexNet 和 VGG 小很多,GoogleNet 参数为 500 万个,AlexNet 参数个数是 GoogleNet 的 12 倍,VGGNet 参数又是 AlexNet 的 3 倍,因此在内存或计算资源有限时,GoogleNet 是比较好的选择;从模型结果来看,GoogLeNet 的性能却更加优越

一、GoogLeNet网络结构

GoogLeNet是google推出的基于Inception模块的深度神经网络模型,在2014年的ImageNet竞赛中夺得了冠军。

一般来说,提升网络性能最直接的办法就是增加网络深度和宽度,深度指网络层次数量、宽度指神经元数量。但这种方式存在以下问题:

(1)参数太多,如果训练数据集有限,很容易产生过拟合;
(2)网络越大、参数越多,计算复杂度越大,难以应用;
(3)网络越深,容易出现梯度弥散问题(梯度越往后穿越容易消失),难以优化模型。

解决方法是在增加网络深度和宽度的同时减少参数。为了减少参数,一般将全连接变成稀疏连接。但是在实现上,全连接变成稀疏连接后实际计算量并不会有质的提升,因为大部分硬件是针对密集矩阵计算优化的,稀疏矩阵虽然数据量少,但是计算所消耗的时间却很难减少。

那么如何既能保持网络结构的稀疏性,又能利用密集矩阵的高计算性能。大量的文献表明可以将稀疏矩阵聚类为较为密集的子矩阵来提高计算性能,就如人类的大脑是可以看做是神经元的重复堆积,因此,GoogLeNet 团队提出了 Inception 网络结构,就是构造一种 “基础神经元” 结构,来搭建一个稀疏性、高计算性能的网络结构。

它的主要特点是网络不仅有深度,还在横向上具有“宽度”。由于图像信息在空间尺寸上的巨大差异,如何选择合适的卷积核大小来提取特征就显得比较困难了。空间分布范围更广的图像信息适合用较大的卷积核来提取其特征,而空间分布范围较小的图像信息则适合用较小的卷积核来提取其特征。 

在随后的两年中一直在改进,形成了Inception V2、Inception V3、Inception V4等版本。

 GoogLeNet网络(22层)结构如下:

1.Inception 结构

(1)Inception v1 

通过设计一个稀疏网络结构,但是能够产生稠密的数据,既能增加神经网络表现,又能保证计算资源的使用效率。谷歌提出了最原始 Inception 的基本结构:其主要思想是利用不同大小的卷积核实现不同尺度的感知最后进行融合,可以得到图像更好的表征。

Inception Module基本组成结构有四个成分:1*1卷积,3*3卷积,5*5卷积,3*3最大池化

该结构将 CNN 中常用的卷积(1x1,3x3,5x5)、池化操作(3x3)堆叠在一起(卷积、池化后的尺寸相同,将通道相加),一方面增加了网络的宽度,另一方面也增加了网络对尺度的适应性。
网络卷积层中的网络能够提取输入的每一个细节信息,同时 5x5 的滤波器也能够覆盖大部分接受层的的输入。还可以进行一个池化操作,以减少空间大小,降低过度拟合。在这些层之上,在每一个卷积层后都要做一个 ReLU 操作,以增加网络的非线性特征。

原始Inception结构存在很严重的问题:

1. 所有的卷积层(1×1、3×3、5×5)都是直接和输入对接的,因此卷积过程的参数计算量很大;

2.并行池化层的输出与输入维度相同,在和其他卷积层的输出做连接时,特征图的深度会变得很深,一样会增加很大的计算量。

为了避免这种情况,在 3x3 前、5x5 前、max pooling 后分别加上了 1x1 的卷积核,以起到了降低特征图厚度的作用,这也就形成了 Inception v1 的网络结构,如下图所示:

 

1x1 的卷积核作用:

1x1 卷积的主要目的是为了减少维度,还用于修正线性激活(ReLU)

假定上一层的特征图尺度为:224×224×128,经过256个5×5卷积核输出后,输出尺寸为:224×224×256,卷积层参数为:128×5×5×256

如果上一层先通过一个具有32个尺寸为1×1的卷积核后,再经过256个5×5卷积核输出,输出特征图尺寸仍为:224×224×256,但此时卷积层参数量变为了:128×1×1×32+32×5×5×256,大约减少了4倍。

这就是 Pointwise Convolution,即 1x1 卷积,简写为 PW,主要用于数据降维,减少参数量。当然也有使用 PW 做升维的,在 MobileNet v2 中就使用 PW 将 3 个特征图变成 6 个特征图,丰富输入数据的特征

(2)Inception v2

GoogLeNet 凭借其优秀的表现,得到了很多研究人员的学习和使用,因此 GoogLeNet 团队又对其进行了进一步地发掘改进,产生了升级版本的 GoogLeNet。

但是谷歌团队发现如果一味的堆叠Inception模块虽然对准确率有所提升,但对计算机效率并没有很好提升,反之会有明显下降,因此如何在不增加过多计算量的同时提高网络的表达能力就成为了一个问题。

Inception V2 版本的解决方案就是修改 Inception 的内部计算逻辑,提出了比较特殊的 “卷积” 计算结构

1.卷积分解

大尺寸的卷积核可以带来更大的感受野,但也意味着会产生更多的参数。因此,GoogLeNet 团队提出可以用 2 个连续的 3x3 卷积层组成的小网络来代替单个的 5x5 卷积层,即在保持感受野范围的同时又减少了参数量,如下图:

并进一步考虑了n×1卷积核,来取代3×3卷积核 。

任意 nxn 的卷积都可以通过 1xn 卷积后接 nx1 卷积来替代。GoogLeNet 团队发现在网络的前期使用这种分解效果并不好,在中度大小的特征图(feature map)上使用效果才会更好(特征图大小建议在 12 到 20 之间)。 

Inception模块优化过程:

 2.降低特征图大小

一般情况下,如果想让图像缩小,可以有如下两种方式:

方法一(左图):先池化再作 Inception 卷积,或者先作 Inception 卷积再作池化。但是方法一先作 pooling(池化)会导致特征表示遇到瓶颈(特征缺失)。

方法二(右图)是正常的缩小,但计算量很大。

为了同时保持特征表示且降低计算量,将网络结构改为下图,使用两个并行化的模块来降低计算量(卷积、池化并行执行,再进行合并) 。

以上所有的方式方法的融合就得到了Inception v2。

(3)Inception v3

Inception V3结构较V2并没有太多改进,主要有一下几点:

  • 对7×7卷积层分解为两个一维卷积(1×7,7×1),3x3也一样
  • 对损失函数添加正则项,避免在分类网络中,神经网络对某一类别具有高度拟合性;
  • 辅助分类器中也使用了BN。

分解既可以加速计算,又可以将 1 个卷积拆成 2 个卷积,使得网络深度进一步增加,增加了网络的非线性(每增加一层都要进行 ReLU)。 

(4)Inception v4

Inception V4 研究了 Inception 模块与残差连接的结合。ResNet 结构大大地加深了网络深度,还极大地提升了训练速度,同时性能也有提升。
Inception V4 主要利用残差连接(Residual Connection)来改进 V3 结构,得到 Inception-ResNet-v1,Inception-ResNet-v2,Inception-v4 网络。

ResNet 的残差结构和Inception-ResNet如下所示:

通过 20 个类似的模块组合,Inception-ResNet 构建如下:

2.网络模型分析

基于 Inception 构建了 GoogLeNet 的网络结构如下(共 22 层):主要由9个 I n c e p t i o n InceptionInception 块、全局平均汇聚层、辅助分类器构成。
 

1. GoogLeNet 采用了模块化的结构(Inception 结构),方便增添和修改。


2.网络最后采用 average pooling(平均池化)来代替全连接层,在最后还是加了一个全连接层,主要是为了方便对输出进行灵活调整。


3.虽然移除了全连接,但是网络中依然使用了 Dropout。


4.为了避免梯度消失,网络额外增加了 2 个辅助的 softmax 用于向前传导梯度(辅助分类器)。

辅助分类器是将中间某一层的输出用作分类,并按一个较小的权重(0.3)加到最终分类结果中,这样相当于做了模型融合,同时给网络增加了反向传播的梯度信号,也提供了额外的正则化,对于整个网络的训练很有裨益。而在实际测试的时候,这两个额外的 softmax 会被去掉。 

GoogLeNet 的网络结构图细节如下: 

列名
type网络名称
patch size/stride网络参数,卷积核大小/stride
output size输出特征矩阵的大小
depth对应该行结构的数量,如第三行卷积层,depth=2,表示经过两层卷积层,先是1x1,然后3x3
后8列关于Inception结构的配置

上表中的 “#3x3 reduce”,“#5x5 reduce” 表示在 3x3,5x5 卷积操作之前使用了 1x1 卷积的数量。"pool proj"表示在池化层后使用1x1卷积的数量。

(1)输入层

原始输入图像为 224x224x3,且都进行了零均值化的预处理操作(图像每个像素减去均值)。

(2)第一个模块

处理流程:卷积-->ReLU-->池化

卷积层:卷积核大小7*7,步长为2,padding为3,输出通道数64,输出特征图尺寸为(224-7+3*2)/2+1=112.5(向下取整)=112,输出特征图维度为112x112x64,卷积后进行ReLU操作。

池化层:窗口大小3*3,步长为2,输出特征图尺寸为((112 -3)/2)+1=55.5(向上取整)=56,输出特征图维度为56x56x64。

(2)第二个模块

处理流程:卷积-->卷积-->ReLU-->池化

卷积层:先用64个1x1的卷积核(3x3卷积核之前的降维)将输入的特征图(56x56x64)变为56x56x64,然后进行ReLU操作。
再用卷积核大小3*3,步长为1,padding为1,输出通道数192,进行卷积运算,输出特征图尺寸为(56-3+1*2)/1+1=56,输出特征图维度为56x56x192,然后进行ReLU操作。

池化层: 窗口大小3*3,步长为2,输出通道数192,输出为((56 - 3)/2)+1=27.5(向上取整)=28,输出特征图维度为28x28x192。



(3)第三个模块 Inception 3a


Inception 3a层分为四个分支,采用不同尺度的卷积核来进行处理。


(1)64 个 1x1 的卷积核,然后 RuLU,输出 28x28x64
(2)96 个 1x1 的卷积核,作为 3x3 卷积核之前的降维,变成 28x28x96,然后进行 ReLU 计算,再进行 128 个 3x3 的卷积(padding 为 1),输出 28x28x128
(3)16 个 1x1 的卷积核,作为 5x5 卷积核之前的降维,变成 28x28x16,进行 ReLU 计算后,再进行 32 个 5x5 的卷积(padding 为 2),输出 28x28x32
(4)pool 层,使用 3x3 的核(padding 为 1),输出 28x28x192,然后进行 32 个 1x1 的卷积,输出 28x28x32。
将四个结果进行连接,对这四部分输出结果的第三维并联,即 64+128+32+32=256,最终输出 28x28x256


(4)第四个模块 Inception 3b 


(1)128 个 1x1 的卷积核,然后 RuLU,输出 28x28x128
(2)128 个 1x1 的卷积核,作为 3x3 卷积核之前的降维,变成 28x28x128,进行 ReLU,再进行 192 个 3x3 的卷积(padding 为 1),输出 28x28x192
(3)32 个 1x1 的卷积核,作为 5x5 卷积核之前的降维,变成 28x28x32,进行 ReLU 计算后,再进行 96 个 5x5 的卷积(padding 为 2),输出 28x28x96
(4)pool 层,使用 3x3 的核(padding 为 1),输出 28x28x256,然后进行 64 个 1x1 的卷积,输出 28x28x64。
将四个结果进行连接,对这四部分输出结果的第三维并联,即 128+192+96+64=480,最终输出输出为 28x28x480

第四层(4a,4b,4c,4d,4e)、第五层(5a,5b)……,与 3a、3b 类似,在此就不再重复。

(5)输出层 

在输出层GoogLeNet与AlexNet、VGG采用3个连续的全连接层不同,GoogLeNet采用的是全局平均池化层,得到的是高和宽均为1的卷积层,然后添加丢弃概率为40%的Dropout,输出层激活函数采用的是softmax。 

(6)辅助分类器

根据实验数据,发现神经网络的中间层也具有很强的识别能力,为了利用中间层抽象的特征,在某些中间层中添加含有多层的分类器

如下图所示,红色边框内部代表添加的辅助分类器。GoogLeNet中共增加了两个辅助的softmax分支,作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0二是将中间某一层输出用作分类,起到模型融合作用。最后的loss=loss_2 + 0.3 * loss_1 + 0.3 * loss_0。实际测试时,这两个辅助softmax分支会被去掉。

3.网络创新点

(1)引入Inception

引入Inception结构,融合不同尺度的特征信息,能得到更好的特征表征。更意味着提高准确率,不一定需要堆叠更深的层或者增加神经元个数等,可以转向研究更稀疏但是更精密的结构同样可以达到很好的效果。

(2)1x1卷积核

使用1x1的卷积核进行降维以及映射处理。

(3)辅助分类器

添加两个辅助分类器帮助训练,在 GoogLeNet(Inception 网络)中,辅助分类器(Auxiliary Classifier)是一种用于训练过程中的辅助分类器,它有助于解决深度神经网络中的梯度消失问题(vanishing gradient problem)并加速训练。辅助分类器的作用如下:

  1. 缓解梯度消失问题:深度神经网络通常有很多层,而反向传播中的梯度在深度网络中可能会逐渐变得非常小,导致训练变得困难。辅助分类器通过在网络中间添加一个额外的分类器,可以提供额外的梯度信号,帮助在训练过程中传播梯度,从而缓解梯度消失问题。

  2. 正则化:辅助分类器可以看作是一种正则化技术。它强制网络中间的特征图具有一定的分类能力,因为这些特征图需要用于中间的分类任务。这有助于网络学习更具有区分性的特征。

  3. 多尺度特征:辅助分类器通常在网络的中间层添加,这使得它可以从中间层获取多尺度的特征表示。这些多尺度的特征可以对不同尺度的对象进行分类,有助于提高模型的分类性能。

  4. 减少过拟合:辅助分类器引入了额外的分类任务,可以视为一种正则化方法,有助于减少过拟合的风险,尤其是在训练数据较少的情况下。

需要注意的是,辅助分类器通常在训练过程中使用,而在推断(inference)阶段时通常不使用它们。在推断阶段,主要的分类器负责最终的分类任务。在训练过程中,辅助分类器的预测结果与主分类器的结果一起被用于计算损失函数,以帮助网络更好地训练。 

(4)平均池化层

丢弃全连接层,使用平均池化层(大大减少模型参数)

二、GoogLeNet实现

1.构建GoogLeNet网络

由于GoogLeNet网络中有大量的重复模块,我们可以将重复的模块单独定义,方便堆叠模块。

首先是卷积层模块,一般处理流程:卷积-->ReLU

# 卷积层基础模块:卷积 + ReLU
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

然后就是GoogLeNet的核心模块Inception模块,主要依据网络结构图搭建该模块,一个输入一个输出,中间含有4条分支,然后在维度上进行拼接,

# Inception模块
class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3reduce, ch3x3, ch5x5reduce, ch5x5, pool_proj):
        super(Inception, self).__init__()
        # 分支1:1x1卷积
        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
        # 分支2:1x1卷积 + 3x3卷积
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3reduce, kernel_size=1),
            BasicConv2d(ch3x3reduce, ch3x3, kernel_size=3, padding=1)  # 保证输出大小等于输入大小
        )
        # 分支3:1x1卷积 + 5x5卷积
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5reduce, kernel_size=1),
            BasicConv2d(ch5x5reduce, ch5x5, kernel_size=5, padding=2)  # 保证输出大小等于输入大小
        )
        # 分支4:池化 + 3x3卷积
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)  # 拼接

 最后还有两个辅助分类器,其输入层分别为4a,4d Inception模块的输出。

# 辅助分类器
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # 辅助分类器1:Nx512x14x14  辅助分类器2:Nx528x14x14
        x = self.averagePool(x)
        # 辅助分类器1:Nx512x4x4  辅助分类器2:Nx528x4x4
        x = self.conv(x)
        # Nx128x4x4
        x = torch.flatten(x, 1)
        x = F.dropout(x, p=0.5, training=self.training)  # 训练模型:self.training=True, 测试模型:self.training=False
        # Nx2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, p=0.5, training=self.training)
        # Nx1024
        x = self.fc2(x)
        # N x num_classes
        return x

 根据以上模块搭建GoogLeNet网络模型,其中有些参数需要根据以下的表格获取。

# GoogLeNet网络
class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        # 这里无nn.LocalResponseNorm(),可自行添加
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)  # ceil_mode:向上取整

        # 查表可得inception的配置参数
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        # 是否使用辅助分类器
        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)

        # N x 192 x 28 x 28
        x = self.inception3a(x)
        # N x 256 x 28 x 28
        x = self.inception3b(x)
        # N x 480 x 28 x 28
        x = self.maxpool3(x)
        # N x 480 x 14 x 14
        x = self.inception4a(x)
        # N x 512 x 14 x 14

        # 训练模型开启辅助分类器1,测试时不使用
        if self.training and self.aux_logits:  # eval model lose this layer
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        # N x 512 x 14 x 14
        x = self.inception4c(x)
        # N x 512 x 14 x 14
        x = self.inception4d(x)
        # N x 528 x 14 x 14

        # 训练模型开启辅助分类器2,测试时不使用
        if self.training and self.aux_logits:  # eval model lose this layer
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        # N x 832 x 14 x 14
        x = self.maxpool4(x)
        # N x 832 x 7 x 7
        x = self.inception5a(x)
        # N x 832 x 7 x 7
        x = self.inception5b(x)
        # N x 1024 x 7 x 7

        x = self.avgpool(x)
        # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 1024
        x = self.dropout(x)
        x = self.fc(x)
        # N x 1000 (num_classes)

        # 训练模型返回三个值,加权作为最终结果,测试时不使用
        if self.training and self.aux_logits:  # eval model lose this layer
            return x, aux2, aux1
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

2.加载数据集

这里使用花朵数据集,数据集制造和数据集使用的脚本的参考:Pytorch之AlexNet花朵分类_风间琉璃•的博客-CSDN博客

 加载数据集和测试集,并进行相应的预处理操作。

data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    # 数据集根目录
    data_root = os.path.abspath(os.getcwd())
    print(os.getcwd())
    # 图片目录
    image_path = os.path.join(data_root, "data_set", "flower_data")
    print(image_path)
    assert os.path.exists(image_path), "{} path does not exit.".format(image_path)

    # 准备数据集
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)

    # 定义一个包含花卉类别到索引的字典:雏菊,蒲公英,玫瑰,向日葵,郁金香
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    # 获取包含训练数据集类别名称到索引的字典,这通常用于数据加载器或数据集对象中。
    flower_list = train_dataset.class_to_idx
    # 创建一个反向字典,将索引映射回类别名称
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # 将字典转换为格式化的JSON字符串,每行缩进4个空格
    json_str = json.dumps(cla_dict, indent=4)
    # 打开名为 'class_indices.json' 的JSON文件,并将JSON字符串写入其中
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    # min: CPU 核心数量、批次大小(如果大于1),以及一个最大值8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print("using {} dataloader workers every process".format(nw))

    # 加载数据集
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num, val_num))

3.训练和测试模型

数据集预处理完成后,就可以进行网络模型的训练和验证。

net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
    # 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
    # 官方的模型中使用了bn层以及改了一些参数,不能混用
    # import torchvision
    # net = torchvision.models.googlenet(num_classes=5)
    # model_dict = net.state_dict()
    # # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
    # pretrain_model = torch.load("googlenet.pth")
    # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
    #             "aux2.fc2.weight", "aux2.fc2.bias",
    #             "fc.weight", "fc.bias"]
    # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
    # model_dict.update(pretrain_dict)
    # net.load_state_dict(model_dict)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0003)

    epochs = 120
    best_acc = 0.0
    save_path = './GoogLeNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # 设置为训练模式
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits, aux_logits2, aux_logits1 = net(images.to(device))
            # 训练时,损失为3个输出损失的加权
            loss0 = loss_function(logits, labels.to(device))
            loss1 = loss_function(aux_logits1, labels.to(device))
            loss2 = loss_function(aux_logits2, labels.to(device))
            loss = loss0 + loss1 * 0.3 + loss2 * 0.3
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        # 设置为测试模式
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                # 测试层仅有最后输出层
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

训练120epoch的准确率能到达80%左右。

三、实现图像分类

利用上述训练好的网络模型进行测试,验证是否能完成分类任务。

报错:注意这里加载模型的时候只需要加载主干网络的权重文件,不需要辅助分类器的相关文件。

加载模型文件如下:

    # 加载模型文件
    weights_path = "./GoogLeNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    # strict=False 表示在加载权重时允许不匹配的键,如果预训练权重文件中的一些权重参数与当前模型不完全匹配,也不会引发错误
    # missing_keys包含了在权重文件中存在但模型中不存在的键
    # unexpected_key包含了在模型中存在但权重文件中不存在的键
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device), strict=False)

    # model.load_state_dict(torch.load(weights_path))

RuntimeError: Error(s) in loading state_dict for GoogLeNet:
    Unexpected key(s) in state_dict: "aux1.conv.conv.weight", "aux1.conv.conv.bias", "aux1.fc1.weight", "aux1.fc1.bias", "aux1.fc2.weight", "aux1.fc2.bias", "aux2.conv.conv.weight", "aux2.conv.conv.bias", "aux2.fc1.weight", "aux2.fc1.bias", "aux2.fc2.weight", "aux2.fc2.bias". 

import os
import json

import torch
from PIL import Image, ImageDraw
from torchvision import transforms

from model import GoogLeNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载图片
    img_path = 'daisy.jpg'
    assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)
    image = Image.open(img_path)

    # img.show()
    image.show()
    # [N, C, H, W]
    img = data_transform(image)
    # 扩展维度
    img = torch.unsqueeze(img, dim=0)

    # 获取标签
    json_path = 'class_indices.json'
    assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
    with open(json_path, 'r') as f:
        # 使用json.load()函数加载JSON文件的内容并将其存储在一个Python字典中
        class_indict = json.load(f)

    # 加载网络
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)

    # 加载模型文件
    weights_path = "./GoogLeNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    # strict=False 表示在加载权重时允许不匹配的键,如果预训练权重文件中的一些权重参数与当前模型不完全匹配,也不会引发错误
    # missing_keys包含了在权重文件中存在但模型中不存在的键
    # unexpected_key包含了在模型中存在但权重文件中不存在的键
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device), strict=False)

    # model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # 对输入图像进行预测
        output = torch.squeeze(model(img.to(device))).cpu()
        # 对模型的输出进行 softmax 操作,将输出转换为类别概率
        predict = torch.softmax(output, dim=0)
        # 得到高概率的类别的索引
        predict_cla = torch.argmax(predict).numpy()

    res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
    draw = ImageDraw.Draw(image)
    # 文本的左上角位置
    position = (10, 10)
    # fill 指定文本颜色
    draw.text(position, res, fill='red')
    image.show()
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))


if __name__ == '__main__':
    main()

运行结果:

 

结束语

感谢阅读吾之文章,今已至此次旅程之终站 🛬。

吾望斯文献能供尔以宝贵之信息与知识也 🎉。

学习者之途,若藏于天际之星辰🍥,吾等皆当努力熠熠生辉,持续前行。

然而,如若斯文献有益于尔,何不以三连为礼?点赞、留言、收藏 - 此等皆以证尔对作者之支持与鼓励也 💞。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1045815.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于FPGA的图像形态学膨胀算法实现,包括tb测试文件和MATLAB辅助验证

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 在FPGA中仿真结果如下所示: 将FPGA中的仿真结果导入到matlab显示二维图,效果如下: 2.算法运行软件版本 matla…

Java 8 CompletableFuture 学习及实践笔记

CompletableFuture 学习及实践笔记 CompletableFuture 是 Java 8 引入的一个强大的异步编程工具&#xff0c;它提供了一种简洁而灵活的方式来处理异步操作和构建复杂的异步流程。 创建 CompletableFuture 使用 CompletableFuture.supplyAsync(Supplier<U> supplier) 方…

WindTerm 安装使用教程【图解】

往期回顾 MobaXtermMobaXterm 安装使用教程【图解】-CSDN博客WindTermWindTerm 安装使用教程【图解】-CSDN博客 一、WindTerm 功能介绍 WindTerm 是一款 Github 上开源的 SSH 终端工具&#xff0c;到目前为止它已经收获了 16.9K 颗星&#xff0c;它是完全可以比肩 MobaXterm 工…

AI写稿软件,最新的AI写稿软件有哪些

写作已经成为各行各业无法绕开的重要环节。不论是企业的广告宣传、新闻媒体的报道、还是个人自媒体的内容创作&#xff0c;文字都扮演着不可或缺的角色。随着信息的爆炸式增长&#xff0c;写作的需求也不断攀升&#xff0c;这使得许多人感到困扰。时间不够用、创意枯竭、写作技…

GICI-LIB源码阅读(三)因子图优化模型

原始 Markdown文档、Visio流程图、XMind思维导图见&#xff1a;https://github.com/LiZhengXiao99/Navigation-Learning 文章目录 三、因子图优化&#xff08;FGO&#xff09;1、因子图模型2、因子图优化状态估计模型3、因子图优化求解4、Ceres 非线性最小二乘库5、GICI-LIB 中…

山西电力市场日前价格预测【2023-09-28】

日前价格预测 预测说明&#xff1a; 如上图所示&#xff0c;预测明日&#xff08;2023-09-28&#xff09;山西电力市场全天平均日前电价为310.91元/MWh。其中&#xff0c;最高日前电价为373.27元/MWh&#xff0c;预计出现在18: 30。最低日前电价为235.17元/MWh&#xff0c;预计…

Java基础篇 IO流

✅作者简介&#xff1a;大家好&#xff0c;我是Leo&#xff0c;热爱Java后端开发者&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;Leo的博客 &#x1f49e;当前专栏&#xff1a; Java从入门到精通 ✨特色专栏&#xf…

服务器中了360勒索病毒怎么办?勒索病毒解密,数据恢复

在众多类型的勒索病毒中&#xff0c;360勒索病毒算是占比较高、恢复难度较大的一种类型了。由于很多用户是第一次遇到这种情况&#xff0c;所以中招以后往往不知道该如何处理。所以云天数据恢复中心将根据自己的经验&#xff0c;来告诉用户服务器中了360勒索病毒怎么办。 断开网…

这3个方法,堪比U盘数据恢复大师!

“我的u盘数据可能是被我误删了&#xff0c;现在把u盘插入电脑后发现里面什么文件都没有了。这种情况还有可能恢复u盘中的数据吗&#xff1f;” 在使用u盘的过程中&#xff0c;我们可能会经常遇到u盘数据丢失的情况。先不要太担心&#xff0c;今天小编就给大家介绍一些好用的u盘…

谈谈 Redis 数据类型底层的数据结构?

谈谈 Redis 数据类型底层的数据结构? RedisObject 在 Redis 中&#xff0c;redisObject 是一个非常重要的数据结构&#xff0c;它用于保存字符串、列表、集合、哈希表和有序集合等类型的值。以下是关于 redisObject 结构体的定义&#xff1a; typedef struct redisObject {…

【python入门篇】基础知识(1)

网上关于python入门到实践的文章多不胜数&#xff0c;为什么我还要写呢&#xff1f; 一个就是对于基础知识的一个温习&#xff0c;二来就是通过详细讲解知识的同时对于自己的表达能力的一个提升&#xff0c;后续文中会出现多个案例以及练习题&#xff0c;这边我会说一些重点掌握…

韩国coupang需要懂韩文吗?平台入驻条件及费用?——站斧浏览器

coupang需要懂韩文吗 Coupang是韩国Top级电商网站&#xff0c;品类繁多&#xff0c;截止2018年&#xff0c;该网站的注册会员数超过了2500万。2017年 和 2018 年&#xff0c; Coupang APP被评为韩国受欢迎的购物APP。Coupang网站的日活移动用户数量是第二名的三倍。 那么做co…

如何在linux操作系统下安装nvm

本文主要介绍如何在linux操作系统下安装nvm&#xff0c;如果想知道nvm如何在windows操作系统下使用&#xff0c;请参考文章如何通过nvm管理多个nodejs版本_nvm 查看所有node版本-CSDN博客。 1、nvm下载 nvm全称Node Version Manager&#xff0c;即Node版本管理器。访问官网地址…

CSS笔记——伪类和伪元素

1、伪类 伪类是用于为元素在某些特定状态下添加样式的选择器。它们可以让我们为用户与页面交互时元素的外观和行为进行样式定义。 常用的伪类有: :hover &#xff0c;鼠标悬停 :active :focus &#xff0c;表单聚焦 :blur ,失去表单聚焦 :link &#xff0c;未访问 :visi…

毕业生求职应聘,性格测评怎么破?

进入到了面试和性格测试环节&#xff0c;也有很多小伙伴&#xff0c;被一个叫做性格测试的家伙给“干下来了”。性格测试这个工具&#xff0c;也没有那么神秘&#xff0c;早在头几年&#xff0c;很多大公司&#xff0c;例如阿里、美的、华为&#xff0c;都能在招聘中&#xff0…

现在的国内MBA教育是否同质化太严重?

如今在国内的MBA教育领域可以说是一片欣欣向荣&#xff0c;两百余所高校开设MBA项目招生&#xff0c;而报考市场也随着时代的发展持续升温&#xff0c;但是在这背后也存在一些问题伴随发生&#xff0c;其中就是MBA项目的同质化与跟风化趋势越来越明显&#xff0c;主要有以下几个…

每天学习3个小时能不能考上浙大MBA项目?

不少考生经常会问到上岸浙大MBA项目想要复习多长时间&#xff0c;这个问题其实没有固定答案。在行业十余年的经验总结来看&#xff0c;杭州达立易考教育认为基于每一位考生的个人复习时间、个人学习能力以及原有基础情况等不同&#xff0c;复习上岸的预期分数目标也会有差异&am…

接口测试--Postman常用断言

Postman的断言是用javascript语言写的 引入--什么是断言 结果中的特定属性或值与预期做对比&#xff0c;如果一致&#xff0c;则用例通过&#xff0c;如果不一致&#xff0c;断言失败&#xff0c;用例失败。断言&#xff0c;是一个完整测试用例所不可或缺的一部分&#xff0c…

maven下载、本地仓库设置与idea内置maven设置

一、下载安装maven maven下载官网&#xff1a;https://maven.apache.org/download.cgi 下载到本地后解压 二、配置环境变量 我的电脑-属性-高级系统设置-环境变量/系统变量 新建MAVEN_HOME 变量值为自己的maven包所在的位置 编辑path 添加 %MAVEN_HOME%\bin 三、测试 Win…

【面试八股】IP协议八股

IP协议八股 子网掩码的作用为什么IP协议需要分片IP协议什么时候需要分片IP协议是怎么进行分片的那么IP协议是如果进行标识属于同一个分片呢&#xff1f;TCP协议和UDP协议将数据交给IP协议之后&#xff0c;是否需要分片传输&#xff1f; 子网掩码的作用 用来标识网络号和主机号…