7.3 详解NiN模型--首次使用多层感知机(1x1卷积核)替换掉全连接层的模型

news2024/12/24 9:46:09

一.前提知识

多层感知机:由一个输入层,一个或多个隐藏层和一个输出层组成。(至少有一个隐藏层,即至少3层)

全连接层:是MLP的一种特殊情况,每个节点都与前一层的所有节点连接,全连接层可以解决线性可分问题,无法学习到非线性特征。(只有输入和输出层)

二.NiN模型特点

NiN与过去模型的区别:AlexNet和VGG对LeNet的改进在于如何扩大加深这两个模块。他们都使用了全连接层,使用全连接层就可能完全放弃表征的空间结构。
NiN放弃了使用全连接层,而是使用两个1x1卷积层(将空间维度中的每个像素视为单个样本,将通道维度视为不同特征。),相当于在每个像素的通道上分别使用多层感知机

优点:NiN去除了全连接层,可以减少过拟合,同时显著减少NiN的参数数量

三.模型架构

在这里插入图片描述

四.代码

import torch
from torch import nn
from d2l import torch as d2l
import time
def nin_block(in_channels,out_channels,kernel_size,strides,padding):
    return nn.Sequential(
        # 卷积层
        nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),
        nn.ReLU(),
        # 两个带有ReLU激活函数的 1x1卷积层
        nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),
        nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU()
    )
net = nn.Sequential(
    nin_block(1,96,kernel_size=11,strides=4,padding=0),
    nn.MaxPool2d(3,stride=2),
    nin_block(96,256,kernel_size=5,strides=1,padding=2),
    nn.MaxPool2d(3,stride=2),
    nin_block(256,384,kernel_size=3,strides=1,padding=1),
    nn.MaxPool2d(3,stride=2),
    nn.Dropout(0.5),
    # 标签类别是10
    nin_block(384,10,kernel_size=3,strides=1,padding=1),
    # 二维自适应平均池化,不用指定池化窗口大小
    nn.AdaptiveAvgPool2d((1,1)),
    # 将(样本,通道,w,h) = (批量,10,1,1),四维的输出转成2维的输出,其形状为(批量大小,10)
    nn.Flatten()
)
X = torch.rand(size=(1,1,224,224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

六.不同参数训练结果

学习率是0.1的情况

# 训练模型
lr,num_epochs,batch_size = 0.1,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

在这里插入图片描述

学习率是0.05的情况(提升了6个点)

'''开始计时'''
start_time = time.time()
# 训练模型
lr,num_epochs,batch_size = 0.05,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
print(f'{round(run_time,2)}s')

在这里插入图片描述

学习率为0.01,批次等于30的情况(反而下降了)

在这里插入图片描述

思考

为什么NiN块中有两个1x1卷积层?

从NiN替换掉全连接层,使用多层感知机角度来说:
因为1个1x1卷基层相当于全连接层,两个1x1卷积层使输入和输出层中间有了隐藏层,才相当于多层感知机。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/852045.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

stm32项目(6)——基于stm32的人体检测系统

目录 1.功能设计 2.硬件方案 1.单片机选择 2.人体传感器 3.报警模块 3.程序设计 4.课题意义 5.未来发展 1.功能设计 本系统为日常生活而开发的人体感应报警系统,主体通过HC-SR501模块达到感知人体靠近,检测到人体后单片机控制蜂鸣器和LE…

MyBatis动态sql标签帮你轻松搞定sql拼接

动态sql介绍 由于在开发过程不同的业务中会用到不同的操作条件,如果每个业务都拼接不同sql语句的话会是一个庞大的工作量;此时动态sql就能解决这个问题,可以针对不确定的操作条件动态拼接sql语句,根据提交的条件来完成业务sql的执…

【LeetCode 75】第二十四题(2390)从字符串中移除星号

目录 题目: 示例: 分析: 代码运行结果: 题目: 示例: 分析: 题目给我们一个字符串,然后字符串中包含星号*,要求每个星号消除一个从星号左边起最近的一个字符&#xf…

随着野火的增加,甲烷排放也会增加

2020 年对加利福尼亚州造成严重破坏的野火使大气中充满了强效温室气体。 2020 年,溪火烧毁了北加州的内华达山脉。图片来源:Zachary Cava/Flickr,CC BY-NC-SA 2.0 2020 年,在高温和干旱的推动下,加州野火烧毁了超过160…

揭示CTGAN的潜力:利用生成AI进行合成数据

推荐:使用 NSDT场景编辑器 助你快速搭建可编辑的3D应用场景 我们都知道,GAN在生成非结构化合成数据(如图像和文本)方面越来越受欢迎。然而,在使用GAN生成合成表格数据方面所做的工作很少。合成数据具有许多好处&#x…

Ae 效果:CC Spotlight

透视/CC Spotlight Perspective/CC Spotlight CC Spotlight(CC 聚光灯) 主要用途是创建和控制逼真的聚光灯效果。通过调整这些属性,可以模拟出各种不同的照明环境和效果,比如舞台照明、日出日落、特定的颜色照明等。 ◆ ◆ ◆ 效…

儿童台灯什么光源好?如何挑选儿童护眼台灯

很多家长有时候会说孩子觉得家里的台灯灯光刺眼,看书看久了就不舒服。这不仅要看光线亮度是否柔和,还要考虑台灯是不是有做遮光式设计。没有遮光式设计的台灯,光源外露,灯光会直射孩子头部,孩子视线较低,很…

c语言——字符串函数和内存操作函数

深度学习字符串函数和内存操作函数 一、求字符串长度函数(strlen)二、长度不受限制的字符串函数2.1字符串拷贝函数(strcpy)2.2字符串连接函数(strcat)2.3字符串比较函数(strcmp) 三、…

Smart HTML Elements 16.1 Crack

Smart HTML Elements 是一个现代 Vanilla JS 和 ES6 库以及下一代前端框架。企业级 Web 组件包括辅助功能(WAI-ARIA、第 508 节/WCAG 合规性)、本地化、从右到左键盘导航和主题。与 Angular、ReactJS、Vue.js、Bootstrap、Meteor 和任何其他框架集成。 智…

thinkphp:分组查询(多条相同列的数据只展示一条)

例子:数据库中有trans_num、subinventory_from、transaction_type、creation_date有相同值,在查询该数据库使,只展示这几个值相同的一条 效果: 限制之前 限制之后 代码 限制前,后端代码 public function select_i…

使用威胁建模进行DevSecOps实践丨IDCF

作者: 姚圣伟(现就职天津引元科技 天津市区块链技术创新中心) 研发效能(DevOps)工程师认证学员 一、从DevOps到 DevSecOps DevOps 最开始最要是强调开发和运维的协作与配合,至今,已不仅仅涉…

第3章 数据和C

本章介绍以下内容: 关键字:int 、short、long、unsigned、char、float、double、_Bool、_Complex、_Imaginary 运算符:sizeof() 函数:scanf() 整数类型和浮点数类型的区别 如何书写整型和浮点型常数,如何声明这些类型的…

在windows中使用parLapply函数执行并行计算

目录 1-lapply()函数介绍: 例子1: 例子2: 例子3: 2-在Windows使用并行计算,使用parLapply()函数 2.1-并行计算的准备阶段: 2.2-parLapply()函数介绍 2.3-使用parLapply()函数编写执行并行计算 2.4-…

领航优配:医药股发力拉升,双成药业等涨停,科源制药等大涨

医药股8日盘中拉升走高,到发稿,科源制药涨超16%,盘中一度冲击涨停;北陆药业涨超10%,精华制药、双成药业、海欣股份、奇特制药、龙津药业、永安药业等涨停,诚达药业、特一药业涨逾8%。 音讯面上,…

《实战AI模型》:GPT语义缓存为什么用GPTCache而不是Redis?

为什么不是Redis? 验证完可行性,便到了搭建系统的环节。这里我有一点必须要分享,在搭建 ChatGPT 缓存系统时,Redis 并不是我们的首选。 个人而言,我很喜欢用 Redis,它性能出色又十分灵活,适用于各种应用。但是 Redis 使用键值数据模型是无法查询近似键的。 如果用户提…

天!刚进公司的00后凭这套大屏模板直接涨薪5K,这是什么加薪利器

现在这世道,对普通人来说有一份谋生的工作真的都是奢求。前几天,坐在老李旁边,在公司待了10多年的老员工被叫到主管办公室说裁员的事情,他本来以为就是个普通的工作讨论,结果是通知他被裁员了,出来后整个人…

ESRVCC准备阶段优化提升方案

一、eSRVCC信令流程 UE发测量报告给eNb;基于测量报告,源eNb触发到GERAN的SRVCC切换流程;源eNb发送Handover Required 消息给源MME,消息中携带Target ID, generic Source to Target Transparent Container, SRVCC HO Indication等信息&#xf…

企业邮箱安全评估:选择最佳安全性的企业邮箱

企业在网络安全技术上投入了数十亿美元,但当涉及到邮箱安全时,风险甚至更高。随着网络钓鱼攻击、勒索软件和其他恶意威胁的兴起,确保邮箱免受入侵至关重要。 幸运的是,有许多安全解决方案和方法可供企业使用,包括: 反垃…

2023年国内最新的CRM系统排名

随着互联网技术的发展和竞争的激烈,越来越多的企业将目光放到CRM上,希望可以提高效率和收入。针对正在选型的企业,这里有一份国内crm系统排名【2023最新】,请注意查收。 1、Zoho CRM Zoho CRM是一款知名的在线CRM系统&#xff0…

uniapp开发小程序实现考勤打卡,附带源码

效果图: 考勤打卡三步走: 在地图上绘制打卡区域: uniapp开发小程序之在地图上进行绘制图形,并将经纬度转为固定格式的字符串_uniapp 绘制地图_阿晨12138的博客-CSDN博客 获取到用户定位,并跳转到当前用户定位&#xf…