info | |
---|---|
paper | https://arxiv.org/abs/2205.13147 |
code | https://github.com/RAIVNLab/MRL |
org | 华盛顿大学、Google、哈弗大学 |
个人博客位置 | http://www.myhz0606.com/article/mrl |
Motivation
我们平时做retrieval相关的工作,很多时候需要根据业务场景和计算资源对向量进行降维。受限开发周期,我们往往不会通过重新训练特征提取模型来调整向量维度,而是用PCA等方法来实现。但是当降维的scale较大时,PCA
等方法的效果较差。Matryoshka Representation Learning (MRL
)这篇paper介绍了一个很简单但有效的方法能实现一次训练,获取不同维度的表征提取。下面来看它具体是怎么做的吧。
Method
文中只描述MRL
最核心的部分,详细介绍请看原论文。
我们以一个图像分类任务为例,其pipeline如下。图片首先通过一个Feature extractor提取特征,flatten后用一个FC
来映射到表征空间,再接入一个classifier(也是个全连接层)得到该图片在类别上的概率分布。用这个方法训练,一次训练我们只能得到一种维度的图片表征(如图中是2048维)
为了一次训练获得不同维度的图片表征,最简单粗暴的方法就是我们可以用多个FC
及对应的Classifier进行联合训练。这无疑是有效的,但由于FC
和classifier多了,模型会大一些。
MRL
对上面做了一个优化,它能通过一组FC
和Classifier实现多种尺度的特征训练。pipeline如下图所示(图中同个颜色表示共享权重)。MRL
实现的核心就是:对同一组FC
和Classifier进行分片,从而实现不同维度的表征训练。
论文公式中的
F
(
x
i
;
θ
F
)
F(x_i; \theta_{F})
F(xi;θF)是我图中的Feature_extractor + FC
。
min { W ( m ) } m ∈ M , θ F 1 N ∑ i ∈ [ N ] ∑ m ∈ M c m ⋅ L ( W ( m ) ⋅ F ( x i ; θ F ) 1 : m ; y i ) , \min _ { \{ { \boldsymbol W } ^ { ( m ) } \} _ { m \in { \mathcal M } } , \, \theta _ { F } } \frac { 1 } { N } \sum _ { i \in [ N ] } \sum _ { m \in { \mathcal M } } c _ { m } \cdot { \mathcal L } ( { \boldsymbol W } ^ { ( m ) } \cdot F ( x _ { i } ; \theta _ { F } ) _ { 1 : m } \, ; \, y _ { i } ) \; , {W(m)}m∈M,θFminN1i∈[N]∑m∈M∑cm⋅L(W(m)⋅F(xi;θF)1:m;yi),
MRL
的实现源码如下图所示:
class MRL_Linear_Layer(nn.Module):
def __init__(self, nesting_list: List, num_classes=1000, efficient=False, **kwargs):
super(MRL_Linear_Layer, self).__init__()
self.nesting_list = nesting_list
self.num_classes = num_classes # Number of classes for classification
self.efficient = efficient
if self.efficient:
setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))
else:
for i, num_feat in enumerate(self.nesting_list):
setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))
def reset_parameters(self):
if self.efficient:
self.nesting_classifier_0.reset_parameters()
else:
for i in range(len(self.nesting_list)):
getattr(self, f"nesting_classifier_{i}").reset_parameters()
def forward(self, x):
nesting_logits = ()
for i, num_feat in enumerate(self.nesting_list):
if self.efficient:
if self.nesting_classifier_0.bias is None:
nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )
else:
nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )
else:
nesting_logits += (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)
return nesting_logits
Result
该图对比了MRL
不同维度的表征在imagenet1K上linear classification和1-NN的准确率。
下图给出了scale model和dataset时MRL
依旧有效,并且MRL
提取的表征具备良好的插值性能。
更多实验结果见原论文。
小结
这篇文章虽然idea很简单,但很适合工程应用。
参考文献
Matryoshka Representation Learning