GoogLeNet网络

news2024/10/5 5:39:05

目录

1. 创新点

1.1 引入Inception结构

1.2 1×1卷积降维

1.3 两个辅助分类器 

1.4 丢弃全连接层,使用平均池化层

2. 网络结构

3. 知识点

3.1 torch.cat

3.2 关于self.training

3.3 关于load_state_dict中的strict

4. 代码 

4.1 model.py

4.2 train.py

4.3 predict.py

5. 结果


1. 创新点

1.1 引入Inception结构

作用:融合不同尺度的特征信息

注意:每个分支所得特征矩阵的宽、高必须相同

下图来自:Going deeper with convolutions

1.2 1×1卷积降维

channels: 512

a.不使用1×1卷积核降维

使用:64个5×5卷积核进行卷积

参数:5×5×512×64=819,200

b.使用1×1卷积核降维

使用:24个1×1卷积核进行卷积

1.3 两个辅助分类器 

内容:GoogLeNet有三个输出层(两个为辅助分类层)

 Going deeper with convolutions文章里:

  • An average pooling layer with 5×5 filter size and stride 3, resulting in an 4×4×512 output for the (4a), and 4×4×528 for the (4d) stage.
  • A 1×1 convolution with 128 filters for dimension reduction and rectified linear activation.
  • A fully connected layer with 1024 units and rectified linear activation.
  • A dropout layer with 70% ratio of dropped outputs.
  • A linear layer with softmax loss as the classifier (predicting the same 1000 classes as the main classifier, but removed at inference time).

1.4 丢弃全连接层,使用平均池化层

作用:大大减少模型的参数

2. 网络结构

Inception层太多,列出几个:

3. 知识点

3.1 torch.cat

import torch

a = torch.Tensor([1, 2, 3])
b = torch.Tensor([4, 5, 6])
c = [a, b]
print(torch.cat(c))
# tensor([1., 2., 3., 4., 5., 6.])

3.2 关于self.training

使用model.train()和model.eval()控制模型的状态

在model.train()模式下self.training=True

在model.eval()模式下self.training=False

3.3 关于load_state_dict中的strict

为True:有什么要什么,每一个键都有(默认为True)

为False:有什么要什么,没有的就不要

missing_keys和unexpected_keys:缺失的、不期望的键

4. 代码 

4.1 model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_use=True, init_weight=False):
        super(GoogLeNet, self).__init__()
        self.aux_use = aux_use
        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)  # ceil_mode默认向下取整 True为向上取整
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_use:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应平均池化 指定输出(H,W)
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weight:
            self._initialize_weights_()

    def forward(self, x):
        # N×3×224×224
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)
        x = self.inception4a(x)
        if self.training and self.aux_use:
            aux1 = self.aux1(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        if self.training and self.aux_use:
            aux2 = self.aux2(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)
        x = self.inception5a(x)
        x = self.inception5b(x)
        x = self.avgpool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.dropout(x)
        x = self.fc(x)
        if self.training and self.aux_use:
            return x, aux1, aux2
        return x;

    def _initialize_weights_(self):
        for v in self.modules():
            if isinstance(v, nn.Conv2d):
                nn.init.xavier_uniform_(v.weight)
                if v.bias is not None:
                    nn.init.constant_(v.bias, 0)
            if isinstance(v, nn.Linear):
                nn.init.xavier_uniform_(v.weight)
                if v.bias is not None:
                    nn.init.constant_(v.bias, 0)


# set BasicConv2d class
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x;


# set Inception class
class Inception(nn.Module):
    # 各分支最后的输出宽高要一样
    def __init__(self, in_channels, ch11, ch33_reduce, ch33, ch55_reduce, ch55, pool_proj):
        super(Inception, self).__init__()
        self.branch1 = BasicConv2d(in_channels, ch11, kernel_size=1)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch33_reduce, kernel_size=1),
            BasicConv2d(ch33_reduce, ch33, kernel_size=3, padding=1)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch55_reduce, kernel_size=1),
            BasicConv2d(ch55_reduce, ch55, kernel_size=5, padding=2)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, dim=1)


# set InceptionAux class
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output:[batch,128,4,4]
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # Input: Aux1(batch,512,14,14) Aux2(batch,528,14,14)
        x = self.averagePool(x)
        # output: Aux1(batch,512,4,4) Aux2(batch,528,4,4)
        x = self.conv(x)
        # output:Aux1、Aux2(batch,128,4,4)
        x = torch.flatten(x, start_dim=1)
        x = F.dropout(x, 0.5, training=self.training)
        # batch × 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # batch × 1024
        x = self.fc2(x)
        # batch × num_classes
        return x

4.2 train.py

import os
import sys

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import json
from model import GoogLeNet
import torch.optim as optim
from tqdm import tqdm

def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    data_transform = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    }
    data_root = os.path.abspath(os.getcwd())
    image_path = os.path.join(data_root, 'data_set', 'flower_data')
    assert os.path.exists(image_path), 'file:{} is not exist!'.format(image_path)

    # set dataset
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'), transform=data_transform['train'])
    val_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])
    train_num = len(train_dataset)
    val_num = len(val_dataset)

    # write dict into file
    flower_list = train_dataset.class_to_idx
    class_dict = dict((k, v) for v, k in flower_list.items())
    json_str = json.dumps(class_dict, indent=4)
    with open('./class_indices.json', 'w') as file:
        file.write(json_str)

    # set dataloader
    batch_size = 32
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    print('using {} images for training, {} images for validation.'.format(train_num, val_num))

    net = GoogLeNet(num_classes=5, aux_use=True, init_weight=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0003)

    epochs = 30
    best_acc = 0.0
    save_path = './GoogLeNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        epoch_loss = 0.0
        train_bar = tqdm(train_loader)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            output, aux1_output, aux2_output = net(images.to(device))
            loss0 = loss_function(output, labels.to(device))
            loss1 = loss_function(aux1_output, labels.to(device))
            loss2 = loss_function(aux2_output, labels.to(device))
            loss = loss0 + 0.3 * loss1 + 0.3 * loss2
            loss.backward()
            optimizer.step()
            # print statistics
            epoch_loss += loss.item()
            train_bar.desc = 'train epoch[{}/{}] loss:{:.3f}'.format(epoch + 1, epochs, loss)

        # validate
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader)
            for step, data in enumerate(val_bar):
                val_images, val_labels = data
                outputs = net(val_images.to(device))
                predict_y = torch.argmax(outputs, dim=1)
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_acc = acc / val_num
        print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, epoch_loss / train_steps, val_acc))

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), save_path)
    print('Finished Training!')


if __name__ == '__main__':
    main()

4.3 predict.py

import os
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
from model import GoogLeNet

def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    img_path = './sunflower.jpg'
    assert os.path.exists(img_path), 'file:{} is not exist!'.format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)

    # [N,C,H,W]
    img = transform(img)
    img = torch.unsqueeze(img, dim=0)

    # read class_dict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), 'file:{} is not exist!'.format(json_path)
    with open(json_path, 'r') as file:
        class_dict = json.load(file)

    # create model
    net = GoogLeNet(num_classes=5, aux_use=False).to(device)

    # load model weights
    weight_path = './GoogLeNet.pth'
    assert os.path.exists(weight_path), 'file:{} is not exist!'.format(weight_path)
    # unexpected_keys里面存放的是辅助分类器aux1与aux2的权重
    missing_keys, unexpected_keys = net.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
    net.eval()
    with torch.no_grad():
        outputs = torch.squeeze(net(img.to(device))).cpu()
        predict = torch.softmax(outputs, dim=0)
        predict_class = torch.argmax(predict).numpy()

    print_res = 'class:{} probability:{:.3f}'.format(class_dict[str(predict_class)], predict[predict_class])
    plt.title(print_res)
    for i in range(len(predict)):
        print('class:{:10} probability:{:.3f}'.format(class_dict[str(i)], predict[i]))
    plt.show()


if __name__ == '__main__':
    main()

5. 结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1041677.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MQ - 30 基础功能:死信队列的设计

文章目录 导图概述什么是死信队列死信队列实现的技术方案死信队列的存储目标死信队列的方案设计生产死信队列消费死信队列Broker 的死信队列主流消息队列的死信功能RocketMQRabbitMQ总结导图 概述 在日常业务的消费数据过程中,如果遇到数据无法被正确处理,就需要先手动把消息…

202309解决新版Animate Diff报错问题

在安装最新版的Animate Diff的时候发现报错了 ImportError: cannot import name images_tensor_to_samples from modules.sd_samplers_common (W:\A1111\stable-diffusion-webui\modules\sd_samplers_common.py)原因很简单,Animate diff已经适配了新版webui1.6 所以…

基于LQR算法的一阶倒立摆控制

1. 一阶倒立摆建模 2. 数学模型 倒立摆的受力分析网上有很多,这里就不再叙述。直接放线性化后的方程: F (Mm)x″-mLφ″ (ImL)φ″ mLx″ mgLφ(F为外力,x为物块位移,M,m为物块和摆杆的质量,…

计算机是怎么跑起来的(2)?程序如何驱动硬件工作的?

上文计算机是怎么跑起来的?从零开始手动组装微型计算机我们说了,如何手动从来组装一台计算机,那组装完后的计算机上是怎么跑起来程序的呢?程序是如何驱动硬件工作的? 前面我们通过DMA将代码输入到内存的指定位置,然后…

JavaWeb 学习

1. 基本概念 1.1 Web web:网络,网页 静态 web html,css提供给所有人看的数据始终不会变化 动态 web 淘宝提供给每个人看的数据会有所不同技术栈:Servlet/JSP,ASP,PHP Java 中,动态 web 资…

如何通过优化Read-Retry机制降低SSD读延迟?

近日,小编发现发表于2021论文中,有关于优化Read-Retry机制降低SSD读延迟的研究,小编这里给大家分享一下这篇论文的核心的思路,感兴趣的同学可以,可以在【存储随笔】VX公号后台回复“Optimizing Read-Retry”获取下载链接。 本文中主要基于Charge Trap NAND架构分析。NAND基…

[C++网络协议] 优于select的epoll

1.epoll函数为什么优于select函数 select函数的缺点: 调用select函数后,要针对所有文件描述符进行循环处理。每次调用select函数,都需要向该函数传递监视对象信息。 对于缺点2,是提高性能的最大障碍。因为,套接字是…

python+vue驾校驾驶理论考试模拟系统

管理员的主要功能有: 1.管理员输入账户登陆后台 2.个人中心:管理员修改密码和账户信息 3.用户管理:管理员可以对用户信息进行添加,修改,删除,查询 4.添加选择题:管理员可以添加选择题目&#xf…

读高性能MySQL(第4版)笔记15_备份与恢复(下)

1. 二进制日志 1.1. 服务器的二进制日志是需要备份的最重要元素之一 1.2. 对于基于时间点的恢复是必需的,并且通常比数据要小,所以更容易被进行频繁的备份 1.3. 如果有某个时间点的数据备份和所有从那时以后的二进制日志,就可以重放从上次…

MySQL数据库入门到精通8--进阶篇( MySQL管理)

7. MySQL管理 7.1 系统数据库 Mysql数据库安装完成后,自带了一下四个数据库,具体作用如下: 7.2 常用工具 7.2.1 mysql 该mysql不是指mysql服务,而是指mysql的客户端工具。 语法 : mysql [options] [database] 选…

NumPy数值计算

1、Numpy概念 1.1Numpy是什么? Numpy是(Numerical Python的缩写):一个开源的Python科学计算库使用NumPy可以方便的使数组、矩阵进行计算包含线性代数、傅里叶变换、随机数生成等大量函数 1.2为什么使用Numpy 对于同样的数值计…

Windows打开:控制面板\网络和 Internet\网络连接 显示空白怎么办?

Windows打开:控制面板\网络和 Internet\网络连接 显示空白怎么办? 最近有用户反馈遇到这个问题,问题产生原因:在卸载某个软件的时候,系统提示需要重新启动计算机,但是,启动之后,就…

前端项目练习(练习-002-NodeJS项目初始化)

首先,创建一个web-002项目,内容和web-001一样。 下一步,规范一下项目结构,将html,js,css三个文件放到 src/view目录下面: 由于html引入css和js时,使用的是相对路径,所以…

从聚水潭到金蝶云星空通过接口配置打通数据

从聚水潭到金蝶云星空通过接口配置打通数据 数据源平台:聚水潭 聚水潭SaaSERP于2014年4月上线,目前累计超过2.5万商家注册使用,成为淘宝应用服务市场ERP类目商家数和商家月订单增速最快的ERP。2014年及2015年“双十一”当天,聚水潭SaaSERP平稳…

python的讲解和总结V2.0

python的讲解和总结V2.0 一、Python的历史二、Python的特点三、Python的语法四、Python的应用领域五、Python的优缺点优点a. 简单易学:b. 可读性强:c. 库和框架丰富:d. 可移植性强:e. 开源: 缺点a. 运行速度较慢&#…

C 语言基础题:PTA L1-027 出租

下面是新浪微博上曾经很火的一张图: 一时间网上一片求救声,急问这个怎么破。其实这段代码很简单,index数组就是arr数组的下标,index[0]2 对应 arr[2]1,index[1]0 对应 arr[0]8,index[2]3 对应 arr[3]0&…

springboot 捕获数据库唯一索引导致的异常

在一些业务场景中,需要保证数据的唯一性,一般情况下,我们会先到数据库中去查询是否存在,再去判断是否可以插入新的数据.如果是在高并发的情况下,可能还是会出现重复的情况.这时候可能就需要用到锁.也可以在数据库中设置唯一索引. 如果使用唯一索引,在插入相同数据的情况下会抛出…

【postgresql】ERROR: column “xxxx.id“ must appear in the GROUP BY

org.postgresql.util.PSQLException: ERROR: column "xxx.id" must appear in the GROUP BY clause or be used in an aggregate function 错误:列“XXXX.id”必须出现在GROUP BY子句中或在聚合函数中使用 在mysql中是正常使用的,在postgre…

GAN笔记:利普希茨连续(Lipschitz continuity)

利普希茨连续(Lipschitz continuity)是一个数学概念,用于描述一个函数在其定义域内的变化程度。在生成对抗网络(GAN)中,利普希茨连续性对于鉴别器(Discriminator)的设计和训练具有重…

麻将技术从入门到高手,麻将听牌从基础到进阶

一、教程描述 本套麻将教程,大小8.82G,共有132个文件。 二、教程目录 麻将教程001-麻将的基本概念.mp4 麻将教程002-数牌的特性.mp4 麻将教程003-好坏搭判断.mp4 麻将教程004-拆搭原则.mp4 麻将教程005-听牌攻略.mp4 麻将教程006-进程判断.mp4 …