Segment Anything网络结构详解
论文链接:http://arxiv.org/abs/2304.02643
代码链接:https://github.com/facebookresearch/segment-anything
一、整体框架
二、图像编码器image encoder
使用一个MAE预训练好的ViT模型(ViT-H/16 使用了 14 × 14 14 \times 14 14×14的窗口注意力和四个等步长的全局注意力模块),最后输出特征宽度为原图大小的1/16。
使用 1024 × 1024 1024 \times 1024 1024×1024大小的图像作为输入,缩放图像和填充最短边到1024,得到的图像特征大小为 64 × 64 64 \times 64 64×64,为减少特征维度,使用一个 1 × 1 1 \times 1 1×1的卷积核将特征缩放到1024个通道,接着使用一个1024通道 3 × 3 3 \times 3 3×3的卷积核,每个卷积核后面都带有层归一化。同一张图片仅需一次推理。
输入:image
输出:image embedding
三、Prompt encoder
Prompt encoder:将特征映射为256维的向量embedding。
一个点代表点位置编码的总和,其中一个可学习的embedding表示前景还是背景。
一个框有一对embedding表示,分别为左上角点的位置编码和右下角点的位置编码。
对于文字,则使用CLIP的文字编码。
Dense prompts(例如mask)对应图像上空间位置,输入mask设置为原图像大小的1/4倍。使用两个 2 × 2 2 \times 2 2×2,步长为2的卷积,输出通道分别为4和16。最后加上一个 1 × 1 1 \times 1 1×1的卷积映射到256维特征。每层都使用GELU激活函数和层归一化操作。mask和图像embedding进行元素相加。如果没有mask。就设置为一个可学习的embedding。
输入:point、box point、text、mask
输出:prompt tokens (prompt token + dense prompts token)
四、轻量级的mask解码器lightweight mask decoder
轻量级的mask解码器:将图像embedding和prompt embedding映射为一个mask输出。在输入前插入一个可学习的输出token embedding。
输入:image embedding(含位置编码)、output token(dense prompts token) + prompt token(含位置编码)
输出:mask和IOU置信度
每个解码器层执行4个步骤:
1)token之间的自注意力;
2)token(作为查询)到图像嵌入的交叉注意力token to image attn(更新token).;
3)点积MLP更新每个token;
4)图像嵌入(作为查询Q)到token的交叉注意力image to token attn.(更新image embedding)。这一步骤更新了图像embedding,包含了prompt信息。
在交叉注意力过程中,图像嵌入被视为一组
6
4
2
64^{2}
642个256维的向量。每个自/交叉注意力和MLP都有残差连接 [49],层归一化,以及训练时丢失率为0.1的dropout [93]。下一解码器层将前一层更新的token和更新的图像嵌入作为输入。使用两层的解码器
。
解码器中每当参与注意力层,位置编码都会被添加到图像嵌入中,同时还会将原始prompt tokens(包含它们的位置编码ouput token)重新添加到更新的token中。-> 增强prompt token的几何位置和类型有很强的依赖。
将解码器后更新的图像嵌入使用两个转置卷积上采样4倍。然后将token再次嵌入到图像嵌入中,将更新后的输出token传递到一个小型3层MLP(多层感知器),该MLP输出一个与放大图像嵌入通道维度匹配的向量。最后通过上采样图像嵌入和MLP输出的空间点积预测一个掩码mask。再将更新后的输出token经过一个MLP输出IOU对应的置信度。
五、结构详细说明
1) Transformer使用输出256维的嵌入维度,Transformer中的MLP中间层使用2048维,在交叉注意层中使用一个 64 × 64 64 \times 64 64×64的图像嵌入,并将查询Q、键K和值V的通道维度减半到128维,使用8个头的注意力层。
2) 用于上采样输出图像嵌入的2层转置卷积是 2 × 2 2 \times 2 2×2,步长为2,输出通道维度分别是64和32,含有层归一化层和GELU激活函数,token经过3层MLP后,两者点乘获得最后的mask。