神经网络(五):U2Net图像分割网络

news2024/9/28 14:46:28

文章目录

  • 一、网络结构
    • 1.1第一种block结构
    • 1.2第二种block结构
    • 1.3特征图融合模块
    • 1.4损失函数
    • 1.5总体网络架构
    • 1.6代码汇总
    • 1.7普通残差块与RSU对比
  • 二、代码复现


  参考论文:U2-Net: Going deeper with nested U-structure for salient object detection
  这篇文章基于显著目标检测任务提出,显著目标检测是指将图像中最吸引人的目标或者区域分割出来,因此只有前景和背景两个类别,相当于语义分割中的二分类任务。例如,下图中展示了三张图片进行显著目标检测后的结果:
在这里插入图片描述
其中,白色区域代表前景 ,即最吸引人的目标或区域,而黑色区域代表背景。
  在 U 2 N e t U^2Net U2Net被提出前,显著目标检测领域主要面临两个问题:

  • 1.现有的SOD网络大多基于当时已有的网络架构(主干网络)进行深度特征的提取,如 A l e x N e t 、 V G G 、 R e s N e t 、 R e s N e X t AlexNet、VGG、ResNet、ResNeXt AlexNetVGGResNetResNeXt等。这些网络最初是为图像分类设计的,它们提取代表语义含义的特征(如某一具体的猫、狗),而非局部细节和全局对比度信息(SOD的主要任务是将图像划分为前景与背景,而不注重某一具体特征的提取),使得这些模型在进行SOD任务时效率低下。
  • 2.当时的SOD网络架构不断通过向现有的主干网络中添加特征聚合模块以提取多级显著特征,使得模型过于复杂。且这些图像分类模型往往通常通过牺牲特征图的高分辨率来实现更深的架构,即特征图在早期阶段会被缩小到较低的分辨率,如ResNet和DenseNet会使用步长为2的卷积和步长为2的最大池化将特征图大小减小到输入图的四分之一。但是,高分辨率在图像分割中有着重要作用,这也使得这些模型并不适用SOD任务。

  为解决上述问题,提出了 U 2 N e t U^2Net U2Net网络结构:

  • 一种两级嵌套的U形结构,专为SOD 设计,无需使用任何来自图像分类的预训练主干。
  • 提出U型残差块RSU,能在不降低特征图分辨率的情况下提取阶段内多尺度特征。

一、网络结构

   U 2 − N e t U^2-Net U2Net网络基于 U N e t UNet UNet网络设计而来。事实上,该网络的整体结构与 U N e t UNet UNet网络几乎相同,但所使用的上采样、下采样模块变成了小型的 U N e t UNet UNet网络,即, U 2 N e t U^2Net U2Net本质上是 U N e t UNet UNet网络的嵌套。而 U 2 − N e t U^2-Net U2Net网络的核心就是这些作为模块的小型 U N e t UNet UNet网络,并将其起名为 R e S i d u a l U − b l o c k ( R S U ,残差 U 块 ) ReSidual U-block(RSU,残差U块) ReSidualUblock(RSU,残差U)。网络结构如下:
在这里插入图片描述
这些模块其实可以分为两种, E n c o d e r 1   E n c o d e r 4 、 D e c o d e r 1   D e c o d e r 4 Encoder1~Encoder4、Decoder1~Decoder4 Encoder1 Encoder4Decoder1 Decoder4采用的是同一种结构的残差块,只不过深度不同,而Encoder5、Encoder6、Decoder5 采用的是另一种结构的残差块。整体流程可概况为:

  • Encoder阶段:每通过一个模块后都会下采样两倍,使用的是torch.nn.MaxPool2d
  • Decoder阶段:每通过一个模块后都会上采用两倍,使用的是torch.nn.functional.interpolate()
  • 跳跃链接:与 U N e t UNet UNet网络思路相同,将编码器的输出与解码器输出的特征图进行拼接,最后得到分割后的图像。

1.1第一种block结构

  本地和全局上下文信息对于显著对象检测和图像分割任务都非常重要,现代CNN网络设计中VGG、ResNet、DenseNet 等,一般使用1x1或3x3的小型卷积核提取特征。但在SOD任务中,由于它们的感受野太小而无法捕捉全局信息,使得浅层的输出特征图仅包含局部特征。在下图(图 ( a ) − ( c ) (a)-(c) (a)(c))中给出了具有小感受野的典型现有卷积块。为从浅层获得高分辨率特征图的更多全局信息,最直接的想法是扩大感受野,图 ( e ) (e) (e)是一种双向消息传递模块(见论文ieee),它试图通过使用扩张卷积扩大感受野来提取局部和非局部特征,以原始分辨率对输入特征图进行多次扩张卷积(尤其是在早期阶段)需要太多的计算和内存资源。
  为解决上述问题,本文提出了RSU模块(图 ( e ) (e) (e),L表示RSU的深度,图中L=7):

  • 一个输入卷积层,用于将尺寸为 ( H , W , C i n ) (H,W,C_{in}) (H,W,Cin)输入特征图x添加到中间映射 F 1 ( x ) F_1(x) F1(x)中,这是一个用于局部特征提取的普通卷积层,通道数为 C o u t C_{out} Cout
  • 高度为L的类似 U N e t UNet UNet的对称编码器-解码器结构,采用 F 1 ( x ) F_1(x) F1(x)作为输入,学习多尺度上下文信息。用 U ( x ) U(x) U(x)表示该类似 U N e t UNet UNet的结构,则提取的信息可表示为 U ( F 1 ( x ) ) U(F_1(x)) U(F1(x))。较大的 L 会生成更深的RSU、更多的池化操作、更大的感受野范围以及更丰富的局部和全局特征。配置此参数可以从具有任意空间分辨率的输入特征图中提取多尺度特征。从逐渐下采样的特征图中提取多尺度特征,并通过渐进式上采样、连接和卷积,将特征图还原为高分辨率特征图。此过程可减轻因使用大比例的直接上采样而导致的精细细节损失。
  • 通过求和融合局部特征和多尺度特征的残差连接: F 1 ( x ) + U ( F 1 ( x ) ) F_1(x)+U(F_1(x)) F1(x)+U(F1(x))

在这里插入图片描述

   R S U − 7 RSU-7 RSU7真实结构如下图所示:
在这里插入图片描述
  回到 U 2 − N e t U^2-Net U2Net结构,该RSU的使用场景有:

  • Encoder1 和 Decoder1 采用的是 RSU-7 结构。
  • Encoder2 和 Decoder2 采用的是 RSU-6 结构。
  • Encoder3 和 Decoder3 采用的是 RSU-5 结构。
  • Encoder4 和 Decoder4 采用的是 RSU-4 结构。

可见,相邻 block 相差一次下采样和一次上采样,例如 RSU-6 相比于 RSU-7 少了一个下采样卷积和上采样卷积部分,RSU-7 是下采样 32 倍和上采样 32 倍,RSU-6 是下采样 16 倍和上采样 16 倍。代码实现如下:


import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
 
class REBNCONV(nn.Module):    #实现conv2d+BN+ReLU操作                                                      
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        # dilation用于实现空洞卷积
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
 
def _upsample_like(src,tar):
    src = F.interpolate(src,size=tar.shape[2:],mode='bilinear',align_corners=True)     
    return src
 
 
### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):                          #En_1   
 
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()
 
        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)              #CBR1
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)              #CBR2
 
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)           
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)             
 
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
 
        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
 
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
 
    def forward(self,x):
 
        hx = x
        hxin = self.rebnconvin(hx)
 
        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)
 
        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)
 
        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)
 
        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)
 
        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)
 
        hx6 = self.rebnconv6(hx)
 
        hx7 = self.rebnconv7(hx6)                                  
 		#实现残差连接
        hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
 
        hx6dup = _upsample_like(hx6d,hx5)
        hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
 
        hx5dup = _upsample_like(hx5d,hx4)
        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
 
        hx4dup = _upsample_like(hx4d,hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
 
        hx3dup = _upsample_like(hx3d,hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
 
        hx2dup = _upsample_like(hx2d,hx1)
 
        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
 
        return hx1d + hxin

在这里插入图片描述

1.2第二种block结构

  数据经过 E n 1 − E n 4 En_1-En4 En1En4下采样处理后对应特征图的分辨率就已经相对比较小了,如果再继续下采样就会丢失很多上下文信息。为保留上下文信息,在 E n c o d e r 5 、 E n c o d e r 6 、 D e c o d e r 5 Encoder5、Encoder6、Decoder5 Encoder5Encoder6Decoder5中将原始RSU中的上采样、下采样结构换成了空洞卷积操作,从而得到了 R S U − 4 F RSU-4F RSU4F,其中 F F F表示 R S U RSU RSU是扩张版本。此时 R S U − 4 F RSU-4F RSU4F的所有中间特征图都与其输入特征图具有相同的分辨率。在这里插入图片描述
需要注意,在 E n c o d e r 5 Encoder5 Encoder5中特征图大小已经到了18*18,非常小(也因此不需要再下采样),故采用了空洞卷积操作,目的在不改变特征图大小的情况下增大感受野。故在代码中使用了dalition=2、4、8 E n c o d e r 6 、 D e c o d e r 5 Encoder6、Decoder5 Encoder6Decoder5同理。这一特点在原图中显示为使用了虚线构成的长方体块。

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
 
class REBNCONV(nn.Module):                                                          #CBL
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
        
### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()
 
        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
 
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
 
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
 
    def forward(self,x):
        hx = x

        hxin = self.rebnconvin(hx)
 
        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)
 
        hx4 = self.rebnconv4(hx3)
 
        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
 
        return hx1d + hxin

1.3特征图融合模块

  在通过编码、解码器的运算后,最后通过特征图融合模块(红框标出)将 D e 1 、 D e 2 、 D e 3 、 D e 4 、 D e 5 、 E n 6 De_1、De_2、De_3、De_4、De_5、En_6 De1De2De3De4De5En6模块的输出分别通过一个3x3的卷积层(卷积层的卷积核个数均为1),并通过双线性插值将得到的特征图还原回输入图像的大小,之后将得到的6个特征图进行拼接(Concatenation),最后再经过一个1x1的卷积层以及sigmoid激活函数,最终得到融合之后的图像。
在这里插入图片描述

1.4损失函数

   U 2 N e t U^2Net U2Net使用多监督算法构建损失函数。网络输出不仅仅包含最终特征图,还包含前面6个不同尺度的特征图,即,不仅要监督网络输出,还要监督中间融合特征图。 损失函数计算公式:
在这里插入图片描述
其中, M = 1 , 2 , 3 , . . . , 6 M=1,2,3,...,6 M=1,2,3,...,6 l s i d e m l_{side}^{m} lsidem表示特征图 S u p 1 、 S u p 2 、 . . . 、 S u p 6 Sup1、Sup2、...、Sup6 Sup1Sup2...Sup6的损失,而 l f u s e l_{fuse} lfuse表示最终特征图的损失, w w w则表示两种损失的权重参数(论文给出的源码中全为1)。 l s i d e l_{side} lside l f u s e l_{fuse} lfuse采用二值交叉熵(standard binary cross-entropy)进行计算:
在这里插入图片描述
其中, ( r , c ) (r,c) (r,c)表示像素坐标值, ( H , W ) (H,W) (H,W)表示图像高度和宽度, P G ( r , c ) P_{G(r,c)} PG(r,c)表示标签图像素灰度值, P S ( r , c ) P_{S(r,c)} PS(r,c)表示预测的图像素灰度值。

1.5总体网络架构

在这里插入图片描述
   U 2 N e t U^2Net U2Net主要由三部分组成:

  • 一个六级编码器:在 E n 1 、 E n 2 、 E n 3 、 E n 4 En_1、En_2、En_3、En_4 En1En2En3En4中分别使用 R S U 7 、 R S U 6 、 R S U 5 、 R S U 4 RSU7、RSU6、RSU5、RSU4 RSU7RSU6RSU5RSU4 L L L通常根据输入特征图的空间分辨率进行配置。对于高宽较大的特征图,使用更大的 L L L来捕获更多的大比例度信息。 E n 5 En_5 En5 E n 6 En_6 En6中特征图的分辨率相对较低,进一步降低这些特征图的采样会导致有用的上下文丢失。因此,在 E n 5 En_5 En5 E n 6 En_6 En6阶段,都使用 R S U − 4 F RSU-4F RSU4F,其中 F F F表示 R S U RSU RSU是扩张版本,其用膨胀卷积代替池化和上采样操作这意味着 R S U − 4 F RSU-4F RSU4F的所有中间特征图都与其输入特征图具有相同的分辨率。
  • 一个五级解码器: D e 5 De_5 De5阶段同样使用 R S U − 4 F RSU-4F RSU4F,并且每个解码器都使用前一阶段的上采样特征图和来自其对称编码器阶段的特征图的串联作为输入(跳跃连接)。
  • 特征图融合模块:将六个侧输出显著性特征图上采样到输入图像的尺寸,之后使用融合操作,并通过1x1卷积层和sigmoid函数生成最终的显著性特征图。

  研究中将3320320的图像裁剪为3288288大小输入模型,最终得到1288288的图像分割结果(二值图像):
在这里插入图片描述

1.6代码汇总

import torch
import torch.nn as nn
import torch.nn.functional as F

class REBNCONV(nn.Module):
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self,x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout

## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):

    src = F.upsample(src,size=tar.shape[2:],mode='bilinear')

    return src


### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))
        hx6dup = _upsample_like(hx6d,hx5)

        hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)


        hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NET,self).__init__()

        self.stage1 = RSU7(in_ch,32,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,32,128)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(128,64,256)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(256,128,512)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(512,256,512)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(512,256,512)

        # decoder
        self.stage5d = RSU4F(1024,256,512)
        self.stage4d = RSU4(1024,128,256)
        self.stage3d = RSU5(512,64,128)
        self.stage2d = RSU6(256,32,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #-------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

1.7普通残差块与RSU对比

在这里插入图片描述
  普通残差块的操作可概况为 H ( x ) = F 2 ( F 1 ( x ) ) + x H(x)=F_2(F_1(x))+x H(x)=F2(F1(x))+x,其中, F 1 、 F 2 F_1、F_2 F1F2代表权重层,此处设为卷积运算。RSU 和残差块之间的主要设计区别在于,RSU 用类似 U N e t UNet UNet的结构替换了普通的单流卷积,并将原始特征替换为由权重层转换的局部特征: H R S U ( x ) = U ( F 1 ( x ) ) + F 1 ( x ) H_{RSU}(x)=U(F_1(x))+F_1(x) HRSU(x)=U(F1(x))+F1(x),其中 U U U表示多层U型结构。种设计更改使网络能够直接从每个残差块中提取来自多个尺度的特征。更值得注意的是,由于 U 结构导致的计算开销很小,因为大多数操作都应用于下采样的特征图。下图中给出了RSU中 F 1 、 U F_1、U F1U的含义:
在这里插入图片描述

  残差块性能比较:
在这里插入图片描述

  • PLN:普通卷积块。
  • RES:残差块。
  • DSE:密集块。
  • INC:初始块。
  • RSU:U型残差块。

二、代码复现

https://github.com/xuebinqin/U-2-Net/tree/master

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2167223.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从销售到 AI 算法工程师 | 转行人工智能大模型(含面经裁员幸存指南)

我叫王东,90后,和大家分享一下我的人工智能转型之路。 农学毕业,投身互联网做销售 机遇难求,养殖梦碎 我是土生土长的农村人,小时候经常和小鱼小虾打交道,上大学的时候就选择了农学专业,想着…

qmt量化交易策略小白学习笔记第67期【qmt编程之获取ETF申赎清单】

qmt编程之获取ETF申赎清单 qmt更加详细的教程方法,会持续慢慢梳理。 也可找寻博主的历史文章,搜索关键词查看解决方案 --获取ETF申赎清单! 实盘或回测qmt,可关注博主咨询~ ETF申赎清单 提示 使用前需要调用xtdata.download_…

JDBC 事务

文章目录 准备数据JDBC操作事务API介绍案例代码小结 准备数据 # 创建一个表:账户表. create database day05_db; # 使用数据库 use day05_db; # 创建账号表 create table account(id int primary key auto_increment,name varchar(20),money double ); # 初始化数据…

如何管理自己的工作任务和时间

在当今快节奏的工作环境中,有效地管理工作任务和时间是取得成功和保持工作生活平衡的关键。以下是一些实用的方法,可以帮助你更好地掌控自己的工作。 一、明确工作任务 1、制定任务清单 每天开始工作前,列出当天需要完成的所有任务。可以使用…

PWA(Progressive web APPs,渐进式 Web 应用)

文章目录 引言I 什么是 PWA功能特性II Web 应用清单引言 PWA 是 Google 于 2016 年提出的概念,于 2017 年正式落地,于 2018 年迎来重大突破,全球顶级的浏览器厂商,Google、Microsoft、Apple 已经全数宣布支持 PWA 技术。 PWA 目的是通过各种 Web 技术实现与原生 App 相近…

Java读取YAML文件

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storm…

基于注意力机制的图表示学习:GRAPH-BERT模型

人工智能咨询培训老师叶梓 转载标明出处 图神经网络(GNNs)在处理图结构数据方面取得了显著的进展,但现有模型在深层结构中存在性能问题,如“悬挂动画问题”和“过平滑问题”。而且图数据内在的相互连接特性限制了大规模图输入的并…

三天搞了7000,AI绘本副业赚钱新途径,抓住绘本创业,轻松开启副业

**项目简介:**这个项目的原理是利用AI文本工具来编写绘本故事脚本、人物描述及场景设定,完成整个绘本内容的创作。接着,使用百度翻译将文本翻译成英文,并在MJ软件上输入关键词生成相应的图片。随后,将这些图片上传至Pi…

《我的世界:地下城》风灵月影修改器使用指南 —— 掌控冒险,轻松畅游

对于热爱《我的世界:地下城》并寻求更加个性化游戏体验的玩家来说,风灵月影修改器提供了一个强大而便捷的辅助工具。 以下简要步骤将指导你如何安全高效地利用此修改器,让你的探险之旅如虎添翼。 1.下载与安装: 首先&#xff0c…

ESP32 入门笔记02: ESP32-C3 系列( 芯片ESP32-C3FN4) (ESP-IDF + VSCode)

ESP32-C3 系列的 芯片 / 模组 / 开发板 ESP32-C3-DevKitM-1是乐鑫一款搭载 ESP32-C3-MINI-1 或 ESP32-C3-MINI-1U 模组的入门级开发板(内置 ESP32-C3FH4 或 ESP32-C3FN4 芯片)。 板上模组大部分管脚均已引出至两侧排针,可根据开发实际需求&a…

Android平台RTMP推送模块的设计意义

为什么要做RTMP推送 RTMP是一种广泛使用的流媒体传输协议,它允许视频和音频数据在互联网上实时、高效地传输。实现RTMP推送功能,主要是为了满足以下需求: 实时性要求:RTMP协议具有低延迟的特点,适合用于需要实时交互的…

大数据毕业设计选题推荐-租房数据分析系统-Hive-Hadoop-Spark

✨作者主页:IT研究室✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

LVS-DR实战案例,实现四层负载均衡

环境准备:三台虚拟机(NET模式或者桥接模式) 192.168.88.200 (web1)(安装nginx服务器作为测试) 192.168.88.201 (服务器)(用于部署lvs-dr) 192.168.88.202 (web2)…

Ubuntu20.04 安装汉语拼音后重启登入黑屏

在虚拟机上装了一个Ubuntu用来学C,默认没有安装中文输入。于是按照网上教程装了几个汉语包。切换输入法的时候突然死机,重启登入直接黑屏。百度后发现有不少老哥和我这个问题一模一样,按照他们的方法也终于整好了,虚惊一场。 解决…

axios proxy 和 httpsAgent 的使用差异案例详解

背景 因为 wadesk 开发了本地 http 服务,http 本地服务是运行在 electron-main 的纯 node 环境中的,这个之前探讨了 node 下怎么使用 fetch 时就提到了一个 https-proxy-agent 库,这次使用 axios,发现 axios 自带 proxy 配置项&a…

Ubuntu20.04中ros2 foxy版本安装gazebo,并运行小车运动demo

这里默认你安装好了ros2 foxy版本 sudo apt install gazebo11sudo apt install ros-foxy-gazebo-ros-pkgs建议把其他的包也安装了 sudo apt install ros-foxy-gazebo-*安装速度的话,比安装ros环境快多了。 此时,可以在/opt/ros/foxy/share目录下看到若…

NXP(恩智浦)—TJA1042TK/3 CAN收发器芯片详解

写在前面 本系列文章主要讲解NXP(恩智浦)—TJA1042TK/3 CAN收发器芯片的相关知识,希望能帮助更多的同学认识和了解NXP(恩智浦)—TTJA1042TK/3 CAN收发器芯片。 若有相关问题,欢迎评论沟通,共同…

大厂再侵美团腹地

美团在拿自己的生命线,来对抗大厂们的第二曲线。 转载:新熵 原创 作者丨茯神 编辑丨九犁 中国互联网行业的“黄金年代”一去不复返后,大厂们纷纷调转航线探寻,数年间尝尽了To B、元宇宙、直播带货等酸甜苦辣。如今蓦然回首才发现…

腾讯云linux服务器修改root用户登录密码操作步骤

https://cloud.tencent.com/loginhttps://cloud.tencent.com/login 点击上面链接 登录腾讯云控制台 在打开页面 确认服务器后 点 登录 按钮 操作命令: sudo passwd root 密码设置不小于16位 字母大小写数字加特殊符号组合 修改成功后关闭登录窗口即可。

9.26号算法题

数组的遍历 414.第三大的数 题解&#xff1a; class Solution {public int thirdMax(int[] nums) {TreeSet<Integer>treeSet new TreeSet<Integer>(); //生成一个TreeSet对象&#xff0c;存储有序唯一整数for (int num : nums){//遍历数组treeSet.add(num);//将…