MobileNetV3原理说明及实践落地

news2024/11/24 8:04:57

本文参考:

pytorch实现并训练MobileNetV3 - 灰信网(软件开发博客聚合)

 【神经网络】(16) MobileNetV3 代码复现,网络解析,附Tensorflow完整代码 - 代码天地

 

1 MobileNetV3与V1、V2对比

(1)MobileNetV1的主要思想是将普通卷积操作分解为两步,先做一次仅卷积不求和的depthwise conv(深度卷积:可理解为处理长宽方向的空间信息)操作,再使用1*1的pointwise conv(逐点卷积:只处理跨通道方向的信息)对深度卷积得到的多通道结果进行融合,减少了减少了大量的通道融合时间和参数。

(2)MobileNetV2:对比ResNet中的Bottleneck残差块通道数先变少再变多,每层中的激活函数会导致丢失一部分的信息,那么在输入通道数很多时丢的信息会很多。MobileNetV2的Bottlenect改为先将通道数增加,再将通道数减少,所以被称为“反转残差块”。

(3)MobileNetV3:引入了H-Switch激活函数与ReLU搭配使用,另外还引入了注意力机制的SE模块,SE模块为该版本最精华部分,同时使用了新的激活函数。

2 MobileNetV3概述

2.1 SE注意力机制

每个Block经过两个卷积层后得到一个由channel个元素组成的向量,每个元素是针对每个通道的权重,将权重和原特征图对应相乘,得到新的特征图数据。

2.2 使用不同的激活函数

 2.3 总体流程

图像输入,先通过1*1卷积上升通道数;

然后在高维空间下使用深度卷积;

再经过SE注意力机制优化特征图数据;

最后经过1*1卷积下降通道数(使用F(x)=x的线性激活函数)。

当步长等于1且输入和输出特征图的shape相同时,使用残差连接输入和输出;当步长等于2(下采样阶段)直接输出降维后的特征图。

 3、代码实现

import torch
from torch import nn
import torch.nn.functional as F


class hswish(nn.Module):
    def __init__(self):
        super(hswish, self).__init__()
        self.relu6 = nn.ReLU6(inplace=True)

    def forward(self, x):
        out = x * self.relu6(x + 3) / 6
        return out

class hsigmoid(nn.Module):
    def __init__(self):
        super(hsigmoid, self).__init__()
        self.relu6 = nn.ReLU6(inplace=True)

    def forward(self, x):
        out = self.relu6(x + 3) / 6
        return out

# 注意力机制
class SE(nn.Module):
    def __init__(self, in_channels, reduce=4):
        super(SE, self).__init__()

        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduce, 1, bias=False),
            nn.BatchNorm2d(in_channels // reduce),
            nn.ReLU6(inplace=True),
            nn.Conv2d(in_channels // reduce, in_channels, 1, bias=False),
            nn.BatchNorm2d(in_channels),
            hsigmoid()
        )

    def forward(self, x):
        out = self.se(x)
        out = x * out
        return out

class Block(nn.Module):
    def __init__(self, kernel_size, in_channels, expand_size, out_channels, stride, se=False, nolinear='RE'):
        super(Block, self).__init__()

        self.se = nn.Sequential()
        if se:
            self.se = SE(expand_size)

        if nolinear == 'RE':
            self.nolinear = nn.ReLU6(inplace=True)
        elif nolinear == 'HS':
            self.nolinear = hswish()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, expand_size, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(expand_size),
            self.nolinear,

            nn.Conv2d(expand_size, expand_size, kernel_size, stride=stride, padding=kernel_size // 2, groups=expand_size, bias=False),
            nn.BatchNorm2d(expand_size),
            self.se,
            self.nolinear,

            nn.Conv2d(expand_size, out_channels, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Sequential()
        if stride == 1 and in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.stride = stride

    def forward(self, x):
        out = self.block(x)

        if self.stride == 1:
            out += self.shortcut(x)

        return out

class MobileNetV3(nn.Module):
    def __init__(self, class_num=10):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            hswish()
        )

        self.neck = nn.Sequential(
            Block(3, 16, 16, 16, 2, se=True),
            Block(3, 16, 72, 24, 2),
            Block(3, 24, 88, 24, 1),
            Block(5, 24, 96, 40, 2, se=True, nolinear='HS'),
            Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),
            Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),
            Block(5, 40, 120, 48, 1, se=True, nolinear='HS'),
            Block(5, 48, 144, 48, 1, se=True, nolinear='HS'),
            Block(5, 48, 288, 96, 2, se=True, nolinear='HS'),
            Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),
            Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(96, 576, 1, bias=False),
            nn.BatchNorm2d(576),
            hswish()
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.conv3 = nn.Sequential(
            nn.Conv2d(576, 1280, 1, bias=False),
            nn.BatchNorm2d(1280),
            hswish()
        )

        self.conv4 = nn.Conv2d(1280, class_num, 1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.neck(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.flatten(1)

        return x

if __name__ == '__main__':
    model = MobileNetV3(10)

    input = torch.randn(2, 3, 516, 516)   # batch_size =1 会报错
    out = model(input)
    print(out.shape)

4、MobileNetV3在CenterNet目标检测落地情况

(1)训练情况

训练loss,mobilenetV1在batch_size=16时最少达到4.0左右;

mobileNetV2在batch_size=16时最少达到0.5以下;

mobileNetV3在batch_size=16时最少达到0.25左右。

(2)目标检测效果

(3)模型参数量:

DLASeg为2000W个左右

MobileNetV1为320W个左右

MobileNetV2为430W个左右,总模型大小为17M

MobileNetV3为166W个左右,总模型大小为7M

(4)CPU运行时间

DLASeg为1.2s

MobileNetV1为250ms

MobileNetV2为600ms

MobileNetV3为120ms

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/91346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode每日一题:1945. 字符串转化后的各位数字之和~~~模拟】

题目描述 给你一个由小写字母组成的字符串 s ,以及一个整数 k 。 首先,用字母在字母表中的位置替换该字母,将 s 转化 为一个整数(也就是,‘a’ 用 1 替换,‘b’ 用 2 替换,… ‘z’ 用 26 替换…

匿名浏览器是什么?为什么联盟营销需要借助匿名浏览器?

这段时间小伙伴们都对联盟营销很感兴趣,东哥也是陆陆续续出了两三篇相关的科普文章,今天继续给大家介绍匿名浏览器在联盟营销上的帮助,毕竟互联网时代,学会如何借助工具高效工作是很重要的。关于联盟营销的概念科普文章大家可以看…

学不会的python之通过某几个关键字排序、分组一个字典列表(列表中嵌套字典)

通过某个关键字排序、分组一个字典列表排序问题描述解决方案1.operator 模块的 itemgetter 函数2.lambda 表达式引申分组问题描述解决方案1.itertools.groupby() 函数2.defaultdict() 构建多值字典排序 问题描述 现在你有一个字典列表(列表中嵌套字典),你想要根据…

web 向 unity 传输文件流 blob 记录

场景:web 与unity 通信,向 unity 传输文件 二进制流。 由 unity 转换并下载文件。 流程: web 端将缓存的 blob 数据流读取为 base64 编码的数据 → 传给 unity, →unity 解码转换 base64 数据并下载。 web 端: 1、 将数据转换成…

【Axure教程】自定义审批流原型模板

审批流即审批流程,是对某项工作的审批活动的一系列有序组合。审批流在业务系统中担当者非常重要的角色,所以今天作者就教大家制作一个通用的自定也审批流的原型模板,方便大家日后的工作。 一、效果展示 1、可以根据业务需要添加多个审批节点…

QT学习笔记(中)

QT学习笔记(中) 文章目录QT学习笔记(中)P21 消息对话框P22 其他标准对话框P23 登录窗口界面和布局P24 控件 按钮组P25 QListWidget控件P26 QTreeWidget控件的使用P27 tableWidgetP28 其他常用控件介绍P30 自定义控件P31 QEventP32…

PyQt5 QtChart-折线图

PyQt5 QtChart-QLineSeries 折线图QLineSeriesQLineSeries QLineSeries类将数据序列显示为折线图,其核心代码: lineSeries QLineSeries() lineSeries.append(1, 3) lineSeries.append(5, 8) … chart.addSeries(lineSeries) 常用方法: set…

【linux】容器之代码自动发布-docker

一、分析 旧: 代码发布环境提前准备,以主机为颗粒度静态 新: 代码发布环境 多套,以容器为颗粒度编译 二、业务发布逻辑设计图 三、工具使用流程图 工具 gitgitlabjenkinstomcatmavenharbordocker 流程图 四、主机规划 五…

​智能化加速,「中国供应商」如何跨越规模化周期|高工观察

在过去的十年时间里,中国在智能电动汽车行业下了巨大的「赌注」,整个行业及其背后快速成长的本地化产业链生态系统成为新一轮汽车产业增长的新引擎。 与此同时,电动化、智能化技术的国产化突围,也让整个中国本土汽车产业链获得了…

SuperMap GIS的TIN地形数据处理QA

目录 一、TIN地形数据简介 二、TIN地形数据格式 三、TIN地形数据处理 3.1 导入数据集 3.2 生成TIN地形缓存 3.3 IDesktop场景加载TIN地形 3.4 发布服务 3.5 WebGL场景加载 3.5.1 viewer初始化加载 3.5.2 scene.open加载 四、可能遇到的报错及解决方案 问题一:多个TI…

蓝海创意云×可米酷 || “360VR全景直播解决方案”亮相企业产品发布会

12月8日,可米酷2023新品发布会重磅召开,蓝海创意云为可米酷提供了前沿技术支持,助力整场活动实现了360全景VR在线直播,为企业线下发布会直播活动提供借鉴。 发布会现场采用了全新的虚拟现实技术VR视频全景直播方式,全国…

Spring 中 PageHelper 不生效问题

使用这个插件时要注意版本的问题&#xff0c;不同的版本可能 PageHelper 不会生效 springboot 导入的 pagehelper 包 <dependency><groupId>com.github.pagehelper</groupId><artifactId>pagehelper-spring-boot-starter</artifactId><vers…

java+mysql 基于ssm的校园二手交易系统

现如今,校园二手交易系统是商业贸易中的一条非常重要的道路,可以把其从传统的实体模式中解放中来,网上购物可以为消费者提供巨大的便利。通过校园二手交易系统这个平台,可以使用户足不出户就可以了解现今的流行趋势和丰富的商品信息,为用户提供了极大的方便,校园二手交易系统的…

技术分享 | 跨平台API对接(Java)

本章介绍基于 Jenkins API 调用的跨平台 API 对接。 基于Jenkins实现跨平台API对接 Jenkins 提供了远程访问应用编程接口&#xff08;Remote Access API&#xff09;&#xff0c;能够通过 Http 协议远程调用相关命令操作 Jenkins 进行 Jenkins 视图、任务、插件、构建信息、任…

vue3 安装使用scss

1、安装相关依赖 node-sass css-loader style-loader sass-loader 2、声明 lang"scss" 或者 scss文件中就可以直接使用 3、重点&#xff1a;安装依赖的过程中出现的各种问题 3.1、安装node-sass 报错 如果没有安装python,就去下个安装包装一下记得配置环境变量…

世界杯小吐槽

冷门 在看这次世界杯的时候&#xff0c;心里真的是一上一下&#xff0c;今年的冷门太多了&#xff01; 如&#xff1a; 阿根延 VS 沙特阿拉伯 阿根延输了&#xff08;我想可能是阿拉伯的战术比较新吧!&#xff09;那场比赛之后&#xff0c;阿拉伯还全国放假一天。到现在&#…

1.浮动 float

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 1.4什么是浮动 float属性用于创建浮动框&#xff0c;将其移动到右边&#xff0c;直到左边缘或右边缘触及包含块或另一个浮动框的边缘。 1、语法&#xff1a; <style> …

2023年pmp的考试时间是什么时候?

PMP 考试一年是有四次考试&#xff0c;分别是 3 月、6月、9月、12月&#xff0c;不出意外的话就是这几个月了&#xff0c;提前 2 个月开始报名&#xff0c;但还是要关注PMI/基金会官网的信息&#xff0c;以官网的消息为准。 一、报考条件 报考条件其实挺简单的&#xff0c;最核…

MFC 错误 error C2504: “CDialogEx”: 未定义基类-报错解决

错误&#xff1a; 在MFC文件中添加资源窗口&#xff0c;后添加新类&#xff0c;随后在.h头文件中出现 CDialogEx C class 未定义基类错误。 原因&#xff1a; 首先&#xff0c;下图这个framework.h非常关键&#xff0c;它在pch.h中也有定义&#xff0c;所以下图这个framework.h…

编译原理实验三

编译原理实验三 问题1: cpp与.ll的对应 请描述你的cpp代码片段和.ll的每个BasicBlock的对应关系。描述中请附上两者代码。 assign 对应的.ll代码如下&#xff1a; define i32 main() #0 {%1 alloca [10 x i32] ;int a[10]%2 getelementptr inbounds [10 x i32], [10 …