前提条件:labelimg打标签得到bbox
1.代码
import torch
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
import os
import glob
import xml.etree.ElementTree as ET
checkpoint = "./weight/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)
image_dir = r"D:\Desktop\mult_test\images"
# 获取图片目录下的所有图片文件路径
image_files = glob.glob(os.path.join(image_dir, '*.[jJpPeEgG]*')) # 获取任意格式的图片
save_dir = r"D:\Desktop\mult_test\mask"
# 注释文件目录路径
xml_dir = r'D:\Desktop\mult_test\label'
# 遍历图片文件
for image_file in image_files:
image = cv2.imread(image_file)
predictor.set_image(image)
# 获取图片文件名(不包含扩展名)
image_filename = os.path.splitext(os.path.basename(image_file))[0]
# 构建注释文件路径
xml_file = os.path.join(xml_dir,image_filename + '.xml')
tree = ET.parse(xml_file)
root = tree.getroot()
data_list = []
# 遍历 XML 标注文件中的目标对象
for object_elem in root.findall('object'):
# 获取目标对象的边界框坐标
bbox_elem = object_elem.find('bndbox')
xmin = int(bbox_elem.find('xmin').text)
ymin = int(bbox_elem.find('ymin').text)
xmax = int(bbox_elem.find('xmax').text)
ymax = int(bbox_elem.find('ymax').text)
data = [xmin,ymin,xmax,ymax]
data_list.append(data)
input_boxes = torch.tensor(data_list, device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
first_mask = np.where(masks[0].cpu().numpy()[0, :, :] == 1, 0, 1) * 255
for i in range(1, len(masks)):
first_mask &= np.where(masks[i].cpu().numpy()[0, :, :] == 1, 0, 1) * 255
image_filename = os.path.basename(image_file)
cv2.imwrite(os.path.join(save_dir, image_filename), first_mask)
2.效果展示