【神经网络】【GoogleNet】

news2025/1/11 10:03:06

1、引言

卷积神经网络是当前最热门的技术,我想深入地学习这门技术,从他的发展历史开始,了解神经网络算法的兴衰起伏;同时了解他在发展过程中的**里程碑式算法**,能更好的把握神经网络发展的未来趋势,了解神经网络的特征。
之前的LeNet为以后的神经网络模型打下了一个基础的框架,真正让神经网络模型在外界广泛引起关注的还是AlexNet,在AlexNet之后也出现了相应对他的改进,或多或少会有一些效果。但是ZFNet是在AlexNet上的改进,他的论文对神经网络的各个层级的作用,做了十分详细的阐述,为如何优化模型,怎样“有理有据”地调节参数指出了一个方向,这是他最大的一个贡献。VggNet也是在2014年的ImageNet的定位赛和分类赛上获得了第一名和第二名,他在原来卷积神经网络结构的基础上,大大增加了网络的深度,最后取得了不错的成绩。GoogleNet同样是在2014年诞生的,他在ImageNet大规模视觉识别挑战赛(ILSVRC14)上提出了一种代号为Inception的深度卷积神经网络结构,并在分类和检测上取得了新的最好结果。
GoogleNet论文原文下载地址
GoogleNet论文中文详解

2、GoogleNet提出背景

在LeNet之后,随着计算能力的提升,研究者不断改进模型的表达能力,最明显的是卷积层数的增加,每一层的内部通道数目也在增加。其中最具代表性的就是AlexNet模型。
AlexNet由Hinton和他的学生Alex Krizhevsky设计,模型名字来源于论文第一作者的姓名Alex。该模型以很大的优势获得了2012年ISLVRC竞赛的冠军网络,分类准确率由传统的 70%+提升到 80%+,自那年之后,深度学习开始迅速发展。GoogLeNet是2014年Christian Szegedy提出的一种全新的深度学习结构,在这之前的AlexNet、VGG等结构都是通过增大网络的深度(层数)来获得更好的训练效果,但层数的增加会带来很多负作用,比如overfit、梯度消失、梯度爆炸等。inception的提出则从另一种角度来提升训练结果:能更高效的利用计算资源,在相同的计算量下能提取到更多的特征,从而提升训练结果。值得注意的是GoogLeNet的参数为500w个(5M),VGG16的参数是138M,在表现接近的情况下,GoogLeNet的参数量有明显的优势。
注:

ImageNet是一个在2009年创建的图像数据集,从2010年开始到2017年举办了七届的ImageNet 挑战赛——ImageNet
Large Scale Visual Recognition ChallengeI (LSVRC),在这个挑战赛上诞生了AlexNet、ZFNet、OverFeat、VGG、Inception、ResNet、WideResNet、FractalNet、DenseNet、ResNeXt、DPN、SENet 等经典模型。

一般来说,提升网络性能最直接的办法就是增加网络深度和宽度,深度指网络层次数量、宽度指神经元数量。但这种方式存在以下问题:

  • (1)参数太多,如果训练数据集有限,很容易产生过拟合;
  • (2)网络越大、参数越多,计算复杂度越大,难以应用;
  • (3)网络越深,容易出现梯度弥散问题(梯度越往后穿越容易消失),难以优化模型。

所以,有人调侃“深度学习”其实是“深度调参”。
解决这些问题的方法当然就是在增加网络深度和宽度的同时减少参数,为了减少参数,自然就想到将全连接变成稀疏连接。但是在实现上,全连接变成稀疏连接后实际计算量并不会有质的提升,因为大部分硬件是针对密集矩阵计算优化的,稀疏矩阵虽然数据量少,但是计算所消耗的时间却很难减少。
那么,有没有一种方法既能保持网络结构的稀疏性,又能利用密集矩阵的高计算性能。大量的文献表明可以将稀疏矩阵聚类为较为密集的子矩阵来提高计算性能,就如人类的大脑是可以看做是神经元的重复堆积,因此,GoogLeNet团队提出了Inception网络结构,就是构造一种“基础神经元”结构,来搭建一个稀疏性、高计算性能的网络结构。

3、GoogleNet的模型详解

在这里插入图片描述
GoogLeNet的网络结构图细节如下:
在这里插入图片描述GoogLeNet网络结构明细表解析如下:

  • 0、输入

原始输入图像为224x224x3,且都进行了零均值化的预处理操作(图像每个像素减去均值)。

  • 1、第一模块

第一模块采用的是一个单纯的卷积层紧跟一个最大池化层。
卷积层:卷积核大小77,步长为2,padding为3,输出通道数64,输出特征图尺寸为(224-7+32)/2+1=112.5(向下取整)=112,输出特征图维度为112x112x64,卷积后进行ReLU操作。
池化层:窗口大小33,步长为2,输出特征图尺寸为((112 -3)/2)+1=55.5(向上取整)=56,输出特征图维度为56x56x64。

  • 2、第二模块

第二模块采用2个卷积层,后面跟一个最大池化层。

在这里插入图片描述卷积层:(1)先用64个1x1的卷积核(3x3卷积核之前的降维)将输入的特征图(56x56x64)变为56x56x64,然后进行ReLU操作。参数量是116464=4096。
(2)再用卷积核大小33,步长为1,padding为1,输出通道数192,进行卷积运算,输出特征图尺寸为(56-3+12)/1+1=56,输出特征图维度为56x56x192,然后进行ReLU操作。参数量是3364192=110592。第二模块卷积运算总的参数量是110592+4096=114688,即114688/1024=112K。
池化层: 窗口大小33,步长为2,输出通道数192,输出为((56 - 3)/2)+1=27.5(向上取整)=28,输出特征图维度为28x28x192。

  • 第三模块(Inception 3a层)

Inception 3a层,分为四个分支,采用不同尺度,图示如下:
在这里插入图片描述
再看下表格结构,来分析和计算吧:
在这里插入图片描述
(1)使用64个1x1的卷积核,运算后特征图输出为28x28x64,然后RuLU操作。参数量1119264=12288。
(2)96个1x1的卷积核(3x3卷积核之前的降维)运算后特征图输出为28x28x96,进行ReLU计算,再进行128个3x3的卷积,输出28x28x128。参数量1119296+3396128=129024。
(3)16个1x1的卷积核(5x5卷积核之前的降维)将特征图变成28x28x16,进行ReLU计算,再进行32个5x5的卷积,输出28x28x32。参数量1119216+551632=15872。
(4)pool层,使用3x3的核,输出28x28x192,然后进行32个1x1的卷积,输出28x28x32.。总参数量1119232=6144。
将四个结果进行连接,对这四部分输出结果的第三维并联,即64+128+32+32=256,最终输出28x28x256。总的参数量是12288+129024+15872+6144=163328,即163328/1024=159.5K,约等于159K。

  • 第三模块(Inception 3b层)

在这里插入图片描述

Inception 3b层,分为四个分支,采用不同尺度。
(1)128个1x1的卷积核,然后RuLU,输出28x28x128。
(2)128个1x1的卷积核(3x3卷积核之前的降维)变成28x28x128,进行ReLU,再进行192个3x3的卷积,输出28x28x192。
(3)32个1x1的卷积核(5x5卷积核之前的降维)变成28x28x32,进行ReLU,再进行96个5x5的卷积,输出28x28x96。
(4)pool层,使用3x3的核,输出28x28x256,然后进行64个1x1的卷积,输出28x28x64。
将四个结果进行连接,对这四部分输出结果的第三维并联,即128+192+96+64=480,最终输出输出为28x28x480。
Inception 3b和Inception 4a之间有一个最大池化下采样层 窗口大小33,步长为2,输出特征图维度为14x14x480。
在这里插入图片描述

  • 第四模块(Inception 4a、4b、4c、4e)

与Inception3a,3b类似
在这里插入图片描述
Inception 4e和Inception 5a之间有一个最大池化下采样层 窗口大小33,步长为2,输出特征图维度为7x7x832。
在这里插入图片描述

  • 第五模块(Inception 5a、5b)

与Inception3a,3b类似
在这里插入图片描述

  • 输出层

在输出层GoogLeNet与AlexNet、VGG采用3个连续的全连接层不同,GoogLeNet采用的是全局平均池化层,得到的是高和宽均为1的卷积层,然后添加丢弃概率为40%的Dropout,输出层激活函数采用的是softmax。

  • 激活函数

GoogLeNet每层使用的激活函数为ReLU激活函数。

  • 辅助分类器(分别来自于Inception 4a和Inception 4d的输出)

根据实验数据,发现神经网络的中间层也具有很强的识别能力,为了利用中间层抽象的特征,在某些中间层中添加含有多层的分类器。如下图所示,红色边框内部代表添加的辅助分类器。GoogLeNet中共增加了两个辅助的softmax分支,作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0。二是将中间某一层输出用作分类,起到模型融合作用。最后的loss=loss_2 + 0.3 * loss_1 + 0.3 * loss_0。实际测试时,这两个辅助softmax分支会被去掉。
(1)辅助分类器的第一层是一个平均池化下采样层,池化核大小为5x5,stride=3。使得(4a)阶段的输出为4×4×512,(4d)的输出为4×4×528。
(2)第二层是卷积层,卷积核大小为1x1,stride=1,卷积核个数是128。
(3)第三层是全连接层,节点个数是1024。
(4)丢弃70%输出的丢弃层。
(5)第四层是全连接层,节点个数是1000(对应分类的类别个数)。
在这里插入图片描述

4、GoogleNet的创新之处

4.1、提出Inception结构

Inception的设计原则

  • 逐层构造网络:如果数据集的概率分布能够被一个神经网络所表达,那么构造这个网络的最佳方法是逐层构筑网络,即将上一层高度相关的节点连接在一起。几乎所有效果好的深度网络都具有这一点,不管AlexNet
    VGG堆叠多个卷积,googleNet堆叠多个inception模块,还是ResNet堆叠多个resblock。
  • 稀疏的结构:人脑的神经元连接就是稀疏的,因此大型神经网络的合理连接方式也应该是稀疏的。稀疏的结构对于大型神经网络至关重要,可以减轻计算量并减少过拟合。
    卷积操作(局部连接,权值共享)本身就是一种稀疏的结构,相比于全连接网络结构是很稀疏的。
  • 符合Hebbian原理: Cells that fire together, wire together. 一起发射的神经元会连在一起。
    相关性高的节点应该被连接而在一起。
    inception中 1×1的卷积恰好可以融合三者。我们一层可能会有多个卷积核,在同一个位置但在不同通道的卷积核输出结果相关性极高。一个1×1的卷积核可以很自然的把这些相关性很高,在同一个空间位置,但不同通道的特征结合起来。而其它尺寸的卷积核(3×3,5×5)可以保证特征的多样性,因此也可以适量使用。于是,这就完成了inception module下图的设计初衷:4个分支:
    在这里插入图片描述

4.2、1*1卷积核降维

左上图是GoogleNet作者设计的初始inception结构(native inception),其想法是用多个不同类型的卷积核(1 × 1 1\times11×1,3 × 3 3\times33×3,5 × 5 5\times55×5,3 × 3 P o o l 3\times3Pool3×3Pool)堆叠在一起(卷积、池化后的尺寸相同,将通道相加)代替一个3x3的小卷积核,好处是可以使提取出来的特征具有多样化,并且特征之间的co-relationship不会很大,最后用把feature map都concatenate起来使网络做得很宽,然后堆叠Inception Module将网络变深。但仅仅简单这么做会使一层的计算量爆炸式增长
native inception中所有的卷积核都在上一层的所有输出上来做,而那个5x5的卷积核所需的计算量就太大了,造成了特征图的厚度很大,为了避免这种情况,在3x3前、5x5前、max pooling后分别加上了1x1的卷积核,以起到了降低特征图厚度的作用,这也就形成了Inception v1的网络结构(右上图)。
假设input feature map的size为28 × 28 × 256 28\times28\times25628×28×256,output feature map的size为28 × 28 × 480 28\times28\times48028×28×480,则native Inception Module的计算量有854M。计算过程如下
在这里插入图片描述从上图可以看出,计算量主要来自高维卷积核的卷积操作,因而在每一个卷积前先使用1 × 1 1\times11×1卷积核将输入图片的feature map维度先降低,进行信息压缩,在使用3x3卷积核进行特征提取运算,相同情况下,Inception v1的计算量仅为358M。
在这里插入图片描述
Inception结构总共有4个分支,输入的feature map并行的通过这四个分支得到四个输出,然后在在将这四个输出在深度维度(channel维度)进行拼接(concate)得到我们的最终输出(注意,为了让四个分支的输出能够在深度方向进行拼接,必须保证四个分支输出的特征矩阵高度和宽度都相同),因此inception结构的参数为:

  • branch1: C o n v 1 × 1 Conv 1\times1Conv1×1, stride=1
  • branch2: C o n v 3 × 3 Conv 3\times3Conv3×3, stride=1, padding=1
  • branch3: C o n v 5 × 5 Conv 5\times5Conv5×5, stride=1, padding=2
  • branch4: M a x P o o l 3 × 3 MaxPool 3\times3MaxPool3×3, stride=1,padding=1
    GoogLeNet中使用了9个Inception v1 module,分别被命名为inception(3a)、inception(3b)、inception(4a)、inception(4b)、inception(4c)、inception(4d)、inception(4e)、inception(5a)、inception(5b)。

4.3、两个辅助分类器帮助训练

作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0。二是将中间某一层输出用作分类,起到模型融合作用。
  辅助函数Axuiliary Function:从信息流动的角度看梯度消失,因为是梯度信息在BP过程中能量衰减,无法到达浅层区域,因此在中间开个口子,加个辅助损失函数直接为浅层。
  GoogLeNet网络结构中有深层和浅层2个分类器,为了避免梯度消失,两个辅助分类器结构是一模一样的,其组成如下图所示,这两个辅助分类器的输入分别来自Inception(4a)和Inception(4d)。

辅助分类器的第一层是一个平均池化下采样层,池化核大小为5x5,stride=3;第二层是卷积层,卷积核大小为1x1,stride=1,卷积核个数是128;第三层是全连接层,节点个数是1024;第四层是全连接层,节点个数是1000(对应分类的类别个数)。
  辅助分类器只是在训练时使用,在正常预测时会被去掉。辅助分类器促进了更稳定的学习和更好的收敛,往往在接近训练结束时,辅助分支网络开始超越没有任何分支的网络的准确性,达到了更高的水平。

4.4、使用平均化吃层(减少模型参数)

网络最后采用了average pooling(平均池化)来代替全连接层,该想法来自NIN(Network in Network),事实证明这样可以将准确率提高0.6%。

5、GoogleNet的代码实现

#class_indices.json
{
    "0": "daisy",
    "1": "dandelion",
    "2": "roses",
    "3": "sunflowers",
    "4": "tulips"
}
#model.py
import torch.nn as nn
import torch
import torch.nn.functional as F


class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)

        # N x 192 x 28 x 28
        x = self.inception3a(x)
        # N x 256 x 28 x 28
        x = self.inception3b(x)
        # N x 480 x 28 x 28
        x = self.maxpool3(x)
        # N x 480 x 14 x 14
        x = self.inception4a(x)
        # N x 512 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        # N x 512 x 14 x 14
        x = self.inception4c(x)
        # N x 512 x 14 x 14
        x = self.inception4d(x)
        # N x 528 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        # N x 832 x 14 x 14
        x = self.maxpool4(x)
        # N x 832 x 7 x 7
        x = self.inception5a(x)
        # N x 832 x 7 x 7
        x = self.inception5b(x)
        # N x 1024 x 7 x 7

        x = self.avgpool(x)
        # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 1024
        x = self.dropout(x)
        x = self.fc(x)
        # N x 1000 (num_classes)
        if self.training and self.aux_logits:   # eval model lose this layer
            return x, aux2, aux1
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            # 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.averagePool(x)
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x num_classes
        return x


class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x
#predict.py
import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)

    # load model weights
    weights_path = "./googleNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
                                                          strict=False)

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()
#train.py
import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
    # 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
    # 官方的模型中使用了bn层以及改了一些参数,不能混用
    # import torchvision
    # net = torchvision.models.googlenet(num_classes=5)
    # model_dict = net.state_dict()
    # # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
    # pretrain_model = torch.load("googlenet.pth")
    # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
    #             "aux2.fc2.weight", "aux2.fc2.bias",
    #             "fc.weight", "fc.bias"]
    # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
    # model_dict.update(pretrain_dict)
    # net.load_state_dict(model_dict)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0003)

    epochs = 30
    best_acc = 0.0
    save_path = './googleNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits, aux_logits2, aux_logits1 = net(images.to(device))
            loss0 = loss_function(logits, labels.to(device))
            loss1 = loss_function(aux_logits1, labels.to(device))
            loss2 = loss_function(aux_logits2, labels.to(device))
            loss = loss0 + loss1 * 0.3 + loss2 * 0.3
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))  # eval model only have last output layer
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

6、总结

书山有路勤为径,学海无涯苦作舟。

7、参考文章

7.1(四)卷积神经网络模型之——GoogLeNet
7.2 GoogLeNet网络详解与模型搭建
7.3 Google Inception Net论文细读
7.4 深度学习入门笔记之GoogLeNet网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1182826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第1天:Python基础语法(一)

** 1、Python简介 ** Python是一种高级、通用的编程语言,由Guido van Rossum于1989年创造。它被设计为易于阅读和理解,具有简洁而清晰的语法,使得初学者和专业开发人员都能够轻松上手。 Python拥有丰富的标准库,提供了广泛的功…

生态环境领域基于R语言piecewiseSEM结构方程模型

结构方程模型(Sructural Equation Modeling,SEM)可分析系统内变量间的相互关系,并通过图形化方式清晰展示系统中多变量因果关系网,具有强大的数据分析功能和广泛的适用性,是近年来生态、进化、环境、地学、…

AI 绘画 | Stable Diffusion 涂鸦功能与局部重绘

在 StableDiffusion图生图的面板里,除了图生图(img2img)选卡外,还有局部重绘(Inpaint),涂鸦(Sketch),涂鸦重绘(Inpaint Sketch),上传重绘蒙版(Inpaint Uplaod)、批量处理&#xff08…

图像标注工具lableImg安装出错怎么办?

我们要训练自己的图像识别模型,首先要进行图像的标注。labelimg就是一款可视化的图像标注工具。它是用Python编写的,通过Qt实现其图形界面,尽管它只支持矩形框标注,但因跨平台,支持Linux、Mac OS、Windows,…

部分iOS机型 new Date() 时间 NAN

部分 iOS 机型 new Date() 时间 NAN 解决代码 是因为部分 iOS 机型 new Date(2023-01-01 00:00:00) 时, 获取时间戳的时间年月日用 - 分隔,将 - 分隔改为 / 分隔即可 new Date(2023/01/01 00:00:00)

【java】实现自定义注解校验——方法二

自定义注解校验的实现步骤: 1.创建注解类,编写校验注解,即类似NotEmpty注解 2.编写自定义校验的逻辑实体类,编写具体的校验逻辑。(这个类可以实现ConstraintValidator这个接口,让注解用来校验) 3.开启使用自定义注解进…

独立开发者学习的技术栈

# 前端 语言 - HTML - CSS/Sass/PostCSS - JavaScript/TypeScriptJS框架 - Vue - NuxtJS - React - NextJS - RemixJS CSS框架 - Tailwindcss - Bulma# 设计语言 - Ant Design - Material Design#后端 语言 - JavaScript/TypeScript - Python - Java - PHP 框架 - NestJS - Exp…

github 上传代码报错 fatal: Authentication failed for ‘xxxxxx‘

问题 今天一时兴起创建了个 github 新仓库,首次上传本地代码时,遇到了一个报错。本来以为是账号密码的问题,搞了好几次,发现都没错的情况下还是上传不上去。目测判断是认证相关问题,具体报错信息如下: rem…

JavaScript基础入门03

目录 1.条件语句 1.1if 语句 1.1.1基本语法格式 1.1.2练习案例 1.2三元表达式 1.3switch 2.循环语句 2.1while 循环 2.2continue 2.3break 2.4for 循环 3.数组 3.1创建数组 3.2获取数组元素 3.3新增数组元素 3.3.1. 通过修改 length 新增 3.3.2. 通过下标新增 …

OpenShift - 利用容器的特权配置实现对OpenShift攻击,以及如何使用 PSA 和 RHACS 防范风险

《OpenShift / RHEL / DevSecOps 汇总目录》 说明:本文已经在 OpenShift 4.14 的环境中验证 本文是《容器安全 - 利用容器的特权配置实现对Kubernetes攻击》的后续篇,来介绍 在 OpenShift 环境中的容器特权配置和攻击过程和 Kubernetes 环境的差异&…

【Spring】Spring IOCDI(万字详解)

文章目录 1. Spring是什么?2. 认识IOC2.1 传统程序开发1. Main.java2. Car.java3. Framework.java4. Bottom.java5. Tire.java 2.2 分析传统开发2.3 IOC程序开发1. Main.java2. Car.java3. Framework.java4. Bottom.java5. Tire.java 2.4 分析IOC开发2.5 IOC容器优点…

零代码编程:用ChatGPT批量将Mp4视频转为Mp3音频

文件夹中有很多mp4视频文件,如何利用ChatGPT来全部转换为mp3音频呢? 在ChatGPT中输入提示词: 你是一个Python编程专家,要完成一个批量将Mp4视频转为Mp3音频的任务,具体步骤如下: 打开文件夹:…

解决Java中https请求接口报错问题

1. 解决SSLException: Certificate for <域名> doesn‘t match any of the subject alternative报错问题 1.1 问题描述 最近在做一个智能问答客服项目,对接的是云问接口,然后云问接口对接使用的是https方式,之前一直…

MySQL | MySQL不区分大小写配置

MySQL不区分大小写配置 1.表内数据条件查询不区分大小写2. 表名字段名不区分大小写 1.表内数据条件查询不区分大小写 MySQL 表内数据条件查询不区分大小写是因为排序规则的问题. 在MySQL中,InnoDB存储引擎默认的字符集是utf8,utf8mb4等,这些字符集再存储数据时没有…

Flink -- 事件时间 Watermark

1、事件时间: 指的是数据产生的时间或是说是数据发生的时间。 在Flink中有三种时间分别是: Event Time:事件时间,数据产生的时间,可以反应数据真实发生的时间 Infestion Time:事件接收时间 Processing Tim…

【机器学习2】模型评估

模型评估主要分为离线评估和在线评估两个阶段。 针对分类、 排序、 回归、序列预测等不同类型的机器学习问题, 评估指标的选择也有所不同。 1 评估指标 1.1准确率 准确率是指分类正确的样本占总样本个数的比例 但是准确率存在明显的问题,比如当负样本…

互联网Java工程师面试题·Spring篇·第六弹

目录 ​编辑 21.什么是 Spring beans? 22、一个 Spring Bean 定义 包含什么? 23、如何给 Spring 容器提供配置元数据? 24、你怎样定义类的作用域? 25、解释 Spring 支持的几种 bean 的作用域。 26、Spring 框架中的单例 bean 是线程安全的吗? 27、解释 …

C/C++(a/b)*c的值 2021年6月电子学会青少年软件编程(C/C++)等级考试一级真题答案解析

目录 C/C(a/b)*c的值 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 C/C(a/b)*c的值 2021年6月 C/C编程等级考试一级编程题 一、题目要求 1、编程实现 给定整数a、b、c,计算(a / b)*c的值&…

专业128分总分390+上岸中山大学884信号与系统电通院考研经验分享

专业课884 信号系统 过年期间开始收集报考信息,找到了好几个上岸学姐和学长,都非常热情,把考研的准备,复习过程中得与失,都一一和我分享,非常感谢。得知这两年专业课难度提高很多,果断参加了学长…

智能网联汽车基础软件信息安全需求分析

目录 1.安全启动 2.安全升级 3.安全存储 4.安全通信 5.安全调试 6.安全诊断 7.小结 1.安全启动 对于MCU,安全启动主要是以安全岛BootROM为信任根,在MCU启动后,用户程序运行前,硬件加密模块采用逐级校验、并行校验或者混合校…