经典神经网络(4)Nin-Net及其在Fashion-MNIST数据集上的应用

news2024/11/24 2:19:49

经典神经网络(4)Nin-Net及其在Fashion-MNIST数据集上的应用

1 Nin-Net的简述

1.1 Nin-Net的概述

LeNet、AlexNet和VGG都有⼀个共同的设计模式:通过⼀系列的卷积层与汇聚层来提取空间结构特征;然后通过全连接层对特征的表征进⾏处理。AlexNet和VGG对LeNet的改进主要在于如何扩⼤和加深这两个模块。

然⽽,如果使⽤了全连接层,可能会完全放弃表征的空间结构。⽹络中的⽹络(NiN)提供了⼀个⾮常简单的解决⽅案:在每个像素的通道上分别使⽤多层感知机。

NiN的想法是在每个像素位置(针对每个⾼度和宽度)应⽤⼀个全连接层。如果我们将权重连接到每个空间位置,我们可以将其视为1 *×* 1卷积层,或作为在每个像素位置上独⽴作⽤的全连接层。从另⼀个⻆度看,即将空间维度中的每个像素视为单个样本,将通道维度视为不同特征(feature)

1.2 Nin-Net的实现

在这里插入图片描述

import torch.nn as nn
import torch



class NinNet(nn.Module):


    def __init__(self):
        super().__init__()
        '''
        最初的NiN⽹络是在AlexNet后不久提出的,显然从中得到了⼀些启⽰。NiN使⽤窗⼝形状为11×11、5×5和3×
        3的卷积层,输出通道数量与AlexNet中的相同。每个NiN块后有⼀个最⼤汇聚层,汇聚窗⼝形状为3 × 3,步幅为2。
        
        NiN和AlexNet之间的⼀个显著区别是NiN完全取消了全连接层。相反,NiN使⽤⼀个NiN块,其输出通道数等
        于标签类别的数量。最后放⼀个全局平均汇聚层(global average pooling layer),⽣成⼀个对数⼏率(logits)。
        NiN设计的⼀个优点是,它显著减少了模型所需参数的数量。然⽽,在实践中,这种设计有时会增加训练模
        型的时间。
        '''
        self.model = nn.Sequential(
            self.nin_block(in_channels=1,out_channels=96,kernel_size=11,strides=4,padding=0),
            nn.MaxPool2d(kernel_size=3,stride=2),
            self.nin_block(in_channels=96, out_channels=256, kernel_size=5, strides=1, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            self.nin_block(in_channels=256, out_channels=384, kernel_size=3, strides=1, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Dropout(0.5),
            # 标签类别数是10
            self.nin_block(384, 10, kernel_size=3, strides=1, padding=1),
            nn.AdaptiveAvgPool2d((1, 1)),
            # 将四维的输出转成二维的输出,其形状为(批量⼤⼩,10)
            nn.Flatten()

        )

    def forward(self, X):
        X = self.model(X)
        return X



    def nin_block(self,in_channels, out_channels, kernel_size, strides, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels,kernel_size, strides, padding),nn.ReLU(),
            nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),
            nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU()
        )


if __name__ == '__main__':
    net = NinNet()
    # 测试神经网络是否可运行
    # inputs = torch.rand(size=(1, 1, 224, 224), dtype=torch.float32)
    # outputs = net(inputs)
    # print(outputs.shape)
    X = torch.rand(size=(1, 1, 224, 224), dtype=torch.float32)
    for layer in net.model:
        X = layer(X)
        print(layer.__class__.__name__, 'output shape:', X.shape)
Sequential output shape: torch.Size([1, 96, 54, 54])
MaxPool2d output shape: torch.Size([1, 96, 26, 26])
    
Sequential output shape: torch.Size([1, 256, 26, 26])
MaxPool2d output shape: torch.Size([1, 256, 12, 12])
    
Sequential output shape: torch.Size([1, 384, 12, 12])
MaxPool2d output shape: torch.Size([1, 384, 5, 5])
    
Dropout output shape: torch.Size([1, 384, 5, 5])
Sequential output shape: torch.Size([1, 10, 5, 5])
# 全局平均汇聚层(global average pooling layer)
AdaptiveAvgPool2d output shape: torch.Size([1, 10, 1, 1])
Flatten output shape: torch.Size([1, 10])

2 Nin-Net在Fashion-MNIST数据集上的应用示例

3.1 创建Nin-Net网络模型

如1.2代码所示。

3.2 读取Fashion-MNIST数据集

其他所有的函数,与经典神经网络(1)LeNet及其在Fashion-MNIST数据集上的应用完全一致。

batch_size = 128

train_iter,test_iter = get_mnist_data(batch_size,resize=224)

3.3 在GPU上进行模型训练

from _04_NinNet import NinNet

# 初始化模型
net = NinNet()

lr, num_epochs = 0.1, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/553086.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

线程池的创建与使用

void execute(Runnable run)方法处理Runnbale任务 Future<> submit(Callable<> task)方法处理Callable任务 void shutdown()结束线程池 List<\Runnable> shutdownNow()立即结束线程池&#xff0c;不管任务是否执行完毕 //创建线程池的一种方式 ExecutorServi…

基于WebApi实现ModbusTCP数据服务

在上位机开发过程中&#xff0c;有时候会遇到需要提供数据接口给MES或者其他系统&#xff0c;今天跟大家分享一下&#xff0c;如何在Winform等桌面应用程序中&#xff0c;开发WebApi接口&#xff0c;提供对外数据服务。 为了更好地演示应用场景&#xff0c;本案例以读取Modbus…

Leetcode 209. 长度最小的子数组——go语言实现

文章目录 一、题目描述二、代码实现方法一&#xff1a;暴力法解题思路代码实现复杂度分析 方法二&#xff1a;滑动窗口 双指针解题思路代码实现复杂度分析 方法三&#xff1a;前缀和 二分查找解题思路代码实现复杂度分析 一、题目描述 给定一个含有 n 个正整数的数组和一个正…

STM32 10个工程篇:1.IAP远程升级(四)

在前三篇博客中主要介绍了IAP远程升级的应用背景、下位机的实现原理、以及基于STM32CubeMX对STM32F103串口DMA的基本配置&#xff0c;第四篇博客主要想介绍Labview端上位机和下位机端的报文定义和通信等。 当笔者工作上刚接触到STM32 IAP升级的时候&#xff0c;实事求是地说存在…

【科普】电压和接地真的存在吗?如何测试?

经常在实验室干活的&#xff0c;难免不被电击过&#xff0c;尤其是在干燥的北方&#xff0c;“被电”是常有的事情&#xff0c;我记得有一次拿着射频线往仪表上拧的时候&#xff0c;遇到过一次严重的电火花&#xff0c;瞬间将仪表的射频口边缘烧出了一个疤&#xff0c;实验室遭…

LeetCode83. 删除排序链表中的重复元素

写在前面&#xff1a; 题目链接&#xff1a;LeetCode83. 删除排序链表中的重复元素 编程语言&#xff1a;C 题目难度&#xff1a;简单 一、题目描述 给定一个已排序的链表的头 head &#xff0c; 删除所有重复的元素&#xff0c;使每个元素只出现一次 。返回 已排序的链表 。 …

Java中异常的处理及捕获

Java中异常的处理及捕获 一、异常的概述 &#xff08;1&#xff09;Java中异常的作用&#xff1a;增强程序的健壮性 &#xff08;2&#xff09;在Java中所有的Error&#xff08;错误&#xff09;和异常&#xff08;Exception&#xff09;都继承了同一个父类Throwable 二、异…

postgresql内核源码分析-删除表drop table流程

专栏内容&#xff1a;postgresql内核源码分析个人主页&#xff1a;我的主页座右铭&#xff1a;天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物&#xff0e; 目录 前言 调用关系 概要流程 详细流程 创建对象列表空间 删除多个指定的数据库…

【蓝桥杯国赛真题27】Scratch LED屏幕 少儿编程scratch图形化编程 蓝桥杯国赛真题讲解

目录 scratch LED屏幕 一、题目要求 编程实现 二、案例分析 1、角色分析

C#中使用git将项目代码上传到远程仓库的操作

一、远程仓库创建操作&#xff08;远程仓库使用的是gitHub&#xff09; 1、登录GitHub官网&#xff0c;注册登录账号后&#xff0c;点击创建仓库 2、仓库名称命名&#xff0c;如下所示&#xff1a; 3、创建成功如下所示&#xff1a;获得https协议&#xff08;https://github.c…

Android开发不可缺少的辅助工具

目录 jadxandroid_toolscrcpy-guiCode CraftsSQLite Expert Personal jadx jadx是一款apk反编译工具。 PS&#xff1a;部分版本安装&#xff0c;无法打开类文件&#xff0c;需换个版本。 开源地址&#xff1a;https://github.com/skylot/jadx android_tool android_tool可以通…

【瑞萨RA_FSP】SCL UART 串口通信

文章目录 一、串口通信协议简介1. 物理层2. 协议层 二、SCI 简介三、SCI的结构框图四、UART波特率计算 一、串口通信协议简介 串口通讯(Serial Communication)是一种设备间非常常用的串行通讯方式&#xff0c;因为它简单便捷&#xff0c;因此大部分电子设备都支持该通讯方式&a…

SNAT和DNAT策略

文章目录 1.SNAT策略及应用1.1 SNAT原理与应用1.2 SNAT策略的工作原理1.3 实验步骤 2.DNAT策略2.1 DNAT策略的概述2.1 DNAT原理与应用2.3 实验步骤 3.规则的导出、导入4. 总结 1.SNAT策略及应用 1.1 SNAT原理与应用 SNAT 应用环境&#xff1a;局域网主机共享单个公网IP地址接…

【利用AI让知识体系化】关于浏览器内核的基础知识

I. 介绍 什么是浏览器内核 浏览器内核&#xff08;Browser Engine&#xff09;&#xff0c;也叫浏览器渲染引擎&#xff08;Rendering Engine&#xff09;&#xff0c;是浏览器的核心组成部分&#xff0c;它负责将 HTML、CSS、JavaScript 等代码经过解析和渲染后&#xff0c;…

End-to-End Object Detection with Transformers 论文学习

论文地址&#xff1a;End-to-End Object Detection with Transformers 1. 解决了什么问题&#xff1f; 现有的目标检测算法需要大量的人为先验的设计&#xff0c;如 anchor 和 NMS&#xff0c;整体架构并不是端到端的。现有的检测方法为了去除重叠框&#xff0c;一般会利用 p…

企业级信息系统开发——初探Spring - 利用组件注解符精简Spring配置文件

文章目录 一、打开项目二、利用组件注解符精简Spring配置文件&#xff08;一&#xff09;创建新包&#xff08;二&#xff09;复制四个类&#xff08;三&#xff09;修改杀龙任务类&#xff08;四&#xff09;修改救美任务类&#xff08;五&#xff09;修改勇敢骑士类&#xff…

NEEPU Sec 2023 公开赛 writeup

文章目录 WebCute CirnoCute Cirno(Revenge) RevHow to use ida?BaseHow to use python?IKUN检查器junk code CryptoFunnyRsaLossloud Misc吉林第一站倒影Shiro重生之我是CTFer 问卷 Web Cute Cirno 学艺不精的我脑袋要炸了 在Cirno界面的源代码中发现任意读 考虑之前的比…

在Ubuntu20.04部署Flink1.17实现基于Flink GateWay的Hive On Flink的踩坑记录(一)

在Ubuntu20.04部署Flink1.17实现基于Flink GateWay的Hive On Flink的踩坑记录&#xff08;一&#xff09; 前言 转眼间&#xff0c;Flink1.14还没玩明白&#xff0c;Flink已经1.17了&#xff0c;这迭代速度还是够快。。。 之前写过一篇&#xff1a;https://lizhiyong.blog.c…

View中的滑动冲突

View中的滑动冲突 1.滑动冲突的种类 滑动冲突一般有3种, 第一种是ViewGroup和子View的滑动方向不一致 比如: 父布局是可以左右滑动,子view可以上下滑动 第二种 ViewGroup和子View的滑动方向一致 第三种 第三种类似于如下图 2.滑动冲突的解决方式 滑动冲突一般情况下有2…

Ubuntu 20.04上安装和配置Samba

介绍&#xff1a; Samba是一个开源的软件套件&#xff0c;它允许不同操作系统之间共享文件和打印机。在Ubuntu 20.04上安装和配置Samba是一种方便的方法&#xff0c;可以在本地网络中共享文件夹&#xff0c;使多台计算机能够轻松访问共享文件。本文将向您展示如何在Ubuntu 20.0…