1. 动机
目标跟踪是计算机视觉领域一种常用的算法,用于将前后帧中的同一个目标关联起来,从而可以针对某一个特定目标进行分析,如对状态进行投票平滑获取更为稳健的结果。
然而,目前流行的跟踪算法大多是基于检测的bbox之间的IOU来匹配的,这对于某些小目标或者点的检测,IOU通常不是一个好的选择,因为目标太小,很容易使得相邻两帧之间的IOU为0。
为了解决这个问题,本文提出了一种基于点之间距离的跟踪方法:将目标建模为一个点,通过计算前后帧点之间的距离,利用匈牙利匹配来进行跟踪。
2. 方法
直接上代码,里面给出了跟踪方法的定义以及一个使用示例:
"""
Test for multi-target tracker
"""
import numpy as np
from scipy.optimize import linear_sum_assignment
# 定义目标类
class Target:
def __init__(self, x, y):
self.x = x
self.y = y
self.id = None
self.miss_num = 0
self.lost = False # 添加一个标记来表示目标是否丢失
# 定义跟踪器类
class Tracker:
def __init__(self, max_age=1, thres_dist=100):
self.max_age = max_age
self.thres_dist = thres_dist
self.tracked_targets = []
self.last_id = 0
# 在下一帧中更新目标位置并使用匈牙利算法进行匹配
def update(self, targets):
# print([[t.x, t.y] for t in targets])
# print([[t.x, t.y] for t in self.tracked_targets])
# 计算每个跟踪器与目标之间的相似度(这里使用欧氏距离作为相似度指标)
distances = []
for target in targets:
dist_list = []
for tracked_target in self.tracked_targets:
if not tracked_target.lost: # 只考虑未丢失的目标
dist = np.sqrt((target[0] - tracked_target.x) ** 2 + (target[1] - tracked_target.y) ** 2)
dist_list.append(dist)
else:
dist_list.append(np.inf) # 将丢失的目标设置为无穷大距离,避免被匹配
distances.append(dist_list)
distances = np.asarray(distances)
# 使用匈牙利算法进行匹配
row_ind, col_ind = linear_sum_assignment(distances)
# print(row_ind, col_ind, distances)
# 更新匹配成功的目标位置
# print(len(self.tracked_targets))
for i, j in zip(row_ind, col_ind):
if not self.tracked_targets[j].lost and distances[i][j] < self.thres_dist: # 只更新未丢失的目标位置
self.tracked_targets[j].x = targets[i][0]
self.tracked_targets[j].y = targets[i][1]
else:
# 如果目标丢失,继续标记为丢失状态,不进行位置更新,同时累计丢失次数
self.tracked_targets[j].miss_num += 1
if self.tracked_targets[j].miss_num >= self.max_age:
# print("lost", self.tracked_targets[j].id)
self.tracked_targets[j].lost = True
# 添加新的目标到跟踪列表
for j in range(len(targets)):
if j not in col_ind:
self.create_target(targets[j][0], targets[j][1])
# 检测丢失的目标,如果目标丢失超过一定帧数,将其从跟踪列表中删除
lost_targets = []
for target in self.tracked_targets:
if target.lost:
lost_targets.append(target)
for target in lost_targets:
self.tracked_targets.remove(target)
def create_target(self, x, y):
tar = Target(x, y)
tar.id = self.last_id
self.tracked_targets.append(tar)
self.last_id += 1
def get_color_by_id(id):
colors = [
(255, 0, 0),
(255, 255, 0),
(255, 0, 255),
(30, 140, 100),
(0, 255, 0),
(0, 50, 200),
(100, 0, 30),
(100, 100, 0),
(20, 10, 200),
(20, 250, 100),
(145, 0, 90),
(15, 10, 190),
(15, 100, 100),
]
index = id % len(colors)
return colors[index]
if __name__ == '__main__':
import cv2
from PIL import Image
import copy
img = np.ones([500, 500, 3]).astype(np.uint8)
pil_imgs = [Image.fromarray(img)]
# 初始化目标列表
preds = [[1, 1], [1, 499], [499, 1], [499, 499]]
# 初始化跟踪器
tracker = Tracker()
# 模拟多帧跟踪过程
for i in range(100):
# 在每一帧中,随机移动每个目标的位置
delta = 10
if i in [10, 11, 12, 13]:
delta = 30
# print([[t.x, t.y] for t in tracker.tracked_targets])
preds[0][0] += np.random.randint(0, delta)
preds[0][1] += np.random.randint(0, delta)
preds[1][0] += np.random.randint(0, delta)
preds[1][1] -= np.random.randint(0, delta)
preds[2][0] -= np.random.randint(0, delta)
preds[2][1] += np.random.randint(0, delta)
preds[3][0] -= np.random.randint(0, delta)
preds[3][1] -= np.random.randint(0, delta)
# print([[t.x, t.y] for t in tracker.tracked_targets])
# 更新跟踪器的目标列表
tracker.update(copy.deepcopy(preds))
# 打印当前帧中每个跟踪到的目标位置
# print("frame: {}".format(i))
points, ids = [], []
for target in tracker.tracked_targets:
print("Target {}: ({}, {}); lost: {}".format(target.id, target.x, target.y, target.lost))
points.append([target.x, target.y])
ids.append(target.id)
# cv2.circle(img, [target.x, target.y], point_size, point_color, thickness)
cv2.putText(img, str(target.id), [target.x, target.y], cv2.FONT_HERSHEY_SIMPLEX, 0.4,
get_color_by_id(target.id), thickness=1, lineType=cv2.LINE_AA)
pil_imgs.append(Image.fromarray(img))
cv2.imshow('img', img)
if cv2.waitKey(30) == ord('q'):
break
# 创建并保存GIF文件
image_0 = pil_imgs[0]
image_0.save('track_result.gif', save_all=True, append_images=pil_imgs[1:], duration=30, loop=0)
运行上述代码后,会保存一个GIF文件,展示了多个目标的跟踪结果,如下图: