【深度学习】详解 MoCo

news2024/11/16 5:59:53


目录

摘要

一、引言

二、相关工作

三、方法

3.1 Contrastive Learning as Dictionary Look-up

3.2 Momentum Contrast

3.3 Pretext Task

四、实验

4.1 Linear Classification Protocol

总结 ☆

实现

参考资料


  • Title:Momentum Contrast for Unsupervised Visual Representation Learning
  • Paper:https://arxiv.org/pdf/1911.05722.pdf
  • Github:https://github.com/facebookresearch/moco 

摘要

        我们提出了 动量对比 (MoCo) 用于 无监督视觉表示学习。从对比学习作为字典查找 (look-up) 的角度来看,我们构建了一个 具有一个 队列 (queue) 和一个 移动平均编码器 (moving-averaged encoder) 的 动态字典。这使得动态 (on-the-fly) 建立一个大型且一致的字典能够促进对比无监督学习。MoCo 在 ImageNet 线性分类通用协议下提供了有竞争力的结果。更重要的是,MoCo 学习到的表示可以很好地迁移到下游任务中。MoCo 可在 PASCAL VOC、COCO 和其他数据集上的 7 个检测/分割任务中 优于 有监督的预训练竞争者,有时甚至远超。这表明,在许多视觉任务中,无监督和有监督表示学习之间的差距已在很大程度上被缩小。


一、引言

        无监督表示学习在 NLP 中非常成功,如 GPT 和 BERT。但是有监督预训练在 CV 中仍占主导地位 (be dominant in),而无监督 CV 方法通常是落后的 (lag behind)。其原因可能在于它们 各自的信号空间的差异语言 任务 具有 离散的信号空间 (words, sub-words 等),用于构建 tokenized 字典,该过程可以基于无监督学习。相比之下,CV 进一步关注字典构建,因为 (视觉的) 原始信号 处于一个 连续且高维的空间 中,且并非面向人类通信的结构 (例如,不像 words)。

        最近的几项研究提出使用 与对比损失相关的方法 进行无监督视觉表示学习,并展示出了有前景的结果。尽管受到各种动机的驱动 (driven by various motivations),这些方法可以被认为是 构建动态字典字典中的 key (tokens) 采样自数据 (如 images 或 patches),并由编码器网络表示无监督学习训练编码器 以实施 字典查找 (look-up)一个经编码的 query 应与其匹配的 key 相似,而与其他 keys 不相似学习被表述为 最小化对比损失

        从这个角度来看,假设构建字典的理想情况是:(i) 大型且 (ii) 在训练期间的演进/发展 (evolve) 具有一致性。直观地,一个更大的字典可以更好地采样潜在的 (underlying) 连续且高维的视觉空间,而字典中的 key 应由相同或相似的编码器来表示,以便它们与 query 的比较是一致的。然而,现有使用对比损失的方法会限制两个方面中一者 (稍后将在上下文中讨论)。

一个好的字典应同时具有以下两个特性

  • 字典足够大型:字典越大则 key 越多,所能表示的视觉信息、视觉特征就越丰富 ,从而用 query 参与对比学习时,才越能学到图片的特征。反之,若字典很小,则模型很容易学习一些捷径来区分正、负样本 (过度拟合简单样本/特征),对大量真实数据的泛化很差。
  • 编码的特征尽量保持一致性:字典里的 key 都应用相同/相似的编码器去编码得到,否则在模型查找 query 时,可以走一些捷径 —— 通过找到和 query 使用相同/相似的编码器的 key,而非真正与 query 含有相同/相似语义信息的 key。

对比学习方法在过去都至少被上述二者之一限制,而 MoCo 最大的贡献在于,使用队列以及动量编码器进行对比学习,解决该问题。

图 1:动量对比 (MoCo) 通过使用对比损失 将 “经编码的 query q” 与 “经编码的 key 的字典” 实施匹配来训练视觉表示编码器。
字典 keys {k0, k1, k2, …} 是由一组数据样本动态 (on-the-fly) 定义的。
keys 的字典被构建为一个队列,当前的 mini-batch 入队,最早的 mini-batch 出队,使之与 mini-batch 大小解耦。
keys 被一个缓慢更新 (slowly progressing) 的编码器编码,由 query 编码器的动量更新驱动。
这种方法为学习视觉表示提供了一个大型且一致的字典。

        我们提出了动量对比 (MoCo),作为一种构建大型且一致的字典的方法,用于具有对比损失的无监督学习 (图 1)。字典 被维护为一个 数据样本的队列:当前的 mini-batch 入队,最早的 mini-batch 出队。该队列 将字典大小与 mini-batch 大小解耦,从而 允许字典变得大型。此外,由于 字典 keys 来自于前面的几个 mini-batch,此处提出了一个 缓慢更新 (slowly progressing) 的 key 编码器作为 query 编码器的基于动量的移动平均 (momentum-based moving average) 来实现,以 保持一致性

        MoCo 是一种构建对比学习的动态字典的机制,可用于各种 前置/代理任务 (pretext task)。本文 following 最广泛应用的简单前置任务 —— 实例判别 (instance discrimination)若 query 和 key 是源自同一图像的经编码的视图 (views)则二者相匹配。利用这个前置任务,MoCo 显示出在 ImageNet 中在线性分类普通协议下的具有竞争力的结果。

        无监督学习的一个主要目的是 预训练可通过微调迁移到下游任务的表示 (即特征)。在 7 个与检测/分割相关的下游任务中,MoCo 无监督预训练可超过 ImageNet 有监督预训练,(甚至) 在某些情况下远超 (by nontrival margins)。在实验中,探索了在 ImageNet 或 10 亿个 Instagram 图像集上预先训练过的 MoCo,证明了 MoCo 可在更真实、十亿图像规模和相对未知的 (uncurated) 场景中很好地工作。这些结果表明,MoCo 在很大程度上在许多 CV 任务中缩小了无监督和有监督表示学习之间的差距,并且在一些应用中可作为 ImageNet 有监督预训练的替代方案。 

  • 详见文末总结

二、相关工作

        无监督/自监督 (自监督学习是无监督学习的一种形式。在现有文献中,它们的区别是非正式的。本文中,我们在 “没有人类注释的标签监督” 的意义上,使用了更经典的术语 “无监督学习”) 学习方法通常涉及两个方面:前置任务和损失函数。“前置” 一词意味着被解决的任务并非真正的兴趣,而为了学习良好的数据表示的用于真正目的 (如 下游任务)。损失函数 通常可独立于前置任务被调查/研究。MoCo 侧重于损失函数方面。接下来将讨论有关这两个方面的相关研究。

        损失函数。定义损失函数的一种常见方法是 衡量 模型的预测 和 一个固定的 target 之间的差异,例如通过 L1 或 L2 损失重建输入像素 (如 自动编码器),或通过交叉熵或 margin-based 的损失将输入分类为预定义的类别 (如 8 个位置、color bins)。如下面所述,其他的替代方案也是可能的。

        对比损失 衡量了 表示空间中样本对的相似度。有别于将输入与固定的 target 进行匹配,在对比损失公式中,target 可在训练过程中动态 (on-the-fly) 变化,并且可根据网络计算的数据表示来定义。对比学习是最近几项关于无监督学习的工作的核心,稍后将在上下文中详细阐述它(3.1 节)。

        对抗损失 衡量了 概率分布之间的差异。这是一种广泛成功的无监督数据生成技术。在 (Adversarial feature learning, Large scale adversarial representation learning) 中探讨了表示学习的对抗方法。生成对抗网络 (GAN)噪声对比估计 (noise-contrastive estimation, NCE) 之间存在关系 (Generative adversarial nets)

        前置任务。人们提出了范围广泛的 前置/代理任务 (pretext task)。例子方面,包括在某些损坏下恢复输入,如 去噪自动编码器、上下文自动编码器,或跨通道自动编码器 (colorization)。一些前置任务通过构造伪标签,如 单张 “样例 (exemplar)” 图像的转换、patch 排序、跟踪或分割视频中的 objects,或聚类特征。

        对比学习 vs 前置任务。各种前置任务都可基于某种形式的对比损失函数。实例判别 (Instance discrimination) 方法,与基于样例 (exemplar-based) 的任务和噪声对比估计相关。对比预测编码 (CPC) 中的前置任务是上下文自动编码的一种形式,而对比多视图 (multivies) 编码 (CMC) 中的前置任务与 colorization 有关。


三、方法

3.1 Contrastive Learning as Dictionary Look-up

        对比学习及其近期发展,可被视为是 为字典查找 (look-up) 任务训练一个编码器,如下所述。

        考虑一个经编码的 query q,和 一组经编码的样本 \{ k_0, k_1, k_2, ... \} —— 字典的 keys。假设字典中有一个 q 相匹配的 key (记为 k_{+})。对比损失是一个函数,当 q 与 positive key k_{+} 相似 且与所有其他 keys 不同 (被视为 q 的 negative keys) 时,损失函数值较低。利用 点积 衡量相似度,本文考虑了一种对比损失函数的形式,称为 InfoNCE

        其中,\tau 是每个 (Unsupervised feature learning via non-parametric instance discrimination) 的温度超参数。当 \tau = 1 时,InfoNCE 变为标签 CE 损失。InfoNCE 的总和包含了一个正样本 k_{+}K +1 个负样本 (即字典/队列里所有 keys)。直观地,InfoNCE 损失是一个 ( K + 1) 路的 softmax 分类器的对数损失,该分类器努力将 q 归类为 k_{+}。对比损失函数也可基于其他形式,如 margin-based 的损失和 NCE 损失的变体。

InfoNCE Loss

  • 公式 (1) 的分子表示 query 和正样本 key 计算,分母表示 query 和 K+1 个负样本 key 计算累加和。
  • 直接计算的复杂度很大,因为 MoCo 使用 instance discrimination 作为前置任务,那么 IN-1K 的 128 万张图片即可视为有 128 万个类别,相应地要设置 128 万分类的 Softmax,从而直接计算和训练是非常困难的。
  • NCE loss (Noise Contrastive Estimation Loss):将多分类改造为二分类 —— 数据类别 data sample (正类) 和噪声类别 noisy sample (负类),从而解决了巨量类别问题。
  • Estimation:意为近似。为降低计算复杂度,不是在每次迭代时遍历整个 IN-1K 的约 128 万个负样本,而是只从中选一些负样本来参与 Loss 计算 (即选队列字典中的 65536 个负样本),从而相当于一种近似。这也正是 MoCo 所强调的 —— 好的字典应足够大型,因为越大型的字典越能够提供越好的近似。
  • InfoNCE Loss 作为 NCE loss 的一个简单变体,认为如果只把问题视为二分类,可能对模型学习不是很友好,毕竟大量的噪声样本很有可能不属于一个类别,所以还是视为了多分类问题。
  • 公式 (1) 的 q * k_{+} 其实相当于 logit,也可类比为 Softmax 中的 z\tau 作为温度超参数,用于控制分布的形状 。\tau越大,分布中的数值越小,经过指数化 ( exp(·) ) 后会更小,分布就会变得更平滑,相当于对比损失对所有的负样本都一视同仁,导致学习的模型缺乏差异化关注。相反,\tau 越小,分布更集中,模型会更关注困难负样本,特别是那些作为困难负样本的潜在正样本,若模型过度关注负样本,会导致模型很难收敛,或学到的特征缺乏泛化性。

        对比损失 作为一个无监督的目标函数来 训练表示 query 和 keys 的编码器网络。通常,query q = f_q(x^q),其中 f_q 是一个 (query) 编码器网络,x^q 是一个 query 样本 (同理有 k = f_k(x^k))。它们的实例化 (instantiations) 取决于具体的前置任务。输入的 query x^q 和 key x^k 可以是 图像patches包含一系列上下文的 patches 等。使用的 query 编码器网络 f_qkey 编码器网络 f_k 可以是 完全相同/共享的 (如 Inva Spread 架构相同参数共享)、部分相同/共享的、或完全不同的 (如 CMC,多视角多编码器)


3.2 Momentum Contrast

        从上述角度来看,对比学习是一种 在图像等高维连续输入上构建离散字典 的一种方法。字典是动态的,因为 keys 是随机采样的,且 key 编码器在训练过程中演进 (evolves)。我们的假设是,好的特征可以通过一个包含大量负样本的大型字典来学习,而字典 keys 的编码器尽管还在演进,但仍尽可能地保持一致。基于这个动机,我们将呈现出下面所描述的动量对比。 

        字典作为队列。我们方法的核心是将字典作为一个数据样本的队列。这允许我们重用最靠近前面的 mini-batches 中的经编码的 keys。队列的引入可将字典大小与 mini-batch 大小解耦。我们的字典大小可以比一个典型的 mini-batch 大小大得多,并且可灵活独立地设为一个超参数。

        字典中的样本逐渐被替换。当前的 mini-batch 将入队字典,队列中最老的 mini-batch 将被移除出队。字典总是表示所有数据的一个采样子集 (类似基于所有数据的滑动窗口),而维护此字典的额外计算是可管理的。此外,删除最老的 mini-batch 可能是有益的,因为其中的 经编码的 keys 是最过时的,因此与最新的 keys 最不一致

        动量更新使用队列可使字典变得大型,但它也使通过反向传播更新 key 编码器变得困难 (梯度应传播到队列中的所有样本)。一个朴素的解决方案是从 query 编码器 f_q复制 key 编码器 f_k 而 忽略 这个 梯度。但这种解决方案在实验中产生的结果很差 (4.1 节)。我们假设,这种失败是由 快速变化的编码器降低了 key 表示的一致性 导致的。我们提出 动量更新 来解决该问题。

        形式上,将 f_k 的参数表示为 \theta_kf_q 的参数表示为 \theta_q,通过下式更新 \theta_k

        此处 m \in [0,1) 是一个动量系数。只有 query 编码器参数 \theta_q 才会通过反向传播进行更新,而当前的 key 编码器参数 \theta_k 是基于先前的 \theta_k 和当前的 \theta_q 实现间接动量更新。在公式 (2) 中的动量更新使 \theta_k 比 \theta_q 演进得更 smoothly。因此,尽管 队列中的 keys 由不同的编码器编码 (在不同的 mini-batches),但这些 编码器间的差异可以很小。在实验中,一个相对较大的动量 (如 m = 0.999,我们的默认值) 比一个较小值 (如 m = 0.9) 要好得多,这表明 一个缓慢演进的 key 编码器是利用队列的核心

图 2:三种对比损失机制的概念性比较 (实证比较见图 3 和表 3)。
此处将举例说明一对 query 和 key。​​这 3 种机制在如何维护 keys 和如何更新 key 编码器方面有所不同。
(a): 用于计算 query 和 key 表示的编码器通过反向传播进行端到端更新 (这两个编码器可以不同)
(b): key 表示从内存库 (memory bank) 中采样
(c): MoCo 通过一个动量更新的编码器动态地编码新 keys,并维护一个 keys 的队列 (详见图 1)

        与先前机制的关系。MoCo 是使用对比损失的通用机制。我们将其与图 2 中两种现有的通用机制进行了比较。它们在字典的大小和一致性上表现出不同的属性。

        通过反向传播进行的端到端更新 是一种自然的机制 (图 2a)。它使用 当前 mini-batch 中的所有样本作为字典,因此 keys 被一致地编码 (由相同的一组编码器参数编码)。但字典大小与 mini-batch 大小相耦合 (couple with),会受到 GPU 内存大小的限制。它也受到了大 mini-batch 优化的挑战。最近的一些方法是基于 由局部位置 (local positions) 驱动的前置任务,其中通过多个位置 (multiple positions) 可以使字典大小更大。但是这些前置任务可能需要特殊的网络设计,如 patchifying 输入 或 customizing 感受野大小,这可能会使这些网络向下游任务的迁移复杂化。

(a) SimCLR / Inva Spread 的端到端学习

  • 缺点:字典大小和 mini_batch 大小一致,但大 batch 难设置、难优化、难收敛,故效果有限。
  • 优点:梯度反向传播使得编码器可以实时更新,从而令字典中的 key 具有很高的特征一致性。
  • SimCLR 最终去 batch_size=8192 训练 (谷歌 TPU memory 很大),可以支持模型做对比学习。

        另一种机制是 内存库 (memory bank) 方法 (图 2b)内存库由数据集中所有样本的表示组成 (离线提取所有 keys 表示)每个 mini-batch 的字典都随机抽样自内存库而没有反向传播,因此它可以支持一个大型的字典。然而,当最后一次看到样本 (的表示) 时,内存库中样本的表示被更新,因此采样到的 keys 本质上是关于在整个过去 epoch 的多个不同 steps 的编码器,因此不够一致。在 (Unsupervised feature learning via non-parametric instance discrimination) 中的内存库采用了动量更新。其动量更新是在同一个样本的表示上,而非编码器。该动量更新与我们的方法无关,因为 MoCo 并没有追踪每个样本。此外,MoCo 的内存效率更高,并且可以在数十亿的规模数据上训练,这对于内存库来说是难以处理的。

(b) Memory Bank / InstDisc 模型

  • memory bank 中,query 的编码器是梯度更新的,但是字典中的 key 没有单独对应的可学习编码器。
  • memory bank 预存了整个数据集的嵌入特征,训练时只需要从中采样一些 keys 子集作为字典,然后正常计算 query 和 key 的 loss,通过梯度反向传播更新编码器。
  • 编码器更新后,重新编码​采样到的 keys 子集得到新的嵌入特征来替换原值,从而完成一个 step 的 memory bank 更新,依此类推.
  • ImageNet 虽有 128 万张图片 —— 128 万个 keys,但特征维度 dim = 128,用 memory bank 存下来只需 600M,尚且可行。但是对于亿级图片规模的数据,预提取和存储所有特征则要几十至几百 G 的 memory,故 memory bank 的扩展性不如 MoCo。

memory bank 的特征一致性很差

  • query 编码器的更新很频繁 (batch-wise),导致 key 的嵌入特征提取自不同时刻的编码器,特征一致性很差。
  • memory bank 预存了整个数据集的嵌入特征,使得模型要训练一个 epoch (所有 steps / iters) 才能把整个 memory bank 更新一遍。当下一个 epoch 训练开始时,第一个 step / iter 选中的 keys 的嵌入特征可能分别来自上一个 epoch 中不同时刻的编码器,导致 query和 key 的嵌入特征差异很大。
  • memory bank 通过另一个 loss (proximal optimization) 平滑训练过程,也提到了动量更新 (样本的表示/特征,而非编码器)。

        第 4 节对这三种机制进行了实证比较 (empiricaly compares)。


3.3 Pretext Task

        对比学习可以驱动各种前置任务。本文的重点不是设计一个新的前置任务,而是主要 following (Unsupervised feature learning via non-parametric instance discrimination) 中的实例判别任务并使用一个简单的前置任务,这与一些最近的工作有关。

        按照 (Unsupervised feature learning via non-parametric instance discrimination)如果一个 query 和一个 key 来自同一图像,则我们将它们视为正对,否则将它们视为负样本对。我们使用随机数据增强广 获取同一图像的两个随机 “视图 (views)” 以构成一个正对。而 query 和 key 分别由它们的编码器 f_q 和 f_k 进行编码。该编码器可以是任何 CNN。

算法 1

        算法 1 为这个前置任务提供了 MoCo 的伪代码。对于当前的 mini-batch,我们编码 query 及其对应的 keys,它们构成了正样本对。负样本则来自队列。

f_k.params = f_q.params  # key 编码器的参数初始化基于 query 编码器
for x in loader:  # 取出一个 mini-batch 的图像序列 x,包含 N = 256 张图片,但没有标签
    x_q = aug(x)  # 用作 query 的图(数据增广得到)
    x_k = aug(x)  # 用作 key 的图 (数据增广得到),与 x_q 构成正样本对
    q = f_q.forward(x_q)  # 提取 query 特征,q.shape = N×C,c 为 embed dim
    k = f_k.forward(x_k)  # 提取 key 特征,k.shape = N×C,c 为 embed dim
    k = k.detach()  # 不使用梯度更新 key 编码器 f_k 的参数,确保提取的特征的一致性

    # bmm 是分批矩阵乘法; 字典大小 K = 65536
    l_pos = bmm(q.view(N,1,C), k.view(N,C,1))  # l_pos.shape = N×1,q * k+ (query 与当前正样本的相似度)
    l_neg = mm(q.view(N,C), queue.view(C,K))  # l_neg.shape = N×K,q * k_ (query 与上一 mini-batch 或 queue 的所有负样本的相似度)
    logits = cat([l_pos, l_neg], dim=1)  # 拼接正负样本相似度,logits.shape = N×(1+K) = 256×(1+65536) -> 相当于 65537 分类
    labels = zeros(N)  # 按照以上实现方式,所有正样本永远在 logits 的 index = 0 的位置上
 
    # InfoNCE Loss,促进 query 与正样本 key 的相似度越来越高、与负样本 keys 的相似度越来越低
    loss = CrossEntropyLoss(logits/t, labels) 
    loss.backward()  # 计算梯度反向传播

    update(f_q.params)  # query 编码器 f_q 使用梯度立即更新
    f_k.params = m*f_k.params+(1-m)*f_q.params   # key 编码器 f_k 缓慢地动量更新

    enqueue(queue, k)  # 当前 mini-batch 的样本特征入队,作为下一个 mini-batch 中 query 的负样本
    dequeue(queue)  # 最早进入 queue 的 mini-batch 出队

        技术细节。我们采用 ResNet 作为编码器,其最后一个全连接层 (在全局平均池化之后) 具有固定维数的输出 (128-D)。输出向量由 L2-范数归一化。此即为 query 或 key 的表示。公式 (1) 中的温度系数 \tau 设为 0.07。数据增广设置遵循 (Unsupervised feature learning via non-parametric instance discrimination)从经随机 resized 的图像中裁剪出 224×224 的像素,然后随机颜色抖动 (color jittering)、随机水平翻转和 随机灰度转换,所有这些都可以在 PyTorch 的 torchvision 包中获得。以下展示了数据增广示例代码:

# https://github.com/facebookresearch/moco/blob/main/main_moco.py

    if args.aug_plus:
        # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
        augmentation = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]
    else:
        # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978
        augmentation = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]

        Shuffling BN。编码器 f_q 和 f_k 都有 BN,如同标准 ResNet。实验中发现 使用 BN 会抑制模型学习良好的表示,就像在 (Data-efficient image recognition with contrastive predictive coding) 中所报道的那样 (它避免使用 BN)。模型似乎 “欺骗 (cheat)” 前置任务,且容易找到低损失的解决方案。这可能是因为 样本间的 intra-batch 通信 (由 BN 引起的) 泄漏了信息(前人认为 BN 使样本间数据不期望的发生交互,从而使模型倾向于找到与原训练目标不符的 low-loss 优化方式,故避免使用 BN)

        我们通过 shuffling BN 来解决该问题。我们使用多 GPU 训练,并为每个 GPU 独立地对样本执行 BN (正如在普通实践中所做的那样)。对于 key 编码器 f_k,在将其分布到各 GPU 之间前,shuffle 当前 mini-batch 中的样本顺序 (并在编码后 shuffle back);query 编码器 f_q 的 mini-batch 的样本顺序不变。这 确保了用于计算 query 及其 positive key 的 batch 统计信息 来自两个不同的子集。这有效地解决了作弊 (cheating) 问题,并允许训练受益于 BN。(由于每个 batch 内的样本之间计算 mean 和 std 导致信息泄露,产生退化解。MoCo 通过多 GPU 训练,分开计算 BN,并且 shuffle 不同 GPU 上产生的 BN 信息来解决问题)

        本文在我们的方法和对应的端到端消融中都使用了 shuffling BN (图 2a)。它与作为竞争者的内存库无关 (图2b),内存库不受此问题的影响,因为 positive keys 来自过去不同的 mini-batches。


四、实验

        我们研究基于以下数据集的无监督训练:

        ImageNet-1M (IN-1M):ImageNet 训练集基于 1000 个类别,有 ∼128 万张图像 (其实是 ImageNet-1K;我们计算图像数量 1M 而非类别 1K,因为无监督学习不用类别)。该数据集在类别分布上很平衡,其中的图像通常包含 objects 的标志性视图 (iconic view)。

        Instagram-1B (IG-1B):根据 (Exploring the limits of weakly supervised pretraining),这是一个来自 Instagram 的具有 ∼10亿 (940M) 公共图像的数据集。这些图片具有与 ImageNet 类别相关的∼1500 种散列标记 (hashtags)。与 IN-1M 相比,该数据集相对未被规整 (uncurated),并且具有真实世界数据的长尾、不平衡的分布。此数据集同时包含标志性 (iconic) objects 和场景级 (scene-level) 图像。

        训练。使用 SGD 优化器,权重衰减为 0.0001,动量为 0.9。对于 IN-1M,在 8 个 GPU 中使用 256 的 mini-batch (算法 1 中的 N),初始学习率为 0.03。在 120 和 160 个 epoch 时学习率乘 0.1,共训练 200 个 epochs,耗费 ∼53 小时训练 ResNet-50。对于 IG-1B,在 64 个 GPU 中使用 1024 的 mini-batch,学习率为 0.12,每 62.5k 次迭代 (64M 张图像) 学习率指数衰减 0.9×。训练 125 万 (1.25M) 次迭代 (IG-1B 的 ∼1.4 个 epoch),耗费 ∼6 天训练 ResNet-50。

# https://github.com/facebookresearch/moco/blob/main/main_moco.py

parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
                    help='learning rate schedule (when to drop lr by 10x)')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum of SGD solver')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
# moco specific configs:
parser.add_argument('--moco-dim', default=128, type=int,
                    help='feature dimension (default: 128)')  # embedding size = 128
parser.add_argument('--moco-k', default=65536, type=int,
                    help='queue size; number of negative keys (default: 65536)')  # len(queue) = 65536
parser.add_argument('--moco-m', default=0.999, type=float,
                    help='moco momentum of updating key encoder (default: 0.999)')  # m = 0.999
parser.add_argument('--moco-t', default=0.07, type=float,
                    help='softmax temperature (default: 0.07)')  # τ = 0.07
# https://github.com/facebookresearch/moco/blob/main/main_moco.py

def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    if args.cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    else:  # stepwise lr schedule
        for milestone in args.schedule:  # default=[120, 160]
            lr *= 0.1 if epoch >= milestone else 1.  # 在 120 和 160 个 epoch 时学习率乘 0.1
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr  # 更新各 group 的优化器参数

4.1 Linear Classification Protocol

        我们首先验证了我们的方法 —— 通过对经冻结特征的线性分类,遵循一个普遍的协议。在本小节中,我们对 IN-1M 进行无监督预训练。然后冻结特征,训练一个有监督线性分类器 (FC + Softmax)。我们在一个 ResNet 的全局平均池化 (GAP) 特征上训练 100 个 epochs 该分类器。我们报告了 ImageNet 验证集上的 1-crop,top-1 分类准确率。

# https://github.com/facebookresearch/moco/blob/main/main_lincls.py

def main_worker(gpu, ngpus_per_node, args):
    # ...

    # create model
    print("=> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch]()

    # freeze all layers but the last fc
    for name, param in model.named_parameters():
        if name not in ['fc.weight', 'fc.bias']:
            param.requires_grad = False  # stop computing gradients to freeze layers

    # init the fc layer
    model.fc.weight.data.normal_(mean=0.0, std=0.01)
    model.fc.bias.data.zero_()

    # ...

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # optimize only the linear classifier
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    assert len(parameters) == 2  # fc.weight, fc.bias
    optimizer = torch.optim.SGD(parameters, 
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

def sanity_check(state_dict, pretrained_weights):
    """
    Linear classifier should not change any weights other than the linear layer.
    This sanity check asserts nothing wrong happens (e.g., BN stats updated).
    """
    print("=> loading '{}' for sanity check".format(pretrained_weights))
    checkpoint = torch.load(pretrained_weights, map_location="cpu")
    state_dict_pre = checkpoint['state_dict']

    for k in list(state_dict.keys()):
        # only ignore fc layer
        if 'fc.weight' in k or 'fc.bias' in k:
            continue

        # name in pretrained model
        k_pre = 'module.encoder_q.' + k[len('module.'):] \
            if k.startswith('module.') else 'module.encoder_q.' + k

        assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
            '{} is changed in linear classifier training.'.format(k)

    print("=> sanity check passed.")

        对于该分类器,实施网格搜索,发现最优初始学习率为 30 且权值衰减为 0。这些超参数在本小节中介绍的所有消融项中始终表现良好。这些超参数值意味着特征分布 (例如,规模 (magnitudes)) 可能与 ImageNet 有监督训练有本质上的差异,我们将在 4.2 节中重新讨论该问题。

        更多实验分析见原文。


总结 ☆

  • 虽然对比学习无需标签,但模型仍需知道图片中哪些相似、哪些不相似才可以训练,于是需要人为设计各种巧妙的代理任务来实现该目的。
  • 如同一图像的不同裁剪和数据增广的结果,虽有差异但仍被视为具有相似的语义信息,从而作为匹配的 正样本对,此时原图即为 基准点/锚点 (anchor),衍生的新图即为 正样本;与其他图像产生的样本即均为 负样本对
  • 从某种程度上,数据集中每一图像 (及其产生的样本) 可以视为 一个单独的类别,故对 IM-1K 而言,类别数不是 1000 而是 128 万
  • 划分正、负样本后,即可通过编码器编码所有正、负样本以提取嵌入特征。
  • 由于所有正、负样本均是基于 anchor 而言的,故 anchor 通常单独配置 一个 编码器 (如本文的 query 编码器),其他的正、负样本配置 另一个 编码器 (如本文的 key 编码器)。
  • 当然,query 编码器 和 key 编码器可以 完全相同部分相同 或 完全不同。但不同的编码器之间必须相似,以确保编码得到的特征具有一致性和比较的意义。
  • 获取到 anchor 和正、负样本的嵌入特征后,只需衡量它们的相似度,并 缩小 anchor 与正样本对的嵌入特征距离,拉大 anchor 与负样本的嵌入特征距离。
  • 确定代理任务并知道如何定义正、负样本后,就要用 目标函数 来告诉模型如何学习,如常见的对比学习目标函数 NCE Loss 等。
  • 事实上,对比学习最大的特性就是方法 灵活,可以设置各种不同的代理任务。只要找到或设计一种合理的方式来 定义正、负样本,就能走完剩下的一些较标准化的流程,从而实施对比学习。

实现

# https://github.com/facebookresearch/moco/blob/main/moco/builder.py

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn


class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
            self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

参考资料

李沐论文精读系列三:MoCo、对比学习综述(MoCov1/v2/v3、SimCLR v1/v2、DINO等)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/175410.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AlmaLinux 9 安装Kasm Workspaces

今天尝试一下AlmaLinux 9 安装Kasm Workspaces。 前提条件 安装了Docker和Docker Compose,已经最新版本要求, docker 18.06 docker compose 2.1.1 创建一个Swap分区 下面的步骤将创建一个2千兆字节(2048MB)的交换分区。请根据…

我的创作纪念日——“永远相信美好的事情即将发生”

作者:非妃是公主 专栏:《程序人生》 个性签:顺境不惰,逆境不馁,以心制境,万事可成。——曾国藩 文章目录序与CSDN的往事机缘收获憧憬碎碎念序 第一次写创作纪念日的文章!哈哈哈哈,今…

一起自学SLAM算法:7.5 基于因子图的状态估计

连载文章,长期更新,欢迎关注: 虽然式(7-90)所示的完全SLAM系统可以用滤波方法求解,比如著名的Fast-SLAM实现框架。但是,贝叶斯网络表示下的完全SLAM系统能很方面地转换成因子图表示,…

字符串匹配: BF与KMP算法

文章目录一. BF算法1. 算法思想2. 代码实现二. KMP算法1. 算法思想概述2. 理解基于最长相等前后缀进行匹配3. 代码中如何实现next数组5. 代码实现6. next数组的优化一. BF算法 1. 算法思想 BF 算法, 即暴力(Brute Force)算法, 是普通的模式匹配算法, 假设现在我们面临这样一个…

24/365 java 观测线程状态 线程优先级

1.观测线程 JDK中定义的线程的六个状态 &#xff1a; 可以用getState()来观测线程 public static void main(String[] args) throws InterruptedException {Thread thread new Thread(()->{for (int i 0; i < 10; i) {try {Thread.sleep(100);} catch (InterruptedExc…

2023适合新手的免费编曲软件FL Studio水果21中文版

水果软件即FL Studio&#xff0c;这是一款较为专业的编曲软件&#xff0c;这款软件自带高品质打击乐、钢琴、弦乐以及吉他等107种乐器效果&#xff0c;内置了包括经典电子音色、合成利器3xosc、sytrus、slicex等多种插件&#xff0c;可以帮助音乐制作人创作不同的音乐曲风&…

数据结构进阶 哈希表

作者&#xff1a;小萌新 专栏&#xff1a;数据结构进阶 作者简介&#xff1a;大二学生 希望能和大家一起进步&#xff01; 本篇博客简介&#xff1a;模拟实现高阶数据结构 哈希表 哈希表 哈希桶哈希概念举例哈希冲突哈希函数哈希冲突的解决方式之一闭散列 --开放定址法哈希表的…

Python CalmAn(Calcium Imaging Analysis)神经生物学工具包安装及环境配置过程

文章目录CalmAn简介安装要求我的设备1>CalmAn压缩包解压&#xff08;caiman文件夹要改名&#xff09;2>conda创建虚拟环境3>requirements依赖包配置&#xff08;包括tensorflow&#xff09;4>caiman安装(mamba install)5>caimanmanager.py install6>PyCharm添…

51单片机独立按键

文章目录前言一、按键原理图二、代码编写三、模块化管理按键总结前言 本篇文章将带大家学习独立按键按键的基本操作。 独立按键式直接用I/O口线构成的单个按键电路&#xff0c;其特点是每个按键单独占用一根I/O口线&#xff0c;每个按键的工作不会影响其他I/O口线的状态。 一…

MongoDB学习笔记【part5】基于 MongoRepository 开发CURD

一、MongoRepository Spring Data 提供了对 mongodb 数据访问的支持&#xff0c;只需继承 MongoRepository 类&#xff0c;并按照 Spring Data 规范就可以实现对 mongodb 的操作。 SpringData 方法定义规范&#xff1a; 注意事项&#xff1a; 不能随便声明&#xff0c;必须要…

汇编语言学习笔记 下

本文承接汇编语言学习笔记 上 上篇文章记录了汇编语言寄存器&#xff0c;汇编语言基本组成部分&#xff0c;数据传送指令&#xff0c;寻址指令&#xff0c;加减法指令&#xff0c;堆栈&#xff0c;过程&#xff0c;条件处理&#xff0c;整数运算的内容 高级过程 大多数现代编程…

day24-网络编程02

1.NIO 1.1 NIO通道客户端【应用】 客户端实现步骤 打开通道指定IP和端口号写出数据释放资源 示例代码 public class NIOClient {public static void main(String[] args) throws IOException {//1.打开通道SocketChannel socketChannel SocketChannel.open();//2.指定IP和端…

你是真的“C”——2023年除夕夜 牛客网刷题经验分享~

2023年除夕夜 牛客网刷题经验分享~&#x1f60e;前言&#x1f64c;BC89 包含数字9的数 &#x1f60a;BC90 小乐乐算多少人被请家长 &#x1f60a;BC91 水仙花数 &#x1f60a;BC92 变种水仙花 &#x1f60a;BC93 公务员面试 &#x1f60a;总结撒花&#x1f49e;&#x1f60e;博…

Android Studio 支持安卓手机投屏

有时当我们在线上做技术分享或者功能演示时&#xff0c;希望共享连接中的手机屏幕&#xff0c;此时我们会求助 ApowerMirror&#xff0c;LetsView&#xff0c;Vysor&#xff0c;Scrcpy 等工具。如果你是一个 Android Developer&#xff0c;那么现在你有了更好的选择。 Android…

Day59| 503. 下一个更大元素 II | 42. 接雨水 --三种方法:1.双指针 2.动态规划 3.单调栈

503. 下一个更大元素 II注意点&#xff1a;初始化了2倍的题目中的nums.size()&#xff0c;最后直接/2即可分清逻辑nums[i] > nums[st.top()]的时候&#xff0c;才进行st.pop()操作class Solution { public:vector<int> nextGreaterElements(vector<int>& nu…

STM32——独立看门狗

目录 看门狗产生背景 看门狗的作用 STM32 独立看门狗 独立看门狗特点 &#xff08;IWDG&#xff09; 独立看门狗常用寄存器 独立看门狗超时时间计算 独立看门狗工作原理 独立看门狗操作库函数 看门狗产生背景 当单片机的工作受到来自外界电磁场的干扰&#xff0c;造…

bert-bilstm-crf提升NER模型效果的方法

1.统一训练监控指标和评估指标评估一个模型的最佳指标是在实体级别计算它的F1值&#xff0c;而不是token级别计算它的的准确率&#xff09;。自定义一个f1值的训练监控指标传给回调函数PreliminaryTP&#xff1a;实际为P&#xff0c;预测为PTN&#xff1a;实际为N&#xff0c;预…

【Java IO流】缓冲流及原理详解

文章目录前言字节缓冲流原理字符缓冲流Java编程基础教程系列前言 前面我们已经学习了四种对文件数据操作的基本流&#xff0c;字节输入流&#xff0c;字节输出流&#xff0c;字符输入流&#xff0c;字符输出流。为了提高其数据的读写效率&#xff0c;Java中又定义了四种缓冲流…

LwIP系列--内存管理(堆内存)详解

一、目的小型嵌入式系统中的内存资源&#xff08;SRAM&#xff09;一般都比较有限&#xff0c;LwIP的运行平台一般都是资源受限的MCU。基于此为了能够更加高效的运行&#xff0c;LwIP设计了基于内存池、内存堆的内存管理以及在处理数据包时的pbuf数据结构。本篇的主要目的是介绍…

Ubuntu22.10安装和配置R/FSL/Freesurfer

2018年购入的台式机&#xff0c;一直处于吃灰状态&#xff0c;决定安装Ubuntu方便学习和使用一些基于Linux系统的软件。本文记录相关软件的安装和配置。>安装Ubuntu22.10<到官网下载ISO文件。将一个闲置U盘通过Rufus制作为安装盘。重启按F11设置引导盘&#xff0c;选择U盘…