U-Net 算法详解

news2024/11/25 16:47:52

目录

1.任务概述

2.编码器-解码器

3.跳跃连接

4.实现细节

5.损失函数

6.上采样方法

不填充还是填充?

7.U-Net 的运作方式

8.结论


1.任务概述

U-Net 是为语义分割任务开发的。当神经网络接受图像作为输入时,我们可以选择一般性地分类对象或按实例分类。我们可以预测图像中包含的对象(图像分类),所有对象的位置(图像定位/语义分割),或个别对象的位置(对象检测/实例分割)。

下图显示了这些计算机视觉任务之间的差异。为了简化问题,我们仅考虑一个类别和一个标签的分类。

图片

 

在分类任务中,我们输出一个大小为 k 的向量,其中 k 是类别的数量。在检测任务中,我们需要输出定义边界框的向量 x、y、高度、宽度、类别。

但在分割任务中,我们需要输出与原始输入具有相同维度的图像。这代表了一个相当大的工程挑战:神经网络如何从输入图像中提取相关特征,然后将它们投影到分割掩模中?

2.编码器-解码器

如果你对编码器-解码器不熟悉,建议你阅读这篇文章:

https://towardsdatascience.com/understanding-latent-space-in-machine-learning-de5a7c687d8d

编码器-解码器之所以相关,是因为它们产生了与我们想要的类似的输出:具有与输入相同维度的输出。我们是否可以将编码器-解码器的概念应用于图像分割?我们可以生成一个一维的二进制掩模,并使用交叉熵损失来训练网络。

我们的网络由两部分组成:编码器从图像中提取相关特征,解码器部分则采用提取的特征并重建分割掩模。

图片

 

在编码器部分,使用了卷积层,然后使用 ReLU 和最大池作为特征提取器。在解码器部分,使用了转置卷积来增加特征图的大小并减少通道数。使用了填充来保持卷积操作后特征图的大小相同。

你可能注意到的一件事是,与分类网络不同,这个网络没有全连接/线性层。这是一个完全卷积网络(FCN)的示例。FCN已经被证明在分割任务上表现良好,始于Shelhamer等人的论文“全卷积网络用于语义分割”[1]。

然而,这个网络存在一个问题。随着编码器和解码器层的增加,我们实际上会越来越“缩小”特征图。因此,编码器可能会丢弃更详细的特征,以获得更一般的特征。如果我们处理医学图像分割,每个像素被分类为患病/正常可能都很重要。我们如何确保这个编码器-解码器网络接受既有一般性又有详细性的特征?

3.跳跃连接

https://towardsdatascience.com/introduction-to-resnets-c0a830a288a4

由于深度神经网络在通过连续层传递信息时可能“遗忘”某些特征,跳跃连接可以重新引入这些特征,使学习更强大。跳跃连接是在残差网络(ResNet)中引入的,并显示出分类改进以及更平滑的学习梯度。受到这一机制的启发,我们可以将跳跃连接添加到 U-Net 中,以使每个解码器包含其对应编码器的特征图。这是 U-Net 的一个定义特征。

图片

 

U-Net 是一个带有跳跃连接的编码器-解码器分割网络。作者提供的图片。U-Net 具有两个定义特性:

  1. 编码器-解码器网络,深入进行时提取更一般性的特征。

  2. 跳跃连接,重新引入解码器中的详细特征。这两个特性意味着 U-Net 可以使用既详细又一般的特征进行分割。U-Net 最初是为生物医学图像处理引入的,其中分割的准确性非常重要[2]。

4.实现细节

图片

 

前面的部分提供了 U-Net 的非常一般的概述以及它为什么有效。然而,细节在一般理解和实际实施之间起着重要作用。在这里,我将概述一些 U-Net 的实现选择。

5.损失函数

因为目标是二进制掩模(像素值为1表示像素包含对象),用于将输出与地面实况进行比较的常见损失函数是分类交叉熵损失(或在单标签情况下的二元交叉熵损失)。

图片

 

在原始的 U-Net 论文中,额外的权重被添加到损失函数中。这个权重参数有两个作用:它补偿了类别不平衡,并且赋予了分割边界更高的重要性。在我找到的许多 U-Net 实现中,这个额外的权重因子通常没有被使用。

另一个常见的损失函数是 Dice 损失。Dice 损失通过比较两组图像的交集区域与它们的总区域来衡量它们的相似性。请注意,Dice 损失与交并比(IOU)不同。它们衡量了类似的内容,但分母不同。Dice 系数越高,Dice 损失越低。

图片

 

在这里,添加了一个 epsilon 项以避免除以0(epsilon 通常为1)。一些实现,如Milletari等人的实现,在求和之前会将分母中的像素值平方[3]。与交叉熵损失相比,Dice 损失对于不平衡的分割掩模非常鲁棒,这在生物医学图像分割任务中很常见。

6.上采样方法

另一个细节是解码器的上采样方法的选择。以下是一些常见的方法:

双线性插值。该方法使用线性插值来预测输出像素。通常,通过这种方法进行上采样之后会跟随一个卷积层。

最大反池化。这个方法是最大池化的反操作。它使用最大池化操作的索引,并将这些索引填充为最大值。所有其他值设为0。通常,在最大反池化之后会跟随一个卷积层以“平滑”所有缺失的值。

反卷积/转置卷积。有许多关于反卷积的博文。我建议阅读这篇文章作为一个好的视觉指南。

https://towardsdatascience.com/types-of-convolutions-in-deep-learning-717013397f4d

反卷积有两个步骤:首先在原始图像的每个像素周围添加填充,然后应用卷积。在最初的 U-Net 中,使用了一个2x2的转置卷积,步长为2,用于改变空间分辨率和通道深度。

像素重排。这种方法在超分辨率网络中如 SRGAN 中被使用。首先,我们使用卷积将C x H x W特征图转换为(Cr^2) x H x W。然后,像素重排会以马赛克的方式“重排”这些像素,以产生尺寸为C x (Hr) x (Wr)的输出。

不填充还是填充?

卷积层,如果内核大于1x1且没有填充,将产生比输入更小的输出。这对于 U-Net 是个问题。回想一下前面部分中 U-Net 图中,我们将图像的一部分与其解码的部分连接起来。如果我们不使用填充,那么解码后的图像将具有较小的空间尺寸,与编码后的图像相比。

然而,原始的 U-Net 论文没有使用填充。虽然没有给出理由,但我认为这是因为作者不想在图像边缘引入分割错误。相反,他们在连接之前对编码的图像进行了中心裁剪。对于输入尺寸为572 x 572的图像,输出将为388 x 388,损失约为50%。如果要在不填充的情况下运行 U-Net,需要在重叠的图块上多次运行以获取完整的分割图像。

7.U-Net 的运作方式

在这里,我们实现了一个非常简单的类 U-Net 网络,只用于分割椭圆。这个 U-Net 只有3层深度,使用相同的填充,和二元交叉熵损失。更复杂的网络可以在每个分辨率上使用更多的卷积层,或根据需要扩展深度。

import torch
import numpy as np
import torch.nn as nn

class EncoderBlock(nn.Module):        
    # Consists of Conv -> ReLU -> MaxPool
    def __init__(self, in_chans, out_chans, layers=2, sampling_factor=2, padding="same"):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.encoder.append(nn.Conv2d(in_chans, out_chans, 3, 1, padding=padding))
        self.encoder.append(nn.ReLU())
        for _ in range(layers-1):
            self.encoder.append(nn.Conv2d(out_chans, out_chans, 3, 1, padding=padding))
            self.encoder.append(nn.ReLU())
        self.mp = nn.MaxPool2d(sampling_factor)
    def forward(self, x):
        for enc in self.encoder:
            x = enc(x)
        mp_out = self.mp(x)
        return mp_out, x

class DecoderBlock(nn.Module):
    # Consists of 2x2 transposed convolution -> Conv -> relu
    def __init__(self, in_chans, out_chans, layers=2, skip_connection=True, sampling_factor=2, padding="same"):
        super().__init__()
        skip_factor = 1 if skip_connection else 2
        self.decoder = nn.ModuleList()
        self.tconv = nn.ConvTranspose2d(in_chans, in_chans//2, sampling_factor, sampling_factor)

        self.decoder.append(nn.Conv2d(in_chans//skip_factor, out_chans, 3, 1, padding=padding))
        self.decoder.append(nn.ReLU())

        for _ in range(layers-1):
            self.decoder.append(nn.Conv2d(out_chans, out_chans, 3, 1, padding=padding))
            self.decoder.append(nn.ReLU())

        self.skip_connection = skip_connection
        self.padding = padding
    def forward(self, x, enc_features=None):
        x = self.tconv(x)
        if self.skip_connection:
            if self.padding != "same":
                # Crop the enc_features to the same size as input
                w = x.size(-1)
                c = (enc_features.size(-1) - w) // 2
                enc_features = enc_features[:,:,c:c+w,c:c+w]
            x = torch.cat((enc_features, x), dim=1)
        for dec in self.decoder:
            x = dec(x)
        return x

class UNet(nn.Module):
    def __init__(self, nclass=1, in_chans=1, depth=5, layers=2, sampling_factor=2, skip_connection=True, padding="same"):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        out_chans = 64
        for _ in range(depth):
            self.encoder.append(EncoderBlock(in_chans, out_chans, layers, sampling_factor, padding))
            in_chans, out_chans = out_chans, out_chans*2

        out_chans = in_chans // 2
        for _ in range(depth-1):
            self.decoder.append(DecoderBlock(in_chans, out_chans, layers, skip_connection, sampling_factor, padding))
            in_chans, out_chans = out_chans, out_chans//2
        # Add a 1x1 convolution to produce final classes
        self.logits = nn.Conv2d(in_chans, nclass, 1, 1)

    def forward(self, x):
        encoded = []
        for enc in self.encoder:
            x, enc_output = enc(x)
            encoded.append(enc_output)
        x = encoded.pop()
        for dec in self.decoder:
            enc_output = encoded.pop()
            x = dec(x, enc_output)

        # Return the logits
        return self.logits(x)

图片

 

正如我们所看到的,即使没有跳跃连接,U-Net 也可以产生可接受的分割结果,但添加跳跃连接可以引入更精细的细节(请看右侧两个椭圆之间的连接部分)。

8.结论

如果要用一句话来解释 U-Net,那就是 U-Net 就像是用于图像的编码器-解码器,但通过跳跃连接来确保细节不会丢失。U-Net 在许多分割任务中经常使用,近年来还在图像生成任务中取得了成功。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1321195.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

传统软件集成AI大模型——Function Calling

传统软件和AI大模型的胶水——Function Calling 浅谈GPT对传统软件的影响Function Calling做了什么,为什么选择Function CallingFunction Calling简单例子,如何使用使用场景 浅谈GPT对传统软件的影响 目前为止好多人对chatGPT的使用才停留在OpenAI自己提…

【Linux】sed命令使用

sed 命令 sed全称是:Stream EDitor。 sed 命令是利用脚本来处理文本文件。sed 一次只读取一行文本到缓冲区,然后读取命令,对此行进行编辑,然后读取下一行,重复此过程直到结束。 sed 与 vi 的区别 【Linux】 vi / v…

【DOM笔记二】操作元素(修改元素内容,修改常见元素/表单元素/元素样式属性,排他思想,自定义属性操作,应用案例!)

文章目录 4 操作元素4.1 修改元素的内容4.2 修改常见元素的属性案例:分时问候 4.3 修改表单元素属性案例:登录时隐藏/显示密码 4.4 修改元素样式属性4.4.1 行内样式操作 element.style案例1:关闭二维码广告案例2:遍历精灵图案例3&…

C# 图解教程 第5版 —— 第19章 枚举器和迭代器

文章目录 19.1 枚举器和可枚举类型19.2 IEnumerator 接口19.3 IEnumerable 接口19.4 泛型枚举接口19.5 迭代器19.5.1 迭代器块19.5.2 使用迭代器来创建枚举器19.5.3 使用迭代器来创建可枚举类型 19.6 常见迭代器模式19.7 产生多个可枚举类型19.8 将迭代器作为属性19.9 迭代器的…

Postman使用总结--参数化

将 测试数据,组织到 数据文件中,通过脚本的反复迭代,使用不同的数据,达到测试不同用例的目标 数据文件有两种: CSV (类似于excel) 格式简单用这个 文件小 JSON(字典列表&#x…

Vue3-22-组件-插槽的使用详解

插槽是干啥的 插槽 就是 组件中的一个 占位符, 这个占位符 可以接收 父组件 传递过来的 html 的模板值,然后进行填充渲染。 就这么简单,插槽就是干这个的。要说它的优点吧,基本上就是可以使子组件的内容可以被父组件控制&#xf…

Amazon CodeWhisperer 体验

文章作者:jiangbei 1. CodeWhisperer 安装 1.1 先安装 IDEA,如下图,IDEA2022 安装为例: 亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者…

MongoDB中的关系

本文主要介绍MongoDB中的关系。 目录 MongoDB的关系嵌入关系引用关系 MongoDB的关系 MongoDB是一个非关系型数据库,它使用了键值对的方式来存储数据。因此,MongoDB没有像传统关系型数据库中那样的表、行和列的概念。相反,MongoDB中的关系是通…

主馆位置即将售罄“2024北京国际信息通信展会”众多知名企聚京城

2024北京国际信息通信展,将于2024年9月份在北京国家会议中心盛大召开。作为全球信息通信技术领域的重要盛会,此次展会将汇集业内顶尖企业,展示最新的技术成果和产品。 目前,主馆位置即将售罄,华为、浪潮、中国移动、通…

数据结构(Chapter Two -02)—顺序表基本操作实现

在前一部分我们了解线性表和顺序表概念,如果有不清楚可以参考下面的博客: 数据结构(Chapter Two -01)—线性表及顺序表-CSDN博客 首先列出线性表的数据结构: #define MaxSize 50 //定义顺序表最大长度 typedef struct{ElemType data…

【面试】Java最新面试题资深开发-微服务篇(1)

问题九:微服务 什么是微服务架构?它与单体架构相比有哪些优势和劣势?解释一下服务发现和服务注册是什么,它们在微服务中的作用是什么?什么是API网关(API Gateway)?在微服务中它有何…

什么是关键词排名蚂蚁SEO

关键词排名是指通过搜索引擎优化(SEO)技术,将特定的关键词与网站相关联,从而提高网站在搜索引擎中的排名。关键词排名对于网站的流量和用户转化率具有至关重要的影响,因此它是SEO工作中最核心的部分之一。 如何联系蚂…

任务十六:主备备份型防火墙双机热备

目录 目的 器材 拓扑 步骤 一、基本配置 配置各路由器接口的IP地址【省略】 1、配置BGP协议实现Internet路由器之间互联 2、防火墙FW1和FW2接口IP配置与区域划分 3、配置区域间转发策略 4、配置NAPT和默认路由 5、配置VRRP组,并加入Active/standby VGMP管…

06-部署knative-eventing

环境要求 For prototyping purposes 单节点的Kubernetes集群,有2个可用的CPU核心,以及4g内存; For production purposes 单节点的Kubernetes集群,需要至少有6个CPU核心、6G内存和30G磁盘空间多节点的Kubernetes集群中,…

Redis设计与实现之慢查询日志

目录 一、慢查询日志 1、相关数据结构 2、慢查询日志的记录 3、慢查询日志的操作 4、如何设置慢查询的阈值? 5、如何查看慢查询日志的内容? 6、如何分析慢查询日志以找出性能瓶颈? 7、如何优化慢查询以提高Redis的性能? 8…

人工智能_机器学习069_SVM支持向量机_网格搜索_交叉验证参数优化_GridSearchCV_找到最优的参数---人工智能工作笔记0109

然后我们再来说一下SVC支持向量机的参数优化,可以看到 这次我们需要,test_data这个是测试数据,容纳后 train_data这个是训练数据 这里首先我们,导出 import numpy as np 导入数学计算包 from sklearn.svm import SVC 导入支持向量机包 分类器包 def read_data(path): wit…

纵横字谜的答案 Crossword Answers

纵横字谜的答案 Crossword Answers - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 翻译后大概是&#xff1a; 有一个 r 行 c 列 (1<r,c<10) 的网格&#xff0c;黑格为 * &#xff0c;每个白格都填有一个字母。如果一个白格的左边相邻位置或者上边相邻位置没有白格&…

【Vue2】Component template should contain exactly one root element.

问题描述 [plugin:vite:vue2] Component template should contain exactly one root element. If you are using v-if on multiple elements, use v-else-if to chain them instead.原因分析 这个错误通常是由于 Vue 组件的模板中包含多个根元素导致的。Vue 要求组件模板中只…

【计算机网络】TCP协议——3. 可靠性策略效率策略

前言 TCP是一种可靠的协议&#xff0c;提供了多种策略来确保数据的可靠性传输。 可靠并不是保证每次发送的数据&#xff0c;对方都一定收到&#xff1b;而是尽最大可能让数据送达目的主机&#xff0c;即使丢包也可以知道丢包。 目录 一. 确认应答和捎带应答机制 二. 超时重…

Linear Regression线性回归(一元、多元)

目录 介绍&#xff1a; 一、一元线性回归 1.1数据处理 1.2建模 二、多元线性回归 2.1数据处理 2.2数据分为训练集和测试集 2.3建模 介绍&#xff1a; 线性回归是一种用于预测数值输出的统计分析方法。它通过建立自变量&#xff08;也称为特征变量&#xff09;和因变…