DeepLabV3+:Mobilenetv2的改进以及浅层特征和深层特征的融合

news2025/1/18 6:37:51

目录

Mobilenetv2的改进

浅层特征和深层特征的融合

完整代码

参考资料


Mobilenetv2的改进

在DeeplabV3当中,一般不会5次下采样,可选的有3次下采样和4次下采样。因为要进行五次下采样的话会损失较多的信息。

在这里mobilenetv2会从之前写好的模块中得到,但注意的是,我们在这里获得的特征是[-1],也就是最后的1x1卷积不取,只取循环完后的模型。

down_idx是InvertedResidual进行的次数。

# t, c, n, s
[1, 16, 1, 1], 
[6, 24, 2, 2],    2
[6, 32, 3, 2],    4
[6, 64, 4, 2],    7  
[6, 96, 3, 1],
[6, 160, 3, 2],   14
[6, 320, 1, 1], 

根据下采样的不同,当downsample_factor=8时,进行3次下采样,对倒数两次,步长为2的InvertedResidual进行参数的修改,让步长变为1,膨胀系数为2。

当downsample_factor=16时,进行4次下采样,只需对最后一次进行参数的修改。

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from net.mobilenetv2 import mobilenetv2
from net.ASPP import ASPP

class MobileNetV2(nn.Module):
    def __init__(self, downsample_factor=8, pretrained=True):
        super(MobileNetV2, self).__init__()
        
        model           = mobilenetv2(pretrained)
        self.features   = model.features[:-1]

        self.total_idx  = len(self.features)
        self.down_idx   = [2, 4, 7, 14]

        if downsample_factor == 8:
            for i in range(self.down_idx[-2], self.down_idx[-1]):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=4)
                )
        elif downsample_factor == 16:
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
        
    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        low_level_features = self.features[:4](x)
        x = self.features[4:](low_level_features)
        return low_level_features, x

forward当中,会输出两个特征层,一个是浅层特征层,具有浅层的语义信息;另一个是深层特征层,具有深层的语义信息。

浅层特征和深层特征的融合

 具有高语义信息的部分先进行上采样,低语义信息的特征层进行1x1卷积,二者进行特征融合,再进行3x3卷积进行特征提取

self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)

这一步就是获得那个绿色的特征层;

low_level_features = self.shortcut_conv(low_level_features)

从这里将是对浅层特征的初步处理(1x1卷积);

x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
x = self.cat_conv(torch.cat((x, low_level_features), dim=1))

上采样后进行特征融合;

完整代码

# deeplabv3plus.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from net.xception import xception
from net.mobilenetv2 import mobilenetv2
from net.ASPP import ASPP

class MobileNetV2(nn.Module):
    def __init__(self, downsample_factor=8, pretrained=True):
        super(MobileNetV2, self).__init__()
        
        model           = mobilenetv2(pretrained)
        self.features   = model.features[:-1]

        self.total_idx  = len(self.features)
        self.down_idx   = [2, 4, 7, 14]

        if downsample_factor == 8:
            for i in range(self.down_idx[-2], self.down_idx[-1]):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=4)
                )
        elif downsample_factor == 16:
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
        
    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        low_level_features = self.features[:4](x)
        x = self.features[4:](low_level_features)
        return low_level_features, x

class DeepLab(nn.Module):
    def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
        super(DeepLab, self).__init__()
        if backbone=="xception":
         
            #   获得两个特征层:浅层特征 主干部分    
            self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 2048
            low_level_channels = 256
        elif backbone=="mobilenet":

            #   获得两个特征层:浅层特征 主干部分
            self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 320
            low_level_channels = 24
        else:
            raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))

        #   ASPP特征提取模块
        #   利用不同膨胀率的膨胀卷积进行特征提取
        self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
       
        # 浅层特征边
        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )		

        self.cat_conv = nn.Sequential(
            nn.Conv2d(48+256, 256, kernel_size=(3,3), stride=(1,1), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Conv2d(256, 256, kernel_size=(3,3), stride=(1,1), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Dropout(0.1),
        )
        self.cls_conv = nn.Conv2d(256, num_classes, kernel_size=(1,1), stride=(1,1))

    def forward(self, x):
        H, W = x.size(2), x.size(3)

        # 获得两个特征层,low_level_features: 浅层特征-进行卷积处理
        #                x : 主干部分-利用ASPP结构进行加强特征提取
  
        low_level_features, x = self.backbone(x)
        x = self.aspp(x)
        low_level_features = self.shortcut_conv(low_level_features)

        #   将加强特征边上采样,与浅层特征堆叠后利用卷积进行特征提取
        x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
        x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
        x = self.cls_conv(x)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
        return x

参考资料

DeepLabV3-/论文精选 at main · Auorui/DeepLabV3- (github.com)

(6条消息) 憨批的语义分割重制版9——Pytorch 搭建自己的DeeplabV3+语义分割平台_Bubbliiiing的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/197659.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第四天链表

24. 两两交换链表中的节点力扣题目链接(opens new window)给定一个链表,两两交换其中相邻的节点,并返回交换后的链表。你不能只是单纯的改变节点内部的值,而是需要实际的进行节点交换。接下来就是交换相邻两个元素了,此时一定要画…

vite --- 为什么选Vite

目录 什么是Vite 为什么选Vite 现实问题 为什么生产环境仍需打包 Vite 与竞品 什么是Vite Vite(法语意为 "快速的",发音 /vit/,发音同 "veet")是一种新型前端构建工具,能够显著提升前端开发体…

SpringBoot+Vue图书馆管理系统1.0

简介:本项目采用了基本的SpringBootVue设计的图书馆管理系统。详情请看截图。经测试,本项目正常运行。本项目适用于Java毕业设计、课程设计学习参考等用途。 项目描述 项目名称SpringBootVue图书馆管理系统1.0源码作者LHL项目类型Java EE项目 &#xff…

Zebec 上线投票治理机制,全新流支付生态正在起航

随着加密货币的兴起,其除了成为一种备受关注的投资品外,它也正在成为一种新兴的支付手段。虽然在加密行业发展早期,以BTC、LTC等为代表的以支付为定位老牌加密资产,因支付效率低下、支付成本高、合规等问题而没能实现早期的愿景&a…

Node.js:CommonJS模块化规范

CommonJS 上文提到了 Node 采用的模块化规范是 CommonJS,它主要规定了如何定义模块,如果导出模块和如何导入模块: 定义模块:一个文件就是一个模块导出模块:通过 module.exports 导出模块导入模块:通过 re…

【Linux】第八部分 Linux常用基本命令

【Linux】第八部分 Linux常用基本命令 文章目录【Linux】第八部分 Linux常用基本命令8. Linux常用基本命令8.1 帮助命令8.2 文件目录类命令pwd 显示当前工作目录的绝对路径cd 切换目录ls 列出目录的内容mkdir 创建目录rmdir 删除目录touch 创建文件cp 复制文件或者目录rm 删除文…

Kaggle系列之预测泰坦尼克号人员的幸存与死亡(随机森林模型)

Kaggle是开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,本节是对于初次接触的伙伴们一个快速了解和参与比赛的例子,快速熟悉这个平台。当然提交预测结果需要注册,这个可能需要科学上网了。我们选择一个预测的入…

【操作系统】4、设备管理

文章目录四、设备管理4.1 I/O设备基本概念4.2 I/O控制方式4.2.1 程序控制方式4.2.2 中断方式4.2.3 DMA控制方式4.2.4 通道控制方式4.3 缓冲技术4.4 假脱机技术四、设备管理 I/O控制方式:程序控制、中断、DMA、通道, 缓冲技术;假脱机技术(SPO…

大龄学长的浙大MBA提面优秀之路分享

作为今年上岸浙大MBA项目的一名中年老学长,想把自己在提面中取得优秀资格的经验做个梳理供大家参考,因为以我的经历来说,我认为浙大MBA提前批面试是非常有价值的,而且在提面过程中也发现了优秀资格其实遍布于各个年龄段和层级&…

2023-02-04 Elasticsearch环境安装

1 JDK-8的安装 查询资料自我安装即可,这里不做展示。 2 Elasticsearch 的安装 Elasticsearch目录结构: 配置文件: #节点名称,集群内要唯一 node.name: node-1001node.master: true node.data: true#ip 地址 network.host: localhost htt…

细讲TCP三次握手四次挥手(一)

计算机网络体系结构 在计算机网络的基本概念中,分层次的体系结构是最基本的。计算机网络体系结构的抽象概念较多,在学习时要多思考。这些概念对后面的学习很有帮助。 网络协议是什么? 在计算机网络要做到有条不紊地交换数据,就必…

lsof - list open file

lsof 指令全称 list open file,用官方的话说 Lsof revision 4.91 lists on its standard output file information about files opened by processes -i 平常工作中,用到最多的就是 -i 参数,后面跟端口号,可以查看和这个端口有关…

【嵌入式】MDK使用sct文件将代码段放入RAM中执行

sct文件即分散加载文件,是ARMCC编译器使用的链接脚本文件,等同于GCC编译器的ld链接脚本。MDK IDE使用的是ARMCC。 支持NorFlash中运行代码(XIP)的MCU例如STM32,一般将所有代码(text段)都放在FL…

[ 云计算 | AWS ] 亚马逊云科技核心服务之计算服务(Part1:AWS EC2 星巴克为什么横向排队)

(星爸爸网络上的一张图) 注意上图中的5个人,对没错这5个人。一般情况星巴克的人员配置大概是这样的: 1个经理,在办公室两个收银,在收银台(本文关注的重点)三个人做咖啡 当你去过星巴克买咖啡时&#xff0…

【NS2】tcl与c++互相调用/传参

在NS2,做实验的时候,为了能通过循环配合传值实验,一直找不到tcl传参给c的方法,网上的只po出一部分看不懂,只能通过源码自己研究。最后的解决办法就是,模仿源码的操作,以下通过tcl→ex→sat-irid…

Navicat Monitor 3.0 现已上市 | 欢迎下载试用

Navicat Monitor 3.0 现已上市Navicat Montior 3.0 现已发布!一经发布,受到广大专业运维人员的关注与选择! 五大新亮点带给运维团队最为实用且有效地提升监控能力。其具备 PostgreSQL 服务器监控能力、支持优化慢查询、构建自定义指标、性能分析工具优化…

flutter问题

问题一1.报错:Flutter ios/Flutter/Debug.xcconfig: unable to open file (in target "Runner" in project "Runner")2.解决:cd 项目目录flutter cleanflutter create --org solanddriver .运行Xcode问题二1.Cannot run with sound …

Java线程安全问题的原因和解决方案

1.什么是线程安全2.线程不安全的原因 及 解决措施2.1 多线程同时修改同一个变量2.2 修改操作不是原子性加锁操作关键字:synchronized2.3 抢占式执行,随机调度 (根本原因)2.4内存可见性问题volatile 关键字2.5指令重排序1.什么是线程安全 线程安全的确切定义是比较复…

Java——SSM项目(瑞吉外卖)笔记

阅读提醒:最重要的内容都是我手打的字,还有截图上的红字备注部分。 nginx是一个服务器,主要部署一些静态的资源,包括后面做tomcat的集群, 可以接收前端的请求,然后分发给各个tomcat 第一步搭建数据库&…

浏览器网页视频怎么快速下载到本地?

我们在浏览网页时,经常会遇到一些特别喜欢的视频文件,想要下载收藏却苦于不会操作怎恶魔办呢?这时候可以通过一些小插件快速达成下载,比如通过猫爪视频下载插件用户可以轻松的抓取任意网页的视频文件,并将其保存到本地…