【小贪】项目实战——Zero-shot根据文字提示分割出图片目标掩码

news2025/1/6 17:54:06

目标描述

给定RGB视频或图片,目标是分割出图像中的指定目标掩码。我们需要复现两个Zero-shot的开源项目,分别为IDEA研究院的GroundingDINO和Facebook的SAM。首先使用目标检测方法GroundingDINO,输入想检测目标的文字提示,可以获得目标的anchor box。将上一步获得的box信息作为SAM的提示,分割出目标mask。具体效果如下(测试数据来自VolumeDeform数据集):

在这里插入图片描述

其中GroundingDINO根据white shirt的文字输入计算的box信息为:"shirt_000500": "[194.23726, 2.378189, 524.09503, 441.5135]"。项目实测下来单张图片的预测速度GroundingDINO要慢于SAM。GroundingDINO和SAM均会给出多个预测结果,当选择置信度最高的结果时两个模型也会存在预测不准确的情况。

论文简介

GroundingDINO

GroundingDINO extends a closedset detector DINO by performing vision-language modality fusion at multiple phases, including a feature enhancer, a language-guided query selection module, and a cross-modality decoder. Such a deep fusion strategy effectively improves open-set object detection.

在这里插入图片描述

SAM

  • 简介:使用三个组件建立图像分割的foundation model,解决一系列下游分割问题,可zero-shot生成
  • 关键技术:
    1. promptable分割任务:使用prompt engineering,prompt不确定时输出多目标mask
    2. 分割模型:image encoder + prompt encoder -> mask decoder
    3. 数据驱动:SA-1B(1B masks from 11M imgs)手工标注->半自动->全自动
  • Limitation:存在不连贯不精细的mask结果;交互式实时mask生成但是img encoder耗时;text-to-mask任务效果不鲁棒

在这里插入图片描述
在这里插入图片描述

项目实战

两个项目的复现很简单,按照github的readme配置相关环境并运行程序。当然也可以直接使用一站式项目Grounded Segment Anything等。当需要分割的图片较多时,可以修改GroundingDINO的demo.shdemo/inference_on_a_image.py文件将检测结果保存至json文件。

demo/inference_on_a_image.py文件

# 修改plot_boxes_to_image函数输出box信息
image_with_box, mask, box_coor = plot_boxes_to_image(image_pil, pred_dict)
# obj为目标名称,i为当前图片的索引
obj = 'shirt'
data = {f'{obj}_{str(i).zfill(6)}': str(list(box_coor.cpu().detach().numpy()))}
with open("box.json", "r", encoding="utf-8") as f:
    old_data = json.load(f)
    old_data.update(data)
with open("box.json", "w", encoding="utf-8") as f:
    json.dump(old_data, f, indent=4)
    # f.write(json.dumps(old_data, indent=4, ensure_ascii=False))
f.close()

然后SAM再读取json文件获取box信息,将SAM的输入提示改为box。

测试代码

import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
import json

coords = []

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))


def on_click(event):
    global coords
    if event.button == 1:
        x, y = event.xdata, event.ydata
        print(f"鼠标左键点击:x={x:.2f}, y={y:.2f}")
        coords.append([x, y])
        # if len(coords) == 2:
        #     fig.canvas.mpl_disconnect(cid)
    elif event.button == 3:
        print("鼠标右键点击")


def get_mask(image, mask_id=1, click_coords=False, choose_mask=False, box=None):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # plt.figure(figsize=(10, 10))
    # plt.imshow(image)
    # plt.axis('on')

    if click_coords:
        global coords
        fig, ax = plt.subplots()  # 创建画布和子图对象
        fig.set_size_inches(30, 20)  # 设置宽度和高度,单位为英寸(inch)
        ax.imshow(image)
        cid = fig.canvas.mpl_connect('button_press_event', on_click)
        plt.show()
    else:  # 如果使用 必须全局
        coords = []

    from segment_anything import SamPredictor, sam_model_registry
    sam_checkpoint = "sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    device = "cuda"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    predictor = SamPredictor(sam)
    predictor.set_image(image)

    input_point = np.array(coords)
    input_label = np.array([1] * len(coords))

    # plt.figure(figsize=(10, 10))
    # plt.imshow(image)
    # show_points(input_point, input_label, plt.gca())
    # plt.axis('on')
    # plt.show()

    input_box = box
    if len(coords) == 0:
        input_point = None
        input_label = None
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box[None, :],
        multimask_output=True)

    if choose_mask:
        plt.figure(figsize=(60, 20))
        plt.subplot(1, 3, 1)
        plt.imshow(image)
        show_mask(masks[0], plt.gca())
        # show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask 0, Score: {scores[0]:.3f}", fontsize=18)
        plt.subplot(1, 3, 2)
        plt.imshow(image)
        show_mask(masks[1], plt.gca())
        # show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask 1, Score: {scores[1]:.3f}", fontsize=18)
        plt.subplot(1, 3, 3)
        plt.imshow(image)
        show_mask(masks[2], plt.gca())
        # show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask 2, Score: {scores[1]:.3f}", fontsize=18)
        plt.show()
        mask_id = int(input())  # 通过输入idx或者设置特定的idx输出

    mask = masks[mask_id]
    mask = np.tile(np.expand_dims(mask, axis=-1), 3)
    mask_data = np.where(mask, 255, 0)
    # mask_image = np.where(mask, image/255, 0.)
    # plt.figure(figsize=(10, 10))
    # plt.imshow(mask_image)
    # plt.show()
    if click_coords: coords.clear()
    return mask_data


if __name__ == '__main__':
    obj = 'shirt'
    color_path = f'/Data/VolumeDeformData/{obj}/data/'
    mask_path = f'/Data/VolumeDeformData/{obj}/mask/'
    if not os.path.exists(mask_path):
        os.makedirs(mask_path)

    img_paths = []
    for extension in ["jpg", "png", "jpeg"]:
        img_paths += glob.glob(os.path.join(color_path, "*.{}".format(extension)))

    json_path = 'GroundingDINO-main/box.json'
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
        for i in range(len(img_paths) // 2):
            img_name = f'frame-{str(i).zfill(6)}.color.png'
            img = cv2.imread(color_path + img_name)
            id = f'{obj}_{str(i).zfill(6)}'
            box = np.array(list(map(float, data[id][1:-1].split(','))))
            mask = get_mask(img, mask_id=2, click_coords=False, choose_mask=False, box=box)
            cv2.imwrite(mask_path + str(i).zfill(6) + '.png', mask)
            print(img_name)
    f.close()

相关链接

  • GroundingDINO github arXiv
  • SAM Demo github arXiv
  • Grounded Segment Anything github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1883149.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

primetime中cell和net的OCV

文章目录 前言一、Cell OCV1. POCV coefficient file2. POCV Slew-Load Table in Liberty Variation Format(LVF lib) 二、Net OCV三、如何check OCV是否已加上?总结 前言 在生产中,外界环境的各种变化,比如PVT&#…

grpc学习golang版( 六、服务器流式传输 )

系列文章目录 第一章 grpc基本概念与安装 第二章 grpc入门示例 第三章 proto文件数据类型 第四章 多服务示例 第五章 多proto文件示例 第六章 服务器流式传输 第七章 客户端流式传输 第八章 双向流示例 文章目录 一、前言二、定义proto文件三、拷贝任意文件进项目四、编写serve…

vscode搭建suricata调试环境

一、环境 windows10 wsl2 $ lsb_release -a No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 20.04.2 LTS Release: 20.04 Codename: focal二、编译 2.1 下载源码 wget https://www.openinfosecfoundation.org/download/suri…

为什么越来越多的人选择做债务重组?

说到债务重组,很多人可能一头雾水。但简单来说,就是帮你优化债务结构,减轻还款压力。 为什么现在这么多人会选择做债务重组? 保护工作和名声:有些在好单位上班的人,怕债务问题影响工作,不想让单…

解决Python用xpath爬取不到数据的一个思路

前言 最近在学习Python爬虫的知识,既然眼睛会了难免忍不住要实践一把。 不废话直接上主题 代码不复杂,简单的例子奉上: import requests from lxml import etreecookie 浏览器F12网络请求标头里有 user_agent 浏览器F12网络请求标头里有…

论文翻译 | (DSP)展示-搜索-预测:为知识密集型自然语言处理组合检索和语言模型

摘要 检索增强式上下文学习已经成为一种强大的方法,利用冻结语言模型 (LM) 和检索模型 (RM) 来解决知识密集型任务。现有工作将这些模型结合在简单的“检索-读取”流程中,其中 RM 检索到的段落被插入到 LM 提示中。 为了充分发挥冻结 LM 和 RM 的…

API-本地存储

学习目标: 掌握本地存储 学习内容: 本地存储介绍本地存储分类存储复杂数据类型 本地存储介绍: 以前我们页面写的数据一刷新页面就没有了,是不是? 随着互联网的快速发展,基于网页的应用越来越普遍,同时也…

反向沙箱技术:安全隔离上网

在信息化建设不断深化的今天,业务系统的安全性和稳定性成为各公司和相关部门关注的焦点。面对日益复杂的网络威胁,传统的安全防护手段已难以满足需求。深信达反向沙箱技术,以其独特的设计和强大的功能,成为保障政务系统信息安全的…

MSPG3507——蓝牙接收数据显示在OLED,滴答定时器延时500MS

#include "ti_msp_dl_config.h" #include "OLED.h" #include "stdio.h"volatile unsigned int delay_times 0;//搭配滴答定时器实现的精确ms延时 void delay_ms(unsigned int ms) {delay_times ms;while( delay_times ! 0 ); } int a0; …

MySQL-数据操作类型的角度理解 S锁 X锁

文章目录 1、S锁和S锁互相兼容2、S锁和X锁互斥3、X锁和X锁也互斥4、X锁和S锁也互斥5、select * from account for update;6、select * from account for update nowait;7、select * from account for update skip locked; 1、S锁和S锁互相兼容 2、S锁和X锁互斥 3、X锁和X锁也互…

换天空背景的软件有哪些?摄影师必备,让背景从灰暗到绚烂

在摄影的世界里,背景往往能够为照片增添一种难以言喻的情感色彩。 有时,一个简单的天空背景更换,就能让整张照片焕发出全新的生命力,表达出摄影师想要传达的情感和故事。 如今,随着科技的发展,一些换天空…

开源205W桌面充电器,140W+65W升降压PD3.1快充模块(2C+1A口),IP6557+IP6538

开源一个基于IP6557和IP6538芯片的205W升降压快充模块(140W65W),其中一路C口支持PD3.1协议,最高输出28V5A,另一路是A口C口,最高输出65W(20V3.25A),可搭配一个24V10A的开关…

LLM对程序员的冲击和影响

1LLM 在软件开发过程中的单点提效 我这里罗列一些更多的可能用途: 智能代码提示代码片段智能生成SQL 语句的智能生成与调优更高效更精准的静态代码检查与自动修复(非 rule-based)智能辅助的代码评审与代码重构单元测试和接口测试代码的自动…

ARM功耗管理软件之时钟电源树

安全之安全(security)博客目录导读 思考:功耗管理软件栈及示例?WFI&WFE?时钟&电源树?DVFS&AVS? 目录 一、时钟&电源树简介 二、时钟树示例 三、电源树示例 一、时钟&电源树简介 时钟门控与自…

炎黄数智人:国家体育总局冬运中心——AI裁判与教练“观君”赋能冰雪运动新篇章

在科技创新的浪潮下,国家体育总局冬季运动管理中心(以下简称“冬运中心”)揭开了人工智能在体育领域应用的新篇章。隆重宣布推出革命性的AI裁判与教练系统——“观君”,该系统将在冰雪运动项目中大放异彩,为运动员的训…

【Kaggle】Telco Customer Churn 电信用户流失预测案例

⭐️前言:案例学习说明与案例建模流程 我们将围绕Kaggle中的电信用户流失数据集(Telco Customer Churn)进行用户流失预测。在此过程中,将综合应用此前所介绍的各种方法与技巧,并在实践中提炼总结更多实用技巧。 ⭐️对…

prometheus 安装node_exporter, node_exporter 安装最新版 普罗米修思安装监控服务器client

1. 本文介绍两种安装方式,一种安装为service,使用systemctl start node_exporter管理,第二种为安装docker内 容器内使用。 1.1 安装到系统内: 1.1.1 github地址: Releases prometheus/node_exporter GitHub ​ 1.1.2 下载命…

基于移动端的助农电商系统的设计与实现08655

基于移动端的助农电商系统的设计与实现 XXX专业XX级XX班:XXX 指导教师:XXX 摘要 近年来,电子商务的快速发展引起了行业和学术界的高度关注。基于移动端的助农电商系统旨在为用户提供一个简单、高效、便捷的农产品购物体验,它不…

嵌入式以太网硬件构成与MAC、PHY芯片功能介绍

一.以太网电路基本构成 1.总体介绍 对于上述三部分,并不一定都是独立的芯片,主要有以下几种情况: CPU内部集成了MAC和PHY,难度较高; CPU内部集成MAC,PHY采用独立芯片(主流方案); CPU不集成MAC和PHY&#…

安卓应用开发学习:通过腾讯地图SDK实现定位功能

一、引言 这几天有些忙,耽误了写日志,但我的学习始终没有落下,有空我就会研究《 Android App 开发进阶与项目实战》一书中定位导航方面的内容。在我的手机上先后实现了“获取经纬度及地理位置描述信息”和“获取导航卫星信息”功能后&#x…