Segment Anything Model(SAM)是 Facebook 的一个 AI 模型,旨在推广分割技术。在我们之前的文章中,我们讨论了 SAM 的一般信息,现在让我们深入了解其技术细节。SAM 模型的结构如下图所示,图像经过编码器得到其嵌入,并且任何掩码都可以被实现。提示可以以文本、边界框或自由形式的点的形式出现。我们对提示进行编码,并将其与图像嵌入一起传递给解码器,该解码器生成我们的掩码。
SAM 最有趣的特性之一是其轻量级的编码器和解码器,可以实现实时性能。您可以使用在 GitHub 上可用的打包版本在 Python 中使用 SAM:https://github.com/kadirnar/segment-anything-video
然而,如果您在使用过程中遇到问题,您可以使用原始 GitHub 页面上提供的 Colab 文件。以下是您可以开始使用的步骤:
using_colab = True
if using_colab:
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
首先,导入 Torch 和 Torchvision,这是项目所必需的,然后使用 pip 安装 Segment Anything。
从https://github.com/facebookresearch/segment-anything#model-checkpoints 下载模型检查点,稍后我们将使用它。
接下来,创建一个图像目录,您可以将测试图像放在其中。您还可以通过替换以下命令中的 URL 来使用自己的图像:
!mkdir images
!wget -O images/image.jpg https://live.staticflickr.com/65535/49894878561_14a39c6c35_b.jpg
一旦您拥有了图像,您可以导入必要的包,包括 numpy、Torch、Matplotlib 和 OpenCV。
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
您可以使用以下函数绘制注释:
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
现在,使用 OpenCV 读取图像并将通道从 BGR 更改为 RGB。然后,使用 Matplotlib 显示图像。
image = cv2.imread('images/image.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()
要创建您的掩码生成器,您需要定义您的 SAM 模型并使用 SamAutomaticMaskGenerator:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
我们的 SAM 模型注册表需要检查点和模型类型以提供模型,记得将您的运行时设置为 GPU。SamAutomaticMaskGenerator 使用您的模型来制作您的掩码生成器。您所需要做的就是将您的输入传递给这个函数以获取您的掩码。
masks = mask_generator.generate(image)
掩码对象包含关于区域和稳定性分数的多个信息,稍后会将标签添加到此掩码对象中。让我们来看一下输出:
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
您还可以通过更改以下变量来调整遮罩生成器的参数:
mask_generator_2 = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Requires open-cv to run post-processing
)
masks2 = mask_generator_2.generate(image)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()
不同参数的另一个输出
您可以在这里找到源代码:
https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb
· END ·
HAPPY LIFE