一、引言
论文: iBOT🤖: Image BERT Pre-Training with Online Tokenizer
作者: ByteDance
代码: iBOT
注意: 该方法是在另一个自监督预训练方法基础上的改进,学习之前建议掌握DINO。
特点: 对于一张图片,该方法首先进行两次全局裁剪与增强得到两张全局视图,之后进行随机mask再产生两张带mask的全局视图,并分别送入教师和学生网络;学生与教师网络均有两个映射头,一个用于映射
[
CLS
]
[\text{CLS}]
[CLS],另一个用于映射特征图,之后以两个网络的输出一致性为损失进行学生网络的更新;教师网络由指数移动平均更新,还采用了中心化和锐化操作避免模式崩溃。
二、详情
- 对于一张图片,进行2次全局裁剪(面积占比在
[
0.14
,
1
]
[0.14,1]
[0.14,1],resize为224),之后对2个切片进行随机增强(翻转、色彩变化、高斯模糊、归一化等)得到2个正常的全局视图;接着按照16*16的
patch
分配14*14的mask
,mask为1表示被遮掩,为1的概率为0.3,被遮掩的patch被设置为全0,于是又得到2个带mask的全局视图。 - 2个正常的全局视图送入教师网络,另外两个带mask的全局视图送入学生网络。教师网络和学生网络的结构相同初始参数也相同,可以是ViT、Swin Transformer等等。教师网络和学生网络都有两个映射头,最后输出维度相同都是8192,学生网络的两个头不共享参数,教师网络的两个头共享参数。一个映射头用于映射 [ CLS ] [\textbf{CLS}] [CLS],输出2*8192,另一个映射头用于映射图像特征,输出2*196*8192。
- 教师网络的
[
CLS
]
[\text{CLS}]
[CLS]输出先减去
center1
(初始全0)再除以temp1=0.04
之后求softmax;教师网络的图像特征输出先减去center2
(初始全0)再除以temp2
(前30个epoch等间隔在0.04-0.07取值,后70个epoch全为0.07)之后求softmax。学生网络的输出均先除以temp=0.1
,然后求softmax再取log;最后,教师与学生网络的 [ CLS ] [\text{CLS}] [CLS]输出之间计算损失,图像特征输出之间计算损失(其实就是在标准的交叉熵损失 − p t log p s -p_t\log p_s −ptlogps中对教师网络输出引入了锐化和中心化,对学生网络输出引入了锐化)。
除以temp的操作称为锐化(sharping),减去中心的操作称为中心化(centering),两个操作叠加是为了避免模式崩溃(无论哪个图像网络输出softmax后始终是某一项很大或始终接近均匀分布)。锐化能放大分布中某一个值减小其他值,中心化能促使分布更接近均匀分布,两个相反的操作相互作用使得模式崩溃得以避免。
⚠️ 图像特征输出mask=0的部分不计算损失,也就是说该项损失是希望学生网络能够借助教师网络的指导通过非mask部分预测出被mask的部分。 [ CLS ] [\text{CLS}] [CLS]对应的损失则是希望网络能够捕捉图像中的高层语义信息。
- 之后更新中心
center1
和center2
,两个center均用下式更新:
center = center_momentum * center + (1 - center_momentum) * batch_center
其中,center_momentum=0.9,batch_center是当前批次所有全局视图经教师网络输出的均值( [ CLS ] [\text{CLS}] [CLS]输出的均值更新center1,图像特征输出的均值更新center2)。
其实这个操作就是指数移动平均,一般公式为 b = λ b + ( 1 − λ ) a b=\lambda b+(1-\lambda)a b=λb+(1−λ)a,简单来说就是用另外一个参数更新自己,但是保留自己的一部分。
- 根据3中的损失更新学生网络,教师网络不用损失更新,而是用指数移动平均更新,见下式:
θ t = λ θ t + ( 1 − λ ) θ s \theta_{t}=\lambda\theta_{t}+(1-\lambda)\theta_{s} θt=λθt+(1−λ)θs,
其中, λ \lambda λ在训练时是遵循cosine schedule,从0.996到1变化。即用学生网络更新教师网络,但保留教师网络的一部分。
⚠️ 因教师网络的 [ CLS ] [\text{CLS}] [CLS]和图像特征共用映射头,所以只需要用上式更新 [ CLS ] [\text{CLS}] [CLS]对应的映射头。
伪代码如下: