人工智能咨询培训老师叶梓 转载标明出处
计算机视觉中,表征学习是一个核心问题。如何让机器像人类一样理解图像内容,是实现高级视觉任务的关键。传统的自监督学习方法往往依赖于数据的变换来预测其变化,例如遮蔽图像建模(Masked Image Modeling)等。然而,这些方法在学习过程中往往忽略了学习到的世界模型在下游任务中的应用潜力。由Quentin Garrido、Mahmoud Assran等人在Meta的FAIR实验室提出的图像世界模型(Image World Models,简称IWM)的核心在于,它不仅学习如何对图像进行有效的表征,还学习了一个能够预测图像全局光度变换的模型,并将这个模型应用于解决多样化的任务。
图像世界模型(IWM)的方法论
图像世界模型(IWM)是依据联合嵌入预测架构(JEPA)构建的,这一架构与Assran等人在2023年提出的I-JEPA有相似之处。在JEPA的框架下,预测器实际上扮演了世界模型的角色。IWM的理念是,如果一个世界模型能够在潜在空间中实施变换,并且能够学习到与变换相适应的表征,即等变表征,那么这个模型就被认为是具备能力的。基于这个理念,能够实施变换的模型被称作等变模型,而不能实施变换的模型则被称为不变模型。
JEPA的一个显著优势在于,与需要依赖不变损失来提升表征质量的对比方法不同,JEPA不需要这样的损失函数。无论是显式地(如Gupta等人在2023年的研究,或者Garrido等人在2023b的研究)还是隐式地(如Chavhan等人在2023a的研究),对比方法通常需要不变损失来确保学习到的表征对数据的变换具有不变性。相比之下,JEPA风格的方法通过潜在空间的填充来学习表征的语义方面,这种方式不要求损失函数的辅助。另外在潜在空间中进行操作还可以让网络排除那些无关紧要或难以预测的信息。这一点尤为重要,因为在重建方法中,重建的质量并不总是与表征的质量正相关的,如Chen等人在2024年的研究中所指出的。因此,JEPA提供了一种更加灵活且有效的学习表征的方法,它通过减少对特定损失函数的依赖,使得模型能够更加自由地学习数据的本质特征。
为了训练IWM,首先需要从图像I生成源视图(source view)x和目标视图(target view)y(如图2所示)。目标视图y的生成是通过在原始图像I上应用随机的水平翻转、裁剪和颜色抖动(亮度、对比度、饱和度、色相)来完成的。为了确保目标尽可能包含多的信息,没有对目标应用破坏性增强,例如灰度化。源视图x的生成则是从目标y开始,进一步转换。首先应用另外一组颜色抖动,以及破坏性增强:灰度化、模糊和曝光过度(solarization)。这些增强与对比性SSL中使用的那些相同。最后根据I-JEPA的方法对图像的部分区域进行遮蔽。遮蔽掩码Mx定义为四个矩形掩码的并集。
变换参数ax→y与x到y的变换相关联,即初始变换过程的逆。这个参数包含了x和y之间颜色抖动的差异信息,以及是否应用了每种破坏性增强的信息。在潜在空间中应用变换的世界模型被称为pϕ。源和目标分别通过编码器fθ及其指数移动平均值fEMAθ进行输入,得到表示zx = fθ(x)和zy = fEMAθ(y)。使用EMA网络对于避免解决方案崩溃至关重要。为了调节预测器,即图像世界模型,它接收以掩码标记形式的几何信息以及ax→y。这些掩码标记表示为ma,对应于MCx中的位置。预测器pϕ然后以嵌入的源补丁xc、变换参数ax→y和掩码标记ma作为输入。其目标是匹配pϕ(zx, ax→y, ma) = zˆy到zy。
所使用的损失函数是预测zˆy和它们的目标zy之间的平方L2距离:L(x, y) = Σ(i∈MCx) ||pϕ(fθ(x), ax→y, ma)i - fEMAθ(y)i||^2。
IWM的编码器是一个视觉变换器(Vision Transformer),具体使用的是ViT-B/16架构。预测器基于相同的架构,但具有不同的深度和嵌入维度。研究者将IWM的实例表示为IWMZX,Y,其中X是预测器的深度,Y是其嵌入维度,Z根据世界模型的能力是Inv(不变)还是Equi(等变)。例如,IWMEqui18,384意味着预测器有18层深度,具有384维嵌入,并表现出等变行为,即学习到了一个多功能的世界模型。
学习图像世界模型以提升表征学习
在表征学习中,学习等变表征和学习世界模型是密切相关的问题。为了评估训练出的世界模型的质量,研究者们借鉴了等变理论中的指标,主要依赖平均倒数排名(Mean Reciprocal Rank,MRR)作为衡量标准。MRR的计算过程涉及到生成一个增强目标图像库(实际中包含256张图像),然后将干净图像的表征通过预测器传递,目标是预测目标图像。接着计算预测结果与增强表征库之间的距离,得到目标在最近邻图中的排名。通过对多个图像和变换求倒数排名的平均值得到MRR,这反映了世界模型的质量。MRR接近1意味着世界模型能够应用变换,而接近0则意味着它不能。
为了构建一个高性能的IWM,研究者们确定了三个关键因素:预测器在变换(或动作)上的条件化、控制变换的复杂性,以及控制预测器的容量。研究表明,如果这些因素没有得到适当的处理,将导致表征的不变性。研究者们探索了两种条件化预测器以反映变换信息的方法:
- 序列条件化:通过向预测器的输入添加代表变换的标记来实现。为了打破变换器预测器的排列等变性,每个标记都通过一个独特的线性层进行传递,使网络能够以可区分的方式转换信息。
- 特征条件化:另一种选择是通过将变换信息作为额外维度添加,然后将掩码标记通过1x1卷积神经网络混合信息,再映射回正确的维度,从而在变换和掩码标记之间混合信息。
如表1所示,没有条件化会导致世界模型无法应用变换,而使用序列或特征轴进行条件化则能产生良好的世界模型。实际中采用特征条件化,因为它能够带来更高的下游性能。
研究者们依赖于对比方法中使用的数据增强,包括颜色抖动(亮度、色调、对比度、饱和度)以及灰度化、模糊和曝光过度等破坏性增强。这些增强的强度也必须适当,以学习有用的世界模型。如果预测任务太容易,预测器将无法学到有用的信息。如表2所示,增强越强,学习强大的世界模型就越容易。
如果变换复杂,预测器需要更多的容量才能应用它,这使得容量成为学习图像世界模型的一个关键因素。如表2所示,更深的预测器使之能够在更广泛的增强范围内学习强大的世界模型,这是IWM成功的关键。对于12层的预测器,颜色抖动等变性在5次尝试中只有1次能够实现,而对于18层的预测器,5次中有4次能够实现。因此,预测器的容量是强大世界模型的关键组成部分。
与计算MRR的方式相同,研究者们可以比较预测的表征与一组变换图像,并观察与预测最近邻相关的图像。如图1所示,IWM学习到的世界模型能够在潜在空间正确应用变换。然而,在反转灰度时可以看到一些不准确之处,因为灰度不是完全可逆的。这些可视化有助于强化IWM能够学习图像变换的强大世界模型的事实。
利用世界模型处理下游任务
虽然在图像上学习到的世界模型能够执行颜色抖动或给图像上色等任务,但这些并不是推动计算机视觉应用的主要任务。这与大模型(LLMs)不同,后者的主要应用之一是预测下一个词。因此,研究者们探索了如何将世界模型用于视觉领域,以处理超越应用变换的任务,重点是分类和图像分割等判别性任务。
对于任何任务,评估头部需要理解学习到的潜在空间,并利用它来解决手头的问题。预测器能够做到这一点,表明它学习到了有用的信息,这些信息不一定存在于编码器中。然而,由于预测器被训练为预测另一种有效的表征,如果直接使用,其输出没有理由会带来更好的下游性能。因此,需要对预测器进行微调,以解决判别性任务。研究者们专注于与He等人(2021年)提出的微调协议进行比较。所有研究的方法都在ImageNet上进行了预训练并进行了评估,并使用ViT-B/16作为编码器。
在微调预测器时仍然需要使用它进行预测任务。在表3中,研究者们研究了定义预测任务的不同方式及其对性能的影响。首先注意到的是,使用教师网络比使用学生网络提高了性能。是否使用随机变换并不是一个重要因素,最重要的是预测另一个完整的图像。这使得评估更加灵活,因为不必为了评估而重用预训练目标。使用CLS标记来聚合信息而不是完整的图像预测也是一个很好的策略,尽管这会降低一半的性能。这项技术的优点是成本更低(N + 1个标记对2N个标记),因此根据用例,它可以是一个很好的选择。总的来说,最简单的方法是最好的:预测完整图像的未变换版本。这使得微调协议很容易重用,因为它不依赖于预训练任务。
在图3中,研究者们研究了预测器微调与编码器微调的效率。当考虑可比较的参数数量时,使用IWM的预测器微调在性能上比MAE编码器微调高出约1个百分点,比IWM高出1.5个百分点。这意味着预测器微调不仅在性能上是一个有竞争力的协议,而且在适应效率上也是如此。
与编码器微调相比,可以进一步提高效率。表征学习的一个主要目标是获得可以用于多种任务的表征。就像预测器被训练来解决多种任务(上色、填充、改变颜色)一样,研究表明它可以在多个任务上进行微调,受到前缀调整(Li和Liang,2021年)和指令调整(Wei等人,2022年;Zhang等人,2023年)的启发。一般的思想是给预测器新的学习标记,以指示它正在尝试解决的任务。对于每个任务,都有一个任务标记,以及特定于任务的头部和/或损失函数。然后所有任务的损失被结合起来,预测器以及特定于任务的头部被更新。研究者们研究了一个简单的场景,其中批量均匀地分配给任务,注意到其他采样策略可能会导致进一步改善性能。
在表7中,研究者们评估了在ImageNet、iNaturalist18、SUN397和Places205上预训练的IWMEqui18,384。对于每个任务,他们都训练了一个单任务基线,其中总迭代次数与多任务训练相同。因此,训练所有四个单任务基线的成本与多任务训练完全相同,尽管它会产生四个不同的模型而不是一个。多任务预测器能够实现与单任务预测器相似的性能,在大多数任务上略有下降,但在SUN397上性能显著提高。平均来看,它实现了与单任务预测器相同的性能。这进一步证明了利用好的世界模型的效率收益,现在参数在所有任务中共享,使得预测器微调在推理时对每个任务都很轻量。
总的来说,当学习到一个好的世界模型时,可以通过微调它来重用它进行下游任务。这导致的性能与编码器微调相当,但成本只有一小部分。通过进行多任务微调,可以使其更加高效,突显了这种方法的多功能性。
图像世界模型使表征更加灵活
为了完善对IWM在表征学习中的分析,研究者们探究了它在自监督学习中常用的轻量级评估协议上的表现,特别关注线性探测(linear probing)和注意力探测(attentive probing)。
如表8所示,当IWM学习到一个不变的世界模型时,它展现出与MoCov3等对比方法相似的行为,在线性评估中相比MIM或其他基于JEPA的方法有显著的性能提升。同样地,当IWM学习到一个等变的世界模型时,它的行为类似于MAE等MIM方法,在线性评估中表现较差,但在注意力探测中的性能更具竞争力。这表明不同方法之间的主要区别不必然在于表征的质量,而在于它们的抽象层次,即提取信息的容易程度。线性探测是最简单的评估方式,注意力探测稍微复杂一些,微调则是一个更复杂的协议。
在图4中,我们可以看到最适合的评估协议与世界模型的等变性之间存在明确的联系。更加不变的世界模型在线性评估中表现优异,而等变的世界模型在具有较大评估头部的预测器微调中表现出色。
可以在图5中看到,对比方法占据了表征抽象的高端,信息可以通过简单的协议轻松提取。然而,当忽略适应成本时,它们在峰值性能上会有所损失,如表5所示。另一方面,MIM位于这一谱系的低端,它在复杂的评估中如微调中提供更强的性能,但在线性探测中表现不佳,因为信息不是那么容易获取。通过改变世界模型的等变性,IWM能够在对比方法和MIM之间的谱系中占据位置,如图4和表8中IWMInv12,384和IWMEqui18,384作为IWM谱系的两个极端所示。
这个谱系可以被SSL的理念“学习可预测的内容”所概括。用一个弱的世界模型学习意味着它不能正确地模拟世界,编码器会移除那些无法预测的信息。另一方面,如果世界模型非常强大,表征就不需要那么抽象或语义化,因为它可以在任何情况下找到预测表征的方法。这意味着学习世界模型提供了一种可衡量的方式来控制表征的抽象层次。
IWM允许在表征的抽象层次上进行调整,从而在不同的评估协议中实现最佳性能。这种灵活性是通过改变世界模型的等变性来实现的,从而可以根据特定任务的需求来定制表征的抽象程度。例如,当我们需要一个能够轻松提取信息的表征时,可以训练一个更抽象的模型;而当我们需要一个保留更多原始信息的表征时,可以训练一个等变性更强的模型。
通过这种方式,IWM不仅提供了一种强大的表征学习方法,而且还提供了一种根据下游任务的具体需求来调整模型的能力。这种灵活性和适应性使得IWM成为一个在多种视觉任务中都具有潜力的强大工具。
论文链接:https://arxiv.org/abs/2403.00504