1. 解决了什么问题?
非监督学习在自然语言处理非常成功,如 GPT 和 BERT。但在计算机视觉任务上,监督预训练方法要领先于非监督的方法。这种差异可能是因为各自的信号空间不同,语言任务有着离散的信号空间(单词、短语等)来构建非监督学习所需的字典。而计算机视觉则很难构建一个字典,因为原始信号位于连续的高维空间,不像单词一样是结构化的。
最近的非监督表征学习方法使用对比损失取得了不错的效果,它们基本是构建了一个动态字典。从数据中采样,产生字典的 keys/tokens,由编码器网络表征。非监督学习训练编码器来进行字典查询:query 应该与匹配到的 key 距离近,而与其它 keys 距离远。通过最小化对比损失来进行训练。
本文假设字典的构建应该满足两个条件,一是足够大,二是在训练过程中不断地更新,是连续的。字典足够大,能更好地从连续的高维空间中采样,keys 由近似的编码器表征,这样 key 和 query 的比较才是连续的。而目前的方法只能满足上述两个条件中的一个。
非监督/自监督学习一般包括两个方面:pretext 任务和损失函数。Pretext 的意思是待解决的任务不是我们真正关心的,我们真正的目的是学习好的数据表征。损失函数可以独立于 pretext 任务来研究,MoCo 聚焦在损失函数上。
2. 提出了什么方法?
针对非监督视觉表征学习任务,提出了 Momentum Contrast,构建了一个包含队列和滑动平均编码器的动态字典。MoCo 针对非监督学习,利用对比损失构建了一个足够大且连续的字典,该字典用一个队列维护:当前 mini-batch 的表征加入队列,最早的 mini-batch 表征从队列中剔除。字典的 keys 来自于之前的多个 mini-batches,通过一个基于动量的滑动平均编码器实现该缓慢演进的 key 编码器,保证连续性。
2.1 Contrastive Learning as Dictionary Look-up
对比学习就是针对字典查询任务,训练一个编码器。给定一个编码后的 query q q q和字典的一组编码样本 { k 0 , k 1 , k 2 , . . . } \lbrace k_0,k_1,k_2,...\rbrace {k0,k1,k2,...}。假设字典中有一个 q q q匹配到的 key,记做 k + k_+ k+。对比损失中,当 q q q与 k + k_+ k+相似而与其它 keys 不相似时,损失值就小。相似度用点积表示,是对比损失函数的一种形式,叫做 InfoNCE:
L q = − log exp ( q ⋅ k + / τ ) ∑ i = 0 K exp ( q ⋅ k i / τ ) \mathcal{L}_q = -\log \frac{\exp\left(q\cdot k_+/\tau\right)}{\sum_{i=0}^K \exp \left(q\cdot k_i /\tau\right)} Lq=−log∑i=0Kexp(q⋅ki/τ)exp(q⋅k+/τ)
τ \tau τ是一个调节超参数。除数是对一个正样本和 K K K个负样本求和。该损失是基于 Softmax 的 ( K + 1 ) − way (K+1)-\text{way} (K+1)−way分类器,将 q q q分类为 k + k_+ k+。对比损失函数也可基于其它形式,比如 margin-based 损失和 NCE 损失变体。
对比损失是训练编码器的非监督目标函数,该编码器表征 query 和 key。通常,query 表征为 q = f q ( x q ) q=f_q(x^q) q=fq(xq), f q f_q fq是编码器网络, x q x^q xq是 query 样本。同样, k = f k ( x k ) k=f_k(x^k) k=fk(xk)。输入 x q x^q xq和 x k x^k xk可以是图像、图块或图块构成的 context。网络 f q f_q fq和 f k f_k fk可以是一样的,也可以部分共享的,也可以是完全不同的。
2.2 Momentum Contrast
对比学习从高维连续输入(如图像)中构建离散字典。该字典是动态的,keys 通过随机采样得到,在训练过程中 key 编码器不断地更新。本文假设,如果一个字典足够大,涵盖了丰富的负样本,就能用该字典学习好的特征。而且编码器在更新过程中是连续的。
Dictionary as a queue
本方法的核心就是,字典用一个样本队列来维护。这样我们就可复用不久前 mini-batches 的 keys。该字典的大小可以远大于 mini-batch 的大小,作为一个超参灵活地设定。
字典中的样本被逐步替换掉。当前 mini-batch 样本加入到字典中,最早的 mini-batch 则被剔除。字典只代表了数据集的一个子集,维持这个字典的计算量是可以控制的。剔除最早的 mini-batch,它所编码的 keys 过时了,与最新的 mini-batch 连续性最低。
Momentum update
队列表示能让字典很大,但无法通过反向传播来更新编码器(梯度应该回传给队列中所有的样本)。简单的办法就是复制 query 编码器 f q f_q fq到 key 编码器 f k f_k fk,不管梯度。但这个办法实验效果不行。作者认为,编码器的迅速变化,降低了 key 表征的连续性。于是提出了动量更新,解决这个问题。
将 f k f_k fk的参数记做 θ k \theta_k θk, f q f_q fq的参数为 θ q \theta_q θq,更新 θ k \theta_k θk:
θ k ← m θ k + ( 1 − m ) θ q \theta_k \leftarrow m\theta_k + (1-m)\theta_q θk←mθk+(1−m)θq
m ∈ [ 0 , 1 ) m\in [0,1) m∈[0,1)是动量系数。反向传播只用更新 θ q \theta_q θq。动量更新使 θ k \theta_k θk的更新更加平滑。这样,尽管队列中的 keys 是用不同的编码器(不同的 mini-batches)编码的,这些编码器的差异很小。在实验中,大动量系数(比如 m = 0.9 m=0.9 m=0.9)的表现要好于小的系数,表明缓慢更新的 key 编码器是使用队列的关键。
Relations to previous mechanisms
MoCo 对于对比损失是通用的。在下图中,MoCo 和现有的两种对比损失机制进行了比较,在字典大小和连续性方面有着不同的特性。它们差异体现在 keys 是如何维护的,以及 key 编码器是如何更新的。
-
计算 query 和 key 的编码器通过端到端的反向传播更新,两个编码器可以不一样。它使用当前 mini-batch
的样本作为字典,keys 是连续编码的,因为编码器参数是一样的。但是字典大小与 mini-batch 大小是耦合的,受 GPU
显存大小限制。 -
从 memory bank 中采样得到 keys 表征。Memory bank 包括了数据集所有样本的表征。每个
mini-batch 的字典都是从 memory bank 中随机采样得到,无需反向传播,因此字典规模可以很大。但是 memory
bank 的样本表征是看到了才会更新,因此采样的 keys 是 epoch 中不同步骤的编码器生成的,彼此缺乏连续性。 -
MoCo 使用动量更新编码器来编码新的 keys,维护了一个 keys 队列。Moco
并不记录每个样本,因此对内存更加有效,可以在数以亿计的数据上训练。
2.3 Pretext Task
如果一对 query 和 key 来自于同一图像,则是正样本对,否则为负样本对。使用数据增强得到同一图像的两个随机视角,产生正样本对。用各自的编码器 f q f_q fq和 f k f_k fk编码得到 query 和 key。算法1 是该 pretext 任务的 MoCo 伪码。对于当前 mini-batch,编码 queries 和相应的 keys,得到正样本对。负样本对则来自于队列。
技术细节
编码器采用 ResNet,全局平均池化层后的最后一个全连接层有固定维度( 128 128 128维)的输出。用 L 2 − norm L2-\text{norm} L2−norm对输出向量归一化。这就是 query 或 key 的表征。 τ \tau τ设为 0.07 0.07 0.07。数据增强方法如下:对随机缩放的图像裁剪一块 224 × 224 224\times 224 224×224大小的区域,然后使用随机色彩变动、随机水平翻转和随机灰度转换。
Shuffling BN
f q f_q fq和 f k f_k fk都在 ResNet 中使用了 BN。BN 会阻碍模型学习高质量表征。模型似乎在作弊,欺骗 pretext 任务,很容易就找到了低损失值的方案。这可能是因为 BN 在 batch 内部交流信息造成了信息泄露。
于是作者使用了 shuffling BN。使用了多个 GPU 训练,对于每个 GPU 的样本独立完成 BN 操作。对于 key 编码器 f k f_k fk,shuffle 当前 mini-batch 样本的顺序,然后再将其分配到各个 GPU,编码后再 shuffle 回来。Query 编码器 f q f_q fq的 mini-batch 样本顺序不变。这保证了计算 query 和它的正样本 key 所需的 batch 统计信息来自于两个不同的子集。这就有效解决了作弊问题。
在上图(a)和©中,作者使用了 shuffling BN,(b) 中没有用,因为 memory bank 提供的正样本 keys 来自于之前产生的、不同的 mini-batches。