YOLOv8损失函数改进-增加MPDIoU提升边界框回归精度【附代码】

news2024/9/21 20:50:15

文章目录

  • 前言
  • 文章概述
  • 必要环境
  • 一、修改方法
    • 1.修改配置文件
    • 2. 增加 MPDIoU
    • 3. 修改 BboxLoss类
    • 4. 修改 v8DetectionLoss 类的 init 方法
  • 二、训练代码
  • 三、训练过程
  • 总结


前言

本篇博客我们将详细介绍如何在 YOLOv8项目中增加 MPDIoULoss,包括如何修改配置文件、增加新的损失函数、调整现有的损失计算模块,以及增加训练代码来使用新的损失函数。相信通过这篇博文会使大家更佳熟悉YOLOv8项目的整体结构
在这里插入图片描述


文章概述

1. default.yaml中新增参数mpdiou,用于控制是否使用 MPDIoU损失
2. 在metrics.py中添加MPDIoU函数
3. 修改 BboxLoss 类的 init 和 forward 函数,加入了MPDIoU损失的计算
4. 修改v8DetectionLoss 类的 init 函数,新增mpdiou参数
5. 编写了训练和验证的主函数,支持命令行参数设置,支持开启或关闭MPDIoU损失


必要环境

  1. 配置yolov8/v10环境 可参考往期博客
    地址:搭建YOLOv10环境 训练+推理+模型评估
  2. 论文地址
    地址:MPDIoU: A Loss for Efficient and Accurate Bounding Box
    Regression

一、修改方法

1.修改配置文件

我们需要在配置文件 ultralytics\cfg\default.yaml 中增加新的参数 mpdiou ,该参数负责控制是否使用 MPDIoULoss

mpdiou: False

参数详解:
mpdiou: 用于指定是否启用 MPDIoULoss,默认值为 False,表示不使用

2. 增加 MPDIoU

在 ultralytics\utils\metrics.py文件中的bbox_iou函数中增加增加MPDIoU

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, MPDIoU=False, eps=1e-7):
    """
    Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).

    Args:
        box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
        box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
        xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
                               (x1, y1, x2, y2) format. Defaults to True.
        GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
        DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
        CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.

    Returns:
        (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
    """

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
            b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
    ).clamp_(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU or MPDIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or MPDIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw.pow(2) + ch.pow(2) + eps  # convex diagonal squared
            rho2 = (
                           (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)
                   ) / 4  # center dist**2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU

            elif MPDIoU:
                sq_sum = (cw ** 2) + (ch ** 2)
                d12 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2
                d22 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2
                return iou - ((d12 / sq_sum) - (d22 / sq_sum))

            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    return iou  # IoU

关键代码

elif MPDIoU:
    sq_sum = (cw ** 2) + (ch ** 2)
    d12 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2
    d22 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2
    return iou - ((d12 / sq_sum) - (d22 / sq_sum))

对应公式
在这里插入图片描述

3. 修改 BboxLoss类

我们需要在 ultralytics\utils\loss.py 的BboxLoss类中集成 MPDIoULoss,需要修改 init 和 forward 方法,将这两个函数替换为如下代码

class BboxLoss(nn.Module):
    """Criterion class for computing training losses during training."""

    def __init__(self, reg_max=16,mpdiou=False):
        """Initialize the BboxLoss module with regularization maximum and DFL settings."""
        super().__init__()
        self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
        self.mpdiou = mpdiou

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        if self.mpdiou:
            iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, MPDIoU=True)
        else:
            iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

        # DFL loss
        if self.dfl_loss:
            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
            loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return loss_iou, loss_dfl

参数详解:
mpdiou: 指定是否使用 MPDIoULoss

4. 修改 v8DetectionLoss 类的 init 方法

我们还需在 ultralytics\utils\loss.py的v8DetectionLoss类中集成 MPDIoULoss 的相关参数,需要修改 init 方法,将该函数代码替换为如下代码

class v8DetectionLoss:
    """Criterion class for computing training losses."""

    def __init__(self, model, tal_topk=10):  # model must be de-paralleled
        """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
        device = next(model.parameters()).device  # get model device
        h = model.args  # hyperparameters

        m = model.model[-1]  # Detect() module
        self.bce = nn.BCEWithLogitsLoss(reduction="none")
        self.hyp = h
        self.stride = m.stride  # model strides
        self.nc = m.nc  # number of classes
        self.no = m.nc + m.reg_max * 4
        self.reg_max = m.reg_max
        self.device = device

        self.use_dfl = m.reg_max > 1

        self.mpdiou = self.hyp.mpdiou

        self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
        self.bbox_loss = BboxLoss(m.reg_max,mpdiou=self.mpdiou).to(device)
        self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)

参数详解:
self.mpdiou: 从default.yaml中读取,指定是否使用MPDIoULoss

二、训练代码

完整训练代码如下 其中mpdiou参数控制是否使用MPDIoULoss

# -*- coding:utf-8 -*-

from ultralytics import YOLO
import os
import argparse

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def parse_args():
    parser = argparse.ArgumentParser(description="YOLO Training and Evaluation Script")
    parser.add_argument('--mpdiou', action='store_true', default=True, help="Use MPDIoU")
    parser.add_argument('--weights', type=str, default='yolov8n.pt', help="Path to the model")
    parser.add_argument('--mode', type=str, choices=['train', 'val'], default='train', help="Mode: train or val")
    parser.add_argument('--data', type=str, default='data.yaml', help="Data configuration file")
    parser.add_argument('--epoch', type=int, default=100, help="Number of epochs for training")
    parser.add_argument('--batch', type=int, default=16, help="Batch size")
    parser.add_argument('--workers', type=int, default=8, help="Number of data loading workers")
    parser.add_argument('--device', type=str, default='0', help="Device to run on, e.g., '0' for GPU 0")
    return parser.parse_args()


def main():
    args = parse_args()

    if args.mode == 'train':
        model = YOLO(args.weights)
        model.train(data=args.data, epochs=args.epoch, batch=args.batch, workers=args.workers, device=args.device,
                    mpdiou=args.mpdiou)  # 训练模型
    else:
        batch = args.batch * 2
        model = YOLO(args.weights)
        print(model.model)
        model.val(data=args.data, batch=batch, workers=args.workers, device=args.device)


if __name__ == '__main__':
    main()

三、训练过程

随便找了几张图测试是否能跑通
在这里插入图片描述
在这里插入图片描述


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!
最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看
b站主页:https://b23.tv/1upjbcG
学习交流群:995760755

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1921846.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

U盘打不开难题解析:原因、恢复与预防全攻略

在日常生活中,U盘作为一种便捷的数据存储设备,深受大家的喜爱。然而,有时我们可能会遇到U盘打不开的情况,这不仅令人困扰,还可能导致重要数据的丢失。那么,当U盘打不开时,我们该如何应对呢&…

[PM]原型与交互设计

原型分类 1.草图原型 手绘图稿, 规划的早期,整理思路会使用 2.低保真原型 简单交互, 无需配色, 黑白灰为主, 产品规划和评审阶段使用 标准化的低保真原型是高保真原型的基础 3.高保真原型 复杂交互, 一般用于公开演示, 产品先产出低保真原型, 设计师根据原型产出设计稿 低保…

2024-07-12 Unity AI状态机1 —— 框架介绍

文章目录 1 有限状态机2 状态机实现框架2.1 StateMachine2.2 BaseState2.3 ...State2.4 IAIObject 3 框架类图 本文章参考 B 站唐老狮 2023 年直播内容。点击前往唐老狮 B 站主页。 1 有限状态机 ​ 有限状态机(Finite - State Machine,FSM&#xff09…

如何让招投标数据成为企业决策的金钥匙?

在当今数据驱动的时代,招投标信息作为行业竞争情报的重要组成部分,正日益成为企业制定战略决策的关键依据。本文将深入探讨招投标数据采集的重要性,以及它如何为企业决策提供强有力的数据支持,同时揭秘如何高效、精准地获取这些数…

花几千上万学习Java,真没必要!(二)

1、注释: java代码注释分3种: 单行注释://注释信息 多行注释: /*注释信息*/ 文本注释:/**注释信息*/ public class TestComments {// 这是单行注释,用于注释单行代码或解释代码功能/* 这是多行注释,用于注释多行代码…

eMMC规范 - 寻址/信息寄存器/总线协议/时序图/速度模式

存储器寻址 e•MMC 规范的早期实现(至 v4.1 的版本)是采用 32-bit 域实现字节寻址的。这种寻址 机制允许最大 2 GB 的 e•MMC 容量。 为了支持更大的容量,寻址机制升级到支持扇区寻址( 512B 扇区)。对所有容量大于 2 …

在PyQt中为自己开发的软件实现远程文件“一机一码”授权管理实例

在使用PyQt搞软件开发时,开发者往往想要给自己的软件添加一个授权机制,只有当客户提供了授权码并且开发者将授权码放在授权管理系统的时候,客户端才能正常启动。这几天小陶就在捣鼓这个事,发现确实是可行的。 如果没有进行授权&a…

mybatis动态传入参数 pgsql 日期 Interval ,day,minute

mybatis动态传入参数 pgsql 日期 Interval 在navicat中,标准写法 SELECT * FROM test WHERE time > (NOW() - INTERVAL 5 day)在mybatis中,错误写法 SELECT * FROM test WHERE time > (NOW() - INTERVAL#{numbers,jdbcTypeINTEGER} day)报错内…

html5——CSS高级选择器

目录 属性选择器 E[att^"value"] E[att$"http"] E[att*"http"] 关系选择器 子代: 相邻兄弟: 普通兄弟: 结构伪类选择器 链接伪类选择器 伪元素选择器 CSS的继承与层叠 CSS的继承性 CSS的层叠性 …

redis介绍与布署

redis remote dictionary server(远程字典服务器) 是一个开源的,使用c语言编写的非关系型数据库,支持内存运行并持久化,采用key-value的存储形式。 单进程模型意味着可以在一台服务器上启动多个redis进程,…

基于语义的法律问答系统

第一步,准备数据集 第二步,构建索引数据集,问答对数据集,训练数据集,召回评估数据集 第三步,构建dataloader,选择优化器训练模型,之后召回评估 第四步,模型动转静,之后…

改摄像头IQ(目前我知道的功能是看色彩)

1、SrcCode\Dx\580_CARDV_ETHCAM_RX_EVB(每个项目不同找到对应的)\isp.dtsi 将下面路径改成对应镜头的 2、将新的IQ复制到文件夹下code\hdal\vendor\isp\configs\dtsi

Vue3 引入腾讯地图 包含标注简易操作

1. 引入腾讯地图API JavaScript API | 腾讯位置服务 (qq.com) 首先在官网注册账号 并正确获取并配置key后 找到合适的引入方式 本文不涉及版本操作和附加库 据体引入参数参考如下图 具体以链接中官方参数为准标题 在项目根目录 index.html 中 写入如下代码 <!-- 引入腾…

【SQL】如何用SQL写透视表

【背景】 报表中有一大需求是透视表,目前有很多分析类应用也搭载了此类功能,那么我们能不能直接用SQL做透视表呢? 【分析】 BI类软件将透视表功能做在了前端,但是数据本身还是存储在数据库中,所以必然有方法可以用SQL直接实现透视表。 【心法】 透视表是任意选取一个…

【C语言】经典C语言笔试面试题目

01. 请填写bool , float, 指针变量 与“零值”比较的if语句。 提示&#xff1a;这里“零值”可以是0, 0.0 , FALSE 或者“空指针”。 例如 int n 与“零值”比较的 if 语句为&#xff1a; if ( n 0 ) if ( n ! 0 )以此类推。 请写出 bool flag 与“零值”比较的 if 语句&a…

IT运维也有自己的节日 724向日葵IT运维节,三大版本如何选?

“724运维节”&#xff0c;是2016年由开放运维联盟发起倡议&#xff0c;广大运维人员共同投票产生的属于运维人自己的节日。 对于运维人最大的印象&#xff0c;那就是工作都需要7x24小时待命&#xff0c;是名副其实的“日不落骑士”&#xff0c;这也是大家选择724这一天作为运…

2024最新6月泛二级域名秒收泛目录(二级域名泛站群)

5月免费版本无后台 无更新功能不自动引蜘蛛 2024年5月最新泛程序&#xff0c;秒收秒排&#xff01;&#xff08;泛型程序&#xff09; - 虚良SEO博客 新曾功能&#xff1a; 后台管理 蜘蛛统计 域名添加 一键强引蜘蛛 蜘蛛统计 识别真假蜘蛛 全自动引蜘蛛 域名要求 …

viteExternalsPlugin 插件管理外部依赖

viteExternalsPlugin 是一个 Vite 插件&#xff0c;用于将指定的模块或库配置为外部依赖 安装&#xff1a; npm i vite-plugin-externals 1.实战用途 比如从项目 index.html 中引入一些SDK文件&#xff0c;我这个是引入的CHATUI vite.config.js 配置&#xff1a; import {…

OSS存储桶密钥泄露【案例】

OSS存储桶密钥泄露 同样的&#xff0c;在前几天的攻防演练中的经历&#xff0c;本文我们将为OSS存储桶单独做文章 公开配置文件泄露 录屏、截图缺失了。发现这个存储桶密钥是因为我在鹰图对一个能够控制生成类似容器的站点&#xff0c;抓包发现api是另一个子域的站点&#x…

C#变量、常量与运算符

文章目录 变量变量定义命名规则作用域和生命周期 常量特殊字符常量 运算符算术运算符关系运算符逻辑运算符位运算符赋值运算符其他运算符 变量 变量就是一个存储空间的名字&#xff0c;变量是什么类型&#xff0c;这个空间里面存储的就是什么类型的数据。 变量定义 <data_t…