摘要
自Meta研究团队发布SAM(Segment Anything Model)项目依赖,因其令人惊艳的零样本迁移特性和与其他视觉应用兼容的高通用性,引起了极大的关注。由于大多数类似的应用都需要运行在资源限制的边缘设备,如手机,因此,本文的目标是通过使用轻量化的encoder替换原始计算量大的encoder使其称为移动友好型模型。一个简单的思路是按照SAM原文训练一个新的轻量化的SAM,但是效果不理想,尤其是在有限训练资源的情况下。作者发现效果不理想是由于图像编码器和mask解码器的耦合优化(coupled optimization)引起的,因此,提出了解耦蒸馏法。具体而言,就是将原始SAM的图像编码器的知识蒸馏给一个轻量化的encoder,它可以与原始SAM中的mask解码器自动兼容。MoibleSAM的训练过程可以在1天之内单张GPU上完成,相同性能的情况下,参数量少了60多倍。在推理速度上,MobileSAM平均10ms一张图像,特征提取8ms和mask解码2ms。凭借卓越的性能和更高的通用性,MobileSAM比同期的FastSAM小7倍,快5倍,更适合移动端使用。此外,MobileSAM可以在CPU上更加流畅的推理运行。
太长不看版:
作者发现由于图像编码器和mask解码器的耦合优化,使得单纯重新训练一个轻量化的SAM效果不佳,提出了解耦蒸馏法。具体而言,就是将原始SAM的图像编码器的知识蒸馏给一个轻量化的encoder,它可以与原始SAM中的mask解码器自动兼容。
MoibleSAM的训练过程可以在1天之内单张GPU上完成,相同性能的情况下,参数量少了60多倍。在推理速度上,MobileSAM平均10ms一张图像,特征提取8ms和mask解码2ms
模型结构
SAM
在介绍MoibleSAM之前,先了解一下SAM模型。如下图所示,SAM包含两个部分:图像编码器和mask解码器,相比于mask解码器的4M参数量,图像编码器的参数量为632M,这使得部署SAM在移动端非常困难。因此,使用轻量化的模型替换原始的图像编码器成为优化手段。
解耦蒸馏
如下图所示,左边为coupled distillation,指的是采用知识蒸馏的方法将ViT-H的知识蒸馏给一个更小的图像编码器。然而,这种直接替换在重新训练的难度主要在于图像编码器和mask解码的耦合优化。基于分而治之的思路,作者提出将知识蒸馏的任务分为两个子任务:图像编码器蒸馏和mask编码器微调。如右图所示,将mask解码器冻结之后进行知识蒸馏,这种蒸馏法称为seim-coupled distillation。
然而,根据经验发现这种优化仍然存在挑战性,因为prompt的选择是随机的,使得mask解码器可变,从而增加了优化难度。因此,提出直接把ViT-H蒸馏到小的图像编码器中,该方法称为decoupled distillation,如下图所示。在图像embedding上直接进行蒸馏的好处是可以使用简单的MSE损失,而不需要使用focal loss和dice loss。
微调mask解码器的必要性
使用decoupled蒸馏法后,发现由于student图像编码器生成的embedding接近teacher,使得微调mask解码器变成可选项。如果需要微调mask解码器,可以冻结图像编码器或者一起微调可能进一步提升性能。
效果对比
将原始SAM生成的mask作为GT,输入相同点比较coupled和的coupled蒸馏法的效果差异。在coupled蒸馏中,使用ViT-B作为图像编码器,在SA-1B数据集上训练180个迭代。相反,在decoupled蒸馏中,使用SA-1B的1%数据训练55个迭代。结果对比如下表所示,decoupled蒸馏法取得0.75的mIoU。
实验结果
轻量化图像编码器
本文的目标是通过使用一个轻量化模型替换原始SAM的ViT-H来得到一个高效的SAM。作为一个ViT-based基础网络,ViT-Tiny与Deit-Tiny有相似的参数量但效果更好。例如,在ImageNet-1K数据集上,Deit-Tiny取得72.2%精度但ViT-Tiny是79.1%。因此采用ViT-Tiny作为MobileSAM的图像编码器,验证本文提出的解耦蒸馏法有效性。
训练和评估细节
为了蒸馏图像编码器,本文使用1%的SA-1B数据在单张GPU上训练8个epoch。由于在知识蒸馏过程中,教师图像编码器占用大部分的计算资源,因此,为了加速蒸馏过程,首先将数据集的图像的embedding保存下来。至此,在单张GPU下,训练不到一天就可以得到MobileSAM。在更多GPU上训练更长时间可以得到更好的效果。最开始的研究结果可知通过微调mask解码器可以进一步提升MobileSAM的性能,但是本文中为了简单跳过了该步骤。为了评估MoblieSAM的效果,计算原始SAM和MobileSAM预测的mask之间的mIoU。
可视化结果
MobileSAM给出了基于点和目标框的两种预测mask结果,以及分割一切的结果对比,如下图所示。
消融实验
如下图所示,可以发现在相同迭代次数情况下,增大batch size可以提升模型性能。此外,在相同batch size下,通过增加训练epoch可以提升性能。
与FastSAM对比
如下表所示,FastSAM包含68M参数量,推理一张图片用时40ms。相反,MoibleSAM参数量小于10M,推理时间进行12ms。
FastSAM推理mask时需要输入多个points,因此将其中1个点设为前景,其余都为背景点。下图为对比结果,可以看到FastSAM的mIoU远小于MobileSAM表明FastSAM预测的mask与原始SAM的偏差较大。
结论
本文工作的目标是通过使用轻量化的模型替换原始SAM的图像编码器,使其移动端友好。由于SAM中图像编码器和mask解码器的耦合优化,使得直接训练一个轻量化模型效果不佳,因此提出了解耦蒸馏,将图像编码器ViT-H的知识蒸馏给一个轻量化的编码器。MobileSAM的参数量比原始SAM少了60多倍,同时MobileSAM延续原始SAM模型的接口,仅仅替换了图像编码器,因此,可以与SAM模型的使用方法无缝对接。