MobileNetV1详细原理(含torch源码)

news2024/9/20 1:23:41

目录

MobileNetV1原理

MobileNet V1的网络结构如下:

为什么要设计MobileNet:

MobileNetV1的主要特点如下:

MobileNetV1的创新点:

MobileNetV1源码(torch版)

训练10个epoch的效果


MobileNetV1原理

        MobileNet V1是一种轻量级的卷积神经网络,能够在保持较高准确率的情况下具有较少的参数量和计算时间。它是由Google的研究人员在2017年提出的,并成为当时最流行的轻量级模型之一。

        MobileNet V1的核心思想是通过深度分离卷积来减少模型的参数量和计算时间。与标准卷积不同,深度分离卷积将空间卷积和通道卷积分为两个独立的卷积层,这使得网络更加高效。具体来说,在深度分离卷积中,首先使用一个空间卷积,然后使用一个通道卷积来提取特征。这与标准卷积相比可以减少参数数量并加速运算。

MobileNet V1的网络结构如下:

        MobileNet V1由序列卷积和1x1卷积两个部分组成。序列卷积包括13个深度可分离卷积层,每个层都包括一个3x3的卷积和一个批量归一化层(BN层),并且在卷积之后使用了ReLU6激活函数。最后,1x1卷积层用于生成最终的特征向量,并使用全局平均池化来缩小特征图的大小。在最后一层之后,使用一个全连接层来进行分类。MobileNet V1可以根据需要使用不同的输入分辨率,其超参数取决于输入分辨率和需要的精度。

为什么要设计MobileNet:

        Mobilenetv1是一种轻量级的深度神经网络模型,设计的目的是在保持较高的精度的同时减小模型的大小和计算量,使其适合于移动设备的推理任务。在过去,大部分深度神经网络模型都是基于卷积神经网络(CNN)进行设计的,这些模型往往非常庞大(比如VGG16/VGG19),因此不能直接应用于手机或其他嵌入式设备上。同时,运行这些大型模型所需要的计算资源也很昂贵。

        为了解决这个问题,Google Brain团队提出了Mobilenetv1。Mobilenetv1是基于深度可分离卷积(depthwise separable convolution)的设计,它将标准的卷积层分成深度卷积层和逐点卷积层两个部分,用较少的参数和计算量达到了相当不错的准确率。具体来说,深度卷积层用于在每个通道上执行空间卷积,而逐点卷积层(Pointwise Convolution)用于在不同通道之间执行线性变换。这种设计可以减少计算量和模型大小,并使得Mobilenetv1在移动设备上能够运行得更快。

        除此之外,Mobilenetv1还使用了其他一些技巧来进一步缩小模型。例如,通过扩张系数(expansion factor)来控制输出通道数和输入通道数之间的关系,从而精细控制模型的大小和复杂度;通过残差连接(Residual Connection)来提高信息流动,从而提高模型的准确性和训练速度。

        综合来说,Mobilenetv1是一种非常出色的深度神经网络模型,它在保持较高精确度的同时,大大减小了模型大小和计算量,使得它更容易嵌入到移动和嵌入式设备中。

MobileNetV1的主要特点如下:

  1. 轻量级:MobileNetv1的模型参数量非常少,只有4.2M,比起其他深度神经网络模型如VGG16、ResNet等模型,模型大小大大减小,更适合移动设备等资源受限环境下进行应用。

  2. 深度可分离卷积:MobileNetv1主要使用了深度可分离卷积,即将标准卷积分解成一个深度卷积和一个逐点卷积两个部分,分离后分别进行卷积操作,可以大大减少计算量和参数数量,从而实现轻量化的目的。

  3. 使用卷积核大小为1x1的卷积层和全局平均池化层:MobileNetv1使用了大量的1x1卷积层和全局平均池化层来代替传统的卷积层,可以减少特征图的空间尺寸,从而减少计算量和参数数量。

  4. 加入线性层和ReLU6激活函数:为了减少梯度消失的现象,MobileNetv1在每个深度可分离卷积结构后加入一个线性层和ReLU6激活函数,同时提高模型的非线性能力。

  5. 高性能:MobileNetv1在性能表现方面也做得很好,准确率达到了当时的state-of-the-art水平,同时模型具有高效率的特点,能够在较短的时间内完成较为复杂的任务。

MobileNetV1的创新点:

  1. Depthwise Separable Convolution(深度可分离卷积)

        MobileNetV1使用Depthwise Separable Convolution代替了传统的卷积操作。Depthwise Separable Convolution分为两个步骤,首先进行深度卷积,然后进行点卷积。深度卷积可以在每个输入通道上进行滤波操作,而点卷积使用1×1卷积来对每个通道进行线性组合。这样可以减少运算量以及减小模型的大小,同时也可以提高模型的精度和鲁棒性。

     2. Width Multiplier(宽度乘法参数)

        MobileNetV1引入了width multiplier的概念,可以通过调整宽度乘数来控制模型的大小和计算量。宽度乘数是作用于每一层的通道数目,可以取0到1的任意值。当宽度乘数为1时,模型与原始模型一致,而当宽度乘数小于1时,模型会变得更轻巧。

     3. Global Depthwise Pooling(全局深度池化)

        MobileNetV1使用Global Depthwise Pooling代替了全连接层。全局深度池化是在每个通道上进行求和操作,并将结果作为输出。这样可以有效地减少模型的参数量和计算量,提高模型的速度和精度。

        总的来说,MobileNetV1在模型轻量化方面具有显著的创新,可以在计算资源有限的设备上进行高效的推理操作,成为了移动设备上的高效神经网络模型。

MobileNetV1源码(torch版)

数据集运行代码时自动下载,如果网络比较慢,可以自行点击我分享的链接下载cifar数据集。

链接:百度网盘
提取码:kd9a 

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.autograd import Variable


class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DepthwiseSeparableConv, self).__init__()

        self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
        self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        x = self.relu(x)
        return x


class MobileNetV1(nn.Module):
    def __init__(self, num_classes=1000):
        super(MobileNetV1, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU(inplace=True)

        self.dw_separable_conv1 = DepthwiseSeparableConv(32, 64)
        self.dw_separable_conv2 = DepthwiseSeparableConv(64, 128)
        self.dw_separable_conv3 = DepthwiseSeparableConv(128, 128)
        self.dw_separable_conv4 = DepthwiseSeparableConv(128, 256)
        self.dw_separable_conv5 = DepthwiseSeparableConv(256, 256)
        self.dw_separable_conv6 = DepthwiseSeparableConv(256, 512)
        self.dw_separable_conv7 = DepthwiseSeparableConv(512, 512)
        self.dw_separable_conv8 = DepthwiseSeparableConv(512, 512)
        self.dw_separable_conv9 = DepthwiseSeparableConv(512, 512)
        self.dw_separable_conv10 = DepthwiseSeparableConv(512, 512)
        self.dw_separable_conv11 = DepthwiseSeparableConv(512, 512)
        self.dw_separable_conv12 = DepthwiseSeparableConv(512, 1024)
        self.dw_separable_conv13 = DepthwiseSeparableConv(1024, 1024)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)

        x = self.dw_separable_conv1(x)
        x = self.dw_separable_conv2(x)
        x = self.dw_separable_conv3(x)
        x = self.dw_separable_conv4(x)
        x = self.dw_separable_conv5(x)
        x = self.dw_separable_conv6(x)
        x = self.dw_separable_conv7(x)
        x = self.dw_separable_conv8(x)
        x = self.dw_separable_conv9(x)
        x = self.dw_separable_conv10(x)
        x = self.dw_separable_conv11(x)
        x = self.dw_separable_conv12(x)
        x = self.dw_separable_conv13(x)

        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
def main():
    train_data = CIFAR10('cifar',train=True,transform = transforms.ToTensor())
    data = DataLoader(train_data,batch_size=128,shuffle=True)

    device = torch.device("cuda")
    net = MobileNetV1(num_classes=10).to(device)
    print(net)
    cross = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(net.parameters(),0.001)
    for epoch in range(10):
        for img,label in data:
            img = Variable(img).to(device)
            label = Variable(label).to(device)
            output = net.forward(img)
            loss = cross(output,label)
            loss.backward()
            optimizer.zero_grad()
            optimizer.step()
            pre = torch.argmax(output,1)
            num = (pre == label).sum().item()
            acc = num / img.shape[0]
        print("epoch:",epoch + 1)
        print("loss:",loss.item())
        print("acc:",acc)
    pass


if __name__ == '__main__':
    main()

        上述代码中,我使用的是CIFAR-10数据集,通过训练MobileNet V1对图像进行分类。在训练过程中,我使用Adam优化器和交叉熵损失函数,并在训练后使用验证集评估模型的性能。

        其中,模型中使用了Depthwise Separable Convolution,它包含一层深度卷积和一层1x1卷积。深度卷积用于处理输入数据的不同通道,1x1卷积用于将不同通道的特征图合并成更多的通道。这个操作可以有效地减少参数数量和计算量,并提高模型的性能。

        另外,模型还使用了AdaptiveAvgPool2d,该层可以自适应地将输入特征图的大小调整为任意大小,并对每个子区域进行平均池化操作。这可以使模型对输入图像的尺寸具有更强的鲁棒性。

        通过MobileNet V1,我们可以在保持较高精度的同时具有较少的参数量和计算时间,在计算资源受限的情况下尤其有用。

训练10个epoch的效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/432757.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

玩转ChatGPT:中科院ChatGPT Academic项目部署与测评

一、ChatGPT Academic简介 最近,以ChatGPT为代表的超大规模语言模型火出了圈,各种二次开发项目也是层出不穷。 比如说今天我们玩弄的这个“ChatGPT Academic”,在GitHub上已经13.7K的点赞了。 项目地址:https://github.com/bina…

因为这5大工具,同事直呼我时间管理小王子

写在前面 关于时间管理、如何做计划、如何提高执行力等等相关话题其实很早之前我就想写了,但一直拖着迟迟没有动笔。 在之前的一篇文章里我曾详细聊过自己对于时间管理,如何提高执行力,以及如何摆脱那种没有灵魂的任务计划的一些思考和做法…

【C语言】深度理解指针(中)

前言✈ 上回说到,我们学习了一些与指针相关的数据类型,如指针数组,数组指针,函数指针等等,我们还学习了转移表的基本概念,学会了如何利用转移表来实现一个简易计算器。详情请点击传送门:【C语言…

Windows 下安装和使用Redis

Redis 一般安装在Linux中, 但有时出于学习和其他目的,需要在Windows机器运行Redis, 本篇介绍如果在Windows中运行和使用Redis。 关于Redis的基本介绍可以参考: Redis介绍、安装与初体验 Windows 下Redis的下载 可…

【NestJs】日志收集

Nest 附带一个默认的内部日志记录器实现,它在实例化过程中以及在一些不同的情况下使用,比如发生异常等等(例如系统记录)。这由 nestjs/common 包中的 Logger 类实现。你可以全面控制如下的日志系统的行为: 完全禁用日…

jenkins windows安装 部署项目 前端 后端

安装 需要安装的程序: 1.下载jenkins windows版本 2.400 此版本需要jdk11 https://www.jenkins.io/ 按着提示安装即可 2.下载jdk 11 https://login.oracle.com/ 按着提示安装即可 部署pc 1.新建项目 2.源码管理 3.添加git用户 4.Build Steps 构建 初始化np…

vue2数据响应式原理(2)搭建webpack认识一下Object.defineProperty

在1中我们讲到 Object.defineProperty() 是vue2实现数据响应的关键 那么我们就来好好的看看这个方法 方法字面意思是定义属性 而他是通过Object对象调用的 所以说 他是用来控制对象的某个属性的 比较官方的解释是 object.defineProperty() 方法会直接在一个对象上定义一个新属…

单片机添加版本号的一些小技巧

平时我们写程序,通常都会备注软件版本,那么,怎么在单片机中保存版本信息呢? 方法其实有很多,但基本原理都是在指定存储区域(Flash)中写入软件版本信息。 实现方法 下面就分享一个最常用&#xf…

算法风险防控

算法风险防控是指在算法应用过程中,通过对算法应用场景、数据、模型和结果等多个方面的风险进行评估和控制,以保障算法应用的安全性、可靠性和合法性。以下是一些常见的算法风险防控措施: 数据风险防控:在算法应用中,…

【python】Python基础入门:从变量到异常处理

天池实验室代码链接:https://tianchi.aliyun.com/notebook-ai/home#notebookLabId491001 简介 Python 是一种通用编程语言,其在科学计算和机器学习领域具有广泛的应用。如果我们打算利用 Python 来执行机器学习,那么对 Python 有一些基本的了…

51单片机定时器与计数器

文章目录 51单片机定时器与计数器一、定时器与计数器的结构与功能计数功能定时功能 二、定时器与计数器的控制TMOD 工作方式寄存器TCON 定时器控制寄存器 三、仿真案例(一).8个LED 1 秒周期闪烁。(二) 产品包装生产线。 51单片机定时器与计数器 一、定时器与计数器的结构与功能…

ESP32设备驱动-BMP388气压传感器驱动

BMP388气压传感器驱动 文章目录 BMP388气压传感器驱动1、BMP388介绍2、硬件准备3、软件准备4、驱动实现1、BMP388介绍 BMP388 是一款非常小巧、低功耗和低噪声的 24 位绝对气压传感器。 它可以实现精确的高度跟踪,特别适合无人机应用。 BMP388 在 0-65C 之间的同类最佳 TCO,…

港联证券|AI概念板块无死角杀跌,主题炒作熄火后资金会流向哪些板块?

ChatGPT概念指数大跌7%,单日跌幅创历史之最。 4月10日,炒作逾月的ChatGPT概念板块团体大跌,云从科技(688327.SH)、三六零(601360.SH)、科大讯飞(002230.SZ)等热门股跌停&…

集中式版本控制工具 —— SVN

一、简介 1️⃣ SVN 是什么? 代码版本管理工具他能记住每次的修改查看所有的修改记录恢复到任何历史版本恢复已经删除的文件 2️⃣ SVN 与 Git 相比有什么优势? 使用简单、上手快目录级权限控制,企业安全必备子目录 Checkout,…

RK3568平台开发系列讲解(Linux系统篇)文件系统的读写

🚀返回专栏总目录 文章目录 一、文件IO1.1、文件 IO read()1.2、文件 IO write()二、系统调用层和虚拟文件系统层三、ext4 文件系统层沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们一起学习 read 和 write 调用过程。 一、文件IO 1.1、文件 IO read() rea…

openLdap2.4.44的安装部署

openLdap2.4.44的安装部署 一、安装 1.从yum源拉取 yum install -y openldap openldap-clients openldap-servers 2.复制DB到指定目录 cp /usr/share/openldap-servers/DB_CONFIG.example /var/lib/ldap/DB_CONFIG 3.给目录授权 (如果没有ldap ,可…

定时任务框架快速入门

一、Quartz 1. Quartz 概述 Quartz 是一个开源的作业调度框架(job scheduler),几乎可以集成到任何 Java 应用程序中,从最小的独立应用程序到最大的电子商务系统。Quartz 可用于创建简单或复杂的调度来执行数十个、数百个甚至数万个作业;其任务…

[NOIP1999 普及组] Cantor 表

[NOIP1999 普及组] Cantor 表 题目描述: 现代数学的著名证明之一是 Georg Cantor 证明了有理数是可枚举的。他是用下面这一张表来证明这一命题的: 1/1 , 1/2 , 1/3 , 1/4, 1/5, … 2/1, 2/2 , 2/3, 2/4, … 3/1 , 3/2, 3/3, … 4…

win11下载配置Python环境+pycharm下载

前两天快乐的把我重装的win10升级成win11,升级的时候超怕不能成功,但效果还不错,然后突然想学一学Python,所以首先来配置环境吧 一、下载安装包 建议去官网,因为自从有了Python3之后,Python2就慢慢的被淘汰…

测试市场已经饱和了吗?现在转行软件测试会不会太迟?

非常有意思的话题,某种程度上来说,测试职场一条从未设想过的道路真的走通了。 这条路指广大测试呼吁对测试从业进行学历保护、专业保护,就像医学那样设置护城河,以一种令人意想不到的方式完成了。 得益于大量培训机构为了赚钱&a…