Training data-efficient image transformers & distillation through attention:通过注意力训练数据高效的图像转换器和蒸馏
论文地址: https://arxiv.org/abs/2012.12877
代码地址: https://github.com/facebookresearch/deit
这篇论文在2020年12月23日首次提交,也就是在ViT提出两个月后。ViT开创了Transformer在计算机视觉领域的先河,但是由于极其庞大的计算量和训练周期,难以应用在其他下游任务中,而本文在ViT的基础上提出了一种基于新型蒸馏方式的视觉Transformer,极大提高了训练速度,在不使用超大型数据集的情况下也能实现和ViT使用超大型数据集相媲美的结果,并且通过蒸馏得到了更好的结果。
1、前言
最近,纯粹基于注意力的神经网络被证明可以解决图像分类等图像理解任务,然而这些高性能视觉Transformer使用了大量的算力对数亿张图像进行了预训练。在这项工作中,作者通过在 Imagenet 上进行训练纯粹的 Transformer,使用一台计算机在3天内完成了训练。该模型仅含有 86M 的参数,在 ImageNet 上无需外部数据即可实现 83.1% 的 top-1 准确率。更重要的是,作者引入了针对Transformer 的师生策略。它依赖于蒸馏令牌,学生模型通过注意力向老师模型学习;当使用卷积网络作为教师模型时,在Imagenet上取得了高达 85.2% 的准确率;和转移到其他任务时报告的结果与卷积网络具有竞争力。
卷积神经网络一直是图像理解任务的主要设计范例,最初在图像分类任务中得到了证明。他们成功的因素之一是拥有大型的训练集,即Imagenet。ViT 是直接继承原始 NLP 架构的 Transformer,然后将原始图像分割成若干块作为输入,从而进行图像分类;ViT 通过使用超大型私有图像数据集(JFT-300M,3 亿张图像)进行训练然后取得了出色的结果;并且得出结论是,Transformer 在数据量不足的情况下训练不能很好地泛化,并且这些模型的训练涉及大量的计算资源。
而在本文中,作者指出不需要大型训练数据集(即仅使用 Imagenet1k),在单个 8-GPU 节点上训练 vision transformer 两到三天(53 小时的预训练,20 小时的微调),能够与具有相似参数量和效率的卷积神经网络模型相竞争,作者将其称为DeiT(数据高效的Transformer)。此外,作者解决了如何对这些大型模型进行蒸馏的问题,针对 transformer 架构,提出了基于token的策略,得到了DeiT⚗。
本文贡献:
1、作者提出的模型,基于纯粹的 Transformer 在没有额外数据的情况下,得到的结果可以与 ImageNet 上的最优结果相媲美。在具有 4 个 GPU 的单个节点上训练三天,得到的新模型 DeiT-S 和 DeiT-Ti 参数较少,可以看作 ResNet-50 和 ResNet-18 的对应模型。
2、作者提出了一种基于 distillation token 的新蒸馏过程,它与 ViT中的 class token 具有相同的作用,只不过它的目的是再现老师模型估计的标签,两个 token 在变压器中通过注意力进行交互。这种针对 Transformer 的蒸馏策略明显优于普通蒸馏;通过这种蒸馏,vision transformer 从卷积网络中学到的东西比从具有同样性能的另一个 transformer 中学到的东西更多。
2、Knowledge Distillation 知识蒸馏
知识蒸馏(Knowledge Distillation,简称KD)是一种模型压缩和优化的技术。它的核心思想是让一个轻量级的“学生”模型(student model)学习一个复杂且性能强大的“教师”模型(teacher model)的知识。传统的分类任务中,模型通常会输出一个预测概率分布(softmax函数输出),表示模型对于每个类别的预测概率。“硬标签”(hard label)是指模型预测概率最高的那个类别,而“软标签”(soft label)则是指模型对所有可能类别的预测概率分布。教师的“软标签”将具有与标签平滑技术类似的效果,以减少模型对于训练数据的过拟合。学生模型在训练时,不仅仅使用硬标签,而且使用教师模型提供的软标签,这意味着学生模型会尝试学习教师模型对所有类别的概率预测,而不仅仅是最有可能的类别。
此外,教师模型会考虑到数据增强带来的一些影响。例如,对于一张“猫”标签的图像,它代表一大片风景和角落里的一只小猫;如果经过旋转或缩放等数据增强,猫不再位于新的增强图片中,那么这张新图片在小模型中很可能就被预测为其他类别,比如风景类;但是教师模型由于其强大的性能和泛化能力,能够更好地处理这种“标签与图像的错位”问题,即使在猫被裁剪掉的情况下,教师模型可能仍然能够识别出这张图像曾经包含猫的信息,因此教师模型的预测会考虑到这种可能性,给出一个不仅仅基于当前裁剪图像的预测。通过知识蒸馏,学生模型学习教师模型的这种处理方式,从而在面对类似的错位情况时,能够做出更加合理和准确的预测。
归纳偏好是指模型对于数据的先验知识或假设,比如卷积神经网络(CNN)倾向于捕捉局部特征和空间层次结构。在没有教师模型的情况下,学生模型可能需要通过大量数据和训练来学习这些偏好,而通过知识蒸馏,教师模型可以直接将这些偏好以柔和的方式传递给学生模型,而不是学生模型通过自己“生硬”学习(即从数据中直接学习)来获得,教师模型自己则凭借强大的性能和泛化能力直接从数据中学习这些偏好。通过使用卷积神经网络作为教师模型,可以让学生 transformer 模型学习到CNN特有的归纳偏好。
3、Vision transformer overview ViT概述
Multi-head Self Attention layers (MSA):注意力机制是基于具有联想记忆,可训练的 Key-Value 键值向量对;使用内积将查询向量 q ∈ Rd 与 k 个 Key 键向量(打包在一起形成矩阵 K ∈ Rk×d)进行匹配;然后对这些内积进行缩放scale,使用 softmax 函数进行归一化以获得 k 个权重。注意力的输出是一组 k 个 value 值向量的加权求和(打包为 V ∈ Rk×d)。对于 N 个 Query 查询向量的序列(打包为 Q ∈ RN×d),它会生成一个输出矩阵(大小为 N × d):
N 表示序列中的元素数量(一条句子的单词数量),而 k 表示每个元素(或查询向量)将与序列其他元素进行比较的 key 键的数量(比较次数)。在标准的自注意力设置中,k 和 N 通常是相等的,这意味着每个元素都会与序列中的每个其他元素进行比较。然而,在某些情况下,k 可以小于 N,这表示每个查询向量只与序列中的部分键向量进行比较,从而减少计算量。
最后,多头自注意力层(MSA)是通过考虑 h 个注意力“头”来定义的,将 h 个自注意力函数应用于输入序列中。每个头输出一个大小为 N × d 的序列,共有 h 个序列,然后这些序列被重新排列为 N × (d*h) 序列,并由线性层重新投影为 N ×D,得到最终输出。
Transformer block for images:Transformer block 是构成整个模型的基本单元,每个块都包含多头自注意力(Multi-head Self-Attention, MSA)机制和前馈神经网络(Feed-Forward Network, FFN)。前馈网络 FFN 位于 MSA 层的顶部,由两个线性层组成,层与层之间由GeLu 激活函数分隔。第一个线性层将维度从 D 扩展到 4D,第二层将维度从 4D 减小回 D。MSA 层和 FFN 都采用了残差连接,残差连接可以让模型在每个子层(sub-layer)的输出中加入输入,然后这个组合的输出再传递到下一个子层;在MSA和FFN之后,使用层归一化来稳定训练过程。
为了让 Transformer 能处理图像,ViT 像处理输入序列一样处理输入图像,固定大小的 RGB 图像被分解成一批(N=14×14)大小为16×16像素的图像块 patches, N 表示图像被分解为196个图像块,每个图像块通过线性层进行投影,该线性层保持其整体维度不变。一个16×16像素的RGB图像块,其维度为 3×16×16=768,这意味着线性层将每个图像块映射到一个768维的向量。
原始图像的位置信息以固定位置嵌入(fixed positional embeddings)或可训练位置嵌入(trainable positional embeddings)的形式被加入,这些嵌入向量在第一个 Transformer block 之前添加到图像标记(patch tokens)中。加入了位置信息的图像块标记随后被送入一系列 transformer blocks 进行处理。
The class token:分类标记。分类标记是一个可训练的向量,附加到第一层之前的 patch tokens 中,它经过 transformer blocks,然后通过线性层进行投影以预测图像类别,这与自然语言处理(NLP)中使用的方法类似,其中分类标记也用于汇总整个序列的信息以进行分类。在计算机视觉中,传统的模型通常使用池化层(如平均池化或最大池化)来汇总特征图(feature maps)的信息,以进行最终的类别预测;transformer 模型则不使用池化层,而是使用分类标记来汇总全局信息。对于一个具有N个单词,每个单词有 d 维的句子序列,transformer 会处理N+1个token,多出来的这个token则是class token,只有 class token 被用来预测最终的输出类别。
fine-tune the network at the larger resolution:先在较低的分辨率下训练模型,然后在更高的分辨率下进行微调。这种方法可以加快整个训练过程,并在当前流行的数据增强方案下提高模型的准确性。当输入图像的分辨率提高时,图像块(patches)的大小仍然保持 16*16 不变,那么被分割出的图像块数量 N 会相应增加,与之相应的位置编码也需要调整,因为每个图像块都需要一个位置编码,总共有 N 个位置编码;在改变分辨率时通过插值方法更新位置编码,插值是一种常用的技术,可以用于估计在两个已知点之间的位置编码。例如,如果我们知道在低分辨率下的位置编码,可以通过插值来估计在高分辨率下的位置编码,而不需要重新训练整个模型。
具体步骤:
1、先在低分辨率(原始图像)下进行训练,patche大小不变,仍然是16*16,但是图像会被分割成更少的patche,可以加快训练速度。
2、训练完成后在更高分辨率的图像上进行微调,由于图像分辨率增加,相同尺寸的图像块数量 N 会增加,所以需要调整模型的位置编码以适应新的图像块数量(位置编码需要为每个图像块提供其在原始图像中的位置信息),从而提高模型在原始图像上的精度。
这种策略可以帮助模型在训练初期快速学习到有效的特征表示,然后在后期通过微调学习更精细的特征,提高模型对细节的识别能力,以适应更高分辨率的输入,从而在实际应用中达到更好的性能。
4、Distillation through attention:通过注意力进行蒸馏
在本节中,作者假设可以使用强大的图像分类器作为教师模型,它可以是一个卷积神经网络(CNN)或者任何其他类型的分类器组合,该模型已经具备了优秀的图像识别能力。作者展示了如何利用老师模型来训练一个基于 transformer 的学生模型,还提到了:硬蒸馏 Hard Distillation 与软蒸馏 Soft Distillation ,以及经典蒸馏 Classical Distillation 与蒸馏令牌 Distillation Token。
4.1、Soft distillation
软蒸馏旨在最小化教师模型和学生模型的softmax输出之间的Kullback-Leibler散度(KL散度),KL散度是一种衡量两个概率分布之间差异的方法。在训练过程中,学生模型尝试同时减少交叉熵损失(使其预测接近真实标签)和KL散度损失(使其预测接近教师模型的预测),通过调整 λ 和 τ,可以控制模型在模仿教师和预测准确性之间的平衡。
软蒸馏过程中使用的全局损失函数 Lglobal为:
公式第一项是学生模型的预测和真实标签之间的交叉熵损失,第二项是学生模型和教师模型softmax输出之间的KL散度损失。
Zt:教师模型的logits,即模型输出的原始分数,未经过softmax转换。
Zs:学生模型的logits。
τ:蒸馏温度,它是一个超参数,用于调整softmax输出的平滑程度。
λ:平衡系数,用于权衡KL散度损失和交叉熵损失之间的权重。
LCE:交叉熵损失函数,用于衡量学生模型预测和真实标签之间的差异。
ψ:softmax函数,用于将logits转换为概率分布。
y:真实标签。
4.2、Hard-label distillation
硬标签蒸馏是知识蒸馏(Knowledge Distillation)的一种变体,在硬标签蒸馏中,不是使用教师模型输出的软概率分布,而是直接使用教师模型最自信的预测(硬决策)作为真实标签来训练学生模型。
设yt=argmaxcZt(c)表示教师模型输出的硬决策,即在教师模型的输出logitsZt 中选择概率最大的类别 c 作为预测结果。
那么硬标签蒸馏过程中使用的全局损失函数 LglobalhardDistill 为:
公式的第一项是学生模型的预测和真实标签 y 之间的交叉熵损失,第二项是学生模型softmax输出和教师模型的硬决策 yt 之间的交叉熵损失;损失函数的两部分权重都是 1/2。
还是之前那个小猫图像的例子,如果经过旋转或缩放等数据增强,猫不再位于新的增强图片中,那么教师模型的硬标签 yt 可能会有所不同,比如变成了风景类,而原始的标签 y 还是猫类;也就是说在数据增强过程中,同一图像的不同变换可能导致教师模型对其预测结果发生变化。在硬标签蒸馏中,学生模型将尝试模仿教师模型对每个增强图像的硬预测。因此,如果教师模型对同一图像的不同增强版本有不一致的预测,学生模型也需要学会如何对这些变化做出反应。这种策略的优势在于它可以帮助学生模型学习到更多的泛化能力,使其不仅限于识别原始图像中的特征,还能适应图像的各种变换,从而使学生模型在面对真实世界的多样化数据时能够表现得更加鲁棒。
此外,硬标签也可以通过标签平滑(Label Smoothing)转换成软标签。在标签平滑中,真实标签 y 被赋予 1−ε 的概率,剩余的 ε 概率平均分配给其他类别,本文所有使用真实标签的实验中,参数 ε 被设置为 0.1。
4.3、Distillation token
作者提出的蒸馏令牌,向初始嵌入embdding(补丁和类标记)添加了一个新的标记,蒸馏标记。distillation token的使用类似于class token,通过自注意力机制与其他嵌入层进行交互,经过模型的所有层,在最后一层被输出。蒸馏令牌是通过蒸馏损失函数来实现的,这意味着它的输出会被用来计算损失函数,以指导模型的训练。蒸馏令牌允许模型从教师模型的输出中学习,这与常规的知识蒸馏相似;同时蒸馏令牌与类别令牌相互补充,它们共同参与模型的学习和预测过程。这种设计意味着类别令牌和蒸馏令牌虽然在目标上相似,但并不完全相同。类别令牌可能更专注于预测类别,而蒸馏令牌可能更专注于学习教师模型的预测,并将这些知识传递给学生模型。通过这种方式,学生模型能够更全面地学习和利用教师模型的知识,同时保持自身的特性和判断力。
作者观察到,经过训练后,类别令牌和蒸馏令牌会收敛到不同的向量上,这表明它们在学习过程中获得了不同的、但相互补充的表示。两者的初始平均余弦相似度为0.06,这意味着类别令牌和蒸馏令牌在初始状态下差异很大;但是随着模型层级的深入,两者逐渐变得相似(最后一层的余弦相似度为0.93),但仍然保持一定的差异。这种设计有助于模型在保留教师模型知识的同时,也能够保持一定的独立性和灵活性。
5、Experiments
本节介绍一些分析实验和结果。首先讨论蒸馏策略;然后比较分析了卷积网络和视觉Transformer的效率和准确性。
5.1、Transformer models
如前所述,作者的架构设计与 ViT 的架构设计相同,唯一的区别是训练策略和蒸馏令牌。此外,不使用 MLP 头进行预训练,仅仅使用线性分类器。如果没有指定,DeiT 指的是我们的参考模型 DeiT-Base,它与 ViT-Base 具有相同的架构;当以更大的分辨率微调 DeiT 时,会将最终的操作分辨率附加在末尾,例如 DeiT-B↑384。 最后,当使用蒸馏令牌时,用蒸馏器符号 DeiT⚗ 进行标识。
ViT-B(以及 DeiT-B)的参数固定为 D = 768(每个向量的特征维度)、h = 12(多头自注意力的头数) 和 d = D/h = 64(每个注意力头处理的维度大小)。通过改变头的数量,保持 d 固定,引入了两个较小的模型,即 DeiT-Small 和 DeiT-Tiny。
5.2、Distillation
使用不同教师模型进行知识蒸馏时学生模型(在这个案例中是DeiT-B模型)在ImageNet数据集上的性能比较。作者在 ImageNet-1k 上最佳模型的 top-1 准确度为 85.2%,优于在 JFT-300M 上以分辨率 384 预训练的最佳 Vit-B 模型 (84.15%)。作为参考,ViT-H 在 JFT-300M 上以 512 分辨率训练(600M 参数)取得了当前最佳水平 88.55%。
Convnets teachers:卷积网络作为教师模型。作者观察到,使用卷积网络作为教师比使用 transformer 具有更好的性能;这可能是由于 transformer 通过蒸馏继承了来自卷积网络的归纳偏差。后续的蒸馏实验,默认的教师模型是 RegNetY-16GF(84M 参数),使用与 DeiT 相同的数据和数据增强进行训练,在 ImageNet 上达到了 82.9% 的 top-1 准确率。下表比较了不同教师架构的蒸馏结果。
DeiT-B:
81.8%: 作为教师模型时,DeiT-B自身的top-1准确率。
81.9%: 使用DeiT-B作为教师模型进行预训练得到的 DeiT-B 学生模型top-1准确率。
83.1%: 在预训练基础上,进一步在384像素分辨率上微调后得到的DeiT-B⚗ 学生模型top-1准确率。
RegNetY-4GF:
80.0%: 作为教师模型时,RegNetY-4GF自身的top-1准确率。
82.7%: 使用RegNetY-4GF作为教师模型进行预训练得到的 DeiT-B 学生模型top-1准确率。
83.6%: 在预训练基础上,进一步在384像素分辨率上微调后得到的DeiT-B⚗ 学生模型top-1准确率。
Comparison of distillation methods:蒸馏方法的比较。下表中比较了不同蒸馏策略的性能:
label:表示使用真实标签进行训练。teacher:表示使用教师模型的软/硬标签进行训练。
Ti224、S224、B224:分别表示在224x224分辨率的图像上,使用Tiny、Small、Base尺寸的模型进行预训练。
B↑384:表示在224x224分辨率上预训练后,再在384x384分辨率上进行微调的Base尺寸模型。
DeiT-no distillation(无蒸馏):使用真实标签进行训练,没有使用蒸馏
DeiT-usual distillation(常规蒸馏):使用教师模型的软标签进行训练
DeiT-hard distillation(硬蒸馏):使用教师模型的硬标签(最大概率类别)训练
DeiT⚗: class embedding(类别嵌入):使用class token 和硬标签训练
DeiT⚗: distil. embedding(蒸馏嵌入):使用 distillation token 和硬标签训练
DeiT⚗: class+distillation(类别+蒸馏):结合class token 和 distillation token并使用硬标签
5.3、结论
蒸馏的有效性:从表中可以看出,使用蒸馏策略(无论是软蒸馏还是硬蒸馏)都能提高模型的预训练和微调性能。
硬蒸馏的优势:硬蒸馏通常比不使用蒸馏或使用常规软蒸馏的准确率更高。
结合类别和蒸馏嵌入:在最后一种方法中,通过结合类别嵌入和蒸馏嵌入,并在训练的最后阶段融合这些分类器(late fusion),可以获得最高的准确率,这表明两种嵌入能够互补,共同提高模型性能。