NiNNet

news2025/1/22 9:08:00

目录

一、网络介绍

1、全连接层存在的问题

2、NiN的解决方案(NiN块)

3、NiN架构

4、总结

二、代码实现

1、定义NiN卷积块

2、NiN模型

3、训练模型


一、网络介绍

       NiN(Network in Network)是一种用于图像识别任务的卷积神经网络模型。它由谷歌研究员Min Lin、Qiang Chen和Shouyuan Chen于2013年提出。NiN的设计理念是通过引入“网络中的网络”结构来增强模型的表示能力。

1、全连接层存在的问题

       在之前的网络(比如AlexNet和VGGNet)后面都用了几个比较大的全连接层,全连接层中的参数相比于卷积层多得多,一个网络的参数大多都在全连接层,并且可以认为主要分布在卷积层之后的第一个全连接层。因此全连接层最大的问题是可能造成过拟合。

2、NiN的解决方案(NiN块)

       NiN的核心思想是使用1x1卷积层替代传统的全连接层。传统的卷积神经网络通常使用卷积层提取特征,然后通过全连接层进行分类。而NiN则在卷积层中引入了一种称为“1x1卷积”的操作,这个操作可以看作是在每个像素点上进行的全连接操作。通过使用1x1卷积,NiN能够在卷积层中引入非线性,增加模型的表达能力,并且减少了参数的数量。

       和VGG一样,NiN也有自己的块(NiN块),每一个NiN块其实就相当于一个小的神经网络(因为它具有卷积层和类似于全连接层的 $1 \times 1$ 卷积层),因此叫网络中的网络。NiN块首先有一个卷积层,然后后跟两个 $1 \times 1$ 的卷积层($1 \times 1$ 的卷积层等价于全连接层)。

3、NiN架构

全局池化层:池化层的高和宽等于输入的高和宽,一个通道得出一个值,用这个值当作对类别的预测。

4、总结

二、代码实现

       NiN的想法是将空间维度中的每个像素视为单个样本,将通道维度视为不同特征(feature)。下图说明了VGG和NiN及它们的块之间主要架构差异。NiN块以一个普通卷积层开始,后面是两个 $1 \times 1$ 的卷积层。NiN块第一层的卷积窗口形状通常由用户设置。随后的卷积窗口形状固定为 $1 \times 1$

1、定义NiN卷积块

import torch
from torch import nn
from d2l import torch as d2l

def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

2、NiN模型

       最初的NiN网络是在AlexNet后不久提出的,显然从中得到了一些启示。NiN使用窗口形状为$11\times 11$$5\times 5$ 和 $3\times 3$ 的卷积层,输出通道数量与AlexNet中的相同。每个NiN块后有一个最大池化层,池化窗口形状为 $3\times 3$,步幅为2。

       NiN和AlexNet之间的一个显著区别是NiN完全取消了全连接层。相反,NiN使用一个个NiN块,最后一个NiN块的输出通道数等于标签类别的数量。最后放一个全局平均池化层(global average pooling layer),生成一个对数几率(logits)。NiN设计的一个优点是,它显著减少了模型所需参数的数量。然而,在实践中,这种设计有时会增加训练模型的时间。

net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    nn.MaxPool2d(3, stride=2),
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2),
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),    # 通道数先增加后减少:1->96->256->384->10
    nn.AdaptiveAvgPool2d((1, 1)),   # 注意这里的(1, 1)不是kernel_size,而是output_size
    # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
    nn.Flatten())   # Flatten会把channel、height和width展平成一行

       我们创建一个数据样本来查看每个块的输出形状。

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

3、训练模型

       我们使用Fashion-MNIST来训练模型。训练NiN与训练AlexNet、VGG时相似。

lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224) # 调节图片尺寸为224
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.563, train acc 0.786, test acc 0.790
3087.6 examples/sec on cuda:0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1330497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

node-red:使用node-red-contrib-amqp节点,实现与RabbitMQ服务器(AMQP)的消息传递

node-red-contrib-amqp节点使用 一、简介1.1 什么是AMQP协议?1.2 什么是RabbitMQ? -> 开源的AMQP协议实现1.3 RabbitMQ的WEB管理界面介绍1.3 如何实现RabbitMQ的数据采集? -> node-red 二、node-red-contrib-amqp节点安装与使用教程2.1 节点安装2.2 节点使用2.2.1 amq…

tsconfig.app.json文件报红:Option ‘importsNotUsedAsValues‘ is deprecated...

在创建vue3 vite ts项目时的 tsconfig.json(或者tsconfig.app.json) 配置文件经常会报一个这样的错误: 爆红: Option ‘importsNotUsedAsValues’ is deprecated and will stop functioning in TypeScript 5.5. Specify compi…

干货:教你如何在JMeter中调用Python代码N种方法!

在性能测试领域,Jmeter已经成为测试专业人士的首选工具,用于模拟用户行为、测量响应时间、评估系统性能。而现在大部分接口都会涉及到验签、签名、加密等操作,为了满足特定需求,我们需要更多的灵活性,比如引入Python来…

推荐算法架构7:特征工程(吊打面试官,史上最全!)

系列文章,请多关注 推荐算法架构1:召回 推荐算法架构2:粗排 推荐算法架构3:精排 推荐算法架构4:重排 推荐算法架构5:全链路专项优化 推荐算法架构6:数据样本 推荐算法架构7:特…

QTNet:Query-based Temporal Fusion with Explicit Motion for 3D Object Detection

参考代码:QTNet 动机和出发点 自动驾驶中时序信息对感知性能具有较大影响,如在感知稳定性维度上。对于常见的时序融合多是在feature的维度上做,这个维度的融合主要分为如下两个方案: 1)BEV-based方案:将之…

信号与线性系统翻转课堂笔记7——信号正交与傅里叶级数

信号与线性系统翻转课堂笔记7——信号正交与傅里叶级数 The Flipped Classroom7 of Signals and Linear Systems 对应教材:《信号与线性系统分析(第五版)》高等教育出版社,吴大正著 一、要点 (1,重点&a…

2023年京东各行业年度数据报告-2023全年度空调十大热门品牌销量(销额)榜单

空调市场如今已经进入存量时代,加之消费市场的低迷,因此,2023年空调市场的整体销售下滑。 根据鲸参谋的统计数据,2023年度,京东平台上空调市场的总销量将近1400万,同比下滑约17%;销售额为410亿&…

CVE-2023-46604 Apache ActiveMQ RCE漏洞

一、Apache ActiveMQ简介 Apache ActiveMQ是一个开源的、功能强大的消息代理(Message Broker),由 Apache Software Foundation 所提供。ActiveMQ 支持 Java Message Service(JMS)1.1 和 2.0规范,提供了一个…

金蝶云星空打开应用报错‘D:\WorkSpace\XXXX\XXXX_k3Cloud‘ is already locked.

文章目录 金蝶云星空打开应用报错D:\WorkSpace\XXXX\XXXX_k3Cloud is already locked.报错界面报错内容原因分析解决方案工作空间下清除项目Clean up应用下-清除SVN锁定 重新打开应用就可以了 金蝶云星空打开应用报错’D:\WorkSpace\XXXX\XXXX_k3Cloud’ is already locked. 报…

多相机系统通用视觉 SLAM 框架的设计与评估

Design and Evaluation of a Generic Visual SLAM Framework for Multi-Camera Systems PDF https://arxiv.org/abs/2210.07315 Code https://github.com/neufieldrobotics/MultiCamSLAM Data https://tinyurl.com/mwfkrj8k 程序设置 主要目标是开发一个与摄像头系统配置无关…

渲染控制之条件渲染

目录 1、使用规则 2、更新机制 3、使用if进行条件渲染 4、if ... else ...语句和子组件状态 5、嵌套if语句 ArkTS提供了渲染控制的能力。条件渲染可根据应用的不同状态,使用if、else和else if渲染对应状态下的UI内容。 1、使用规则 支持if、else和else if语句…

网络技术基础与计算思维实验教程_2.3_单交换机VLAN配置实验

2.3.1 实验内容 2.3.2实验目的 实验的目的一是验证交换机 VLAN 配置过程; 二是验证属于同一 VLAN的终端之间的通信过程; 三是验证每一个 VLAN 为独立的广播域; 四是验证属于不同 VLAN的两个终端之间不能通信; 五是验证转发项和 VLAN的对应关系。 2.3.3实验原理 默认情况下,交换…

Spring IoCDI

文章目录 前言什么是Spring1. 什么是 IoC 容器1.1 什么是容器1.2 什么是 IoC 2. 什么是DI IoC & DI 的使用IoC详解Bean的存储Controller注解如何获取Bean1. 根据Bean的名称获取Bean2. 根据Bean类型获取Bean3. 根据Bean名和Bean类型获取Bean Service注解Repository注解Compo…

less 查看文本时,提示may be a binary file.See it anyway?

解决办法 首先使用echo $LESSCHARSET查看less的编码 看情况设置less的编码格式(我的服务器上使用utf-8查看中文) 还要特别注意一下,Linux中存在的文本文件的编码一定要是utf - 8;(这一步很关键) 例如:要保证windows上传到Linux的…

Ubuntu 常用命令之 ps 命令用法介绍

📑Linux/Ubuntu 常用命令归类整理 ps命令是Linux下的一个非常重要的命令,它用于查看系统中的进程状态。ps是Process Status的缩写,可以显示系统中当前运行的进程的状态。 以下是一些常用的参数 a:显示所有进程(包括…

【3D Max】入门

文章目录 概述界面介绍常用功能保存和导入基本建模编辑模型材质和贴图光源和阴影动画制作渲染设置导出和打印来源 概述 3 ds MAX是由 Discreet (后来被 Autodesk (Autodesk)合并)开发的一款基于 PC系统的3 d Max或3 ds MAX三维动画绘制和制作软件,其主要功能有建模…

选择谷歌SEO公司时怎么避开套路型公司?

选择谷歌SEO公司时,如何避开套路型公司是一个至关重要的话题。在当今数字化时代,优化网站以获得搜索引擎排名的重要性越发凸显,而选择一家信誉良好、专业的SEO公司将对企业的发展产生深远影响。然而,市场上存在许多套路型公司&…

java开发需要掌握的TypeScript相关的知识点,细致简洁版。

Typescript: 介绍: TypeScript(简称 TS)是JavaScript的超集(继承了JS全部语法),TypeScript Type JavaScript。 简单说,就是在JS的基础上,为JS添加了类型支持。是微软开…

构建创新学习体验:企业培训系统技术深度解析

企业培训系统在现代企业中发挥着越来越重要的作用,它不仅仅是传统培训的延伸,更是技术创新的结晶。本文将深入探讨企业培训系统的关键技术特点,并通过一些简单的代码示例,展示如何在实际项目中应用这些技术。 1. 前端技术&#…

springboot集成websocket全全全!!!

一、界面展示 二、前置了解 1.什么是websocket WebSocket是一种在单个TCP连接上进行全双工通信的持久化协议。 全双工协议就是客户端可以给我们服务器发数据 服务器也可以主动给客户端发数据。 2.为什么有了http协议 还要websocket 协议 http协议是一种无状态,非…