动手学深度学习(pytorch)学习记录30-含并行连接的网络(GoogLeNet)[学习记录]

news2024/9/20 3:08:46

目录

  • GoogLeNet
  • Inception块
  • GoogLeNet模型
  • 训练模型

GoogLeNet

GoogLeNet,也称为Inception v1,是由Google团队在2014年提出的深度学习模型,它在当年的ImageNet竞赛中取得了显著的成绩。GoogLeNet的设计引入了多个创新点,包括Inception模块、辅助分类器、全局平均池化层等,这些设计使得网络在保持深度的同时减少了参数数量和计算复杂度。

Inception模块是GoogLeNet的核心,它通过并行的方式使用不同尺寸的卷积核(1x1、3x3、5x5)和最大池化层来提取特征,然后将这些特征在通道维度上进行拼接。这种设计允许网络在不同的尺度上捕捉信息,并且通过1x1卷积进行降维,有效控制了参数数量和计算量。

GoogLeNet还引入了辅助分类器,这些分类器在训练过程中提供额外的梯度信号,有助于模型的收敛,并在一定程度上提高了最终的分类性能。

此外,GoogLeNet在最后一层使用了全局平均池化层代替传统的全连接层,这不仅进一步减少了参数数量,还提高了模型的泛化能力。在输出层之前,GoogLeNet还使用了Dropout技术来防止过拟合。

GoogLeNet的网络结构设计非常灵活,可以根据不同的需求调整Inception模块中的卷积层数量和通道数。这种设计使得GoogLeNet在图像分类任务中表现出色,同时也为后续的深度学习模型设计提供了重要的参考。

在实际应用中,GoogLeNet的变种如Inception v2、Inception v3等在原有的基础上进行了进一步的优化和改进,例如引入了批量归一化(Batch Normalization)和残差连接(Residual Connections),以提高训练效率和模型性能。

本文只介绍初代版本,后续会介绍改进版本。

Inception块

如图所示,一个块由四条并行路径组成,每条路径选用大小不同的卷积层,以实现从不同空间大小中提取信息。
中间两路在输入上使用1×1卷积核,减少像素级上的通道维数。这些通路使用合适的填充,使得输出尺寸一致,最后能够在输出通道维度上合并。
Inception块中,通常调整的超参数是每层输出通道数。

在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l


class Inception(nn.Module):
    # c1--c4是每条路径的输出通道数
    def __init__(self, in_channels, c1, c2, c3, c4, **kwargs):
        super(Inception, self).__init__(**kwargs)
        # 线路1,单1x1卷积层
        self.p1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)
        # 线路2,1x1卷积层后接3x3卷积层
        self.p2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)
        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        # 线路3,1x1卷积层后接5x5卷积层
        self.p3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)
        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
        # 线路4,3x3最大汇聚层后接1x1卷积层
        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)

    def forward(self, x):
        p1 = F.relu(self.p1_1(x))
        p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))
        p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))
        p4 = F.relu(self.p4_2(self.p4_1(x)))
        # 在通道维度上连结输出
        return torch.cat((p1, p2, p3, p4), dim=1)

GoogLeNet模型

GoogleLeNet共使用9个Inception块和一个全局平均汇聚层的堆叠来生成估计值。
在这里插入图片描述

# 第一模块 64通道、7×7卷积层
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# 第二模块 和Inception的第二条路径一样
b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),
                   nn.ReLU(),
                   nn.Conv2d(64, 192, kernel_size=3, padding=1),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

第三个模块串联两个完整的lnception块。第一个Inception块的输出通道数为64+128+32+32=256,四个路径之间的输出通道数量比为64:128:32:32 =2:4:1:1。第二个和第三个路径首先将输入通道的数量分别减少到96/192= 1/2和16/192 = 1/12,然后连接第二个卷积层。第二个|nception块的输出通道数增加到128+192+96+64=480,四个路径之间的输出通道数量比为128:192:96:64=4:6:3:2。第二条和第三条路径首先将输入通道的数量分别减少到128/256=1/2和32/256=1/8。

# 第三模块
b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),
                   Inception(256, 128, (128, 192), (32, 96), 64),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

第四模块更加复杂,它串联了5个Inception块,其输出通道数分别是
192+208+48+64=512、
160+224+64+64=512、
128+256+64+64=512、
112+288+64+64=528和
256+320+128+128 =832。
这些路径的通道数分配和第三模块中的类似,首先是含3x3卷积层的第二条路径输出最多通道,其次是仅含1x1卷积层的第一条路径,之后是含5x5卷积层的第三条路径和含3x3最大汇聚层的第四条路径。 其中第二、第三条路径都会先按比例减小通道数。这些比例在各个Inception块中都略有不同。

# 第四模块
b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),
                   Inception(512, 160, (112, 224), (24, 64), 64),
                   Inception(512, 128, (128, 256), (24, 64), 64),
                   Inception(512, 112, (144, 288), (32, 64), 64),
                   Inception(528, 256, (160, 320), (32, 128), 128),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

第五模块包含输出通道数为256+320+128+128=832和384+384+128+128=1024的两个lnception块。其中每条路径通道数的分配思路和第三、第四模块中的一致,只是在具体数值上有所不同。 需要注意的是,第五模块的后面紧跟输出层,该模块同NiN一样使用全局平均汇聚层,将每个通道的高和宽变成1。 最后,将输出变成二维数组,再接上一个输出个数为标签类别数的全连接层。

# 第五模块
b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),
                   Inception(832, 384, (192, 384), (48, 128), 128),
                   nn.AdaptiveAvgPool2d((1,1)),
                   nn.Flatten())

net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))

GoogLeNet模型的计算复杂,而且不如VGG那样便于修改通道数。

X = torch.rand(size=(1, 1, 96, 96))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 64, 24, 24])
Sequential output shape:	 torch.Size([1, 192, 12, 12])
Sequential output shape:	 torch.Size([1, 480, 6, 6])
Sequential output shape:	 torch.Size([1, 832, 3, 3])
Sequential output shape:	 torch.Size([1, 1024])
Linear output shape:	 torch.Size([1, 10])

训练模型

使用Fashion-MNIST数据集来训练我们的模型。

lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

· 本文使用了d2l包,这极大地减少了代码编辑量,需要安装d2l包才能运行本文代码
封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2131840.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot实战-文章管理系统(1.用户相关接口)

一、用户相关接口 1.注解 RestController:是一个组合注解,它结合了 Controller 和 ResponseBody 注解的功能(就相当于把两个注解组合在一起)。 在使用 RestController 注解标记的类中,每个方法的返回值都会以 JSON 或…

【数据结构】带你初步了解排序算法

文章目录 1. 排序的概念及运用1.1 概念1.2 运用 2. 常见的排序算法2.1 插入排序2.1.1 直接插入排序(简单插入排序)2.1.2 希尔排序 2.2 选择排序2.2.1 直接选择排序(简单选择排序)2.2.2 堆排序 2.3 交换排序2.3.1 冒泡排序2.3.2 快…

python | 2行命令解决pip模块不存在问题

一、报错情况 有时,在执行 pip 更新命令后,会提示更新失败或错误警告。 报错提示如下: 然后,再次使用 pip 安装命令时,会提示 pip 模块找不到或不存在: ModuleNotFoundError: No module named pip 导致…

在线IP代理检测:保护您的网络安全

在互联网飞速发展的今天,越来越多的人开始意识到网络安全和隐私保护的重要性。在线IP代理检测工具作为一种有效的网络安全手段,能够帮助用户识别和检测IP代理的使用情况,从而更好地保护个人隐私和数据安全。本文将详细介绍在线IP代理检测的相…

SQL数据库(MySQL)

一、在Ubuntu系统下安装MySQL数据库 1、更新软件源,在确保ubuntu系统能正常上网的情况下执行以下命令 sudo apt-get update 2、安装MySQL数据库及相关软件包 # 安装过程中设置root用户的密码 123456 sudo apt-get install mysql-server ​ # 安装访问数据库的客…

Rsync——远程同步

目录 一、rsync远程同步概述 1、rsync 简介 2、rsync的同步方式 3、rsync的备份方式 4、rsync与cp、scp对比 二、常用rsync命令 1、基本格式 2、配置源的两种表达方法 三、搭建rsync下行同步 1、搭建环境 2、配置rsync源服务器(172.16.88.44)…

出版学术专著需要具备哪些条件?

出版学术专著通常需要具备以下条件: 一、学术价值 1. 创新性 - 你的专著应在研究主题、方法、观点等方面具有一定的创新性。这可以是提出新的理论框架、发现新的现象、采用新的研究方法或对已有理论进行新的阐释和拓展。 - 例如,在某一特定学科领域中&…

【北京迅为】《STM32MP157开发板使用手册》- 第二十七章Cortex-M4按键实验

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器,既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构,主频650M、1G内存、8G存储,核心板采用工业级板对板连接器,高可靠,牢固耐…

跟《经济学人》学英文:2024年09月07日这期 What to read about the British economy

What to read about the British economy Britain used to be the world’s richest country. These six books explain how it came to be, and why it is no longer 原文: IN RECENT YEARS the British economy has tended to be in the news for the wrong re…

凯伦股份融合®️TMP复合瓦系统实力硬扛摩羯台风

第11号台风“摩羯”,今年以来登陆我国的最强台风,也是继2014年“威马逊”之后登陆我国的最强台风。 沿海多地发布防风Ⅰ级应急响应,多市启动落实“六停”措施。面对17级台风,工商业厂房遭受严重的破坏。据前方报道,当地…

Vue实用操作篇-1-第一个 Vue 程序

安装 Vue 非常的简便&#xff0c;只需下载好 Vue 对应的 .js 文件&#xff0c;在 html 中引入 vue.js 即可使用 Vue 下载好了 vue.js 我们便可以编写我们的第一个 vue 程序了 <!doctype html> <html lang"zh-CN"><head><meta charset"utf…

【数据结构】十大经典排序算法总结与分析

文章目录 前言1. 十大经典排序算法分类2. 相关概念3. 十大经典算法总结4. 补充内容4.1 比较排序和非比较排序的区别4.2 稳定的算法就真的稳定了吗&#xff1f;4.3 稳定的意义4.4 时间复杂度的补充4.5 空间复杂度补充 结语 前言 排序算法是《数据结构与算法》中最基本的算法之一…

计算机视觉(一)—— 特刊推荐

特刊征稿 01 期刊名称&#xff1a; Computer Vision for Smart Cities 截止时间&#xff1a; 提交截止日期&#xff1a;2024 年 12 月 31 日 目标及范围&#xff1a; 以下是一些潜在的主题&#xff1a; 城市交通和交通管理&#xff1a; • 车辆检测和跟踪以实现高效的交通流…

相机SD卡删除的照片可以恢复吗?6个方法,快速找回删除照片!

相机SD卡的照片在相机中误删了&#xff0c;有什么恢复办法吗&#xff1f;今天我要和大家分享一些关于如何恢复相机SD卡中删除的照片的方法。相信很多摄影爱好者都遇到过不小心删除了重要照片的情况&#xff0c;这时候我们该怎么办呢&#xff1f;别担心&#xff0c;下面我将为大…

野兔在线工具箱系统(市面上最强最多)最新版本更新2024.9

野兔在线工具箱系统&#xff0c;采用最新ThinkPHP8框架开发完成&#xff0c;也是基于YETUADMIN开发的工具箱系统&#xff0c;这次野兔在线工具系统更新&#xff0c;更新了几个新的功能模块&#xff0c;和已知的问题&#xff0c;修复系统部分功能。 程序开发 程序名称&#xf…

【生产力必备工具】GPU加速计算的首选云服务——蓝耘GPU(点击我的链接注册登录,可获50使用卷)

点击下面我的链接注册并登录&#xff0c;可获50使用卷&#xff1a;https://cloud.lanyun.net/#/registerPage?promoterCode11f606c51ehttps://cloud.lanyun.net/#/registerPage?promoterCode11f606c51e获得广泛丰富的NVIDIA高端GPU选择。高可配置高可用&#xff0c;专为大规模…

Linux之CentOS 7.9-Minimal部署Oracle 11g r2 安装实测验证(桌面模式)

前言: 发个之前的库存… Linux之CentOS 7.9-Minimal部署Oracle 11g r2 安装实测验证(桌面模式) 本次验证的是CentOS_7_Minimal-2009桌面模式来部署Oracle 11g r2,大家可根据自身环境及学习来了解。 环境:下载地址都给你们超链好了 1、Linux系统镜像包: 1.1 CentOS-7-x86_…

系统出现d3dcompiler_47.dll缺失怎么修复?总结6种d3dcompiler_47.dll修复方法

在现代电脑游戏中&#xff0c;​d3dcompiler_47.dll​ 文件是一个非常重要的组件&#xff0c;它用于DirectX应用程序的编译。然而&#xff0c;许多用户在尝试运行游戏或应用程序时&#xff0c;都会遇到“d3dcompiler_47.dll缺失”的错误。本文将为您提供解决此问题的详细步骤和…

2024年江西省职业院校技能大赛赛项规程 (简要概括)

这里写目录标题 一、赛项说明二、大赛时间三、参赛资格四、名额分配五、竞赛规程六、选拔方式七、报名办法八、奖项设置九、大赛QQ群十、资格审查 一、赛项说明 二、大赛时间 2024年十月至十二月 具体时间 地点 参考 赛项信息表 三、参赛资格 四、名额分配 五、竞赛规程 六、…

安全、稳定、高速的跨国文件传输系统

在全球化的大潮中&#xff0c;跨国企业的合作日益频繁&#xff0c;这使得跨国文件传输变得至关重要。企业在这一过程中追求的是快速、安全且稳定的文件传输服务。然而&#xff0c;跨国传输文件时&#xff0c;企业往往会遇到一些挑战。 要实现跨国文件传输的高效、安全与稳定&am…