Unet网络解析

news2024/11/27 22:33:10

1 Unet网络概述

论文名称:U-Net: Convolutional Networks for Biomedical Image Segmentation

发表会议及时间 :MICCA ( 国际医学图像计算和 计算机辅 助干预会 议 ) 2 0 1 5

Unet提出的初衷是为了解决医学图像分割的问题。

Unet网络非常的简单,前半部分就是特征提取,后半部分是上采样。在一些文献中把这种结构叫做编码器-解码器结构,由于网络的整体结构是一个大些的英文字母U,所以叫做U-net。其实可以将图像->高语义feature map的过程看成编码器,高语义->像素级别的分类score map的过程看作解码器

  • Encoder:左半部分,由两个3x3的卷积层(RELU)再加上一个2x2的maxpooling层组成一个下采样的模块;
  • Decoder:右半部分,由一个上采样的卷积层(反卷积层),特征拼接concat,两个3x3的卷积层,非线性ReLU层;

在当时,Unet相比更早提出的FCN网络,使用拼接来作为特征图的融合方式。

  • FCN是通过特征图对应像素值的相加来融合特征的;
  • U-net通过通道数的拼接,这样可以形成更厚的特征,当然这样会更佳消耗显存;

2 Unet与FCN网络的区别

U-Net和FCN非常的相似,U-Net比FCN稍晚提出来,但都发表在2015年,和FCN相比,U-Net的第一个特点是完全对称,也就是左边和右边是很类似的,而FCN的decoder相对简单。第二个区别就是skip connection,FCN用的是加操作(summation),U-Net用的是叠操作(concatenation)。这些都是细节,重点是它们的结构用了一个比较经典的思路,也就是编码和解码(encoder-decoder)结构。其实可以将图像->高语义feature map的过程看成编码器,高语义->像素级别的分类score map的过程看作解码器

此外, 由于UNet也和FCN一样, 是全卷积形式, 没有全连接层(即没有固定图的尺寸),所以容易适应很多输入尺寸大小,但并不是所有的尺寸都可以,需要根据网络结构决定

3 为什么Unet在医疗图像分割种表现好

  • 医疗影像语义较为简单、结构固定。因此语义信息相比自动驾驶等较为单一,因此并不需要去筛选过滤无用的信息。医疗影像的所有特征都很重要,因此低级特征和高级语义特征都很重要,所以U型结构的skip connection结构(特征拼接)更好派上用场

  • 医学影像的数据较少,获取难度大,数据量可能只有几百甚至不到100,因此如果使用大型的网络例如DeepLabv3+等模型,很容易过拟合。大型网络的优点是更强的图像表述能力,而较为简单、数量少的医学影像并没有那么多的内容需要表述,因此也有人发现在小数量级中,分割的SOTA模型与轻量的Unet并没有优势

  • 医学影像往往是多模态的。比方说ISLES脑梗竞赛中,官方提供了CBF,MTT,CBV等多中模态的数据(这一点听不懂也无妨)。因此医学影像任务中,往往需要自己设计网络去提取不同的模态特征,因此轻量结构简单的Unet可以有更大的操作空间。

因此,大多数医疗影像语义分割任务都会首先用Unet作为baseline

4 Unet网络结构

Unet网络是建立在FCN网络基础上的,它的网络架构如下图所示,总体来说与FCN思路非常类似。这里需要注意的是,U-Net的输入大小是572x572,但是输出却是388x388,按理说它们应该相等(因为图像分割相当于逐像素进行分类,所以要求输入图像和输出图像大小一致),但是为什么这里的输入尺寸要比输出尺寸大呢?那是因为下图这个结构图是当年论文作者绘制的,该作者对输入图像的边缘进行了镜像填充,通过镜像填充将边界区域进行扩大,这样可以给模型提供更多信息来完成模型的分割。

按照论文中的解释,镜像填充的原因是:因为图像 的边界的外面是空白的,没有其它有效像素,而我们预测图像中的像素类别时往往需要参考它的周围像素作为上下文信息,这样才能保持分割的准确性,为了预测边界像素,论文对边界区域进行镜像,来补全边界周围缺失的内容,然后进行预测。这种策略叫做"overlap-tile"

这里的输入是单通道的原因是因为输入图片是灰度图,而输出是两通道是因为这里是对像素进行二分类(前景和背景),所以输出通道是2
在这里插入图片描述
整个网络由编码部分(左) 和 解码部分(右)组成,类似于一个大大的U字母,具体介绍如下:

1、编码部分是典型的卷积网络架构:

  • 它主要的作用是进行特征提取
  • 架构中含有着一种重复结构,每次重复中都有2个 3 x 3卷积层、非线性ReLU层和一个 2 x 2 max pooling层(stride为2)。(图中的蓝箭头、红箭头,没画ReLu)
  • 每一次下采样后我们都把特征通道的数量加倍

2、解码部分也使用了类似的模式:

  • 它主要的作用是进行上采样 (上采样可以让包含高级抽象特征的低分辨率图片在保留高级抽象特征的同时变为高分辨率)
  • 架构中包含有一种重复结构,每次重复都有一个上采样的卷积层(反卷积层),特征拼接concat,两个3x3的卷积层,非线性ReLU层
  • 每一步都首先使用反卷积(up-convolution),每次使用反卷积都将特征通道数量减半,特征图大小加倍。(图中绿箭头)
  • 反卷积过后,将反卷积的结果与编码部分中对应步骤的特征图拼接起来(concat)(也就是将深层特征与浅层特征进行融合,使得信息变得更丰富)。(白/蓝块)
  • 编码部分中的特征图尺寸稍大,将其修剪过后进行拼接(这里是将两个特征图的尺寸调整一致后按通道数进行拼接)。(左边深蓝虚线部分就是要裁剪的部分,它对应右边的白色长方块部分)
  • 对拼接后的map再进行2次3 x 3的卷积。(右侧蓝箭头)
  • 最后一层的卷积核大小为1 x 1,将64通道的特征图转化为特定类别数量(分类数量)的结果。(图中青色箭头)

5 代码复现

下面使用pytorch框架对论文中的unet进行复现:

#编码器(论文中称之为收缩路径)的基本单元
def contracting_block(in_channels, out_channels):
    block = torch.nn.Sequential(
        #这里的卷积操作没有使用padding,所以每次卷积后图像的尺寸都会减少2个像素大小
        nn.Conv2d(kernel_size=(3,3), in_channels=in_channels, out_channels=out_channels),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
        nn.Conv2d(kernel_size=(3,3), in_channels=out_channels, out_channels=out_channels),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels)
    )
    return block

#解码器(论文中称之为扩张路径)的基本单元
class expansive_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(expansive_block, self).__init__()
        
        #每进行一次反卷积,通道数减半,尺寸扩大2倍
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=(3,3), stride=2, padding=1, 
                                     output_padding=1)
        self.block = nn.Sequential(
            # 这里的卷积操作没有使用padding,所以每次卷积后图像的尺寸都会减少2个像素大小
            nn.Conv2d(kernel_size=(3,3), in_channels=in_channels, out_channels=mid_channels),
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size=(3,3), in_channels=mid_channels, out_channels=out_channels),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )
    
    def forward(self, e, d):
        d = self.up(d)
        #concat
        #e是来自编码器部分的特征图,d是来自解码器部分的特征图,它们的形状都是[B,C,H,W]
        diffY = e.size()[2]-d.size()[2]
        diffX = e.size()[3]-d.size()[3]
        #裁剪时,先计算e与d在高和宽方向的差距diffY和diffX,然后对e高方向进行裁剪,具体方法是两边分别裁剪diffY的一半,
        #最后对e宽方向进行裁剪,具体方法是两边分别裁剪diffX的一半,
        #具体的裁剪过程见下图一
        e = e[:,:, diffY//2:e.size()[2]-diffY//2, diffX//2:e.size()[3]-diffX//2]
        cat = torch.cat([e, d], dim=1)#在特征通道上进行拼接
        out = self.block(cat)
        return out

#最后的输出卷积层
def final_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(kernel_size=(1,1), in_channels=in_channels, out_channels=out_channels),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
    )
    return block

class UNet(nn.Module):
    
    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        
        #编码器 (Encode)
        self.conv_encode1 = contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode2 = contracting_block(in_channels=64, out_channels=128)
        self.conv_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode3 = contracting_block(in_channels=128, out_channels=256)
        self.conv_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode4 = contracting_block(in_channels=256, out_channels=512)
        self.conv_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        #编码器与解码器之间的过渡部分(Bottleneck)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(kernel_size=(3,3), in_channels=512, out_channels=1024),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(kernel_size=(3,3), in_channels=1024, out_channels=1024),
            nn.ReLU(),
            nn.BatchNorm2d(1024)
        )
        
        # 解码器(Decode)
        self.conv_decode4 = expansive_block(1024, 512, 512)
        self.conv_decode3 = expansive_block(512, 256, 256)
        self.conv_decode2 = expansive_block(256, 128, 128)
        self.conv_decode1 = expansive_block(128, 64, 64)
        
        self.final_layer = final_block(64, out_channel)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_pool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_pool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_pool3(encode_block3)
        encode_block4 = self.conv_encode4(encode_pool3)
        encode_pool4 = self.conv_pool4(encode_block4)

        # Bottleneck
        bottleneck = self.bottleneck(encode_pool4)
        
        # Decode
        decode_block4 = self.conv_decode4(encode_block4, bottleneck)
        decode_block3 = self.conv_decode3(encode_block3, decode_block4)
        decode_block2 = self.conv_decode2(encode_block2, decode_block3)
        decode_block1 = self.conv_decode1(encode_block1, decode_block2)
        
        final_layer = self.final_layer(decode_block1)
        return final_layer

模型测试:

image = torch.rand((1, 3, 572, 572))
unet = UNet(in_channel=3, out_channel=2)
mask = unet(image)
print(mask.shape)
#输出结果:
torch.Size([1, 2, 388, 388])

图一:图像裁剪过程演示:
这里演示的是将64x64的特征图裁剪为56x56大小的过程
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/166761.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode 2246. Longest Path With Different Adjacent Characters(不同相邻字母的最长路径)

给一棵以节点0为根的树(不一定是二叉树),共有n个节点,0~n-1, 同样的,有长度为n的数组parent, parent[i]表示第 i 个节点的parent, 0的parent是-1. 长度为n的字符串,s(i)表示第i个节点…

JSPmvc

一、JSP 概述 JSP(全称:Java Server Pages):Java 服务端页面。是一种动态的网页技术,其中既可以定义HTML、JS、CSS等静态内 容,还可以定义 Java代码的动态内容,也就是 JSP HTML Java 。如下就…

【金融】新成立基金建仓时点、行业分布与市场行情关系探究

需要进一步交流,获取数据和代码的同学欢迎私信奥~基于新成立基金建仓带入市场的巨量资金会推动市场行情这一逻辑,开展了一系列研究。首先提出了通过基金净值识别建仓行为(累计绝对值涨跌幅法)和通过基金β值识别建仓行为&#xff…

Vue知识系列-VS Code的安装+Vue环境的搭建+Vue指令

一、VS Code下载地址 Visual Studio Code - Code Editing. Redefined 二、VS Code初始化设置 1.安装插件 在安装好的VSCode软件的扩展菜单中查找安装如下4个插件 2、创建项目 vscode本身没有新建项目的选项,所以要先创建一个空的文件夹,如project_xx…

自主异常检测算法(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

AppScan 扫描web应用程序

系列文章 AppScan介绍和安装 第二节-AppScan 扫描web应用程序 1.环境布置 我们准备了如下一个靶场用来做实验 2.扫描步骤 1.启动AppScan 2.选择 【扫描web应用程序】 3.输入起始URL,点击【下一步】 http://127.0.0.1:83/4.选择【不使用代理】,点击【下一…

ctfshow php特性[125-135]

&#x1f60b;大家好&#xff0c;我是YAy_17&#xff0c;是一枚爱好网安的小白&#xff0c;自学ing。 本人水平有限&#xff0c;欢迎各位大佬指点&#xff0c;一起学习&#x1f497;&#xff0c;一起进步⭐。⭐此后如竟没有炬火&#xff0c;我便是唯一的光。⭐web 125<?php…

word怎么转换成pdf?其实很简单,看这里即可!

转眼间又到了校招的季节&#xff0c;想必许多小伙伴都在忙着编辑自己的简历吧。不过&#xff0c;咱们编辑的时候常常用到的都是word文件&#xff0c;但是当我们要将文件投递出去的时候就需要用到pdf了。其实不仅仅是投递简历&#xff0c;许多地方在要求我们发送正式文件的时候都…

vue+node+mysql全栈项目完整记录

文章目录vuenodemysql全栈项目完整记录写在前面项目最终界面展示项目框架搭建后端创建后端项目编写入口文件数据库及数据库使用前端创建前端项目使用elementUI必要包安装设置跨域访问&#xff0c;全局挂载axios删除无用的文件和代码设置统一的页面样式主页面页面设计路由设计登…

【C语言】柔性的数组是什么?C/C++程序的内存开辟又是?

本文主要讲解柔性数组的相关知识点&#xff0c;并穿插一下C/C程序的内存开辟&#xff0c;涉及到动态内存管理函数&#xff0c;如有不了解的&#xff0c;请参考这一篇文章【C语言】小王带您轻松实现动态内存管理&#xff08;简单易懂&#xff09;_小王学代码的博客-CSDN博客 目录…

【C++】类和对象【下篇】--初始化列表,static成员,友元,内部类,匿名对象

文章目录一、再谈构造函数1.构造函数体赋值2.初始化列表1.概念2.特性二、隐式类型转换1.概念2.构造函数的类型转换3.explict关键字4.类型转换的意义三、Static成员1.概念2.static成员变量3.static成员函数四、友元1.友元函数2.友元类五、内部类六、匿名对象七、拷贝对象时的一些…

深入使用noexcept

深入使用noexcept简介好处坏处适用场景不适用场景实验结果总结参考资料简介 noexcept是C11引入的&#xff0c;表明函数是否会抛出异常。正确使用它可以优化性能&#xff0c;错误使用则会带来麻烦。 noexcept使用语法有两种&#xff1a; noexcpetnoexcept(expression) 第二种…

如何提高系统稳定性?

1、系统稳定性的评判标准 在开始谈稳定性保障之前&#xff0c;我们先来聊聊业内经常提及的一个词SLA&#xff01;业内喜欢用SLA &#xff08;服务等级协议&#xff0c;全称&#xff1a;service level agreement&#xff09;来衡量系统的稳定性&#xff0c;对互联网公司来说&am…

测试开发知识总结(一)

本文内容顺序&#xff1a;测试基础理论、测试岗经常被问到的场景题、智力题、测试岗高频算法题、数据库、Linux知识点。常用自动化测试工具1、Appium官网&#xff1a;http://appium.ioAppUI自动化测试Appium 是一个移动端自动化测试开源工具&#xff0c;支持iOS 和Android 平台…

为什么很少拿神经网络来直接做滤波器呢?

其实无论是IIR&#xff08;RNN&#xff09;还是FIR(CNN)滤波器都可以看成一个简单神经网络&#xff0c;而且有严格的推理&#xff0c;可解释性比神经网络强多了&#xff0c;而已易于工程实现&#xff0c;因此在工程中大量应用。你说的含色噪声和其他乱七八糟的噪声难以滤除时&a…

ROS | Realsense中的IMU解算orientation

文章目录概述一、定义介绍二、操作教程(一)、下载并编译imu_tools功能包1.创建工作空间并初始化2.下载imu_tools并编译(二)、修改配置1.修改imu_tools源码2.修改launch文件3.启动解算概述 本文详细介绍了如何使用ROS自带的工具解算6轴IMU&#xff0c;获取其位姿。 一、定义介绍…

mybatis之动态SQL常见标签的使用

引入where标签的原因&#xff1a; 在上篇文章使用if语句的查询中&#xff0c;我们在SQL语句后面都写入了where 11&#xff0c;以保证每次都能够查询出结果&#xff0c;但这种方法并不是最合理的&#xff0c;假设我们现在将where后面的11去掉&#xff1a; 如下所示&#xff1a…

上午摆摊,下午写代码,35岁程序员的双面人生超爽!

最近看到一个程序员发帖分享自己的工作&#xff1a;白天出摊卖馄饨&#xff0c;下午在家为海外公司全职远程工作。“年入百万是可以的&#xff0c;并且我老家是三线城市&#xff0c;没有房租、通勤费用&#xff0c;性价比还是很高的。” 对比在大城市天天996的程序员&#xff0…

【JavaEE】多线程之线程安全(synchronized篇),死锁问题

目录 线程安全问题 观察线程不安全 线程安全问题的原因 从原子性入手解决线程安全问题 ——synchronized synchronized的使用方法 synchronized的互斥性和可重入性 死锁 死锁的三个典型情况 死锁的四个必要条件 破除死锁 线程安全问题 在前面的章节中&#xff0c…

Wav2Vec HuBert 自监督语音识别模型

文章目录Wav2Vec: Unsupervised pre-training for speech recognitionabstractmethodwav2vec 2.0: A Framework for Self-Supervised Learning of Speech RepresentationsabstractintroductionmethodMODEL arch损失函数finetuneexprimentHuBERT: Self-Supervised Speech Repres…