【ICML 2023】Hiera详解:一个简单且高效的分层视觉转换器
- 0. 引言
- 1. 模型介绍
- 2. Hiera介绍
- 2.1 为什么提出Hiera?
- 2.2 Hiera 中的 Mask
- 2.3 空间结构的分离和填充到底如何操作
- 2.4 为什么使用Mask Unit Attn
- 3. 简化版理解
- 4. 总结
0. 引言
虽然现在各种各样版本的 Vision Transformer 模型带来了越来越高的精度
,但是同样地,在各种不同版本中存在的各种复杂结构
也带来了复杂性
的增加。
然而,Hiera
文章的作者认为:增加的各种复杂结构是不必要的
。作者提出了一个非常简单的分层视觉变压器 Hiera
,它比以前的模型更准确
,同时在推理和训练过程中都要快得多
。
论文名称:Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
论文地址:https://arxiv.org/abs/2306.00989
项目地址:https://github.com/facebookresearch/hiera
1. 模型介绍
为了方便大家对模型的理解,首先介绍模型整体结构,然后分别介绍各个不同的成分。
如上图所示为模型整体结构图
。模型整体构成与 MAE
模型是类似的,包括将图形 mask
一部分,然后通过 Encoder-Decoder
结构重建原图。
具体而言,Hiera
模型的操作流程如下所示:
- 首先,将输入图片切分成不同的小的
patch
,然后根据mask比率
对patch进行mask
- 然后,将没有
mask
的patch
输入到Hiera Encoder
部分。在Hiera Encoder
部分,分为四个阶段
,每个新的阶段都会使用Pooling
对数据进行下采样
。此外,在前面两个阶段
中使用Mask Unit Attention
进行注意力的计算。 - 最后,将
Hiera Encoder
部分得到的输出输入到ViT Decoder
进行图片还原。
上图所示结构为Hiera-B
模型,为了更好的理解模型。下表列举了Hiera-B
模型和其它变种模型的参数介绍。
2. Hiera介绍
在 Hiera
中存在很多需要人们关注的细节,接下来会对文章细节进行分点描述。
2.1 为什么提出Hiera?
作者希望可以得到一个非常简单
的模型。这个简单的模型使用分层 Vision Transformer
,但是在分层 Vision Transformer
中,不存在 卷积、窗口移位或注意力偏差
这些复杂的模块。为了完成这个任务,作者将MAE
的思想用在了Hiera
中。当然,在使用时需要进行一系列操作来满足Hiera
的要求。
具体而言:
在传统的分层 Vision Transformer
模型中, 通过 卷积、窗口移位或注意力偏差
,增加 Transformer
模型非常需要的空间(和时间)偏差
,进而得到精度非常高
的分类任务。然而,作者希望可以通过训练一个强的pretext task
来进行 空间(和时间)偏差的学习
。而对于这个 pretext task
,作者选用 MAE
,通过让网络重构掩码输入补丁。
2.2 Hiera 中的 Mask
在MAE
中,无需使用pooling
结构,每个小的patch
被作为一个整体输入到Transformer Block
中,因此也不存在对数据之间关联关系的破坏
。然而,在分层 Vision Transformer
模型中,MAE
是稀疏
的,因为MAE
删除 masked tokens
破坏了分层模型依赖的图像的2D网格结构
。具体内容如下图所示。
具体而言:图(b) 所示为:对于原始的MAE
来说,MAE
删除了掩码单元
。如果此时使用CNN结构
,两个卷积会跳跃(卷积在原始图像中的表示被切分成两个部分
),即MAE
破坏了图片的空间结构
。图© 所示为:如果直接用掩码单元进行填充
可以解决该问题,但是破坏了MAE
结构4-10倍
的加速
效果。图(d) 所示为:使用了空间结构(空间分离和填充)
,将每个掩码单元作为一个结构整体,在内部使用Conv结构。解决了上述问题,但是需要不必要的填充
。图(e) 所示为:作者提出的Hiera
。令Kernel_size=stride
,这样的话任意Maxpooling
之间就不会产生重叠
。
注意:图(d)部分说的空间结构的分离和填充同文章后续说的 shift the mask units to the batch dimension to separate them for pooling (effectively treating each mask unit as an “image”)
是一致的。
2.3 空间结构的分离和填充到底如何操作
空间结构的分离和填充也即将mask units 转移到批处理维度,即掩码单元作为一整个数据进行处理,对于各个掩码单元之间不进行处理。
作者的回答:
转向批处理(或分离和填充)技巧仅适用于我们对论文表2所做的中间MViTv1消融(因为内核重叠)。最终的 Hiera 模型实际上根本不使用它,因为正如您所说,我们可以跳过蒙版单元。
对于 MAE,我们强制要求在每个图像中遮罩相同数量的单位。这样,如果我们像您的示例一样屏蔽,批处理中的每个图像将始终留下 3 个单位。然后,为了回答您的问题(请注意,在此存储库中没有实现向批处理技巧的转变,因为 Hiera 不需要它),假设我们有 4 张 w=96、h=64 的图像,有 3 个通道。
然后我们的输入张量将如下所示:
input_image: shape = [4, 3, 64, 96]
每个令牌都是 4x4 像素,因此一旦我们对图像进行标记化,我们就会下降到:
请注意,分词器还会将通道调暗度提高到 144(例如,对于 L 型号)。tokenized_image: shape = [4, 3, 64, 96] -> tokenizer (patch embed) -> [4, 144, 16, 24]
然后我们提取掩码单元,每个掩码单元都是 8x8 标记:
在这里,每个图像包含 6 个 2x3 排列的掩码单元,如上面的示例所示,其中每个掩码单元是 8x8 (64) 个标记。tokenized_image_mu: shape = [4, 144, (2, 8), (3, 8)] -> permute -> [4, 144, (2, 3), (8, 8)] -> reshape -> [4, 144, 6, 64]
现在,我们从每个图像中删除相同数量的令牌,因此如果遮罩率为 50%,我们将从 3 张图像中的每一个中选择 4 个进行丢弃:
masked_image_mu: shape = [4, 144, 6, 64] -> discard 3 mus from ea. image -> [4, 144, 3, 64]
然后,最后转向批处理技巧:只需将“3”维度移动到批处理维度即可。
shifted_to_batch: shape = [4, 144, 3, 64] -> permute -> [(4, 3), 64, 144] -> reshape -> [12, 64, 144]
然后这是熟悉的形状,您可以传递到任何变压器中。池化和窗口 attn 可以在“64”维度上完成(即 8x8 -> 4x4 -> 2x2 等),如果你填充它,你可以像 MViT 一样做 3x3 内核(这就是为什么我们也称它为“分离和填充”)。[batch, tokens, embed_dim]
这是一种冗长的解释,但希望这是有道理的。
作者的原版回答请查看这个issue:How do we drop tokens?
2.4 为什么使用Mask Unit Attn
上图所述为 MViTv2
中 Pooling Attn
和 Hiera
中的 Mask Unit Attn
的区别。
具体而言:MViTv2
使用 Pooling Attn
,通过
K
K
K 和
V
V
V 的池化版本执行全局关注
。对于大输入(例如视频)来说,这计算成本
可能会很昂贵,所以作者选择用Mask Unit Attn
来代替它,它在掩码单元内执行局部注意。这没有开销
,因为在前面的操作阶段已经将令牌分组为屏蔽单元。同时,不必像在Swin
中那样担心转移(窗口之间没有联系,使用shift来获取全局注意力
),因为作者在阶段3和4
中使用了全局注意力
。
此外,Mask Unit Attn
和Window Attn
最主要的区别就是:Window Attn 的窗口大小是固定的,Mask Unit Attn 的窗口大小可以在当前分辨率下调整窗口大小以适应掩码单元的大小
。具体在论文中的内容如下所示:
为了更好地帮助大家理解内容,具体可见下面的论文源代码:
class MaskUnitAttention(nn.Module):
"""
Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
Note: this assumes the tokens have already been flattened and unrolled into mask units.
See `Unroll` for more details.
"""
def __init__(
self,
dim: int,
dim_out: int,
heads: int,
q_stride: int = 1,
window_size: int = 0,
use_mask_unit_attn: bool = False,
):
"""
Args:
- dim, dim_out: The input and output feature dimensions.
- heads: The number of attention heads.
- q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
- window_size: The current (flattened) size of a mask unit *after* pooling (if any).
- use_mask_unit_attn: Use Mask Unit or Global Attention.
"""
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.heads = heads
self.q_stride = q_stride
self.head_dim = dim_out // heads
self.scale = (self.head_dim) ** -0.5
self.qkv = nn.Linear(dim, 3 * dim_out)
self.proj = nn.Linear(dim_out, dim_out)
self.window_size = window_size
self.use_mask_unit_attn = use_mask_unit_attn
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Input should be of shape [batch, tokens, channels]. """
B, N, _ = x.shape
# 如果use_mask_unit_attn 为True,输入数据x经过线性变换得到qkv会根据q_stride和window_size来进行变化注意力窗口大小
num_windows = (
(N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
)
qkv = (
self.qkv(x)
.reshape(B, -1, num_windows, 3, self.heads, self.head_dim)
.permute(3, 0, 4, 2, 1, 5)
)
q, k, v = qkv[0], qkv[1], qkv[2]
if self.q_stride > 1:
# Refer to Unroll to see how this performs a maxpool-Nd
q = (
q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim)
.max(dim=3)
.values
)
if hasattr(F, "scaled_dot_product_attention"):
# Note: the original paper did *not* use SDPA, it's a free boost!
x = F.scaled_dot_product_attention(q, k, v)
else:
attn = (q * self.scale) @ k.transpose(-1, -2)
attn = attn.softmax(dim=-1)
x = (attn @ v)
x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x
3. 简化版理解
可能看了上述的内容,大家对于 Hiera
的整体还是不太理解。这里对文章内容进行口语式解答来帮助大家理解文章内容。
Hiera 这篇文章总的来说是将 MAE 与分层 Vision Transformer 模型相结合,通过 MAE 框架来替代原始分层 Vision Transformer 模型中 卷积、窗口移位或注意力偏差
等复杂框架 ,进而学习空间偏差来达到一个非常高的分类精度。在简化模型的同时带来了非常高的精度。
4. 总结
作者创建了一个简单的分层视觉变压器
,通过现有的视觉变压器并去除其所有的信号,同时通过MAE预训练为模型提供空间偏差。由此产生的架构Hiera
比目前在图像识别任务上的工作更有效,并且在视频任务上超越了最先进的技术
。如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。
到此,有关TPS
的内容就基本讲完了。如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。