使用 AMD GPU 实现 Segment Anything

news2024/11/19 7:37:33

Segment Anything with AMD GPUs — ROCm Blogs

作者: Sean Song

发布日期:2024年6月4日

介绍

分割任务——识别图像中哪些像素属于某对象——是计算机视觉中的一个基础任务,应用广泛,从科学图像分析到照片编辑。Segment Anything 模型(SAM)是一个先进的图像分割模型,它通过提示分割(promptable segmentation)实现了前所未有的多功能性,使图像分析任务变得更加简单。SAM 可以用于帮助在需要查找和分割图像中任何对象的领域内应用。对于AI研究社区和其他相关领域,SAM 很有可能成为大型AI系统中的关键组件,在多模态环境中实现对世界的全面理解。

在这篇博客中,我们将演示如何在使用 ROCm 的 AMD GPU 上运行 Segment Anything 模型。

SAM

SAM经过训练可以根据各种提示返回有效的**分割掩码**,这些提示包括前景和背景**点**、大致的**框**或**掩码**、非结构化的**文本**,或任何其他指示图像中要分割内容的标识。对有效掩码的要求仅仅意味着即使提示是模糊的,可能指向多个物体(例如,衣服上的一个点可能表示衣服或穿衣服的人),输出结果也应是这些物体中的一个合理掩码。

SAM的先进设计使其能够在没有事先知识的情况下适应新的图像分布和任务——这种特性被称为零样本迁移。SAM是在庞大的SA-1B 数据集上进行训练的,该数据集包含超过10亿个掩码,分布在1100多万张精心挑选的图像上(见下方来自SAM论文的示例),已经展示了令人印象深刻的零样本性能,在许多情况下超越了之前的全监督结果。

png

图片来源:Segment Anything(Kirillov et al.)。

png

图片来源:Introducing Segment Anything (Meta Research)。

SAM 模型由三个关键模块组成:

  • 图像编码器: 为了实现可扩展性和强大的预训练方法,SAM 采用经过 Masked AutoEncoder (MAE) 预训练的 Vision Transformer (ViT),经过最小化调整后以处理输入图像。图像编码器每个图像只运行一次,可以在提示模型之前应用。

  • 提示编码器: SAM 考虑了两组提示:稀疏提示(点、框、文本)和密集提示(掩码)。它通过位置编码加上每种提示类型的学习嵌入来表示点和框,并使用来自 CLIP 的开箱即用文本编码器对自由文本进行编码。密集提示(即掩码)使用卷积进行嵌入,并与图像嵌入按元素相加。

  • 掩码解码器: 掩码解码器高效地将图像嵌入、提示嵌入和输出令牌映射到掩码。这一设计采用修改后的 transformer 解码器块,随后是动态掩码预测头。修改后的解码器块使用提示自注意力和交叉注意力(从提示到图像嵌入和反之亦然)来更新所有嵌入。之后,模型对图像嵌入进行上采样,并通过多层感知机(MLP)将输出令牌映射到动态线性分类器,然后计算每个图像位置的掩码前景概率。

接下来,我们将在三个部分中展示 Segment Anything 模型在具有 ROCm 的 AMD GPU 上的流畅执行:

  • 自动生成所有对象的掩码

  • 使用点作为提示生成掩码

  • 使用框作为提示生成掩码

注意:在撰写此博客时,尽管 SAM 论文中探讨了文本提示的功能,但此功能尚未完全发布。本文不涉及文本提示。

我们在本博客文章中引用了 SAM 的 GitHub 代码仓库。可以在 facebookresearch/segment-anything 找到模型的源代码。

设置

本演示使用以下硬件和软件环境。有关全面的支持详细信息,请参阅 ROCm 文档。

  • 硬件 & 操作系统:

    • AMD Instinct GPU

    • Ubuntu 22.04.3 LTS

  • 软件:

    • ROCm 5.7.0+

    • PyTorch 2.0+

准备

首先让我们安装所需的包。

pip install git+https://github.com/facebookresearch/segment-anything.git
pip install matplotlib opencv-python

添加必要的导入

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

检查测试图像

image = cv2.imread('./images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()

png

下载检查点

Meta 提供了三个预训练模型,其中 vit_h 具有它们中最强大的视觉编码器。点击以下链接来下载对应模型类型的检查点。

  • vit_h (默认): ViT-H SAM model.

  • vit_l: ViT-L SAM model.

  • vit_b: ViT-B SAM model.

在这个演示中,我们使用最大的 vit_h 模型。

wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

自动掩码生成

运行自动掩码生成时,需要将 SAM 模型提供给 SamAutomaticMaskGenerator 类。将 SAM 检查点的路径在下方设置。

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device="cuda")
mask_generator = SamAutomaticMaskGenerator(sam)

检查 sam 模型。

print(sam)

输出:

    Sam(
      (image_encoder): ImageEncoderViT(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
        )
        (blocks): ModuleList(
          (0-31): 32 x Block(
            (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
            (attn): Attention(
              (qkv): Linear(in_features=1280, out_features=3840, bias=True)
              (proj): Linear(in_features=1280, out_features=1280, bias=True)
            )
            (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
            (mlp): MLPBlock(
              (lin1): Linear(in_features=1280, out_features=5120, bias=True)
              (lin2): Linear(in_features=5120, out_features=1280, bias=True)
              (act): GELU(approximate='none')
            )
          )
        )
        (neck): Sequential(
          (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): LayerNorm2d()
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (3): LayerNorm2d()
        )
      )
      (prompt_encoder): PromptEncoder(
        (pe_layer): PositionEmbeddingRandom()
        (point_embeddings): ModuleList(
          (0-3): 4 x Embedding(1, 256)
        )
        (not_a_point_embed): Embedding(1, 256)
        (mask_downscaling): Sequential(
          (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
          (1): LayerNorm2d()
          (2): GELU(approximate='none')
          (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
          (4): LayerNorm2d()
          (5): GELU(approximate='none')
          (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (no_mask_embed): Embedding(1, 256)
      )
      (mask_decoder): MaskDecoder(
        (transformer): TwoWayTransformer(
          (layers): ModuleList(
            (0-1): 2 x TwoWayAttentionBlock(
              (self_attn): Attention(
                (q_proj): Linear(in_features=256, out_features=256, bias=True)
                (k_proj): Linear(in_features=256, out_features=256, bias=True)
                (v_proj): Linear(in_features=256, out_features=256, bias=True)
                (out_proj): Linear(in_features=256, out_features=256, bias=True)
              )
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (cross_attn_token_to_image): Attention(
                (q_proj): Linear(in_features=256, out_features=128, bias=True)
                (k_proj): Linear(in_features=256, out_features=128, bias=True)
                (v_proj): Linear(in_features=256, out_features=128, bias=True)
                (out_proj): Linear(in_features=128, out_features=256, bias=True)
              )
              (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (mlp): MLPBlock(
                (lin1): Linear(in_features=256, out_features=2048, bias=True)
                (lin2): Linear(in_features=2048, out_features=256, bias=True)
                (act): ReLU()
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (norm4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (cross_attn_image_to_token): Attention(
                (q_proj): Linear(in_features=256, out_features=128, bias=True)
                (k_proj): Linear(in_features=256, out_features=128, bias=True)
                (v_proj): Linear(in_features=256, out_features=128, bias=True)
                (out_proj): Linear(in_features=128, out_features=256, bias=True)
              )
            )
          )
          (final_attn_token_to_image): Attention(
            (q_proj): Linear(in_features=256, out_features=128, bias=True)
            (k_proj): Linear(in_features=256, out_features=128, bias=True)
            (v_proj): Linear(in_features=256, out_features=128, bias=True)
            (out_proj): Linear(in_features=128, out_features=256, bias=True)
          )
          (norm_final_attn): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (iou_token): Embedding(1, 256)
        (mask_tokens): Embedding(4, 256)
        (output_upscaling): Sequential(
          (0): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))
          (1): LayerNorm2d()
          (2): GELU(approximate='none')
          (3): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
          (4): GELU(approximate='none')
        )
        (output_hypernetworks_mlps): ModuleList(
          (0-3): 4 x MLP(
            (layers): ModuleList(
              (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)
              (2): Linear(in_features=256, out_features=32, bias=True)
            )
          )
        )
        (iou_prediction_head): MLP(
          (layers): ModuleList(
            (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)
            (2): Linear(in_features=256, out_features=4, bias=True)
          )
        )
      )
    )

要生成掩码,请在图像上运行生成过程。

masks = mask_generator.generate(image)
print(masks[0])

输出:

{'segmentation': array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]]), 'area': 632681, 'bbox': [86, 282, 1621, 566], 'predicted_iou': 1.0396634340286255, 'point_coords': [[1378.125, 581.25]], 'stability_score': 0.9835065603256226, 'crop_box': [0, 0, 1800, 1200]}

掩码生成返回一个掩码列表,每个掩码是一个包含各种数据的字典。键包括:

  • segmentation : 二值掩码,表示感兴趣的区域。

  • area : 掩码在图像中的面积。

  • bbox : 掩码的矩形边界,包括它的左上角坐标 (X, Y) 及其宽度 (W) 和高度 (H),格式为 (X, Y, W, H)。

  • predicted_iou :模型预测的交并比值,表示掩码的质量。

  • point_coords : 生成此掩码的输入点。

  • stability_score : 用于评估掩码质量的附加指标或分数。

  • crop_box : 用于生成此掩码的原始图像裁剪区域,以 XYWH 格式。

将所有掩码叠加显示在图像上。

def process_annotations(annotations):
    if len(annotations) == 0:
        return
    sorted_annotations = sorted(annotations, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_annotations[0]['segmentation'].shape[0], sorted_annotations[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for annotations in sorted_annotations:
        m = annotations['segmentation']
        img[m] = np.concatenate([np.random.random(3), [0.35]])
    ax.imshow(img)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
process_annotations(masks)
plt.show() 

png

自动掩码生成有多个可调参数,这些参数控制如何密集采样点以及删除低质量或重复掩码的阈值。你可以在 文档中找到关于设定参数的更多细节。

使用点作为提示进行掩码生成

通过调用 SamPredictor.set_image 处理图像以生成图像嵌入。`SamPredictor` 会记住这个嵌入,并将其用于后续的掩码预测。

predictor = SamPredictor(sam)
predictor.set_image(image)

要选择卡车,请在其上选择一个点。点以 (x, y) 格式输入模型,并带有标签 1(前景点)或 0(背景点)。可以提供多个点,尽管这里我们只使用一个。所提供的点将显示为图像上的星号。

input_point = np.array([[500, 375]])
input_label = np.array([1]) # A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.

使用 SamPredictor.predict 进行预测。模型返回掩码、这些掩码的质量预测以及可以传递到下一次预测迭代的低分辨率掩码 logits。

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

当 multimask_output=True(默认设置)时,SAM 输出 3 个掩码,其中 scores 提供模型自己对这些掩码质量的评估。此设置旨在处理模棱两可的输入提示,帮助模型区分与提示一致的不同对象。当设置为 False 时,它将返回一个掩码。对于单点等模糊提示,即使只需要一个掩码,也建议使用 multimask_output=True;通过选择 scores 返回的最高得分的掩码,可以得到最好的单个掩码。

def display_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def display_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    display_mask(mask, plt.gca())
    display_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  

png

png

png

单个输入点是模糊的,模型返回了与该点一致的多个对象。要获取单个对象,可以提供多个点。有关更高级的用法,请参阅 facebookresearch/segment-anything。

使用框生成蒙版

SamPredictor可以处理多个指定格式为 (x_min, y_min, x_max, y_max) 的边界框输入,用于给定图像的处理。

此方法假设输入点已经表示为 torch 张量,并且已经转换以匹配输入框架。

import torch
input_boxes = torch.tensor([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750],
], device=predictor.device)

将边界框转换到输入框架,然后预测蒙版。

masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]),
    multimask_output=False,
)

def display_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  
    
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    display_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    display_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

png

在上图中,我们发现三个边界框已被适当地用于分割汽车的不同区域。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2243250.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Cloud Stream实现数据流处理

1.什么是Spring Cloud Stream? 我看很多回答都是“为了屏蔽消息队列的差异,使我们在使用消息队列的时候能够用统一的一套API,无需关心具体的消息队列实现”。 这样理解是有些不全面的,Spring Cloud Stream的核心是Stream&#xf…

无人机飞手入门指南

无人机飞手入门指南旨在为初学者提供一份全面的学习路径和实践建议,帮助新手快速掌握无人机飞行技能并了解相关法规知识。以下是一份详细的入门指南: 一、了解无人机基础知识 1. 无人机构造:了解无人机的组成部分,如机身、螺旋桨…

使用Mac下载MySQL修改密码

Mac下载MySQL MySQL官网链接MySQL​​​​​​ 当进入到官网后下滑到community社区,进行下载 然后选择community sever下载 这里就是要下载的界面,如果需要下载之前版本的话可以点击archives, 可能会因为这是外网原因,有时候下…

两大新兴开发语言大比拼:Move PK Rust

了解 Move 和 Rust 的差异有助于开发者根据项目的具体需求选择最合适的语言。选择不恰当的语言可能会导致项目后期出现技术债务。不同语言有其独特的优势。了解 Move 和 Rust 的差异可以帮助开发者拓展技术视野,发现不同语言在不同领域的应用潜力。 咱们直奔主题&a…

three.js 对 模型使用 视频进行贴图修改材质

three.js 对 模型使用 视频进行贴图修改材质 https://threehub.cn/#/codeMirror?navigationThreeJS&classifyapplication&idvideoModel import * as THREE from three import { OrbitControls } from three/examples/jsm/controls/OrbitControls.js import { GLTFLoad…

【论文分享】利用多源大数据衡量街道步行环境的老年友好性:以中国上海为例

本次给大家带来一篇SCI论文的全文翻译!该论文考虑了绿化程度、可步行性、安全性、形象性、封闭性和复杂性这六个指标,提出了一种基于多源地理空间大数据的新型定量评价模型,用于从老年人和专家的角度评估街道步行环境的老年友好程度&#xff…

计算机网络安全 —— 对称加密算法 DES (一)

一、对称加密算法概念# ​ 我们通过计算机网络传输数据时,如果无法防止他人窃听, 可以利用密码学技术将发送的数据变换成对任何不知道如何做逆变换的人都不可理解的形式, 从而保证了数据的机密性。这种变换被称为加密( encryptio…

6.C操作符详解,深入探索操作符与字符串处理

C操作符详解,深入探索操作符与字符串处理 C语言往期系列文章目录 往期回顾: C语言是什么?编程界的‘常青树’,它的辉煌你不可不知VS 2022 社区版C语言的安装教程,不要再卡在下载0B/s啦C语言入门:解锁基础…

微信小程序 最新获取用户头像以及用户名

一.在小程序改版为了安全起见 使用用户填写来获取头像以及用户名 二.代码实现 <view class"login_box"><!-- 头像 --><view class"avator_box"><button wx:if"{{ !userInfo.avatarUrl }}" class"avatorbtn" op…

Uni-APP+Vue3+鸿蒙 开发菜鸟流程

参考文档 文档中心 运行和发行 | uni-app官网 AppGallery Connect DCloud开发者中心 环境要求 Vue3jdk 17 Java Downloads | Oracle 中国 【鸿蒙开发工具内置jdk17&#xff0c;本地不使用17会报jdk版本不一致问题】 开发工具 HBuilderDevEco Studio【目前只下载这一个就…

【Android、IOS、Flutter、鸿蒙、ReactNative 】屏幕适配

Android Java 屏幕适配 参考 今日头条适配依赖配置 添加设计屏幕尺寸 设置字体大小 通过切换不同屏幕尺寸查看字体大小 设置文本宽高 通过切换不同屏幕尺寸查看文本宽高 Android Compose 屏幕适配 <

从视频帧生成点云数据、使用PointNet++模型提取特征,并将特征保存下来的完整实现。

文件地址 https://github.com/yanx27/Pointnet_Pointnet2_pytorch?spm5176.28103460.0.0.21a95d27ollfze Pointnet_Pointnet2_pytorch\log\classification\pointnet2_ssg_wo_normals文件夹改名为Pointnet_Pointnet2_pytorch\log\classification\pointnet2_cls_ssg "E:…

Websocket如何分块处理数据量超大的消息体

若我们服务端一次性最大处理的字节数是1M,而客户端发来了2M的数据&#xff0c;此时服务端的数据就要被切割成两次传输解码。Http协议中有分块传输&#xff0c;而在Websocket也可以分块处理超大的消息体。在jsr356标准中使用javax.websocket.MessageHandler.Partial可以分块处理…

论文复现_How Machine Learning Is Solving the Binary Function Similarity Problem

1. 内容概述 前言&#xff1a;此代码库支持 USENIX Security 22 论文 《How Machine Learning Is Solving the Binary Function Similarity Problem》&#xff0c;作者包括 Andrea Marcelli 等人&#xff0c;提供了相关代码、数据集和技术细节。 关键内容&#xff1a;技术报告…

【视觉SLAM】2-三维空间刚体运动的数学表示

读书笔记&#xff1a;学习空间变换的三种数学表达形式。 文章目录 1. 旋转矩阵1.1 向量运算1.2 坐标系空间变换1.3 变换矩阵与齐次坐标 2. 旋转向量和欧拉角2.1 旋转向量2.2 欧拉角 3. 四元数 1. 旋转矩阵 1.1 向量运算 对于三维空间中的两个向量 a , b ∈ R 3 a,b \in \R^3 …

【WPF】Prism学习(六)

Prism Dependency Injection 1.依赖注入&#xff08;Dependency Injection&#xff09; 1.1. Prism与依赖注入的关系&#xff1a; Prism框架一直围绕依赖注入构建&#xff0c;这有助于构建可维护和可测试的应用程序&#xff0c;并减少或消除对静态和循环引用的依赖。 1.2. P…

多账号登录管理器(淘宝、京东、拼多多等)

目录 下载安装与运行 解决什么问题 功能说明 目前支持的平台 功能演示 登录后能保持多久 下载安装与运行 下载、安装与运行 语雀 解决什么问题 多个账号的快捷登录与切换 功能说明 支持多个电商平台支持多个账号的登录保持支持快捷切换支持导入导出支持批量删除支持…

UniAPP快速入门教程(一)

一、下载HBuilder 首先需要下载HBuilder开发工具&#xff0c;下载地址:https://www.dcloud.io/hbuilderx.htmlhttps://www.dcloud.io/hbuilder.html 选择Windows正式版.zip文件下载。下载解压后直接运行解压目录里的HBuilderX.exe就可以启动HBuilder。 UniApp的插件市场网址…

PyAEDT:Ansys Electronics Desktop API 简介

在本文中&#xff0c;我将向您介绍 PyAEDT&#xff0c;这是一个 Python 库&#xff0c;旨在增强您对 Ansys Electronics Desktop 或 AEDT 的体验。PyAEDT 通过直接与 AEDT API 交互来简化脚本编写&#xff0c;从而允许在 Ansys 的电磁、热和机械求解器套件之间无缝集成。通过利…

SpringBoot源码解析(四):解析应用参数args

SpringBoot源码系列文章 SpringBoot源码解析(一)&#xff1a;SpringApplication构造方法 SpringBoot源码解析(二)&#xff1a;引导上下文DefaultBootstrapContext SpringBoot源码解析(三)&#xff1a;启动开始阶段 SpringBoot源码解析(四)&#xff1a;解析应用参数args 目录…