​Segment-and-Track Anything——通用智能视频分割、跟踪、编辑算法解读与源码部署

news2024/11/24 6:52:23

一、 万物分割

随着Meta发布的Segment Anything Model (万物分割)的论文并开源了相关的算法,我们可以从中看到,SAM与GPT-4类似,这篇论文的目标是(零样本)分割一切,将自然语言处理(NLP)的提示范式引入了计算机视觉(CV)领域,为CV基础模型提供了更广泛的支持和深度研究的机会。
Segment Anything与传统的图像分割有两个很大的区别:

1、数据收集和主动学习的方式。

对于一个庞大的数据集,例如包含十亿组数据的情况,标注全部数据几乎是不可行的。因此,一个解决方案是采用主动学习的方法。这种方法可以分为以下步骤:
初步标注: 首先,对数据集的一部分进行手动标注。这可以是一个小样本,但应涵盖多种情况和类别,以确保模型获得足够的多样性。
半监督学习: 接下来,使用已标注的数据来训练一个初始模型。这个模型可以用来预测未标注数据的标签。
人工校验与修正: 模型生成的预测标签需要经过人工校验和修正,以确保其准确性。这可以通过专业人员或者众包的方式来完成。
迭代循环: 重复上述步骤,逐渐扩展已标注数据的数量。每次迭代都会提高模型的性能,因为它可以在更多数据上进行训练。
通过这种方式,可以逐步提高数据集的标注质量,而不需要手动标注所有数据。当数据集足够大并且模型被训练到一定程度时,其性能将会显著提升。

2. prompt

Segment Anything 引入了prompt的概念。Prompt是一种用户输入的提示,用于引导模型生成特定类型的回复。这在像GPT-3和SAM这样的模型中非常有用。用户可以提供一个问题或者描述,以帮助模型理解其意图并生成相关的回答或操作。
例如,在SAM中,你可以输入一个提示词,如“Cat”或“Dog”,以告诉模型你希望它分割出照片中的猫或狗。模型将自动检测并绘制框,以实现分割。这个提示词可以用来限定模型的任务,使其更专注于特定的信息提取或操作。
在这里插入图片描述
这两个概念都是在处理大规模数据和提高模型性能方面非常重要的工程性工作。通过合理的数据收集和主动学习策略,以及通过引导模型的prompt,可以更好地满足用户需求,提高模型的效果,并逐步改进模型的性能。

二、​Segment-and-Track Anything

1、算法简介

SAM的出现统一了分割这个任务很多应用,也表明了在CV领域可能存在大规模模型的潜力。这一突破肯定会对CV领域的研究带来重大变革,许多任务将得到统一处理。这一新的数据集和范式结合了超强的零样本泛化能力,将对CV领域产生深远影响。但缺乏对视频数据的支持。随后,浙江大学ReLER实验室的科研人员在最新开源的SAM-Track项目其中,解锁了SAM的视频分割能力,即:分割并跟踪一切(Segment-and-track anything,SAM-track)。SAM-Track在单卡上即可支持各种时空场景中的目标分割和跟踪,包括街景、AR、细胞、动画、航拍等,可同时追踪超过200个物体,为用户提供了强大的视频编辑能力。“Segment and Track Anything” 利用自动和交互式方法。主要使用的算法包括 SAM(Segment Anything Models)用于自动/交互式关键帧分割,以及 DeAOT(Decoupling features in Associating Objects with Transformers)(NeurIPS2022)用于高效的多目标跟踪和传播。SAM-Track 管道实现了 SAM 的动态自动检测和分割新物体,而 DeAOT 负责跟踪所有识别到的物体。

2、项目特点

自动/交互式分割:项目中的 SAM(Segment Anything Models)算法提供了自动和交互式关键帧分割的功能。通过 SAM,用户可以选择使用自动分割算法或与算法进行交互,以实现对视频中任意对象的精确分割。这种灵活性使得该项目适用于不同需求的应用场景。

高效多目标跟踪:Segment-and-Track-Anything 还引入了 DeAOT 算法,用于实现高效的多目标跟踪和传播。DeAOT 利用先进的跟踪技术,能够准确地跟踪视频中的多个对象,并支持对象之间的传播和关联。这使得项目在处理复杂场景和多目标跟踪任务时表现出色。

独立和开放性:该项目是一个独立的开源项目,可以直接访问和使用。它提供了丰富的文档和示例代码,帮助用户快速上手并进行定制开发。同时,项目欢迎社区的贡献和扩展,这使得用户能够与其他研究者和开发者共享经验和成果。

应用广泛性:Segment-and-Track-Anything 的分割和跟踪功能可以应用于各种视频分析任务,包括视频监控、智能交通、行为分析等。它为研究者和开发者提供了一个强大的工具,用于处理和分析具有复杂动态场景的视频数据。
在这里插入图片描述

三、项目部署

项目地址:https://github.com/z-x-yang/Segment-and-Track-Anything

1.部署环境

我这里测试部署的系统win 10, cuda 11.8,cudnn 8.5,GPU是RTX 3060, 8G显存,使用conda创建虚拟环境。
创建并激活一个虚拟环境:

conda create -n sta python==3.10
activate sta

下载项目:

git clone https://github.com/z-x-yang/Segment-and-Track-Anything.git
cd Segment-and-Track-Anything
pip install gradio
pip install scikit-image

因为要使用GPU,这里单独安装pytorch:

conda install pytorch2.0.0 torchvision0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia

因为项目的依赖要使用sh脚本进行安装,win下不支持bash,所以要单独安装m2-base:

conda install m2-base

安装项目依赖:

bash script/install.sh

当出现下面提示代表安装成功。
在这里插入图片描述
GroundingDINO可能会安装不成功,可以直接从源码安装:

git clone https://github.com/IDEA-Research/GroundingDINO.git
cd GroundingDINO/
pip install -e .
cd …

下载所需模型:

bash script/download_ckpt.sh

如果模型下载不成功,也可以手动复制这个地址把模型下载了放到指定目录目录.

2.运行项目

python app.py

然后打开http://127.0.0.1:7860/
在这里插入图片描述
导入一个视频,然后只追踪其中一个人,效果如下:
在这里插入图片描述

视频目标追踪:

目标分割与目标追踪

3.分割与追踪处理代码

import sys
sys.path.append("..")
sys.path.append("./sam")
from sam.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from aot_tracker import get_aot
import numpy as np
from tool.segmentor import Segmentor
from tool.detector import Detector
from tool.transfer_tools import draw_outline, draw_points
import cv2
from seg_track_anything import draw_mask


class SegTracker():
    def __init__(self,segtracker_args, sam_args, aot_args) -> None:
        """
         Initialize SAM and AOT.
        """
        self.sam = Segmentor(sam_args)
        self.tracker = get_aot(aot_args)
        self.detector = Detector(self.sam.device)
        self.sam_gap = segtracker_args['sam_gap']
        self.min_area = segtracker_args['min_area']
        self.max_obj_num = segtracker_args['max_obj_num']
        self.min_new_obj_iou = segtracker_args['min_new_obj_iou']
        self.reference_objs_list = []
        self.object_idx = 1
        self.curr_idx = 1
        self.origin_merged_mask = None  # init by segment-everything or update
        self.first_frame_mask = None

        # debug
        self.everything_points = []
        self.everything_labels = []
        print("SegTracker has been initialized")

    def seg(self,frame):
        '''
        Arguments:
            frame: numpy array (h,w,3)
        Return:
            origin_merged_mask: numpy array (h,w)
        '''
        frame = frame[:, :, ::-1]
        anns = self.sam.everything_generator.generate(frame)

        # anns is a list recording all predictions in an image
        if len(anns) == 0:
            return
        # merge all predictions into one mask (h,w)
        # note that the merged mask may lost some objects due to the overlapping
        self.origin_merged_mask = np.zeros(anns[0]['segmentation'].shape,dtype=np.uint8)
        idx = 1
        for ann in anns:
            if ann['area'] > self.min_area:
                m = ann['segmentation']
                self.origin_merged_mask[m==1] = idx
                idx += 1
                self.everything_points.append(ann["point_coords"][0])
                self.everything_labels.append(1)

        obj_ids = np.unique(self.origin_merged_mask)
        obj_ids = obj_ids[obj_ids!=0]

        self.object_idx = 1
        for id in obj_ids:
            if np.sum(self.origin_merged_mask==id) < self.min_area or self.object_idx > self.max_obj_num:
                self.origin_merged_mask[self.origin_merged_mask==id] = 0
            else:
                self.origin_merged_mask[self.origin_merged_mask==id] = self.object_idx
                self.object_idx += 1

        self.first_frame_mask = self.origin_merged_mask
        return self.origin_merged_mask

    def update_origin_merged_mask(self, updated_merged_mask):
        self.origin_merged_mask = updated_merged_mask
        # obj_ids = np.unique(updated_merged_mask)
        # obj_ids = obj_ids[obj_ids!=0]
        # self.object_idx = int(max(obj_ids)) + 1

    def reset_origin_merged_mask(self, mask, id):
        self.origin_merged_mask = mask
        self.curr_idx = id

    def add_reference(self,frame,mask,frame_step=0):
        '''
        Add objects in a mask for tracking.
        Arguments:
            frame: numpy array (h,w,3)
            mask: numpy array (h,w)
        '''
        self.reference_objs_list.append(np.unique(mask))
        self.curr_idx = self.get_obj_num()
        self.tracker.add_reference_frame(frame,mask, self.curr_idx, frame_step)

    def track(self,frame,update_memory=False):
        '''
        Track all known objects.
        Arguments:
            frame: numpy array (h,w,3)
        Return:
            origin_merged_mask: numpy array (h,w)
        '''
        pred_mask = self.tracker.track(frame)
        if update_memory:
            self.tracker.update_memory(pred_mask)
        return pred_mask.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.uint8)
    
    def get_tracking_objs(self):
        objs = set()
        for ref in self.reference_objs_list:
            objs.update(set(ref))
        objs = list(sorted(list(objs)))
        objs = [i for i in objs if i!=0]
        return objs
    
    def get_obj_num(self):
        objs = self.get_tracking_objs()
        if len(objs) == 0: return 0
        return int(max(objs))

    def find_new_objs(self, track_mask, seg_mask):
        '''
        Compare tracked results from AOT with segmented results from SAM. Select objects from background if they are not tracked.
        Arguments:
            track_mask: numpy array (h,w)
            seg_mask: numpy array (h,w)
        Return:
            new_obj_mask: numpy array (h,w)
        '''
        new_obj_mask = (track_mask==0) * seg_mask
        new_obj_ids = np.unique(new_obj_mask)
        new_obj_ids = new_obj_ids[new_obj_ids!=0]
        # obj_num = self.get_obj_num() + 1
        obj_num = self.curr_idx
        for idx in new_obj_ids:
            new_obj_area = np.sum(new_obj_mask==idx)
            obj_area = np.sum(seg_mask==idx)
            if new_obj_area/obj_area < self.min_new_obj_iou or new_obj_area < self.min_area\
                or obj_num > self.max_obj_num:
                new_obj_mask[new_obj_mask==idx] = 0
            else:
                new_obj_mask[new_obj_mask==idx] = obj_num
                obj_num += 1
        return new_obj_mask
        
    def restart_tracker(self):
        self.tracker.restart()

    def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray,):
        ''''
        Use bbox-prompt to get mask
        Parameters:
            origin_frame: H, W, C
            bbox: [[x0, y0], [x1, y1]]
        Return:
            refined_merged_mask: numpy array (h, w)
            masked_frame: numpy array (h, w, c)
        '''
        # get interactive_mask
        interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
        refined_merged_mask = self.add_mask(interactive_mask)

        # draw mask
        masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)

        # draw bbox
        masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255))

        return refined_merged_mask, masked_frame

    def seg_acc_click(self, origin_frame: np.ndarray, coords: np.ndarray, modes: np.ndarray, multimask=True):
        '''
        Use point-prompt to get mask
        Parameters:
            origin_frame: H, W, C
            coords: nd.array [[x, y]]
            modes: nd.array [[1]]
        Return:
            refined_merged_mask: numpy array (h, w)
            masked_frame: numpy array (h, w, c)
        '''
        # get interactive_mask
        interactive_mask = self.sam.segment_with_click(origin_frame, coords, modes, multimask)

        refined_merged_mask = self.add_mask(interactive_mask)

        # draw mask
        masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)

        # draw points
        # self.everything_labels = np.array(self.everything_labels).astype(np.int64)
        # self.everything_points = np.array(self.everything_points).astype(np.int64)

        masked_frame = draw_points(coords, modes, masked_frame)

        # draw outline
        masked_frame = draw_outline(interactive_mask, masked_frame)

        return refined_merged_mask, masked_frame

    def add_mask(self, interactive_mask: np.ndarray):
        '''
        Merge interactive mask with self.origin_merged_mask
        Parameters:
            interactive_mask: numpy array (h, w)
        Return:
            refined_merged_mask: numpy array (h, w)
        '''
        if self.origin_merged_mask is None:
            self.origin_merged_mask = np.zeros(interactive_mask.shape,dtype=np.uint8)

        refined_merged_mask = self.origin_merged_mask.copy()
        refined_merged_mask[interactive_mask > 0] = self.curr_idx

        return refined_merged_mask
    
    def detect_and_seg(self, origin_frame: np.ndarray, grounding_caption, box_threshold, text_threshold, box_size_threshold=1, reset_image=False):
        '''
        Using Grounding-DINO to detect object acc Text-prompts
        Retrun:
            refined_merged_mask: numpy array (h, w)
            annotated_frame: numpy array (h, w, 3)
        '''
        # backup id and origin-merged-mask
        bc_id = self.curr_idx
        bc_mask = self.origin_merged_mask

        # get annotated_frame and boxes
        annotated_frame, boxes = self.detector.run_grounding(origin_frame, grounding_caption, box_threshold, text_threshold)
        for i in range(len(boxes)):
            bbox = boxes[i]
            if (bbox[1][0] - bbox[0][0]) * (bbox[1][1] - bbox[0][1]) > annotated_frame.shape[0] * annotated_frame.shape[1] * box_size_threshold:
                continue
            interactive_mask = self.sam.segment_with_box(origin_frame, bbox, reset_image)[0]
            refined_merged_mask = self.add_mask(interactive_mask)
            self.update_origin_merged_mask(refined_merged_mask)
            self.curr_idx += 1

        # reset origin_mask
        self.reset_origin_merged_mask(bc_mask, bc_id)

        return refined_merged_mask, annotated_frame

if __name__ == '__main__':
    from model_args import segtracker_args,sam_args,aot_args

    Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
    
    # ------------------ detect test ----------------------
    
    origin_frame = cv2.imread('/data2/cym/Seg_Tra_any/Segment-and-Track-Anything/debug/point.png')
    origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_BGR2RGB)
    grounding_caption = "swan.water"
    box_threshold = 0.25
    text_threshold = 0.25

    predicted_mask, annotated_frame = Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold)
    masked_frame = draw_mask(annotated_frame, predicted_mask)
    origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_RGB2BGR)

    cv2.imwrite('./debug/masked_frame.png', masked_frame)
    cv2.imwrite('./debug/x.png', annotated_frame)

四、 报错

1.下载模型问题

requests.exceptions.SSLError: (MaxRetryError(“HTTPSConnectionPool(host=‘huggingface.co’, port=443): Max retries exceeded with url: /bert-base-uncased/resolve/main/tokenizer_config.json (Caused by SSLError(SSLEOFError(8, ‘EOF occurred in violation of protocol (_ssl.c:997)’)))”), ‘(Request ID: d4f21f96-45fd-47a1-9afb-b7e4260a6f3b)’)

https://huggingface.co/bert-base-uncased/tree/main

在这里插入图片描述
可以手动从这里下载模型,然后放到指定的目录:
在这里插入图片描述

2. imageio版本问题

TypeError: The keyword fps is no longer supported. Use duration(in ms) instead, e.g. fps=50 == duration=20 (1000 * 1/50).

pip uninstall imageio
pip install imageio==2.23.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1037021.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构】二叉排序树;平衡二叉树的知识点学习总结

目录 1、二叉排序树 1.1 定义 1.2 查找操作 1.3 插入操作 1.4 删除操作 1.5 C语言实现二叉排序树的基本操作 2、平衡二叉树的知识点总结 2.1 定义 2.2 插入操作 2.3 调整“不平衡” 2.4 删除操作 1、二叉排序树 1.1 定义 二叉排序树&#xff08;Binary Search …

云计算与大数据——部署Hadoop集群并运行MapReduce集群(超级详细!)

云计算与大数据——部署Hadoop集群并运行MapReduce集群(超级详细&#xff01;) Linux搭建Hadoop集群(CentOS7hadoop3.2.0JDK1.8Mapreduce完全分布式集群) 本文章所用到的版本号&#xff1a; CentOS7 Hadoop3.2.0 JDK1.8 基本概念及重要性 很多小伙伴部署集群用hadoop用mapr…

C++设计模式_06_Decorator 装饰模式

本篇将会介绍Decorator 装饰模式&#xff0c;它是属于一个新的类别&#xff0c;按照C设计模式_03_模板方法Template Method中介绍的划分为“单一职责”模式。 “单一职责”模式讲的是在软件组件的设计中&#xff0c;如果责任划分的不清晰&#xff0c;使用继承得到的结果往往是随…

HT for Web (Hightopo) 使用心得(2)- 2D 图纸、节点、连线 与基本动画

概括来说&#xff0c;用 HT for Web 做可视化主要分为两部分&#xff0c;也就是 2D 和 3D。这两部分需要单独创建。在它们被创建完成后&#xff0c;我们再把它们集成到一起。 HT for Web 的 2D 部分主要是指 ht.graph.GraphView (简称 GraphView&#xff0c;也就是 2D 图纸)。…

Java项目:SSM的食堂点餐系统

作者主页&#xff1a;Java毕设网 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 一、相关文档 系统中的核心用户是系统管理员&#xff0c;管理员登录后&#xff0c;通过管理员菜单来管理后台系统。主要功能有&#xff1a;个人中心、用户管理…

自动发现、zabbix_proxy代理

自动发现&#xff1a;自己去发现被监控的主机 它能够根据用户事先定义的规则自动添加监控的主机或服务等。 优点 加快Zabbix部署&#xff08;agent&#xff09; 简化管理 在快速变化的环境中使用Zabbix&#xff0c;而不需要过度管理 部署自动发现(新机子) rpm -Uvh https://re…

OSI 七层网络协议最全的图

OSI 七层网络协议最全的图 文章出处&#xff1a;https://www.shuzhiduo.com/A/RnJWawowdq/

DINO(ICLR 2023)

DINO&#xff08;ICLR 2023&#xff09; DETR with Improved deNoising anchOr box DINO发展&#xff1a; Conditional DETR->DAB-DETR&#xff08;4D,WH修正&#xff09; DN-DETR&#xff08;去噪训练&#xff0c;deNoising 稳定匹配过程&#xff09; Deformable DETR&…

后端大厂面试-16道面试题

1 java集合类有哪些&#xff1f; List是有序的Collection&#xff0c;使用此接口能够精确的控制每个元素的插入位置&#xff0c;用户能根据索引访问List中元素。常用的实现List的类有LinkedList&#xff0c;ArrayList&#xff0c;Vector&#xff0c;Stack。 ArrayList是容量…

基于同名面片的TLS测站点云配准

1、原理介绍 2、代码介绍 基于C++编写的程序代码如下,其依赖eigen矩阵运算库,在创建工程时包含库目录中使用了相对路径,因此其下载下来直接可以运行,不用单独在设置环境,非常方便。

Java项目:SpringBoot高校宿舍管理系统

作者主页&#xff1a;Java毕设网 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 一、相关文档 宿舍是大学生学习与生活的主要场所之一&#xff0c;宿舍管理是高校学工管理事务中尤为重要的一项。随着我国高校招生规模的进一步扩大&#xff0…

异步回调

Future 设计的初衷&#xff1a;对将来的某个事件的结果进行建模 package com.kuang.future;import com.kuang.pc.C;import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.uti…

ubuntu | 安装NVIDIA套件:驱动、CUDA、cuDNN

CUDA 查看支持最高的cuda版本 nvidia-smiCUDA Version:12.2 区官网下在12.2.x最新的版本即可CUDA Toolkit Archive | NVIDIA Developer 下载安装 wget https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run sudo…

《开发实战》16 | 缓存设计:缓存可以锦上添花也可以落井下石

不要把 Redis 当作数据库 Redis 的确具有数据持久化功能&#xff0c;可以实现服务重启后数据不丢失。这一点&#xff0c;很容易让我们误认为 Redis 可以作为高性能的 KV 数据库。Redis 的特点是&#xff0c;处理请求很快&#xff0c;但无法保存超过内存大小的数据。第一&#…

leetcode 22. 括号生成

2023.9.24 看到组合两个字&#xff0c;想到了回溯。 大致思路是将所有可能的组合列出来&#xff0c;通过中止条件筛选掉无效的括号。 第一个中止条件&#xff1a;如果右括号数量大于左括号&#xff0c;那括号肯定无效。 第二个中止条件&#xff1a;当左右括号数量相等&#x…

swiper使用

介绍 Swiper&#xff08;swiper master&#xff09;是一个第三方的库&#xff0c;可以用来实现移动端、pc端的滑动操作。&#xff0c;swiper应用广泛&#xff0c;使用频率仅次于jquery, 轮播图类排名第一&#xff0c;是网页设计师必备技能&#xff0c;众多耳熟能详的品牌在使用…

Keil 无法烧写程序

问题描述&#xff1a; Keil MDK V5.38 按 F8 键无法烧录程序&#xff0c;提示: Error: Flash Download failed - "Cortex-M7", No Algorithm found for: 08000000H - 080013D3H 解决办法&#xff1a; Debug 工具改为&#xff1a;ST-Link Debugger Debug 的 Conne…

mac怎么把两张图片拼在一起

mac怎么把两张图片拼在一起&#xff1f;在如今的生活中&#xff0c;喜欢摄影的朋友们越来越多。拍照已经成为我们的一种习惯&#xff0c;因为当我们遇到美景或迷人的人物时&#xff0c;总是忍不住按下快门&#xff0c;将它们定格。随着时间的推移&#xff0c;我们渐渐发现自己的…

[Java | Web] JavaWeb——JSON与AJAX简介

目录 一、JSON 简介 1、什么是 JSON 2、JSON 的定义和访问 3、JSON 在 JS 中两种常用的转换方法 4、JSON 在 Java 中的使用 5、匿名内部类 二、AJAX 简介 1、什么是 AJAX 2、原生 JS 的 AJAX 请求示例 3、JQuery 中的 AJAX 请求 一、JSON 简介 1、什么是 JSON JSON…

Elasticsearch:什么是向量和向量存储数据库,我们为什么关心?

Elasticsearch 从 7.3 版本开始支持向量搜索。从 8.0 开始支持带有 HNSW 的 ANN 向量搜索。目前 Elasticsearch 已经是全球下载量最多的向量数据库。它允许使用密集向量和向量比较来搜索文档。 矢量搜索在人工智能和机器学习领域有许多重要的应用。 有效存储和检索向量的数据库…