用SAM2和Cutie模型目标追踪

news2024/11/24 13:59:51

一、数据集

视频:每个视频文件夹以图片帧的形式存储

box:给出每个视频第一帧要追踪的物体的box

二、将数据格式转换成SAM2所需要的格式

主要是将box转换成mask的格式,下面这个代码就是将box转换成mask的代码,具体转换原理如下:

box作为prompt输入进入SAM2,SAM2生成第一帧的mask,在mask上选五个positive point作为新的prompt,结合bbox作为新的prompt再过一遍SAM2 tracking

import os
import json
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
from PIL import Image

# 初始化SAM2预测器
checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

# 图像和bbox目录路径
image_base_path = "/data/WebDAV/VidOR/video_frames_extract"  #图像帧路径
bbox_base_path = "/data/WebDAV/VidOR/all_training_bbox"  #bbox的 路径
output_mask_path = "/data/WebDAV/VidOR/mask_final" #
output_mask_path_initial = "/data/WebDAV/VidOR/mask_initial" #box作为prompt输入SAM2的结果

# 创建保存mask的文件夹
os.makedirs(output_mask_path, exist_ok=True)
os.makedirs(output_mask_path_initial, exist_ok=True)

# # 从bbox区域采样点(这里随机采样一些点作为示例)
# def sample_points_from_bbox(bbox, num_points=5):
#     xmin, ymin, xmax, ymax = bbox
#     x_coords = np.random.randint(xmin, xmax, size=num_points)
#     y_coords = np.random.randint(ymin, ymax, size=num_points)
#     return np.column_stack((x_coords, y_coords))

# 从生成的掩码中采样正点
def sample_points_from_mask(mask, num_points=5):
    # 获取mask中所有正点的位置 (非零点)
    y_coords, x_coords = np.where(mask > 0)
    # 如果正点数量不足,取所有正点
    if len(x_coords) < num_points:
        sampled_indices = np.arange(len(x_coords))
    else:
        sampled_indices = np.random.choice(len(x_coords), num_points, replace=False)
    sampled_points = np.column_stack((x_coords[sampled_indices], y_coords[sampled_indices]))
    return sampled_points

# 遍历video_frames_extract目录中的所有子文件夹
for video_segment in os.listdir(image_base_path):
    # 提取video_id和帧信息
    video_id, start_frame, end_frame = video_segment.split('_')

    if int(start_frame) == 0:
        start_frame = str(int(start_frame) + 1)

    # 获取视频段的第一帧路径
    first_frame_path = os.path.join(image_base_path, video_segment, "{:04d}.png".format(int(start_frame)))

    if not os.path.exists(first_frame_path):
        print(f"帧 {first_frame_path} 不存在,跳过该视频段.")
        continue

    # 打开第一帧图像
    image = Image.open(first_frame_path).convert("RGB")
    image_array = np.array(image)

    # 找到对应的bbox json文件
    json_file_path = os.path.join(bbox_base_path, f"{video_id}.json")
    if not os.path.exists(json_file_path):
        print(f"对应的bbox文件 {json_file_path} 不存在,跳过该视频段.")
        continue

    # 读取json文件获取bbox
    with open(json_file_path, "r") as f:
        data = json.load(f)

    # 查找对应开始帧的bbox(假设trajectories中的索引与帧一一对应)
    frame_index = int(start_frame) - 1  # 帧索引从0开始,所以减去1
    if frame_index >= len(data['trajectories']):
        print(f"帧 {start_frame} 超出 {video_id} 的轨迹范围.")
        continue

    frame_bboxes = data['trajectories'][frame_index]

    # 对每个object生成mask
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        predictor.set_image(image_array)

        for obj in frame_bboxes:
            tid = obj['tid']
            bbox = obj['bbox']

            # Step 1: 使用bbox生成初步的SAM2 mask
            bbox_array = np.array([bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']])
            masks, _, _ = predictor.predict(box=bbox_array, multimask_output=False)
            # # 转换bbox为np数组
            # bbox_array = np.array([bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']])

            # 将mask转换为uint8类型
            initial_mask = (masks[0] * 255).astype(np.uint8)

            # 保存初步的mask
            initial_mask_pil = Image.fromarray(initial_mask)
            initial_mask_filename = os.path.join(output_mask_path_initial, f"{video_id}_{start_frame}_{end_frame}_{tid}.png")
            initial_mask_pil.save(initial_mask_filename)
            # print(f"保存初步的mask: {initial_mask_filename}")

            # Step 2: 从生成的初步mask中采样正点
            sampled_points = sample_points_from_mask(initial_mask)
            point_labels = np.ones(len(sampled_points))  # 所有采样点都是正点

            # Step 3: 使用采样点、mask和bbox作为prompt,再次输入SAM2进行掩码生成
            final_masks, _, _ = predictor.predict(
                # mask_input=masks,
                box=bbox_array,
                point_coords=sampled_points,
                point_labels=point_labels,
                multimask_output=False
            )
            # # 从bbox区域采样正点
            # point_coords = sample_points_from_bbox(bbox_array, num_points=5)
            # point_labels = np.ones(len(point_coords))  # 所有点都标记为正点(标签为1)

            # # 生成mask
            # masks, _, _ = predictor.predict(box=bbox_array, multimask_output=False)

            # 将最终生成的mask保存
            final_mask = (final_masks[0] * 255).astype(np.uint8)
            final_mask_pil = Image.fromarray(final_mask)
            final_mask_filename = os.path.join(output_mask_path, f"{video_id}_{start_frame}_{end_frame}_{tid}.png")
            final_mask_pil.save(final_mask_filename)
            # print(f"保存最终的mask: {final_mask_filename}")

三、接下来将视频帧和mask输入进入SAM2进行推理就可以了

四、Cutie

在推理的时候主要有三个参数影响推理结果

1、size图像的大小,我的所有图像都是1080*1920大小的,但是我将其设置成800,再往后效果都是一样的,而且推理占的显存还是挺大的

2、max_mem_frames:30

3、min_mem_frames:28

2和3个参数是我在保证size一定的时候最大的值了,我用的显卡是80G的

还有cutie应该是有三个权重,但是mega那个权重是效果最好的

五、总结

总的来说,sam2的large权重是要比cutie效果好很多的。当然以上只是对我的数据集来说,具体效果还要根据实际情况来定

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2211168.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深圳易图讯科技有限公司承建的厦门应急处突大队三维电子沙盘顺利通过专家验收

近日&#xff0c;深圳易图讯科技有限公司承建的厦门应急处突大队三维电子沙盘系统项目成功通过专家组的严格验收&#xff0c;标志着该系统在应急管理和处置突发事件方面的应用取得了重要突破。 验收过程中&#xff0c;专家组对三维电子沙盘系统的各项功能进行了全面而细致的测试…

第十六周:机器学习笔记

第十六周周报 摘要Abstratc一、机器学习1. Pointer Network&#xff08;指针网络&#xff09;2. 生成式对抗网络&#xff08;Generative Adversarial Networks | GAN&#xff09;——&#xff08;上&#xff09;2.1 Generator&#xff08;生成器&#xff09;2.2 Discriminator&…

Cef加载自定义本地资源

在Cef auto build下载cefCEF Automated Builds 我下载的是104&#xff0c;使用cefsimple工程。 例如&#xff1a;前端资源如下 通过http协议把前端资源加载出来。所有的资源都通过http://local.test.cn/xxx加载。 前端资源包括index.html、test.css、test.js index.html&am…

麒麟系统离线安装英伟达驱动

麒麟系统离线安装英伟达驱动 驱动相关程序下载下载显卡驱动下载CUDA-Toolkit下载cudnn 安装关闭自带图形界面禁用 Nouveau 驱动安装驱动安装CUDA-Toolkit安装cudnn 驱动相关程序下载 下载显卡驱动 进入显卡驱动查询页面&#xff0c;下载对应的显卡驱动&#xff0c;页面如下&a…

第十节:React路由:react-router认识与基本使用

1. React Router的理解 React的路由根据项目的不同使用不同的路由库,web应用主要使用react-router和react-router-dom react-router和react-router-dom的区别 react-rotuer 核心库,提供了一些核心的api,但是没有提供dom操作进行跳转的api react-router-dom扩展了核心库,提供了一…

Edge TTS

edge-tts项目地址&#xff1a;https://github.com/rany2/edge-tts 1.安装部署 在cmd中运行以下命令安装edge-tts pip install edge-tts pip install edge-tts速度非常快&#xff0c;几秒钟就安装完成了。 2.文本转语音 输入以下命令&#xff0c;将一段英文转为音频。 edg…

Linux——传输层协议

目录 一再谈端口号 1端口号范围划分 2两个问题 3理解进程与端口号的关系 二UDP协议 1格式 2特点 3进一步理解 3.1关于UDP报头 3.2关于报文 4基于UDP的应用层协议 三TCP协议 1格式 2TCP基本通信 2.1关于可靠性 2.2TCP通信模式 3超时重传 4连接管理 4.1建立…

RocketMq的学习

1.mq的秒杀场景 2.mq产品的选型

基于SpringBoot的校园兼职管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…

solidity中的mapping以及Memory,Storage Calldata

1.Memory&#xff0c;Storage & Calldata 在 Solidity 中&#xff0c;有以下几种数据存储位置&#xff1a; 栈&#xff08;Stack&#xff09;&#xff1a;栈是一种临时存储区域&#xff0c;用于存储局部变量和函数参数。在函数执行期间&#xff0c;栈上的数据会被分配和释…

探索光耦:光耦——电动自行车安全与智能的坚实保障

随着电动自行车市场的蓬勃发展&#xff0c;如何提升其安全性、可靠性和智能化水平已成为行业关注的焦点。在众多关键元件中&#xff0c;光电耦合器&#xff08;简称光耦&#xff09;正以其独特的功能&#xff0c;成为电动自行车设计中的关键角色。下面&#xff0c;让我们一同探…

Ubuntu22.04阿里云服务器 Gitlab搭建CICD

gitlab搭建cicd流水线教程 1、阿里云申请免费云盘 申请免费云盘用于创建gitlab 申请方法百度 2、安装gitlab-ce 更新系统&#xff1a; sudo apt update sudo apt upgrade -y 安装必要的依赖&#xff1a; sudo apt install -y curl openssh-server ca-certificates pos…

【云原生】Helm资源清单管理工具

资源清单管理工具-Helm 文章目录 资源清单管理工具-Helm资源列表基础环境一、Helm的介绍1.1、Helm的价值概述1.2、Helm的关键名词 二、安装部署Helm2.1、解压安装包2.2、添加命令补全设置 三、使用Helm部署服务管理3.1、使用Helm创建chart3.2、响应式创建名称空间3.3、安装char…

基于Arduino的植物状态监测系统

Arduino植物监测/浇水系统 本项目的3D打印及源码开源&#xff0c;可以私信我进行获取 简介 大家好&#xff0c;今天我将向大家介绍一个非常有趣的项目——Arduino植物监测/浇水系统。这个项目利用一些传感器来观察土壤的状况&#xff0c;并根据这些读数来判断植物是否需要浇…

异构环境下统一授权管理系统的兼容性具体如何实现?

在异构环境中&#xff0c;由于不同系统的差异性&#xff0c;实现统一授权管理面临诸多挑战。其中&#xff0c;兼容性问题是关键之一。兼容性的实现不仅关系到不同系统之间的协同工作&#xff0c;还直接影响到整个管理系统的效率和稳定性。 异构系统带来的挑战 异构系统的存在…

手写mybatis之通过注解配置执行SQL语句

前言 可能领导也都觉得可能就是码农不爱说话&#xff0c;其实不爱说话是一方面&#xff0c;但还有另外一方面是有些领导对于码农提出的问题&#xff0c;给出的回复往往是&#xff1a;“你提出这个问题&#xff0c;你就要给出这个问题的解决办法&#xff01;” 所以不同的岗位要…

AD24之铺铜操作

1.选择板框&#xff0c;即机械1层&#xff0c;转换为覆铜 这样顶层就铺好了&#xff0c;还需要铺底层 2.打开底层&#xff0c;选择板框&#xff0c;转换为铺铜&#xff0c;然后给铜皮添加网络和层&#xff0c;最后是铺铜 注意&#xff1a;None铺铜是无效果的&#xff0c;要Ha…

2.使用 Label Studio 标注文本

使用 Label Studio 标注文本 文章目录 使用 Label Studio 标注文本前言Label Studio的简单使用1.创建项目2.添加本地存储3.选择标注模板4.添加数据5.标注6.添加关系 总结 前言 Label Studio是一个开源的功能强大的标注平台&#xff0c;可以标注视频&#xff0c;图片&#xff0…

一个新韭菜的炒股心得

一个新韭菜的炒股心得 前言 股市其实是一场修行。时刻控制人性的弱点。所以量化优势明显&#xff0c;它没有情绪&#xff0c;可以随意止盈止损。我从一个小白一路走过来&#xff0c;发现A股里有学不完的知识,有做不完的功课。我的主要关注点在如何有效实现价值投资(价值投资在…

算法: 位运算题目练习

文章目录 位运算判定字符是否唯一丢失的数字两整数之和只出现一次的数字 II消失的两个数字常见位运算总结 位运算 判定字符是否唯一 有很多解法,比如hash表,或者给字符串排个序,然后遍历… 写这道题时没注意到如果出现奇数个相同字符,此时就应该返回false了. 而不是全部放到位…