论文:RMT: Retentive Networks Meet Vision Transformers - AMiner
摘要
这篇论文探讨了将Retentive Network(RetNet)的概念引入到计算机视觉领域,并与Vision Transformer结合,提出了一种新的模型RMT(Retentive Networks Meet Vision Transformers)。RetNet最初在自然语言处理(NLP)领域展现出色性能,作者们提出疑问,是否将RetNet的思想迁移到视觉领域也能带来卓越的性能。RMT通过引入与空间距离相关的显式衰减机制,为视觉模型带来了空间先验知识。此外,为了降低全局建模的计算成本,作者们沿图像的两个坐标轴分解了建模过程。大量实验表明,RMT在多种计算机视觉任务上表现出色,例如,在ImageNet-1k数据集上达到了84.1%的Top1准确率,且计算量仅为4.5G FLOPs。
拟解决的问题
Transformer架构虽然在计算机视觉任务中表现出色,但RetNet在NLP任务中展现出更强的性能,这激发了作者们探索将RetNet迁移到视觉任务的可能性,以期提高视觉模型的性能。
创新之处
- 提出了Retentive Self-Attention(ReSA)机制,将空间距离相关的先验知识引入到视觉模型中。
- 通过沿图像的两个坐标轴分解ReSA,降低了计算复杂度,同时最小化了对模型性能的影响。
方法论
RMT的整体架构与传统的backbone相似,分为四个阶段。前三个阶段使用分解的ReSA,最后一个阶段使用原始的ReSA。此外,模型中还整合了条件位置编码(CPE)。
Retentive Network(RetNet)
在RetNet中,这种记忆机制是通过一种特殊的“保留”操作来实现的。我们可以把这种操作想象成一个过滤器,它允许网络在考虑每个数据点(在NLP中是单词,在图像中是像素)时,给予它们周围邻近数据点不同级别的重视。
保留机制是RetNet的基础,它通过引入衰减系数来控制每个token对其邻近token的关注程度。这种衰减是根据token之间的距离来确定的,距离越远,注意力权重越小。
保留操作的数学表达式可以写为:
其中,是第n个token的输出,𝑄、𝐾、𝑉分别是查询(Query)、键(Key)和值(Value)矩阵,𝜃θ是角度参数,是衰减系数,它取决于token n和m之间的距离。
衰减系数用于控制序列中不同位置token之间的注意力权重。在RetNet中,衰减系数通常设计为随着token间距离的增加而指数级减少,从而引入了一维距离的先验知识。
并行训练中,保留操作为:
其中𝑋是输入序列,⊙表示Hadamard积。
相比于传统的Transformer模型,RetNet通过引入显式的衰减机制,能够更好地捕捉序列数据中的长距离依赖关系,从而在多种NLP任务中取得了更强的性能。
Retentive Self-Attention (ReSA)
ReSA机制考虑了图像中像素(或称为tokens)之间的空间距离。与传统的Transformer模型不同,ReSA允许模型根据像素之间的距离来调整它们之间的相互影响。距离较近的像素会有更强的相互关系,而距离较远的像素则关系较弱。
:衰减矩阵,其元素根据像素对之间的二维空间距离计算。
衰减矩阵是根据像素之间的曼哈顿距离来计算的,其元素定义如下:
其中:
Decomposed ReSA
为了降低计算成本,特别是在图像的早期阶段,ReSA可以沿图像的两个轴(水平和垂直)分解。分解的ReSA可以分别计算行和列方向上的注意力,公式如下:
和:分别代表行和列方向上的注意力机制。
Local Enhancement Module
为了增强局部特征的表达能力,ReSA还包括一个局部增强模块,通常使用深度卷积(DWConv)来实现。最终的输出可以表示为:
LCE(X):局部增强模块的输出。
总的来说,ReSA是一种结合了空间距离先验知识的注意力机制,它通过调整像素之间的注意力权重,提高了模型对局部特征的敏感度,从而在视觉任务中取得了更好的性能。
结论
RMT模型通过引入显式的空间距离衰减,有效地提高了视觉模型的性能,尤其是在下游任务如目标检测、实例分割等领域展现出显著的优势。