【ICML 2023】Hiera详解:一个简单且高效的分层视觉转换器

news2024/11/28 4:34:45

【ICML 2023】Hiera详解:一个简单且高效的分层视觉转换器

  • 0. 引言
  • 1. 模型介绍
  • 2. Hiera介绍
    • 2.1 为什么提出Hiera?
    • 2.2 Hiera 中的 Mask
    • 2.3 空间结构的分离和填充到底如何操作
    • 2.4 为什么使用Mask Unit Attn
  • 3. 简化版理解
  • 4. 总结

0. 引言

虽然现在各种各样版本的 Vision Transformer 模型带来了越来越高的精度,但是同样地,在各种不同版本中存在的各种复杂结构也带来了复杂性的增加。
然而,Hiera 文章的作者认为:增加的各种复杂结构是不必要的。作者提出了一个非常简单的分层视觉变压器 Hiera,它比以前的模型更准确,同时在推理和训练过程中都要快得多

论文名称:Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
论文地址:https://arxiv.org/abs/2306.00989
项目地址:https://github.com/facebookresearch/hiera

1. 模型介绍

为了方便大家对模型的理解,首先介绍模型整体结构,然后分别介绍各个不同的成分。
在这里插入图片描述
如上图所示为模型整体结构图。模型整体构成与 MAE 模型是类似的,包括将图形 mask 一部分,然后通过 Encoder-Decoder 结构重建原图。
具体而言,Hiera 模型的操作流程如下所示:

  1. 首先,将输入图片切分成不同的小的patch,然后根据mask比率对patch进行mask
  2. 然后,将没有maskpatch输入到 Hiera Encoder部分。在Hiera Encoder部分,分为四个阶段,每个新的阶段都会使用 Pooling 对数据进行下采样。此外,在前面两个阶段中使用 Mask Unit Attention 进行注意力的计算。
  3. 最后,将 Hiera Encoder部分得到的输出输入到 ViT Decoder 进行图片还原。

上图所示结构为Hiera-B模型,为了更好的理解模型。下表列举了Hiera-B模型和其它变种模型的参数介绍。
在这里插入图片描述

2. Hiera介绍

Hiera 中存在很多需要人们关注的细节,接下来会对文章细节进行分点描述。

2.1 为什么提出Hiera?

作者希望可以得到一个非常简单的模型。这个简单的模型使用分层 Vision Transformer,但是在分层 Vision Transformer中,不存在 卷积、窗口移位或注意力偏差 这些复杂的模块。为了完成这个任务,作者将MAE的思想用在了Hiera中。当然,在使用时需要进行一系列操作来满足Hiera的要求。

具体而言
在传统的分层 Vision Transformer模型中, 通过 卷积、窗口移位或注意力偏差,增加 Transformer 模型非常需要的空间(和时间)偏差,进而得到精度非常高的分类任务。然而,作者希望可以通过训练一个强的pretext task 来进行 空间(和时间)偏差的学习。而对于这个 pretext task ,作者选用 MAE,通过让网络重构掩码输入补丁。

2.2 Hiera 中的 Mask

MAE中,无需使用pooling结构,每个小的patch被作为一个整体输入到Transformer Block中,因此也不存在对数据之间关联关系的破坏。然而,在分层 Vision Transformer模型中,MAE稀疏的,因为MAE删除 masked tokens 破坏了分层模型依赖的图像的2D网格结构。具体内容如下图所示。
在这里插入图片描述
具体而言:图(b) 所示为:对于原始的MAE来说,MAE删除了掩码单元。如果此时使用CNN结构,两个卷积会跳跃(卷积在原始图像中的表示被切分成两个部分),即MAE破坏了图片的空间结构。图© 所示为:如果直接用掩码单元进行填充可以解决该问题,但是破坏了MAE结构4-10倍加速效果。图(d) 所示为:使用了空间结构(空间分离和填充),将每个掩码单元作为一个结构整体,在内部使用Conv结构。解决了上述问题,但是需要不必要的填充。图(e) 所示为:作者提出的Hiera。令Kernel_size=stride,这样的话任意Maxpooling之间就不会产生重叠

注意:图(d)部分说的空间结构的分离和填充同文章后续说的 shift the mask units to the batch dimension to separate them for pooling (effectively treating each mask unit as an “image”) 是一致的。

2.3 空间结构的分离和填充到底如何操作

空间结构的分离和填充也即将mask units 转移到批处理维度,即掩码单元作为一整个数据进行处理,对于各个掩码单元之间不进行处理。
作者的回答:

转向批处理(或分离和填充)技巧仅适用于我们对论文表2所做的中间MViTv1消融(因为内核重叠)。最终的 Hiera 模型实际上根本不使用它,因为正如您所说,我们可以跳过蒙版单元。

对于 MAE,我们强制要求在每个图像中遮罩相同数量的单位。这样,如果我们像您的示例一样屏蔽,批处理中的每个图像将始终留下 3 个单位。然后,为了回答您的问题(请注意,在此存储库中没有实现向批处理技巧的转变,因为 Hiera 不需要它),假设我们有 4 张 w=96、h=64 的图像,有 3 个通道。

然后我们的输入张量将如下所示:
input_image: shape = [4, 3, 64, 96]

每个令牌都是 4x4 像素,因此一旦我们对图像进行标记化,我们就会下降到:

请注意,分词器还会将通道调暗度提高到 144(例如,对于 L 型号)。tokenized_image: shape = [4, 3, 64, 96] -> tokenizer (patch embed) -> [4, 144, 16, 24]

然后我们提取掩码单元,每个掩码单元都是 8x8 标记:

在这里,每个图像包含 6 个 2x3 排列的掩码单元,如上面的示例所示,其中每个掩码单元是 8x8 (64) 个标记。tokenized_image_mu: shape = [4, 144, (2, 8), (3, 8)] -> permute -> [4, 144, (2, 3), (8, 8)] -> reshape -> [4, 144, 6, 64]

现在,我们从每个图像中删除相同数量的令牌,因此如果遮罩率为 50%,我们将从 3 张图像中的每一个中选择 4 个进行丢弃:
masked_image_mu: shape = [4, 144, 6, 64] -> discard 3 mus from ea. image -> [4, 144, 3, 64]

然后,最后转向批处理技巧:只需将“3”维度移动到批处理维度即可。
shifted_to_batch: shape = [4, 144, 3, 64] -> permute -> [(4, 3), 64, 144] -> reshape -> [12, 64, 144]

然后这是熟悉的形状,您可以传递到任何变压器中。池化和窗口 attn 可以在“64”维度上完成(即 8x8 -> 4x4 -> 2x2 等),如果你填充它,你可以像 MViT 一样做 3x3 内核(这就是为什么我们也称它为“分离和填充”)。[batch, tokens, embed_dim]

这是一种冗长的解释,但希望这是有道理的。

作者的原版回答请查看这个issue:How do we drop tokens?

2.4 为什么使用Mask Unit Attn

在这里插入图片描述
上图所述为 MViTv2Pooling AttnHiera 中的 Mask Unit Attn的区别。

具体而言MViTv2使用 Pooling Attn,通过 K K K V V V 的池化版本执行全局关注。对于大输入(例如视频)来说,这计算成本可能会很昂贵,所以作者选择用Mask Unit Attn 来代替它,它在掩码单元内执行局部注意。这没有开销,因为在前面的操作阶段已经将令牌分组为屏蔽单元。同时,不必像在Swin中那样担心转移(窗口之间没有联系,使用shift来获取全局注意力),因为作者在阶段3和4中使用了全局注意力

此外,Mask Unit AttnWindow Attn 最主要的区别就是:Window Attn 的窗口大小是固定的,Mask Unit Attn 的窗口大小可以在当前分辨率下调整窗口大小以适应掩码单元的大小。具体在论文中的内容如下所示:
在这里插入图片描述
为了更好地帮助大家理解内容,具体可见下面的论文源代码:

class MaskUnitAttention(nn.Module):
    """
    Computes either Mask Unit or Global Attention. Also is able to perform q pooling.

    Note: this assumes the tokens have already been flattened and unrolled into mask units.
    See `Unroll` for more details.
    """

    def __init__(
        self,
        dim: int,
        dim_out: int,
        heads: int,
        q_stride: int = 1,
        window_size: int = 0,
        use_mask_unit_attn: bool = False,
    ):
        """
        Args:
        - dim, dim_out: The input and output feature dimensions.
        - heads: The number of attention heads.
        - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
        - window_size: The current (flattened) size of a mask unit *after* pooling (if any).
        - use_mask_unit_attn: Use Mask Unit or Global Attention.
        """
        super().__init__()

        self.dim = dim
        self.dim_out = dim_out
        self.heads = heads
        self.q_stride = q_stride

        self.head_dim = dim_out // heads
        self.scale = (self.head_dim) ** -0.5

        self.qkv = nn.Linear(dim, 3 * dim_out)
        self.proj = nn.Linear(dim_out, dim_out)

        self.window_size = window_size
        self.use_mask_unit_attn = use_mask_unit_attn

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ Input should be of shape [batch, tokens, channels]. """
        B, N, _ = x.shape
        #	如果use_mask_unit_attn 为True,输入数据x经过线性变换得到qkv会根据q_stride和window_size来进行变化注意力窗口大小
        num_windows = (
            (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
        )

        qkv = (
            self.qkv(x)
            .reshape(B, -1, num_windows, 3, self.heads, self.head_dim)
            .permute(3, 0, 4, 2, 1, 5)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        if self.q_stride > 1:
            # Refer to Unroll to see how this performs a maxpool-Nd
            q = (
                q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim)
                .max(dim=3)
                .values
            )

        if hasattr(F, "scaled_dot_product_attention"):
            # Note: the original paper did *not* use SDPA, it's a free boost!
            x = F.scaled_dot_product_attention(q, k, v)
        else:
            attn = (q * self.scale) @ k.transpose(-1, -2)
            attn = attn.softmax(dim=-1)
            x = (attn @ v)

        x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
        x = self.proj(x)
        return x

3. 简化版理解

可能看了上述的内容,大家对于 Hiera 的整体还是不太理解。这里对文章内容进行口语式解答来帮助大家理解文章内容。
Hiera 这篇文章总的来说是将 MAE 与分层 Vision Transformer 模型相结合,通过 MAE 框架来替代原始分层 Vision Transformer 模型中 卷积、窗口移位或注意力偏差等复杂框架 ,进而学习空间偏差来达到一个非常高的分类精度。在简化模型的同时带来了非常高的精度。

4. 总结

作者创建了一个简单的分层视觉变压器,通过现有的视觉变压器并去除其所有的信号,同时通过MAE预训练为模型提供空间偏差。由此产生的架构Hiera比目前在图像识别任务上的工作更有效,并且在视频任务上超越了最先进的技术。如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

到此,有关TPS的内容就基本讲完了。如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/628775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker 安装Tomcat、实现Tomcat集群

文章目录 1、寻找Tomcat镜像2、下载tomcat镜像(下载最近版本)(1) docker pull tomcat 下载(2) 查看镜像是否安装成功 3、启动容器(跟安装Nginx一样)4、 测试tomcat(报错解决)5、 搭建Tomcat集群 1、寻找Tom…

访问修饰符private,default,protected,public访问等级区别

private:private是最严格的访问修饰符,它将成员声明为私有的。私有成员只能在声明它们的类内部访问,其他类无法直接访问私有成员。这样可以确保数据的封装性和安全性。 default(默认):如果没有明确指定访问…

Pytest 分组测试

有时候需要针对不同的测试环境跑不同的测试用例,如:冒烟测试、sit、uat、prd,所以给自动化测试用例做标记分组是很有必要的,pytest.mark 可以轻松实现这个功能。首先需要注册自定义标记。 注册marks 有3中方法注册marks&#xf…

【Apache Pinot】浅析 Pinot 的 Table,Index 和 Tenant 原理

本文属于基础篇幅,不会涉及过深入的原理,主要还是如何用好 Pinot 背景 单独讲 Table 概念可能有些许单薄,本文会扩展场景,讲解表的部分原理,表与表之间的相互影响,租户是怎么作用到表的,增加字…

一位年薪35W的测试被开除,回怼的一番话,令人沉思

一位年薪35W测试工程师被开除回怼道:“反正我有技术,在哪不一样” 一技傍身,万事不愁,当我们掌握了一技之长后,在职场上说话就硬气了许多,不用担心被炒,反过来还可以炒了老板,这一点…

基于深度学习的高精度袋鼠检测识别系统(PyTorch+Pyside6+YOLOv5模型)

摘要:基于深度学习的高精度袋鼠检测识别系统可用于日常生活中或野外来检测与定位袋鼠目标,利用深度学习算法可实现图片、视频、摄像头等方式的袋鼠目标检测识别,另外支持结果可视化与图片或视频检测结果的导出。本系统采用YOLOv5目标检测模型…

活跃主机发现技术指南

活跃主机发现技术指南 1.活跃主机发现技术简介2.基于ARP协议的活跃主机发现技术3.基于ICMP协议的活跃主机发现技术4.基于TCP协议的活跃主机发现技术5.基于UDP协议的活跃主机发现技术6.基于SCTP协议的活跃主机发现技术7.主机发现技术的分析 1.活跃主机发现技术简介 在生活中有这…

继电保护名词解释三

第三章:电网的相间电流、电压保护和方向性相间电流、电压保护 1. 瞬时电流速断保护:对于仅反应于电流增大而瞬时动作的电流保护。 2. 保护装置的起动电流:能够使保护装置起动的最小电流值。 3. 系统最大运行方式:通过保护装置的…

了解服务级别协议(SLA)在 ITSM 中的重要性

什么是服务级别协议 根据ITIL 4,SLA是服务提供商和客户之间的书面协议,用于确定所需的服务和预期的服务水平。这些协议可以是正式的,也可以是非正式的。 在 ITSM 的上下文中,SLA 有助于设置和管理最终用户在提出请求时的期望 或…

如何导出Axure原型设计中的图片?零基础入门教程

Axure 是一款广为人知的原型设计工具,特别适用于新手产品经理或产品设计初学者。然而,如果用户想要在浏览器中预览 Axure 原型图,需要安装插件才能实现。而安装完 Axure RP Chrome 插件后,还需要在扩展程序中选择 "允许访问文…

类和对象【5】日期类的实现

全文目录 引言实现日期类概述默认成员函数构造函数析构函数拷贝构造赋值重载 功能运算符重载日期间的比较日期与天数日期-与-天数日期前置与后置日期前置- -与后置- -日期 - 日期 输入输出重载(友元) 代码总览头文件源文件main函数 总结 引言 类和对象1…

详解Java异常和异常面试题(上)

1.异常的体系结构 2.从程序执行过程,看编译时异常和运行时异常 编译时异常:执行javac.exe命名时,可能出现的异常 运行时异常:执行java.exe命名时,出现的异常 1.运行时异常  是指编译器不要求强制处置的异常。一般是…

网络安全:信息收集专总结【社会工程学】

前言 俗话说“渗透的本质也就是信息收集”,信息收集的深度,直接关系到渗透测试的成败,打好信息收集这一基础可以让测试者选择合适和准确的渗透测试攻击方式,缩短渗透测试的时间。 一、思维导图 二、GoogleHacking 1、介绍 利用…

HVV常问的Web漏洞(护网蓝初面试干货)

目录 1、SQL注入 (1)原理 (2)分类 (3)防御 2、XSS (1)原理 (2)分类 3、中间件(解析漏洞) (1)IIS6.X …

【AI实战】开源大语言模型LLMs汇总

【AI实战】开源大语言模型LLM汇总 大语言模型开源大语言模型1、LLaMA2、ChatGLM - 6B3、Alpaca4、PandaLLM5、GTP4ALL6、DoctorGLM (MedicalGPT-zh v2)7、MedicalGPT-zh v18、Cornucopia-LLaMA-Fin-Chinese9、minGPT10、InstructGLM11、FastChat12、Luot…

在线原型设计工具推荐

原型设计是每个产品经理必备的基本技能。 本文从即时设计原型设计的步骤开始,帮助您快速使用即时设计制作高还原度、丰富互动的产品原型。 利用即时设计进行原型设计的优势 快速启动原型设计工作 借助即时设计内置设计系统和社区资源,包括大量原型设…

Hive执行计划之什么是hiveSQL向量化模式及优化详解

目录 文章目录 1.什么是hive向量化模式2.Hive向量化模式支持的使用场景2.1 hive向量化模式使用前置条件2.2 向量模式支持的数据类型2.3 向量化模式支持的函数 3.如何查看hiveSQL向量化运行信息3.1 explain vectorization only只查询向量化描述信息内容3.2 explain vectorizati…

javaScript蓝桥杯---JSON 生成器

目录 一、介绍二、准备三、目标四、代码五、完成 一、介绍 JSON 已经是大家必须掌握的知识点,JSON 数据格式为前后端通信带来了很大的便利。在开发中,前端开发工程师可以借助于 JSON 生成器快速构建一个 JSON 用来模拟数据。 本题请你开发一个简易的 J…

chatgpt赋能python:Python快速安装库

Python快速安装库 Python作为一种功能强大且易于学习的编程语言,已经成为许多开发人员的首选。在Python中,库是重要的一部分,它们提供了各种功能和工具来简化开发过程。安装这些库的过程可能会比较繁琐,但我们可以通过一些简单的…

PHP的学习--Traits新特性

自 PHP 5.4.0 起,PHP 实现了代码复用的一个方法,称为 traits。 Traits 是一种为类似 PHP 的单继承语言而准备的代码复用机制。Trait 为了减少单继承语言的限制,使开发人员能够自由地在不同层次结构内独立的类中复用方法集。Traits 和类组合的…