YOLOv9改进策略【损失函数篇】| 利用MPDIoU,加强边界框回归的准确性

news2024/9/20 0:50:46

一、背景

  1. 目标检测和实例分割中的关键问题
    • 现有的大多数边界框回归损失函数在不同的预测结果下可能具有相同的值,这降低了边界框回归的收敛速度和准确性。
  2. 现有损失函数的不足
    • 现有的基于 ℓ n \ell_n n范数的损失函数简单但对各种尺度敏感。
    • 当预测框与真实框具有相同的宽高比但不同的宽度和高度值时,现有损失函数可能会存在问题,限制了收敛速度和准确性。

文章目录

  • 一、背景
  • 二、原理
    • 2.1 IoU计算原理
    • 2.2. 基于最小点距离的IoU度量
    • 2.3. 作为边界框回归损失函数
    • 2.4 MPDIoU的计算公式
  • 三、添加步骤
    • 3.1 utils\metrics.py
    • 3.2 修改utils\loss_tal_dual.py
  • 四、总结


MPDIoU(Intersection over Union with Minimum Points Distance)是一种用于高效且准确的边界框回归的损失函数。

二、原理

2.1 IoU计算原理

IoU(Intersection over Union)即交并比,用于衡量预测边界框和真实边界框的重合程度。

1. 交集计算:
- 首先确定预测边界框和真实边界框的交集区域。
- 对于两个以左上角和右下角坐标表示的矩形框,分别找出它们在横坐标和纵坐标方向上的重叠区间。
- 如果两个矩形框在横坐标和纵坐标方向上都有重叠部分,那么这个重叠区域就是一个矩形,其面积就是交集的大小。
2. 并集计算:
- 计算预测边界框和真实边界框的并集区域。
- 并集的大小等于两个矩形框各自的面积之和减去它们的交集面积。
3. 比值计算:
- 最后,IoU的值就是交集面积与并集面积的比值。

2.2. 基于最小点距离的IoU度量

  • 原论文中受水平矩形的几何特性启发,设计了一种基于最小点距离的新型IoU度量MPDIoU,直接最小化预测边界框和真实边界框的左上角和右下角点之间的距离。
  • MPDIoU的计算通过两个任意凸形状 A A A B B B,用其左上角和右下角点的坐标来表示,通过计算两个框的交集与并集之比,再减去左上角和右下角两点距离的归一化值来得到MPDIoU

2.3. 作为边界框回归损失函数

  • 在训练阶段,通过最小化基于MPDIoU的损失函数 L M P D I o U = 1 − M P D I o U L_{MPDIoU}=1-MPDIoU LMPDIoU=1MPDIoU,使模型预测的每个边界框 B p r d B_{prd} Bprd接近其真实框 B g t B_{gt} Bgt
  • 现有损失函数中的所有因素(如非重叠区域、中心点距离、宽高偏差等)都可以通过左上角和右下角两点的坐标确定,这意味着提出的 L M P D I o U L_{MPDIoU} LMPDIoU不仅考虑了这些因素,还简化了计算过程。

在这里插入图片描述

2.4 MPDIoU的计算公式

  1. MPDIoU的计算公式:

    • M P D I o U = A ∩ B A ∪ B − d 1 2 w 2 + h 2 − d 2 2 w 2 + h 2 MPDIoU=\frac{A\cap B}{A\cup B}-\frac{d_{1}^{2}}{w^{2}+h^{2}}-\frac{d_{2}^{2}}{w^{2}+h^{2}} MPDIoU=ABABw2+h2d12w2+h2d22
    • 其中 A A A B B B是两个任意凸形状, ( x A 1 , y A 1 ) (x_{A1}, y_{A1}) (xA1,yA1) ( x A 2 , y A 2 ) (x_{A2}, y_{A2}) (xA2,yA2)表示(A)的左上角和右下角点坐标, ( x B 1 , y B 1 ) (x_{B1}, y_{B1}) (xB1,yB1) ( x B 2 , y B 2 ) (x_{B2}, y_{B2}) (xB2,yB2)表示 B B B的左上角和右下角点坐标。
    • d 1 2 = ( x 1 B − x 1 A ) 2 + ( y 1 B − y 1 A ) 2 d_{1}^{2}=(x_{1}^{B}-x_{1}^{A})^{2}+(y_{1}^{B}-y_{1}^{A})^{2} d12=(x1Bx1A)2+(y1By1A)2 d 2 2 = ( x 2 B − x 2 A ) 2 + ( y 2 B − y 2 A ) 2 d_{2}^{2}=(x_{2}^{B}-x_{2}^{A})^{2}+(y_{2}^{B}-y_{2}^{A})^{2} d22=(x2Bx2A)2+(y2By2A)2
  2. 基于MPDIoU的损失函数计算公式:

    • L M P D I o U = 1 − M P D I o U L_{MPDIoU}=1-MPDIoU LMPDIoU=1MPDIoU

三、添加步骤

3.1 utils\metrics.py

此处需要查看的文件是utils\metrics.py

metrics.py中定义了模型的损失函数和计算方法,我们想要加入新的损失函数就只需要将代码放到这个文件内即可。YOLOv9原模型中使用的是CIoU,并且在原YOLOv9的代码中已经实现了MPDIoU的代码,

MPDIoU的代码在utils\metrics.py的第254行,如下:

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, MDPIoU=False, feat_h=640, feat_w=640, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

    # Intersection area
    inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
            (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU:
        cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex (smallest enclosing box) width
        ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    elif MDPIoU:
        d1 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2
        d2 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2
        mpdiou_hw_pow = feat_h ** 2 + feat_w ** 2
        return iou - d1 / mpdiou_hw_pow - d2 / mpdiou_hw_pow  # MPDIoU
    return iou  # IoU

在这里插入图片描述

3.2 修改utils\loss_tal_dual.py

utils\loss_tal_dual.py是损失函数的辅助分支+主分支损失计算文件。

utils\loss_tal_dual.py的75行处修改成如下代码,使模型调用此MPDIoU损失函数。


iou = bbox_iou(pred_bboxes_pos, target_bboxes_pos, xywh=False, MPDIoU=True)

在这里插入图片描述

四、总结

当发现预测边界框和真实边界框具有相同的宽高比但不同的宽度和高度值时,MPDIoU损失函数比现有损失函数更有效,此时可以尝试将损失函数修改成MPDIoU查看效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2075755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Word文件密码忘记,该如何才能编辑Word文件呢?

Word文件打开之后,发现编辑功能都是灰色的,无法使用,无法编辑,遇到这种情况,是因为Word文件设置了限制编辑导致的。一般情况下,我们只需要输入Word密码,将限制编辑取消就可以正常编辑文件了&…

LLM大模型:生成式人工智能完全指南,240页pdf

《你最后一本需要的AI书籍。我们保证!》 AI技术发展如此迅速,这本书可能已经过时了!但别担心——《生成性AI完全过时指南》依然是任何想将生成性AI从玩具变成工具的人必读的书籍。无论未来如何变化,它都能教你如何充分利用AI。你…

FL Studio24.1.1.4239中文高级版破解补丁+永久免费激活码许可证

FL Studio 24.1.1.4239中文版,音乐制作人的“瑞士军刀” 在音乐制作的世界中,有一款软件被誉为“瑞士军刀”,那就是FL Studio 24.1.1.4239中文版。它不仅功能强大,而且界面友好,让音乐制作变得简单又有趣。今天&#…

大模型如何赚钱,杀手级应用是什么、创业机会在哪里?

除了通义大模型外,MiniMax、月之暗面、智谱AI、猎户星空、零一万物、百川智能六家大模型厂商已经与钉钉达成合作。目前,钉钉生态伙伴总数超过5600家,其中AI 生态伙伴已经超过100家;钉钉AI每天调用量超1000万次。 在下午的圆桌对话…

八种dll文件丢失怎么恢复的步骤分享,超全面介绍dll文件几解决方法

在使用Windows操作系统的过程中,我们时常会遇到程序运行错误提示,其中“DLL文件丢失”是一类非常典型的问题。这类错误不仅令人困扰,还可能阻碍软件或系统功能的正常使用。动态链接库(DLL)文件是Windows系统中的一个关…

LeetCode 精选 75 回顾

目录 一、数组 / 字符串 1.交替合并字符串 (简单) 2.字符串的最大公因子 (简单) 3.拥有最多糖果的孩子(简单) 4.种花问题(简单) 5.反转字符串中的元音字母(简单&a…

基于大语言模型的医疗问答系统的设计与研究

目录 研究背景及意义 国内外研究现状 研究内容 研究方案与技术路线 大语言模型的基本原理 大语言模型的部署 大语言模型微调 大语言模型提示工程(Prompt) 大语言模型RAG技术 LangChain 多模态大语言模型 研究背景及意义 大语言模型&#xff0…

网络安全售前入门03——审计类产品了解

目录 1.前言 2.堡垒机介绍 2.1产品架构功能 2.2应用场景 2.3部署形式 2.4产品价值 2.5选型依据 3.日志审计 3.1产品架构功能 3.2应用场景 3.3部署形式 3.4产品价值 3.5选型依据 后续 1.前言 为方便初接触网络安全售前工作的小伙伴了解网安行业情况,我制作一系统…

CSS文本样式(一)

一、font-family 1、font-family属性 font-family​ :属性指定元素的​字体​,语法格式如下: ​font-family​: 字体1,字体2,...; 有两种字体系列名称: ​字体系列​:特定的字体系列(如Times New Rom…

Mac上免费使用Typora保姆级教程 简单 2024可用

一、官网安装正版软件 Typora官网--点击进入Typora官网下载正版软件 二、找到软件文件 进入访达,commandshiftG打开路径搜索,输入 /Applications/Typora.app/Contents/Resources/TypeMark 进入Typora文件夹 打开这个文件 三、修改字段 然后搜索字段…

Ubuntu2004编译VLC-QT(记录)(根据官方步骤来)

来到VLC-QT的github官方地址--VLC-QT(点击前面的) 下载官方源码,也可以git clone拉取 2:解压源码之后,进入文件夹 创建文件夹“build”用于存放待会编译产生的相关文件,执行 mkdir buildcd build 回到VLC…

【Redis】Redis 持久化 -- RDB AOF

文章目录 1 持久化介绍2 RDB2.1 RDB 介绍2.2 触发方式2.3 流程介绍2.4 RDB 文件2.5 RDB 优缺点 3 AOF3.1 AOF 介绍3.2 缓冲区刷新策略3.3 AOF 重写机制3.3.1 重写机制介绍3.3.2 混合持久化3.3.3 重写触发方式3.3.4 AOF 重写流程 3.4 AOF 优缺点 4 启动时数据恢复 1 持久化介绍 …

OceanBase V4 技术解读:从Alter Table 看DDL的支持

背景 数据库类型可以划分为两大类:关系型数据库和非关系型数据库。而关系型数据库以表格形式进行数据组织,同时遵循表关系的约束,例如创建一张表,表里面包含多个列,不同的列可以有不同的类型。当需要改表结构&#xf…

什么是数据库 DevOps?

在深入研究数据库 DevOps 之前,先回顾一下什么是 DevOps。它没有统一的定义,但我们知道它起源于软件开发方法与部署和运维的结合。 大约 2007 年和 2008 年,软件开发和 IT 界人士提出了这样的担忧:两个行业的分离,即编…

Datawhale X 李宏毅苹果书 AI夏令营(深度学习入门)task3

实践方法论 在应用机器学习算法时,实践方法论能够帮助我们更好地训练模型。如果在 Kaggle 上的结果不太好,虽然 Kaggle 上呈现的是测试数据的结果,但要先检查训练数据的损失。看看模型在训练数据上面,有没有学起来,再…

解锁 TypeScript Record 的奇妙用法:轻松搞定键值对!

在没有非常了解 Record 之前,定义对象的类型,一般使用 interface。它是 TS 中定义数据结构的一种方式,用来描述对象的形状、函数类型、类的结构等。 // 基本用法 interface User {name: string;age: number;isAdmin: boolean; }const user: …

抖音ip地址与实际地址不符是怎么回事

在数字化时代,社交媒体已成为人们日常生活不可或缺的一部分,而抖音作为其中的佼佼者,更是吸引了数以亿计的用户。然而,在使用抖音的过程中,不少用户发现了一个有趣而又令人困惑的现象:抖音显示的IP地址与实…

趣味算法------煤球数目

目录 前言: 题目描述: 解题思路: 具体代码: 前言: 数列在数学中是一个非常基础且重要的概念,它指的是按照一定顺序排列的一系列数。数列中的每一个数被称为该数列的项。 数列可以分为有限数列和无限数列…

7 nestjs 环境变量

安装 pnpm i --save nestjs/confignestjs/config 内部使用 dotenv 实现。 配置 一般会在根模块AppModal中导入,并使用.forRoot()静态方法导入它的配置 import { Module } from nestjs/common; import { ConfigModule } from nestjs/config; ​ Module({imports: …

降低游戏直播软件开发风险:自建团队、外包公司与现成源码

随着游戏直播行业的快速发展,越来越多的企业和个人开始涉足这一领域。然而,在游戏直播软件的开发过程中,选择合适的开发模式对于降低供应链风险至关重要。本文将探讨三种主要的游戏直播软件开发模式,并分析它们各自的风险管理策略…