3、Segment Anything

news2024/11/27 14:43:33

github

创建anaconda环境

conda create -n ASM python=3.8

下载依赖包

# pytorch>=1.7 and torchvision>=0.8
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch

pip install git+https://github.com/facebookresearch/segment-anything.git
pip install opencv-python pycocotools matplotlib onnxruntime onnx

预训练权重
default or vit_h:ViT-H SAM model
vit_l:vit_l
vit_b:vit_b

example
详细的官网example
Automatically generating

代码使用

工具方法

读取图片

def read_image(path="./data/000.png"):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

展示标记框

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

展示标记点

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)

展示掩膜

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

简单使用

main方法

if __name__ == '__main__':
	# 初始化模型
    sam = init()
    predictor = SamPredictor(sam)
    # 读取图片
    image = read_image("./data/000.png")
    # 绑定图片
    predictor.set_image(image)
    # 调用自定义方法
    predict_box(image, predictor)

加载模型

def init(model_type="vit_h", sam_checkpoint="/devdata/chengan/SAM_checkpoint/sam_vit_h_4b8939.pth", device="cuda"):
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    return sam

点语义分割

def sample_use(image, predictor):
    input_points = np.array([
        [300, 300]
    ])
    # 1 (foreground point) or 0 (background point)
    input_labels = np.array([
        1
    ])
    # 掩膜,置信度,低分辨率掩码逻辑
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True
    )

    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_points, input_labels, plt.gca())
        plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()
    print("\nmask shape", masks.shape)

点语义分割迭代

def predict_dir(image, predictor):
    input_points = np.array([
        [300, 300]
    ])
    # 1 (foreground point) or 0 (background point)
    input_labels = np.array([
        1
    ])

    # 第一次语义
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True
    )

    # Choose the model's best mask
    mask_input = logits[np.argmax(scores), :, :]
    # 第二次语义
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        mask_input=mask_input[None, :, :],
        multimask_output=False,
    )

    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_points, input_labels, plt.gca())
    plt.axis('off')
    plt.show()

box语义分割

def predict_box(image, predictor):
    input_box = np.array([425, 600, 600, 700])
    masks, _, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    plt.axis('off')
    plt.show()

点 box 语义分割

def predict_box_point(image, predictor):
    input_box = np.array([425, 600, 700, 700])
    input_point = np.array([[575, 750]])
    input_label = np.array([0])
    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box,
        multimask_output=False,
    )
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks[0], plt.gca())
    show_box(input_box, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()

多个box

def predict_boxs(image, predictor):
    input_boxes = torch.tensor([
        [75, 275, 725, 750],
        [425, 600, 700, 775],
        [375, 550, 650, 700],
        [240, 675, 400, 750],
    ], device=predictor.device)
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    # (batch_size) x (num_predicted_masks_per_input) x H x W
    print(masks.shape)
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
    for box in input_boxes:
        show_box(box.cpu().numpy(), plt.gca())
    plt.axis('off')
    plt.show()

batch all images

def predict_batch(images, sam):
    resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

    image1 = images[0]
    # dual with image
    image1 = resize_transform.apply_image(image1)
    image1 = torch.as_tensor(image1, device=sam.device)
    image1 = image1.permute(2, 0, 1).contiguous()
    # box
    image1_boxes = torch.tensor([
        [75, 275, 725, 750],
        [425, 600, 700, 775],
        [375, 550, 650, 800],
        [240, 675, 400, 750],
    ], device=sam.device)

    image2 = images[1]
    image2 = resize_transform.apply_image(image2)
    image2 = torch.as_tensor(image2, device=sam.device)
    image2 = image2.permute(2, 0, 1).contiguous()
    image2_boxes = torch.tensor([
        [450, 170, 520, 350],
        [350, 190, 450, 350],
        [500, 170, 580, 350],
        [580, 170, 640, 350],
    ], device=sam.device)

    """
        image: The input image as a PyTorch tensor in CHW format.
        original_size: The size of the image before transforming for input to SAM, in (H, W) format.
        point_coords: Batched coordinates of point prompts.
        point_labels: Batched labels of point prompts.
        boxes: Batched input boxes.
        mask_inputs: Batched input masks.
    """
    batched_input = [
        {
            'image': image1,
            'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
            'original_size': image1.shape[:2]
        },
        {
            'image': image2,
            'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
            'original_size': image2.shape[:2]
        }

    ]

    batched_output = sam(batched_input, multimask_output=False)
    """
    masks: A batched torch tensor of predicted binary masks, the size of the original image.
    iou_predictions: The model's prediction of the quality for each mask.
    low_res_logits: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration.
    """
    print(batched_output[0].keys())
    fig, ax = plt.subplots(1, 2, figsize=(20, 20))

    ax[0].imshow(image1)
    for mask in batched_output[0]['masks']:
        show_mask(mask.cpu().numpy(), ax[0], random_color=True)
    for box in image1_boxes:
        show_box(box.cpu().numpy(), ax[0])
    ax[0].axis('off')

    ax[1].imshow(image2)
    for mask in batched_output[1]['masks']:
        show_mask(mask.cpu().numpy(), ax[1], random_color=True)
    for box in image2_boxes:
        show_box(box.cpu().numpy(), ax[1])
    ax[1].axis('off')

    plt.tight_layout()
    plt.show()

多语义实例分割

多语义分割图片展示

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

默认方法

def sample_use(image, sam):
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    plt.figure(figsize=(20, 20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.show()
    print(len(masks))
    print(masks[0].keys())

调整输入参数

def improved_use(image, sam):
    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.86,
        stability_score_thresh=0.92,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  # Requires open-cv to run post-processing
    )
    masks = mask_generator.generate(image)
    plt.figure(figsize=(20, 20))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.show()
    print(len(masks))
    print(masks[0].keys())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1259580.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

13、LCD1602调试工具

LCD1602调试工具 使用LCD1602液晶屏作为调试窗口&#xff0c;提供类似Printf函数的功能&#xff0c;可实时观察单片机内部数据的变化情况&#xff0c;便于调试和演示。 main.c #include <REGX52.H> #include "LCD1602.h" #include "Delay.h"//存储…

快速搭建一个SpringCloud、SpringBoot项目 || 项目搭建要点

1. 基本结构 建立springcloud项目从表入手&#xff0c;分析好需求建立表结构后&#xff0c;使用mybatis-plux生成POJO类&#xff0c;在对应的model模块中。 2. 微服务部分架构 2.1 依赖 service 微服务模块的依赖仅包含如下&#xff0c;数据库等依赖包含在model中&#xff0c…

【解决视觉引导多个位置需要标定多个位置的问题】

** 以下只针对2D定位&#xff0c;就是只有X、Y、Rz三个自由度的情况。** 假设一种情况&#xff0c;当视觉给机器人做引导任务时&#xff0c;零件有多个&#xff0c;分布在料框里&#xff0c;视觉需要走多个位置去拍&#xff0c;那么只需要对第一个位置确定拍照位&#xff0c;确…

力扣6:N字形变化

代码&#xff1a; class Solution { public:string convert(string s, int numRows){int lens.size();if(numRows1){return s;}int d2*numRows-2;int count0;string ret;//第一行&#xff01;for(int i0;i<len;id){rets[i];}//第k行&#xff01;for(int i1;i<numRows-1;…

智能优化算法应用:基于教与学算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于教与学算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于教与学算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.教与学算法4.实验参数设定5.算法结果6.参考文献7.…

超越GPT-4,拥有联网能力,Kimi-Chat大模型已免费使用,国内直接访问

目前ChatGPT的所有免费用户都已可以使用带有语音功能的ChatGPT。 人吧&#xff0c;总是贪婪的&#xff0c;我还想要ChatGPT Plus用户独享的“联网”功能。 目前对于ChatGPT来说&#xff0c;不想交钱&#xff0c;别拥有“联网”能力了&#xff0c;于是我找到了一个后起之秀&…

【差旅游记】新疆哈密回王府印象

哈喽&#xff0c;你好啊&#xff0c;我是雷工&#xff01; 2023年11月4号&#xff0c;那天的风的确挺大&#xff0c;逛完哈密博物馆考虑要不要去旁边的哈密回王府逛逛。想着来都来了&#xff0c;虽然网上评价不太好&#xff0c;还是去溜达一圈吧&#xff0c;于是决定自己去转转…

为啥网络安全那么缺人,但很多人却找不到工作?

文章目录 一、学校的偏向于学术二、学的东西太基础三、不上班行不行 为什么网络安全的人才缺口那么大&#xff0c;但是大学毕业能找到网安工作的人却很少&#xff0c;就连招聘都没有其他岗位多&#xff1f; 明明央视都说了网络安全的人才缺口还有300多万&#xff0c;现在找不到…

C++ 用ifstream读文件

输入流的继承关系: C++ 使用标准库类来处理面向流的输入和输出: iostream 处理控制台 IOfstream 处理命名文件 IOstringstream 完成内存 string 的 IO每个IO 对象都维护一组条件状态 flags (eofbit, failbit and badbit),用来指出此对象上是否可以进行 IO 操作。如果遇到错误…

vue实战——登录【详解】(含自适配全屏背景,记住账号--支持多账号,显隐密码切换,登录状态保持)

效果预览 技术要点——自适配全屏背景 https://blog.csdn.net/weixin_41192489/article/details/119992992 技术要点——密码输入框 自定义图标切换显示隐藏 https://blog.csdn.net/weixin_41192489/article/details/133940676 技术要点——记住账号&#xff08;支持多账号&…

「江鸟中原」有关HarmonyOS-ArkTS的Http通信请求

一、Http简介 HTTP&#xff08;Hypertext Transfer Protocol&#xff09;是一种用于在Web应用程序之间进行通信的协议&#xff0c;通过运输层的TCP协议建立连接、传输数据。Http通信数据以报文的形式进行传输。Http的一次事务包括一个请求和一个响应。 Http通信是基于客户端-服…

进程等待讲解

今日为大家分享有关进程等待的知识&#xff01;希望读完本文&#xff0c;大家能有一定的收获&#xff01; 正文开始&#xff01; 进程等待的引进 既然我们今天要讲进程等待这个概念&#xff01;那么只有我们把下面这三个方面搞明白&#xff0c;才能真正的了解进程等待&#x…

形象建设、生意经营、用户运营,汽车品牌如何在小红书一举多得?

随着小红书在多领域的持续成长&#xff0c;现在来小红书看汽车的用户&#xff0c;需求逐渐多元化与专业化。近1年的时间&#xff0c;有超过1亿人在小红书「主动搜索」过汽车内容&#xff0c;大家已经不仅限于玩车、用车&#xff0c;更是扩展到了百科全书式的看、选、买、学各个…

Python3 selenium 设置元素等待的三种方法

为什么要设置元素等待&#xff1f; 当你的网络慢的时候&#xff0c;打开网页慢&#xff0c;网页都没完全打开&#xff0c;代码已经在执行了&#xff0c;但是没找到你定位的元素&#xff0c;此时python会报错。 当你的浏览器或电脑反应慢&#xff0c;网页没完全打开&#xff0c;…

12、模块化编程

模块化编程 1、传统方式编程&#xff1a;所有的函数均放在main.c里&#xff0c;若使用的模块比较多&#xff0c;则一个文件内会有很多的代码&#xff0c;不利于代码的组织和管理&#xff0c;而且很影响便朝着的思路 2、模块化编程&#xff1a;把各个模块的代码放在不同的.c文件…

java--单继承、Object

java是单继承的&#xff0c;java中的类不支持多继承&#xff0c;但是支持多层继承。 反证法&#xff1a; 如果一个类同时继承两个类&#xff0c;然后两个类中都有同样的一个方法&#xff0c;哪当我创建这个类里的方法&#xff0c;是调用哪父类的方法 所以java中的类不支持多继…

PostgreSQL + SQL Server = WiltonDB

WiltonDB 是一个基于 PostgreSQL 的开源数据库&#xff0c;通过 Babelfish 插件支持 Microsoft SQL Server 协议以及 T-SQL 语句。 Babelfish 是亚马逊提供的一个开源项目&#xff0c;使得 PostgreSQL 数据库同时具有 Microsoft SQL Server 数据查询和处理的能力。Babelfish 可…

11、动态数码管显示

数码管驱动方式 1、单片机直接扫描&#xff1a;硬件设备简单&#xff0c;但会消耗大量的单片机CPU时间 2、专用驱动芯片&#xff1a;内部自带显存、扫描电路&#xff0c;单片机只需告诉他显示什么即可 #include <REGX52.H> //数组代表显示亮灯的内容0、1、2、3、4、5、…

单车模型及其线性化

文章目录 1 单车模型2 线性化3 实现效果4 参考资料 1 单车模型 这里讨论的是以后轴为中心的单车运动学模型&#xff0c;由下式表达&#xff1a; S ˙ [ x ˙ y ˙ ψ ˙ ] [ v c o s ( ψ ) v s i n ( ψ ) v t a n ( ψ ) L ] \dot S \begin{bmatrix} \dot x\\ \dot y\\ \d…

C++ : 初始化列表 类对象作为类成员

传统方式初始化 C 提供了初始化列表语法&#xff0c;用来初始化属性 初始化列表 语法&#xff1a; 构造函数()&#xff1a;属性1(值1), 属性2&#xff08;值2&#xff09;... {} class Person { public://传统方式初始化 Person(int a, int b, int c) {m_A a;m_B b;m_C c…