主要内容包含segment-anything项目的安装、基于SamPredictor对单点输入生成mask、基于SamPredictor对多点输入生成mask、基于SamAutomaticMaskGenerator自动生成mask。
Segment Anything项目是一个可以对任何图像进行分割的项目,其论文介绍可以查看https://blog.csdn.net/a486259/article/details/131137939,其测试网站为 https://segment-anything.com
这里对Segment Anything项目的使用进行初步总结,绝大部分内容源自https://github.com/facebookresearch/segment-anything 。
注:segment-anything训练VIT模型时的输入size为1024x1024,其输出的feature size为256x64x64,进行了16倍的下采样
1、安装segment-anything
下载segment-anything项目,进入目录后执行pip install -e .
安装项目。
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
该项目依赖opencv-python pycocotools matplotlib onnxruntime onnx torch等包,安装命令如下
pip install opencv-python pycocotools matplotlib onnxruntime onnx torch
segment-anything模型是基于torch框架实现的
2. 根据提示输入生成mask
Segment Anything Model (SAM) 预测对象mask,给出所需识别出对象的提示输入(对象的粗略位置信息)。该模型首先将图像转换为图像嵌入,然后解码器根据用户输入的提示(粗略位置信息)可以生成高质量的掩模。
SamPredictor类为模型调用提供了一个简单的接口,用于提示模型的输入。它先让用户使用“set_image”方法设置图像,该方法会将图像输入转换到特征空间嵌入。然后,可以通过“predict”方法输入提示信息,以根据这些提示有效地预测掩码。predict函数支持将点和框提示以及上一次预测迭代中的mask作为输入。
2.1 前置函数库
前置库导入和函数实现
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
2.2 显示样图
读取图片并展示
image = cv2.imread('images/truck.jpg')
image = cv2.resize(image,None,fx=0.5,fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()
2.3 加载SAM模型
sam_vit_b_01ec64模型的下载地址为: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
这里需要注意要使用cuda
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
#predictor.set_image(image)
其它版本的模型下载地址为
default
orvit_h
: ViT-H SAM model.vit_l
: ViT-L SAM model.vit_b
: ViT-B SAM model.
通过调用“SamPredictor.set_image”处理图像以生成图像嵌入(特征向量)。“SamPrejector”会记住此特征向量,并将其用于后续掩码预测。
predictor.set_image(image)
2.4 单点输入生成mask
要选择卡车,可以卡车上选择一个点。点以(x,y)格式输入到模型中,并带有标签1(前景点)或0(背景点)。可以输入多个点;这里我们只使用一个。所选的点将在图像上显示为星形。
此时代码及执行效果如下:
input_point = np.array([[250, 187]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
使用“SamPredictor.prdict”进行预测。该模型返回掩码(masks)、掩码的分数(scores)以及可传递到下一次预测迭代的低分辨率掩码(logits)。
在“multimask_output=True”(默认设置)的情况下,SAM输出3个掩码,其中“scores”给出了模型对这些掩码质量的估计。此设置用于存在不明确输入提示的时候(光凭一个点无法有效识别出用户意图是组件局部、组件还是整体),并帮助模型消除与提示一致的不同对象的歧义。当为“multimask_output=False”时,它将返回一个掩码。对于单点等不明确的提示,建议使用“multimask_output=True”,即使只需要一个掩码;可以通过选择在“分数”中返回的分数最高的一个来选择最佳的单个掩码。这通常会得到更好的mask。
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
print(masks.shape) # (number_of_masks) x H x W | output (3, 600, 900)
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
3、多输入生成mask
3.1 多点输入生成mask
单个输入点不明确,需要让模型返回了与其一致的多个对象。要获得单个对象,可以提供多个点。如果可用,还可以将先前迭代的掩码(logits值)提供给模型以帮助预测。当使用多个提示指定单个对象时,可以通过设置“multimask_output=False”来请求获取单个掩码。
input_point = np.array([[250, 184], [562, 322]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
print(masks.shape) #output: (1, 600, 900)
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
输入负点明确区域
input_label与input_point相对应,为0时表示是负点
input_point = np.array([[250, 187], [561, 322]])
input_label = np.array([1, 0])#为0时表示是负点,即第二个点[561, 322]是负点
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
3.2 boxes输入生成mask
支持将xyxy格式的box作为输入,将框内的主体目标识别出来(类似于实例分割)
input_box = np.array([212, 300, 350, 437])
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
3.3 同时输入点与boxes生成mask
point和boxes可以同时输入,只需将这两种类型的提示都包括在预测器中即可。在这里,这可以用来只选择卡车的轮胎(将车轴部分设置为负点),而不是整个车轮。
input_box = np.array([215, 310, 350, 430]) #只能默认框住正类
input_point = np.array([[287, 375]])
input_label = np.array([0]) #将车轴部分设置为负点
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
3.4 同时输入多个boxes生成mask
SamPredictor可以使用predict_tarch方法对同一图像输入多个提示(points、boxes)。该方法假设输入点已经是tensor张量,且boxes信息与image size相符合。例如,假设我们有几个来自对象检测器的输出结果。
SamPredictor对象(此外也可以使用segment_anything.utils.transforms)可以将boxes信息编码为特征向量(以实现对任意数量boxes的支持,transformed_boxes),然后预测mask。
input_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=predictor.device) #假设这是目标检测的预测结果
input_boxes=input_boxes/2
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
print(masks.shape) # (batch_size) x (num_predicted_masks_per_input) x H x W | output: torch.Size([4, 1, 600, 900])
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
3.5 端到端的批量推理
如果所有提示输入都已经明确的,则可以以端到端的方式直接运行SAM。这允许SAM对图像进行批处理,以下代码构建了2个image和boxes。
image1 = cv2.imread('images/truck.jpg')
image1_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=sam.device)
image2 = cv2.imread('images/groceries.jpg')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
image2_boxes = torch.tensor([
[450, 170, 520, 350],
[350, 190, 450, 350],
[500, 170, 580, 350],
[580, 170, 640, 350],
], device=sam.device)
图像和提示都作为PyTorch张量输入,这些张量(图像和提示输入)已经被编码为特征向量。所有的输入数据都被封装为list,每个元素都是一个dict,它的key如下:
image
: CHW格式的PyTorch tensor .original_size
: 图像原始大小, (H, W) format.point_coords
: 一批输入点格式.point_labels
: 每个输入点所对应的类型(正例或负例).boxes
: 一批输入的boxe(只能是正例).mask_inputs
: 一批输入的mask.
如果没有相应的信息,可以不进行输入,但image必须输入
from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)
def prepare_image(image, transform, device):
image = transform.apply_image(image)
image = torch.as_tensor(image, device=device.device)
return image.permute(2, 0, 1).contiguous()
batched_input = [
{
'image': prepare_image(image1, resize_transform, sam),
'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
'original_size': image1.shape[:2]
},
{
'image': prepare_image(image2, resize_transform, sam),
'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
'original_size': image2.shape[:2]
}
]
batched_output = sam(batched_input, multimask_output=False)
print(batched_output[0].keys()) # output:dict_keys(['masks', 'iou_predictions', 'low_res_logits'])
输出是每个输入图像的结果列表,其中元素是字典对象,其key为:
masks
: 一批mask,tensor张量iou_predictions
: 与mask相对应的iou预测值.low_res_logits
: 每个掩码的低分辨率logits,可以在以后的迭代中作为掩码输入再次调用模型。
fig, ax = plt.subplots(1, 2, figsize=(20, 20))
ax[0].imshow(image1)
for mask in batched_output[0]['masks']:
show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:
show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')
ax[1].imshow(image2)
for mask in batched_output[1]['masks']:
show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:
show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')
plt.tight_layout()
plt.show()
4、自动生成mask
4.1 基础前置库
这里加载了一些基础库,并读取images/dog.jpg作为样例数据
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
image = cv2.imread('images/dog.jpg')
image = cv2.resize(image,None,fx=0.5,fy=0.5)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()
4.2 自动生成mask
要自动生成mask,请向“SamAutomaticMaskGenerator”类注入SAM模型(需要先初始化SAM模型)
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
#自动生成采样点对图像进行分割
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
print(masks[0])
plt.figure(figsize=(16,16))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
代码输出的文字信息如下:
42
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
{'segmentation': array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., False, False, False],
[ True, True, True, ..., False, False, False],
[False, False, False, ..., False, False, False]]), 'area': 18821, 'bbox': [0, 113, 207, 152], 'predicted_iou': 0.9937220215797424, 'point_coords': [[93.75, 146.015625]], 'stability_score': 0.9622295498847961, 'crop_box': [0, 0, 400, 267]}
所生成的图像如下
masks = mask_generator.generate(image)
Mask generation返回该图像所有的masks信息,每一个mask都是一个字典对象,mask的keys如下:
segmentation
: np的二维数组,为二值的mask图片area
: mask的像素面积bbox
: mask的外接矩形框,为XYWH格式predicted_iou
: 该mask的质量(模型预测出的与真实框的iou)point_coords
: 用于生成该mask的point输入stability_score
: mask质量的附加指标crop_box
: 用于以XYWH格式生成此遮罩的图像裁剪
4.3 自动mask的参数
在自动掩模生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩模的阈值。此外,SamAutomaticMaskGenerator可以自动在图像上切片运行,以提高较小对象的性能,可以通过后处理去除杂散像素和孔洞。以下是对更多遮罩进行采样的示例配置:
mask_generator_2 = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,#控制采样点的间隔,值越小,采样点越密集
pred_iou_thresh=0.86,#mask的iou阈值
stability_score_thresh=0.92,#mask的稳定性阈值
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=50, #最小mask面积,会使用opencv滤除掉小面积的区域
)
masks2 = mask_generator_2.generate(image)
print(len(masks2)) # 69
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()