YOLOv5涨点必备!改进损失函数EIoU,SIoU,AlphaIoU,FocalEIoU,Wise-IoU

news2024/11/17 13:49:46

目录

一,改进损失函数的作用

二,具体实现


一,改进损失函数的作用

YOLOv5损失函数的作用是衡量预测框与真实框之间的差异,并根据这些差异来更新模型的参数。它帮助模型学习如何准确地检测和定位目标物体,从而提高检测的精度和准确性。

YOLOv5中的损失函数主要包括三个部分:目标分类损失、边界框坐标损失和对象置信度损失。

  1. 目标分类损失:该损失函数用于衡量预测框中的目标类别与真实框中的目标类别之间的差异。它使用交叉熵损失函数来计算分类误差,促使模型学习正确地分类各个目标物体。

  2. 边界框坐标损失:该损失函数用于衡量预测框中的边界框位置与真实框中的边界框位置之间的差异。一般采用平方损失函数或者IOU(交并比)损失函数来衡量边界框的位置偏移,以便模型能够准确地定位目标物体。

  3. 对象置信度损失:该损失函数用于衡量预测框中的对象置信度与真实框中的对象置信度之间的差异。对象置信度表示预测框中是否存在目标物体,它是检测算法中一个关键的指标。通过对对象置信度的损失函数进行优化,模型可以学习如何准确地判断预测框中是否有目标物体。

YOLOv5的损失函数综合考虑了目标分类、边界框位置和对象置信度这三个重要因素,它们共同构成了目标检测的关键要素。通过最小化损失函数,模型可以不断优化参数,提高目标检测的准确性和鲁棒性。

二,具体实现

YOLOv5默认的损失函数为CIoU,另外自带的还有GIoU以及DIoU,

文件路径:utils/metrics.py

函数名为:bbox_iou

原损失函数定义:

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    return iou  # IoU

改为:把上面提及到的这个函数替换成以下

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, Focal=False, alpha=1, gamma=0.5, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    # iou = inter / union # ori iou
    iou = torch.pow(inter/(union + eps), alpha) # alpha iou
    if CIoU or DIoU or GIoU or EIoU or SIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or EIoU or SIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squared
            rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha_ciou = v / (v - iou + (1 + eps))
                if Focal:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter/(union + eps), gamma)  # Focal_CIoU
                else:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoU
            elif EIoU:
                rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                cw2 = torch.pow(cw ** 2 + eps, alpha)
                ch2 = torch.pow(ch ** 2 + eps, alpha)
                if Focal:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter/(union + eps), gamma) # Focal_EIou
                else:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIou
            elif SIoU:
                # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
                s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
                s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
                sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
                sin_alpha_1 = torch.abs(s_cw) / sigma
                sin_alpha_2 = torch.abs(s_ch) / sigma
                threshold = pow(2, 0.5) / 2
                sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
                angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
                rho_x = (s_cw / cw) ** 2
                rho_y = (s_ch / ch) ** 2
                gamma = angle_cost - 2
                distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
                omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
                omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
                shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
                if Focal:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_SIou
                else:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIou
            if Focal:
                return iou - rho2 / c2, torch.pow(inter/(union + eps), gamma)  # Focal_DIoU
            else:
                return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        if Focal:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter/(union + eps), gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
        else:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    if Focal:
        return iou, torch.pow(inter/(union + eps), gamma)  # Focal_IoU
    else:
        return iou  # IoU

Alpha-IoU的介绍:

论文的名字很好,反映了本文的核心想法。作者将现有的基于IoU Loss推广到一个新的Power IoU系列 Loss,该系列具有一个Power IoU项和一个附加的Power正则项,具有单个Power参数α,称这种新的损失系列为α-IoU Loss。

函数特性:


文中,作者将现有的基于IoU Loss推广到一个新的Power IoU系列 Loss,该系列具有一个Power IoU项和一个附加的Power正则项,具有单个Power参数α。称这种新的损失系列为α-IoU Loss。在多目标检测基准和模型上的实验表明,α-IoU损失:

可以显著地超过现有的基于IoU的损失;

通过调节α,使检测器在实现不同水平的bbox回归精度方面具有更大的灵活性;

对小数据集和噪声的鲁棒性更强。

实验结果表明,α(α>1)增加了high IoU目标的损失和梯度,进而提高了bbox回归精度。

power参数α可作为调节α-IoU损失的超参数以满足不同水平的bbox回归精度,其中α >1通过更多地关注High IoU目标来获得高的回归精度(即High IoU阈值)。

**α对不同的模型或数据集并不过度敏感,在大多数情况下,α=3表现一贯良好。**α-IoU损失家族可以很容易地用于改进检测器的效果,在干净或嘈杂的环境下,不会引入额外的参数,也不增加训练/推理时间。
公式如下:

所以将  alpha  设置为1,   其实还是用的是原本的IOU,并没有加入alpha的属性,一般设置为3,

然后再把  iou改成你需要的,这样就组合而成了  alpha-ciou   ,  alpha-Diou等等

注意:

  1. gamma参数是Focal_EIoU中的gamma参数,一般就是为0.5,有需要可以自行更改。
  2. alpha参数为AlphaIoU中的alpha参数,默认为1,1的意思就是跟正常的IoU一样,如果想采用AlphaIoU的话,论文alpha默认值为3。
  3. 跟Focal_EIoU一样,我认为AlphaIoU的思想同样可以用在其他的IoU变种上,简单来说就是如果你设置了alpha为3,其他IoU设定的参数(GIoU,DIoU,CIoU,EIoU,SIoU)为False的时候,那就是AlphaIoU,如果你设置了alpha为3,CIoU为True的时候,那就是
  4.  
  5. 想用那个IoU变种,就直接设置参数为True即可

除了以上这个函数替换,还需要在utils/loss.py中ComputeLoss Class中的__call__函数中修改一下:



将红框代码替换为:

iou = bbox_iou(pbox, tbox[i], CIoU=True)  # iou(prediction, target)
if type(iou) is tuple:
    lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean()
    iou = iou[0].squeeze()
else:
    lbox += (1.0 - iou.squeeze()).mean()  # iou loss
    iou = iou.squeeze()

最后修改参数就在调用bbox_iou中进行修改即可,比如上面的代码就是使用了CIoU,如果你想使用Focal_EIoU那么你可以修改为下:

iou = bbox_iou(pbox, tbox[i], EIoU=True, Focal=True) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1118673.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【分类讨论】CF1747D

Problem - D - Codeforces 题意 思路 一看这个做法一定就是分类讨论 先判无解 显然,如果区间异或和不是0一定无解 如果区间内全是0,答案一定是0 之后怎么讨论 注意到需要讨论区间长度 如果长度是奇数,那么直接操作即可,答…

【PyTorch】深度学习实践 1. Overview

目录 人工智能概述 课程前置知识 人工智能 问题分类 推理类 预测类 算法分类 传统算法与智能算法 人工智能领域细分 学习系统的发展 基于规则的系统 经典机器学习算法 表示学习方法 维度诅咒 说明 解决方法 第一代 第二代(深度学习) 传统…

数据结构: 红黑树

目录​​​​​​​ 1.红黑树概念 2.红黑树性质 3.调整 1.如果p和u都是红色,将其都改为黑色即可,然后向上调整 2.如果p红(u黑/u不在),这时候左子树两红,于是给右子树一个红(旋转变色) 2.1…

华为ICT——第四章深度学习和积卷神经

接第三章的末尾: 目录 接第三章的末尾: 1:自适应阈值分割: 2:形态处理: 4:膨胀: 5:腐蚀 6:开运算 7:闭运算 8:特征描述子 9&#xf…

看《软能力》的读后感

最近在图书大夏看了一本书的一部分,书名是《软能力》。本人看到了几个有意思的观点。一是接一些兼职项目。 二是分享自己的技术,让同事能干自己的工作,让自己的工作变得别人也能干,才有机会让自己的职位提升。 三是让手动操作变…

Linux实战——网络连接模式的三种模式

Linux可以分为三种网络模式: 桥接模式 (vmnet0) 仅主机模式 (vmnet1) NAT模式 (vmnet8) 当我们下载了vmware之后,在电脑会出现两个虚拟网卡,VMware Network Adapter VMnet1、VMware Network Adapter VMnet8。 可以通过查找 控…

巡检管理系统哪一款简单实用?如何解决传统巡检难题,实现高效监管?

在电力、燃气、水务等公共服务领域,线路巡检工作是保障公众安全、避免事故发生的重要环节。然而,传统的巡检方式存在一些显著的问题,可能会对公共安全和稳定运行产生不利的影响。为了解决这些问题,需要一种能够实现高效、精准的线…

多线程环境下如何安全的使用线性表, 队列, 哈希表

小王学习录 今日鸡汤安全使用ArrayList安全使用队列安全使用HashMap 今日鸡汤 安全使用ArrayList 使用synchronized锁或者reentrantLock锁使用CopyOnWriteArrayList(COW写时拷贝)类来代替ArrayList类. 多个线程对CopyOnWriteArrayList里面的ArrayList进行读操作, 不会发生线程…

原子核内的相互作用

原子核内的相互作用 氘核基态 和态的混合 核子-核子散射 低能核子-核子散射 n-p散射:只有核力 p-p散射:较复杂 n-n散射:n-n散射没有直接实验 低能 p-p 散射和核力的电荷无关性 高能核子-核子散射 核力的主要性质 核力主要性质 核力是短程力…

Qt第六十五章:自定义菜单栏的隐藏、弹出

目录 一、效果图 二、qtDesigner 三、ui文件如下: 四、代码 一、效果图 二、qtDesigner 原理是利用属性动画来控制QFrame的minimumWidth属性。 ①先拖出相应的控件 ②布局一下 ③填上一些样式 相关QSS background-color: rgb(238, 242, 255); border:2px sol…

【uniapp/uView】解决消息提示框悬浮在下拉框之上

需要实现这样的效果&#xff0c;即 toast 消息提示框在 popup 下拉框之上&#xff1a; 解决方法&#xff0c;把 <u-toast ref"uToast" /> 放在 u-popup 里面即可&#xff0c;这样就可以提升 toast 的优先级&#xff1a; <!-- 弹出下拉框 --><u-popu…

第三章 内存管理 十二、请求分页管理方式

目录 一、页表机制 1、页表结构 二、缺页中断机制 1、有如下例子 2、根据要访问的逻辑地址的页号2&#xff0c;找到该页的状态是没有放入内存&#xff0c;所以会产生缺页中断&#xff0c;将缺页进程堵塞&#xff0c;放入堵塞队列&#xff0c;调页完成后再将其唤醒&#xf…

SAP MM学习笔记37 - 请求书照合中的配送费用

上一次学习了请求书照合中的 追加请求&#xff0c;追加Credit&#xff0c;请求书取消等知识&#xff0c;这一章来学习请求书中的配送费用处理。 SAP MM学习笔记37 - 请求书照合中的 追加请求/追加Credit 等概念/ 请求书的取消-CSDN博客 如下图所示&#xff0c;配送费用分以下两…

ROS功能包编译报错fatal error: xxxxConfig.h: 没有那个文件或目录的解决方法及原理介绍

在ROS中&#xff0c;我们常使用动态调参工具或参数配置文件来进行参数调节&#xff0c;在编译时会生成对应的Config.h文件&#xff0c;如本文例子中的MPCPlannerConfig.h文件 一、报错原因及解决方法 在编译时报以下错误的原因是在编译生成可执行文件mpc_ros的过程中需要使用MP…

【(数据结构)—— 基于单链表实现通讯录】

&#xff08;数据结构&#xff09;—— 基于单链表实现通讯录 一.通讯录的功能介绍1.基于单链表实现通讯录(1). 知识要求(2). 功能要求 二.通讯录的代码实现1.通讯录的底层结构(单链表)(1).思路展示(2).底层代码实现(单链表)1.单链表头文件 —— &#xff08;函数的定义&#x…

国旗升降系统程序及原理图资料

本文主要介绍国旗升降系统设计程序及原理图&#xff08;完整资料见文末链接&#xff09; 系统原理图如下&#xff0c;程序资料见文末 附完整资料链接 百度网盘链接: https://pan.baidu.com/s/1Q5J2J8LgVJ-hoeTSVP95_g?pwd3qkw 提取码: 3qkw

TypeScript 安装

TypeScript 的安装 在电脑上全局安装typescript 在确保电脑上已经安装了node.js的前提下&#xff0c;使用npm工具进行安装。 执行如下命令即可&#xff1a; (执行成功会&#xff0c;会安装当前发布的最新版本的typescript) npm install -g typescript如果是Linux or Mac 系统&…

2023年“绿盟杯”四川省大学生信息安全技术大赛

findme 下载附件打开无法正常显示 使用010editor打开发现CRC报错&#xff0c;很可能是高度被修改了 使用工具爆破图片正确的宽度和高度 这里工具自动修复的依旧不能正常打开显示 我们先对原来图片的高度进行修改 之后再使用工具进行修复&#xff0c;即可正常显示&#xff0c;…

JavaScript基础知识15——专业术语:语句和表达式

哈喽&#xff0c;大家好&#xff0c;我是雷工。 今天看到了JavaScript中的专业术语&#xff1a;语句和表达式&#xff0c;以下为学习笔记。 1、表达式概念&#xff1a; 表达式是可以被求值的代码&#xff0c;JavaScript引擎会将其计算出一个具体的结果。 示例&#xff1a; a…