一、引言
论文: Align before Fuse: Vision and Language Representation Learning with Momentum Distillation
作者: Salesforce Research
代码: ALBEF
特点: 该方法使用ViT进行图像特征提取,提出将BERT分两部分,一部分进行文本特征提取,另一部分进行图像-文本交互的特征提取;提出使用image-text contrastive learning (ITC)损失、masked language modeling (MLM)损失、image-text matching (ITM)损失进行模型优化;提出Momentum Distillation策略以一个通过exponential moving average (EMA)的网络生成软伪标签提供另一个视角的优化方向。
⚠️ 在学习该方法前,建议补充ViT、BERT、CLIP、MoCo的相关知识。
二、详情
ALBEF的整体结构图如下:
可见,ALBEF在网络结构上主要包括1个图像编码器、1个文本编码器、1个多模态编码器和1个同样包含上述3个编码器的动量模型;ALBEF在损失上主要包括image-text contrastive learning (ITC)损失、masked language modeling (MLM)损失、image-text matching (ITM)损失;此外,ALBEF还引入了动量蒸馏。
2.1 网络结构
如图,ALBEF在网络结构上主要包括1个图像编码器
、1个文本编码器
、1个多模态编码器
和1个同样包含上述3个编码器的动量模型
。
2.1.1 图像编码器
ALBEF的图像编码器
使用ViT-B/16,共12个transformer模块,由在ImageNet-1k上进行预训练的权重初始化。输入图像转为token后会再扩充一个名为
[
CLS
]
[\text{CLS}]
[CLS]的token(初始化全0的可学习参数向量),用来表达图像的全局信息。最后输出的是经过12个transformer模块优化过的输入图像的token和
[
CLS
]
[\text{CLS}]
[CLS]的token,记为
{
v
cls
,
v
1
,
⋯
,
v
N
}
\{\boldsymbol{v}_{\text{cls}},\boldsymbol{v}_{1},\cdots,\boldsymbol{v}_N\}
{vcls,v1,⋯,vN}。
关于ViT的详情,请参考我之前的博客Vision Transformer。
2.1.2 文本编码器
ALBEF的文本编码器
使用6个transformer模块,由
BERT
base
\textbf{BERT}_{\textbf{base}}
BERTbase的前6层初始化。输入文本会在最前面扩充一个名为
[
CLS
]
[\text{CLS}]
[CLS]的token(直接放在句子最前面,例如原文本是“I am very happy today.”,则新文本应为“
[
CLS
]
[\text{CLS}]
[CLS] I am very happy today.”),用来表达文本的全局信息。最后输出经Tokenizer和6个transformer模块优化过的输入文本的token和
[
CLS
]
[\text{CLS}]
[CLS]的token,记为
{
w
cls
,
w
1
,
⋯
,
w
N
}
\{\boldsymbol{w}_{\text{cls}},\boldsymbol{w}_{1},\cdots,\boldsymbol{w}_N\}
{wcls,w1,⋯,wN}。
2.1.3 多模态编码器
ALBEF的多模态编码器
使用6个transformer模块(含交叉注意力),由
BERT
base
\textbf{BERT}_{\textbf{base}}
BERTbase的后6层初始化(BERT不包含交叉注意力,所以其中交叉注意力是随机初始化的)。输入为图像编码器和文本编码器输出的toekn,即
{
v
cls
,
v
1
,
⋯
,
v
N
}
\{\boldsymbol{v}_{\text{cls}},\boldsymbol{v}_{1},\cdots,\boldsymbol{v}_N\}
{vcls,v1,⋯,vN}和
{
w
cls
,
w
1
,
⋯
,
w
N
}
\{\boldsymbol{w}_{\text{cls}},\boldsymbol{w}_{1},\cdots,\boldsymbol{w}_N\}
{wcls,w1,⋯,wN},输出以图像token为指导经6个含交叉注意力的transformer模块优化过的文本token。
2.1.4 动量模型
ALBEF的动量模型
是额外保存的一个模型,同样包括图像编码器、文本编码器、多模态编码器,且初始化参数也一致。但动量模型不通过梯度优化,而是通过指数移动平均进行参数更新。为方便讲解,我们将正常通过梯度优化更新参数的模型称为梯度模型
,额外保存的通过指数移动平均更新参数的模型称为动量模型
。
指数移动平均的公式可表示为 θ t = λ θ t + ( 1 − λ ) θ g \theta_{t}=\lambda\theta_{t}+(1-\lambda)\theta_{g} θt=λθt+(1−λ)θg,其中 θ t \theta_{t} θt和 θ g \theta_{g} θg分别为动量模型和梯度模型的参数, λ \lambda λ取的很大这里为0.995,这是为了保证动量模型的稳定使其不会因为噪声干扰而偏离原本正确的优化方向。
简单来说,就是通过梯度更新过梯度模型后,新的梯度模型和旧的动量模型参数加权求和即得到了新的动量模型参数。
2.2 损失
如图,ALBEF主要包括3个损失,分别为image-text contrastive learning (ITC
)损失、masked language modeling (MLM
)损失、image-text matching (ITM
)损失。
2.2.1 image-text contrastive learning 损失
ITC
损失旨在更好地学习两个模态的特征表达,使两个模态的特征能够对齐,即图像特征与对应文本描述的特征更相似。相关模块如下图红框所示:
如上图,ITC
损失的计算需要用到梯度模型中图像编码器和文本编码器输出的
[
CLS
]
[\text{CLS}]
[CLS]token;还需要用到动量模型中图像编码器和文本编码器输出的一堆
[
CLS
]
[\text{CLS}]
[CLS]token。
这里用到了MoCo的自训练思想,关于MoCo 的详情请参考我之前的博客SAVC的第1.2.2节。
我们以一个批次中的一张图像为例讲解具体的计算过程:
假设一个批次包括 B B B个图像-文本对, I 1 I_1 I1是第一张图片, T 1 T_1 T1是与之配对的文本, T 2 , ⋯ , T B T_2,\cdots,T_B T2,⋯,TB是与当前批次其他图片配对的文本; T 1 , T 2 , ⋯ , T B T_1,T_2,\cdots,T_B T1,T2,⋯,TB和 Q B + 1 T , Q B + 2 T , ⋯ , Q M T Q^T_{B+1},Q^T_{B+2},\cdots,Q^T_{M} QB+1T,QB+2T,⋯,QMT组成一个队列, M = 65536 M=65536 M=65536,初始是从训练集随机选出的 M M M个文本,新批次的文本数据到来后会被推入队列,队列另一头的数据会被推出。
I 1 I_1 I1会经过梯度模型的图像编码器输出 [ CLS ] [\text{CLS}] [CLS]token,即 v cls \boldsymbol{v}_{\text{cls}} vcls; T 1 , T 2 , ⋯ , T B , Q B + 1 T , Q B + 2 T , ⋯ , Q M T T_1,T_2,\cdots,T_B,Q^T_{B+1},Q^T_{B+2},\cdots,Q^T_{M} T1,T2,⋯,TB,QB+1T,QB+2T,⋯,QMT会经过动量模型的文本编码器输出 [ CLS ] [\text{CLS}] [CLS]token,即 w cls , 1 ′ , w cls , 2 ′ , ⋯ , w cls , M ′ \boldsymbol{w}^{\prime}_{\text{cls},1},\boldsymbol{w}^{\prime}_{\text{cls},2},\cdots,\boldsymbol{w}^{\prime}_{\text{cls},M} wcls,1′,wcls,2′,⋯,wcls,M′。
梯度模型在图像编码器后增加一个全连接映射使token维度从768降低至256,并施加归一化,映射+归一化操作记为 g v ( ⋅ ) g_v(\cdot) gv(⋅),得到 q q q;动量模型在文本编码器后进行同样的操作,记为 g w ′ ( ⋅ ) g^{\prime}_w(\cdot) gw′(⋅),得到 k 0 , k 1 , ⋯ , k M k_0,k_1,\cdots,k_M k0,k1,⋯,kM。
q q q与 k 1 , k 2 , ⋯ , k M k_1,k_2,\cdots,k_{M} k1,k2,⋯,kM一一计算相似度,记为 s ( I , T m ) , m = 1 , 2 , ⋯ , M s(I,T_m),m=1,2,\cdots,M s(I,Tm),m=1,2,⋯,M。于是可以获取该图像的softmax-normalized image-to-text similarity:
M M M个概率值组成 p i2t ( I ) \boldsymbol{p}^{\text{i2t}}(I) pi2t(I),对应的真实标签则应该是 y i2t ( I ) = { 1 , 0 , ⋯ , 0 } \boldsymbol{y}^{\text{i2t}}(I)=\{1,0,\cdots,0\} yi2t(I)={1,0,⋯,0},真实标签是one-hot形式的, I I I与队列中哪个文本对应,对应位置就应该为 1 1 1,其余为 0 0 0。
相应地,以文本为基准利用梯度模型的文本编码器、动量模型的图像编码器、图像队列也可以获得
p
t2i
(
T
)
\boldsymbol{p}^{\text{t2i}}(T)
pt2i(T)和
y
t2i
(
T
)
\boldsymbol{y}^{\text{t2i}}(T)
yt2i(T)。最后便可得到ITC
损失:
其中 H ( ⋅ , ⋅ ) H(\cdot,\cdot) H(⋅,⋅)为标准交叉熵。
2.2.2 masked language modeling 损失
MLM
损失利用图像和文本上下文来预测被mask的单词,以此提升模型的理解能力。相关模块如下图红框所示:
对于一个图像-文本对,图像被完整送入图像编码器,文本会被随机mask,即由原来的单词替换为 [ MASK ] [\text{MASK}] [MASK],然后送入文本编码器。
mask的规则是每个单词有15%的概率被选中,被选中的单词中80%被替换为 [ MASK ] [\text{MASK}] [MASK],10%被随机替换成其他token,10%没有任何改变。
之所以不是直接选15%*90%的进行mask和替换,是因为15%*10%没有任何变化的单词也需要模型对其进行预测。
以“I am very happy today.”为例,讲解mask的过程。经过Tokenizer其变为:
tokens = [I, am, very, happy, today]
增加 [ CLS ] [\text{CLS}] [CLS]和 [ SEP ] [\text{SEP}] [SEP]得到:
tokens = [[CLS], I, am, very, happy, today, [SEP]]
假设我们以15%的概率选中I、happy、today,再以80%-10%-10%的概率进行调整后得到:
tokens = [[CLS], I, am, very, [MASK], good, [SEP]]
可见,I没有发生变化,happy被替换为 [ MASK ] [\text{MASK}] [MASK],today被替换为good。此时要求模型利用图像和 [ [ CLS ] , I , a m , v e r y , [ MASK ] , g o o d , [ SEP ] ] [[\text{CLS}], I, am, very, [\text{MASK}], good, [\text{SEP}]] [[CLS],I,am,very,[MASK],good,[SEP]]来预测出句子原本的单词。
图像信息图像编码器和交叉注意力与文本信息交互从而起到指导作用,文本信息经文本编码器和多模态解码器输出优化后的tokens。每个token后跟一个FFN和softmax进行当前位置对应单词的预测。 下图给出了一个“Paris is a beautiful city. I love Paris.”中city被替换为 [ MASK ] [\text{MASK}] [MASK]后模型的预测过程以帮助理解:
可见,该部分预测仍是一个概率分布,所以MLM
损失同样使用标准交叉熵:
其中, I I I和 T T T是原始图像-文本对, T ^ \hat{T} T^是经mask后的文本; p msk ( I , T ^ ) \boldsymbol{p}^{\text{msk}(I,\hat{T})} pmsk(I,T^)是对一个被mask的单词预测的概率分布, y msk \boldsymbol{y}^{\text{msk}} ymsk是该单词真实的one-hot标签。
⚠️ 由于
MLM
损失与其它损失,例如ITC
损失,的输入不同(有无mask),所以该损失会额外产生一次文本编码器和多模态编码器的forward。
2.2.3 image-text matching 损失
ITM
损失用来预测输入的图像-文本对是否匹配,匹配为1,不匹配为0,是一个二分类损失。相关模块如下图红框所示:
对于一个批次的图像-文本对,该损失是较简单的,因为非原配的图像-文本对很容易被判定为否,所以ALBEF利用在ITC
损失计算时得到的本批次图像-文本相似度来挑选hard的负例。
对于一个批次中的一张图像来说, { p 1 i2t ( I ) , p 2 i2t ( I ) , ⋯ , p B i2t ( I ) } \{p_1^{\text{i2t}}(I),p_2^{\text{i2t}}(I),\cdots,p_B^{\text{i2t}}(I)\} {p1i2t(I),p2i2t(I),⋯,pBi2t(I)}就是计算
ITC
损失时得到的相似度,其中非原配的最高相似度所对应的文本即为hard的负例文本。同样地,对于每个文本来说,也可以选出自己的hard负例图像。
图像和文本分别经过各自的编码器再通过相似度选出各自的hard负例之后每个图像或文本都有1个正例和1个负例与之对应,将它们的
[
CLS
]
[\text{CLS}]
[CLS]token送入多模态编码器即可得到优化后的
[
CLS
]
[\text{CLS}]
[CLS]token。在后面跟一个全连接映射和softmax即可进行二元预测判断图像-文本是否匹配。所以ITM
损失也可以使用标准交叉熵:
其中,一个图像-文本对可以产生3项损失,包括1个原配的图像-文本对、2个hard的图像-文本对(因为分别是梯度模型和动量模型的输出之间计算相似度,不一定是同一对图像-文本互为hard)。
3 动量蒸馏
动量模型不仅在计算ITC
损失时发挥作用,ALBEF还用它来应对从网络爬虫下来的图像-文本对富含噪声的问题。
首先,我们需要知道网络数据的噪声是什么样的。一般我们看到一个蛋糕图片后希望获取的是它的店铺位置从而去购买,所以我们从网上下载的数据很可能是一个蛋糕图片和一个对商铺的描述;但实际我们是希望与图片匹配的文本描述应该是针对图片中内容的描述,例如这个蛋糕的外观,如下图:
可能还有些数据的图像-文本是匹配的,但是明显有更合适的描述,如下图:
事实上,网络上很多都是这种噪声数据,如果我们使用one-hot形式的标签进行模型训练和学习,就会很大程度上被这些数据误导。于是,ALBEF使用动量模型来生成软伪标签约束和指导模型的学习。
软标签是相对one-hot形式的硬标签而言的。例如三分类问题中,one-hot只有一个值是1,其余均为0,例如 { 1 , 0 , 0 } \{1,0,0\} {1,0,0}、 { 0 , 1 , 0 } \{0,1,0\} {0,1,0}、 { 0 , 0 , 1 } \{0,0,1\} {0,0,1};软标签则只要求各个值的和为1,允许多个类别上有值,例如 { 0.6 , 0.3 , 0.1 } \{0.6,0.3,0.1\} {0.6,0.3,0.1}、 { 0.3 , 0.7 , 0 } \{0.3,0.7,0\} {0.3,0.7,0}等等。
伪标签是相对真实标签而言的,非原始的真实标签,而是通过其它手段生成的标签都称为伪标签。
其次,就是如何生成软伪标签。ALBEF是将动量模型的预测作为软伪标签。
例如,ITC
损失原本是将图像或文本输入梯度网络然后将队列输入动量网络再计算相互间的相似度得到预测,如果原本图像和文本是一对,则标签值为1,否则为0。动量蒸馏是将图像或文本以及队列均输入到动量网络中,然后计算动量网络输出间的相似度。
下图说明了两者相似度计算的差异:
可见,主要区别就是图像-文本对是送入梯度网络(原始)还是动量网络(动量蒸馏)。有了新的相似度之后,再通过softmax-normalized image-to-text similarity即可得到动量网络的预测,即伪软标签
q
i2t
(
I
)
\boldsymbol{q}^{\text{i2t}}(I)
qi2t(I)和
q
t2i
(
T
)
\boldsymbol{q}^{\text{t2i}}(T)
qt2i(T)。于是得到ITC
损失的动量蒸馏损失:
其中, α = 0.4 \alpha=0.4 α=0.4。由于 q i2t ( I ) \boldsymbol{q}^{\text{i2t}}(I) qi2t(I)和 q t2i ( T ) \boldsymbol{q}^{\text{t2i}}(T) qt2i(T)不是one-hot形式的,所以这里用KL散度衡量梯度网络的预测与动量网络的预测的一致性。
当真实图像-文本不太匹配时,这种操作能允许模型将图像或文本与其它文本或图像做匹配。但是我们又不希望随机找一个进行匹配,所以用比较稳定的动量网络提供一个合适的匹配。
类似地,将被mask后的文本输入动量网络,也能得到动量网络的预测
q
msk
(
I
,
T
^
)
\boldsymbol{q}^{\text{msk}}(I,\hat{T})
qmsk(I,T^),即软伪标签。于是,得到MLM
损失的动量蒸馏损失:
⚠️ 因为
ITM
损失就是根据原标签进行0和1的分配的,所以不太适合采用该策略,ALBEF没有对其进行修改。
作者还提供了一些例子,来证明软伪标签有时是更好更合适的:
上面3幅中,被mask的部分真实单词没有伪标签的单词合适;下面2幅中,原描述没有伪标签的描述合适。
致谢:
本博客仅做记录使用,无任何商业用途,参考内容如下:
是时候彻底弄懂BERT模型了
多模态论文串讲·上