ViT模型何时才能破万亿?
Transformer无疑是促进自然语言处理领域繁荣的最大功臣,也是GPT-4等大规模语言模型的基础架构。
不过相比语言模型动辄成千上万亿的参数量,计算机视觉领域吃到Transformer的红利就没那么多了,目前最大的视觉Transformer模型ViT-e的参数量还只有40亿参数。
最近谷歌发布了一篇论文,研究人员提出了一种能够高效且稳定训练大规模Vision Transformers(ViT)模型的方法,成功将ViT的参数量提升到220亿。
论文链接:https://arxiv.org/abs/2302.05442
为了实现模型的扩展,ViT-22B结合了其他语言模型(如PaLM模型)的思路,使用 QK 归一化改进了训练稳定性,提出了一种异步并行线性操作(asynchronous parallel linear operations)的新方法提升训练效率,并且能够在硬件效率更高的Cloud TPU上进行训练。
在对ViT-22B模型进行实验以评估下游任务性能时,ViT-22B也表现出类似大规模语言模型的能力,即随着模型规模的扩大,性能也在不断提升。
ViT-22B 还可以应用于PaLM-e中,与语言模型结合后的大模型可以显著提升机器人任务的技术水平。
研究人员还进一步观察到规模带来的其他优势,包括更好地平衡公平性和性能,在形状/纹理偏见方面与人类视觉感知的一致性,以及更好的稳健性。
模型架构
ViT-22B 是一个基于Transformer架构的模型,和原版ViT架构相比,研究人员主要做了三处修改以提升训练效率和训练稳定性。
并行层(parallel layers)
ViT-22B并行执行注意力块和MLP块,而在原版Transformer中为顺序执行。
PaLM模型的训练也采用了这种方法,可以将大模型的训练速度提高15%,并且性能没有下降。
query/key (QK) normalization
在扩展ViT的过程中,研究人员在80亿参数量的模型中观察到,在训练几千步之后训练损失开始发散(divergence),主要是由于注意力logits的数值过大引起的不稳定性,导致零熵的注意力权重(几乎one-hot)。
为了解决这个问题,研究人员在点乘注意力计算之前对Query和Key使用LayerNorm
在80亿参数模型上的实验结果如下图所示,归一化可以缓解发散问题。
删除QKV投影和LayerNorms上的偏置项
和PaLM模型一样,ViT-22B从QKV投影中删除了偏置项,并且在所有LayerNorms中都没有偏置项(bias)和centering,使得硬件利用率提高了3%,并且质量没有下降。
不过与PaLM不同的是,ViT-22B对(内部和外部)MLP稠密连接层使用了偏置项,可以观察到质量得到了改善,并且速度也没有下降。
ViT-22B的编码器模块中,嵌入层,包括抽取patches、线性投影和额外的位置嵌入都与原始ViT中使用的相同,并且使用多头注意力pooling来聚合每个头中的per-token表征。
ViT-22B的patch尺寸为14×14,图像的分辨率为224×224(通过inception crop和随机水平翻转进行预处理)。
异步并联线性运算(asynchronous parallel linear operations)
大规模的模型还需要分片(sharding),即将模型参数分布在不同的计算设备中,除此之外,研究人员还把激活(acctivations,输入的中间表征)也进行分片。
因为输入和矩阵本身都是分布在各种设备上的,即使是像矩阵乘法这样简单的操作也需要特别小心。
研究人员开发了一种称为异步并行线性运算的方法,可以在矩阵乘法单元(在TPU 中占据绝大多数计算能力的单元)中计算时,同时对设备之间的激活和权值进行通信。
异步方法最小化了等待传入通信的时间,从而提高了设备效率。
异步并行线性运算的目标是计算矩阵乘法 y = Ax,但矩阵 A 和激活 x 都分布在不同的设备上,需要通过跨设备的重叠通信和计算来实现这一点。矩阵 A 在设备之间进行列分片(column-shard),每个矩阵包含一个连续的切片,每个块表示为 Aij,更多细节请看原始论文。
实验结果
为了说明ViT-22B学习到的表征非常丰富,研究人员使用LiT-tuning训练一个文本模型来生成一些表征用来对齐文本和图像。
下面是用Parti 和 Imagen 生成的分布外(out-of-distribution)图像得到的实验结果,可以看到ViT-22B的zero-shot图像分类泛化能力非常强,仅从web上爬取的自然图像就能识别出没见过的物体和场景。
论文中还讨论了ViT-22B在视频分类、深度估计和语义分割任务上的效果。
与人类目标识别对齐
为了验证 ViT-22B 分类决策与人类分类决策的一致性,研究人员对 ViT-22B 进行了微调,对分布外(OOD)数据集的不同分辨率进行了微调,其中人类比较数据可通过model-vs-human toolbox获得。
该工具箱主要衡量三个关键指标: 模型如何处理失真(准确性) ?人和模型的精度(精度差)有什么不同?人和模型的错误模式(错误一致性)有多相似?
形状偏差评估(值越大代表更多的形状偏差)。许多视觉模型具有低形状/高纹理偏差,而在 ImageNet 上进行微调的 ViT-22B具有迄今为止在 ML 模型中记录的最高形状偏差,更接近于人类形状偏见
实验结果显示,虽然并非所有的微调解决方案都表现得很好,但 ViT-22B 变体在所有三个指标上都达到了新高。
此外,ViT-22B 模型在视觉模型中也有最高的形状偏差记录。这意味着他们主要使用目标的形状,而不是目标的纹理来进行分类决策,策略结果类似于人类的感知(其形状偏差为96%)。
标准模型(例如,ResNet-50有20-30% 的形状偏差)通常根据纹理来分类,而高形状偏差的模型则倾向于关注形状(下图识别为猫),尽管人类和模型的感知之间仍然存在许多差异,但是 ViT-22B 显示出与人类视觉对象识别更多的相似性。
猫还是大象?车还是钟?鸟还是自行车?具有某个物体的形状和另一个不同物体纹理的图像,可用于测量形状/纹理偏差
分布外(out-of-distribution)性能
测量 OOD 数据集的性能有助于评估模型泛化性。
在这个实验中,研究人员构建了从 JFT 到 ImageNet 的标签映射,以及从 ImageNet 到不同的分布外数据集(如 ObjectNet)的标签映射。
对这些数据进行预训练后的结果如下图所示,然后在 ImageNet 上对模型进行完全微调。
可以观察到缩放 Vision Transformers 可以提高 OOD 性能: 即使 ImageNet 的精度达到饱和,也可以看到 ObjectNet 上从 ViT-e 换成 ViT-22B 模型可以显著提升性能。
线性探测Linear Probe
线性探测是一种将单个线性层置于冻结模型之上的技术,与完全微调相比,这种方法的训练成本更低,设置起来也更容易。
在 ImageNet 上训练的线性探测结果,在 ImageNet-Real,ImageNet-v2,ObjectNet,ImageNet-R 和 ImageNet-A 数据集上评估,提供高分辨率微调 ViT-e/14作为参考
从结果中可以观察到,ViT-22B 的线性探测性能接近于使用高分辨率图像对较小模型进行全面微调的最先进水平,其中具有较高分辨率的训练通常要昂贵得多,但可以在许多任务上取得更好的结果。
蒸馏
利用蒸馏法,可以将较大模型的知识转化为较小模型的知识,可以提升成本更高、运行速度更慢的大模型的运行效率。
从实验结果中可以发现,ViT-22B 的知识可以迁移到更小的模型,如 ViT-B/16和 ViT-L/16,并在同等模型尺寸下在ImageNet上刷新了性能记录。
公平性与偏见
机器学习模型容易受到意想不到的不公平偏见的影响,例如找到错误的相关性或者在各个子群体之间存在性能差距,研究人员发现,扩大模型规模有助于缓解这些问题。
首先,规模是一个有前景的权衡方式,即使模型经过训练后再进行后处理,将其人口平等(demographic parity)水平控制在规定的、可容忍的水平之下,性能也会随着规模的增加而提高。
上图: 去偏前 CelebA 中每个子组的精度。下图: y 轴显示了在这个例子中突出显示的两个特定亚组(女性和男性)的表现的绝对差异。与较小的 ViT 模型相比,ViT-22B 在性能的差距很小。
更重要的是,这不仅适用于以准确性衡量性能的情况,而且适用于其他度量,例如校准,即对模型估计概率的真实性的统计测量,所有子群的分类随着规模的增大而趋于改善,并且ViT-22B 降低了各子群之间的性能差距。
结论
研究人员提出了一个目前最大的视觉Transformer模型 ViT-22B,包含220亿参数。
通过对原始模型架构进行微小但关键的修改后,实现了更高的硬件利用率和训练稳定性,从而得到了一个在几个基准测试上提高了模型的上限性能。
使用冻结模型生成嵌入,只需要在顶部训练几层,即可获得很好的性能,并且评估结果进一步表明,与现有模型相比,ViT-22B 在形状和纹理偏差方面显示出与人类视知觉更多的相似性,并且在公平性和稳健性方面提供了优势。
参考资料:
https://ai.googleblog.com/2023/03/scaling-vision-transformers-to-22.html