MobileNetV2详细原理(含torch源码)

news2025/1/20 16:27:35

 

目录

MobilneNetV2原理

MobileNetV2的创新点:

MobileNetV2对比MobileNetV1

MobilneNetV2源码(torch版)

训练10个epoch的效果


MobilneNetV2原理

        MobileNetV2是由谷歌开发的一种用于移动设备的轻量级卷积神经网络。与传统卷积神经网络相比,它具有更高的计算效率和更小的模型尺寸,可以在移动设备上实现高精度的图像识别任务。

        MobileNetV2的主要原理是使用深度可分离卷积来减少模型的参数数量和计算量。深度可分离卷积将传统的卷积操作分解为两个独立的操作:深度卷积和逐点卷积。深度卷积仅在通道维度上进行卷积操作,而逐点卷积仅在空间维度上进行卷积操作。这种分解能大大降低计算复杂度,同时保持较高的分类精度。

        另外,MobileNetV2还使用了线性瓶颈函数来加速网络训练,以及Inverted Residuals结构来充分使用低维特征信息。

线性瓶颈结构:

MobileNetV2两种残差块:

        它还采用了轻量级的特征网络Design Spaces提升性能的策略,优化卷积核大小和数量,调整网络宽度和深度,最终得到一个更加高效的网络。网络结构图如下:

 

MobileNetV2的创新点:

        MobileNetV2相较于MobileNetV1在以下方面进行了创新:

  1. Inverted Residuals:MobileNetV2使用了Inverted Residuals结构,将输入先进行低维变换,再使用残差模块加上上采样,最后使用1x1卷积进行通道变换,从而减少计算量。

  2. Linear Bottlenecks: MobileNetV2使用1x1卷积核将输入通道数缩小到一个较小的值,然后进行卷积操作,最后再使用1x1卷积通道扩展回原来的通道数。这样可以减少计算量和参数量,同时提高模型准确度。

  3. 使用深度可分离卷积:MobileNetV2中使用了深度可分离卷积,在计算相同的特征图时用的参数远少于传统卷积。而且,深度可分离卷积允许使用不同的卷积核、池化层和标准化层,从而提高了模型的灵活性。

  4. 设计高效的shortcut连接:在MobileNetV2中,shortcut连接采用的是identity mapping方法,使用1x1卷积将跳过的特征图的通道数与当前特征图的通道数对应起来,同时这种结构可以避免梯度消失和梯度爆炸的问题,提高了模型的稳定性。

  5. 激活函数采用Scaled Exponential Linear Unit (SELU):MobileNetV2将激活函数采用了Scaled Exponential Linear Unit (SELU),可以在不增加计算量的情况下提高模型的准确性。

        总之,MobileNetV2通过使用深度可分离卷积和其他技术来减少计算量和模型尺寸,同时保持高精度的分类任务,是一种非常有前途的轻量级卷积神经网络。

MobileNetV2对比MobileNetV1

        MobileNetV2相比于MobileNetV1,主要改进有以下几个方面:

  1. 更优的性能:MobileNetV2在ImageNet上的Top-1准确率为72.0%,相比MobileNetV1(70.6%)有显著提升。

  2. 更高的效率:MobileNetV2在相同的计算资源下,参数量比MobileNetV1少了40%,计算量比它少了30%。

  3. 更好的适应性:MobileNetV2引入了一些新的技术,例如倒置余弦线性单元(Inverted Residuals with Linear Bottlenecks)和线性瓶颈模块(Linear Bottlenecks),使得它更适应于移动设备上的实时推理场景。

  4. 更好的鲁棒性:MobileNetV2在对小型变形的物体分类和检测任务上,能够显著提高模型的准确率。

MobilneNetV2源码(torch版)

数据集运行代码时自动下载,如果网络比较慢,可以自行点击我分享的链接下载cifar数据集。

链接:百度网盘
提取码:kd9a 

此代码是使用的GPU运行的,如果没有GPU导致运行失败,就把代码中的device、.to(device)删除,使用默认CPU运行。

如果使用GPU,GPU显存小导致运行报错,就将主函数main()里面的batch_size调小即可。



from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.autograd import Variable


import torch.nn as nn
import torch.nn.functional as F
import torch

class Block(nn.Module):
    def __init__(self, in_planes, out_planes, expansion, stride):
        super(Block, self).__init__()
        self.stride = stride

        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x) if self.stride == 1 else out
        return out


class MobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()

        self.cfgs = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 1],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layers = self._make_layers(in_planes=32)
        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        self.linear = nn.Linear(1280, num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for t, c, n, s in self.cfgs:
            for i in range(n):
                stride = s if i == 0 else 1
                layers.append(Block(in_planes, c, t, stride))
                in_planes = c
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


if __name__ == '__main__':
    train_data = CIFAR10('cifar', train=True, transform=transforms.ToTensor())
    data = DataLoader(train_data, batch_size=148, shuffle=True)

    device = torch.device("cuda")
    net = MobileNetV2().to(device)
    print(net)
    cross = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(net.parameters(), 0.0001)
    for epoch in range(10):
        for img, label in data:
            img = Variable(img).to(device)
            label = Variable(label).to(device)
            output = net.forward(img)
            loss = cross(output, label)
            loss.backward()
            optimizer.zero_grad()
            optimizer.step()
            pre = torch.argmax(output, 1)
            num = (pre == label).sum().item()
            acc = num / img.shape[0]
        print("epoch:", epoch + 1)
        print("loss:", loss.item())
        print("Accuracy:", acc)

        以上代码采用的是跟PyTorch官方模型一样的模型结构,在_make_layers函数中构建了7个Block块,每个Block块都是跟MobileNetV2一样的结构,通过自定义一个Block类来实现。在forward函数中,通过调用这7个Block块的方式构建整个网络的结构。在最后的分类层中,采用了一个线性层。 

        使用cifar10训练做测试。

训练10个epoch的效果

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/434637.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RapidOCR调优尝试教程

目录 引言常见错例种类个别字丢失调优篇个别字识别错误调优篇情况一:轻量中英文模型识别对个别汉字识别错误情况二:轻量中英文模型对个别英文或数字识别错误 相关链接 引言 由于小伙伴们使用OCR的场景多种多样,单一的参数配置往往不能满足要…

qt6.2.4下载在线安装

前言 qt官网声明5.15版本以后不提供安装包安装,均需在线安装:Due to The Qt Company offering changes, open source offline installers are not available any more since Qt 5.15。此文主要记录在线安装方法及遇到问题解决方式。 一. 在线安装执行文…

mingw32-make -j$(nproc) 命令含义

系列文章目录 文章目录 系列文章目录前言一、具体操作二、使用步骤 前言 在使用krita源码编译时遇到报错: 这段代码是 Krita 源码中的一个 CMakeLists.txt 文件片段,用于配置 Krita 项目的构建系统。以下是对这段代码的解释: find_package(…

如何写科技论文?(以IEEE会议论文为例)

0. 写在前面 常言道,科技论文犹如“八股文”,有固定的写作模式。本篇博客主要是针对工程方面的论文的结构以及写作链条的一些整理,并不是为了提高或者润色一篇论文的表达。基本上所有的论文,都需要先构思好一些点子,有…

一文带你快速了解业务流程分析和流程建模

🔥业务流程分析与建模 01业务流程分析要了解的问题 有哪些业务流程?业务流程如何完成?业务流程有谁参与?流程中有哪些控制流(如判断、 同步分支和会合)?多个不同流程建的关系?完成…

JUC线程池之线程池架构

JUC线程池之线程池架构 在多线程编程中,任务都是一些抽象且离散的工作单元,而线程 是使任务异步执行的基本机制。随着应用的扩张,线程和任务管理也 变得非常复杂。为了简化这些复杂的线程管理模式,我们需要一个 “管理者”来统一…

SOLIDWORKS Composer如何实现可视化产品交互

SOLIDWORKS Composer是一款让工程师和非工程人员都能够直接访问 3D CAD 模型、并为技术交流材料创建图形内容的 3D 软件。现如今很多制造型企业都已逐步实现其产品设计流程的自动化,以期比竞争对手更快进入市场。但遗憾的是在很多企业中,技术交流内容&am…

Android之修改Jar包源码后再重新打Jar包

一、找到jar包使用框架的github源码,并下载 例如:原有jar包 找到框架源码的github地址: https://github.com/eclipse/paho.mqtt.android 使用git拉取源码项目到本地 二、New Module — 选择Java or Kotlin Library新建 (1&…

RabbitMQ安装教程

目录 Erlang官网 Erlang下载 RabbitMQ官网 windows RabbitMQ docker安装rabbitmq 最近入职新公司,要求会RabbitMQ,所以自学了一下,现将自学的结果总结如下: 安装RabbitMQ之前,需要先安装 Erlang,因为RabbitMQ使用…

【C++初阶】缺省参数与函数重载

一.缺省参数 C祖师爷在用C写代码的时候,就觉得有些地方用着很不方便,于是就在C设计了缺省参数,在实现一些功能时,用这个就方便了许多。 1.概念 缺省参数是声明或定义函数时为函数的参数指定一个缺省值。在调用该函数时&#xff1a…

强训之【井字棋和密码强度等级】

目录 1.井字棋1.1题目1.2思路讲解1.3代码展示 2.密码强度判断2.1题目2.2思路讲解2.3代码 3.选择题 1.井字棋 1.1题目 链接: link 描述 给定一个二维数组board,代表棋盘,其中元素为1的代表是当前玩家的棋子,0表示没有棋子,-1代表…

C/C++每日一练(20230419)

目录 1. 插入区间 2. 单词拆分 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏 Java每日一练 专栏 1. 插入区间 给你一个 无重叠的 ,按照区间起始端点排序的区间列表。 在列表中插入一个新的区间…

数据分析实战(二百零四):项目分析思路 —— 某线下连锁水果店销售数据分析

版权声明:本文为博主原创文章,未经博主允许不得转载。 文章目录 一、问题确认与指标拆解:业务逻辑图 一、问题确认与指标拆解:业务逻辑图 版权声明:本文为博主原创文章,未经博主允许不得转载。

高精度(加法+减法+除法+乘法)合集

由于c/c语言特性,当数很大时,就要考虑精度问题,python和java则不用,因此c学会精度运算很重要的,这里作个总结 1.高精度加法 给定两个正整数(不含前导 0),计算它们的和。 输入格式…

一定要会的算法复杂度分析

本文首发自「慕课网」,想了解更多IT干货内容,程序员圈内热闻,欢迎关注"慕课网"! 原作者:s09g|慕课网讲师 我们知道面对同一道问题时可能有多种解决方案。自然地,我们会将多种方法进行比较。那么…

【Linux】网络协议(应用层与传输层)

应用层传输层协议 应用层HTTP协议格式请求格式响应格式头部字段中的 Cookie (请求头) & Set-Cookie(响应头)cookiesession:会话cookie vs session HTTPS 协议:对 HTTP 协议进行加密 传输层UDP 协议TCP 协议 应用层 序列化:指将多个数组对…

使用shell封装Linux命令实现自定义Linux命令

前言 在日常工作中,尤其是在Linux上做开发的同学或者运维的同学们肯定会遇到过如下场景,比如在Linxu下通过find查找一个文件并且想看这个文件的详细信息,如果直接使用命令可能会一时想不起来,或者想起来了但是有个别参数忘记了。…

SpringMVC02注解与Rest风格

SpringMVC02 SpringMVC的注解 一、RequestParam 1、RequestParam注解介绍 位置:在方法入参位置作用:指定参数名称,将该请求参数 绑定到注解参数的位置属性 name:指定要绑定的请求参数名称; name属性和value属性互为…

vue-quill-editor富文本编辑框使用

vue富文本中实现上传图片及修改图片大小等功能。 1&#xff0c;配置使用 配置使用网上很多&#xff0c;记录下自己的使用过程 第一步&#xff1a;components/Editor文件夹下创建QuillEditor.vue文件 <template><div :class"prefixCls"><quill-edito…

bitset的用法

bitset的用法 bitset介绍 C的 bitset 在 bitset 头文件中&#xff0c;它是一种类似数组的结构&#xff0c;它的每一个元素只能是&#xff10;或&#xff11;&#xff0c;每个元素仅用&#xff11;bit空间&#xff0c;相当于一个char元素所占空间的八分之一。 bitset中的每个…