NMS非极大值抑制

news2025/1/12 0:45:06

文章目录

  • 一、NMS详解
  • 二、NMS具体步骤与实现
    • 1.步骤
    • 2、代码(pytorch版本)



一、NMS详解

NMS即非极大值抑制,常被用于目标检测等,即只保留检测同一物体置信度最大的框。
具体作用可以看图:
在这里插入图片描述

可以看出,未经过nms的图片,有很多指向同一物体的框。

二、NMS具体步骤与实现

1.步骤

这里是转发bubbliiiing博主的!
本博文实现的是多分类的非极大抑制:
输入shape为[ batch_size, all_anchors, 5+num_classes ]

第一个维度是图片的数量。
第二个维度是所有的预测框。
第三个维度是所有的预测框的预测结果。
这里的预测结果是(x,y,w,h,包含种类的概率,所有种类的概率),在这里我使用的种类为2分类。
非极大抑制的执行过程如下所示:
1、对所有图片进行循环。(循环1)
2、找出该图片中得分大于门限函数的框。在进行重合框筛选前就进行得分的筛选可以大幅度减少框的数量。
3、判断第2步中获得的框的种类与得分。即找出该图片中不同框所对应的最大种类的概率以及种类。取出预测结果中框的位置与之进行堆叠。此时最后一维度里面的内容由5+num_classes变成了4+1+2,四个参数代表框的位置,一个参数代表预测框是否包含物体,两个参数分别代表种类的置信度与种类。
4、对种类进行循环,(循环2)非极大抑制的作用是筛选出一定区域内属于同一种类得分最大的框,对种类进行循环可以帮助我们对每一个类分别进行非极大抑制。
5、根据得分对该种类进行从大到小排序。
6、每次取出得分最大的框(循环3),计算其与其它所有预测框的重合程度,重合程度过大的则剔除。

2、代码(pytorch版本)

def bbox_iou(self, box1, box2, x1y1x2y2=True):
    """
        计算IOU
    """
    if not x1y1x2y2:
        b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
        b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
    else:
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]

    inter_rect_x1 = torch.max(b1_x1, b2_x1)
    inter_rect_y1 = torch.max(b1_y1, b2_y1)
    inter_rect_x2 = torch.min(b1_x2, b2_x2)
    inter_rect_y2 = torch.min(b1_y2, b2_y2)

    inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, min=0) * \
                torch.clamp(inter_rect_y2 - inter_rect_y1, min=0)
                
    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
    
    iou = inter_area / torch.clamp(b1_area + b2_area - inter_area, min = 1e-6)

    return iou

def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
    #----------------------------------------------------------#
    #   将预测结果的格式转换成左上角右下角的格式。
    #   prediction  [batch_size, num_anchors, 85]
    #----------------------------------------------------------#
    box_corner          = prediction.new(prediction.shape)
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4]

    output = [None for _ in range(len(prediction))]
    for i, image_pred in enumerate(prediction):
        #----------------------------------------------------------#
        #   对种类预测部分取max。
        #   class_conf  [num_anchors, 1]    种类置信度
        #   class_pred  [num_anchors, 1]    种类
        #----------------------------------------------------------#
        class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)

        #----------------------------------------------------------#
        #   利用置信度进行第一轮筛选
        #----------------------------------------------------------#
        conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()

        #----------------------------------------------------------#
        #   根据置信度进行预测结果的筛选
        #----------------------------------------------------------#
        image_pred = image_pred[conf_mask]
        class_conf = class_conf[conf_mask]
        class_pred = class_pred[conf_mask]
        if not image_pred.size(0):
            continue
        #-------------------------------------------------------------------------#
        #   detections  [num_anchors, 7]
        #   7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
        #-------------------------------------------------------------------------#
        detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)

        #------------------------------------------#
        #   获得预测结果中包含的所有种类
        #------------------------------------------#
        unique_labels = detections[:, -1].cpu().unique()

        if prediction.is_cuda:
            unique_labels = unique_labels.cuda()
            detections = detections.cuda()

        for c in unique_labels:
            #------------------------------------------#
            #   获得某一类得分筛选后全部的预测结果
            #------------------------------------------#
            detections_class = detections[detections[:, -1] == c]

            # #------------------------------------------#
            # #   使用官方自带的非极大抑制会速度更快一些!
            # #------------------------------------------#
            # keep = nms(
            #     detections_class[:, :4],
            #     detections_class[:, 4] * detections_class[:, 5],
            #     nms_thres
            # )
            # max_detections = detections_class[keep]
            
            # 按照存在物体的置信度排序
            _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
            detections_class = detections_class[conf_sort_index]
            # 进行非极大抑制
            max_detections = []
            while detections_class.size(0):
                # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
                max_detections.append(detections_class[0].unsqueeze(0))
                if len(detections_class) == 1:
                    break
                ious = self.bbox_iou(max_detections[-1], detections_class[1:])
                detections_class = detections_class[1:][ious < nms_thres]
            # 堆叠
            max_detections = torch.cat(max_detections).data
            
            # Add max detections to outputs
            output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
        
        if output[i] is not None:
            output[i]           = output[i].cpu().numpy()
            box_xy, box_wh      = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
            output[i][:, :4]    = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
    return output


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/590380.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Freertos的ESP-IDF开发——8.使用wifi访问HTTP服务器

目录 0. 前言其他ESP-IDF文章 1. 前期准备1.1头文件准备1.2 http 服务器搭建 2. 连接 wifi3.http访问任务4. 完整代码 0. 前言 使用ESP32使用 wifi 访问 http 服务器 开发环境&#xff1a;ESP-IDF 4.2 操作系统&#xff1a;Ubuntu22.04 开发板&#xff1a;自制的ESP32-WROOM-…

流行框架(二)网络请求库 OKhttp

文章目录 概述HttpURLConnectionGET和POST获取文本数据GETPOST OKHttp基本使用依赖与权限发起一个get请求重要概念OkHttpClientRequestCallRealCallAsyncCall 请求调度器Dispatcher同步请求execute的执行异步请求enqueue的执行两种请求方式的总结 OkHttp拦截器链拦截器种类addI…

字节狂问1小时,小伙offer到手,太狠了!(字节面试真题)

前言&#xff1a; 在尼恩的&#xff08;50&#xff09;读者社群中&#xff0c;经常有小伙伴&#xff0c;需要面试 头条、美团、阿里、京东等大厂。 下面是一个小伙伴成功拿到字节飞书offer&#xff0c;通过一小时拷问的面试经历&#xff0c;就两个字&#xff1a; 深&#xf…

基于STM32的SYN6288语音播报模块驱动实验(代码开源)

前言&#xff1a;本文为手把手教学 SYN6288 语音播报模块的驱动实验&#xff0c;本教程的 MCU 采用STM32F103ZET6。通过 CubeMX 软件配置 UART 串口协议驱 SYN6288 模块进行规定的语音播报。考虑到 SYN6288 模块的集成化与智能化很高&#xff0c;所以该模块的使用是极其便利的。…

【HarmonyOS】初识低代码平台开发元服务

【关键字】 HarmonyOS、低代码平台、元服务开发、拖拽式开发 【写在前面】 今天要分享的是HarmonyOS中的低代码开发相关的内容&#xff0c;低代码开发是DevEco Studio提供的一种UI界面可视化的构建方式&#xff0c;通过图形化的自由拖拽数据的参数化配置&#xff0c;可以快速…

【Java项目】基于SpringBoot+Vue的校园二手商品交易平台

文章目录 功能简述功能展示用户模块购物车模块管理员模块物物对价功能实现 代码 视频演示 代码下载 项目内含有 功能简述 系统登录界面的实现 系统首页界面的实现 用户信息管理界面的实现 商品购物功能的实现 购物车管理功能及支付功能的实现 物物对价功能的实现 用户安全设置…

【面试需了解】jvm垃圾回收机制-GC基础知识、jvm基本组成、查看、排查

前言 jvm垃圾回收机制-GC基础知识、jvm基本组成、查看、排查 文章目录 前言GC基础知识概述 JVM基本组成1. 虚拟机的组成2. jvm的内存区域 查看jvm排查jvm问题1. 正常运行的系统2. 对于已经发生了OOM的系统 GC基础知识 概述 什么是垃圾 一个对象没有被引用&#xff0c;没有任何…

Spring MVC详解(学习总结)

一、Sprig MVC简介1.1介绍1.2MVC是什么 二、Spring MVC实现原理2.1核心组件2.2工作流程 三、第一个Spring MVC四、常用注解五、参数绑定5.1URL风格参数绑定5.2RESTful风格的URL参数获取5.3映射Cookie5.4使用POJO绑定参数5.5JSP页面的转发和重定向 六、Spring MVC数据绑定6.1基本…

vulnstack(红日)内网渗透靶场二: 免杀360拿下域控

前言 在我之前的文章vulnstack(一)打靶&#xff0c;我主要依赖Cobalt Strike进行后期渗透测试&#xff0c;这次我计划使用Metasploit框架(MSF)来进行这个阶段的工作。这个靶场与之前的不同之处在于它的WEB服务器安装了360安全卫士。虽然这增加了挑战的难度&#xff0c;但只要我…

Shell脚本攻略:循环语句while、until

目录 一、理论 1.while 2.until 3.break 4.continue 二、实验 1.实验一 2.实验二 3.实验三 4.实验四 5.实验五 一、理论 1.while (1)while用法 while循环满足条件执行&#xff0c;不满足不执行。 用于不知道循环次数&#xff0c;需要主动结束循环或达到条件结束…

二开项目权限应用全流程-按钮级控制

二开项目权限应用全流程-按钮级控制 员工A和员工B都可以访问同一个页面&#xff08;以员工管理为例&#xff09;&#xff0c;但是员工A可以导出excel&#xff0c;员工B就不可以导出excel(看不到按钮) 思路 用户登陆成功后&#xff0c;用户可以访问的按钮级别权限保存在point…

阿里巴巴淘天集团后端暑期实习面经

目录 1.面向对象三大特性2.重写和重载3.protected 关键字和 default 关键字的作用范围4.栈帧中有哪些东西&#xff1f;5.堆中有哪些区域&#xff1f;6.new 一个对象存放在哪里&#xff1f;7.CMS 收集器回收阶段8.CMS 收集器回收过程哪些需要暂停线程&#xff1f;9.HashMap JDK …

手机行业再多一条“鲶鱼”,小度青禾要打一场漂亮突围战?

文 | 智能相对论 作者 | 佘凯文 智能手机到底还是不是一门好生意&#xff1f; 在换机周期被无限拉长、市场竞争越发激烈、高端市场迟迟无法突破等共同背景下&#xff0c;智能手机到底还是不是一门好生意&#xff0c;成为行业内这两年被热议的话题之一。 由TechInsights发布…

腾讯云轻量应用服务器CPU主频多少?型号?

腾讯云轻量应用服务器CPU型号是什么&#xff1f;轻量服务器处理器主频&#xff1f;腾讯云服务器网账号下的CPU处理器型号为2.5GHz主频的Intel(R) Xeon(R) Gold 6133 CPU和2.4GHz主频Intel(R) Xeon(R) CPU E5-26xx v4&#xff0c;腾讯云轻量应用服务器不支持指定底层物理服务器的…

NodeJs内存快照分析

&#xff08;头等人&#xff0c;有本事&#xff0c;没脾气&#xff1b;二等人&#xff0c;有本事&#xff0c;有脾气&#xff1b;末等人&#xff0c;没本事&#xff0c;大脾气。——南怀瑾&#xff09; NodeJs内存分析的必要性 回顾过去&#xff0c;我们排查web应用问题的途径…

36岁大龄程序员全职接单三个月的感触

36岁大龄程序员&#xff0c;原以为逃过35岁危机&#xff0c;没想到在年前被优化&#xff0c;拿了N2&#xff0c;12w薪资后&#xff0c;我开始了全职接单的道路。现在每个月平均收入有个20K&#xff0c;一路走来挺有感触的&#xff0c;把自己的经验分享给大家。 赚钱&#xff0…

【Jmeter】生成html格式接口自动化测试报告

jmeter自带执行结果查看的插件&#xff0c;但是需要在jmeter工具中才能查看&#xff0c;如果要向领导提交测试结果&#xff0c;不够方便直观。 笔者刚做了这方面的尝试&#xff0c;总结出来分享给大家。 这里需要用到ant来执行测试用例并生成HTML格式测试报告。 一、ant下载安…

Android13蓝牙 停用绝对音量功能

Android13蓝牙 停用绝对音量功能 文章目录 Android13蓝牙 停用绝对音量功能一、前言二、代码实现分析过程1、查看SettingsLib源码资源2、查看原生Setitntgs 相关字符&#xff08;1&#xff09;xml 布局文件中的显示&#xff08;2&#xff09; java 代码文件中的控制串口上控制&…

Vue注册界面精美模板分享

文章目录 &#x1f412;个人主页&#x1f3c5;Vue项目常用组件模板仓库&#x1f4d6;前言&#xff1a;&#x1f380;源码如下&#xff1a; &#x1f412;个人主页 &#x1f3c5;Vue项目常用组件模板仓库 &#x1f4d6;前言&#xff1a; 本篇博客主要提供vue组件之注册组件源码…

硬件软件【部署】

开发板和主机 1.功能不同&#xff1a;帮助开发者进行嵌入式系统的开发和调试&#xff0c;具有较强的硬件拓展能力&#xff0c;可以连接各种传感器/执行器等外设。主机为满足一般的计算需求而设计&#xff0c;具备更强的计算和图形处理能力。 2.架构不同&#xff1a;开发板通常…