按原项目GitHub - facebookresearch/segment-anything: The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.步骤
Install Segment Anything:
pip install git+https://github.com/facebookresearch/segment-anything.git
or
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .
但 报错
https改成git后依然报相同的错,直接下载项目上传到服务器后运行
cd segment-anything; pip install -e .
配置好环境,下载checkpoint文件
打开 segment-anything/notebooks/automatic_mask_generator_example.ipynb"运行即可,或者把代码集合到同一个py文件中直接运行
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
# image = cv2.imread('images/dog.jpg')
image = cv2.imread('image path')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('off')
plt.show()
sam_checkpoint = "checkpoint path"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()