【人工智能前沿弄潮】—— SAM从提示生成物体mask

news2024/9/28 1:22:40

SAM从提示生成物体mask

Segment Anything Model(SAM)根据指示所需的对象来预测对象掩码。该模型首先将图像转换为图像嵌入,从而可以从提示中高效地生成高质量的掩码。

SamPredictor类为模型提供了一个简单的接口来提示模型。用户可以首先使用set_image方法设置图像,该方法会计算所需的图像嵌入。然后,可以通过predict方法提供提示,以从这些提示中高效地预测掩码。模型可以接受点和框提示以及先前迭代预测的掩码作为输入。

设置

导入所需的库和用于显示点、框和掩码的辅助函数。

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

示例图像

image = cv2.imread('images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()


请添加图片描述

使用SAM选择对象

首先,加载SAM模型和预测器。将下面的路径更改为指向SAM检查点。为了获得最佳结果,建议在CUDA上运行并使用默认模型。

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

通过调用SamPredictor.set_image来处理图像,以生成图像嵌入。SamPredictor会记住这个嵌入,并在后续的掩码预测中使用它。

predictor.set_image(image)

要选择卡车,选择其上的一个点。点以(x,y)格式输入模型,并附带标签1(前景点)或0(背景点)。可以输入多个点;在这里我们只使用一个。所选点将在图像上显示为星号。

input_point = np.array([[500, 375]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  


请添加图片描述

使用SamPredictor.predict进行预测。模型将返回掩码、这些掩码的质量预测以及可以传递给下一次预测迭代的低分辨率掩码logits。

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

使用 multimask_output=True(默认设置),SAM会输出3个掩码,其中 scores 给出了模型对这些掩码质量的自我估计。这个设置旨在处理模糊的输入提示,有助于模型根据提示消除不同的一致对象。当设置为 False 时,它会返回一个单独的掩码。对于模糊的提示,比如一个单独的点,建议使用 multimask_output=True,即使只需要一个单独的掩码;最佳的单个掩码可以通过选择在 scores 中返回的最高分的掩码来选择。这通常会导致更好的掩码。

masks.shape  # (number_of_masks) x H x W
(3, 1200, 1800)
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  
  

请添加图片描述
请添加图片描述


请添加图片描述

使用额外的点指定特定对象

单个输入点是模糊的,并且模型返回了与之一致的多个对象。要获取单个对象,可以提供多个点。如果可用,还可以将来自上一次迭代的掩码提供给模型以帮助预测。在使用多个提示指定单个对象时,可以通过设置 multimask_output=False 来请求单个掩码。

input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
masks.shape
(1, 1200, 1800)
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 


请添加图片描述

为了排除汽车并仅指定窗户,可以提供一个背景点(标签为0,这里以红色显示)。

input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 


请添加图片描述

使用框指定特定对象

模型还可以接受框作为输入,以xyxy格式提供。

input_box = np.array([425, 600, 700, 875])
masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()


请添加图片描述

结合点和框

可以通过将两种类型的提示都包含在预测器中来组合点和框。在这里,可以使用这种方法仅选择卡车的轮胎,而不是整个车轮。

input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()


请添加图片描述

批量的提示输入

SamPredictor可以使用predict_torch方法为同一图像接受多个输入提示。此方法假定输入点已经是torch张量,并且已经被转换为输入帧。例如,假设我们从目标检测器中有几个框输出。

input_boxes = torch.tensor([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750],
], device=predictor.device)

将框转换为输入帧,然后预测掩码。SamPredictor将所需的转换存储在transform字段中,以便轻松访问,尽管它也可以直接实例化,用于例如数据加载器中的使用(参见segment_anything.utils.transforms)。

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)
masks.shape  # (batch_size) x (num_predicted_masks_per_input) x H x W
torch.Size([4, 1, 1200, 1800])
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()


在这里插入图片描述

端到端批量推断

如果所有提示都提前准备好,就可以直接以端到端的方式运行SAM。这也允许对图像进行批处理。

image1 = image  # truck.jpg from above
image1_boxes = torch.tensor([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750],
], device=sam.device)

image2 = cv2.imread('images/groceries.jpg')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
image2_boxes = torch.tensor([
    [450, 170, 520, 350],
    [350, 190, 450, 350],
    [500, 170, 580, 350],
    [580, 170, 640, 350],
], device=sam.device)

这两个图像和提示都作为已经转换到正确帧的PyTorch张量进行输入。输入被打包成图像列表,其中每个元素是一个字典,包含以下键:

  • image:输入图像,以CHW格式的PyTorch张量形式。
  • original_size:在将图像转换为输入SAM之前的图像尺寸,以(H,W)格式表示。
  • point_coords:批量点提示的坐标。
  • point_labels:批量点提示的标签。
  • boxes:批量输入的框。
  • mask_inputs:批量输入的掩码。

如果没有提示,可以排除相应的键。

from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

def prepare_image(image, transform, device):
    image = transform.apply_image(image)
    image = torch.as_tensor(image, device=device.device) 
    return image.permute(2, 0, 1).contiguous()
batched_input = [
     {
         'image': prepare_image(image1, resize_transform, sam),
         'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
         'original_size': image1.shape[:2]
     },
     {
         'image': prepare_image(image2, resize_transform, sam),
         'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
         'original_size': image2.shape[:2]
     }
]

运行模型。

batched_output = sam(batched_input, multimask_output=False)

输出是一个列表,其中包含每个输入图像的结果,列表元素是带有以下键的字典:

  • masks: 预测的二进制掩码的批处理torch张量,尺寸与原始图像相同。
  • iou_predictions: 模型对每个掩码质量的预测。
  • low_res_logits: 每个掩码的低分辨率logits,可以在后续迭代中作为掩码输入传回模型。
batched_output[0].keys()
dict_keys(['masks', 'iou_predictions', 'low_res_logits'])
fig, ax = plt.subplots(1, 2, figsize=(20, 20))

ax[0].imshow(image1)
for mask in batched_output[0]['masks']:
    show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:
    show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')

ax[1].imshow(image2)
for mask in batched_output[1]['masks']:
    show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:
    show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')

plt.tight_layout()
plt.show()


请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/853588.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HTML——格式化文本与段落

😊HTML——格式化文本与段落 🌏前言🎭HTML文本标签🎯主体内容body标签🎯标题字标签🎯空格和特殊字符 🎭格式化文本标签🎯文本修饰标签🎯计算机输出标签(成对标…

基于MATLAB小波变换的信号突变点检测

之前在不经意间也有接触过求突变点的问题。在我看来,与其说是求突变点,不如说是我们常常玩的"找不同"。给你两幅图像,让你找出两个图像中不同的地方,我认为这其实也是找突变点在生活中的应用之一吧。回到找突变点位置上…

Linux部署Zabbix主机监控

192.168.136.55 服务端 192.168.136.56 客户端 一、服务端 1.1 安装lamp环境 #关闭防火墙以及SELINUX systemctl disable firewalld systemctl stop firewalld sed -i s/SELINUXenforcing$/SELINUXdisabled/g /etc/selinux/config setenforce 0设置yum源 yum install epe…

Cocos Creator 3.8 后期效果 Shader 编写(2/2) 进阶篇

前言 在上一篇文章中,麒麟子给大家分享了如何在 Cocos Creator 3.8 中的自定义管线中,添加属于自己的后期效果 Shader。 但基于 BlitScreen 的方案,我们只能编写最简单后效 Shader,如果我们想要支持更多复杂的 Shader&#xff0c…

pc端与flutter通信失效, Method not found

报错情况描述:pc端与flutter通信,ios端能实现通信,安卓端通信报错 报错通信代码: //app消息通知window.callbackName function (res) {window?.jsBridge && window.jsBridge?.postMessage(JSON.stringify(res), "…

axios的使用和接口请求统一封装处理

axios官网:axios中文网|axios API 中文文档 | axios 简单封装:配置基础路径和超时时间,还有请求拦截器和响应拦截器 //对axios进行二次封装 import axios from axios//1、利用axios对象的方法create,去创建一个axios实例 const requests …

Redux基础知识,Redux部分源码分析(手写)

复合组件通信的两种方案: 基于props属性实现父子组件通信(或具备相同父亲的兄弟组件)基于context上下文实现祖先和后代组件间的通信(或具备相同祖先的平行组件) 除了以上方案,其实还可以基于公共状态管理(Redux)实现组件间的通信…

有哪些pdf修改方法?这几种方法学会就够了

有哪些pdf修改方法?PDF是一种非常常见的电子文档格式,它有很多优点,例如可读性强、易于保护、易于打印等等。但是,有时候我们需要对PDF进行修改,例如添加、删除或修改文本、更改图片、合并或分割文件等等。那么今天就给…

对强缓存和协商缓存的理解

浏览器缓存的定义: 浏览器缓存是浏览器在本地磁盘对用户最近请求过的文档进行存储,当访问者再次访问同一页面时,浏览器就可以直接从本地磁盘加载文档。 浏览器缓存分为强缓存和协商缓存。 浏览器是如何使用缓存的: 浏览器缓存…

天津市城市管理委员会莅临道本科技,共同探讨加快推进城市综合执法数字化新模式

2023年8月4日,市城管委处长李春利带队莅临道本科技考察指导,与道本科技董事长王智勇共同探讨加快推进城市综合执法数字化新模式。 会议上,董事长王智勇着重介绍了道本科技最新研发上线的法治大数据应用产品“合规数知法用法平台”。他表示&am…

微信开发之检测僵尸粉的技术实现

简要描述: 检测好友状态 请求URL: http://域名地址/checkZombie 请求方式: POST 请求头Headers: Content-Type:application/jsonAuthorization:login接口返回 参数: 参数名必选类型说明…

《算法和数据结构》算法篇

前言 我大学的时候比较疯狂,除了上课的时候,基本都是在机房刷题,当然,有时候连上课都在想题目,纸上写好代码,一下课就冲进机房把代码敲了,目的很单纯,为了冲排行榜,就像玩…

C++ 派生类成员的标识与访问——作用域分辨符

在派生类中,成员可以按访问属性分为以下四种: (1)不可访问成员。这是从基类私有成员继承下来的,派生类或是建立派生类对象的模块都无法访问到它们,如果从派生类继续派生新类,也是无法访问的。 &…

OpenLayers入门,OpenLayers视图飞行动画,OpenLayers飞行到指定经纬度位置

专栏目录: OpenLayers入门教程汇总目录 前言 本章实现OpenLayers视图飞行动画,根据经纬度和动画持续时长,飞行到指定地图位置。 上一章中可以直接通过修改中心点和层级跳转到指定位置:《Openlayers入门,Openlayers调整中心点坐标、Openlayers调整缩放级别、Openlayers调…

第八章:Linux信号

系列文章目录 文章目录 系列文章目录前言linux中的信号进程对信号的处理信号的释义 信号的捕捉信号的捕捉signal()信号的捕捉sigaction() 信号的产生通过终端按键产生信号前台进程与后台进程 kill()用户调用kill向操作系统发送信号raise()进程自己给自己发任意信号(…

Qt事件过滤器

1 介绍 事件过滤器是一种机制,当某个QObject没有所需要的事件功能时,可将其委托给其它QObject,通过eventFilter成员函数来过滤实现功能。 2 主要构成 委托: ui->QObject1->installEventFilter(QObject2); eventFilter声明 …

【C++精华铺】4.C++类和对象(上)面向对象、类、this指针

目录 1. 面向过程和面向对象 2. 类的引入 3. 类的定义 4. 类的访问限定符和封装 4.1 类的访问限定符 4.2 封装 5. 类的作用域 6. 类的实例化 7. 类对象模型 7.1 类对象的存储方式 7.2 类的大小 7.2.1 空类的大小 7.2.2 结构体内存对齐规则 8. this关键字深入讲解 8.1…

如何选择适合自己的考试培训系统

随着考试的逐渐增多和竞争的加剧,许多人开始关注考试培训系统,以提高他们的考试成绩。然而,选择适合自己的考试培训系统并不容易,因为市场上有许多不同的培训系统可供选择。 1. 确定目标 在选择培训系统之前,首先要明…

Java并发编程(一)多线程基础概念

概述 多线程技术:基于软件或者硬件实现多个线程并发执行的技术 线程可以理解为轻量级进程,切换开销远远小于进程 在多核CPU的计算机下,使用多线程可以更好的利用计算机资源从而提高计算机利用率和效率来应对现如今的高并发网络环境 并发编程…

无涯教程-Perl - getpriority函数

描述 此函数返回进程(PRIO_PROCESS),进程组(PRIO_PGRP)或用户(PRIO_USER)的当前优先级。 参数WHICH指定要为PRIO_PROCESS,PRIO_PGRP或PRIO_USER之一设置优先级的实体,WHO是要设置的进程ID或用户ID。 WHO的值为0定义了当前流程,流程组或用户。这会在不支持系统getpriority()函…