YOLOv11改进策略【损失函数篇】| Shape-IoU:考虑边界框形状和尺度的更精确度量

news2024/10/1 20:40:07

一、本文介绍

本文记录的是改进YOLOv11的损失函数,将其替换成Shape-IoU。现有边界框回归方法通常考虑真实GT(Ground Truth)框预测框之间的几何关系,通过边界框的相对位置和形状计算损失,但忽略了边界框本身的形状和尺度等固有属性对边界框回归的影响。为了弥补现有研究的不足,Shape-IoU提出了一种关注边界框本身形状和尺度的边界框回归方法。


文章目录

  • 一、本文介绍
  • 二、Shape-IoU设计原理
    • 2.1 原理
    • 2.2 优势
  • 三、Shape-IoU的实现代码
  • 四、添加步骤
    • 4.1 修改ultralytics/utils/metrics.py
    • 4.2 修改ultralytics/utils/loss.py
    • 4.3 修改ultralytics/utils/tal.py


二、Shape-IoU设计原理

Shape-IoU:考虑边界框形状和尺度的更精确度量

以下是关于Shape-IoU的详细介绍:

2.1 原理

  • 分析边界框回归特性:通过对边界框回归样本的分析,得出以下结论:
    • 当回归样本的偏差和形状偏差相同且不全为0时,假设GT框不是正方形且有长短边,边界框形状和尺度的差异会导致其IoU值的差异。
    • 对于相同尺度的边界框回归样本,当回归样本的偏差和形状偏差相同且不全为0时,边界框的形状会对回归样本的IoU值产生影响。沿着边界框短边方向的偏差和形状偏差对应的IoU值变化更为显著。
    • 对于具有相同形状边界框的回归样本,当回归样本偏差和形状偏差相同且不全为0时,与较大尺度的回归样本相比,较小尺度边界框回归样本的IoU值受GT框形状的影响更为显著。
  • Shape - IoU公式
    • I o U = ∣ B ∩ B g t ∣ ∣ B ∪ B g t ∣ IoU = \frac{|B \cap B^{gt}|}{|B \cup B^{gt}|} IoU=BBgtBBgt
    • w w = 2 × ( w g t ) s c a l e ( w g t ) s c a l e + ( h g t ) s c a l e ww = \frac{2 \times (w^{gt})^{scale}}{(w^{gt})^{scale} + (h^{gt})^{scale}} ww=(wgt)scale+(hgt)scale2×(wgt)scale
    • h h = 2 × ( h g t ) s c a l e ( w g t ) s c a l e + ( h g t ) s c a l e hh = \frac{2 \times (h^{gt})^{scale}}{(w^{gt})^{scale} + (h^{gt})^{scale}} hh=(wgt)scale+(hgt)scale2×(hgt)scale
    • d i s t a n c e s h a p e = h h × ( x c − x c g t c ) 2 + w w × ( y c − y c g t c ) 2 distance^{shape} = hh \times (\frac{x_c - x_c^{gt}}{c})^{2} + ww \times (\frac{y_c - y_c^{gt}}{c})^{2} distanceshape=hh×(cxcxcgt)2+ww×(cycycgt)2
    • Ω s h a p e = ∑ t = w , h ( 1 − e − ω t ) θ , θ = 4 \Omega^{shape} = \sum_{t = w, h}(1 - e^{-\omega_t})^{\theta}, \theta = 4 Ωshape=t=w,h(1eωt)θ,θ=4,其中 { ω w = h h × ∣ w − w g t ∣ m a x ( w , w g t ) ω h = w w × ∣ h − h g t ∣ m a x ( h , h g t ) \left\{\begin{array}{l} \omega_{w} = hh \times \frac{|w - w^{gt}|}{max(w, w^{gt})} \\ \omega_{h} = ww \times \frac{|h - h^{gt}|}{max(h, h^{gt})} \end{array}\right. {ωw=hh×max(w,wgt)wwgtωh=ww×max(h,hgt)hhgt
  • 对应的边界框回归损失 L S h a p e − I o U = 1 − I o U + d i s t a n c e s h a p e + 0.5 × Ω s h a p e L_{Shape - IoU} = 1 - IoU + distance^{shape} + 0.5 \times \Omega^{shape} LShapeIoU=1IoU+distanceshape+0.5×Ωshape

在这里插入图片描述

2.2 优势

  • 提高检测性能:论文中通过一系列对比实验,证明了Shape-IoU方法在不同检测任务中能够有效提高检测性能,优于现有方法,在不同检测任务中达到了最先进的性能。
  • 关注边界框自身属性:考虑了边界框本身的形状和尺度对边界框回归的影响,弥补了现有研究忽略这一因素的不足。
  • 在小目标检测任务中的应用:针对小目标检测任务,提出了Shape-Dot DistanceShape-NWD,将Shape-IoU的思想融入其中,提高了在小目标检测方面的性能。

论文:https://arxiv.org/pdf/2312.17663
源码:https://github.com/malagoutou/Shape-IoU


三、Shape-IoU的实现代码

Shape-IoU的实现代码如下:

def shape_iou(box1, box2, xywh=True, scale=0, eps=1e-7):
    (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
    w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
    b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
    b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
 
    # Intersection area
    inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
            (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
 
    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps
 
    # IoU
    iou = inter / union
 
    #Shape-Distance    #Shape-Distance    #Shape-Distance    #Shape-Distance    #Shape-Distance    #Shape-Distance    #Shape-Distance  
    ww = 2 * torch.pow(w2, scale) / (torch.pow(w2, scale) + torch.pow(h2, scale))
    hh = 2 * torch.pow(h2, scale) / (torch.pow(w2, scale) + torch.pow(h2, scale))
    cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex width
    ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height
    c2 = cw ** 2 + ch ** 2 + eps                            # convex diagonal squared
    center_distance_x = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2) / 4
    center_distance_y = ((b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4
    center_distance = hh * center_distance_x + ww * center_distance_y
    distance = center_distance / c2
  
    omiga_w = hh * torch.abs(w1 - w2) / torch.max(w1, w2)
    omiga_h = ww * torch.abs(h1 - h2) / torch.max(h1, h2)
    shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
    
    iou = iou - distance - 0.5 * ( shape_cost)
    return iou  # IoU

四、添加步骤

4.1 修改ultralytics/utils/metrics.py

此处需要查看的文件是ultralytics/utils/metrics.py

metrics.py中定义了模型的损失函数和计算方法,我们想要加入新的损失函数就只需要将代码放到这个文件内即可

Shape-IoU添加后如下:

在这里插入图片描述

4.2 修改ultralytics/utils/loss.py

utils\loss.py用于计算各种损失。

ultralytics/utils/loss.py在的引用中添加shape_iou,然后在BboxLoss函数内修改如下代码,使模型调用此Shape-IoU损失函数。

在这里插入图片描述


iou = shape_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])

在这里插入图片描述

4.3 修改ultralytics/utils/tal.py

tal.py中是一些损失函数的功能应用。

ultralytics/utils/tal.py在的引用中添加shape_iou,然后在iou_calculation函数内修改如下代码,使模型调用此Shape-IoU损失函数。

在这里插入图片描述

return shape_iou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)

在这里插入图片描述

此时再次训练模型便会使用Shape-IoU计算模型的损失函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2183108.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PV大题--专题突破

写在前面: PV大题考查使用伪代码控制进程之间的同步互斥关系,它需要我们一定的代码分析能力,算法设计能力,有时候会给你一段伪代码让你补全使用信号量控制的操作,请一定不要相信某些人告诉你只要背一个什么模板&#…

Java线程入门

目录 一.线程相关概念 1.程序(program) 2.进程 3.线程 4.其他相关概念 二.线程的创建 1.继承Thread 2.Runnable接口 3.多线程机制(重要) 4.start() 三.线程终止--通知 四.线程(Thread)方法 1.常…

fastAPI教程:数据库操作

FastAPI 六、数据库操作 FastAPI支持操作各种数据库,但本身并没有内置关于任何数据库相关的模块。因此我们可以根据需求使用任何数据库,包括关系型(SQL)数据库,例如:PostgreSQL、MySQL、SQLite、Oracle、…

【AGC005D】~K Perm Counting(计数抽象成图)

容斥原理。 求出f(m) ,f(m)指代至少有m个位置不合法的方案数。 怎么求? 注意到位置为id,权值为v ,不合法的情况,当且仅当 v idk或 v id-k 因此,我们把每一个位置和权值抽象成点 ,不合法的情况之间连一…

【JVM】基础篇

1 初识JVM 1.1 什么是JVM JVM 全称是 Java Virtual Machine,中文译名 Java虚拟机。JVM 本质上是一个运行在计算机上的程序,他的职责是运行Java字节码文件。 Java源代码执行流程如下: 分为三个步骤: 1、编写Java源代码文件。 …

自动驾驶系列—深度剖析自动驾驶芯片SoC架构:选型指南与应用实战

🌟🌟 欢迎来到我的技术小筑,一个专为技术探索者打造的交流空间。在这里,我们不仅分享代码的智慧,还探讨技术的深度与广度。无论您是资深开发者还是技术新手,这里都有一片属于您的天空。让我们在知识的海洋中…

认知杂谈74《远离渣女陷阱,拥抱健康情感》

内容摘要: 渣女在感情中使用甜言蜜语陷阱,利用男性渴望理解和关爱的心理,通过虚假承诺和情感操控来获得利益。 男性易陷入这种陷阱,因为他们可能因压力大、感性而易受感动。为了避免这种情况,男性需要辨别言行一致性&a…

【含文档】基于Springboot+Vue的国风彩妆网站(含源码+数据库+lw)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统定…

软件设计之SSM(4)

软件设计之SSM(4) 路线图推荐: 【Java学习路线-极速版】【Java架构师技术图谱】 尚硅谷新版SSM框架全套视频教程,Spring6SpringBoot3最新SSM企业级开发 资料可以去尚硅谷官网免费领取 学习内容: 基于配置类方式管理Bean 完全注解开发第三…

共模电感工作原理:【图文讲解】

共模电感,相信做电源较多的朋友用的比较多,而做消费级产品的朋友或许用的不是那么的多。但是还是有必要了解了解。 先上图,看看它长什么样子: (实物图) (结构图) 很显然&#xff0…

【Ubuntu】安装常用软件包-mysql

我的几个服务是部署在docker的同一个网络里,这样相互访问就可以通过docker容器的名字访问,比如容器A访问容器B,就可以http://B:8080/xxx 这样访问,不用关心ip是多少。 所以mysql前面文章给安装到主机里,感觉有点坑自己…

02.usePrevious

在 React 开发中,有时我们需要访问组件的前一个状态或属性。这在进行比较、动画或其他需要历史数据的操作时特别有用。usePrevious 钩子提供了一种简单而有效的方式来存储和访问前一个值。以下是如何实现和使用这个自定义钩子: const usePrevious valu…

【数据类型】C和C++的区别

文章目录 一、字符串二、布尔类型 bool三、数据的输入和输出 C和C在数据类型上打区别不大,下面就二者在这方面的部分区别做比较。 一、字符串 C语言和C在字符串的定义和书写风格上略有差异。 C风格字符串: char str[]"hello";C风格字符串 st…

社交内容电商中的新机遇:2+1链动模式AI智能名片商城小程序

在当今的电商世界里,社交内容电商正蓬勃发展。这种模式基于高质量内容,将有着共同兴趣爱好的用户聚集起来形成社群,随后引导用户进行裂变式的传播与交易。无论是像微信、微博、快手、抖音、今日头条这样的平台形式,还是网红、“大…

算法笔记(四)——模拟

文章目录 替换所有的问号提莫攻击Z字形变换外观数列数青蛙 模拟算法就是根据题目的要求,题目让干神马就做神马,一步一步来 替换所有的问号 题目:替换所有的问号 思路 从左到右遍历整个字符串,找到问号之后,就⽤ a ~ z…

QT系统学习篇(2)- Qt跨平台GUI原理机制

一、Qt工程管理 新建项目: 我们程序员新建项目对话框所有5类项目模板 Application: Qt的应用程序,包含Qt Quick和普通窗口程序。 Library: 它可以创建动态库、静态库、Qt Creator自身插件、Qt Quick扩展插件。 其他项目: 创建单元测试项目、子目录项目…

自动驾驶系列—自动驾驶MCU架构全方位解析:从单核到多核的选型指南与应用实例

🌟🌟 欢迎来到我的技术小筑,一个专为技术探索者打造的交流空间。在这里,我们不仅分享代码的智慧,还探讨技术的深度与广度。无论您是资深开发者还是技术新手,这里都有一片属于您的天空。让我们在知识的海洋中…

五子棋双人对战项目(3)——匹配模块

一、分析需求 二、约定前后端接口 三、实现游戏大厅页面(前端代码) 四、实现后端代码 五、线程安全问题 六、忙等问题 一、分析需求 需求:多个玩家,在游戏大厅进行匹配,系统会把实力相近的玩家匹配到一起。 要想实…

使用cmake配置pcl环境

项目文件在https://pan.quark.cn/s/d347f72c7432 文件中包含CMakeLists.txt,一个pcd文件,一个cpp源文件。 这里的话,首先你需要下载好cmake软件,并将其添加到环境变量。 CMakeLists.txt文件内容如下 cmake_minimum_required(VER…