medsam ,数入xml +img, 根据检测框,原图显示分割效果,加上点的减少处理

news2024/11/24 9:48:49

1、输入每张图片的多个检测框,得到这张图片的sam 分割结果

import numpy as np
import matplotlib.pyplot as plt
import os

join = os.path.join
import torch
from segment_anything import sam_model_registry
from skimage import io, transform
import torch.nn.functional as F
import argparse


@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :]  # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed,  # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
        multimask_output=False,
    )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg


# %% load model and image
parser = argparse.ArgumentParser(
    description="run inference on testing set based on MedSAM"
)
parser.add_argument(
    "-i",
    "--data_path",
    type=str,
    default="assets/img_demo.png",
    help="path to the data folder",
)
parser.add_argument(
    "-o",
    "--seg_path",
    type=str,
    default="assets/",
    help="path to the segmentation folder",
)
parser.add_argument(
    "--box",
    type=list,
    default=[95, 255, 190, 350],
    help="bounding box of the segmentation target",
)
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument(
    "-chk",
    "--checkpoint",
    type=str,
    default="work_dir/MedSAM/medsam_vit_b.pth",
    # default="/home/syy/code/sam/MedSAM-LiteMedSAM/carotid_MedSAM-Lite-Box-20240508-1808/medsam_lite_best1.pth",
    help="path to the trained model",
)
args = parser.parse_args()

device = args.device
medsam_model = sam_model_registry["vit_b"](checkpoint=args.checkpoint)
medsam_model = medsam_model.to(device)
medsam_model.eval()
print("=====================================> 模型加载完毕")


import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
import os
import random 


import os
import xml.etree.ElementTree as ET
import cv2



def parse_xml(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    image_name = root.find('filename').text
 
    boxes = []
    labels = []

    for obj in root.findall('object'):
        label = obj.find('name').text
        bbox = obj.find('bndbox')
        x1 = int(bbox.find('xmin').text)
        y1 = int(bbox.find('ymin').text)
        x2 = int(bbox.find('xmax').text)
        y2 = int(bbox.find('ymax').text)
        boxes.append((x1, y1, x2, y2))
        labels.append(label)

    return image_name, boxes, labels

def process_xmls(xmls_dir):
    results = []
    xml_lists = os.listdir(xmls_dir)
    xml_lists.sort()
    for xml_file in xml_lists[0:200]:
        if xml_file.endswith('.xml'):
            xml_path = os.path.join(xmls_dir, xml_file)
            result = parse_xml(xml_path)
            results.append(result)

    return results



def show_mask(mask, ax, random_color=False):
    #  mask  模型预测的分割图 01  目标和背景
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.1]) #透明度0.3
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) #将掩码和颜色相乘,得到最终的带有颜色的掩码图像


    ax.imshow(mask_image) # 不显示mask区域

    #########################################
    # 找到掩码的轮廓
    contours, _ = cv2.findContours((mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    # 对最大的轮廓进行逼近处理,减少轮廓点的数量
    reduction_factor = 0.002 #0  #0.005
    if contours:  #没有会返回空
        areas = [cv2.contourArea(cnt) for cnt in contours]
        # 找到最大面积的轮廓的索引
        max_area_index = np.argmax(areas)
        # 获取最大面积的轮廓
        largest_contour = contours[max_area_index]           
        # 对每个轮廓进行逼近处理,减少轮廓

        if reduction_factor > 0.000001:
            epsilon = reduction_factor * cv2.arcLength(largest_contour, True)
            approx = cv2.approxPolyDP(largest_contour, epsilon, True)  # 最大轮廓的操作,平滑轮廓点
            # 绘制轮廓,减少的点,平滑的不是很好,换一个
            print("点有没有减少,len(approx),len(contours)",len(approx),len(largest_contour))
            ax.plot(approx[:, 0, 0], approx[:, 0, 1], color='red', linewidth=1)
        else:
            ax.plot(largest_contour[:, 0, 0], largest_contour[:, 0, 1], color='red', linewidth=0.3)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='yellow', facecolor=(0,0,0,0), lw=1))



def prompt_box_pred(xmls_dir,imgs_dir,save_dir):
    # 示例用法
    results = process_xmls(xmls_dir)
    for ind, res in enumerate(results):
        image_name, boxes, labels = res
        print(ind,': Image:', image_name)

        # 读取图片和xml 文件,获取坐标
        img_path = os.path.join(imgs_dir,image_name)
        # image = cv2.imread(img_path)
        # if image is None:
        #     print("=======================> 图片路径不存在",img_path)
        #     continue
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
        # image_height, image_width = image.shape[:2]


        img_np = io.imread(img_path)
        if len(img_np.shape) == 2:
            img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
        else:
            img_3c = img_np
        H, W, _ = img_3c.shape
        # %% image preprocessing
        img_1024 = transform.resize(
            img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
        ).astype(np.uint8)
        img_1024 = (img_1024 - img_1024.min()) / np.clip(
            img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
        )  # normalize to [0, 1], (H, W, 3)
        # convert the shape to (3, H, W)
        img_1024_tensor = (
            torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
        )        


        plt.figure(figsize=(10, 10))  #画布的大小
        plt.imshow(img_3c)

        for box, label in zip(boxes, labels):
            x1, y1, x2, y2 = box
            print('  Label:', label)
            print('  Box:', x1, y1, x2, y2)

            input_box = np.array(box) 
            box_np = np.array([box]) 
            # transfer box_np t0 1024x1024 scale
            box_1024 = box_np / np.array([W, H, W, H]) * 1024
            #  预测图片的分割标签
            with torch.no_grad():
                image_embedding = medsam_model.image_encoder(img_1024_tensor)  # (1, 256, 64, 64)

            medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)  #分割最后输出原图大小
        
            # print(medsam_seg.shape) #(127, 212)
            # print(img_3c.shape) # (127, 212, 3)

            show_mask(medsam_seg, plt.gca())
            show_box(input_box, plt.gca())

        plt.axis('off')
        # plt.show()
        ###  bbox_inches='tight'表示将图像边缘紧贴画布边缘,pad_inches=0表示不添加额外的边距
        plt.savefig(save_dir + image_name,bbox_inches='tight', pad_inches=0) #) # 一张图保存多个框   
            
if __name__ == "__main__":
    xmls_dir = '/home/syy/data/甲乳/breast/image2/xmls'
    imgs_dir = '/home/syy/data/甲乳/breast/image2/images' 
    save_dir = "/home/syy/data/甲乳/breast/image2/medsam/"   
    
    os.makedirs(save_dir,exist_ok=True)
    prompt_box_pred(xmls_dir,imgs_dir,save_dir)    

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1714532.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

透视AI技术:探索折射技术在去衣应用中的奥秘

引言: 随着人工智能技术的飞速发展,其在图像处理和计算机视觉领域的应用日益广泛。其中,AI去衣技术作为一种颇具争议的应用,引发了广泛的讨论和关注。本文将深入探讨折射技术在AI去衣中的应用及其背后的原理。 一、AI去衣技术简介…

AI Agent智能体概述及原理

AI Agent概述 AI Agent旨在理解、分析和响应人类输入,像人类一样执行任务、做出决策并与环境互动。它们可以是遵循预定义规则的简单系统,也可以是根据经验学习和适应的复杂、自主的实体;可以是基于软件的实体,也可以是物理实体。…

深入理解统计学中的最大值与最小值

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、统计学中的基础概念:最大值与最小值 1. 创建数组与数据导入 2. 求解整体数…

重磅发布,2024精选《制造业商业智能BI最佳实践合集 》

在数字时代,中国制造业正面临着前所未有的深刻变革。 商业环境的复杂性与多变性、全球化竞争的激烈程度、消费需求的快速演变,以及新技术的持续进步等多种因素共同推动着制造企业积极加入数字化转型的潮流。 在这个转型的过程中,转型的速度…

yq—2024/5/29—零钱兑换

代码实现&#xff1a; #define min(a, b) ((a) > (b) ? (b) : (a))int coinChange(int *coins, int coinsSize, int amount) {int dp[amount 1];// 初始化for (int i 0; i < amount 1; i) {dp[i] INT32_MAX;}dp[0] 0;// 01背包 -----先遍历物品&#xff0c;再遍历背…

oracle数据回显时候递归实战

太简单的两篇递归循环 orcale 在项目里递归循环实战 先看资产表T_ATOM_ASSET结构 看业务类别表T_ATOM_BUSI_CATEGORY结构 问题出现 页面显示 实际对应的归属业务分类 涉及到oracle递归实战(这里不会如何直接在atomAsset的seelct里面处理递归回显) 直接在实现层看atomAs…

CTF_RE典例

PZCTF Xor 分组异或 0&#xff0c;1&#xff0c;2&#xff0c;3 不变, 4 , 5 &#xff0c;6&#xff0c;7只异或Str[0], 8,9,10,11要先后异或Str[0],Str[1] s [0x50, 0x5a, 0x43, 0x54, 0x16, 0x2b, 0x11, 0xf, 0x3b, 0x63,0x7e, 0x7e, 0x78, 0x2c, 0x16, 0x3a, 0x71, 0x2e…

The First项目报告:一场由社区驱动的去中心化加密冒险—Turbo

2023年3月14日&#xff0c;由OpenAI公司开发自回归语言模型GPT-4发布上线&#xff0c;一时之间引发AI智能领域的轩然大波&#xff0c;同时受到影响的还有加密行业&#xff0c;一众AI代币纷纷出现大幅度拉升。与此同时&#xff0c;一款名为Turbo的Meme代币出现在市场中&#xff…

DNSlog环境搭建

阿里云域名公网VPS地址 购买阿里云域名后设置“自定义DNSHOST” DNS服务器填写ns1和ns2 如&#xff1a;ns1.aaa.com IP地址填写你的VPS地址 如&#xff1a;1.1.1.1 填写解析记录&#xff0c;一个A记录、一个NS记录 NS记录就是*.域名指向记录值ns1.域名 如&#xff1a;*.aaa…

计算机图形学入门03:二维基本变换

变换(Transformation)可分为模型(Model)变换和视图(Viewing)变换。在3D虚拟场景中相机的移动和旋转&#xff0c;角色人物动画都需要变换&#xff0c;用来描述物体运动。将三维世界投影变换到2D屏幕上成像出来&#xff0c;也需要变换。 1.缩放变换 如上图所示&#xff0c;把一个…

社区供稿丨GPT-4o 对实时互动与 RTC 的影响

以下文章来源于共识粉碎机 &#xff0c;作者AI芋圆子 前面的话&#xff1a; GPT-4o 发布当周&#xff0c;我们的社区伙伴「共识粉碎机」就主办了一场主题为「GPT-4o 对实时互动与 RTC 的影响」讨论会。涉及的话题包括&#xff1a; GPT-4o 如何降低延迟&#xff08;VAD 模块可…

图片怎样在线改像素大小?电脑快速修改图片大小的方法

在设计图片的时候下载的图片尺寸一般会比较大&#xff0c;在网上使用经常会因为尺寸的问题导致无法正常上传&#xff0c;那么如何快速在线改图片大小呢&#xff1f;想要修改图片尺寸可以在直接选择网上的图片改大小工具的功能来快速完成修改&#xff0c;操作简单方便使用&#…

M功能-支付平台(六)

target&#xff1a;离开柬埔寨倒计时-217day 今天突然发现我在csdn居然把我ip属地搞出来了&#xff0c;之前都没注意到&#xff0c;哎 前言 M功能演示版本做到后期(也就是第二周的后面3天)真的很心酸&#xff0c;这边安排的4后端后面都放弃了&#xff0c;觉得做不出来&#…

python Z-score标准化

python Z-score标准化 Zscore标准化sklearn库实现Z-score标准化手动实现Z-score标准化 Zscore标准化 Z-score标准化&#xff08;也称为标准差标准化&#xff09;是一种常见的数据标准化方法&#xff0c;它将数据集中的每个特征的值转换为一个新的尺度&#xff0c;使得转化后的…

设置自动刷新数据透视表的数据源

数据透视表数据源的自动刷新 一般情况操作&#xff1a; 自动刷新操作&#xff1a; 1、定义名称名称 引用位置&#xff1a;OFFSET(Sheet1!$A$1,0,0,COUNTA(Sheet1!$A:$A),COUNTA(Sheet1!$1:$1)) 2、数据透视表的数据源更改为【源数据】—— 即前面定义的名称 3、数据——全部…

香港优才计划找中介是否是智商税,靠谱中介又该如何找?

关于香港优才计划的申请&#xff0c;找中介帮助还是自己DIY&#xff0c;网络上充斥的声音太多&#xff0c;对不了解的人来说&#xff0c;难以抉择的同时还怕上当受骗。 这其中很容易误导人的关键在于——信息差&#xff01; 今天这篇文章的目的就是想让大家看清一些中介和DIY…

ResNet 原理剖析以及代码复现

原理 ResNet 解决了什么问题&#xff1f; 一言以蔽之&#xff1a;解决了深度的神经网络难以训练的问题。 具体的说&#xff0c;理论上神经网络的深度越深&#xff0c;其训练效果应该越好&#xff0c;但实际上并非如此&#xff0c;层数越深会导致越差的结果并且容易产生梯度爆炸…

Scapy:用Python编写自己的网络抓包工具

随着Python越来越流行&#xff0c;在安全领域的用途也越来越多。比如可以用requests 模块撰写进行Web请求工具&#xff1b;用sockets编写TCP网络通讯程序&#xff1b;解析和生成字节流可以使用struct模块。而要解析和处理网络包在网络安全领域更加普遍&#xff0c;时常我们会使…

Vue——事件修饰符

文章目录 前言阻止默认事件 prevent阻止事件冒泡 stop 前言 在官方文档中对于事件修饰符有一个很好的说明&#xff0c;本篇文章主要记录验证测试的案例。 官方文档 事件修饰符 阻止默认事件 prevent 在js原生的语言中&#xff0c;可以根据标签本身的事件对象进行阻止默认事件…

数组-给出最大容量,求能获得的最大值

一、问题描述 二、解题思路 这个题目其实是求给出数组中&#xff0c;子数组和不大于M中&#xff0c;和最大值的子数组。 求子数组使用双指针就可以解决问题&#xff0c;相对比较简单。&#xff08;如果是子序列&#xff0c;则等价于0-1背包问题&#xff0c;看题目扩展中的问题…