torchvision.ops.nms实现NMS

news2024/12/23 12:42:44

nms原理:
当目标检测模型对一个目标有多个检测框时,需要滤掉多余的框,留下最接近真实目标的框。

在这里插入图片描述

步骤是这样的:
1.先把目标框初筛一波,比如设阈值为0.25, 把预测概率 < 0.25的目标框滤掉。
2.把 每个类别的 目标框 按预测概率从大到小排序。
3.每个类别的目标框 两两计算 IOU,当IOU > 阈值时,说明这两个目标框高度重合,没必要都留着,留概率较大的那个。注意这里是同类别的框计算IOU,不同类别的即使高度重合也不滤掉。

这样每个类别就滤掉了多余的框。因为已经按照预测概率从大到小排序,所以两两计算IOU时优先计算的是概率较大的框。

下面看下torchvision.ops.nms的用法,
这里假设有N个目标框。
传入参数boxes为Tensor,shape为[N,4], 每个box坐标格式是(x1,y1,x2,y2),即左上角右下角坐标。
如果你的目标框为(x,y,w,h), 需要做格式转换。
score: Tensor, shape为[N], 每个目标框检测的概率,如果是COCO,预测了80个类别的概率,就把最大的概率取出来。
iou阈值:float型,每个类别内两两目标框IOU > 这个阈值时,扔掉概率小的那个。

用它之前先用概率阈值初筛一波box, 把预测概率很低的box去掉。
需要把box的坐标转为左上角右下角坐标格式。
把box按概率从大到小排序。
提取每个box的预测概率。

torchvision.ops.nms的注释:
```python
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
    """
    Performs non-maximum suppression (NMS) on the boxes according
    to their intersection-over-union (IoU).

    NMS iteratively removes lower scoring boxes which have an
    IoU greater than iou_threshold with another (higher scoring)
    box.

    If multiple boxes have the exact same score and satisfy the IoU
    criterion with respect to a reference box, the selected box is
    not guaranteed to be the same between CPU and GPU. This is similar
    to the behavior of argsort in PyTorch when repeated values are present.

    Args:
        boxes (Tensor[N, 4])): boxes to perform NMS on. They
            are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
            ``0 <= y1 < y2``.
        scores (Tensor[N]): scores for each one of the boxes
        iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold

    Returns:
        Tensor: int64 tensor with the indices of the elements that have been kept
        by NMS, sorted in decreasing order of scores
    """

下面以yolov8为例,说明如何用torchvision.ops.nms来计算nms, 过滤掉多余的目标框。

这里yolov8的prediction是(1,116,5460), 其中5460是anchor的数量,
116的前4个是目标框坐标(x,y,w,h), 中间80个是COCO数据集中80个类的预测概率,
后面32个是mask coeff, 分割mask用的,nms这里不用。只用到前面84个。

正常情况下每个类别都要算一次nms, 这里用了batched nms, 用了小技巧把所有类别的nms一起计算。

代码为了阅读简洁做了修改。

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nc=0,  # number of classes (optional)
        max_time_img=0.05,
        max_nms=30000,
        max_wh=7680,
):
  #prediction:(1,116,5460) 
    
    bs = prediction.shape[0]  # batch size  #这里只有一张图片,所以是1
    nc = 80  # 80个类别
    nm = 32  #最后32是mask coeff,用于目标分割的
    mi = 84  # mask start index
    #4~84是80个类别概率所在的位置,每个anchor取出最大的预测概率,用0.25的阈值过滤一波。
    #最大的都滤掉的话,说明这个anchor处没有预测到目标,得到的是boolean(1,5460), anchor数量
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

    #用来保存nms结果
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    
    for xi, x in enumerate(prediction):  # image index, image inference #每个图片计算一次
        #取出刚刚用阈值滤掉一波剩下的anchor,input x是(116,5460),滤掉后剩27个anchor
        #x transpose后变为(5460,116),过滤后为(27,116)
        x = x.transpose(0, -1)[xc[xi]]  # confidence
   
        # 把前面说的box坐标,80个类别的概率和mask分开
        #分别是(27,4), (27,80), (27,32)
        box, cls, mask = x.split((4, nc, nm), 1)
        #torchvision的nms要求是(x1,y1,x2,y2)格式
        box = xywh2xyxy(box)  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        
        #cls:(27,80), 取每行最大值,得到(27,1). conf是最大value, j是最大value对应的index, 也就是class id
        conf, j = cls.max(1, keepdim=True)
        
        #上次已经用conf_thres筛过一次了,这次>conf_thres应该全是true
        #cat之后是(27,38), box:(27,4),conf:(27,1),j:(27,1), mask:(27,32)
        x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        #把x的所有行按conf从大到小排序,(27,38)
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        #如果你不想每个类别都做一次nms,而是所有类别一起做nms
        #就需要把不同类别的目标框尽量没有重合,不至于把不同类别的IOU大的目标框滤掉
        #先用每个类别id乘一个很大的数,作为offset,把每个类别的box坐标都加上相应的offset,这是batched nms
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

        i = i[:max_det]  # limit detections
        
        output[xi] = x[i]  #取出NMS过滤剩下的prediction
        
    return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/653092.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DEVONthink 3:Mac文档管理工具,知识管理app

DEVONthink Pro是一款功能强大的文档管理软件&#xff0c;它可以帮助用户高效地组织、管理和查找各种类型的文件和信息。 下面是DEVONthink Pro的主要特点介绍&#xff1a; 多功能性&#xff1a;DEVONthink Pro支持多种文件类型和数据源&#xff0c;并提供全面的搜索、分类、过…

在 ZBrush、Substance 3D Painter 和 UE5 中创作警探角色(P1)

小伙伴们大家好&#xff0c;今天瑞云渲染小编给大家分享的是自由CG艺术家Jean Zoudi创建《极乐迪斯科》的警探角色的项目花絮&#xff0c;会解释身体和服装的建模方式&#xff0c;分享角色发型和面部毛发背后的工作流程&#xff0c;也会详细介绍渲染过程。 介绍 大家好&#…

性能测试怎么做?性能测试策略配套适用场景,打通性能测试...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、常见的测试策略…

直流对数放大器

Logarithmic Amplifiers 对数放大器的应用场合 在雷达和一些其他测距的场合&#xff0c;sensor输出的信号的动态范围比较宽&#xff0c;也就是要求sensor输出的弱信号时有比较大的放大倍数&#xff0c;强的信号有较小的放大倍数&#xff0c;以保证sensor输出的信号经过放大器后…

可移动硬盘无媒体是什么意思?移动硬盘显示无媒体数据如何恢复

案例分享&#xff1a;【最近我遇到了一个麻烦&#xff0c;我的移动硬盘突然显示“无媒体”。我不知道发生了什么&#xff0c;我很担心我的硬盘中存储的大量重要数据是否还能恢复。我该怎么解决移动硬盘显示无媒体问题呢&#xff0c;求大神帮帮我吧&#xff01;&#xff01;&…

浏览器是如何实现生成HTTP消息的

我们经常会使用浏览器访问各种网站&#xff0c;获取各种信息&#xff0c;帮助解决工作生活中的问题。那你知道&#xff0c;浏览器是怎么帮助我们实现对web服务器的访问&#xff0c;并返回给我们想要的信息吗&#xff1f; 1. 浏览器生成HTTP消息 我们平时使用的浏览器有很多种&…

【强烈推荐】 十多款2023年必备国内外王炸级AI工具 (免费 精品 好用) 让你秒变神一样的装逼佬感受10倍生产力 (6) AI学习

&#x1f680; 个人主页 极客小俊 ✍&#x1f3fb; 作者简介&#xff1a;web开发者、设计师、技术分享博主 &#x1f40b; 希望大家多多支持一下, 我们一起进步&#xff01;&#x1f604; &#x1f3c5; 如果文章对你有帮助的话&#xff0c;欢迎评论 &#x1f4ac;点赞&#x1…

【LLMs系列】90%chatgpt性能的小羊驼Vicuna模型学习与实战

一、前言 UC伯克利学者联手CMU、斯坦福等&#xff0c;再次推出一个全新模型70亿/130亿参数的Vicuna&#xff0c;俗称「小羊驼」&#xff08;骆马&#xff09;。小羊驼号称能达到GPT-4的90%性能 github 地址: GitHub - lm-sys/FastChat: An open platform for training, servi…

ChatGPT爆火网络背后的故事?

文章目录 前言一、ChatGPT的诞生背景二、ChatGPT的技术原理三、ChatGPT的推广策略四、ChatGPT的未来展望五、橙子送书第2期 前言 ChatGPT是一款基于人工智能技术的聊天机器人&#xff0c;它的出现引起了广泛的关注和热议。在短短的时间内&#xff0c;ChatGPT就成为了全球范围内…

实测|飞凌嵌入式OK3588-C开发板4G模组的使用与测试

本篇试用报告由发烧友 ouxiaolong提供&#xff0c;感谢ouxiaolong的支持。飞凌嵌入式会持续开展开发板有奖试用活动&#xff0c;更有京东E卡等着你&#xff01;欢迎大家的持续关注。 飞凌嵌入式OK3588-C开发板是一款性能强劲的旗舰产品&#xff0c;采用核心板底板的分体式设计…

linuxOPS系统服务_Linux下用户管理

用户概念以及基本作用 **用户&#xff1a;**指的是Linux操作系统中用于管理系统或者服务的人 一问&#xff1a;管理系统到底在管理什么&#xff1f; 答&#xff1a;Linux下一切皆文件&#xff0c;所以用户管理的是相应的文件 二问&#xff1a;如何管理文件呢&#xff1f; …

JDK、JRE、JVM三者的区别

JDK&#xff08;Java Development Kit&#xff09;&#xff1a;Java开发工具包 JRE&#xff08;Java Runtime Environment&#xff09;&#xff1a;Java运行环境 JVM&#xff08;Java Virtual Mechinal&#xff09;&#xff1a;Java虚拟机 &#xff08;1&#xff09;JDK和JRE 是…

Python海龟画图 几种基本图形

注&#xff1a;本文主要根据绘制步骤进行区分&#xff0c;实际使用时应当调节参数以绘制需要的图形。文中的步骤均为循环进行&#xff0c;循环50到100次&#xff0c;具体次数见代码示例。 1.前进小角度旋转 绘制效果如图&#xff0c;如果旋转角度为360的因数则绘制出多边形。 …

OJ Summation of Four Primes

1.题目 题目描述 Euler proved in one of his classic theorems that prime numbers are infinite in number. But can every number be expressed as a summation of four positive primes? I don’t know the answer. May be you can help!!! I want your solution to be v…

弹性盒子(display: flex)布局超全讲解|Flex 布局教程

文章目录 什么是弹性布局&#xff1f;弹性布局的特点&#xff1f;容器的属性justify-contentalign-itemsflex-directionflex-wrapflex-flowalign-contentorder属性flex-grow属性flex-shrink属性flex-basis属性flex属性align-self属性 什么是弹性布局&#xff1f; 弹性布局&…

我被今年就业难度震惊到了

随着毕业季到来&#xff0c;今年高校毕业生就业问题正在被越来越多的人关注。年年都是最难就业季&#xff0c;但今年却格外不同寻常的难。大家都知道 2022 年毕业生人数历史上首次突破千万。而今年毕业生人数&#xff0c;高达 1158 万人&#xff0c;史无前例的多。加上海外留学…

【Unity Shader】从入门到着魔(2)用C#画一个立方体

文章目录 一、构成一个立方需要多少个顶点?二、定义三角面的索引数组:三、定义UV坐标数组:四、最后构建Mesh:五、完整代码:一、构成一个立方需要多少个顶点? 这个问题是面试经常被问到的题。如上图,我们知道在几何中立方体有6个面,8个顶点。但在图形学中,顶点指的是模…

项目管理:制定项目计划,这些作用不可忽视

做任何事&#xff0c;做计划不可缺少&#xff0c;没有计划&#xff0c;就没有控制&#xff0c;编制计划可帮助项目管理团队提前进行思考。 制定计划后&#xff0c;还需要对项目计划进行跟踪&#xff0c;这样才不会让计划白做。 你知道项目计划进行跟踪&#xff0c;有哪些不可…

更新公告:Airtest更新至1.2.10.2版本

1. 前言 本次是Airtest库更新&#xff0c;版本提升至1.2.10.2&#xff0c;内容主要是Android录屏功能的改动。 2. 更改部分 在Airtest1.2.9中提供的cv2模式已经被舍弃&#xff0c;因为容易引发错误&#xff0c;效果也不如ffmpeg&#xff1b; 只有Android需要mode参数&#x…

java8 (jdk 1.8) 新特性 ——初步,发现不一样的新特性

前言 3202 年了&#xff0c;现在市面上的公司几乎都是 jdk1.8, 有也是极少数在用java7 , 即使是一些传统企业&#xff0c;在技术革新方面也很重视&#xff0c;毕竟现在是大数据时代 那么java8 有哪些新特性呢&#xff1f;换句话说为什么在码界 这么受欢迎&#xff01;&#xf…