bytetrack 解决跟踪后框晃动的问题

news2025/1/7 10:44:50

使用距离最近的匹配的检测框 替代 bytetrack返回的跟踪框 作为最终的返回结果
在这里插入图片描述

完整byte_tracker.py代码为:

import numpy as np
from collections import deque
import os
import os.path as osp
import copy
import torch
import torch.nn.functional as F

from yolov5.utils.general import xywh2xyxy, xyxy2xywh


from trackers.bytetrack.kalman_filter import KalmanFilter
from trackers.bytetrack import matching
from trackers.bytetrack.basetrack import BaseTrack, TrackState

class STrack(BaseTrack):
    shared_kalman = KalmanFilter()
    def __init__(self, tlwh, score, cls):

        # wait activate
        self._tlwh = np.asarray(tlwh, dtype=np.float32)
        self.kalman_filter = None
        self.mean, self.covariance = None, None
        self.is_activated = False

        self.score = score
        self.tracklet_len = 0
        self.cls = cls

    def predict(self):
        mean_state = self.mean.copy()
        if self.state != TrackState.Tracked:
            mean_state[7] = 0
        self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)

    @staticmethod
    def multi_predict(stracks):
        if len(stracks) > 0:
            multi_mean = np.asarray([st.mean.copy() for st in stracks])
            multi_covariance = np.asarray([st.covariance for st in stracks])
            for i, st in enumerate(stracks):
                if st.state != TrackState.Tracked:
                    multi_mean[i][7] = 0
            multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
            for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
                stracks[i].mean = mean
                stracks[i].covariance = cov

    def activate(self, kalman_filter, frame_id):
        """Start a new tracklet"""
        self.kalman_filter = kalman_filter
        self.track_id = self.next_id()
        self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))

        self.tracklet_len = 0
        self.state = TrackState.Tracked
        if frame_id == 1:
            self.is_activated = True
        # self.is_activated = True
        self.frame_id = frame_id
        self.start_frame = frame_id

    def re_activate(self, new_track, frame_id, new_id=False):
        self.mean, self.covariance = self.kalman_filter.update(
            self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
        )
        self.tracklet_len = 0
        self.state = TrackState.Tracked
        self.is_activated = True
        self.frame_id = frame_id
        if new_id:
            self.track_id = self.next_id()
        self.score = new_track.score
        self.cls = new_track.cls

    def update(self, new_track, frame_id):
        """
        Update a matched track
        :type new_track: STrack
        :type frame_id: int
        :type update_feature: bool
        :return:
        """
        self.frame_id = frame_id
        self.tracklet_len += 1
        # self.cls = cls

        new_tlwh = new_track.tlwh
        self.mean, self.covariance = self.kalman_filter.update(
            self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
        self.state = TrackState.Tracked
        self.is_activated = True

        self.score = new_track.score

    @property
    # @jit(nopython=True)
    def tlwh(self):
        """Get current position in bounding box format `(top left x, top left y,
                width, height)`.
        """
        if self.mean is None:
            return self._tlwh.copy()
        ret = self.mean[:4].copy()
        ret[2] *= ret[3]
        ret[:2] -= ret[2:] / 2
        return ret

    @property
    # @jit(nopython=True)
    def tlbr(self):
        """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
        `(top left, bottom right)`.
        """
        ret = self.tlwh.copy()
        ret[2:] += ret[:2]
        return ret

    @staticmethod
    # @jit(nopython=True)
    def tlwh_to_xyah(tlwh):
        """Convert bounding box to format `(center x, center y, aspect ratio,
        height)`, where the aspect ratio is `width / height`.
        """
        ret = np.asarray(tlwh).copy()
        ret[:2] += ret[2:] / 2
        ret[2] /= ret[3]
        return ret

    def to_xyah(self):
        return self.tlwh_to_xyah(self.tlwh)

    @staticmethod
    # @jit(nopython=True)
    def tlbr_to_tlwh(tlbr):
        ret = np.asarray(tlbr).copy()
        ret[2:] -= ret[:2]
        return ret

    @staticmethod
    # @jit(nopython=True)
    def tlwh_to_tlbr(tlwh):
        ret = np.asarray(tlwh).copy()
        ret[2:] += ret[:2]
        return ret

    def __repr__(self):
        return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)


class BYTETracker(object):
    def __init__(self, track_thresh=0.45, match_thresh=0.8, track_buffer=25, frame_rate=30):
        self.tracked_stracks = []  # type: list[STrack]
        self.lost_stracks = []  # type: list[STrack]
        self.removed_stracks = []  # type: list[STrack]

        self.frame_id = 0
        self.track_buffer=track_buffer
        
        self.track_thresh = track_thresh
        self.match_thresh = match_thresh
        self.det_thresh = track_thresh + 0.1
        self.buffer_size = int(frame_rate / 30.0 * track_buffer)
        self.max_time_lost = self.buffer_size
        self.kalman_filter = KalmanFilter()

    def update(self, dets, _):
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []

        xyxys = dets[:, 0:4]
        xywh = xyxy2xywh(xyxys)
        confs = dets[:, 4]
        clss = dets[:, 5]
        
        classes = clss.numpy()
        xyxys = xyxys.numpy()
        confs = confs.numpy()

        remain_inds = confs > self.track_thresh
        inds_low = confs > 0.1
        inds_high = confs < self.track_thresh

        inds_second = np.logical_and(inds_low, inds_high)
        
        dets_second = xywh[inds_second]
        dets = xywh[remain_inds]
        
        scores_keep = confs[remain_inds]
        scores_second = confs[inds_second]
        
        clss_keep = classes[remain_inds]
        clss_second = classes[inds_second]
        

        if len(dets) > 0:
            '''Detections'''
            detections = [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores_keep, clss_keep)]
        else:
            detections = []

        ''' Add newly detected tracklets to tracked_stracks'''
        unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)

        ''' Step 2: First association, with high score detection boxes'''
        strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
        # Predict the current location with KF
        STrack.multi_predict(strack_pool)
        dists = matching.iou_distance(strack_pool, detections)
        #if not self.args.mot20:
        dists = matching.fuse_score(dists, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh)

        for itracked, idet in matches:
            track = strack_pool[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(detections[idet], self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        ''' Step 3: Second association, with low score detection boxes'''
        # association the untrack to the low score detections
        if len(dets_second) > 0:
            '''Detections'''
            detections_second = [STrack(xywh, s, c) for (xywh, s, c) in zip(dets_second, scores_second, clss_second)]
        else:
            detections_second = []
        r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
        dists = matching.iou_distance(r_tracked_stracks, detections_second)
        matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections_second[idet]
            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:
                track.mark_lost()
                lost_stracks.append(track)

        '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
        detections = [detections[i] for i in u_detection]
        dists = matching.iou_distance(unconfirmed, detections)
        #if not self.args.mot20:
        dists = matching.fuse_score(dists, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])
        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)

        """ Step 4: Init new stracks"""
        for inew in u_detection:
            track = detections[inew]
            if track.score < self.det_thresh:
                continue
            track.activate(self.kalman_filter, self.frame_id)
            activated_starcks.append(track)
        """ Step 5: Update state"""
        for track in self.lost_stracks:
            if self.frame_id - track.end_frame > self.max_time_lost:
                track.mark_removed()
                removed_stracks.append(track)

        # print('Ramained match {} s'.format(t4-t3))

        self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
        self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
        self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [track for track in self.tracked_stracks if track.is_activated]
        outputs = []

        for t in output_stracks:
            output= []
            tlwh = t.tlwh
            tid = t.track_id
            tlwh = np.expand_dims(tlwh, axis=0)
            xyxy = xywh2xyxy(tlwh)
            xyxy = np.squeeze(xyxy, axis=0)

            #########################################################
            x_track = (xyxy[0] + xyxy[2]) / 2
            y_track = (xyxy[1] + xyxy[3]) / 2
            x1 = -1
            y1 = -1
            x2 = -1
            y2 = -1
            max_dis = 10000000
            for xyxy_det in xyxys:
                x_det = (xyxy_det[0] + xyxy_det[2]) / 2
                y_det = (xyxy_det[1] + xyxy_det[3]) / 2
                dis = (x_det - x_track) * (x_det - x_track) + (y_det - y_track) * (y_det - y_track)
                if dis < max_dis:
                    max_dis = dis
                    x1 = xyxy_det[0]
                    y1 = xyxy_det[1]
                    x2 = xyxy_det[2]
                    y2 = xyxy_det[3]
            xyxy =  np.squeeze([x1,y1,x2,y2])
            #########################################################

            output.extend(xyxy)
            output.append(tid)
            output.append(t.cls)
            output.append(t.score)
            outputs.append(output)

        return outputs
#track_id, class_id, conf

def joint_stracks(tlista, tlistb):
    exists = {}
    res = []
    for t in tlista:
        exists[t.track_id] = 1
        res.append(t)
    for t in tlistb:
        tid = t.track_id
        if not exists.get(tid, 0):
            exists[tid] = 1
            res.append(t)
    return res


def sub_stracks(tlista, tlistb):
    stracks = {}
    for t in tlista:
        stracks[t.track_id] = t
    for t in tlistb:
        tid = t.track_id
        if stracks.get(tid, 0):
            del stracks[tid]
    return list(stracks.values())


def remove_duplicate_stracks(stracksa, stracksb):
    pdist = matching.iou_distance(stracksa, stracksb)
    pairs = np.where(pdist < 0.15)
    dupa, dupb = list(), list()
    for p, q in zip(*pairs):
        timep = stracksa[p].frame_id - stracksa[p].start_frame
        timeq = stracksb[q].frame_id - stracksb[q].start_frame
        if timep > timeq:
            dupb.append(q)
        else:
            dupa.append(p)
    resa = [t for i, t in enumerate(stracksa) if not i in dupa]
    resb = [t for i, t in enumerate(stracksb) if not i in dupb]
    return resa, resb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2271564.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何使用OBS Studio录制屏幕?

可以进入官网或github进行下载&#xff1a; https://obsproject.com/download 安装包解压后进入bin 进入64-bit 选择obs 64 进入OBS Studio后在来源内右键&#xff0c;选择添加 选择添加显示器采集即可录取整个屏幕&#xff0c;窗口采集可选择窗口进行录制 选择对应显示器即配置…

ArcGIS Server 10.2授权文件过期处理

新的一年&#xff0c;arcgis server授权过期了&#xff0c;服务发不不了。查看ecp授权文件&#xff0c;原来的授权日期就到2024.12.31日。好吧&#xff0c;这里直接给出处理方法。 ArcGIS 10.2安装时&#xff0c;有的破解文件中会有含一个这样的注册程序&#xff0c;没有的话&…

循环冗余校验CRC的介绍

一、简介 循环冗余校验CRC&#xff08;Cyclic Redundancy Check&#xff09;是数据通信领域中最常用的一种差错校验码。该校验方法中&#xff0c;使用多项式出发&#xff08;模2除法&#xff09;运算后的余数为校验字段。CRC只能实现检错&#xff0c;不能实现纠错&#xff0c;使…

消息中间件类型都有哪些

在消息中间件的专业术语中&#xff0c;我们可以根据其特性和使用场景将其划分为几种主要的类型。这些类型不仅反映了它们各自的技术特点&#xff0c;还决定了它们在不同应用场景下的适用性。 1. 点对点&#xff08;Point-to-Point&#xff09;消息中间件&#xff1a; • 这类中…

微信小程序中 “页面” 和 “非页面” 的区别

微信小程序中 “页面” 和 “非页面” 的区别&#xff0c;并用表格进行对比。 核心概念&#xff1a; 页面 (Page)&#xff1a; 页面是微信小程序中用户可以直接交互的视图层&#xff0c;也是小程序的基本组成部分。每个页面都有自己的 WXML 结构、WXSS 样式和 JavaScript 逻辑…

卸载wps后word图标没有变成白纸恢复

这几天下载了个wps教育版&#xff0c;后头用完了删了 用习惯的2019图标 给兄弟我干没了&#xff1f;&#xff1f;&#xff1f; 其他老哥说什么卸载关联重新下 &#xff0c;而且还要什么撤销保存原来的备份什么&#xff0c;兄弟也是不得不怂了 后头就发现了这个半宝藏博主&…

SQL Server导出和导入可选的数据库表和数据,以sql脚本形式

一、导出 1. 打开SQL Server Management Studio&#xff0c;在需要导出表的数据库上单击右键 → 任务 → 生成脚本 2. 在生成脚本的窗口中单击进入下一步 3. 如果只需要导出部分表&#xff0c;则选择第二项**“选择具体的数据库对象(Select specific database objects)”**&am…

基于SpringBoot在线竞拍平台系统功能实现十五

一、前言介绍&#xff1a; 1.1 项目摘要 随着网络技术的飞速发展和电子商务的普及&#xff0c;竞拍系统作为一种新型的在线交易方式&#xff0c;已经逐渐深入到人们的日常生活中。传统的拍卖活动需要耗费大量的人力、物力和时间&#xff0c;从组织拍卖、宣传、报名、竞拍到成…

Android Camera压力测试工具

背景描述&#xff1a; 随着系统的复杂化和业务的积累&#xff0c;日常的功能性测试已不足以满足我们对Android Camera相机系统的测试需求。为了确保Android Camera系统在高负载和多任务情况下的稳定性和性能优化&#xff0c;需要对Android Camera应用进行全面的压测。 对于压…

配置嵌入式服务器

一、如何定制和修改Servlet容器的相关配置 修改和server有关的配置&#xff08;ServerProperties&#xff09; server.port8081 server.context‐path/tx server.tomcat.uri-encodingUTF-8二、注册servlet三个组件【Servlet、Filter、Listener】 由于SpringBoot默认是以jar包…

GPIO、RCC库函数

void GPIO_DeInit(GPIO_TypeDef* GPIOx); void GPIO_AFIODeInit(void); void GPIO_Init(GPIO_TypeDef* GPIOx, GPIO_InitTypeDef* GPIO_InitStruct); void GPIO_StructInit(GPIO_InitTypeDef* GPIO_InitStruct); //输出 读 uint8_t GPIO_ReadInputDataBit(GPIO_TypeDef* GPIOx,…

3D高斯点云CUDA版本数据制作与demo运行

0. 简介 关于UCloud(优刻得)旗下的compshare算力共享平台 UCloud(优刻得)是中国知名的中立云计算服务商&#xff0c;科创板上市&#xff0c;中国云计算第一股。 Compshare GPU算力平台隶属于UCloud&#xff0c;专注于提供高性价4090算力资源&#xff0c;配备独立IP&#xff0c;…

框架模块说明 #09 日志模块_01

背景 日志模块是系统的重要组成部分&#xff0c;主要负责记录系统运行状态和定位错误问题的功能。通常&#xff0c;日志分为系统日志、操作日志和安全日志三类。虽然分布式数据平台是当前微服务架构中的重要部分&#xff0c;但本文的重点并不在此&#xff0c;而是聚焦于自定义…

conda指定路径安装虚拟python环境

DataBall 助力快速掌握数据集的信息和使用方式&#xff0c;会员享有 百种数据集&#xff0c;持续增加中。 需要更多数据资源和技术解决方案&#xff0c;知识星球&#xff1a; “DataBall - X 数据球(free)” -------------------------------------------------------------…

aws(学习笔记第二十二课) 复杂的lambda应用程序(python zip打包)

aws(学习笔记第二十二课) 开发复杂的lambda应用程序(python的zip包) 学习内容&#xff1a; 练习使用CloudShell开发复杂lambda应用程序(python) 1. 练习使用CloudShell CloudShell使用背景 复杂的python的lambda程序会有许多依赖的包&#xff0c;如果不提前准备好这些python的…

driftingblues6靶场攻略

首先 打开kali&#xff0c;扫描主机 地址是192.168.111.143 访问网站 主页源码看一看&#xff0c;没啥用 老套路&#xff0c; 用nmap扫描一下开放端口 用dirsearch扫描一下目录 如果说扫描不到&#xff0c;那就可能是字典不行&#xff0c;换工具就完了 nmap -sV 192.168.…

【顶刊TPAMI 2025】多头编码(MHE)之Part 6:极限分类无需预处理

目录 1 标签分解方法的消融研究2 标签分解对泛化的影响3 讨论4 结论 论文&#xff1a;Multi-Head Encoding for Extreme Label Classification 作者&#xff1a;Daojun Liang, Haixia Zhang, Dongfeng Yuan and Minggao Zhang 单位&#xff1a;山东大学 代码&#xff1a;https:…

vue视频录制 限制大小,限制时长

<template><div style"height: 100vh;background: #000;"><span style"color: #fff;font-size: 18px;">切换数量&#xff1a;{{ devices.length }}</span><video ref"video" autoplay muted playsinline></vid…

毕业项目推荐:基于yolov8/yolov5的行人摔倒检测识别系统(python+卷积神经网络)

文章目录 概要一、整体资源介绍技术要点功能展示&#xff1a;功能1 支持单张图片识别功能2 支持遍历文件夹识别功能3 支持识别视频文件功能4 支持摄像头识别功能5 支持结果文件导出&#xff08;xls格式&#xff09;功能6 支持切换检测到的目标查看 二、数据集三、算法介绍1. YO…

高等数学学习笔记 ☞ 无穷小比较与等价无穷小替换

1. 无穷小比较 1. 本质&#xff1a;就是函数的极限趋于0时的速度&#xff0c;谁快谁慢的问题。 2. 定义&#xff1a;若是在同一自变量的变化过程中的无穷小&#xff0c;且&#xff0c;则&#xff1a; ①&#xff1a;若&#xff0c;则称是比的高阶无穷小&#xff0c;记作&…