[从0开始轨迹预测][NMS]:NMS的应用(目标检测、轨迹预测)

news2024/11/14 1:06:43

非极大值抑制(Non-Maximum Suppression,简称NMS)是一种在计算机视觉中广泛应用的算法,主要用于消除冗余和重叠的边界框。在目标检测任务中,尤其是在使用诸如R-CNN系列的算法时,会产生大量的候选区域,而这些区域可能存在大量的重叠。为了解决这个问题,使用NMS算法来保留最有可能的区域,同时抑制其他冗余或重叠的区域。

1. NMS在目标检测领域的应用

非极大值抑制在目标检测领域的基本原理和步骤如下:

  1. 对于每个类别,按照预测框的置信度进行排序,将置信度最高的预测框作为基准。
  2. 从剩余的预测框中选择一个与基准框的重叠面积最大的框,如果其重叠面积大于一定的阈值,则将其删除。
  3. 对于剩余的预测框,重复步骤2,直到所有的重叠面积都小于阈值,或者没有被删除的框剩余为止。

通过这样的方式,NMS可以过滤掉所有与基准框重叠面积大于阈值的冗余框,从而实现检测结果的优化。值得注意的是,NMS的阈值通常需要根据具体的数据集和应用场景进行调整,以兼顾准确性和召回率。

# NMS Python 简单实现
import numpy as np

def nms(dets, thresh):
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)

        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep

这段代码首先计算所有候选框的面积和分数,然后按照分数对候选框进行排序。然后,它进入一个循环,每次循环中,它都会选择当前分数最高的框,并将其添加到保留列表中。然后,它会计算这个框与其他所有框的重叠区域,并计算这些重叠区域与各自框的面积之比(即IoU)。如果这个比值大于给定的阈值,那么就会将对应的框从候选列表中删除。这个过程会一直重复,直到所有的框都被处理完毕。

2. NMS在轨迹预测领域的应用

NMS在轨迹预测中的应用,主要是用来处理预测结果中的冗余和重叠的轨迹,对于一些方法,模型预测出大量的候选轨迹,这些轨迹可能存在大量的重叠。为了解决这个问题,可以使用上述NMS算法来保留最有可能的轨迹,同时抑制其他冗余或重叠的轨迹。

假设对某个场景中的某辆车使用模型预测了64条或更多的轨迹,以很好地捕获多模态性,同时每条轨迹对应一个置信度,所有轨迹置信度总和为1。但最终输出时,我们一般仅输出6条轨迹(下游 or 打榜需求),如果直接选择置信度最高的6条轨迹会存在问题,比如说这六条轨迹靠的很近,无法体现多模态性。因此,我们需要使用NMS来选择最终的轨迹:

  1. 将轨迹按照置信度从高到低排序。
  2. 计算每两条轨迹之间最后一个点的距离,会产生一个距离矩阵。
  3. 依次按照置信度高低选取轨迹,比如第一次选择排名第一的轨迹,后面再选择轨迹时需要跟已经选择的所有判断距离是否大于某个阈值,如果小于该阈值,说明存在已选的轨迹与当前要被选择的轨迹很类似,则放弃选择该轨迹。

这样,通过NMS,我们可以从大量的预测轨迹中选择出最具代表性的轨迹,从而提高轨迹预测的效果。

在这里插入图片描述

从图中6条轨迹中选择出3条,如果按照置信度来选,应该选择0.8,0.5,0.4的轨迹,但由于0.5和0.4两条轨迹靠的太近(小于某个阈值)因此最终选择的轨迹为0.8,0.5,0.3三条轨迹。

def batch_nms(pred_trajs, pred_scores, dist_thresh, num_ret_modes=6):
    """

    Args:
        pred_trajs (batch_size, num_modes, num_timestamps, 7)
        pred_scores (batch_size, num_modes):
        dist_thresh (float):
        num_ret_modes (int, optional): Defaults to 6.

    Returns:
        ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
        ret_scores (batch_size, num_ret_modes)
        ret_idxs (batch_size, num_ret_modes)
    """
    
    batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape

    sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
    bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
    sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]  # 对score从大到小排序
    sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs]  # (batch_size, num_modes, num_timestamps, 7)
    sorted_pred_goals = sorted_pred_trajs[:, :, -1, :]  # (batch_size, num_modes, 7)  最后一个点

    dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)  # 64*64 的距离矩阵
    point_cover_mask = (dist < dist_thresh)

    point_val = sorted_pred_scores.clone()  # (batch_size, N)
    point_val_selected = torch.zeros_like(point_val)  # (batch_size, N)

    ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
    ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
    ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
    bs_idxs = torch.arange(batch_size).type_as(ret_idxs)

    for k in range(num_ret_modes):
        cur_idx = point_val.argmax(dim=-1) # (batch_size)
        ret_idxs[:, k] = cur_idx

        new_cover_mask = point_cover_mask[bs_idxs, cur_idx]  # (batch_size, N)
        point_val = point_val * (~new_cover_mask).float()  # (batch_size, N)
        point_val_selected[bs_idxs, cur_idx] = -1
        point_val += point_val_selected

        ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
        ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]

    bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)

    ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
    return ret_trajs, ret_scores, ret_idxs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1910946.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Spring Boot的旅游信息推荐信息系统设计与实现(源码+lw+部署+讲解)

技术指标 开发语言&#xff1a;Java 框架&#xff1a;Spring BootJSP JDK版本&#xff1a;JDK1.8 数据库&#xff1a;MySQL5.7 数据库工具&#xff1a;Navicat16 开发软件&#xff1a;IDEA Maven包&#xff1a;Maven3.6.3 浏览器&#xff1a;IE浏览器 功能描述 旅游信…

怎么压缩ppt?这几种压缩方法大家都在用!

怎么压缩ppt&#xff1f;当我们沉浸在PPT创作的海洋中&#xff0c;每一个精心的布局、每一个动人的动画&#xff0c;都仿佛是我们心血的结晶&#xff0c;然而&#xff0c;随着我们不断雕琢&#xff0c;PPT文件的大小也在悄然增长&#xff0c;如同一只隐形的巨兽&#xff0c;在不…

文华财经盘立方多空变色波段趋势线指标公式源码

文华财经盘立方多空变色波段趋势线指标公式源码&#xff1a; N1:20; N2:ROUND(N1/2,1); N3:ROUND(SQRT(N1),1); N4:2*EMA2(C,N2)-EMA2(C,N1); 尊重市场:EMA2(N4,N3),COLORRED,LINETHICK2; 尊重市场1:IF(尊重市场<REF(尊重市场,1), 尊重市场,NULL),COLORGREEN,LINETHIC…

Redis数据类型和数据队列

一.Redis数据类型 参考资料&#xff1a;http://www.redis.cn/topics/data-types.html 相关命令参考: http://redisdoc.com/ Redis 是一种基于内存的开源数据结构存储系统&#xff0c;支持多种数据类型&#xff0c;每种数据类型都有自己特定的操作命令。 String&#xff08;字…

渲染农场怎么用更省钱?渲染100邀请码1a12

现在越来越多的设计师开始使用渲染农场&#xff0c;其中收费是个大问题&#xff0c;怎么用渲染农场才能更省钱呢&#xff1f;今天我们就来看下吧。 1、明确渲染方式 要根据不同情况选择合理的渲染方式&#xff0c;比如渲染农场就适合大场景渲染和紧急出图情况&#xff0c;其他…

跟我练习100道FPGA入门题目~(2/100)

难度指数&#xff1a;一颗星 关键词&#xff1a;组合逻辑、入门基础 点击此处直接答题&#xff1a;F学社-全球FPGA技术提升平台 (zzfpga.com) 提交代码就能看到波形图和电路图啦&#xff01; &#xff08;在社区加入群聊&#xff0c;更多学友等着和你探讨~&#xff09;

主流电商平台营销中大数据的应用◆

随着经济的不断发展&#xff0c;网络信息技术不断加强&#xff0c;电子商务和大数据的蓬勃发展极大地方便了人们的生活。本文章主要阐述大数据分析与电商营销的含义、大数据分析在电子商务营销中的应用&#xff0c;以及该应用的作用和存在哪些不足及解决方法。探究大数据分析在…

年化15.73%:创业板指数布林带突破Backtrader策略(代码+数据)

原创文章第582篇&#xff0c;专注“AI量化投资、世界运行的规律、个人成长与财富自由"。 昨天咱们使用backtrader&#xff0c;重写了创业板动量趋势择时&#xff1a;年化19.2%&#xff1a;backtraderquantstats实现创业板动量择时(代码数据) 今天咱们换一个通道指标&#…

Labview_Note_4

1.字符串显示控件设置自动在最下边位置 字符串属性节点中→文本→滚动条位置 滚动框在滚动条中的位置。 如需设置该位置&#xff0c;可连线用于表示滚动行数的数值至该属性。LabVIEW在滚动条的最后一行显示连线至该属性的数值。如需滚动至文本的最后一行&#xff0c;可连线行…

Lab1 论文 MapReduce

目录 &#x1f339;前言 &#x1f985;2 Programming Model &#x1f33c;2.1 Example &#x1f33c;2.2 Types &#x1f33c;2.3 More Examples &#x1f985;3 Implementation(实现) &#x1f33c;3.1 ~ 3.3 &#x1f33c;3.4 ~ 3.6 &#x1f985;4 Refinemen…

OpenBayes 教程上新 | 清华大学强推!YOLOv10 实现更高效的目标检测

过去几年中&#xff0c;由于 YOLO 在计算成本和检测性能之间的有效平衡&#xff0c;它已经成为实时目标检测领域的主要范式。然而&#xff0c;YOLO 依赖于非极大值抑制 (NMS) 进行后处理&#xff0c;这阻碍了 YOLO 的端到端部署&#xff0c;并对推理延迟产生了不利影响。 YOLO…

C++第三弹 -- 类与对象(上)

目录 前言一. 面向过程和面向对象的初步认识二. 类的引入三. 类的定义1.定义2. 命名规则建议 四. 类的访问限定符以及封装1. 访问限定符2.面试题3. 封装 五. 类的作用域六. 类的实例化七. 类的对象大小的计算八. 类成员函数this指针1. this指针的引出2. this指针的特性3. C语言…

虚幻引擎在建筑和房地产中的五大颠覆性应用:探索新时代的优势

最初&#xff0c;虚幻引擎作为一个强大的游戏开发工具出现。它不断推动虚拟环境的可能性边界。正因如此&#xff0c;它的使用自然而然地超越了游戏开发&#xff0c;涵盖了包括建筑工程在内的其他行业。那么&#xff0c;在建筑和房地产领域使用虚幻引擎有哪些好处呢&#xff1f;…

Date TimePicker 时间选择器精确限制到时分秒,此刻按钮点击失效处理

今天在开发的时候遇到一个需求&#xff0c;日期时间选择器组件不能选择已经过去的年月日时分秒。用户只能选择当前时间的时间&#xff0c;如果年月日选择是当天之前的时间&#xff0c;时分秒不做限制&#xff0c;如果年月日选择的是当天时间&#xff0c;就要判断时分秒&#xf…

koa + http-proxy-middleware 搭建一个带转发的静态服务器

背景 由于工作中碰到写普通页面&#xff08;未使用脚手架&#xff09;&#xff0c;需要发起接口请求&#xff0c;但普通页面又无法对接口发起正常请求&#xff0c;故编写一个Koa搭建的带转发功能的静态服务器。 起步 新建一个文件夹&#xff0c;在文件夹下打开 cmd 或者 git …

侯捷C++面向对象高级编程(上)-11-虚函数与多态

1.虚函数 2.virtual 3.继承&#xff0b;复合关系下的构造和析构 4.委托&#xff0b;继承

Java基础语法--基本数据类型

Java基础语法–基本数据类型 Java是一种静态类型语言&#xff0c;这意味着每个变量在使用前都必须声明其数据类型。Java提供了多种基本数据类型&#xff0c;用于存储整数、浮点数、字符和布尔值等。以下是Java中的基本数据类型及其特点&#xff1a; 1. 整型&#xff08;Integ…

深入剖析预处理

目录 1.预定义符号 2.#define 定义常量 3.#define定义宏 4.带有副作用的宏参数 5.宏替换的规则 6.宏函数的对比 7.#和## 7.1 #运算符 7.2 ## 运算符 8.命名约定 9.#undef 10.命令行定义 11.条件编译 12.头文件的包含 12.1 头文件被包含的方式&#xff1a; 12.1…

C++ 语法习题(2)

第三讲 循环语句 1.偶数 编写一个程序&#xff0c;输出 1 到 100之间&#xff08;包括 1 和 100&#xff09;的全部偶数。 输入格式 无输入。 输出格式 输出全部偶数&#xff0c;每个偶数占一行。 输入样例 No input输出样例 2 4 6 ... 100 参考代码: #include <i…

玩具营销是如何拿捏成年人钱包?

好像现在的成年人逐渐热衷于偏向年轻化&#xff0c;问问题会好奇“尊嘟假嘟”&#xff0c;饭量上的“儿童套餐”&#xff0c;娃娃机前排长队......而最突出的莫过于各类各式的玩具不断收割当代年轻人&#xff0c;除去常给大朋友们小朋友们送去玩具福利的“麦、肯”双门&#xf…