经典轻量级神经网络(1)MobileNet V1及其在Fashion-MNIST数据集上的应用

news2024/11/24 10:46:10

经典轻量级神经网络(1)MobileNet V1及其在Fashion-MNIST数据集上的应用

1 MobileNet V1的简述

自从2017年由谷歌公司提出,MobileNet可谓是轻量级网络中的Inception,经历了一代又一代的更新。

  1. MobileNet 应用了Depthwise 深度可分离卷积来代替常规卷积,从而降低计算量,减少模型参数。
  2. MobileNet 不仅产生了小型网络,还重点优化了预测延迟。与之相比,有一些小型网络虽然网络参数较少,但是预测延迟较大。
  3. 论文下载地址: https://arxiv.org/pdf/1704.04861.pdf

1.1 深度可分离卷积

对于传统的卷积层,单个输出feature 这样产生:

  • 首先由一组滤波器对输入的各通道执行滤波,生成滤波feature 。这一步仅仅考虑空间相关性。

  • 然后计算各通道的滤波feature 的加权和,得到单个feature

    这里不同位置处的通道加权和的权重不同,这意味着在空间相关性的基础上,叠加了通道相关性。

Depthwise 深度可分离卷积打破了空间相关性和通道相关性的混合:

  • 首先由一组滤波器对输入的各通道执行滤波,生成滤波feature 。这一步仅仅考虑空间相关性。
  • 然后执行1x1 卷积来组合不同滤波feature 。这里不同位置处的通道加权和的权重都相同,这意味着这一步仅仅考虑通道相关性。
  • depthwise 卷积的参数数量和计算代价都是常规卷积的 1/8 到 1/9。

传统卷积和深度可分离卷积的区别可参考下面博客:

Pytorch常用的函数(三)深度学习中常见的卷积操作详细总结_undo_try的博客-CSDN博客

1.2 网络结构

1.2.1 V1卷积层

在这里插入图片描述

上图左边是标准卷积层,右边是V1的卷积层。V1的卷积层,首先使用3×3的深度卷积提取特征,接着是一个BN层,随后是一个ReLU层,在之后就会逐点卷积,最后就是BN和ReLU了。这也很符合深度可分离卷积,将左边的标准卷积拆分成右边的一个深度卷积和一个逐点卷积

注意:深度可分离卷积里面的ReLU,是ReLU6。

在这里插入图片描述

上图左边是普通的ReLU,对于大于0的值不进行处理,右边是ReLU6,当输入的值大于6的时候,返回6,relu6“具有一个边界”。作者认为ReLU6作为非线性激活函数,在低精度计算下具有更强的鲁棒性

标准卷积核深度可分离卷积层到底对结果有什么样的影响?

在这里插入图片描述

从上图可以看到使用深度可分离卷积与标准卷积,参数和计算量能下降为后者的九分之一到八分之一左右。但是准确率只有下降极小的1%

1.2.2 网络结构

MobileNeet 网络结构如下表所示。其中:

  • Conv 表示标准卷积,Conv dw 表示深度可分离卷积。
  • 所有层之后都跟随BNReLU (除了最后的全连接层,该层的输出直接送入到softmax 层进行分类)。
  • 先是一个3x3的标准卷积,s2进行下采样。然后就是堆积深度可分离卷积,并且其中的部分深度卷积会利用s2进行下采样。然后采用平均池化层将feature变成1x1,根据预测类别大小加上全连接层,最后是一个softmax层。整个网络有28层,其中深度卷积层有13层。
  • 与训练大模型相反,训练MobileNet 时较少的采用正则化和数据集增强技术,因为MobileNet 是小模型,而小模型不容易过拟合。

在这里插入图片描述

整个计算量基本集中在1x1卷积上。对于参数也主要集中在1x1卷积,除此之外还有就是全连接层占了一部分参数。

  • Conv 1x1 包含了所有的1x1 卷积层,包括可分离卷积中的1x1 卷积。
  • Conv DW 3x3 仅包括可分离卷积中的 3x3 卷积。

在这里插入图片描述

1.3 宽度乘子、分辨率乘子

尽管基本的MobileNet 架构已经很小,延迟很低,但特定应用需要更快的模型。为此MobileNet 引入了两个超参数:宽度乘子、分辨率乘子。

宽度乘子width multiplier ,记做a 。实际上是减小每层网络的输入、输出 feature map 的通道数量。

  • 宽度乘子应用于第一层(是一个全卷积层)的输出通道数上。这也影响了后续所有Depthwise可分离卷积层的输入feature map通道数、输出feature map通道数。

    这可以通过直接调整第一层的输出通道数来实现。

  • 它大概以 a^2的比例减少了参数数量,降低了计算量。

  • 通常将其设置为:0.25、0.5、0.75、1.0 四档。

分辨率乘子resolution multiplier,记做p 。其作用是:降低输出的feature map 的尺寸。

  • 分辨率乘子应用于输入图片上,改变了输入图片的尺寸。这也影响了后续所有Depthwise可分离卷积层的输入feature map 尺寸、输出feature map 尺寸。

    这可以通过直接调整网络的输入尺寸来实现。

  • 它不会改变模型的参数数量,但是大概以p^2 的比例降低计算量。

如果模型同时实施了宽度乘子和分辨率乘子,则模型大概以 a^2 的比例减少了参数数量,大概以 a^2p^2 的比例降低了计算量。

假设输入feature map 尺寸为14x14,通道数为 512 ;卷积尺寸为3x3;输出feature map 尺寸为14x14,通道数为512

层类型乘-加操作(百万)参数数量(百万)
常规卷积4622.36
深度可分离卷积52.30.27
a=0.75 的深度可分离卷积29.60.15
a=0.75,p=0.714 的深度可分离卷积15.10.15

1.4 模型的性能

1.4.1 更瘦的模型和更浅的模型的比较

在计算量和参数数量相差无几的情况下,采用更瘦的MobileNet 比采用更浅的MobileNet 更好。

  • 更瘦的模型:采用a=0.75 宽度乘子( 表示模型的通道数更小)。
  • 更浅的模型:删除了MobileNet5x Conv dw/s 部分(即:5层 feature size=14x14@512 的深度可分离卷积)。
模型ImageNet Accuracy乘-加操作(百万)参数数量(百万)
更瘦的MobileNet68.4%3252.6
更浅的MobileNet65.3%3072.9

1.4.2 不同宽度乘子及分辨率乘子的比较

随着a降低,模型的准确率一直下降(a=1 表示基准 MobileNet)。

with multiplierImageNet Accuracy乘-加 操作(百万)参数数量(百万)
1.070.6%5694.2
0.7568.4%3252.6
0.563.7%1491.3
0.2550.6%410.5

同样,随着 p的降低,模型的准确率一直下降( p=1表示基准MobileNet)。

resolutionImageNet Accuracy乘-加 操作(百万)参数数量(百万)
224x22470.6%5694.2
192x19269.1%4184.2
160x16067.2%2904.2
128x12864.4%1864.2

1.4.3 MobileNet和其它模型的比较

作者将V1与大型网络GoogleNet和VGG16进行了比较。

可以发现,作为轻量级网络的V1在计算量小于GoogleNet,参数量差不多是在一个数量级的基础上,在分类效果上比GoogleNet还要好,这就是要得益于深度可分离卷积了。VGG16的计算量参数量比V1大了30倍,但是结果也仅仅只高了1%不到。

在这里插入图片描述

瘦身的MobileNet(宽度乘子a=0.75 ,分辨率乘子p=0.714 )和 Squeezenet 模型大小差不多,但是准确率更高,计算量小了 22 倍。

在这里插入图片描述

2 MobileNet V1在Fashion-MNIST数据集上的应用示例

2.1 创建MobileNet V1网络模型

我们实现一个简化版本的模型。

import torch.nn as nn
import torch

class MobileNetV1(nn.Module):
    def __init__(self):
        super(MobileNetV1, self).__init__()

        def conv_bn(inp, oup, stride):
            """
               标准卷积块
            """
            return nn.Sequential(
                nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        def conv_dw(inp, oup, stride):
            """
               深度可分离卷积
            """
            return nn.Sequential(
                # 深度卷积
                nn.Conv2d(
                    in_channels=inp,
                    out_channels=inp, # out_channels=in_channels
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=inp,      # groups=in_channels
                    bias=False
                ),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),

                # 逐点卷积
                nn.Conv2d(
                    in_channels=inp,
                    out_channels=oup,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    bias=False
                ),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential(
            # conv_bn(3, 32, 2),
            conv_bn(1, 32, 2),
            conv_dw(32, 64, 1),  # 深度卷积层有13层
            conv_dw(64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AvgPool2d(7),
        )
        # self.fc = nn.Linear(1024, 1000)
        self.fc = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

if __name__ == '__main__':
    net = MobileNetV1()
    X = torch.rand(size=(1, 1, 224, 224), dtype=torch.float32)
    for layer in net.model:
        X = layer(X)
        print(layer.__class__.__name__, 'output shape:', X.shape)

equential output shape: torch.Size([1, 32, 112, 112])
Sequential output shape: torch.Size([1, 64, 112, 112])
Sequential output shape: torch.Size([1, 128, 56, 56])
Sequential output shape: torch.Size([1, 128, 56, 56])
Sequential output shape: torch.Size([1, 256, 28, 28])
Sequential output shape: torch.Size([1, 256, 28, 28])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 1024, 7, 7])
Sequential output shape: torch.Size([1, 1024, 7, 7])
AvgPool2d output shape: torch.Size([1, 1024, 1, 1])

2.2 读取Fashion-MNIST数据集

# 我们将图片大小设置224×224
# 训练机器内存有限,将批量大小设置为64
batch_size = 64

train_iter,test_iter = get_mnist_data(batch_size,resize=224)

2.3 在GPU上进行模型训练

from _07_MobileNetV1 import MobileNetV1

# 初始化模型
net = MobileNetV1()

lr, num_epochs = 0.1, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/721561.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【hadoop】Google的基本思想

Google的基本思想 三架马车GFS分布式文件系统的核心架构和原理机架感知 MapReduce计算模型PageRank问题MapReduce BigTable 三架马车 Google的基本思想主要有三个,称之为三架马车,分别是GFS(Google File System)、MapReduce计算模…

gitlab/gerrit

gitlab/gerrit 1. gitlab2. gerrit2.1 环境准备2.2 下载软件2.3 创建启动账户2.4 安装gerrit2.5 创建登录账户2.6 启动服务2.7 修改配置文件2.8 配置反向代理(nginx)2.9 gerrit主页 3. gitlabgerrit3.1 配置gerrit replication功能(用于复制具体项目)3.2…

深入浅出讲解Stable Diffusion原理,新手也能看明白

说明 最近一段时间对多模态很感兴趣,尤其是Stable Diffusion,安装了环境,圆了自己艺术家的梦想。看了这方面的一些论文,也给人讲过一些这方面的原理,写了一些文章,具体可以参考我的文章: 北方…

51单片机驱动 mg996r金属舵机 STC89C52单片机直接驱动金属大舵机

/*无论是大舵机&#xff0c;还是小舵机&#xff0c;控制方法都一样会区别在 大舵机只能接P0口&#xff08;此口外接上拉&#xff0c;驱动电流最大&#xff09;小舵机任意口 */ //#include<reg51.h> //#define uint unsigned int //#define uchar unsigned char //sbit S…

10、架构:组件通信设计

通信是一个应用中不可或缺的一个功能&#xff0c;现如今前端视图类框架大多数都是由数据驱动&#xff0c;通过数据来进行视图层的展示渲染。举个简单的例子如下&#xff0c;这是一个常见的 React 列表渲染&#xff1a; // each const numbers [1, 2, 3, 4, 5]; const listIte…

应用级监控方案Spring Boot Admin

1.简介 Spring Boot Admin为项目常用的监控方式&#xff0c;可以动态的监控服务是否运行和运行的参数&#xff0c;如类的调用情况、流量等。其中分为server与client&#xff1a; server&#xff1a; 提供展示UI与监控服务。client&#xff1a;加入server&#xff0c;被监控的…

C语言王国探险记之函数的简单概念

王国探险记系列 文章目录&#xff08;5&#xff09; 目录 王国探险记系列 文章目录&#xff08;5&#xff09; 前言 一&#xff0c;函数的基本概念 二&#xff0c;调用外部函数和main()函数区别 2.1如果我们将函数的定义放到后面&#xff0c;可不可以呢&#xff1f; 总结…

插值应用案例1

案例1 一阶线性插值 待加工零件外形根据工艺要求在一组数据(x,y)给定&#xff08;如下表&#xff09;&#xff0c;用程控铣床加工时每一刀只能沿着x方向或y方向走非常小的一步&#xff0c;需要从已知数据得到加工步长很小的(x,y)的坐标。 下表中所给x,y数据位于机翼断面的下…

使用Vue脚手架

(193条消息) 第 3 章 使用 Vue 脚手架_qq_40832034的博客-CSDN博客 初始化脚手架 说明 1.Vue脚手架是Vue官方提供的标准化开发工具&#xff08;开发平台&#xff09; 2.最新的版本是4.x 3.文档Vue CLI脚手架&#xff08;命令行接口&#xff09; 具体步骤 1.如果下载缓慢…

Libvirt Event Loop简介

文章目录 前言实现原理处理框架编程接口 原理验证事件订阅服务监听验证流程 前言 Event Loop顾名思义就是事件循环&#xff0c;整个程序是一个大的循环&#xff0c;通过事件来驱动程序要做的事情。传统编程模型是顺序的&#xff0c;程序运行一次然后终止&#xff0c;这种模型简…

JavaScript Day10 DOM详解

DOM DOM是JS操作网页的接口&#xff0c;全称为“文档对象模型”&#xff08;Document Object Model&#xff09;。它的作用是将网页转为一个JS对象&#xff0c;从而可以用脚本进行各种操作&#xff08;比如增删内容&#xff09;。 • 文档 – 文档表示的就是整个的HTML网页文档…

19-Linux 权限

目录 1.用户操作 1.1.创建用户 1.2.配置密码 1.3. 切换用户 2.三种角色 3.文件类型和访问权限 3.1.文件类型 3.2.基本权限 4.修改文件权限 1.用户操作 Linux下有两种用户&#xff1a; 超级用户&#xff08;root&#xff09;普通用户 超级用户&#xff1a;可以再lin…

【Cache】Redis主从复制哨兵模式集群

文章目录 一、Redis 持久化1. 主从复制2. 哨兵模式3. 集群 二、 Redis 主从复制1. 概述2. 主从复制的作用3. 主从复制流程4. 搭建 Redis 主从复制4.1 环境准备4.2 安装 Redis4.3 修改 Master 节点配置文件4.4 修改Slave节点配置文件&#xff08;Slave1和Slave2配置相同&#xf…

【vant移动端表格数据排版】用vant2简单实现一个把PC端表格数据展示在移动端的排版。上拉加载更多,下拉刷新页面,新增,编辑,删除功能

前言 上次做了一个移动端的表格功能&#xff0c;纯表格的那种。 跟PC一样&#xff0c;但是我一直觉得在移动端上写表格很糟糕的体验&#xff0c;毕竟手机就那么大。这不合理。 但是我这公司又需要把PC端的表格的数据展示在移动端。 导致我只能去试试看怎么排版比较好。由于网上…

【Qt-14】QT小知识点

1、关闭程序时报错 解决方案&#xff1a; 报这个错误可能是内存溢出&#xff0c;申请的空间与注销的空间不一致导致&#xff0c;排查了好久&#xff0c;我不是因为这个原因&#xff0c;我的问题如下&#xff0c;没有new窗体。 2、固定QT窗体大小 this->setMinimumSize(QSi…

NLP实战6:seq2seq翻译实战-Pytorch复现-小白版

目录 一、前期准备 1. 搭建语言类 2. 文本处理函数 3. 文件读取函数 二、Seq2Seq 模型 1. 编码器&#xff08;Encoder&#xff09; 2. 解码器&#xff08;Decoder&#xff09; 三、训练 1. 数据预处理 2. 训练函数 四、训练与评估 &#x1f368; 本文为[&#x1f51…

【算法集训之线性表篇】Day 02

文章目录 题目一思路分析代码实现效果 题目二思路分析代码实现效果 题目一 01.设置一个高效算法&#xff0c;将顺序表L的所有元素逆置&#xff0c;要求其空间复杂度为O(1)。 思路分析 首先&#xff0c;根据题目要求&#xff0c;空间复杂度度为O(1),则不能通过空间换时间的方…

为什么编程更关注内存而很少关注CPU?

我们知道&#xff0c;我们编写的程序&#xff0c;不管是什么编程语言&#xff0c;最后执行的时候&#xff0c;基本上都是CPU在完成。之所以说基本上&#xff0c;是因为还有GPU、FPGA等特殊情况。 但不知道大家发现没有&#xff0c;我们编程的时候&#xff0c;经常在关注内存问…

大促转化率精准预估优化论文随笔记

这是一篇阿里妈妈的论文【KDD’23 | 转化率预估新思路&#xff1a;基于历史数据复用的大促转化率精准预估】 常规的销量预测&#xff0c;遇到一些特大事件&#xff0c;直播、大促&#xff0c;一般很难预估得准确。而且现在电商机制也比较多样&#xff0c;预售、平台折扣等。 本…

初识MySQL:了解MySQL特性、体系结构以及在Linux中部署MySQL

目录 MySQL简介 MySQL特性 MySQL体系结构 SQL的四个层次&#xff1a; 连接层&#xff1a; SQL层&#xff1a; 插件式存储引擎&#xff1a; 物理文件层&#xff1a; 一条SQL语句的执行流程&#xff1a; MySQL在Linux中的安装、部署 首先需要下载mysql软件包&#xff…