iou的cpu和gpu源码实现

news2024/12/27 12:19:23

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

简介

IoU(Intersection over Union)是一种测量在特定数据集中检测相应物体准确度的一个标准,通常用于目标检测中预测框(bounding box)之间准确度的一个度量(预测框和实际目标框)。
在这里插入图片描述

IoU计算的是“预测的边框”和“真实的边框”的交叠率,即它们的交集和并集的比值。最理想情况是完全重叠,即比值为1。

IoU的计算方法如下:

计算两个框的交集面积,即两个框的左、上、右、下四个点的交集。
计算两个框的并集面积,即两个框的左、上、右、下四个点的并集。
计算交集面积和并集面积的比值,即为 IoU 值。
IoU的优点是可以反映预测检测框与真实检测框的检测效果,并且具有尺度不变性,即对尺度不敏感。但是,IoU也存在一些缺点,例如无法反映两个框之间的距离大小(重合度),如果两个框没有相交,则 IoU 值为 0,无法进行学习训练。

源码实现:

cpu版源码实现:

def iou_core(box1: Tensor, box2: Tensor, area_sum: Tensor):
    overlap_w = torch.min(box1[2],box2[2]) - torch.max(box1[0],box2[0])
    overlap_h = torch.min(box1[3],box2[3]) - torch.max(box1[1],box2[1])
    if overlap_w <= 0 or overlap_h <= 0:
        return 0
    overlap_area = overlap_h * overlap_w
    return overlap_area / (area_sum - overlap_area)

def iou_cpu(box1: Tensor, box2: Tensor):
    box1_num = box1.size(0)
    box2_num = box2.size(0)
    box1_dim = box1.size(1)
    box2_dim = box2.size(1)
    if box1_dim != 4 or box2_dim != 4:
        return -1

    box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
    box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])

    result = torch.zeros(size=(box1_num, box2_num))
    for i in range(box1_num):
        for j in range(box2_num):
            if box1_area[i] >= 0 and box2_area[j] >= 0:
                result[i, j] = iou_core(box1[i], box2[j], box1_area[i] + box2_area[j])
            else:
                result[i, j] = 9999
    return result

gpu版源码实现:

__device__ float iou_core(const float* box1 ,const float* box2){
    float box1_x0 = *(box1 + 0);
    float box1_y0 = *(box1 + 1);
    float box1_x1 = *(box1 + 2);
    float box1_y1 = *(box1 + 3);
    float box2_x0 = *(box2 + 0);
    float box2_y0 = *(box2 + 1);
    float box2_x1 = *(box2 + 2);
    float box2_y1 = *(box2 + 3);
    if(!(box1_x0 < box1_x1 && box1_y0 < box1_y1 && box2_x0 < box2_x1 && box2_y0 < box2_y1)){
        return 9999;
    }

    float inter_x0 = std::max(box1_x0, box2_x0);
    float inter_x1 = std::min(box1_x1, box2_x1);
    float inter_y0 = std::max(box1_y0, box2_y0);
    float inter_y1 = std::min(box1_y1, box2_y1);
    float inter_area = (inter_x1 - inter_x0)*(inter_y1-inter_y0);
    inter_area = std::max(inter_area, 0.0f);

    float box1_area = (box1_x1 - box1_x0)*(box1_y1-box1_y0);
    float box2_area = (box2_x1 - box2_x0)*(box2_y1-box2_y0);
    float iou = inter_area / (box1_area + box2_area - inter_area);
    printf("iou =%f\n",iou);
    return iou;
}

__global__ void iou_gpu_kernel(const int box1_num,
const float* box1_ptr,
const int box2_num,
const float* box2_ptr,
float* result_ptr){
    const int box1_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
    const int box2_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
    printf("gpu: box1_idx = %d, box2_idy= %d\n",box1_idx,box2_idx);
    if(box1_idx>=box1_num || box2_idx>=box2_num){
        return;
    }
    printf("gpu: box1_idx = %d, box2_idy= %d, result_id= %d\n",box1_idx,box2_idx,box1_idx * box2_num + box2_idx);
    const float* box1 = box1_ptr + box1_idx * 4;
    const float* box2 = box2_ptr + box2_idx * 4;
    float iou = iou_core(box1, box2);
    *(result_ptr + box1_idx * box2_num + box2_idx) = iou;
}

void iou_gpu_launch(const int box1_num,
const float* box1_ptr,
const int box2_num,
const float* box2_ptr,
float* result_ptr){
    dim3 blocks(DIVUP(box1_num, THREADS_PER_BLOCK),DIVUP(box2_num, THREADS_PER_BLOCK));//每个grid的blocks
    dim3 threads(THREADS_PER_BLOCK,THREADS_PER_BLOCK);//每个block里面的thread
    printf("blocks=(%d %d), threads=(%d %d)\n",
        DIVUP(box1_num, THREADS_PER_BLOCK),DIVUP(box2_num, THREADS_PER_BLOCK),
        THREADS_PER_BLOCK,THREADS_PER_BLOCK);
    iou_gpu_kernel<<<blocks,threads>>>(box1_num,box1_ptr,box2_num,box2_ptr,result_ptr);
    cudaDeviceSynchronize();// waiting for gpu work
    printf("gpu done\n");
}

耗时测试:

import torch
from iou import iou_gpu, iou_cpu
from utils import TicToc

device = torch.device('cuda:0')
input1 = torch.Tensor([[0, 0, 1, 1],
                       [0, 2, 1, 3],
                       [0.2, 0, 1, 1],
                       [0.1, 2, 1, 3],
                       [0.11, 0, 1, 1],
                       [0, 2.4, 1, 3],
                       [0.2, 0.1, 1, 1],
                       [0.7, 2.5, 1, 3],
                       [0, 0, 6, 1],
                       [1.5, 2, 1, 3]]).to(device)
input2 = torch.Tensor([[0.5, 0, 1.5, 1],
                       [0, 0.5, 1, 1.5],
                       [0.5, 0.5, 1.5, 1.5],
                       [0, 0.5, 1, 2.5]]).to(device)

tictic = TicToc('iou fun')
for i in range(1000):
    result = iou_gpu(input1, input2)
tictic.toc()
tictic.tic()
for i in range(1000):
    result2 = iou_cpu(input1.to('cpu'), input2.to('cpu'))
tictic.toc()
pass

具体流程说明:

IoU的计算方法如下:
计算两个框的交集面积,即两个框的左、上、右、下四个点的交集。
计算两个框的并集面积,即两个框的左、上、右、下四个点的并集。
计算交集面积和并集面积的比值,即为 IoU 值。
在实际应用中,通常设定 IoU 的阈值,例如 0.5 或 0.7 等,当 IoU 值大于阈值时,认为预测成功。通过调整阈值,可以得到不同的模型,再通过不同的评价指标(如 ROC 曲线、F1 值等)来确定最优模型。

如需获取全套代码请参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1407388.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Arduino U8g2库:图形界面库的强大利器,

Arduino U8g2库&#xff1a;图形界面库的强大利器 介绍 在Arduino世界中&#xff0c;图形界面的显示通常是一项关键的任务。为了简化这个过程&#xff0c;提高开发效率&#xff0c;许多库被开发出来&#xff0c;其中U8g2库就是其中之一。U8g2库是一个功能强大的图形库&#x…

uniapp复选框 实现排他选项

选择了排他选项之后 复选框其他选项不可以选择 <view class"reportData" v-for"(val, index) in obj" :key"index"> <view v-if"val.type 3" ><u-checkbox-group v-model"optionValue" placement"colu…

web系统服务器监控检查

一、检查操作系统是否存在增减文件&#xff0c;是否有shell被上传 要检查操作系统是否存在增减文件或是否有shell被上传&#xff0c;您可以按照以下步骤进行操作&#xff1a; 文件完整性检查&#xff1a; 使用文件系统的完整性检查工具&#xff0c;例如fsck&#xff08;对于ext…

项目一:踏上Java开发之旅

文章目录 一、实战概述二、实战步骤任务1&#xff1a;安装配置JDK并开发第一个Java程序步骤一&#xff1a;安装JDK步骤二&#xff1a;配置JDK环境变量步骤三&#xff1a;开发第一个Java程序 课堂练习任务1、打印个人信息任务2、打印直角三角形任务3、打印一颗爱心任务4、打印史…

git:使用git rebase合并多次commit为一个

git log&#xff1a;找到需要合并的最早 commit 的父级 git rebase -i 73a5cd8597除第一个 pick 外&#xff0c;将其它改成 s&#xff0c;改完后保存退出 保存完后弹出 commit message 合并提示&#xff0c;根据这次合并的目的&#xff0c;重写commit message&#xff0c;改完后…

软考复习之软件工程篇

软件生命周期 问题定义&#xff1a;要示系统分析员与用户进行交流&#xff0c;弄清”用户需要计算机解决什么问题”然后提出关于“系统目标与范围的说明”&#xff0c;提交用户审查和确认 可行性研究&#xff1a;一方面在于把待开发的系统的目标以明确的语言描述出来&#xf…

httpClient忽略https的证书认证

忽略https证书认证代码: /*** 创建模拟客户端&#xff08;针对 https 客户端禁用 SSL 验证&#xff09;* return* throws Exception*/public static CloseableHttpClient createHttpClientWithNoSsl() throws Exception {// Create a trust manager that does not validate cer…

【C++】初步认识基于C的优化

C祖师爷在使用C语言时感觉到了不方便的一些点&#xff0c;于是一步一步改进优化&#xff0c;最后形成了C 本文将盘点一下基于C的语法优化 目录 命名空间&#xff1a;命名空间定义&#xff1a;命名空间使用&#xff1a; C输入&输出&#xff1a;cout&#xff1a;endl&#…

司铭宇老师:门店服装销售技巧培训:卖衣服销售方法和技巧

门店服装销售技巧培训&#xff1a;卖衣服销售方法和技巧 在服装零售行业&#xff0c;销售方法和技巧对于提升销售业绩和增强顾客满意度至关重要。一个成功的销售人员需要掌握如何吸引顾客、如何展示商品、如何促成交易等多方面的技能。以下是关于卖衣服的销售方法和技巧的详细…

ai智能写作软件有分享吗?分享4款解放双手的软件!

随着人工智能技术的不断发展&#xff0c;AI智能写作软件逐渐成为内容创作者们的新宠。这些软件不仅能够帮助我们快速生成高质量的文本内容&#xff0c;还能在优化搜索引擎排名等方面发挥重要作用。本文将为大家介绍几款常用的AI智能写作软件&#xff0c;让您轻松提升内容创作效…

如何在飞书创建企业ChatGPT智能问答助手应用并实现公网远程访问(1)

文章目录 前言环境列表1.飞书设置2.克隆feishu-chatgpt项目3.配置config.yaml文件4.运行feishu-chatgpt项目5.安装cpolar内网穿透6.固定公网地址7.机器人权限配置8.创建版本9.创建测试企业10. 机器人测试 前言 在飞书中创建chatGPT机器人并且对话&#xff0c;在下面操作步骤中…

Unity | 渡鸦避难所-8 | URP 中利用 Shader 实现角色受击闪白动画

1. 效果预览 当角色受到攻击时&#xff0c;为了增加游戏的视觉效果和反馈&#xff0c;可以添加粒子等动画&#xff0c;也可以使用 Shader 实现受击闪白动画&#xff1a;受到攻击时变为白色&#xff0c;逐渐恢复为正常颜色 本游戏中设定英雄受击时播放粒子效果&#xff0c;怪物…

pytorch实战-6手写数字加法机-迁移学习

1 概述 迁移学习概念&#xff1a;将已经训练好的识别某些信息的网络拿去经过训练识别另外不同类别的信息 优越性&#xff1a;提高了训练模型利用率&#xff0c;解决了数据缺失的问题&#xff08;对于新的预测场景&#xff0c;不需要大量的数据&#xff0c;只需要少量数据即可…

IP代理可以保护信息安全吗?

“随着互联网的普及和发展&#xff0c;网络安全问题已经成为众多企业和个人所面临的严峻挑战。保护信息安全已成为企业的核心竞争力之一&#xff0c;而IP代理正成为实现这一目标的有效手段。” 一、IP代理真的可以保护用户信息安全吗&#xff1f; IP代理作为一种网络工具&…

CSS基本知识总结

目录 一、CSS语法 二、CSS选择器 三、CSS样式表 1.外部样式表 2.内部样式表 3.内联样式 四、CSS背景 1.背景颜色&#xff1a;background-color 2.背景图片&#xff1a;background-image 3.背景大小&#xff1a;background-size 4.背景图片是否重复&#xff1a;backg…

鸿蒙应用开发学习:获取手机位置信息

一、前言 移动应用中经常需要获取设备的位置信息&#xff0c;因此在鸿蒙应用开发学习中&#xff0c;如何获取手机的位置信息是必修课。之前我想偷懒从别人那里复制黏贴代码&#xff0c;于是在百度上搜了一下&#xff0c;可能是我输入的关键字不对&#xff0c;结果没有找到想要…

离线编译 onnxruntime-with-tensortRT

记录为centos7的4090开发机离线编译onnxruntime的过程&#xff0c;因为在离线的环境&#xff0c;所以踩了很多坑。 https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html 这里根据官网的推荐安装1.15 版本的onnx 因为离线环境&#xff0c;所以很…

10个常考的前端手写题,你全都会吗?(下)

前言 &#x1f4eb; 大家好&#xff0c;我是南木元元&#xff0c;热爱技术和分享&#xff0c;欢迎大家交流&#xff0c;一起学习进步&#xff01; &#x1f345; 个人主页&#xff1a;南木元元 今天接着上篇再来分享一下10个常见的JavaScript手写功能。 目录 1.实现继承 ES5继…

【制作100个unity游戏之23】实现类似七日杀、森林一样的生存游戏2(附项目源码)

本节最终效果演示 文章目录 本节最终效果演示系列目录前言添加小动物模型动画动物AI脚本效果 添加石头石头模型拾取物品效果 源码完结 系列目录 【制作100个unity游戏之23】实现类似七日杀、森林一样的生存游戏1&#xff08;附项目源码&#xff09; 【制作100个unity游戏之23】…

卓振江:我的大数据能力提升之路 | 提升之路系列(二)

导读 为了发挥清华大学多学科优势&#xff0c;搭建跨学科交叉融合平台&#xff0c;创新跨学科交叉培养模式&#xff0c;培养具有大数据思维和应用创新的“π”型人才&#xff0c;由清华大学研究生院、清华大学大数据研究中心及相关院系共同设计组织的“清华大学大数据能力提升项…