图像语义分割 pytorch复现U2Net图像分割网络详解

news2025/1/1 23:49:30

图像语义分割 pytorch复现U2Net图像分割网络详解

  • 1、U2Net网络模型结构
  • 2、block模块结构解析
    • RSU-7模块
    • RSU-4F
    • saliency map fusion module
  • U2Net网络结构详细参数配置
  • RSU模块代码实现
  • RSU4F模块代码实现
  • u2net_full与u2net_lite模型配置函数
  • U2Net网络整体定义类
  • 损失函数计算
  • 评价指标
  • 数据集
  • pytorch训练U2Net图像分割模型

在这里插入图片描述
U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection

1、U2Net网络模型结构

在这里插入图片描述
网络的主体类似于U-Net的网络结构,在大的U-Net中,每一个小的block都是一个小型的类似于U-Net的结构,因此作者取名U2Net
仔细观察,可以将网络中的block分成两类:
第一类:En_1 ~ En_4 与 De_1 ~ De_4这8个block采用的block其实是一样的,只不过模块的深度不同。

第二类:En_5、En_6、De_5

  • 在整个U2Net网络中,在Encoder阶段,每通过一个block都会进行一次下采样操作(下采样2倍,maxpool)
  • 在Decoder阶段,在每个block之间,都会进行一次上采样(2倍,bilinear)

2、block模块结构解析

在 En_1 与 De_1 模块中,采用的 block 是RSU-7;
En_2 与 De_2采用的 block 是RSU-6(RSU-6相对于RSU-7 就是少了一个下采样卷积以及上采样卷积的部分,RSU-6 block只会下采样16倍,RSU-7 block下采样的32倍);
En_3 与 De_3采用的 block 是RSU-5
En_4 与 De_4采用的 block 是RSU-4
En_5、En_6、De_5采用的block是RSU-4F
(使用RSU-4F的原因:因为数据经过En_1 ~ En4 下采样处理后对应特征图的高与宽就已经相对比较小了,如果再继续下采样就会丢失很多上下文信息,作者为了保留上下文信息,就对En_5、En_6、De_5不再进行下采样了而是在RSU-4F的模块中,将下采样、上采样结构换成了膨胀卷积)

RSU-7模块

在这里插入图片描述详细结构图解
在这里插入图片描述

RSU-4F

在这里插入图片描述

saliency map fusion module

saliency map fusion module模块是将每个阶段的特征图进行融合,得到最终的预测概率图,即下图中,红色框标注的模块
在这里插入图片描述
其会收集De_1、De_2、De_3、De_4、De_5、En_6模块的输出,将这些输出分别通过一个3x3的卷积层(这些卷积层的kerner的个数都是为1)输出的featuremap的channel是为1的,在经过双线性插值算法将得到的特征图还原回输入图像的大小;再将得到的6个特征图进行concant拼接;在经过一个1x1的卷积层以及sigmoid激活函数,最终得到融合之后的预测概率图。

U2Net网络结构详细参数配置

在这里插入图片描述
u2net_full大小为176.3M、u2net_lite大小为4.7M

RSU模块代码实现

在这里插入图片描述

class RSU(nn.Module):
    def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):
        super().__init__()

        assert height >= 2
        self.conv_in = ConvBNReLU(in_ch, out_ch)

        encode_list = [DownConvBNReLU(out_ch, mid_ch, flag=False)]
        decode_list = [UpConvBNReLU(mid_ch * 2, mid_ch, flag=False)]
        for i in range(height - 2):
            encode_list.append(DownConvBNReLU(mid_ch, mid_ch))
            decode_list.append(UpConvBNReLU(mid_ch * 2, mid_ch if i < height - 3 else out_ch))

        encode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2))
        self.encode_modules = nn.ModuleList(encode_list)
        self.decode_modules = nn.ModuleList(decode_list)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_in = self.conv_in(x)

        x = x_in
        encode_outputs = []
        for m in self.encode_modules:
            x = m(x)
            encode_outputs.append(x)

        x = encode_outputs.pop()
        for m in self.decode_modules:
            x2 = encode_outputs.pop()
            x = m(x, x2)

        return x + x_in

RSU4F模块代码实现

在这里插入图片描述

class RSU4F(nn.Module):
    def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
        super().__init__()
        self.conv_in = ConvBNReLU(in_ch, out_ch)
        self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),
                                             ConvBNReLU(mid_ch, mid_ch, dilation=2),
                                             ConvBNReLU(mid_ch, mid_ch, dilation=4),
                                             ConvBNReLU(mid_ch, mid_ch, dilation=8)])

        self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch * 2, mid_ch, dilation=4),
                                             ConvBNReLU(mid_ch * 2, mid_ch, dilation=2),
                                             ConvBNReLU(mid_ch * 2, out_ch)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_in = self.conv_in(x)

        x = x_in
        encode_outputs = []
        for m in self.encode_modules:
            x = m(x)
            encode_outputs.append(x)

        x = encode_outputs.pop()
        for m in self.decode_modules:
            x2 = encode_outputs.pop()
            x = m(torch.cat([x, x2], dim=1))

        return x + x_in

u2net_full与u2net_lite模型配置函数

def u2net_full(out_ch: int = 1):
    cfg = {
        # height, in_ch, mid_ch, out_ch, RSU4F, side     side:表示是否要收集当前block的输出
        "encode": [[7, 3, 32, 64, False, False],      # En1
                   [6, 64, 32, 128, False, False],    # En2
                   [5, 128, 64, 256, False, False],   # En3
                   [4, 256, 128, 512, False, False],  # En4
                   [4, 512, 256, 512, True, False],   # En5
                   [4, 512, 256, 512, True, True]],   # En6
        # height, in_ch, mid_ch, out_ch, RSU4F, side
        "decode": [[4, 1024, 256, 512, True, True],   # De5
                   [4, 1024, 128, 256, False, True],  # De4
                   [5, 512, 64, 128, False, True],    # De3
                   [6, 256, 32, 64, False, True],     # De2
                   [7, 128, 16, 64, False, True]]     # De1
    }

    return U2Net(cfg, out_ch)


def u2net_lite(out_ch: int = 1):
    cfg = {
        # height, in_ch, mid_ch, out_ch, RSU4F, side
        "encode": [[7, 3, 16, 64, False, False],  # En1
                   [6, 64, 16, 64, False, False],  # En2
                   [5, 64, 16, 64, False, False],  # En3
                   [4, 64, 16, 64, False, False],  # En4
                   [4, 64, 16, 64, True, False],  # En5
                   [4, 64, 16, 64, True, True]],  # En6
        # height, in_ch, mid_ch, out_ch, RSU4F, side
        "decode": [[4, 128, 16, 64, True, True],  # De5
                   [4, 128, 16, 64, False, True],  # De4
                   [5, 128, 16, 64, False, True],  # De3
                   [6, 128, 16, 64, False, True],  # De2
                   [7, 128, 16, 64, False, True]]  # De1
    }

U2Net网络整体定义类

class U2Net(nn.Module):
    def __init__(self, cfg: dict, out_ch: int = 1):
        super().__init__()
        assert "encode" in cfg
        assert "decode" in cfg
        self.encode_num = len(cfg["encode"])

        encode_list = []
        side_list = []
        for c in cfg["encode"]:
            # c: [height, in_ch, mid_ch, out_ch, RSU4F, side]
            assert len(c) == 6
            encode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))     # 判断当前是构建RSU模块,还是构建RSU4F模块

            if c[5] is True:
                side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))
        self.encode_modules = nn.ModuleList(encode_list)

        decode_list = []
        for c in cfg["decode"]:
            # c: [height, in_ch, mid_ch, out_ch, RSU4F, side]
            assert len(c) == 6
            decode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))

            if c[5] is True:
                side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))    # 收集当前block的输出
        self.decode_modules = nn.ModuleList(decode_list)
        self.side_modules = nn.ModuleList(side_list)
        self.out_conv = nn.Conv2d(self.encode_num * out_ch, out_ch, kernel_size=1)   # 构建一个1x1的卷积层,去融合来自不同尺度的信息

    def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
        _, _, h, w = x.shape

        # collect encode outputs
        encode_outputs = []
        for i, m in enumerate(self.encode_modules):
            x = m(x)
            encode_outputs.append(x)
            if i != self.encode_num - 1:  # 此处需要进行判断,因为在没通过一个encoder模块后,都需要进行下采样的,但最后一个模块后,是不需要下采样的
                x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)

        # collect decode outputs
        x = encode_outputs.pop()
        decode_outputs = [x]
        for m in self.decode_modules:
            x2 = encode_outputs.pop()
            x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
            x = m(torch.concat([x, x2], dim=1))
            decode_outputs.insert(0, x)

        # collect side outputs
        side_outputs = []
        for m in self.side_modules:
            x = decode_outputs.pop()
            x = F.interpolate(m(x), size=[h, w], mode='bilinear', align_corners=False)
            side_outputs.insert(0, x)

        x = self.out_conv(torch.concat(side_outputs, dim=1))

        if self.training:
            # do not use torch.sigmoid for amp safe
            return [x] + side_outputs     # 用于计算损失
        else:
            return torch.sigmoid(x)

损失函数计算

在这里插入图片描述
如上图所示,红色框部分为每个分量与真实标签的交叉熵损失函数求和;黄色框标部分为将各个分量经双线性插值恢复至原始尺寸、进行concant处理、经过1x1的卷积核与sigmoid处理后的结果与真实标签的交叉熵损失函数。
损失函数代码实现:

import math
import torch
from torch.nn import functional as F
import train_utils.distributed_utils as utils


def criterion(inputs, target):
    losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]
    total_loss = sum(losses)

    return total_loss

评价指标

在这里插入图片描述
其中F-measure是在0~1之间的,数值越大,代表的网络分割效果越好;
MAE是Mean Absolute Error的缩写,其值是在0~1之间的,越趋近于0,代表网络性能越好。

数据集

在这里插入图片描述
在这里插入图片描述

pytorch训练U2Net图像分割模型

项目目录结构:

├── src: 搭建网络相关代码
├── train_utils: 训练以及验证相关代码
├── my_dataset.py: 自定义数据集读取相关代码
├── predict.py: 简易的预测代码
├── train.py: 单GPU或CPU训练代码
├── train_multi_GPU.py: 多GPU并行训练代码
├── validation.py: 单独验证模型相关代码
├── transforms.py: 数据预处理相关代码
└── requirements.txt: 项目依赖

项目目录:
在这里插入图片描述
项目中u2net_full大小为176.3M、u2net_lite大小为4.7M,演示过程中,训练的为u2net_lite版本
多GPU训练指令:
pytorch版本为1.7

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py --data-path ./data_root

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1097873.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LXC、Docker、 Kubernetes 容器以及Hypervisor的区别

LXC、Docker、 Kubernetes 容器以及Hypervisor的区别 SaaS: Software-as-a-Service&#xff08;软件即服务&#xff09; PaaS: Platform-as-a-Service&#xff08;平台即服务&#xff09; IaaS: Infrastructure-as-a-Service&#xff08;基础设施即服务&#xff09; 1、Docke…

划重点!3DEXPERIENCE SOLIDWORKS 2024 十大增强功能

SOLIDWORKS 2024 以更加强大的姿态亮相&#xff0c;帮助您重塑设计。为了助力您简化和加快由概念到成品的产品开发流程&#xff0c;SOLIDWORKS 2024 涵盖全新以用户为中心的增强功能&#xff0c;致力帮您实现更智能、更快速地与您的团队和外部合作伙伴协同工作。本篇为大家介绍…

微信公众号怎么从个人转为企业?

公众号账号迁移的作用是什么&#xff1f;只能变更主体吗&#xff1f;1.可合并多个公众号的粉丝、文章&#xff0c;打造超级大V2.可变更公众号主体&#xff0c;更改公众号名称&#xff0c;变更公众号类型——订阅号、服务号随意切换3.可以增加留言功能4.个人订阅号可迁移到企业名…

封箱机不打盖是什么原因?

折盖封箱机是封箱机中自动化程度比较高的一款设备&#xff0c;该机既可以单机使用&#xff0c;也可以配自动化流水线一起使用&#xff0c;实现无人化操作&#xff0c;但这款设备在使用的过程中有时会出现一些小问题&#xff0c;今天就其中的一个常见故障-----不打盖的问题和您做…

单片机TDL的功能、应用与技术特点 | 百能云芯

在现代电子领域中&#xff0c;单片机&#xff08;Microcontroller&#xff09;是一种至关重要的电子元件&#xff0c;广泛应用于各种应用中。TDL&#xff08;Time Division Multiplexing&#xff0c;时分多路复用&#xff09;是一种数据传输技术&#xff0c;结合单片机的应用&a…

MySQL InnoDB引擎深入学习的一天(InnoDB架构 + 事务底层原理 + MVCC)

目录 逻辑存储引擎 架构 概述 内存架构 Buffer Pool Change Buffe Adaptive Hash Index Log Buffer 磁盘结构 System Tablespace File-Per-Table Tablespaces General Tablespaces Undo Tablespaces Temporary Tablespaces Doublewrite Buffer Files Redo Log 后台线程 事务原…

关于脑部的基础知识

脑部的基础知识 1 解剖学基本术语&#xff1a;1.1 解剖学方向1.2 解剖学平面1.3 神经元集合体1.4 神经元轴突集合体 2 中枢神经系统CNS2.1 脑 Brain2.1.1 **大脑** 大脑皮层 皮层下结构2.1.2 **间脑** **丘脑 下丘脑 垂体**2.1.3 **中脑 ** **顶盖 ** **大脑脚**2.1.4 脑桥…

【算法|前缀和系列No.2】牛客网 DP35 【模板】二维前缀和

个人主页&#xff1a;兜里有颗棉花糖 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【手撕算法系列专栏】【牛客网刷题】 &#x1f354;本专栏旨在提高自己算法能力的同时&#xff0c;记录一下自己的学习过程&#xff0c;希…

功能测试-本地缓存修改时间戳:操作支付中心页面触发二次挽留弹窗

需求&#xff1a;在活动期间&#xff0c;指定的用户在关闭支付中心个人会员页时&#xff0c;增加二次挽留弹窗机制 涉及端口&#xff1a;PC官网支付中心、移动端官网支付中心、插件端支付中心 触发频率&#xff1a;用户每天触发一次&#xff08;需求写的其实这是错误的&#…

实赣!赣州与开源网安联手打造软件供应链安全检测中心

10月16日&#xff0c;开源网安与赣州市行政审批局、赣州市网信办三方先后签署战略协议及投资协议&#xff0c;签约后将在赣州打造软件供应链安全检测中心&#xff0c;为数字政府、数字经济等领域提供全面安全检测和软件安全运营监测等服务&#xff0c;提升软件的安全与质量&…

【IEEE会议】第三届信息技术与当代体育国际学术会议(TCS 2023)

【IEEE】第三届信息技术与当代体育国际学术会议&#xff08;TCS 2023&#xff09; 2023 3rd International Conference on Information Technology and Contemporary Sports 2023年第三届信息技术与当代体育国际学术会议&#xff08;TCS 2023&#xff09;将于2023年12月22-24…

阿里云服务器ECS实例规格族c/g/r等字母说明

阿里云服务器ECS实例命名规则&#xff1a;ecs.<规格族>.large字母含义命名说明&#xff0c;包括x86、ARM架构、GPU异构计算、弹性裸金属、超级计算集群SCC云服务器&#xff0c;c代表计算型、g代表通用型、r代表内存型、u代表通用算力型、e代表经济型e实例&#xff0c;阿里…

【题解】[NOIP2016]玩具谜题

题目描述 P1563 [NOIP2016 提高组] 玩具谜题 前置知识 无 题目分析 题目比较绕&#xff0c;关键是要搞清楚顺时针、逆时针和左、右&#xff0c;把每个指令转换为数组下标的移动。 首先提取出关键信息&#xff1a; 输入顺序为逆时针 面朝圈内&#xff1a;左->顺时针&am…

JAVASSMMYSQL高校学生选课系统01483-计算机毕业设计项目选题推荐(附源码)

目 录 摘要 1 绪论 1.1 研究背景 1.2开发意义 1.3ssm框架 1.4论文结构与章节安排 2 2 高校学生选课系统系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1 数据增加流程 2.2.2 数据修改流程 2.2.3数据删除流程 2.3 系统功能分析 2.3.1功能性分析 2.3.2非功能性分析…

哪个文字转语音配音软件最好用?

现在TTS技术不断发展&#xff0c;文字转语音技术已经越来越成熟&#xff0c;声音听着拟人度非常高&#xff0c;现在好用的软件也不在少数。很多手机里面都有自带的朗读功能&#xff0c;如果觉得声音不够&#xff0c;也可以自己下载软件使用。给大家分享一下我一直使用的一款文字…

全志R128 BLE最高吞吐量测试正确配置测试

在R128使用前我们需要了解BLE的最高吞吐量&#xff0c;以方便评估相关功能的开发。 首先我们了解一下哪些因素会影响蓝牙的吞吐量&#xff1a; 1、蓝牙版本与PHY&#xff1a; 蓝牙设备的版本和物理层&#xff08;PHY&#xff09;对于吞吐量有很大影响。例如&#xff0c;R128设…

基于Python的车牌识别系统实现

本文将以基于Python的车牌识别系统实现为方向&#xff0c;介绍车牌识别技术的基本原理、常用算法和方法&#xff0c;并详细讲解如何利用Python语言实现一个完整的车牌识别系统。 目录 引言车牌识别技术的应用场景Python在车牌识别领域的优势 车牌识别技术概述图像处理和计算机视…

《软件方法》第1章2023版连载(07)UML的历史和现状

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 1.3 统一建模语言UML 1.3.1 UML的历史和现状 上一节阐述了A→B→C→D的推导是不可避免的&#xff0c;但具体如何推导&#xff0c;有各种不同的做法&#xff0c;这些做法可以称为“方…

正点原子嵌入式linux驱动开发——新字符设备驱动实验

经过之前两篇笔记的实战操作&#xff0c;已经掌握了Linux字符设备驱动开发的基本步骤&#xff0c;字符设备驱动开发重点是使用register_chrdev函数注册字符设备&#xff0c;当不再使用设备的时候就使用unregister_chrdev函数注销字符设备&#xff0c;驱动模块加载成功以后还需要…

广义回归神经网络预测程序

欢迎关注“电击小子程高兴的MATLAB小屋” %% 学习目标:广义回归神经网络 %% 训练速度快 非线性映射能力强 常用于函数逼近 clear all; close all; P1:30; T3*sin(P); netnewgrnn(P,T,0.3); %径向基函数的分布密度是0.3 ysim(net,P); figure; plot(P,T,:,P,T-y,-o);