论文地址:Segment Anything
项目地址:https://github.com/facebookresearch/segment-anything
在线Demo: https://segment-anything.com/demo
前言
近日,MetaAI团队开源了一个用于分割的10亿个masks,1100w张图片数据集SA-1B。并开开源了基于此数据集训练的模型。这应该是目前CV分割领域最强的基础模型。
回答3个问题就能搞懂SAM
segment anything model(SAM)能根据用户给定的prompt,分割出该prompt对应的mask。它的工作pipeline如下图所示,主要包含3个步骤:
- 通过image encode提取图片的embedding。这个过程很耗时,不过对每一张图片只需计算一次。
- 通过prompt encode对用户传入的prompt进行编码
- 通过mask decoder对image embedding与prompt emnbedding进行解码,提取所用感兴趣的分割。
SAM它能够实现与用户交互的实时性。这是因为最耗时的image enbedding会在用户上传图片的时候计算好,交互时只需推理prompt encode与mask decode这两个轻量的模块。下面我们来看它是怎么做的。
image encode工作机制
采用的是基于MAE1预训练的ViT模型,在最后接了2层卷积层进行降维。模型的输出是图片 X i n ∈ R 3 × 1024 × 1024 X_{in} \in \mathbb{R}^{3 \times 1024 \times 1024} Xin∈R3×1024×1024,模型的输出 X o u t ∈ R 256 × 64 × 64 X_{out} \in \mathbb{R}^{256 \times 64 \times 64} Xout∈R256×64×64,下采样了16倍。
Prompt encode的工作机制
两类prompt,离散型(point、boxes、text),稠密型(mask)
-
Points prompt 采用了文献2的位置编码方式,给定点的坐标与前景点/背景点的标签,返回 R B × 2 × 256 \mathbb{R}^{B \times 2 \times 256} RB×2×256的point embedding。只有一个点为什么是 2 × 256 2 \times 256 2×256这是为了和box promot处理方式保持一致,对point进行了padding操作。
-
box prompt。box可以视为2个点组成,故可采用point promot的位置编码操作,其返回也是 R B × 2 × 256 \mathbb{R}^{B \times 2 \times 256} RB×2×256
-
text prompt。论文中提出一个概念性的验证(proof-of-concept),可以用CLIP的text encode,最后在其维度进行处理。
mask decode如何融入根据image embedding与prompt embedding来预测mask
Mask decode的主要工作流程如下图所示。其输入包含两部分:image embedding与prompt embedding。其输出也包含两部分:预测的masks与每个mask对应的IoU分数。
2个值得注意的细节:
如何处理模糊感知
什么是模糊感知呢(ambiguity-aware)。举个🌰,假设一张图片有一个人身穿印有小汽车的T恤,当用户给了一个T恤上小汽车位置prompt,模型该输出什么mask呢? 是小汽车还是T恤还是人?
Meta的解决方案是,一个prompt对应多个输出,以上面的case为例,此时模型输出3个mask,分别是小汽车(sub-part)、T恤(part)和人(object)。在训练阶段,只反向传播着三个mask误差最低的mask。在推理时返回置信度最高的mask作为输出。
下图为论文中给的示例
如何构model-in-loop dataset (three stage)
assisted-manual
先用开源数据集进行训练SAM,再结合交互式的标注方法进行人工标注。随后再仅用标注的数据进行retrain。这个过程进行了6轮,通过模型的迭代,标注人员的标注时间从34s/mask降低到14s/mask(一张图片8-9分钟左右),并且分割的粒度不断细化,从20masks/image提升到44masks/image(说明模型效果不断提升)。这个阶段总计获得120k标注图片,总计4.3M masks。
值得注意的时,标注时标注人员也会打语义标签,它的标签是灵活的,不局限于给定的几类。
semi-automatic
这个阶段的主要目的是提升mask的多样性,核心目的是让模型能够segment anything。首先会通过预训练的模型检测出显著性较高的mask,随后让标注人员完善显著性不高的物体的标注。retrain过程重复了5次。由于标注的粒度更细,且更不显著,标注速率回升到34s/mask。在这个阶段总计标注了180k的图片得到5.9M的mask。每张图片的平均mask从44提升至72。通过上述两个阶段总计获得10.2M的mask,300k图片。
值得注意的是,为了获得显著性的mask,meta基于第一阶段的mask构建bounding box训练了一个检测模型。将被检测到的mask作为显著性mask。
Fully automatic
在全自动阶段,模型会默认给定32*32的点阵作为prompt(总计1024个点)。根据感知模糊规则,每个点有3个mask,部分(part)、子部分(sub-part)、整体(object)的mask (总计3072mask)。再根据模型的IOU预测模块给这个点的mask进行打分,并返回稳定的mask。最后再根据非极大抑制来过滤重复的mask。为了进一步提升小mask的质量,meta还对这部分区域进行了重叠crop再预测,再合并多个crop image的分割结果。通过这个全自动的样本制造过程,总计得到11M的图片(图片平均分辨率3300x4950),1.1B高质量的mask。
值得注意的是,稳定mask的判别条件:如果将mask 的semantic map分别按照阈值 0.5 + δ , 0.5 − δ 0.5 + \delta, 0.5-\delta 0.5+δ,0.5−δ进行二值化得到的mask一致,则认为是稳定的mask。
如何用segment anything进行zero-shot迁移
下面主要概述下如何用sam做edge detection,即如何基于text生成mask。
如何用sam进行edge detection
类似创建dataset的Fully automatic方案。这里会预先给定16 * 16的点阵作为point prompt。这样总计可以得到768个mask。随后通过非极大抑制过滤重复mask,再对每个mask用Sobel算子提取边缘,最后再对提取的边缘用非极大抑制进行细化。
(真的是强行zero-shot,上面的pipeline计算成本非常非常高)
如何用sam进行instance segmentation
先通过一个检测模型检测出object的bounding box,随后在用bounding box作为box prompt获得实例分割图。
如何用sam实现text-to-mask
这个任务根据输入的文本分割出图中所有符合的实例。meta还没有填这个坑只是提出了一个概念性的验证(proof-of-concept)。在训练过程中,将mask区域大小大于 10 0 2 100^2 1002的mask对应图片区域的CLIP image embedding作为prompt的输入。由于CLIP的image embedding和text embedding做了对齐,因此在推理阶段可以将文本的embedding作为prompt来检测相关的segment。
总结
SAM从算法层面来看,基本都是已有的算法。他的亮点主要在于model-in-loop创建数据集的pipeline和二阶段推理交互逻辑(先提取图片特征,在用一个轻量模型配合用户做交互)。之所以能够达到这么惊艳的效果,主要还是取决于训练的样本足够大、质量足够高。他让我们看到了现有模型的上界。
基于SAM 二次开发的工作
项目名称 | description | 相关信息 |
---|---|---|
Track-Anything | 视频目标跟踪与分割 | https://github.com/gaomingqi/Track-Anything |
Segment-Everything-Everywhere-All-At-Once | 基于多模态prompt分割 (point、box、mask、audio、text) | https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once |
caption anything | 生成sam分割实体的描述。底层架构:SAM, BLIP2, ChatGPT | https://github.com/ttengwang/Caption-Anything |
edit anything | 将sam集成到stable-diffusion中。简化inpaint gnerate过程。底层架构:SAM, StableDIffusion,ContorlNet, BLIP2 | https://github.com/sail-sg/EditAnything |
Semantic Segment Anything (SSA) | 给sam开源的SA-1B的mask打上语义标签。获得一个新的数据集SSA-1B | https://github.com/fudan-zvg/Semantic-Segment-Anything |
salt | 基于pypq5和SAM开发的segment标注工具 | https://github.com/anuragxel/salt |
grounded-segment-anything | 集成SAM,whisper,ChatGPT,diffusion model,BLIP等来解决一些复杂的cv问题,如inpaint 生成、control生成、基于whisper生成,实例分割,目标追踪等 | |
Segment Anything for Stable Diffusion WebUI | 将SAM集成到stable-diffusion-webUI中 | https://github.com/continue-revolution/sd-webui-segment-anything |
参考文献
[Masked autoencoders are scalable vision learners.]( ↩︎
Fourier features let networks learn high frequency functions in low dimensional domains. ↩︎