sagment-anything官方代码使用详解

news2024/9/28 21:23:45

文章目录

  • 一. sagment-anything官方例程说明
    • 1. 结果显示函数说明
    • 2. SamAutomaticMaskGenerator对象
      • (1) SamAutomaticMaskGenerator初始化参数
    • 3. SamPredictor对象
      • (1) 初始化参数
      • (2) set_image()
      • (3) predict()
  • 二. SamPredictor流程说明
    • 1. 导入所需要的库
    • 2. 读取图像
    • 3. 加载模型
    • 4. 生成预测对象
    • 5. 设置要检测的图像
    • 6. 根据不同输入需求对图像进行掩膜预测
      • (1) 根据输入一个点,输出对于这个点的三个不同置信度的掩膜
      • (2) 通过多个点获取一个对象的掩膜
      • (3) 通过设置反向点反选掩膜
      • (4) boxes输入生成掩膜
      • (5) 同时输入点与boxes生成掩膜
      • (6) 多个输入输出不同预测结果
  • 三. SamAutomaticMaskGenerator预测流程
    • 1. 导入所需要的库
    • 2. 读取图像
    • 3. 加载模型
    • 4. 生成预测对象
    • 5. 设置要检测的图像
    • 6. 给分割出来的物体上色,显示分割效果
  • 四. SamAutomaticMaskGenerator不同参数下的检测效果
    • 1. points_per_side参数测试
    • 2. pred_iou_thresh参数测试
    • 3. stability_score_thresh参数测试
    • 4. box_nms_thresh参数测试
    • 5. crop_nms_thresh参数测试

一. sagment-anything官方例程说明

1. 结果显示函数说明

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

2. SamAutomaticMaskGenerator对象

(1) SamAutomaticMaskGenerator初始化参数

  • model (Sam): 用于掩模预测的Sam模型。
  • points_per_side (int or None): 沿图像一侧要采样的点的数量。总点数为points_per_side 2 ^2 2。如果为None,则point_grids必须提供显式点采样。默认为32
  • points_per_batch (int): 设置模型同时检测的点数。更高的数字可能更快,但使用更多的GPU内存。默认为64
  • pred_iou_thresh (float): [0,1]中的滤波阈值,使用模型的预测掩码质量。默认值为0.88
  • stability_score_thresh (float): [0,1]中的滤波阈值,使用掩码在截断值变化下的稳定性,用于对模型的掩码预测进行二值化。默认值为0.95
  • stability_score_offset (float): 计算稳定性分数时,偏移截止值的量。默认值为1.0
  • box_nms_thresh (float): 非最大抑制用于过滤重复掩码的框IoU截止。默认值为0.7
  • crop_n_layers (int): 如果>0,将对图像的裁剪再次运行掩膜预测。设置要运行的层数,其中每层具有2*i_layer数量的图像裁剪。默认值为0
  • crop_nms_thresh (float): 非最大抑制用于过滤不同物体之间的重复掩码的框IoU截止。默认值为0.7
  • crop_overlap_ratio (float): 设置物体重叠的程度。在第一个裁剪层中,裁剪将重叠图像长度的这一部分。物体较多的后几层会缩小这种重叠。默认值为512 / 1500
  • crop_n_points_downscale_factor (int): 在层n中采样的每侧的点数按比例缩小crop_n_points_downscale_factor n ^n n。默认值为1
  • point_grids (list(np.ndarray) or None): 用于采样的点的显式网格上的列表,归一化为[0,1]。列表中的第n个栅格用于第n个裁剪层。与points_per_side独占。默认值为None
  • min_mask_region_area (int): 如果>0,将应用后处理来移除面积小于min_mask_region_area的掩膜来中断开连接的区域和孔。需要opencv。默认为0
  • output_mode (str): 表单掩码在中返回。可以是binary_maskuncompressed_rlecoco_rlecoco_rle需要pycocotools。对于大分辨率,binary_mask可能会消耗大量内存。默认为'binary_mask'
    “”"

3. SamPredictor对象

(1) 初始化参数

  • model (Sam): 用于掩模预测的Sam模型。

(2) set_image()

说明:
	设置检测的图像
参数:
	image(np.ndarray):用于计算掩码的图像。应为HWC uint8格式的图像,像素值为[0,255]。
	image_format(str):图像的颜色格式,以'RGB''BGR'为单位。

(3) predict()

说明:
	使用当前设置的图像预测给定输入提示的掩码。
参数:
	point_coords(np.ndarray或None):存放指向图像中物体的点的Nx2数组。每个点都以像素为单位(X,Y)。
	point_labels(np.ndarray或None):点提示的长度为N的标签阵列。1表示前景点,0表示背景点。
	box(np.ndarray或None):长度为4的数组,以XYXY格式向模型提供长方体提示。
	mask_input(np.ndarray):输入到模型的低分辨率掩码,通常来自先前的预测迭代。形式为1xHxW,其中对于SAM,H=W=256。
	multimask_output(bool):如果为true,则模型将返回三个掩码。对于不明确的输入提示(如单击),这通常会产生比单个预测更好的掩码。
	                          如果只需要单个遮罩,则可以使用模型的预测质量分数来选择最佳遮罩。对于非模糊提示,例如多个输入提示,
	                          multimask_output=False可以提供更好的结果。
	return_logits(bool):如果为true,则返回未阈值掩码logits,而不是二进制掩码。
返回值:
    (np.ndarray):CxHxW格式的输出掩码,其中C是掩码的数量,(H,W)是原始图像大小。
    (np.ndarray):长度为C的数组,包含模型对每个掩码质量的预测。
    (np.ndarray):形状为CxHxW的数组,其中C是掩码的数量,H=W=256。这些低分辨率logits可以作为掩码输入传递给后续迭代。

二. SamPredictor流程说明

1. 导入所需要的库

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

2. 读取图像

image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

3. 加载模型

sam_checkpoint = "sam_vit_h_4b8939.pth"  # 模型文件所在路径
model_type = "vit_h"  # 模型的类型
device = "cuda"  # 运行模型的设备

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)  # 注册模型
sam.to(device=device)

4. 生成预测对象

mask_predictor = SamPredictor(sam)  # 生成sam预测对象

5. 设置要检测的图像

predictor.set_image(image)

6. 根据不同输入需求对图像进行掩膜预测

(1) 根据输入一个点,输出对于这个点的三个不同置信度的掩膜

input_point = np.array([[250, 187]])
input_label = np.array([1])

# 在'multimask_output=True'(默认设置)的情况下,SAM输出3个掩码,其中“scores”给出了模型对这些掩码质量的估计。
masks, scores, logits = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=True,)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

(2) 通过多个点获取一个对象的掩膜

# 通过多个点获取一个对象的掩膜
input_point = np.array([[237, 244], [273, 259]])
input_label = np.array([1, 1])  # 把两个点的标签都设置为1,代表两个点为同一个目标物所有 

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(3) 通过设置反向点反选掩膜

# 通过多个点获取一个对象的掩膜
input_point = np.array([[237, 244], [319, 274]])
input_label = np.array([1, 0])  # 把两个点的标签都设置为1,代表两个点为同一个目标物所有 

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(4) boxes输入生成掩膜

input_box = np.array([228, 230, 280, 276])

masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False,)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(5) 同时输入点与boxes生成掩膜

input_point = np.array([[237, 244]])
input_label = np.array([1])
input_box = np.array([228, 230, 280, 276])

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, box=input_box[None, :], multimask_output=False,)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_points(input_point, input_label, plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(6) 多个输入输出不同预测结果

SamPredictor可以使用predict_tarch方法对同一图像输入多个提示(points、boxes)。该方法假设输入点已经是tensor张量,且boxes信息与image size相符合。例如,假设我们有几个来自对象检测器的输出结果。
SamPredictor对象(此外也可以使用segment_anything.utils.transforms)可以将boxes信息编码为特征向量(以实现对任意数量boxes的支持,transformed_boxes),然后预测mask。

input_boxes = torch.tensor([
    [228, 230, 280, 276],
    [495, 90, 554, 125],
    [447, 499, 494, 548],
    [162, 346, 214, 390],
], device=predictor.device) #假设这是目标检测的预测结果

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])

masks, _, _ = predictor.predict_torch(point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False)

plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

三. SamAutomaticMaskGenerator预测流程

1. 导入所需要的库

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

2. 读取图像

image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

3. 加载模型

sam_checkpoint = "sam_vit_h_4b8939.pth"  # 模型文件所在路径
model_type = "vit_h"  # 模型的类型
device = "cuda"  # 运行模型的设备

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)  # 注册模型
sam.to(device=device)

4. 生成预测对象

mask_generator = SamAutomaticMaskGenerator(model=sam,
                                           points_per_side=32,
                                           points_per_batch=64,
                                           pred_iou_thresh=0.88,
                                           stability_score_thresh=0.95,
                                           stability_score_offset=1.0,
                                           box_nms_thresh=0.7,
                                           crop_n_layers=0,
                                           crop_nms_thresh=0.7,
                                           crop_overlap_ratio=0.34133,
                                           crop_n_points_downscale_factor=1,
                                           point_grids=None,
                                           min_mask_region_area=0,
                                           output_mode='binary_mask')

5. 设置要检测的图像

# 将图像送入推理对象进行推理分割,输出结果为一个列表,其中存的每个字典对象内容为:
# segmentation : 分割出来的物体掩膜(与原图像同大小,有物体的地方为1其他地方为0)
# area : 物体掩膜的面积
# bbox : 掩膜的边界框(XYWH)
# predicted_iou : 模型自己对掩模质量的预测
# point_coords : 生成此掩码的采样输入点
# stability_score : 掩模质量的一个附加度量
# crop_box : 用于以XYWH格式生成此遮罩的图像的裁剪
masks = mask_generator.generate(image)

6. 给分割出来的物体上色,显示分割效果

# 给分割出来的物体上色,显示分割效果
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

在这里插入图片描述

四. SamAutomaticMaskGenerator不同参数下的检测效果

1. points_per_side参数测试

  1. points_per_side=4,检测到9个物体
    在这里插入图片描述

  2. points_per_side=16,检测到211个物体
    在这里插入图片描述

  3. points_per_side=64,检测到683个物体
    在这里插入图片描述

  4. points_per_side=256,检测到872个物体
    在这里插入图片描述

2. pred_iou_thresh参数测试

  1. pred_iou_thresh=1, 检测到1个物体
    在这里插入图片描述
  2. pred_iou_thresh=0.95, 检测到274个物体
    在这里插入图片描述
  3. pred_iou_thresh=0.8, 检测到792个物体
    在这里插入图片描述

3. stability_score_thresh参数测试

  1. stability_score_thresh=1,检测到0个物体
    kjui
  2. stability_score_thresh=0.95,检测到683个物体
    在这里插入图片描述
  3. stability_score_thresh=0.95,检测到764个物体
    在这里插入图片描述

4. box_nms_thresh参数测试

  1. box_nms_thresh=1,检测到4680个物体
    在这里插入图片描述

  2. box_nms_thresh=0.7,检测到683个物体
    在这里插入图片描述

  3. box_nms_thresh=0.4,检测到621个物体
    在这里插入图片描述

  4. box_nms_thresh=0.1,检测到458个物体
    在这里插入图片描述

  5. box_nms_thresh=0,检测到201个物体
    在这里插入图片描述

5. crop_nms_thresh参数测试

  1. crop_nms_thresh=1,检测到683个物体
    在这里插入图片描述

  2. crop_nms_thresh=0.7,检测到683个物体
    在这里插入图片描述

  3. crop_nms_thresh=0.1,检测到683个物体
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1284026.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

QT之QString

QT之QString 添加容器 点击栅格布局 添加容器,进行栅格布局 布局总结:每一个模块放在一个Group中,排放完之后,进行栅格布局。多个Group进行并排时,先将各个模块进行栅格布局,然后都选中进行垂直布…

Python实现交易策略评价指标-夏普比率

1.夏普比率的定义 在投资的过程中,仅关注策略的收益率是不够的,同时还需要关注承受的风险,也就是收益风险比。 夏普比率正是这样一个指标,它表示承担单位的风险会产生多少超额收益。用数学公式描述就是: S h a r p R…

Java中三种定时任务总结(schedule,quartz,xxl-job)

目录 1、Spring框架的定时任务 2、Quartz Quartz的用法 3、xxl-job 3.1 docker 安装xxl-job 3.2 xxl-job编程测试 补充:Java中自带的定时任务调度 1. java.util.Timer和java.util.TimerTask 2. java.util.concurrent.Executors和java.util.concurrent.Sche…

TI 毫米波雷达器件中的自校准功能(TI文档)

摘要 TI 的毫米波雷达传感器包括一个内部处理器和硬件架构,支持自校准和监控。校准可确保在温度和工艺变化范围内维持雷达前端的性能。监控可以周期性测量射频/模拟性能参数并检测潜在故障。 本应用手册简要介绍了校准和监控机制,主要侧重于内部…

整数的立方和

系列文章目录 进阶的卡莎C++_睡觉觉觉得的博客-CSDN博客数1的个数_睡觉觉觉得的博客-CSDN博客双精度浮点数的输入输出_睡觉觉觉得的博客-CSDN博客足球联赛积分_睡觉觉觉得的博客-CSDN博客大减价(一级)_睡觉觉觉得的博客-CSDN博客小写字母的判断_睡觉觉觉得的博客-CSDN博客纸币(…

卷积神经网络(CNN):乳腺癌识别.ipynb

文章目录 一、前言一、设置GPU二、导入数据1. 导入数据2. 检查数据3. 配置数据集4. 数据可视化 三、构建模型四、编译五、训练模型六、评估模型1. Accuracy与Loss图2. 混淆矩阵3. 各项指标评估 一、前言 我的环境: 语言环境:Python3.6.5编译器&#xf…

线程池、及Springboot线程池实践

摘要 本文介绍了线程池基本概念、线程及线程池状态、java中线程池提交task后执行流程、Executors线程池工具类、最后介绍在springboot框架下使用线程池和定时线程池,以及task取消 线程池基本 背景 线程池 线程池是一种多线程处理形式,处理过程中将任务…

HCIP——交换综合实验

一、实验拓扑图 二、实验需求 1、PC1和PC3所在接口为access,属于vlan2;PC2/4/5/6处于同一网段,其中PC2可以访问PC4/5/6;但PC4可以访问PC5,不能访问PC6 2、PC5不能访问PC6 3、PC1/3与PC2/4/5/6/不在同一网段 4、所有PC通…

【Java】类和对象之超级详细的总结!!!

文章目录 前言1. 什么是面向对象?1.2面向过程和面向对象 2.类的定义和使用2.1什么是类?2.2类的定义格式2.3类的实例化2.3.1什么是实例化2.3.2类和对象的说明 3.this引用3.1为什么会有this3.2this的含义与性质3.3this的特性 4.构造方法4.1构造方法的概念4…

数据结构第六课 -----链式二叉树的实现

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 ​🎂 作者介绍: 🎂🎂 🎂 🎉🎉&#x1f389…

Java生成word[doc格式转docx]

引入依赖 <!-- https://mvnrepository.com/artifact/org.freemarker/freemarker --><dependency><groupId>org.freemarker</groupId><artifactId>freemarker</artifactId><version>2.3.32</version></dependency> doc…

针对net core 使用CSRedis 操作redis的三种连接实例方式

1、主从访问 2、哨兵模式 3、集群访问 写法一&#xff1a;写任意一个地址即可&#xff0c;其它节点在运行过程中自动增加&#xff0c;确保每个节点密码一致。如&#xff1a;Console.WriteLine("集群测试");RedisHelper.Initialization(new CSRedis.CSRedisClient(&q…

Http和WebSocket

客户端发送一次http请求&#xff0c;服务器返回一次http响应。 问题&#xff1a;如何在客户端没有发送请求的情况下&#xff0c;返回服务端的响应&#xff0c;网页可以得服务器数据&#xff1f; 1&#xff1a;http定时轮询 客户端定时发送http请求&#xff0c;eg&#…

回溯和分支算法

状态空间图 “图”——状态空间图 例子&#xff1a;农夫过河问题——“图”状态操作例子&#xff1a;n后问题、0-1背包问题、货郎问题(TSP) 用向量表示解&#xff0c;“图”由解向量扩张得到的解空间树。 ——三种图&#xff1a;n叉树、子集树、排序树 ​ 剪枝 不满住条件的…

链表【3】

文章目录 &#x1f433;23. 合并 K 个升序链表&#x1f41f;题目&#x1f42c;算法原理&#x1f420;代码实现 &#x1f437;25. K 个一组翻转链表&#x1f416;题目&#x1f43d;算法原理&#x1f367;代码实现 &#x1f433;23. 合并 K 个升序链表 &#x1f41f;题目 题目链…

Sentinel基础知识

Sentinel基础知识 资源 1、官方网址&#xff1a;https://sentinelguard.io/zh-cn/ 2、os-china: https://www.oschina.net/p/sentinel?hmsraladdin1e1 3、github: https://github.com/alibaba/Sentinel 一、软件简介 Sentinel 是面向分布式服务架构的高可用流量防护组件…

Unity 关于SetParent方法的使用情况

在设置子物体的父物体时&#xff0c;我们使用SetParent再常见不过了。 但是通常我们只是使用其中一个语法&#xff1a; public void SetParent(Transform parent);使用改方法子对象会保持原来位置&#xff0c;跟使用以下方法效果一样&#xff1a; public Transform tran; ga…

【数值计算方法(黄明游)】函数插值与曲线拟合(二):Newton插值【理论到程序】

​ 文章目录 一、近似表达方式1. 插值&#xff08;Interpolation&#xff09;2. 拟合&#xff08;Fitting&#xff09;3. 投影&#xff08;Projection&#xff09; 二、Lagrange插值1. 拉格朗日插值方法2. Lagrange插值公式a. 线性插值&#xff08;n1&#xff09;b. 抛物插值&…

UDS 诊断报文格式

文章目录 网络层目的N_PDU 格式诊断报文的分类&#xff1a;单帧、多帧 网络层目的 N_PDU(network protocol data unit)&#xff0c;即网络层协议数据单元 网络层最重要的目的就是把数据转换成符合标准的单一数据帧&#xff08;符合can总线规范的&#xff09;&#xff0c;从而…