神经网络(五):U2Net模型

news2024/9/27 17:29:29

文章目录

  • 一、网络结构
    • 1.1第一种block结构
    • 1.2第二种block结构
    • 1.3特征图融合
    • 1.4损失函数
    • 1.5总体网络架构
    • 1.6代码汇总
    • 1.7普通残差块与RSU对比
  • 二、代码复现


  参考论文:U2-Net: Going deeper with nested U-structure for salient object detection
  这篇文章基于显著目标检测任务提出,显著目标检测是指将图像中最吸引人的目标或者区域分割出来,因此只有前景和背景两个类别,相当于语义分割中的二分类任务。例如,下图中展示了三张图片进行显著目标检测后的结果:
在这里插入图片描述
其中,白色区域代表前景 ,即最吸引人的目标或区域,而黑色区域代表背景 。

一、网络结构

   U 2 − N e t U^2-Net U2Net网络基于 U N e t UNet UNet网络设计而来。事实上,该网络的整体结构与 U N e t UNet UNet网络几乎相同,但所使用的上采样、下采样模块变成了小型的 U N e t UNet UNet网络,即, U 2 N e t U^2Net U2Net本质上是 U N e t UNet UNet网络的嵌套。而 U 2 − N e t U^2-Net U2Net网络的核心就是这些作为模块的小型 U N e t UNet UNet网络,并将其起名为 R e S i d u a l U − b l o c k ( R S U ,残差 U 块 ) ReSidual U-block(RSU,残差U块) ReSidualUblock(RSU,残差U)。网络结构如下:
在这里插入图片描述
这些模块其实可以分为两种, E n c o d e r 1   E n c o d e r 4 、 D e c o d e r 1   D e c o d e r 4 Encoder1~Encoder4、Decoder1~Decoder4 Encoder1 Encoder4Decoder1 Decoder4采用的是同一种结构的残差块,只不过深度不同,而Encoder5、Encoder6、Decoder5 采用的是另一种结构的残差块。整体流程可概况为:

  • Encoder阶段:每通过一个模块后都会下采样两倍,使用的是torch.nn.MaxPool2d
  • Decoder阶段:每通过一个模块后都会上采用两倍,使用的是torch.nn.functional.interpolate()
  • 跳跃链接:与 U N e t UNet UNet网络思路相同,将编码器的输出与解码器输出的特征图进行拼接,最后得到分割后的图像。

1.1第一种block结构

  论文中给出了常见的四种特征提取模块(图 ( a ) − ( d ) (a)-(d) (a)(d)),以及提出的RSU模块(图 ( e ) (e) (e)):
在这里插入图片描述

  • PLN模块:常规网络特征提取模块,包含卷积、归一化、ReLU激活函数三种计算。
  • RES模块:残差块,增加了残差连接。
  • DSE模块:特征图融合。
  • INC模块:特征图融合。

而图 ( e ) (e) (e)即为本文提出的新型残差块 R S U − L RSU-L RSUL,其中,L代表RSU的深度,论文中给出的是深度为7的RSU,即 R S U − 7 RSU-7 RSU7。该结构图的真实结构如下图所示:
在这里插入图片描述
  回到 U 2 − N e t U^2-Net U2Net结构,该RSU的使用场景有:

  • Encoder1 和 Decoder1 采用的是 RSU-7 结构。
  • Encoder2 和 Decoder2 采用的是 RSU-6 结构。
  • Encoder3 和 Decoder3 采用的是 RSU-5 结构。
  • Encoder4 和 Decoder4 采用的是 RSU-4 结构。

可见,相邻 block 相差一次下采样和一次上采样,例如 RSU-6 相比于 RSU-7 少了一个下采样卷积和上采样卷积部分,RSU-7 是下采样 32 倍和上采样 32 倍,RSU-6 是下采样 16 倍和上采样 16 倍。代码实现如下:


import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
 
class REBNCONV(nn.Module):    #实现conv2d+BN+ReLU操作                                                      
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        # dilation用于实现空洞卷积
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
 
def _upsample_like(src,tar):
    src = F.interpolate(src,size=tar.shape[2:],mode='bilinear',align_corners=True)     
    return src
 
 
### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):                          #En_1   
 
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()
 
        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)              #CBR1
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)              #CBR2
 
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)           
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)             
 
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
 
        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
 
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
 
    def forward(self,x):
 
        hx = x
        hxin = self.rebnconvin(hx)
 
        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)
 
        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)
 
        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)
 
        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)
 
        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)
 
        hx6 = self.rebnconv6(hx)
 
        hx7 = self.rebnconv7(hx6)                                  
 		#实现残差连接
        hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
 
        hx6dup = _upsample_like(hx6d,hx5)
        hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
 
        hx5dup = _upsample_like(hx5d,hx4)
        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
 
        hx4dup = _upsample_like(hx4d,hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
 
        hx3dup = _upsample_like(hx3d,hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
 
        hx2dup = _upsample_like(hx2d,hx1)
 
        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
 
        return hx1d + hxin

在这里插入图片描述

1.2第二种block结构

  数据经过 E n 1 − E n 4 En_1-En4 En1En4下采样处理后对应特征图的高与宽就已经相对比较小了,如果再继续下采样就会丢失很多上下文信息。为保留上下文信息,在 E n c o d e r 5 、 E n c o d e r 6 、 D e c o d e r 5 Encoder5、Encoder6、Decoder5 Encoder5Encoder6Decoder5中将原始RSU中的上采样、下采样结构换成了空洞卷积操作,从而得到了 R S U − L F RSU-LF RSULF,其中 L L L表示RSU的深度。 E n c o d e r 5 、 E n c o d e r 6 、 D e c o d e r 5 Encoder5、Encoder6、Decoder5 Encoder5Encoder6Decoder5中使用的是RSU-4F。结构如下:在这里插入图片描述
需要注意,在 E n c o d e r 5 Encoder5 Encoder5中特征图大小已经到了18*18,非常小(也因此不需要再下采样),故采用了空洞卷积操作,目的在不改变特征图大小的情况下增大感受野。故在代码中使用了dalition=2、4、8 E n c o d e r 6 、 D e c o d e r 5 Encoder6、Decoder5 Encoder6Decoder5同理。这一特点在原图中显示为使用了虚线构成的长方体块。

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
 
class REBNCONV(nn.Module):                                                          #CBL
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
        
### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()
 
        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
 
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
 
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
 
    def forward(self,x):
        hx = x

        hxin = self.rebnconvin(hx)
 
        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)
 
        hx4 = self.rebnconv4(hx3)
 
        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
 
        return hx1d + hxin

1.3特征图融合

  在通过编码、解码器的运算后,最后通过特征图融合模块(红框标出)将 D e 1 、 D e 2 、 D e 3 、 D e 4 、 D e 5 、 E n 6 De_1、De_2、De_3、De_4、De_5、En_6 De1De2De3De4De5En6模块的输出分别通过一个3x3的卷积层(卷积层的卷积核个数均为1),并通过双线性插值将得到的特征图还原回输入图像的大小,之后将得到的6个特征图进行拼接(Concatenation),最后再经过一个1x1的卷积层以及sigmoid激活函数,最终得到融合之后的图像。
在这里插入图片描述

1.4损失函数

   U 2 N e t U^2Net U2Net使用多监督算法构建损失函数。网络输出不仅仅包含最终特征图,还包含前面6个不同尺度的特征图,即,不仅要监督网络输出,还要监督中间融合特征图。 损失函数计算公式:
在这里插入图片描述
其中, M = 1 , 2 , 3 , . . . , 6 M=1,2,3,...,6 M=1,2,3,...,6,表示特征图 S u p 1 、 S u p 2 、 . . . 、 S u p 6 Sup1、Sup2、...、Sup6 Sup1Sup2...Sup6的损失,而 l f u s e l_{fuse} lfuse表示最终特征图的损失, w w w则表示两种损失的权重参数(论文给出的源码中全为1)。 l s i d e l_{side} lside l f u s e l_{fuse} lfuse采用二值交叉熵(standard binary cross-entropy)进行计算:
在这里插入图片描述
其中,(r,c)表示像素坐标值,(H,W)表示图像高、宽,PG(r,c)表示标签图像素灰度值,PS(r,c)表示预测的图像素灰度值。

1.5总体网络架构

  研究中将3320320的图像裁剪为3288288大小输入模型,最终得到1288288的图像分割结果(二值图像):
在这里插入图片描述

1.6代码汇总

import torch
import torch.nn as nn
import torch.nn.functional as F

class REBNCONV(nn.Module):
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self,x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout

## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):

    src = F.upsample(src,size=tar.shape[2:],mode='bilinear')

    return src


### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))
        hx6dup = _upsample_like(hx6d,hx5)

        hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)


        hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NET,self).__init__()

        self.stage1 = RSU7(in_ch,32,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,32,128)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(128,64,256)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(256,128,512)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(512,256,512)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(512,256,512)

        # decoder
        self.stage5d = RSU4F(1024,256,512)
        self.stage4d = RSU4(1024,128,256)
        self.stage3d = RSU5(512,64,128)
        self.stage2d = RSU6(256,32,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #-------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

1.7普通残差块与RSU对比

在这里插入图片描述
  普通残差块的操作可概况为 H ( x ) = F 2 ( F 1 ( x ) ) + x H(x)=F_2(F_1(x))+x H(x)=F2(F1(x))+x,其中, F 1 、 F 2 F_1、F_2 F1F2代表权重层,此处设为卷积运算。RSU 和残差块之间的主要设计区别在于,RSU 用类似 U-Net 的结构替换了普通的单流卷积,并将原始特征替换为由权重层转换的局部特征: H R S U ( x ) = U ( F 1 ( x ) ) + F 1 ( x ) H_{RSU}(x)=U(F_1(x))+F_1(x) HRSU(x)=U(F1(x))+F1(x),其中 U U U表示多层U型结构。种设计更改使网络能够直接从每个残差块中提取来自多个尺度的特征。更值得注意的是,由于 U 结构导致的计算开销很小,因为大多数操作都应用于下采样的特征图。
  残差块性能比较:
在这里插入图片描述

  • PLN:普通卷积块。
  • RES:残差块。
  • DSE:密集块。
  • INC:初始块。
  • RSU:U型残差块。

二、代码复现

https://github.com/xuebinqin/U-2-Net/tree/master

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2165692.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

钢管加工长度检测系统源码分享

钢管加工长度检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer…

笔试编程-百战成神——Day01

1.数字统计 题目来源:数字统计——牛客网 测试用例 算法原理 根据题目我们知道,首先要输出两个数字确定一个区间,寻找这个区间内数字中所有包含2的个数,比如12包含一个2,22包含两个2,以此类推,所以我们的…

问题记录:end value has mixed support, consider using flex-end instead

一、问题记录 二、解决问题 根据提示改为flex-end 三、理解问题 ‌这个警告信息表明,在Flex布局中使用“end”属性时存在兼容性问题,建议使用“flex-end”代替。 当在Flex布局中使用“justify-content: end;”时,浏览器可能对“end”值的支…

嵌入式C语言的自我修养:内存泄漏与防范

内存泄漏与防范 一个内存泄漏的例子 如果我们使用malloc()申请的内存在使用结束后没有及时被释放&#xff0c;则C标准库中的内存分配器ptmalloc和内核中的内存管理子系统都失去了对这块内存的追踪和管理。 #include <stdlib.h> int main(void){ char *p; p(char *)mal…

plt常用函数介绍一

目录 前言plt.figure()plt.subplot()plt.subplots()plt.xticks()plt.xlim() 前言 Matplotlib是Python中的一个库&#xff0c;它是数字的-NumPy库的数学扩展。 Pyplot是Matplotlib模块的基于状态的接口。在Pyplot中可以使用各种图&#xff0c;例如线图&#xff0c;轮廓图&#…

C++独立开发开源大数计算库 CBigNum

项目简介&项目地址 CBigNum 是本人独立开发开源的一款大数计算库&#xff0c;支持任意位数整数带任意位数小数的浮点运算。您可以通过本库执行非常大的数据运算或非常高精度的除法运算(您可以随意指定除法的小数保留到第几位)以及各种科学计算(详见1.3)。 项目地址&#…

数字电路基础(锁存器+触发器)+Proteus仿真

1.锁存器 1.1.基本概念 1.1.1基本双稳态电路 下面电路中&#xff0c;具有0、1两种逻辑状态&#xff0c;一旦进入其中一种状态&#xff0c;就能长期保持不变的单元电路称为双稳态存储电路&#xff0c;简称双稳态电路。 锁存器和触发器都属于双稳态电路 该双稳态电路没有输入…

【Godot4自学手册】第四十八节创建雨粒子效果

今天我们要利用GPU粒子节点玩雨粒子效果&#xff0c;下雨天。 一、添加GPU粒子系统 添加GPUParticles2D节点。选择根节点&#xff0c;单击添加按钮&#xff0c;选择GPUParticles2D&#xff0c;完成添加。 二、修改属性 1.设置粒子数量。 在GPUParticles2D检查器中将Amount设…

智慧城市主要运营模式分析

(一)运营模式演变 作为新一代信息化技术落地应用的新事物,智慧城市在建设模式方面借鉴了大量工程建设的经验,如平行发包(DBB,Design-Bid-Build)、EPC工程总承包、PPP等模式等,这些模式在不同的发展阶段和条件下发挥了重要作用。 在智慧城市发展模式从政府主导、以建为主、…

求一个数的因子数(c语言)

1.计算并输出给定整数n的所有因子&#xff08;不包括1与n自身&#xff09;之和。规定n的值不大于1000。&#xff08;因子是能整除n的数 即n%i0&#xff09; // 例如&#xff0c;在主函数中从键盘给n输入的值为856&#xff0c;则输出为: sum763。 2.第一步我们先输入n的数&…

C++深入学习string类成员函数(1):默认与迭代

引言 在 C 编程中&#xff0c;std::string 类是处理字符串的核心工具之一。作为一个动态管理字符数组的类&#xff0c;它不仅提供了丰富的功能&#xff0c;还通过高效的内存管理和操作接口&#xff0c;极大地方便了字符串操作。通过深入探讨 std::string 的各类成员函数&#…

【java21】java21新特性之简单的Web服务器jwebserver

jwebserver是Java 18中引入的一个全新功能点&#xff0c;它允许用户通过命令行工具快速启动一个提供静态资源访问的迷你Web服务器。这个服务器不支持CGI和Servlet&#xff0c;因此其主要用途是轻量级的静态文件服务&#xff0c;如HTML、CSS、JavaScript和图片等。 其实在如Pyt…

热源检测系统源码分享

热源检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vision …

大语言模型之LlaMA系列- LlaMA 2及LLaMA2_chat(上)

LlaMA 2是一个经过预训练与微调的基于自回归的transformer的LLMs&#xff0c;参数从7B至70B。同期推出的Llama 2-Chat是Llama 2专门为对话领域微调的模型。 在许多开放的基准测试中Llama 2-Chat优于其他开源的聊天模型&#xff0c;此外Llama 2-Chat还做了可用性与安全性评估。 …

大数据 flink 01 | 从零环境搭建 简单Demo 运行

什么是Flink Flink是一个开源的流处理和批处理框架,它能够处理无界和有界的数据流&#xff0c;具有高吞吐量、低延迟和容错性等特点 Flink 可以应用于多个领域如&#xff1a;实时数据处理、数据分析、机器学习、事件驱动等。 什么是流式处理&#xff1f;什么是批处理 流处理…

Python 如何使用 unittest 模块编写单元测试

Python 如何使用 unittest 模块编写单元测试 单元测试是软件开发过程中的重要环节&#xff0c;它帮助开发者验证代码的正确性&#xff0c;确保功能按预期工作。Python 提供了一个强大的内置模块 unittest&#xff0c;使得编写和执行单元测试变得非常方便。本文将深入探讨如何使…

计算机组成原理(笔记5原码和补码的乘法以及直接补码阵列乘法器 )

原码一位乘法 手算&#xff1a;过程 令x′|x|0.x1x2…xn-1xn&#xff0c;y′|y|0.y1y2…yn-1yn 同时令乘积P′ |P| x′ y′&#xff0c;有&#xff1a; x′ y′ x′(0.y1y2…yn-1yn) x′ (y12-1y22-2…yn-12-(n-1)yn2-n) 2-1(y1x′2-1(y2x′…2-1(yn-1x′2-1(ynx′0))…))…

使用awvs测试站点并输出漏洞报告教程

环境配置 pikachu靶场 awvs 使用步骤 1.访问本机3443端口&#xff08;安装时自己设定的端口&#xff09; 2.点击【Targets】--》【Add Targets】新建扫描目标&#xff0c;输入目标网址&#xff08;以pikachu靶场为例&#xff09;&#xff0c;点击【Save】开始扫描 2.点击【…

【AI大模型】股票价格预测精度增强,基于变分模态分解、PatchTST和自适应尺度加权层

简介 股票价格指数是金融市场和经济健康的晴雨表&#xff0c;准确预测对投资决策至关重要。股票市场的高频交易和复杂行为使得预测具有挑战性&#xff0c;需开发稳定、准确的预测模型。研究表明&#xff0c;估值比率、数据驱动模型&#xff08;如支持向量机&#xff09;、股票…

机器学习 | 使用scikit-learn学习Python中的PCA(主成分分析)

为什么选择PCA&#xff1f; 当有许多输入属性时&#xff0c;很难将数据可视化。在机器学习领域有一个非常著名的术语“维度诅咒”。基本上&#xff0c;它指的是数据集中的属性数量越多&#xff0c;对机器学习模型的准确性和训练时间产生不利影响。主成分分析&#xff08;PCA&a…