DeepLabV3+:搭建Mobilenetv2网络

news2024/11/8 23:54:49

目录

Mobilenetv2的介绍

Mobilenetv2的结构

Inverted Residual Block倒残差结构 

Pytorch实现Inverted Residual Block

搭建Mobilenetv2

Pytorch实现Mobilenetv2主干网络

相关参考资料


Mobilenetv2的介绍

Mobilenetv2网络设计基于Mobilenetv1,它保持了其简单性,不需要任何特殊的操作,同时显著提高了其准确性,实现了移动应用的多图像分类和检测任务的最先进水平。

MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩。另外,我们发现移除通道数很少的层做线性激活非常重要。

论文对模型在ImageNet分类、COCO目标检测和VOC图像分割的表现进行了度量,评估权衡了精度、乘加操作次数,实际延迟和参数的数量。

Mobilenetv2的结构

Inverted Residual Block倒残差结构 

 可以看见在我们上图的右边,就是倒残差结构,它会经历以下部分:

  • 1x1卷积升维
  • 3x3卷积DW
  • 1x1卷积降维

接下来请结合着下面的代码来看,首先有一个expand_ratio来表示是否对输入进来的特征层进行升维,如果不需要就会进行卷积、标准化、激活函数、卷积、标准化。不然就会先有1x1卷积进行通道数的上升,在用3x3逐层卷积,进行跨特征点的特征提取,最后1x1卷积进行通道数的下降。

上升是为了让我们的网络结构有具备更好的特征表征能力,下降是为了让我们的网络具备更低的运算量,在完成这样的特征提取后,如果要使用残差边,我们就会将特征提取的结果直接与输入相接,如果没有使用残差边,就会直接输出卷积结果。

Pytorch实现Inverted Residual Block

import torch.nn as nn

BatchNorm2d = nn.BatchNorm2d

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # 进行3x3的逐层卷积,进行跨特征点的特征提取
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=(3,3), stride=stride, padding=1, groups=hidden_dim, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # 利用1x1卷积进行通道数的调整
                nn.Conv2d(hidden_dim, oup, kernel_size=(1,1), stride=(1,1), padding=0, bias=False),
                BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # 利用1x1卷积进行通道数的上升
                nn.Conv2d(inp, hidden_dim, kernel_size=(1,1), stride=(1,1), padding=0, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # 进行3x3的逐层卷积,进行跨特征点的特征提取
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=(3,3), stride=stride, padding=1, groups=hidden_dim, bias=False),
                BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # 利用1x1卷积进行通道数的下降
                nn.Conv2d(hidden_dim, oup, kernel_size=(1,1), stride=(1,1), padding=0, bias=False),
                BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

搭建Mobilenetv2

在这里它的实现还是相对比较清晰的。在建立Mobilenetv2前,首先先定义了bn卷积,只有卷积核的大小有所不同,具体可以看下面pytoch实现当中。

变量features会先对图片有3x3大小、步长为2d的卷积进行一个高和宽的压缩。接下来会进入一个列表的循环,t表示是否进行1*1卷积上升的过程,c表示output_channel大小,n表示小列表倒残差次数,s是步长,表示是否对高和宽进行压缩。

那么这样来看,如果最初图片为(512,512,3),经过features后,在经过循环列表会有这样的处理。

  • 输入features:512,512,3 -> 256, 256, 32
  • 第1次循环:256, 256, 32 -> 256, 256, 16
  • 第2次循环:256, 256, 16 -> 128, 128, 24   
  • 第3次循环:128, 128, 24 -> 64, 64, 32     
  • 第4次循环:64, 64, 32 -> 32, 32, 64       
  • 第5次循环:32, 32, 64 -> 32, 32, 96
  • 第6次循环:32, 32, 96 -> 16, 16, 160     
  • 第7次循环:16, 16, 160 -> 16, 16, 320

接着会用1x1卷积调整通道数,完成features的建立。

论文给出的:

Pytorch实现Mobilenetv2主干网络

import math
import torch.nn as nn


BatchNorm2d = nn.BatchNorm2d

def conv_bn(inp, oup, strides):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size=(3,3), stride=strides, padding=1, bias=False),
        BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size=(1,1), stride=(1,1), padding=0, bias=False),
        BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1], 
            [6, 24, 2, 2],   
            [6, 32, 3, 2],     
            [6, 64, 4, 2],      
            [6, 96, 3, 1],
            [6, 160, 3, 2],   
            [6, 320, 1, 1], 
        ]

        assert input_size % 32 == 0
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel

        self.features = [conv_bn(3, input_channel, 2)]

        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel

        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        self.features = nn.Sequential(*self.features)

        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, n_class),
        )

        self.initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

def mobilenetv2(pretrained=False, **kwargs):
    model = MobileNetV2(n_class=1000, **kwargs)
    if pretrained:
        pass
    return model

if __name__ == "__main__":
    model = mobilenetv2()
    for i, layer in enumerate(model.features):
        print(i, '->', layer)

运行成功,至此mobielnetv2的搭建完成

 

相关参考资料

DeepLabV3-/Mobilenetv2.pdf at main · Auorui/DeepLabV3- (github.com)

MobileNet_v2模型解读

MobileNet_v2模型解读——知乎

憨批的语义分割重制版9——Pytorch 搭建自己的DeeplabV3+语义分割平台

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/185409.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【进击的算法】动态规划——01背包

🍿本文主题:动态规划 01背包 背包问题 C/C 算法 🎈更多算法:基础回溯算法 基础动态规划 💕我的主页:蓝色学者的主页 文章目录一、前言二、概念✔️动态规划概念✔️01背包的概念三、问题描述与讲解&#x1…

spring 中 mybaits 的一级缓存失效

mybatis 的一级缓存 简单回顾下mybatis的一级缓存 本质上是一个基于map实现的内存级别的缓存,默认开启,生命周期是 sqlsession 级别的 为什么会失效 其实这个问题反向分析一下就会有思路了,一级缓存默认是sqlsession级别的,这个规…

2022年rust杂记

以下记录的是,我在学习中的一些学习笔记,这篇笔记是自己学习的学习大杂烩,主要用于记录,方便查找1、相关学习链接https://www.rust-lang.org/zh-CN/governance/ RUST 官网博客https://kaisery.github.io/trpl-zh-cn/(最…

应用性能监控对DMS系统综合分析案例

背景 DMS系统是某汽车集团的经销商在线系统,是汽车集团的重要业务系统。本次分析重点针对DMS系统性能进行分析,以供安全取证、性能分析、网络质量监测以及深层网络分析。 该汽车总部已部署NetInside流量分析系统,使用流量分析系统提供实时和…

好好的系统,为什么要分库分表?

不急于上手实战 ShardingSphere 框架,先来复习下分库分表的基础概念,技术名词大多晦涩难懂,不要死记硬背理解最重要,当你捅破那层窗户纸,发现其实它也就那么回事。 什么是分库分表 分库分表是在海量数据下&#xff0…

51单片机学习笔记-14 ADDA

14 ADDA [toc] 注:笔记主要参考B站江科大自化协教学视频“51单片机入门教程-2020版 程序全程纯手打 从零开始入门”。 注:工程及代码文件放在了本人的Github仓库。 14.1 AD/DA简介 14.1.1 AD/DA基本介绍 AD(Analog to Digital)…

FreeRTOS任务管理

RTOS 的核心是如果高效管理各个任务及任务之间通信,本章将向大家介绍 FreeRTOS 的任务管理,通过本章的学习,让大家对 RTOS 任务的理解更加深入, 为后面的学习做好铺垫。本章分为如下几部分内容: 1 任务管理介绍 2 常用…

ue4c++日记7(动画蓝图)

FVector Speed Pawn->GetVelocity();//获取方向向量FVector xyspeed FVector(Speed.X, Speed.Y,0);//不要z方向MovementSpeed xyspeed.Size();//xy取长//角色是否处于下落状态IsJumping Pawn->GetMovementComponent()->IsFalling();//#include "GameFramewor…

FreeRTOS中的信号量实验

信号量是操作系统中重要的一部分,信号量一般用来进行资源管理和任务同 步,FreeRTOS 中信号量又分为二值信号量、计数型信号量、互斥信号量和递归 互斥信号量。不同的信号量其应用场景不同,但有些应用场景是可以互换着使用。 本章要实现的功能…

【数据结构从0到1之树的初识】

目录 1.树的表达方式 1.1 树的定义 1.2树的相关概念 1.3树的存储结构 1.3.1 双亲表示法 1.3.2 孩子表示法 1.3.3 孩子兄弟表示法 1.4树在实际中的应用 后记: 🕺作者: 迷茫的启明星 😘欢迎关注:👍点…

Lua 迭代器

Lua 迭代器 参考文章: 菜鸟教程。 https://cloud.tencent.com/developer/article/2203215 迭代器(iterator)是一种对象,它能够用来遍历标准模板库容器中的部分或全部元素,每个迭代器对象代表容器中的确定的地址。 在 L…

23种设计模式之七种结构型模式

23种设计模式之七种结构型模式1. 设计模式概述1.1 什么是设计模式1.2 设计模式的好处2. 设计原则分类3. 详解3.1 单一职责原则3.2 开闭原则3.3 里氏代换原则3.4 依赖倒转原则3.5 接口隔离原则3.6 合成复用原则3.7 迪米特法则4. Awakening1. 设计模式概述 我们的软件开发技术也包…

[Python从零到壹] 番外篇之可视化利用D3库实现CSDN博客每日统计效果(类似github)

欢迎大家来到“Python从零到壹”,在这里我将分享约200篇Python系列文章,带大家一起去学习和玩耍,看看Python这个有趣的世界。所有文章都将结合案例、代码和作者的经验讲解,真心想把自己近十年的编程经验分享给大家,希望…

关于对公司做项目的一些想法

项目管理法则里面最重要的是如下的三角形:基于一定的范围、合理的时间和足够的成本下实现项目完成,并保证质量。项目中最重要的是质量,质量不行就意味着项目失败,请参考大跃进时期的大炼钢铁(多快好省大炼钢&#xff0…

是什么影响了 MySQL 索引 B + 树的高度?

提到 MySQL,想必大多后端同学都不会陌生,提到 B 树,想必还是有很大部分都知道 InnoDB 引擎的索引实现,利用了 B 树的数据结构。 那 InnoDB 的一棵 B 树可以存放多少行数据?它又有多高呢? 到底是哪些因…

WebRTC → 信令服务器

相关简介 信令:驱动系统运转。控制各个模块的前后调用关系;业务不同,逻辑不同,信令也会千差万别 要实现一对一通信,驱动系统的核心就是信令。信令控制着系统各个模块之间的前后调用关系,比如当收到用户成功加入房间后…

3D模型在线查看利器【多种格式】

BimAnt 3DViewer网站可以 打开多种 3D 文件格式并在你的浏览器中可视化展示3D模型,支持 obj、3ds、stl、ply、gltf、glb、off、 3dm、fbx 等等。 1、支持的3D模型格式 BimAnt 3DViewer网站支持多种文件格式的导入和导出。 如果文件格式有文本和二进制版本&#x…

Minecraft 1.19.2 Fabric模组开发 09.Mixin

我们今天用mixin在1.19.2 fabric中实现一个望远镜 1.由于fabric已经自动配置好了mixin,所以我们无需配置mixin,先在ItemInit中新建一个我们的望远镜物品: ItemInit.java public static final Item BIRDWATCHER registerItem("birdwat…

Smart-doc的脚本生成在线文档(精简官方文档描述)

Smart-doc优点: 无侵入的接口文档、在线文档生成器。三种生成文档方式。对于程序代码开发中只需要加注释(符合一定的语法,五分钟可掌握)就能生成在线文档。可以支持c、java、php、node等等常见的主流语言。 如何使用: …

47.Isaac教程--ORB

ORB ISAAC教程合集地址: https://blog.csdn.net/kunhe0512/category_12163211.html 文章目录ORBGem 提供的类型关键点描述符如何使用 Gem(界面)构建包Isaac Codelets示例应用程序主机设备嵌入式 Jetson 设备这个 gem 提供了一个特征检测器和描述符提取器…