【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块PromptEncoder网络解析
Segment Anything:建立了迄今为止最大的分割数据集,在1100万张图像上有超过1亿个掩码,模型的设计和训练是灵活的,其重要的特点是Zero-shot(零样本迁移性)转移到新的图像分布和任务,一个图像分割新的任务、模型和数据集。SAM由三个部分组成:一个强大的图像编码器(Image encoder)计算图像嵌入,一个提示编码器(Prompt encoder)嵌入提示,然后将两个信息源组合在一个轻量级掩码解码器(Mask decoder)中来预测分割掩码。本博客将讲解Prompt encoder模块的深度学习网络代码。
文章目录
- 【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块PromptEncoder网络解析
- 前言
- PromptEncoder网络简述
- SAM模型关于ProEnco网络的配置
- ProEnco网络结构与执行流程
- ProEnco网络基本步骤代码详解
- Embed_Points
- Embed_Boxes
- Embed_Masks
- PositionEmbeddingRandom
- 总结
前言
在详细解析SAM代码之前,首要任务是成功运行SAM代码【win10下参考教程】,后续学习才有意义。本博客讲解Prompt encoder模块的深度网络代码,不涉及其他功能模块代码。
PromptEncoder网络简述
SAM模型关于ProEnco网络的配置
博主以sam_vit_b为例,详细讲解ViT网络的结构。
代码位置:segment_anything/build_sam.py
def build_sam_vit_b(checkpoint=None):
return _build_sam(
# 图像编码channel
encoder_embed_dim=768,
# 主体编码器的个数
encoder_depth=12,
# attention中head的个数
encoder_num_heads=12,
# 需要将相对位置嵌入添加到注意力图的编码器( Encoder Block)
encoder_global_attn_indexes=[2, 5, 8, 11],
# 权重
checkpoint=checkpoint,
)
sam模型中prompt_encoder模块初始化
prompt_encoder=PromptEncoder(
# 提示编码channel(和image_encoder输出channel一致,后续会融合)
embed_dim=prompt_embed_dim,
# mask的编码尺寸(和image_encoder输出尺寸一致)
image_embedding_size=(image_embedding_size, image_embedding_size),
# 输入图像的标准尺寸
input_image_size=(image_size, image_size),
# 对输入掩码编码的通道数
mask_in_chans=16,
),
ProEnco网络结构与执行流程
Prompt encoder源码位置:segment_anything/modeling/prompt_encoder.py
ProEnco网络(PromptEncoder类)结构参数配置。
def __init__(
self,
embed_dim: int, # 提示编码channel
image_embedding_size: Tuple[int, int], # # mask的编码尺寸
input_image_size: Tuple[int, int], # 输入图像的标准尺寸
mask_in_chans: int, # 输入掩码编码的通道数
activation: Type[nn.Module] = nn.GELU, # 激活层
) -> None:
super().__init__()
self.embed_dim = embed_dim # 提示编码channel
self.input_image_size = input_image_size # 输入图像的标准尺寸
self.image_embedding_size = image_embedding_size # mask的编码尺寸
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # 4个点:正负点,框的俩个点
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] # 4个点的嵌入向量
# nn.ModuleList它是一个存储不同module,并自动将每个module的parameters添加到网络之中的容器
self.point_embeddings = nn.ModuleList(point_embeddings) # 4个点的嵌入向量添加到网络
self.not_a_point_embed = nn.Embedding(1, embed_dim) # 不是点的嵌入向量
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) # mask的输入尺寸
self.mask_downscaling = nn.Sequential( # 输入mask时 4倍下采样
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim) # 没有mask输入时 嵌入向量
SAM模型中ProEnco网络结构如下图所示:
ProEnco网络(PromptEncoder类)在特征提取中的几个基本步骤:
- Embed_Points:标记点编码(标记点由点转变为向量)
- Embed_Boxes:标记框编码(标记框由点转变为向量)
- Embed_Masks:mask编码(mask下采样保证与Image encoder输出一致)
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# 获得 batchsize 当前predict为1
bs = self._get_batch_size(points, boxes, masks)
# -----sparse_embeddings----
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
# -----sparse_embeddings----
# -----dense_embeddings----
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
# -----dense_embeddings----
return sparse_embeddings, dense_embeddings
获取batchsize
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
获取设备型号
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
ProEnco网络基本步骤代码详解
Embed_Points
标记点预处理,将channel由2变成embed_dim(PositionEmbeddingRandom),然后再加上位置编码权重。
2:坐标(h,w)
embed_dim:提示编码的channel
Embed_Points结构如下图所示:
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
# 移到像素中心
points = points + 0.5
# points和boxes联合则不需要pad
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) # B,1,2
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) # B,1
points = torch.cat([points, padding_point], dim=1) # B,N+1,2
labels = torch.cat([labels, padding_label], dim=1) # B,N+1
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) # B,N+1,2f
# labels为-1是非标记点,设为非标记点权重
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
# labels为0是背景点,加上背景点权重
point_embedding[labels == 0] += self.point_embeddings[0].weight
# labels为1的目标点,加上目标点权重
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
个人理解:pad的作用相当于box占位符号,box和points可以联合标定完成图像分割的,但是此时的box只能有一个,不能有多个。
Embed_Boxes
标记框预处理,将channel由4到2再变成embed_dim(PositionEmbeddingRandom),然后再加上位置编码权重。
4:坐标(h1,w1,h2,w2) -->起始点与末位点
2:坐标(h,w)–>4 reshape 成 2×2
embed_dim:提示编码的channel
Embed_Boxes结构如下图所示:
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
# 移到像素中心
boxes = boxes + 0.5
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) #
# 目标框起始点的和末位点分别加上权重
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
个人理解:boxes reshape 后 batchsize是会增加的,B,N,4–>BN,2,2
因此这里可以得出box和points联合标定时,box为什么只能是一个,而不能是多个。
Embed_Masks
mask的输出尺寸是Image encoder模块输出的图像编码尺寸的4倍,因此为了保持一致,需要4倍下采样。
Embed_Masks结构如下图所示:
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
# mask下采样4倍
mask_embedding = self.mask_downscaling(masks)
return mask_embedding
# 在PromptEncoder的__init__定义
self.mask_downscaling = nn.Sequential( # 输入mask时 4倍下采样
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
假设没有mask输入,则将no_mask_embed编码扩展到与图像编码一致的尺寸代替mask。
# 在PromptEncoder的forward定义
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
PositionEmbeddingRandom
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
# 理解为模型的常数 [2,f]
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
将标记点的坐标具体的位置转变为[0~1]之间的比例位置
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
coords = coords_input.clone()
# 将坐标位置缩放到[0~1]之间
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
# B,N+1,2-->B,N+1,2f
return self._pe_encoding(coords.to(torch.float))
标记点位置编码
因为sin和cos,编码的值归一化至 [-1,1],源码注释是[0,1],博主经过实验发现注释不对
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
# B,N+1,2 × 2,f --> B,N+1,f
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
# B,N+1,2f
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
总结
尽可能简单、详细的介绍SAM中Prompt encoder模块的ProEnco网络的代码。后续会讲解SAM的其他模块的代码。