非最大值抑制(NMS)函数

news2024/12/29 10:00:46

非最大值抑制(NMS)函数

flyfish

非最大值抑制(Non-Maximum Suppression, NMS)是计算机视觉中常用的一种后处理技术,主要用于目标检测任务。其作用是从一组可能存在大量重叠的候选边界框中,筛选出最具代表性的边界框,即通过置信度分数和重叠区域的过滤,保留最具代表性的边界框。

边界框(Bounding Boxes):一组表示候选目标区域的矩形框,每个框由左上角和右下角的坐标(x1, y1, x2, y2)表示。
置信度分数(Confidence Scores):每个边界框对应的一个置信度分数,表示该框内包含目标的可能性。

执行步骤

初始化:
boxes:输入的边界框列表。
scores:每个边界框对应的置信度得分列表。
confidence_threshold:过滤边界框的最低置信度阈值。
iou_threshold:用于确定边界框是否重叠的 IOU 阈值。

过滤低置信度边界框:
根据 confidence_threshold 过滤掉置信度低于该阈值的边界框。

按置信度排序:
对剩余的边界框按照置信度从高到低排序。

非极大值抑制:
从排序后的列表中选择置信度最高的边界框,并计算其与其他边界框的 Intersection-over-Union (IoU)。
如果 IoU大于 iou_threshold,则移除该边界框(表示重叠太多)。
重复该过程直到处理完所有边界框。

返回结果:
返回保留的边界框的索引。
在这里插入图片描述
可视化 Intersection-over-Union (IoU)

蓝色矩形表示 Box A,红色矩形表示 Box B,绿色矩形表示它们的交集区域,剩余的红色和蓝色是并集区域。
在这里插入图片描述

torchvision.ops.nms 和 cv2.dnn.NMSBoxes 的调用

import numpy as np
import torch
import torchvision.ops as ops
import cv2

# 输入数据
boxes = np.array([
    [100, 100, 210, 210], [220, 220, 320, 330], [300, 300, 400, 400],
    [50, 50, 150, 200], [200, 150, 280, 320], [280, 280, 380, 380],
    [80, 90, 190, 210], [250, 250, 350, 370], [290, 290, 390, 390]
])# (x1, y1, x2, y2)格式
scores = np.array([0.9, 0.8, 0.75, 0.85, 0.7, 0.65, 0.82, 0.78, 0.6])
score_threshold = 0.5
nms_threshold = 0.4

def convert_to_xywh(boxes): #opencv用 (x, y, w, h)格式
    """
    将边界框从 (x1, y1, x2, y2) 格式转换为 (x, y, w, h) 格式。
    
    参数:
    - boxes: 形状为 (N, 4) 的数组,其中 N 是边界框的数量
    
    返回:
    - boxes_xywh: 形状为 (N, 4) 的数组,包含转换后的边界框
    """
    boxes_xywh = np.zeros_like(boxes)
    boxes_xywh[:, 0] = boxes[:, 0]  # x
    boxes_xywh[:, 1] = boxes[:, 1]  # y
    boxes_xywh[:, 2] = boxes[:, 2] - boxes[:, 0]  # w
    boxes_xywh[:, 3] = boxes[:, 3] - boxes[:, 1]  # h
    return boxes_xywh

def nms_torchvision(boxes, scores, nms_threshold):
    boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
    scores_tensor = torch.tensor(scores, dtype=torch.float32)
    keep = ops.nms(boxes_tensor, scores_tensor, nms_threshold)
    return keep.numpy()

def nms_opencv(boxes, scores, score_threshold, nms_threshold):
    boxes = convert_to_xywh(boxes)
    indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), score_threshold, nms_threshold)
    return np.array(indices).flatten()

# 调用 NMS
keep_torchvision = nms_torchvision(boxes, scores, nms_threshold)
keep_opencv = nms_opencv(boxes, scores, score_threshold, nms_threshold)

print("使用 torchvision.ops.nms 保留的边界框索引: ", keep_torchvision)
print("使用 cv2.dnn.NMSBoxes 保留的边界框索引: ", keep_opencv)

输出

使用 torchvision.ops.nms 保留的边界框索引:  [0 3 1 7 2 4]
使用 cv2.dnn.NMSBoxes 保留的边界框索引:  [0 3 1 7 2 4]

用纯 NumPy 实现的非最大值抑制(NMS)函数

import numpy as np

def nms(boxes, scores, score_threshold, nms_threshold):
    """单类 NMS 使用 NumPy 实现。"""
    # 过滤掉低置信度的框
    indices = np.where(scores > score_threshold)[0]
    boxes = boxes[indices]
    scores = scores[indices]

    # 如果没有剩余的框,返回空列表
    if len(boxes) == 0:
        return []

    # 提取每个边界框的坐标
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]

    # 计算每个边界框的面积
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    # 根据分数进行排序(从高到低)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(indices[i])
        # 计算当前边界框与其余边界框的交集坐标
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        # 计算交集的宽度和高度
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        # 计算交集面积
        inter = w * h
        # 计算交并比(IOU)
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        # 只保留 IOU 小于阈值的边界框
        inds = np.where(ovr <= nms_threshold)[0]
        order = order[inds + 1]

    return keep

# 示例数据
boxes = np.array([
    [100, 100, 210, 210], [220, 220, 320, 330], [300, 300, 400, 400],
    [50, 50, 150, 200], [200, 150, 280, 320], [280, 280, 380, 380],
    [80, 90, 190, 210], [250, 250, 350, 370], [290, 290, 390, 390]
])
scores = np.array([0.9, 0.8, 0.75, 0.85, 0.7, 0.65, 0.82, 0.78, 0.6])
score_threshold = 0.5
nms_threshold = 0.4

# 调用NMS
keep_indices = nms(boxes, scores, score_threshold, nms_threshold)
print("使用 NumPy 实现的 NMS 保留的边界框索引: ", keep_indices)
使用 NumPy 实现的 NMS 保留的边界框索引:  [0, 3, 1, 7, 2, 4]

关于语法的解释

在 NumPy 中,冒号 : 用于数组切片。它们可以用来提取数组的子集、重排数组或选取特定的元素。

示例1

scores.argsort()[::-1]
scores.argsort():返回 scores 中元素的索引数组,这些索引会将 scores 排序。
[::-1]:表示反转数组。
在这个例子中,[::-1] 表示从开始到结束,步长为 -1,因此数组会被反转。这里的两个冒号是为了清楚地表示切片的完整语法 [start:stop:step],其中省略了 start 和 stop,只指定了 step 为 -1。

import numpy as np

scores = np.array([0.9, 0.8, 0.75, 0.85, 0.7, 0.65, 0.82, 0.78, 0.6])
sorted_indices = scores.argsort()  # 升序排序的索引
print("sorted_indices:", sorted_indices)

# 反转排序索引(降序排序)
reversed_indices = sorted_indices[::-1]
print("reversed_indices:", reversed_indices)
sorted_indices: [8 5 4 2 7 1 6 3 0]
reversed_indices: [0 3 6 1 7 2 4 5 8]

示例2

boxes[:, 0]
boxes[:, 0]:选取 boxes 数组中第 0 列的所有元素。
: 表示选择所有行,0 表示选择第 0 列。
这段代码的作用是提取 boxes 数组中每个边界框的 x1 坐标(左上角的 x 坐标)。

import numpy as np
boxes = np.array([
    [100, 100, 210, 210],
    [220, 220, 320, 330],
    [300, 300, 400, 400],
    [50, 50, 150, 200]
])

x1 = boxes[:, 0]
print("x1:", x1)
x1: [100 220 300  50]

可视化数据的代码

def plot_boxes(boxes, keep_indices):
    fig, ax = plt.subplots(1, figsize=(12, 12))

    for i, box in enumerate(boxes):
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1

        # 所有输入框用蓝色绘制
        edgecolor = 'blue'
        if i in keep_indices:
            # NMS 保留的框用绿色绘制
            edgecolor = 'green'
        else:
            # 被抑制的框用红色绘制
            edgecolor = 'red'
        
        rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor=edgecolor, facecolor='none')
        ax.add_patch(rect)

    # 设置坐标范围
    ax.set_xlim(0, np.max(boxes[:, [0, 2]]) + 50)
    ax.set_ylim(0, np.max(boxes[:, [1, 3]]) + 50)
    ax.invert_yaxis()  # 图像坐标系和实际坐标系相反时需要

    plt.show()

# 示例数据
boxes = np.array([
    [100, 100, 210, 210], [220, 220, 320, 330], [300, 300, 400, 400],
    [50, 50, 150, 200], [200, 150, 280, 320], [280, 280, 380, 380],
    [80, 90, 190, 210], [250, 250, 350, 370], [290, 290, 390, 390]
])
scores = np.array([0.9, 0.8, 0.75, 0.85, 0.7, 0.65, 0.82, 0.78, 0.6])
score_threshold = 0.5
nms_threshold = 0.4

# 调用NMS
keep_indices = nms(boxes, scores, score_threshold, nms_threshold)
print("使用 NumPy 实现的 NMS 保留的边界框索引: ", keep_indices)

# 绘图
plot_boxes(boxes, keep_indices)

可视化 Intersection-over-Union (IoU)的代码

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def plot_iou(boxA, boxB):
    fig, ax = plt.subplots(1, figsize=(8, 8))

    # 绘制 Box A
    x1A, y1A, x2A, y2A = boxA
    widthA = x2A - x1A
    heightA = y2A - y1A
    rectA = patches.Rectangle((x1A, y1A), widthA, heightA, linewidth=2, edgecolor='blue', facecolor='blue', label='Box A')
    ax.add_patch(rectA)

    # 绘制 Box B
    x1B, y1B, x2B, y2B = boxB
    widthB = x2B - x1B
    heightB = y2B - y1B
    rectB = patches.Rectangle((x1B, y1B), widthB, heightB, linewidth=2, edgecolor='red', facecolor='red', label='Box B')
    ax.add_patch(rectB)

    # 计算交集
    xx1 = np.maximum(x1A, x1B)
    yy1 = np.maximum(y1A, y1B)
    xx2 = np.minimum(x2A, x2B)
    yy2 = np.minimum(y2A, y2B)

    w = np.maximum(0, xx2 - xx1)
    h = np.maximum(0, yy2 - yy1)
    intersection_area = w * h

    # 计算并集
    areaA = (x2A - x1A) * (y2A - y1A)
    areaB = (x2B - x1B) * (y2B - y1B)
    union_area = areaA + areaB - intersection_area

    # 计算 IoU
    iou = intersection_area / union_area

    # 绘制交集
    if w > 0 and h > 0:
        rect_intersection = patches.Rectangle((xx1, yy1), w, h, linewidth=2, edgecolor='green', facecolor='green', linestyle='--', label='Intersection')
        ax.add_patch(rect_intersection)

    # 显示图例
    handles, labels = ax.get_legend_handles_labels()

    plt.legend(handles=handles)

    plt.xlim(0, 500)
    plt.ylim(0, 500)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.title(f'IoU = {iou:.2f}')
    plt.show()

# 示例框
boxA = [100, 100, 300, 300]
boxB = [200, 200, 400, 400]

plot_iou(boxA, boxB)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1870246.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件必须要进行跨浏览器测试吗?包括哪些内容和注意事项?

随着互联网的普及和发展&#xff0c;用户对软件的要求越来越高。无论是在台式机、笔记本还是移动设备上&#xff0c;用户都希望能够以最好的体验来使用软件。然而&#xff0c;不同的浏览器在解析网页的方式、支持的技术标准等方面存在差异&#xff0c;这就导致了同一个网页在不…

LeetCode 585, 438, 98

目录 585. 2016年的投资题目链接表要求知识点思路代码 438. 找到字符串中所有字母异位词题目链接标签思路代码 98. 验证二叉搜索树题目链接标签合法区间思路代码 中序遍历思路代码 585. 2016年的投资 题目链接 585. 2016年的投资 表 表Insurance的字段为pid、tiv_2015、tiv…

RabbitMQ WEB管理端介绍

页面功能概览 Overview(概述)Connections(连接)Channels(通道)Exchanges(交换器)Queues(队列)Admin(用户管理)。 1. Overview(概述) 主要分为三部分 1.1 Queued messages&#xff08;所有队列的消息情况&#xff09; Ready&#xff1a;待消费的消息总数Unacked&#xff1a;待应…

抖音集成:通过MessageBox引领数字化营销新潮流

抖音集成&#xff1a;通过MessageBox引领数字化营销新潮流 在数字化营销的大潮中&#xff0c;企业需要不断探索新的方式来优化其营销策略&#xff0c;以抓住更多的市场机会。抖音作为一款全球知名的短视频社交平台&#xff0c;凭借其庞大的用户群体和高度互动的特性&#xff0…

[leetcode]24-game

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:static constexpr int TARGET 24;static constexpr double EPSILON 1e-6;static constexpr int ADD 0, MULTIPLY 1, SUBTRACT 2, DIVIDE 3;bool judgePoint24(vector<int> &nums) {vector&l…

A股跌懵了,股民一片茫然!

今天的A股跌懵了&#xff0c;股民一片茫然&#xff01;让人脸色苍白&#xff0c;盘面上出现了非常奇怪的一幕&#xff0c;不废话&#xff0c;直接说重点&#xff1a; 1、今天两市低开低走&#xff0c;跌懵了&#xff0c;昨晚人民币汇率大幅贬值&#xff0c;创下7.3的记录&#…

转转游戏MQ重构:思考与心得之旅

文章目录 1 背景1.1 起始之由1.2 重构前现状1.3 问题分析 2 重构2.1 目标2.2 制定方案2.2.1 架构设计2.2.2 实施计划2.2.3 测试计划 2.3 部分细节设计 3. 总结 1 背景 游戏业务自 2017 年启航&#xff0c;至今已近乎走过七个春秋&#xff0c;历经漫长岁月的发展&#xff0c;不…

SpringSecutrity原理

一、基于RBAC实现的权限管理通常需要涉及以下几张表&#xff1a; 1. 用户表&#xff08;user&#xff09;&#xff1a;记录系统中的所有用户&#xff0c;包括用户ID、用户名、密码等信息。 2. 角色表&#xff08;role&#xff09;&#xff1a;记录系统中的所有角色&#xff0…

【MySQL】(基础篇十七) —— 存储过程

存储过程 本文将介绍什么是存储过程&#xff0c;为什么要使用存储过程以及如何使用存储过程&#xff0c;并且介绍创建和使用存储过程的基本语法。 MySQL的存储过程是预编译的SQL语句集合&#xff0c;它们作为一个可执行单元存储在数据库中。存储过程能够封装复杂的业务逻辑&a…

分享一款永久免费内网穿透工具——巴比达内网穿透

最近在做web项目&#xff0c;想办法将web项目映射到公网进行访问&#xff0c;由于没有固定IP&#xff0c;只能使用内网穿透的方法&#xff0c;于是在网上搜索了一番&#xff0c;只有神卓互联旗下的这款巴比达内网穿透是真正免费的&#xff0c; 其它的要么用不了、要么限制没有流…

文件进行周期性备份后权限更改的解决方案--使用脚本和定时任务

这里写目录标题 背景现象解决方案原因分析面临的问题解决思路操作步骤每个文件夹权限分配表测试chmod和chown两个命令是否可行写脚本实现定时同步同时修改权限 异地同步改权限在NAS上生成SSH密钥对将NAS的公钥复制到Linux服务器在NAS上编写同步脚本在NAS上执行脚本&#xff0c;…

记录一次OPDS trunc()函数使用错误

说明&#xff1a;本文介绍 场景 在一次SQL查询时&#xff0c;需要对结果值保留两位小数&#xff0c;不四舍五入&#xff0c;直接截取到小数点后两位。如 59.156到59.15&#xff0c;23.2134到23.21&#xff0c;查看官方帮助文档&#xff08;https://help.aliyun.com/zh/maxcom…

【分享】30秒在线自助制作电子证件照

近期由于自己需要制作电子证件照&#xff0c;所以在网上找在线制作电子证件照的网站&#xff0c;找了很多网站都是收费的&#xff0c;也下载了很多app制作&#xff0c;都是要收费的。最后&#xff0c;所以索性自己开发一个网站制作电子证件照。这里分享给需要的朋友。&#xff…

探索Android架构设计

Android 应用架构设计探索&#xff1a;MVC、MVP、MVVM和组件化 MVC、MVP和MVVM是常见的三种架构设计模式&#xff0c;当前MVP和MVVM的使用相对比较广泛&#xff0c;当然MVC也并没有过时之说。而所谓的组件化就是指将应用根据业务需求划分成各个模块来进行开发&#xff0c;每个…

14、电科院FTU检测标准学习笔记-录波功能2

作者简介&#xff1a; 本人从事电力系统多年&#xff0c;岗位包含研发&#xff0c;测试&#xff0c;工程等&#xff0c;具有丰富的经验 在配电自动化验收测试以及电科院测试中&#xff0c;本人全程参与&#xff0c;积累了不少现场的经验 ———————————————————…

力扣:203. 移除链表元素(Java)

目录 题目描述&#xff1a;示例 1&#xff1a;示例 2&#xff1a;代码实现&#xff1a; 题目描述&#xff1a; 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 示例 1&#xff1a; 输入…

C++学习笔记---POCO库

在Windows系统中安装POCO 1&#xff09;安装OpenSSL POCO编译安装依赖OpenSSL&#xff0c;如果未安装OpenSSL则应该先安装OpenSSL。 假设将OpenSSL安装在C:\OpenSSL-Win64&#xff0c;将C:\OpenSSL-Win64、C:\OpenSSL-Win64\lib添加到PATH环境变量中2&#xff09;安装POCO 将p…

这不是危言耸听!时序Transformer颠覆传统,历史级突破!

【时间序列Transformer】在近年来的深度学习领域中备受关注&#xff0c;它通过将Transformer架构应用于时间序列数据&#xff0c;显著提升了模型在长时间依赖建模和复杂模式识别任务中的表现。时间序列Transformer技术已经在金融预测、气象预报和健康监测等多个领域取得了显著成…

PNAS|这样也可以?拿别人数据发自己Paper?速围观!

还在为数据量小&#xff0c;说服力不足发愁&#xff1f; 想研究脱颖而出、眼前一亮&#xff1f; 想从更高层次的探索微生物的奥秘&#xff0c;发出一篇好文章&#xff1f; 近期&#xff0c;有一篇发表在PNAS(IF11.1)的文章“Deforestation impacts soil biodiversity and ecos…

Swift 周报 第五十六期

文章目录 前言新闻和社区苹果与消费者修改 3500 万美元 iPhone 音响和解协议苹果(AAPL.US)因监管担忧今年不会在欧盟推出 AI 功能苹果暂停高端 Vision 头戴设备研发 计划推出更廉价版 Swift论坛推荐博文话题讨论关于我们 前言 本期是 Swift 编辑组自主整理周报的第五十六期&am…