【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】

news2025/1/14 18:00:58

目录

    • 1. NiN块介绍
    • 2. 构建NiN模型
    • 3.NIN模型每层输出形状
    • 4. 获取Fashion-MNIST数据和训练NiN模型
    • 5. 总结

前几篇文章介绍的LeNetAlexNetVGG在设计上的共同之处是:先以由卷积层构成的模块充分抽取空间特征,再以由全连接层构成的模块来输出分类结果。其中,AlexNet和VGG对LeNet的改进主要在于如何对这两个模块加宽(增加通道数)和加深。本文我们介绍网络中的网络(NiN)。它提出了另外一个思路,即串联多个由卷积层和“全连接”层构成的小网络来构建一个深层网络。

1. NiN块介绍

通常卷积层的输入和输出是四维数组(样本,通道,高,宽),而全连接层的输入和输出则通常是二维数组(样本,特征)。如果想在全连接层后再接上卷积层,则需要将全连接层的输出变换为四维。之前介绍的 1 × 1 1\times 1 1×1卷积层,它可以看成全连接层,其中空间维度(高和宽)上的每个元素相当于样本,通道相当于特征。因此,NiN使用 1 × 1 1\times 1 1×1卷积层来替代全连接层,从而使空间信息能够自然传递到后面的层中去。图1对比了NiN同AlexNet和VGG等网络在结构上的主要区别。

在这里插入图片描述

NiN块是NiN中的基础块。它由一个卷积层加两个充当全连接层的 1 × 1 1\times 1 1×1卷积层串联而成。其中第一个卷积层的超参数可以自行设置,而第二和第三个卷积层的超参数一般是固定的。

下面我们定义一个NiN块的函数nin_block

import time
import torch
from torch import nn, optim
import sys
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                        nn.ReLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size=1),
                        nn.ReLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size=1),
                        nn.ReLU())
    return blk

2. 构建NiN模型

NiN是在AlexNet问世不久后提出的。它们的卷积层设定有类似之处。NiN使用卷积窗口形状分别为 11 × 11 11\times 11 11×11 5 × 5 5\times 5 5×5 3 × 3 3\times 3 3×3的卷积层,相应的输出通道数也与AlexNet中的一致。每个NiN块后接一个步幅为2、窗口形状为 3 × 3 3\times 3 3×3的最大池化层。

除使用NiN块以外,NiN还有一个设计与AlexNet显著不同:NiN去掉了AlexNet最后的3个全连接层,取而代之地,NiN使用了输出通道数等于标签类别数的NiN块,然后使用全局平均池化层对每个通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。NiN的这个设计的好处是可以显著减小模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。

import torch.nn.functional as F
class GlobalAvgPool2d(nn.Module):
    # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()
    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])

net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, stride=4, padding=0),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(96, 256, kernel_size=5, stride=1, padding=2),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(256, 384, kernel_size=3, stride=1, padding=1),
    nn.MaxPool2d(kernel_size=3, stride=2), 
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, stride=1, padding=1),
    GlobalAvgPool2d(), 
    # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
    d2l.FlattenLayer())

3.NIN模型每层输出形状

我们构建一个数据样本来查看每一层的输出形状。

X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children(): 
    X = blk(X)
    print(name, 'output shape: ', X.shape)

输出:

0 output shape:  torch.Size([1, 96, 54, 54])
1 output shape:  torch.Size([1, 96, 26, 26])
2 output shape:  torch.Size([1, 256, 26, 26])
3 output shape:  torch.Size([1, 256, 12, 12])
4 output shape:  torch.Size([1, 384, 12, 12])
5 output shape:  torch.Size([1, 384, 5, 5])
6 output shape:  torch.Size([1, 384, 5, 5])
7 output shape:  torch.Size([1, 10, 5, 5])
8 output shape:  torch.Size([1, 10, 1, 1])
9 output shape:  torch.Size([1, 10])

4. 获取Fashion-MNIST数据和训练NiN模型

此处依然使用Fashion-MNIST数据集来训练NiN模型

batch_size = 128
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0101, train acc 0.513, test acc 0.734, time 260.9 sec
epoch 2, loss 0.0050, train acc 0.763, test acc 0.754, time 175.1 sec
epoch 3, loss 0.0041, train acc 0.808, test acc 0.826, time 151.0 sec
epoch 4, loss 0.0037, train acc 0.828, test acc 0.827, time 151.0 sec
epoch 5, loss 0.0034, train acc 0.839, test acc 0.831, time 151.0 sec

5. 总结

  • NiN重复使用由卷积层和代替全连接层的 1 × 1 1\times 1 1×1卷积层构成的NiN块来构建深层网络。
  • NiN去除了容易造成过拟合的全连接输出层,而是将其替换成输出通道数等于标签类别数的NiN块和全局平均池化层。

对文章存在的问题,或者其他关于Python相关的问题,都可以在评论区留言或者私信我哦

如果文章内容对你有帮助,感谢点赞+关注!

关注下方GZH:阿旭算法与机器学习,可获取更多干货内容~欢迎共同学习交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/94180.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RocketMQ基本概念及功能

文章目录背景架构模型NameServer 名字服务器Broker 代理服务器生产者主题队列消息消息标签消息位点消费者消费位点消费者分组订阅关系参考文章背景 RocketMQ是阿里巴巴在2012年开发的分布式消息中间件,专为万亿级超大规模的消息处理而设计,具有高吞吐量…

【VScode插件开发】<二>插件实践开发+发布

开发环境配置完,就得好好琢磨开发内容了,不能老停留在hello world上呀! 一、开发文档结构分析 1.Package.json {"name": "kidtest","displayName": "KidTest","description": "for…

Gnoppix Linux系统发布

导读基于 Kali Linux 的 Linux 滚动发行版 Gnoppix 22.12 带来了 GNOME 43、Linux 内核 6.0 和新的升级。作为传统的现场 CD 发行版 Knoppix 项目的继承者,​​Gnoppix Linux​​ 是专门为渗透测试和反向工程而设计的。它为网页应用安全和数字权利保护进行了优化。除…

Java也可以轻松编写并发程序

如今,多核处理器在服务器,台式机及笔记本电脑上已经很普遍了,同时也被应用在更小的设备上,比如智能手机和平板电脑。这就开启了并发编程新的潜力,因为多个线程可以在多个内核上并发执行。在应用中要实现最大性能的一个…

SpringBoot+Vue实现前后端分离的小而学在线考试系统

文末获取源码 开发语言:Java 使用框架:spring boot 前端技术:JavaScript、Vue.js 、css3 开发工具:IDEA/MyEclipse/Eclipse、Visual Studio Code 数据库:MySQL 5.7/8.0 数据库管理工具:phpstudy/Navicat JD…

访问者模式(Visitor)

参考: 模板方法设计模式 (refactoringguru.cn) design-patterns-cpp/TemplateMethod.cpp at master JakubVojvoda/design-patterns-cpp GitHubhttps://github.com/JakubVojvoda/design-patterns-cpp/blob/master/state/State.cpp) 文章目录一、什么是访问者模式…

【Python机器学习】Sklearn库中Kmeans类、超参数K值确定、特征归一化的讲解(图文解释)

一、局部最优解 采用随机产生初始簇中心 的方法,可能会出现运行 结果不一致的情况。这是 因为不同的初始簇中心使 得算法可能收敛到不同的 局部极小值。 不能收敛到全局最小值,是最优化计算中常常遇到的问题。有一类称为凸优化的优化计算,不…

数字货币市场风暴肆虐,币圈人应该把握哪些新的赛道机遇

11月11日(周五)美股盘前,曾经为全球第二大加密货币交易所FTX在推特发布了申请破产保护的声明,创始人SBF已经辞去CEO职务。据声明,FTX已经任命John J. Ray III 担任CEO,SBF还将协助相关破产事宜。据FTX在推特…

[附源码]Python计算机毕业设计Django面向高校活动聚App

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

记录--三分钟打造自己专属的uni-app工具箱

这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 介绍 可曾想过我们每次创建新项目,或者换地方写程序,都要把之前写过的工具类找出来又要复制粘贴一遍有些麻烦,尤其是写uni-app自定义模板主要还是开发工具完成的。这…

反序列化漏洞之CVE-2016-7124

目录 魔术函数 发生条件 靶场练习 魔术函数 __constuct: 构建对象的时被调用 __destruct: 明确销毁对象或脚本结束时被调用 __invoke: 当以函数方式调用对象时被调用 __toString: 当一个类被转换成字符串时被调用 __wakeup: 当使用unserialize时被调用,可用于做些…

【python】pandas 之 DataFrame、Series使用详解

目录 一:Pandas简介 二:Pandas数据结构 三:Series 四:字典生成Series 五:标量值生成Series 六:Series类似多维数组 七:Series类似字典 八:矢量操作与对齐 Series 标签 九…

对话交通银行:中国金融业数据仓库有哪些重要趋势?

数字经济时代,什么才是金融机构的核心竞争力?笔者访谈了交通银行软件开发中心总经理刘雷。刘雷指出:“数据和数据能力是金融机构发展的核心竞争力”。 当下,金融机构的数字化转型正迈入纵深阶段,使得两大核心诉求更加…

SpringMVC学习:四、SpringMVC的高级开发(异常处理器、文件上传、 拦截器)

5. SpringMVC的高级开发 5.1 异常处理器 ​ springmvc在处理请求过程中出现异常信息交由异常处理器进行处理,自定义异常处理器可以实现一个系统的异常处理逻辑。 思路: ​ 系统中异常包括两类:预期异常和运行时异常RuntimeException,前者…

如何利用代理IP做SEO监控优化?

从事互联网营销相关的用户多多少少都会接触到SEO,一般来说企业为了实现传播效果,每天都需要大量重复地做各种渠道的投放,这是一项逐渐累积的长期性工作。而这其中关键的优化分析与监控,势必需要大量的数据支持。接下来就一起来了解…

Linux--seq命令

seq(sequeue)用于序列化输出一个数到另一个数之间的整数,输出连续的数字、 固件间隔的数字、指定格式的数字。 一、使用方法 seq [选项] 尾数seq [选项] 首数 尾数seq [选项] 首数 增量 尾数 [选项] -f, --formatFORMAT use printf style floating-point FO…

深度学习 Day22——利用LSTM实现火灾温度预测

深度学习 Day22——利用LSTM实现火灾温度预测 文章目录深度学习 Day22——利用LSTM实现火灾温度预测一、前言二、我的环境三、LSTM介绍1、长期依赖的问题2、LSTM3、LSTM结构四、前期工作1、设置GPU2、导入数据3、数据可视化五、构建数据集1、设置X、y2、设置归一化3、划分数据集…

[附源码]Nodejs计算机毕业设计基于的校园失物招领平台Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置: Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分…

[附源码]Python计算机毕业设计高校学生综合素质测评系统Django(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程 项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等…

分析Linux 内核 SCSI IO 子系统

【推荐阅读】 浅析linux内核网络协议栈--linux bridge virtio-net 实现机制【一】(图文并茂) 怎么在Windows下使用Makefile文件 概述 LINUX 内核中 SCSI 子系统由 SCSI 上层,中间层和底层驱动模块 [1] 三部分组成,主要负责管…