Vision - 开源视觉分割算法框架 Grounded SAM2 配置与推理 教程 (1)

news2024/11/26 16:49:15

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/143388189

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


Grounded SAM2

Grounded SAM2 集成多个先进模型的视觉 AI 框架,融合 GroundingDINO、Florence-2 和 SAM2 等模型,实现开放域目标检测、分割和跟踪等多项视觉任务的突破性进展,通过自然语言描述来定位图像中的目标,生成精细的目标分割掩码,在视频序列中持续跟踪目标,保持 ID 的一致性。

Paper: Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks,SAM 版本由 1.0 升级至 2.0

1. 环境配置

GitHub: Grounded-SAM-2

git clone https://github.com/IDEA-Research/Grounded-SAM-2
cd Grounded-SAM-2

准备 SAM 2.1 模型,格式是 pt 的,GroundingDINO 模型,格式是 pth 的,即:

wget https://huggingface.co/facebook/sam2.1-hiera-large/resolve/main/sam2.1_hiera_large.pt?download=true -O sam2.1_hiera_large.pt
wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth

最新模型位置:

cd checkpoints
ln -s [your path]/llm/workspace_comfyui/ComfyUI/models/sam2/sam2_hiera_large.pt sam2_hiera_large.pt

cd gdino_checkpoints
ln -s [your path]/llm/workspace_comfyui/ComfyUI/models/grounding-dino/groundingdino_swinb_cogcoor.pth groundingdino_swinb_cogcoor.pth
ln -s [your path]/llm/workspace_comfyui/ComfyUI/models/grounding-dino/groundingdino_swint_ogc.pth groundingdino_swint_ogc.pth

激活环境:

conda activate sam2

测试 PyTorch:

import torch
print(torch.__version__)  # 2.5.0+cu124
print(torch.cuda.is_available())  # True
exit()
echo $CUDA_HOME

安装 Grounding DINO:

pip install --no-build-isolation -e grounding_dino
pip show groundingdino

安装 SAM2:

pip install --no-build-isolation -e .
pip install --no-build-isolation -e ".[notebooks]"  # 适配 Jupyter
pip show SAM-2

配置参数:视觉分割开源算法 SAM2(Segment Anything 2) 配置与推理

依赖文件:

cd grounding_dino/
pip install -r requirements.txt --verbose

2. 测试图像

测试脚本:grounded_sam2_local_demo.py

导入相关的依赖包:

import os
import cv2
import json
import torch
import numpy as np
import supervision as sv
import pycocotools.mask as mask_util
from pathlib import Path
from torchvision.ops import box_convert
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict

from PIL import Image
import matplotlib.pyplot as plt

配置数据,以及依赖环境,其中包括:

  • 输入文本提示,例如 袜子(socks) 和 吉他(guitar)
  • 输入图像
  • SAM2 模型 v2.1 版本,以及配置
  • GroundingDINO (DETR with Improved deNoising anchOr boxes, 改进的去噪锚框的DETR) 模型,以及配置
  • Box 阈值、文本阈值
  • 输出文件夹与Json

即:

TEXT_PROMPT = "socks. guitar."
#IMG_PATH = "notebooks/images/truck.jpg"
IMG_PATH = "[your path]/llm/vision_test_data/image2.png"

image = Image.open(IMG_PATH)
plt.figure(figsize=(9, 6))
plt.title(f"annotated_frame")
plt.imshow(image)

SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs/grounded_sam2_local_demo")
DUMP_JSON_RESULTS = True

# create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

加载 SAM2 模型,获得 sam2_predictor,即:

# build SAM2 image predictor
sam2_checkpoint = SAM2_CHECKPOINT
model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
sam2_predictor = SAM2ImagePredictor(sam2_model)

加载 GroundingDINO 模型,获得 grounding_model,即:

# build grounding dino model
grounding_model = load_model(
    model_config_path=GROUNDING_DINO_CONFIG, 
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
    device=DEVICE
)

SAM2 加载图像数据,即:

text = TEXT_PROMPT
img_path = IMG_PATH

# image(原图), image_transformed(正则化图像)
image_source, image = load_image(img_path)
sam2_predictor.set_image(image_source)

GroudingDINO 预测 Bounding Box,输入模型、图像、文本、Box和Text阈值,即:

  • load_image()predict() 都来自于 GroundingDINO,数据和模型匹配。
boxes, confidences, labels = predict(
    model=grounding_model,
    image=image,
    caption=text,
    box_threshold=BOX_THRESHOLD,
    text_threshold=TEXT_THRESHOLD,
)

适配不同 Box 的格式:

h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()

SAM2 依赖的 PyTorch 配置:

# FIXME: figure how does this influence the G-DINO model
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

SAM2 预测图像:

masks, scores, logits = sam2_predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_boxes,
    multimask_output=False,
)

后处理预测结果:

"""
Post-process the output of the model to get the masks, scores, and logits for visualization
"""
# convert the shape to (n, H, W)
if masks.ndim == 4:
    masks = masks.squeeze(1)


confidences = confidences.numpy().tolist()
class_names = labels

class_ids = np.array(list(range(len(class_names))))

labels = [
    f"{class_name} {confidence:.2f}"
    for class_name, confidence
    in zip(class_names, confidences)
]

输出结果可视化:

"""
Visualize image with supervision useful API
"""
img = cv2.imread(img_path)
detections = sv.Detections(
    xyxy=input_boxes,  # (n, 4)
    mask=masks.astype(bool),  # (n, h, w)
    class_id=class_ids
)

box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)

label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
plt.figure(figsize=(9, 6))
plt.title(f"annotated_frame")
plt.imshow(annotated_frame[:,:,::-1])

mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
plt.figure(figsize=(9, 6))
plt.title(f"annotated_frame")
plt.imshow(annotated_frame[:,:,::-1])

GroundingDINO 的 Box 效果,准确检测出 袜子 和 吉他,两类实体:

Box

SAM2 的分割效果,如下:
Seg

转换成 COCO 数据格式:

def single_mask_to_rle(mask):
    rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
    rle["counts"] = rle["counts"].decode("utf-8")
    return rle

if DUMP_JSON_RESULTS:
    # convert mask into rle format
    mask_rles = [single_mask_to_rle(mask) for mask in masks]

    input_boxes = input_boxes.tolist()
    scores = scores.tolist()
    # save the results in standard format
    results = {
        "image_path": img_path,
        "annotations" : [
            {
                "class_name": class_name,
                "bbox": box,
                "segmentation": mask_rle,
                "score": score,
            }
            for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
        ],
        "box_format": "xyxy",
        "img_width": w,
        "img_height": h,
    }
    
    with open(os.path.join(OUTPUT_DIR, "grounded_sam2_local_image_demo_results.json"), "w") as f:
        json.dump(results, f, indent=4)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2232011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3674矢量网络分析仪-003噪声系数测量选件

3674X-003 噪声系数测量选件 3674系列矢量网络分析仪 综述 3674X-003噪声系数测量一次连接,可同时测试S参数、噪声系数、噪声参数、增益压缩和变频增益等多种参数。基于冷源噪声系数测试方法,可进行精确的噪声系数和噪声参数测试。 •特点 • 3674X…

DDRPHY数字IC后端设计实现系列专题之后端设计导入,IO Ring设计

本章详细分析和论述了 LPDDR3 物理层接口模块的布图和布局规划的设计和实 现过程,包括设计环境的建立,布图规划包括模块尺寸的确定,IO 单元、宏单元以及 特殊单元的摆放。由于布图规划中的电源规划环节较为重要, 影响芯片的布线资…

如何将MySQL彻底卸载干净

目录 背景: MySQL的卸载 步骤1:停止MySQL服务 步骤2:软件的卸载 步骤3:残余文件的清理 步骤4:清理注册表 步骤五:删除环境变量配置 总结: 背景: MySQL卸载不彻底往往会导致重新安装失败…

Vue生成名片二维码带logo并支持下载

一、需求 生成一张名片,名片上有用户信息以及二维码,名片支持下载功能(背景样式可更换,忽略本文章样图样式)。 二、参考文章 这不是我自己找官网自己摸索出来的,是借鉴各位前辈的,学以致用&am…

【AIGC】深入探索『后退一步』提示技巧:激发ChatGPT的智慧潜力

博客主页: [小ᶻZ࿆] 本文专栏: AIGC | ChatGPT 文章目录 💯前言💯“后退一步”技巧介绍技巧目的 💯“后退一步”原理“后退一步”提示技巧与COT和TOT的对比实验验证 💯如何应用“后退一步”策略强调抽象思考引导提…

Python | Leetcode Python题解之第520题检测大写字母

题目: 题解: class Solution:def detectCapitalUse(self, word: str) -> bool:# 若第 1 个字母为小写,则需额外判断第 2 个字母是否为小写if len(word) > 2 and word[0].islower() and word[1].isupper():return False# 无论第 1 个字…

【python ASR】win11-从0到1使用funasr实现本地离线音频转文本

文章目录 前言一、前提条件安装环境Python 安装安装依赖,使用工业预训练模型最后安装 - torch1. 安装前查看显卡支持的最高CUDA的版本,以便下载torch 对应的版本的安装包。torch 中的CUDA版本要低于显卡最高的CUDA版本。2. 前往网站下载[Pytorch](https://pytorch.o…

Java日志脱敏——基于logback MessageConverter实现

背景简介 日志脱敏 是常见的安全需求,最近公司也需要将这一块内容进行推进。看了一圈网上的案例,很少有既轻量又好用的轮子可以让我直接使用。我一直是反对过度设计的,而同样我认为轮子就应该是可以让人拿去直接用的。所以我准备分享两篇博客…

Java面试经典 150 题.P26. 删除有序数组中的重复项(003)

本题来自:力扣-面试经典 150 题 面试经典 150 题 - 学习计划 - 力扣(LeetCode)全球极客挚爱的技术成长平台https://leetcode.cn/studyplan/top-interview-150/ 题解: class Solution {public int removeDuplicates(int[] nums) …

Prometheus套装部署到K8S+Dashboard部署详解

1、添加helm源并更新 helm repo add prometheus-community https://prometheus-community.github.io/helm-charts helm repo update2、创建namespace kubectl create namespace monitoring 3、安装Prometheus监控套装 helm install prometheus prometheus-community/prome…

如何选择到印尼的海运代理

如何选择到印尼的海运代理 选择合适的海运代理的重要性 海运代理负责安排货物从发货地到目的地的整个运输过程,包括装运、清关、仓储等服务。一个可靠的海运代理能确保货物安全准时到达,并帮助企业节省时间和成本。 选择海运代理需考虑的主要因素 公司…

python常用的第三方库下载方法

方法一:打开pycharm-打开项目-点击左侧图标查看已下载的第三方库-没有下载搜索后点击install即可直接安装--安装成功后会显示在installed列表 方法二:打开dos窗口输入命令“pip install requests“后按回车键,看到successfully既安装成功&…

FFmpeg 4.3 音视频-多路H265监控录放C++开发八,使用SDLVSQT显示yuv文件 ,使用ffmpeg的AVFrame

一. AVFrame 核心回顾,uint8_t *data[AV_NUM_DATA_POINTERS] 和 int linesize[AV_NUM_DATA_POINTERS] AVFrame 存储的是解码后的数据,(包括音频和视频)例如:yuv数据,或者pcm数据,参考AVFrame结…

jenkins 构建报错 Cannot run program “sh”

原因 在 windows 操作系统 jenkins 自动化部署的时候, 由于自动化构建的命令是 shell 执行的,而默认windows 从 path 路径拿到的 shell 没有 sh.exe ,因此报错。 解决方法 前提是已经安装过 git WINR 输入cmd 打开命令行, 然后输入where git 获取 git 的路径, …

基于Spring Boot的高校物品捐赠管理系统设计与实现,LW+源码+讲解

摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此,在计算机上安装高校物品捐赠管理系统软件来发挥其高效地信息处理的作用&a…

AndroidStudio通过Bundle进行数据传递

作者:CSDN-PleaSure乐事 欢迎大家阅读我的博客 希望大家喜欢 使用环境:AndroidStudio 目录 1.新建活动 2.修改页面布局 代码: 效果: 3.新建类ResultActivity并继承AppCompatActivity 4.新建布局文件activity_result.xml 代…

测试分层:减少对全链路回归依赖的探索!

引言:测试分层与全链路回归的挑战 在软件开发和测试过程中,全链路回归测试往往是一个复杂且耗费资源的环节,尤其在系统庞大且模块众多的场景下,全链路测试的集成难度显著提高。而“测试分层”作为一种结构化的测试方法&#xff0…

【python】OpenCV—findContours(4.5)

文章目录 1、功能描述2、原理分析3、代码实现4、效果展示5、完整代码6、参考 1、功能描述 输入图片,计算出图片中的目标到相机间的距离 2、原理分析 用最简单的三角形相似性 已知参数,物体的宽度 W W W,物体到相机的距离 D D D&#xff0…

jmeter基础01-3_环境准备-Linux系统安装jdk

Step1. 查看系统类型 打开终端,命令行输入uname -a,显示所有系统信息,包括内核名称、主机名、内核版本等。 如果输出是x86_64,则系统为64位。如果输出是i686 或i386,则系统为32位。 Step2. 官网下载安装包 https://www…

获取JSON对象的时候,值会自动带上双引号

问题:当使用下方代码,获取JsonNode对象的时候,从该对象中通过键获取的值会自动带上双引号。 JsonNode jsonNode new ObjectMapper().readTree("JSON字符串"); 注意:以上方法是获得的JsonNode对象,不是JSO…