SAM:Segment Anything 代码复现和测试 基本使用

news2025/1/11 11:30:11

相关地址

代码:
https://github.com/facebookresearch/segment-anything
在线网站:
https://segment-anything.com/demo

环境配置

建议可以clone下来学习相关代码,安装可以不依赖与这个库

git clone https://github.com/facebookresearch/segment-anything.git

1.创建environment.yaml

name: sam
channels:
  - pytorch
  - conda-forge
dependencies:
  - python=3.8
  - pytorch=1.9.0
  - torchvision=0.10.0
  - cudatoolkit=11.1
  - pip
conda env create -f environment.yaml
conda activate raptor

2.安装

pip install git+https://github.com/facebookresearch/segment-anything.git

3.其他库

pip install opencv-python pycocotools matplotlib onnxruntime onnx

目前安装的版本

Successfully installed coloredlogs-15.0.1 contourpy-1.1.1
cycler-0.12.1 flatbuffers-23.5.26 fonttools-4.43.1 humanfriendly-10.0
importlib-resources-6.1.0 kiwisolver-1.4.5 matplotlib-3.7.3
mpmath-1.3.0 numpy-1.24.4 onnx-1.15.0 onnxruntime-1.16.1
opencv-python-4.8.1.78 packaging-23.2 protobuf-4.24.4
pycocotools-2.0.7 pyparsing-3.1.1 python-dateutil-2.8.2 six-1.16.0
sympy-1.12 zipp-3.17.0

初阶测试

1.下载模型
https://github.com/facebookresearch/segment-anything#model-checkpoints

2.测试代码

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor


def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

img_path = '/data/qinl/code/segment-anything/notebooks/images/dog.jpg'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

masks = mask_generator.generate(image)

'''
Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format
'''

print(len(masks))
print(masks[0].keys())

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

3.输出

65
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])

在这里插入图片描述

进阶测试

图片预处理部分

其他instruction,都是在这个基础上进行处理

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))   


sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

img_path = '/data/qinl/code/segment-anything/notebooks/images/truck.jpg'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# 预处理输入图片
predictor.set_image(image)

输入的instruction为point的情况

# 输入为point的情况
    input_point = np.array([[500, 375]])
    input_label = np.array([1])

    # 可以用来显示一下点的位置
    # plt.figure(figsize=(10,10))
    # plt.imshow(image)
    # show_points(input_point, input_label, plt.gca())
    # plt.axis('on')
    # plt.show()  

    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )

    print('masks.shape',masks.shape)  # (number_of_masks) x H x W

    # 输出3个mask,分别有不同的score
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10,10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_point, input_label, plt.gca())
        plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()  

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

多点输入(都视为前景点)

# 输入为多个point的情况(前景点)
    input_point = np.array([[500, 375]])
    input_label = np.array([1])

    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )

    # additional points
    input_point = np.array([[500, 375], [1125, 625]])
    input_label = np.array([1, 1])

    mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        mask_input=mask_input[None, :, :],
        multimask_output=False,
    )
    
    print('masks.shape',masks.shape) # only 1 x H x W

    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show() 

在这里插入图片描述

多点输入(前景点加后景点)

决定这个点是前景点还是后景点的就是label,0就是背景的意思

修改标签,得到不一样的结果

    # input_point = np.array([[500, 375], [1125, 625]])
    # input_label = np.array([1, 1])

    input_point = np.array([[500, 375], [1125, 625]])
    input_label = np.array([1, 0])

在这里插入图片描述

使用box框具体物体

# 输入为additional points
    input_box = np.array([425, 600, 700, 875])
    masks, _, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    plt.axis('off')
    plt.show()

在这里插入图片描述

结合points和box

    # 输入为point和box
    input_box = np.array([425, 600, 700, 875])
    input_point = np.array([[575, 750]])
    input_label = np.array([0])

    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box,
        multimask_output=False,
    )

    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()

在这里插入图片描述

batch prompt inputs

    # batch prompt inputs
    input_boxes = torch.tensor([
        [75, 275, 1725, 850],
        [425, 600, 700, 875],
        [1375, 550, 1650, 800],
        [1240, 675, 1400, 750],
    ], device=predictor.device)

    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )

    print(masks.shape)  # (batch_size) x (num_predicted_masks_per_input) x H x W

    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
    for box in input_boxes:
        show_box(box.cpu().numpy(), plt.gca())
    plt.axis('off')
    plt.show()

在这里插入图片描述

End-to-end batched inference

    ## End-to-end batched inference
    image1 = image  # truck.jpg from above
    image1_boxes = torch.tensor([
        [75, 275, 1725, 850],
        [425, 600, 700, 875],
        [1375, 550, 1650, 800],
        [1240, 675, 1400, 750],
    ], device=sam.device)

    image2 = cv2.imread('./notebooks/images/groceries.jpg')
    image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
    image2_boxes = torch.tensor([
        [450, 170, 520, 350],
        [350, 190, 450, 350],
        [500, 170, 580, 350],
        [580, 170, 640, 350],
    ], device=sam.device)

    # Both images and prompts are input as PyTorch tensors that are already transformed to the correct frame. 
    # Inputs are packaged as a list over images, which each element is a dict that takes the following keys:
    # * `image`: The input image as a PyTorch tensor in CHW format.
    # * `original_size`: The size of the image before transforming for input to SAM, in (H, W) format.
    # * `point_coords`: Batched coordinates of point prompts.
    # * `point_labels`: Batched labels of point prompts.
    # * `boxes`: Batched input boxes.
    # * `mask_inputs`: Batched input masks.

    from segment_anything.utils.transforms import ResizeLongestSide
    resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

    def prepare_image(image, transform, device):
        image = transform.apply_image(image)
        image = torch.as_tensor(image, device=device.device) 
        return image.permute(2, 0, 1).contiguous()
    
    batched_input = [
        {
            'image': prepare_image(image1, resize_transform, sam),
            'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
            'original_size': image1.shape[:2]
        },
        {
            'image': prepare_image(image2, resize_transform, sam),
            'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
            'original_size': image2.shape[:2]
        }
    ]

    batched_output = sam(batched_input, multimask_output=False)

    # The output is a list over results for each input image, where list elements are dictionaries with the following keys:
    # * `masks`: A batched torch tensor of predicted binary masks, the size of the original image.
    # * `iou_predictions`: The model's prediction of the quality for each mask.
    # * `low_res_logits`: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration.

    print('batched_output[0].keys()',batched_output[0].keys())

    fig, ax = plt.subplots(1, 2, figsize=(20, 20))

    ax[0].imshow(image1)
    for mask in batched_output[0]['masks']:
        show_mask(mask.cpu().numpy(), ax[0], random_color=True)
    for box in image1_boxes:
        show_box(box.cpu().numpy(), ax[0])
    ax[0].axis('off')

    ax[1].imshow(image2)
    for mask in batched_output[1]['masks']:
        show_mask(mask.cpu().numpy(), ax[1], random_color=True)
    for box in image2_boxes:
        show_box(box.cpu().numpy(), ax[1])
    ax[1].axis('off')

    plt.tight_layout()
    plt.show()

在这里插入图片描述

高阶测试

模型训练(waiting)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1155182.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机毕业设计选题推荐-大学生校园兼职微信小程序/安卓APP-项目实战

✨作者主页:IT研究室✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

unittest与pytest的区别

Unittest vs Pytest 主要从用例编写规则、用例的前置和后置、参数化、断言、用例执行、失败重运行和报告这几个方面比较unittest和pytest的区别: 用例编写规则 用例前置与后置条件 断言 测试报告 失败重跑机制 参数化 用例分类执行 如果不好看,可以看下面表格&…

我与“云栖大会”剪不断的缘分

目录 前言首次参会经历技术前沿与创新洞察交流与合作项目展示与学习收获激励与成长之旅结束语 前言 作为开发者,想必大家对“云栖大会”并不陌生,“云栖大会”作为中国最具规模和影响力的云计算盛会,每年吸引着众多科技从业者、企业家和开发…

【C语言初学者周冲刺计划】2.3有3个字符串,要求找出其中“最大者

目录 1解题思路: 2代码: 3代码运行结果:​编辑 4总结: 1解题思路: 比较字符串大小的依据:26个大、小写字母“A-Z”,“a-z”中,字母越往后面的越大,小写字母比大写字母…

哪款进销存软件好用,企业该如何选择进销存软件?

哪个进销存软件好用?企业该如何选择进销存软件? 对于这个问题,企业首先应该考虑的不是所谓的哪个进销存软件是免费的,哪个进销存软件便宜,企业对于业务系统的选型可不像你双十一凑单买日用品那么简单。 如果你想要完…

【扩散模型】理解扩散模型的微调(Fine-tuning)和引导(Guidance)

理解扩散模型的微调Fine-tuning和引导Guidance 1. 环境准备2. 加载预训练过的管线3. DDIM——更快的采样过程4. 微调5. 引导6. CLIP引导参考资料 微调(Fine-tuning)指的是在预先训练好的模型上进行进一步训练,以适应特定任务或领域的过程。这…

使用 Authing 快速实现一套类似 OpenAI 的认证、API Key 商业权益授权机制

如果你有经常使用 OpenAI 或者 HuggingFace 这一类面向开发者的 SaaS 服务,对于 API Key 肯定不会陌生。我们在使用这些服务时,通常都会在其平台上面创建一套 API Key,之后我们才能在代码中通过这一串 API key 访问其服务;同时&am…

qt5工程打包成可执行exe程序

一、编译生成.exe 1.1、在release模式下编译生成.exe 1.2、建一个空白文件夹package,再将在release模式下生成的.exe文件复制到新建的文件夹中package。 1.3、打开QT5的命令行 1.4、用命令行进入新建文件夹package,使用windeployqt对生成的exe文件进行动…

Android Button修改背景颜色及实现科技感效果

目录 效果展示 实现科技感效果 修改Button背景 结语 效果展示 Android Button修改背景颜色及实现科技感效果效果如下: 实现科技感效果 操作方法如下: 想要创建一个富有科技感的按钮样式时,可以使用 Android 的 Shape Drawable 和 Sele…

阿里云发布通义千问2.0,模型参数达千亿级

10月31日,阿里云正式发布千亿级参数大模型通义千问2.0。在10个权威测评中,通义千问2.0综合性能超过GPT-3.5,正在加速追赶GPT-4。当天,通义千问APP在各大手机应用市场正式上线,所有人都可通过APP直接体验最新模型能力。…

精密数据工匠:探索 Netty ChannelHandler 的奥秘

通过上篇文章(Netty入门 — Channel,把握 Netty 通信的命门),我们知道 Channel 是传输数据的通道,但是有了数据,也有数据通道,没有数据加工也是没有意义的,所以今天学习 Netty 的第四…

一种支持热插拔的服务端插件设计思路

定位 服务端插件是一个逻辑扩展平台,提供了一个快速托管逻辑的能力。 核心特点 高性能:相对于RPC调用,没有网络的损耗,性能足够强劲。 高可靠:基于线程隔离,保证互不影响,插件的资源占用或崩溃等问题不直接影响业务。 部署快:不需要发布审核流程, 插件本身逻辑简短,…

有一个 3*4 的矩阵,找出其中值最大的元素,及其行列号

1解题思路&#xff1a; 首先学会输入二维数组&#xff1b;然后知道如何比较求最大值&#xff1b;最后就是格式问题&#xff1b; 2代码&#xff1a; #include<stdio.h> int main() {int a[3][4];int i,j,max,row,line;for(i0;i<3;i){printf("请输入二维数组\n&…

【JAVA】类与对象的重点解析

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️初识JAVA】 文章目录 前言类与对象的关系JAVA源文件有关类的重要事项static关键字 前言 Java是一种面向对象编程语言&#xff0c;OOP是Java最重要的概念之一。学习OOP时&#xff0c;学生必须理解面向…

架构设计之大数据架构(Lambda架构、Kappa架构)

大数据架构 一. 大数据技术生态二. 大数据分层架构三. Lambda架构3.1 Lambda架构分解为三层3.2 优缺点3.3 实际案例 四. Kappa架构4.1 结构图4.2 优缺点4.3 实际案例 五. Lambda架构与Kappa架构对比 其它相关推荐&#xff1a; 系统架构之微服务架构 系统架构设计之微内核架构 鸿…

杂货铺 | 报错记录(持续更新)

文章目录 ⚠️python SyntaxError: Non-UTF-8 code starting with ‘\xb3‘ in file⚠️partially initialized module ‘‘ has no attribute ‘‘(most likely due to a circular import)⚠️AttributeError: ‘DataFrame‘ object has no attribute ‘append‘ ⚠️python S…

OpenCV官方教程中文版 —— 分水岭算法图像分割

OpenCV官方教程中文版 —— 分水岭算法图像分割 前言一、原理二、示例三、完整代码 前言 本节我们将要学习 • 使用分水岭算法基于掩模的图像分割 • 函数&#xff1a;cv2.watershed() 一、原理 任何一副灰度图像都可以被看成拓扑平面&#xff0c;灰度值高的区域可以被看成…

企业知识库知识分类太有必要了,是省时省力的关键!

企业知识库是存储、组织和共享企业内部知识的重要工具。在现代企业中&#xff0c;知识是一项宝贵的资产&#xff0c;对于提高企业的竞争力和创新能力至关重要。而通过企业知识库进行知识分类&#xff0c;可以将海量信息有序划分和组织&#xff0c;让企业员工能够快速定位、理解…

贪心算法学习------优势洗牌

目录 一&#xff0c;题目 二&#xff0c;题目接口 三&#xff0c;解题思路和代码 全部代码&#xff1a; 一&#xff0c;题目 给定两个数组nums1和nums2,nums1相对于nums2的优势可以用满足nums1[i]>nums2[i]的索引i的数目来描述。 返回nums1的任意排序&#xff0c;使其优…

标签推荐Top-N列表优化算法_朱小兵

2算法模型 2&#xff0e;1 Top-N推荐列表重排序算法