非极大值抑制算法(Non-Maximum Suppression,NMS)

news2025/2/23 3:26:42

https://tcnull.github.io/nms/
https://blog.csdn.net/weicao1990/article/details/103857298

目标检测中检测出了许多的候选框,候选框之间是有重叠的,NMS作用重叠的候选框只保留一个

算法:

  1. 将所有候选框放入到集和B
  2. 从B中选出分数S最大的bm
  3. 将bm 从集和B中移除到集和D
  4. 计算bm与B中剩余的候选框之间的IOU。
  5. 如果iou>Nt则将其从B中删除。(去除重叠比较多的候选框)
  6. 循环直至B为空。
  7. D会越来越多。
def box_iou_union_2d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tuple[Tensor, Tensor]:
    """
    Return intersection-over-union (Jaccard index) and  of boxes.
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

    Arguments:
        boxes1: set of boxes (x1, y1, x2, y2)[N, 4]
        boxes2: set of boxes (x1, y1, x2, y2)[M, 4]
        eps: optional small constant for numerical stability

    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise
            IoU values for every element in boxes1 and boxes2
        union (Tensor[N, M]): the nxM matrix containing the pairwise union
            values
    """
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    x1 = torch.max(boxes1[:, None, 0], boxes2[:, 0])  # [N, M]
    y1 = torch.max(boxes1[:, None, 1], boxes2[:, 1])  # [N, M]
    x2 = torch.min(boxes1[:, None, 2], boxes2[:, 2])  # [N, M]
    y2 = torch.min(boxes1[:, None, 3], boxes2[:, 3])  # [N, M]

    inter = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)) + eps  # [N, M]
    union = (area1[:, None] + area2 - inter)
    return inter / union, union
def nms_cpu(boxes, scores, thresh):
    """
    Performs non-maximum suppression for 3d boxes on cpu
   
    Args:
        boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
        scores (Tensor): score for each box [N]
        iou_threshold (float): threshould when boxes are discarded
   
    Returns:
        keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,
            sorted in decreasing order of scores
    """
    ious = box_iou(boxes, boxes)
    _, _idx = torch.sort(scores, descending=True)
   
    keep = []
    while _idx.nelement() > 0:
        keep.append(_idx[0])
        # get all elements that were not matched and discard all others.
        non_matches = torch.where((ious[_idx[0]][_idx] <= thresh))[0]
        _idx = _idx[non_matches]
    return torch.tensor(keep).to(boxes).long()

template <typename T>
__device__ inline float devIoU(T const* const a, T const* const b) {
  // a, b hold box coords as (y1, x1, y2, x2) with y1 < y2 etc.
  T bottom = max(a[0], b[0]), top = min(a[2], b[2]);
  T left = max(a[1], b[1]), right = min(a[3], b[3]);
  T width = max(right - left, (T)0), height = max(top - bottom, (T)0);
  T interS = width * height;

  T Sa = (a[2] - a[0]) * (a[3] - a[1]);
  T Sb = (b[2] - b[0]) * (b[3] - b[1]);

  return interS / (Sa + Sb - interS);
}
at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, float iou_threshold) {
  /* dets expected as (n_dets, dim) where dim=4 in 2D, dim=6 in 3D */
  AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");
  AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");
  at::cuda::CUDAGuard device_guard(dets.device());//管理CUDA设备上下文,并指制定使用的CUDA设备

  bool is_3d = dets.size(1) == 6;
  //按照第一维(索引为0)对scores降序排序,并返回一个包含排序后索引的 Tensor,std::get<1>(...) 提取排序后索引的 Tensor。
  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
  //使用排序后的索引 order_t 从 dets 中选择对应的元素,返回一个新的 Tensor; dets_sorted,其中的元素按照排序后的顺序排列。
  auto dets_sorted = dets.index_select(0, order_t);
  //bbox个数
  int dets_num = dets.size(0);
  //该函数用于计算在CUDA中进行并行化计算时所需的最大列块数。其中,dets_num代表待处理的数据数量,threadsPerBlock表示每个CUDA块中的线程数量。
  //函数通过将dets_num除以threadsPerBlock并向上取整得到最大列块数col_blocks。这个函数常用于确定CUDA并行计算中需要启动多少个CUDA块来处理所有数据。
  const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);//获取block个数

  //该函数创建了一个名为mask的空Tensor,其大小为dets_num * col_blocks,数据类型为长整型(at::kLong)。该Tensor的属性与dets相同。
  at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));//创建一个空容器  D

  dim3 blocks(col_blocks, col_blocks);//定义了一个二维网格,其维度为col_blocks行col_blocks列
  dim3 threads(threadsPerBlock);//定义了一个线程块,其维度为threadsPerBlock;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();//获取当前的CUDA流,用于异步计算。


  if (is_3d) {
  //std::cout << "performing NMS on 3D boxes in CUDA" << std::endl;
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      dets_sorted.type(), "nms_kernel_cuda", [&] {
        nms_kernel_3d<scalar_t><<<blocks, threads, 0, stream>>>(
            dets_num,
            iou_threshold,
            dets_sorted.data_ptr<scalar_t>(),
            (unsigned long long*)mask.data_ptr<int64_t>());
      });
   }
   else {
    //该函数是PyTorch CUDA扩展中的一个宏定义,用于在CUDA代码中处理浮点类型数据(包括float、double、half)的泛型编程。
    //它会根据传入的输入数据类型,自动选择对应的CUDA内核函数进行计算。
    //这样可以避免为每种数据类型编写重复的代码,提高代码的可维护性和可扩展性。
    //相当与模板类
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      dets_sorted.type(), "nms_kernel_cuda", [&] {
        nms_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
            dets_num,
            iou_threshold,
            dets_sorted.data_ptr<scalar_t>(),
            (unsigned long long*)mask.data_ptr<int64_t>());
      });

   }

  //将mask_cpu的数据拷贝到主机内存中
  at::Tensor mask_cpu = mask.to(at::kCPU);
  unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();

  //创建一个空的remv Tensor,其大小为col_blocks,数据类型为unsigned long long。用于记录每个block中哪些线程被标记为无效。
  std::vector<unsigned long long> remv(col_blocks);
  memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

  //创建一个空的keep Tensor,其大小为dets_num,数据类型为long。用于存放整个筛选后的索引。
  at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
  int64_t* keep_out = keep.data_ptr<int64_t>();

  int num_to_keep = 0;
  for (int i = 0; i < dets_num; i++) {
    int nblock = i / threadsPerBlock;//第几个block
    int inblock = i % threadsPerBlock;//block中第几个threads

    //判断当前线程是否被标记为无效
    if (!(remv[nblock] & (1ULL << inblock))) {
      keep_out[num_to_keep++] = i;
      unsigned long long* p = mask_host + i * col_blocks;//实际上将所有的形式设置成one hot形式。
      for (int j = nblock; j < col_blocks; j++) {
        remv[j] |= p[j];//按位或运算,并将结果保留在remv数组中
      }
    }
  }

  AT_CUDA_CHECK(cudaGetLastError());//检测CUDA 是否发生错误
  return order_t.index(
      {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)//从张量keep中选择制定维度、起始位置和长度的子张量
           .to(order_t.device(), keep.scalar_type())});
}
template <typename T>
__global__ void nms_kernel(const int n_boxes, const float iou_threshold, const T* dev_boxes, unsigned long long* dev_mask) {
  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

  // if (row_start > col_start) return;
  const int row_size = min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);//当前block行可以放多少个box
  const int col_size = min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);//获取block列可以放多少个box

  __shared__ T block_boxes[threadsPerBlock * 4];
  if (threadIdx.x < col_size) {//将当前block列的box拷贝到shared memory
    block_boxes[threadIdx.x * 4 + 0] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];
    block_boxes[threadIdx.x * 4 + 1] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];
    block_boxes[threadIdx.x * 4 + 2] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];
    block_boxes[threadIdx.x * 4 + 3] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];
  }
  __syncthreads();//同步

  if (threadIdx.x < row_size) {
    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;//当前block行box的索引
    const T* cur_box = dev_boxes + cur_box_idx * 4;//当前block行box的指针
    int i = 0;
    unsigned long long t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;//自己的IOU就不算了
    }
    for (i = start; i < col_size; i++) {
      if (devIoU<T>(cur_box, block_boxes + i * 4) > iou_threshold) {//计算各自的IOU
        t |= 1ULL << i;//以二进制的形式表示重叠关系成立
      }
    }
    const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);//列block数
    dev_mask[cur_box_idx * col_blocks + col_start] = t;//将重叠关系写入shared memory
  }
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1867278.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hadoop3:Yarn框架的三种调度算法

一、概述 目前&#xff0c;Hadoop作业调度器主要有三种&#xff1a;FIFO、容量&#xff08;Capacity Scheduler&#xff09;和公平&#xff08;Fair Scheduler&#xff09;。Apache Hadoop3.1.3默认的资源调度器是Capacity Scheduler。 CDH框架默认调度器是Fair Scheduler。 …

Http客户端-Feign 学习笔记

作者介绍&#xff1a;计算机专业研究生&#xff0c;现企业打工人&#xff0c;从事Java全栈开发 主要内容&#xff1a;技术学习笔记、Java实战项目、项目问题解决记录、AI、简历模板、简历指导、技术交流、论文交流&#xff08;SCI论文两篇&#xff09; 上点关注下点赞 生活越过…

EVPN-VXLAN:如何在数据中心使用

移动设备的迅速普及、社交媒体和协作工具的使用不断增加&#xff0c;使得网络中的端点数量日益增多。这种端点的快速增长促使对更有效的分段策略的需求&#xff0c;以区分不同用户、设备和流量类型。EVPN-VXLAN通过在物理第3层底层网络建立第2层覆盖虚拟网络的隧道&#xff0c;…

VMware虚拟机安装CentOS7.9 Oracle 11.2.0.4 RAC+单节点RAC ADG

目录 一、参考资料 二、RAC环境配置清单 1.主机环境 2.共享存储 3.IP地址 4.虚拟机 三、系统参数配置 1. 配置网卡 1.1 配置NAT网卡 1.2 配置HostOnly网卡 2. 修改主机名 3. 配置/etc/hosts 4. 关闭防火墙 5. 关闭Selinux 6. 配置内核参数 7. 配置grid、oracle…

SSI 注入漏洞

0x00漏洞描述 SSI 英文是 Server Side Includes 的缩写&#xff0c;翻译成中文就是服务器端包含的意思。从技术角度上说&#xff0c;SSI 就是在 HTML 文件中&#xff0c;可以通过注入注释调用的命令或指针。SSI 具有强大的功能&#xff0c;只要使用一条简单的 SSI 命令就可以实…

Ubuntu挂载window的网络共享文件夹爱

1.进入win10创建一个用户smb密码也是smb 2.右键进入文件夹共享 3.进入Ubuntu安装支持cifs-utils sudo apt update sudo apt install cifs-utils 4.sudo mkdir /mnt/shared 5.挂载&#xff1a; sudo mount -t cifs -o usernamesm bpasswordsmb //172.16.11.37(windowsIP)/s…

结构体(二)

今天来继续介绍我们有关结构体的相关知识 结构体的自引用 结构体的自引用&#xff0c;顾名思义嘛&#xff0c;就是在我们的结构体中再次引用该结构体&#xff0c;这一点跟我们的函数递归有异曲同工之妙&#xff0c;不了解函数递归的小伙伴可以移步到我之前做过的一期&#xf…

ElasticSearch索引架构与存储

关于ES官网的介绍: Elasticsearch provides near real-time search and analytics for all types of data. Whether you have structured or unstructured text, numerical data, or geospatial data, Elasticsearch can efficiently store and index it in a way that support…

django学习入门系列之第三点《案例 商品推荐部分》

文章目录 划分区域搭建骨架完整代码小结往期回顾 划分区域 搭建骨架 /*商品图片&#xff0c;父级设置*/ .slider .sd-img{display: block;width: 1226px;height: 460px; }<!-- 商品推荐部分 --> <!--搭建出一个骨架--> <div class"slider"><di…

openwrt igmp 适配

每弄完一次&#xff0c;过不多久就忘了&#xff0c;这次决心记下来。 openwrt 的igmpproxy 包是干嘛的&#xff1f;原来&#xff0c;组播包并不能穿透路由&#xff0c;也就是我们在wan端播放的组播视频流&#xff0c;lan端是没法收到的&#xff0c;igmpproxy就是用来打通wan端…

【Linux】进程信号_3

文章目录 八、进程信号2. 信号的保存3. 信号的处理 未完待续 八、进程信号 2. 信号的保存 实际执行信号的处理动作称为信号递达(Delivery) 信号从产生到递达之间的状态,称为信号未决(Pending)。 进程可以选择阻塞 (Block )某个信号。 被阻塞的信号产生时将保持在未决状态,直到…

网络问题排障专题-AF网络问题排障

目录 一、数据交换基本原理 1、ARP协议工作原理 数据包如图&#xff1a; 2、二层交换工作原理 简述核心概念&#xff1a; 二层交换原理-VLAN标签 3、三层交换工作原理 二、AF各种部署模式数据转发流程 1、路由模式数据转发流程 三、分层/分组逐一案例讲解 1、问题现…

自然语言处理——英文文本预处理

高质量数据的重要性 数据的质量直接影响模型的性能和准确性。高质量的数据可以显著提升模型的学习效果&#xff0c;帮助模型更准确地识别模式、进行预测和决策。具体原因包括以下几点&#xff1a; 噪音减少&#xff1a;高质量的数据经过清理&#xff0c;减少了无关或错误信息…

redis哨兵模式(Redis Sentinel)

哨兵模式的背景 当主服务器宕机后&#xff0c;需要手动把一台从服务器切换为主服务器&#xff0c;这就需要人工干预&#xff0c;费事费力&#xff0c;还会造成一段时间内服务不可用。这不是一种推荐的方式。 为了解决单点故障和提高系统的可用性&#xff0c;需要一种自动化的监…

【D3.js in Action 3 精译】1.2.2 可缩放矢量图形(一)

译注 由于 1.2.2 小节介绍 SVG 的篇幅过多&#xff0c;为了方便查阅&#xff0c;后续将分多个小节依次进行翻译。为了确保整个 1.2.2 小节的完整性&#xff0c;特意将上一篇包含的 SVG 小节的内容整理出来重新编排。敬请留意。 1.2.2 SVG - 可缩放矢量图形 可伸缩矢量图形&…

Django-开发一个列表页面

需求 基于ListView,创建一个列表视图,用于展示"BookInfo"表的信息要求提供分页提供对书名,作者,描述的查询功能 示例展示: 1. 数据模型 models.py class BookInfo(models.Model):titlemodels.CharField(verbose_name"书名",max_length100)authormode…

【STM32-存储器映射】

STM32-存储器映射 ■ STM32F1-4G地址空间分成8个块■ STM32F1-Block0■ STM32F1-Block1■ STM32F1-Block2■ STM32F1- ■ STM32F1-4G地址空间分成8个块 ■ STM32F1-Block0 有出厂 BootLoader 就可以使用串口下载程序。如Keil5图中IROM地址是0x8000000 开始 就是flash地址 ■ S…

安装vue开发者工具

浏览器控制台提示&#xff1a; 打开网址 GitHub - vuejs/devtools: ⚙️ Browser devtools extension for debugging Vue.js applications. 点击添加 上图地址&#xff1a;Installation | Vue Devtools 安装好了

线上民族传统服饰商城

摘 要 随着互联网的不断发展和普及&#xff0c;电子商务成为了人们生活中不可或缺的一部分。传统的线下购物方式逐渐被线上购物所取代&#xff0c;人们越来越习惯在互联网上购物。而民族传统服饰作为我国丰富多样的民族文化的重要组成部分&#xff0c;具有独特的艺术价值和商业…

不是KVM不支持精简置备的磁盘,而是VMM

正文共&#xff1a;999 字 11 图&#xff0c;预估阅读时间&#xff1a;1 分钟 书接上文&#xff08;不会吧&#xff01;KVM竟然不支持磁盘的精简置备&#xff01;&#xff1f;&#xff09;&#xff0c;我们已经掌握了通过“虚拟系统管理器VMM”创建虚拟机的基本方法&#xff0c…