使用 PyTorch 进行高效图像分割:第 4 部分

news2025/1/14 1:24:28

一、说明

        在这个由 4 部分组成的系列中,我们将使用 PyTorch 中的深度学习技术从头开始逐步实现图像分割。本部分将重点介绍如何实现基于视觉转换器的图像分割模型。

 

图 1:使用视觉转换器模型架构运行图像分割的结果。

        从上到下,输入图像、地面实况分割掩码和预测分割掩码。来源:作者

二、文章大纲

        在本文中,我们将参观风靡深度学习世界的变压器架构。变压器是一种多模态架构,可以对语言、视觉和音频等不同模态进行建模。

        在本文中,我们将

  1. 了解变压器架构和所涉及的关键概念
  2. 了解视觉变压器架构
  3. 介绍从头开始编写的视觉转换器模型,以便您可以欣赏所有构建块和移动部件
  4. 跟踪输入到该模型的输入张量,并检查它如何改变形状
  5. 使用此模型对牛津 IIIT 宠物数据集执行图像分割
  6. 观察此分割任务的结果
  7. 简要介绍SegFormer,一种用于语义分割的最先进的视觉转换器

        在本文中,我们将引用此笔记本中的代码和结果进行模型训练。如果要重现结果,则需要一个 GPU 来确保第一个笔记本在合理的时间内完成运行。

三、本系列文章

        本系列面向所有深度学习经验水平的读者。如果您想了解深度学习和视觉AI的实践以及一些扎实的理论和实践经验,那么您来对地方了!这将是一个由 4 部分组成的系列,包含以下文章:

  1. 概念和想法
  2. 基于 CNN 的模型
  3. 深度可分离卷积
  4. 基于视觉变压器的模型(本文)

        让我们从对变压器架构的介绍和直观理解开始我们的视觉变压器之旅。

四、变压器架构

        我们可以将变压器架构视为交错的通信计算层的组合。图 2 直观地描述了这一想法。变压器有N个处理单元(图3中的N为2),每个单元负责处理输入的1/N部分。为了使这些处理单元产生有意义的结果,每个处理单元都需要具有输入的全局视图。因此,系统将有关每个处理单元中的数据的信息重复传达给每个其他处理单元;使用从每个处理单元到每个其他处理单元的红色、绿色和蓝色箭头进行显示。接下来是基于此信息进行的一些计算。在充分重复此过程后,模型能够产生预期的结果。

图 2:变压器中的交错通信和计算。该图像仅显示了 2 层通信和计算。

        值得注意的是,大多数在线资源通常会讨论变压器的编码器和解码器,如题为“注意力是你所需要的”的论文中所述。但是,在本文中,我们将仅描述变压器的编码器部分。

        让我们仔细看看变压器中的通信和计算构成。

4.1 变压器中的通信:注意

        在变压器中,通信由称为注意力层的层实现。在 PyTorch 中,这被称为 MultiHeadAttention。我们稍后会谈到这个名字的原因。

        文档说:

“允许模型共同关注来自不同表示子空间的信息,如论文中所述:注意力就是你所需要的

        注意力机制使用形状(批处理、长度、特征)的输入张量 x,并生成形状相似的张量 y,以便根据张量在同一实例中关注的其他输入更新每个输入的特征。因此,在大小为“长度”的实例中,长度为“特征”的每个张量的特征会根据其他每个张量进行更新。这就是注意力机制的二次成本的用武之地。

图3:相对于句子中其他单词显示的单词“it”的注意。我们可以看到,“它”是在同一句话中注意“动物”、“太”和“tire(d)”等词。 

        在视觉变压器的上下文中,变压器的输入是图像。假设这是一个 128 x 128(宽度、高度)的图像。我们将其分成多个较小的大小块(16 x 16)。对于 128 x 128 的图像,我们得到 64 个补丁(长度),每行 8 个补丁和 8 行补丁。

        这 64 个大小为 16 x 16 像素的块中的每一个都被视为变压器模型的单独输入。在不深入细节的情况下,将此过程视为由 64 个不同的处理单元驱动就足够了,每个处理单元都在处理单个 16x16 图像补丁。

        在每一轮中,每个处理单元中的注意力机制负责查看它负责的图像补丁,并查询其余 63 个处理单元中的每一个,以询问它们可能相关和有用的任何信息,以帮助它有效地处理自己的图像补丁。

        通过注意力的沟通步骤之后是计算,我们接下来将研究。

4.2 变压器中的计算:多层感知器

        变压器中的计算只不过是一个多层感知器(MLP)单元。该单元由 2 个线性层组成,介于两者之间具有 GeLU 非线性。也可以考虑使用其他非线性。该单元首先将输入投影到大小的 4 倍,然后将其重新投影回 1 倍,这与输入大小相同。

        在我们将在笔记本中看到的代码中,此类称为多层感知器。代码如下所示。

class MultiLayerPerceptron(nn.Sequential):
    def __init__(self, embed_size, dropout):
        super().__init__(
            nn.Linear(embed_size, embed_size * 4),
            nn.GELU(),
            nn.Linear(embed_size * 4, embed_size),
            nn.Dropout(p=dropout),
        )
    # end def
# end class

        现在我们了解了变压器架构的高级工作原理,让我们把注意力集中在视觉转换器上,因为我们将执行图像分割。

五、视觉转换器

        视觉转换器最初是由题为“图像价值16x16字:大规模图像识别的变压器”的论文介绍的。本文讨论了作者如何将原版变压器架构应用于图像分类问题。这是通过将图像拆分为大小为 16x16 的补丁,并将每个补丁视为模型的输入令牌来完成的。转换器编码器模型被馈送这些输入令牌,并被要求预测输入图像的类。

图 4:来源:用于大规模图像识别的变压器。

        在我们的例子中,我们对图像分割感兴趣。我们可以将其视为像素级分类任务,因为我们打算预测每个像素的目标类。

        我们对原版视觉转换器进行了一个小但重要的更改,并更换了MLP头,以便由MLP头进行像素级分类。我们在输出中有一个线性层,由每个补丁共享,其分割掩模由视觉变压器预测。此共享线性层预测作为模型输入发送的每个补丁的分割掩码。

        在视觉转换器的情况下,大小为 16x16 的补丁被视为等效于特定时间步长的单个输入令牌。

图 5:用于图像分割的视觉转换器的端到端工作。使用此笔记本生成的图像。

5.1 在视觉转换器中构建张量维度的直觉

        当使用深度CNN时,我们大部分使用的张量维度是(N,C H,W),其中字母代表以下内容:

  • N:批量大小
  • C:通道数
  • H:身高
  • W:宽度

        您可以看到这种格式面向 2D 图像处理,因为它闻起来非常特定于图像的特征。

        另一方面,有了变压器,事情变得更加通用和领域无关。我们将在下面看到的内容适用于视觉、文本、NLP、音频或其他输入数据可以表示为序列的问题。值得注意的是,当张量流经我们的视觉转换器时,在张量的表示中几乎没有视觉特定偏差。

        在使用转换器和一般情况下,我们希望张量具有以下形状:(B,T,C),其中字母代表以下内容:

  • B:批量大小(与CNN相同)
  • T:时间维度或序列长度。此维度有时也称为 L。在视觉变压器的情况下,每个图像块对应于这个维度。如果我们有 16 个图像补丁,那么 T 维度的值将为 16
  • C:通道或嵌入大小维度。此维度有时也称为 E。处理图像时,大小为 3x16x16(通道、宽度、高度)的每个补丁通过补丁嵌入层映射到大小为 C 的嵌入。我们稍后会看到如何做到这一点。

        让我们深入了解输入图像张量在预测分割掩码的过程中如何变异和处理。

5.2 视觉转换器中张量的旅程

        在深度CNN中,张量的旅程看起来像这样(在UNet,SegNet或其他基于CNN的架构中)。

        输入张量通常是形状为 (1, 3, 128, 128)。该张量经过一系列卷积和最大池化操作,其中其空间维度减小,通道维度增加,通常每个增加 2 倍。这称为特征编码器。在此之后,我们执行反向操作,增加空间维度并减少通道维度。这称为特征解码器。在解码过程之后,我们得到一个形状的张量(1,64,128,128)。然后将其投影到我们希望的输出通道 C 的数量中,使用 1x128 无偏差的逐点卷积作为 (128, C, 1, 1)。

图 6:张量形状通过用于图像分割的深度 CNN 的典型进展。 

        使用视觉变压器时,流程要复杂得多。让我们看一下下面的一张图片,然后尝试了解张量如何在每一步中转换形状。

图 7:张量形状通过用于图像分割的视觉转换器的典型进展。 

        让我们更详细地看一下每个步骤,看看它如何更新流经视觉转换器的张量的形状。为了更好地理解这一点,让我们为张量维度取具体值。

  1. 批量规范化:输入和输出张量具有形状 (1, 3, 128, 128)。形状保持不变,但值归一化为零均值和单位方差。
  2. 图像到补丁:形状 (1, 3, 128, 128) 的输入张量被转换为 16x16 图像的堆叠块。输出张量具有形状 (1, 64, 768)。
  3. 补丁嵌入:补丁嵌入层将 768 个输入通道映射到 512 个嵌入通道(在本例中)。输出张量的形状为 (1, 64, 512)。补丁嵌入层基本上只是一个 nn。PyTorch 中的线性层。
  4. 位置嵌入:位置嵌入层没有输入张量,但有效地贡献了一个可学习的参数(PyTorch 中的可训练张量),其形状与补丁嵌入相同。这是形状(1,64,512)。
  5. 加:贴片和位置嵌入分段地加在一起,以产生视觉变压器编码器的输入。这个张量的形状是(1,64,512)。您会注意到,视觉变压器的主要主力,即编码器基本上保持这种张量形状不变。
  6. 变压器编码器:形状为(1,64,512)的输入张量流经多个变压器编码器块,每个转换器编码器块具有多个注意头(通信),后跟一个MLP层(计算)。张量形状保持不变,如 (1, 64, 512)。
  7. 线性输出投影:如果我们假设要将每个图像分成 10 个类,那么我们需要每个大小为 16x16 的补丁有 10 个通道。该 nn.用于输出投影的线性层现在会将 512 个嵌入通道转换为 16x16x10 = 2560 个输出通道,此张量将类似于 (1, 64, 2560)。在上图中 C' = 10。理想情况下,这将是一个多层感知器,因为MLP 是通用函数近似器,但我们使用单个线性层,因为这是一项教育练习
  8. 补丁到映像:该层将编码为 (64, 1, 64) 张量的 2560 个补丁转换回看起来像分割掩码的东西。这可以是 10 个单通道图像,或者在本例中是单个 10 通道图像,每个通道是 10 个类别之一的分割掩码。输出张量的形状为 (1, 10, 128, 128)。

         就是这样 — 我们已经使用视觉转换器成功分割了输入图像!接下来,让我们看一个实验以及一些结果。

5.3 视觉变压器的实际应用

        此笔记本包含此部分的所有代码。

        就代码和类结构而言,它非常模仿上面的框图。上面提到的大多数概念都与此笔记本中的类名 1:1 对应。

        有一些与注意力层相关的概念是我们模型的关键超参数。我们之前没有提到多头关注的细节,因为我们提到它超出了本文的范围。如果您对变压器中的注意力机制没有基本的了解,我们强烈建议您在继续之前阅读上述参考资料。

        我们将以下模型参数用于视觉变压器进行分割。

  1. 补丁嵌入层的 768 个嵌入维度
  2. 12 变压器编码器块
  3. 每个变压器编码器块中有 8 个注意头
  4. 多头注意力和 MLP 中 20% 的辍学率

这种配置可以在 VisionTransformerArgs Python 数据类中看到。

@dataclass
class VisionTransformerArgs:
    """Arguments to the VisionTransformerForSegmentation."""
    image_size: int = 128
    patch_size: int = 16
    in_channels: int = 3
    out_channels: int = 3
    embed_size: int = 768
    num_blocks: int = 12
    num_heads: int = 8
    dropout: float = 0.2
# end class

        在模型训练和验证期间使用了与以前类似的配置。配置指定如下。

  1. 随机水平翻转颜色抖动数据增强应用于训练集以防止过度拟合
  2. 在非宽高比保留调整大小操作中将图像大小调整为 128x128 像素
  3. 不会对图像应用任何输入归一化,而是使用批量归一化层作为模型的第一层
  4. 该模型使用 LR 为 50.0 的 Adam 优化器和每 0004 个 epoch 将学习率衰减 0.8 倍的 StepLR 调度器训练 12 个 epoch
  5. 交叉熵损失函数用于将像素分类为属于宠物、背景或宠物边框

        该模型具有 86.28M 参数,经过 85 个训练周期后,验证准确率为 89.50%。这低于深度 CNN 模型在 88 个训练周期后达到的 28.20% 的准确率。这可能是由于一些需要通过实验验证的因素。

  1. 最后一个输出投影图层为单个 nn。线性而非多层感知器
  2. 16x16 色块大小太大,无法捕获更细粒度的细节
  3. 训练时期不足
  4. 没有足够的训练数据 - 众所周知,与深度CNN模型相比,转换器模型需要更多的数据来有效训练
  5. 学习率太低

我们绘制了一个 gif,显示了模型如何学习预测验证集中 21 张图像的分割掩码。

图 8:显示图像分割模型的视觉转换器预测的分割掩码进程的 gif。 

        我们在早期训练时期注意到一些有趣的事情。预测的分割掩码有一些奇怪的阻塞伪影。我们能想到的唯一原因是,我们将图像分解为大小为 16x16 的补丁,经过很少的训练时期,模型除了一些非常粗略的信息之外,没有学到任何有用的东西关于这个 16x16 补丁通常被宠物或背景像素覆盖。

图 9:使用视觉转换器进行图像分割时,预测分割中看到的阻塞伪影会掩盖。 

        现在我们已经看到了一个基本的视觉转换器,让我们把注意力转向用于分割任务的最先进的视觉转换器。

5.4 SegFormer:使用转换器进行语义分割

        本文于 2021 年提出了 SegFormer 架构。我们在上面看到的转换器是SegFormer 架构的简化版本。

图 10:SegFormer 架构。资料来源: 

        最值得注意的是,SegFormer:

  1. 生成 4 组映像,其中包含大小为 4x4、8x8、16x16 和 32x32 的修补程序,而不是具有大小为 16x16 的修补程序的单个修补映像
  2. 使用 4 个变压器编码器块,而不仅仅是 1 个。这感觉就像一个模型合奏
  3. 在自我注意的前阶段和后期阶段使用卷积
  4. 不使用位置嵌入
  5. 每个变压器模块以空间分辨率 H/4 x W/4、H/8 x W/8、H/16 x W/16 和 H/32、W/32 处理图像
  6. 同样,当空间维度减小时,通道也会增加。这感觉类似于深度CNN
  7. 对多个空间维度的预测进行上采样,然后在解码器中合并在一起
  8. MLP 将所有这些预测结合起来,提供最终预测
  9. 最终的预测是在空间维度H/4,W/4,而不是在H,W。

六、结论

在本系列的第 4 部分中,我们特别介绍了变压器架构和视觉变压器。我们对视觉变压器的工作原理以及视觉变压器的通信和计算阶段所涉及的基本构建块有了直观的理解。我们看到了视觉转换器采用的基于补丁的独特方法,用于预测分割掩模,然后将预测组合在一起。

我们回顾了一个实验,该实验显示了视觉转换器的实际作用,并能够将结果与深度CNN方法进行比较。虽然我们的视觉转换器不是最先进的,但它能够取得相当不错的结果。我们提供了对最先进的方法的一瞥,例如SegFormer。

现在应该很清楚,与基于深度CNN的方法相比,变压器具有更多的活动部件,并且更复杂。从原始FLOP的角度来看,变压器有望提高效率。在变压器中,唯一计算繁重的实层是nn。线性。这是在大多数架构上使用优化的矩阵乘法实现的。由于这种架构的简单性,与基于深度CNN的方法相比,变压器有望更容易优化和加速。

恭喜你走到了这一步!我们很高兴您喜欢阅读有关 PyTorch 中高效图像分割的系列文章。如果您有任何问题或意见,请随时将其留在评论部分。

七、延伸阅读

注意力机制的细节超出了本文的范围。此外,您还可以参考许多高质量的资源来详细了解注意力机制。以下是我们强烈推荐的一些内容。

  1. 图解变压器
  2. 使用 PyTorch 从头开始 NanoGPT

我们将在下面提供文章的链接,这些文章提供了有关视觉转换器的更多详细信息。

  1. 在 PyTorch 中实现视觉转换器 (ViT):本文详细介绍了在 PyTorch 中实现用于图像分类的视觉转换器。值得注意的是,它们的实现使用 einops,我们避免这样做,因为这是一个以教育为中心的练习(我们建议学习和使用 einops 以提高代码可读性)。我们改用原生 PyTorch 运算符来排列和重新排列张量维度。此外,作者在一些地方使用 Conv2d 而不是线性图层。我们希望构建一个完全不使用卷积层的视觉转换器实现。
  2. 视觉转换器:AI之夏
  3. 在 PyTorch 中实现 SegFormer

德鲁夫·马塔尼

·

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/889160.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文件批量改名,一键将西班牙语文件批量改名为中文!

亲爱的用户们,您是否经常面对大量以西班牙语命名的文件,需要将其转换为中文?手动逐个改名无疑是一项繁琐且耗时的任务。现在,我们为您带来一款便捷的简繁转换工具,让您一键将西班牙语文件批量改名为中文! …

STM32--ADC模数转换

文章目录 ADC简介逐次逼近型ADCADC框图转换模式数据对齐转换时间校准ADC基本结构ADC单通道工程代码: ADC简介 STM32的ADC(Analog-Digital Converter)模拟-数字转换器,是一种逐次逼近型模拟数字转换器,可以将引脚上连续…

node获取抖音直播间Id

node获取抖音直播间Id 信息位置 直播间信息存放在id是RENDER_DATA的script标签里 安装依赖 npm install fetch cheerio # 或 pnpm install fetch cheerionode代码 // room.js const fetch require("fetch"); const cheerio require("cheerio"); // co…

LVS-DR的RS进行ARP抑制的原因和LVS持久连接配置

一.RS的ARP抑制 1.为什么要抑制 2.如何抑制 (1)修改/etc/sysctl.conf文件,增加以下内容 (2)命令行临时设置 二.LVS持久连接 1.客户端持久连接 2.端口持久连接 3.防火墙标记持久连接 一.RS的ARP抑制 1.为什么要…

提示丢失vcomp140.dll怎么办?如何快速修复vcomp140.dll丢失问题

最近我遇到了一个程序启动失败的问题,错误提示显示缺少了vcomp140.dll文件。经过一番研究和尝试,我终于成功修复了这个问题。在这里,我将分享一下我的修复方法。 目录 vcomp140.dll是什么? 如何快速修复呢? vcomp140…

mysql 01.三范式,数据类型

01.概念的区分: mysql是属于DBMS层次的,sql语句是用于DBMS的语句。 02.sql语句详细介绍: SQL的概述Structure Query Language(结构化查询语言)简称SQL,它被美国国家标准局(ANSI)确定为关系型数据库语言的美国标准,后…

Nginx安全加固,版本隐藏及HTTP请求头修改方法

1 隐藏nginx版本号 1.1 引言 nginx作为目前较为流行的http server软件,其相关的安全漏洞也非常多,攻击者可以根据我们的nginx版本来了解到相关的漏洞从而针对性的进行攻击。 通过新版本的nginx都会修复一些老版本的已知漏洞,但有时候我们生…

Android Studio实现读取本地相册文件并展示

目录 原文链接效果 代码activity_main.xmlMainActivity 原文链接 效果 代码 activity_main.xml 需要有一个按钮和image来展示图片 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk…

华硕win11笔记本双系统deepin 解决更改硬盘模式为AHCI后无法进入Windows的问题

华硕win11笔记本双系统deepin 解决更改硬盘模式为AHCI后无法进入Windows的问题 重新将硬盘模式改为Intel RST Premium With Intel Optane System Acceleration(RAID)然后才能进入Windows&#xff0c;但是改了之后又不能进deepin了&#xff0c;需要再将硬盘模式改为AHCI才能进de…

【100天精通python】Day37:GUI界面编程_PyQT从入门到实战(上)

目录 专栏导读 1 PyQt6 简介&#xff1a; 1.1 安装 PyQt6 和相关工具&#xff1a; 1.2 PyQt6 基础知识&#xff1a; 1.2.1 Qt 的基本概念和组件&#xff1a; 1.2.2 创建和使用 Qt 窗口、标签、按钮等基本组件 1.2.3 布局管理器&#xff1a;垂直布局、水平布局、网格布局…

生信豆芽菜-细胞丰度比较

网址&#xff1a;生信豆芽菜-细胞丰度比较 一、使用方法 1、数据准备 这里需要上传一个行为样本&#xff0c;列为细胞评分的矩阵数据 分组信息 2、选择检验的方法&#xff0c;其中两组的可以选择用wilcox.test/test&#xff0c;三组的可以选择用kruskat.test/anova 3、分组…

最强自动化测试框架Playwright(29)-文件选择对象

FileChooser对象通过page.on("filechoose")事件监听。 如下代码实现点击百度搜图按钮&#xff0c;上传文件进行搜索。 from playwright.sync_api import Playwright, sync_playwright, expectdef run(playwright: Playwright) -> None:browser playwright.chro…

php+echarts实现数据可视化实例2

效果: 代码 php <?php include(includes/session.inc); include(includes/SQL_CommonFunctions.inc); ?> <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible"…

notepad++正则表达式匹配方括号及里面的内容

可以用于去除注释 在notepad也可以直接使用 (\[.*?\])

【Redis】Redis 的学习教程(五)之 SpringBoot 集成 Redis

在前几篇文章中&#xff0c;我们详细介绍了 Redis 的一些功能特性以及主流的 java 客户端 api 使用方法。 在当前流行的微服务以及分布式集群环境下&#xff0c;Redis 的使用场景可以说非常的广泛&#xff0c;能解决集群环境下系统中遇到的不少技术问题&#xff0c;在此列举几…

【项目管理】PMP备考宝典-第三章《人》

文章目录 第一节&#xff1a;概述1.项目涉及的人2.项目经理3.团队4.干系人 第二节&#xff1a;原则1.有效的干系人参与2.成为勤勉尊重关心他人的管家3.创建协作的项目团队环境4.展现领导力行为 第三节&#xff1a;任务1.定义团队的基本原则2.建设团队3.领导团队4.管理冲突5.凝聚…

2022年工作架构分析

mpmw自动化流程工具 schema动态数据 Schema 本身是一个JSON &#xff0c;Schema 通过一些特定字段描述和定义 JSON的数据结构。 最常见的表单通过类XML语法定义。一些库支持通过一些特定结构的 JSON (Schema)来生成类XML标签。 formily 是其中实现之一。 表单设计器通过可视…

S03-快速填充,批量提取和组合数据的神奇

视频教程 快速填充&#xff08;Ctrl➕E&#xff09; 作用&#xff1a;对数据进行拆分重组合并 方式 1 CtrlE 2 双击加号选择智能填充 快速填充数据<>智能填充 开始☞填充☞快速填充&#xff08;注意附近一个单元格&#xff0c;是一定要有数据的&#xff0c;不能出现单独…

前后端分离-毕业生就业服务平台SpringBoot+Redis+Vue校园实习招聘指导java jsp源代码

本项目为前几天收费帮学妹做的一个项目&#xff0c;Java EE JSP项目&#xff0c;在工作环境中基本使用不到&#xff0c;但是很多学校把这个当作编程入门的项目来做&#xff0c;故分享出本项目供初学者参考。 一、项目描述 前后端分离-毕业生就业服务平台SpringBootRedisVue 系…

招生老师如何制作发布录取通知书文案?这个开发教程一看就会

作为一名负责招生的老师&#xff0c;录取通知的公布是整个招生环节最重要的一环&#xff0c;如何快速搞定这项工作&#xff1f;传统的公布方式需要设及技术开发、服务器搭建等&#xff0c;一起来看看传统方法制作录取通知查询系统的教程&#xff08;结尾有惊喜&#xff09;&…