Fast-iTPN: Integrally Pre-Trained Transformer Pyramid Network with Token Migration
https://github.com/sunsmarterjie/iTPN/blob/main
https://arxiv.org/pdf/2211.12735
Introduction
背景
近年来,视觉模型取得了两大进展,一是将Vision Transformer(ViT)作为网络主干,二是使用Masked Image Modeling(MIM)方法进行模型预训练。这两者的结合在多种下游任务中取得了先进性能,包括图像分类、目标检测和实例/语义分割。
挑战
然而,预训练与下游微调之间的迁移差距仍然存在。具体来说,下游任务(尤其是细粒度识别任务如检测和分割)需要层次化特征,但大多数预训练任务(如BEiT和MAE)都是基于简单的ViT,缺乏层次化设计。即使使用层次化ViT,预训练也仅影响主干网络,而特征金字塔(neck部分)未经过训练,这增加了下游任务微调的风险。
iTPN的提出
目的:为了缓解这一问题,本文提出了积分预训练的金字塔Transformer网络(iTPN),旨在联合优化网络主干和特征金字塔,从而最小化表示模型与下游任务之间的迁移差距。
方法:iTPN基于HiViT(一种MIM友好的层次Transformer),并为其配备了特征金字塔。通过两个关键技术贡献来联合优化主干和特征金字塔:1) 在预训练阶段插入特征金字塔进行重构,并在微调阶段重用这些权重;2) 提出掩码特征建模(MFM)以更好地预训练特征金字塔。
MFM的优势
MFM通过两个步骤来预训练特征金字塔:首先,将原始图像输入到移动平均主干中以计算中间目标;然后,使用金字塔各阶段的输出来重建这些中间目标。MFM与MIM互补,提高了重建和识别的准确性,并可以吸收来自预训练教师(如CLIP)的知识,进一步提高性能。
Fast-iTPN的改进
改进背景:使用层次化架构和特征金字塔时,全局自注意力导致的计算成本会累积。为了缓解这一问题,本文对iTPN进行了升级,提出了Fast-iTPN。
改进方法:Fast-iTPN通过两个灵活的设计来加速推理并减少内存开销:1) 令牌迁移(Token Migration),即根据相似性度量从主干中丢弃冗余令牌,并在没有自注意力操作的特征金字塔中补充这些令牌;2) 令牌收集(Token Gathering),通过引入少量收集令牌来聚合来自所有窗口的全局信息,从而用窗口注意力替换全局注意力,显著加速推理过程且性能损失可忽略不计。
ITPN
动机与背景
在视觉模型的发展中,尽管Vision Transformer(ViT)架构和Masked Image Modeling(MIM)方法结合取得了显著进展,但在上游预训练与下游微调之间仍存在较大的迁移差距。特别是对于需要层次特征的下游任务(如检测和分割),这一差距尤为明显。传统的预训练方法往往只针对骨干网络(backbone)进行优化,而忽略了特征金字塔(neck,如特征金字塔网络FPN)的预训练。
为了缓解这一问题,本文提出了积分预训练的金字塔Transformer网络(iTPN),旨在同时优化骨干网络和特征金字塔,从而最小化表示模型与下游任务之间的迁移差距。
技术贡献
首个预训练特征金字塔
插入特征金字塔:在预训练阶段,iTPN在骨干网络(如HiViT)后插入了一个特征金字塔。这样,在预训练过程中,特征金字塔就能够被优化,并在下游任务中复用其权重。
统一上下游的脖子:通过将特征金字塔整合到预训练阶段(用于重建)并在微调阶段复用其权重(用于识别),实现了上下游脖子的一致性。
掩码特征建模(MFM)
计算中间目标:MFM通过将一个移动平均的骨干网络应用于原始图像来计算中间目标。
多阶段监督:使用特征金字塔的每一阶段的输出来重建这些中间目标,从而实现对特征金字塔的多阶段监督。
适应教师模型:MFM还可以吸收来自预训练教师模型(如CLIP)的知识,进一步提高性能。
网络架构
iTPN的整体架构由两部分组成:骨干网络(如HiViT)和特征金字塔。骨干网络用于提取初步的视觉特征,而特征金字塔则对这些特征进行进一步的处理和聚合,以适应不同的下游任务。
骨干网络
本文采用HiViT作为骨干网络,HiViT通过引入两个基于MLP的阶段来构建层次特征,避免了在全局注意力阶段使用卷积操作或窗口注意力,从而保证了计算效率和与MIM的兼容性。
特征金字塔
特征金字塔通过逐步上采样和融合来自骨干网络的不同层次特征来构建多尺度特征表示。在预训练阶段,这些特征被用于重建由移动平均骨干网络计算的中间目标。
训练过程
iTPN的训练过程分为两个阶段:预训练和微调。
预训练
输入:原始图像被划分为一系列的图像块(tokens),其中一部分被随机掩码。
前向传播:掩码后的图像块通过骨干网络进行处理,生成初步的特征表示。这些特征随后被传递到特征金字塔中进行进一步的处理。
重建目标:特征金字塔的每一阶段都试图重建由移动平均骨干网络计算的中间目标。
损失函数:重建损失(例如,均方误差)用于监督特征金字塔的训练。
微调
在微调阶段,预训练的特征金字塔和骨干网络的权重被冻结或微调,以适应特定的下游任务(如图像分类、目标检测和语义分割)。
技术优势
积分预训练:同时优化骨干网络和特征金字塔,减少了迁移差距。
多阶段监督:MFM通过对特征金字塔的每一阶段进行监督,提高了特征的泛化能力。
适应性强:iTPN可以灵活地适应不同的预训练目标和下游任务。
Fast-iTPN
Fast-iTPN(快速整体预训练Transformer金字塔网络)是在iTPN基础上进行的改进,旨在通过两种灵活的设计来减少计算内存开销并加速推理过程。这两种设计分别是:Token迁移和Token聚合。这些设计不仅保持了模型在下游任务中的性能,还显著提升了模型的推理速度。
Fast-iTPN的主要设计
Token迁移(Token Migration)
Token迁移机制通过两个步骤来实现:
丢弃冗余Token:根据一个相似性度量(如余弦相似度),从主干网络(如HiViT)中丢弃一部分冗余的Token。这些被丢弃的Token通常是信息重复的或者对全局表示贡献较小的。
补充到特征金字塔:将被丢弃的Token补充到特征金字塔中,但不进行自注意力操作。这一步骤有效地利用了这些被丢弃的Token,并且因为特征金字塔中没有自注意力操作,所以补充Token的计算成本相对较低。
Token聚合(Token Gathering)
Token聚合通过引入少量的聚合Token来进一步减少全局自注意力操作的计算成本。这些聚合Token的作用是从所有窗口聚合全局信息,从而使得全局注意力可以被窗口注意力替代。具体来说:
聚合Token的作用:聚合Token在每个窗口中接收来自其他Token的信息,并将这些信息进行聚合,然后传递给下一层。通过这种方式,每个聚合Token都能够捕获到来自整个输入图像的全局信息。
减少计算成本:由于聚合Token只需要在窗口内进行自注意力操作,因此相比于全局自注意力,这种机制可以显著降低计算成本。同时,由于聚合Token的数量远少于输入Token的总数,因此总体计算量也大大减少。
技术细节
特征金字塔的设计
Fast-iTPN在HiViT的基础上构建了一个特征金字塔,该金字塔通过上采样和特征融合操作将不同层的特征图整合到一起。这种设计使得模型在不同尺度上都能够学习到丰富的特征表示,从而更好地适应下游任务的需求。
掩码特征建模(Masked Feature Modeling, MFM)
MFM是Fast-iTPN中用于预训练特征金字塔的方法。它通过以下两个步骤来实现:
计算中间目标:将原始图像输入到一个移动平均主干网络中,以计算得到中间目标。这些中间目标包含了丰富的特征信息,可以作为预训练过程中的监督信号。
使用金字塔各阶段输出进行重建:利用特征金字塔每个阶段的输出来重建这些中间目标。这一步骤不仅预训练了特征金字塔本身,还通过重建任务促进了金字塔内部不同层之间的特征交互。
模型训练
Fast-iTPN在训练过程中同时优化了主干网络和特征金字塔。由于特征金字塔在预训练阶段就被引入了,因此它在后续的微调阶段能够直接使用预训练得到的权重,从而减少了迁移差距并提升了下游任务的性能。