yolov5增加iou loss,无痛涨点trick

news2024/12/23 7:51:49

        yolo无痛涨点trick,简单实用

        先贴一张最近一篇论文的结果

后来的几种iou的消融实验结果在一定程度上要优于CIoU,最新的WIoU暂时还没复现。

        本文将在yolov5的基础上增加SIoU,EIoU,Focal-XIoU(X为C,D,G,E,S等)以及AlphaXIoU。

        在yolov5的utils文件夹下新增iou.py文件

import math
import torch


def bbox_iou(box1,
             box2,
             xywh=True,
             GIoU=False,
             DIoU=False,
             CIoU=False,
             SIoU=False,
             EIoU=False,
             Focal=False,
             alpha=1,
             gamma=0.5,
             eps=1e-7):
    """
    计算bboxes iou
    Args:
        box1: predict bboxes
        box2: target bboxes
        xywh: 将bboxes转换为xyxy的形式
        GIoU: 为True时计算GIoU LOSS (yolov5自带)
        DIoU: 为True时计算DIoU LOSS (yolov5自带)
        CIoU: 为True时计算CIoU LOSS (yolov5自带,默认使用)
        SIoU: 为True时计算SIoU LOSS (新增)
        EIoU: 为True时计算EIoU LOSS (新增)
        Focal: 为True时,可结合其他的XIoU生成对应的IoU变体,如CIoU=True,Focal=True时为Focal-CIoU
        alpha: AlphaIoU中的alpha参数,默认为1,为1时则为普通的IoU,如果想采用AlphaIoU,论文alpha默认值为3,此时设置CIoU=True则为AlphaCIoU
        gamma: Focal_XIoU中的gamma参数,默认为0.5
        eps: 防止除0

    Returns:
        iou
    """
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    # iou = inter / union # ori iou
    iou = torch.pow(inter / (union + eps), alpha)  # alpha iou
    if CIoU or DIoU or GIoU or EIoU or SIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or EIoU or SIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squared
            rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (
                        b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha_ciou = v / (v - iou + (1 + eps))
                if Focal:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps),
                                                                                                 gamma)  # Focal_CIoU
                return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoU
            elif EIoU:
                rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                cw2 = torch.pow(cw ** 2 + eps, alpha)
                ch2 = torch.pow(ch ** 2 + eps, alpha)
                if Focal:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps), gamma)  # Focal_EIou
                return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)  # EIou
            elif SIoU:
                # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
                s_cw, s_ch = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps, (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
                sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
                sin_alpha_1, sin_alpha_2 = torch.abs(s_cw) / sigma, torch.abs(s_ch) / sigma
                threshold = pow(2, 0.5) / 2
                sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
                angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
                rho_x, rho_y = (s_cw / cw) ** 2, (s_ch / ch) ** 2
                gamma = angle_cost - 2
                distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
                omiga_w, omiga_h = torch.abs(w1 - w2) / torch.max(w1, w2), torch.abs(h1 - h2) / torch.max(h1, h2)
                shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
                if Focal:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(
                        inter / (union + eps), gamma)  # Focal_SIou
                return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha)  # SIou
            if Focal:
                return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma)  # Focal_DIoU
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        if Focal:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps), gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
        return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    if Focal:
        return iou, torch.pow(inter / (union + eps), gamma)  # Focal_IoU
    return iou  # IoU

在调用bbox_iou函数的地方做如下修改(主要是__call__中):

class ComputeLoss:
    sort_obj_iou = False

    # Compute losses
    def __init__(self, model, autobalance=False):
        device = next(model.parameters()).device  # get model device
        h = model.hyp  # hyperparameters

        # Define criteria
        BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
        BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))

        # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
        self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))  # positive, negative BCE targets

        # Focal loss
        g = h['fl_gamma']  # focal loss gamma
        if g > 0:
            BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

        m = de_parallel(model).model[-1]  # Detect() module
        self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02])  # P3-P7
        self.ssi = list(m.stride).index(16) if autobalance else 0  # stride 16 index
        self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
        self.na = m.na  # number of anchors
        self.nc = m.nc  # number of classes
        self.nl = m.nl  # number of layers
        self.anchors = m.anchors
        self.device = device

    def __call__(self, p, targets):  # predictions, targets
        lcls = torch.zeros(1, device=self.device)  # class loss
        lbox = torch.zeros(1, device=self.device)  # box loss
        lobj = torch.zeros(1, device=self.device)  # object loss
        tcls, tbox, indices, anchors = self.build_targets(p, targets)  # targets

        # Losses
        for i, pi in enumerate(p):  # layer index, layer predictions
            b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
            tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device)  # target obj

            n = b.shape[0]  # number of targets
            if n:
                # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1)  # faster, requires torch 1.8.0
                pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1)  # target-subset of predictions

                # Regression
                pxy = pxy.sigmoid() * 2 - 0.5
                pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
                pbox = torch.cat((pxy, pwh), 1)  # predicted box
                # iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze()  # iou(prediction, target)
                # lbox += (1.0 - iou).mean()  # iou loss
                # //
                iou = bbox_iou(pbox, tbox[i], EIoU=True)
                if isinstance(iou,tuple):
                    lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean()
                    iou = iou[0].squeeze()
                else:
                    lbox += (1.0 - iou.squeeze()).mean()  # iou loss
                    iou = iou.squeeze()
                # /

                # Objectness
                iou = iou.detach().clamp(0).type(tobj.dtype)
                if self.sort_obj_iou:
                    j = iou.argsort()
                    b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
                if self.gr < 1:
                    iou = (1.0 - self.gr) + self.gr * iou
                tobj[b, a, gj, gi] = iou  # iou ratio

                # Classification
                if self.nc > 1:  # cls loss (only if multiple classes)
                    t = torch.full_like(pcls, self.cn, device=self.device)  # targets
                    t[range(n), tcls[i]] = self.cp
                    lcls += self.BCEcls(pcls, t)  # BCE

                # Append targets to text file
                # with open('targets.txt', 'a') as file:
                #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]

            obji = self.BCEobj(pi[..., 4], tobj)
            lobj += obji * self.balance[i]  # obj loss
            if self.autobalance:
                self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()

        if self.autobalance:
            self.balance = [x / self.balance[self.ssi] for x in self.balance]
        lbox *= self.hyp['box']
        lobj *= self.hyp['obj']
        lcls *= self.hyp['cls']
        bs = tobj.shape[0]  # batch size

        return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()

    def build_targets(self, p, targets):
        # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
        na, nt = self.na, targets.shape[0]  # number of anchors, targets
        tcls, tbox, indices, anch = [], [], [], []
        gain = torch.ones(7, device=self.device)  # normalized to gridspace gain
        ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)
        targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2)  # append anchor indices

        g = 0.5  # bias
        off = torch.tensor(
            [
                [0, 0],
                [1, 0],
                [0, 1],
                [-1, 0],
                [0, -1],  # j,k,l,m
                # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
            ],
            device=self.device).float() * g  # offsets

        for i in range(self.nl):
            anchors, shape = self.anchors[i], p[i].shape
            gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]]  # xyxy gain

            # Match targets to anchors
            t = targets * gain  # shape(3,n,7)
            if nt:
                # Matches
                r = t[..., 4:6] / anchors[:, None]  # wh ratio
                j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']  # compare
                # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
                t = t[j]  # filter

                # Offsets
                gxy = t[:, 2:4]  # grid xy
                gxi = gain[[2, 3]] - gxy  # inverse
                j, k = ((gxy % 1 < g) & (gxy > 1)).T
                l, m = ((gxi % 1 < g) & (gxi > 1)).T
                j = torch.stack((torch.ones_like(j), j, k, l, m))
                t = t.repeat((5, 1, 1))[j]
                offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
            else:
                t = targets[0]
                offsets = 0

            # Define
            bc, gxy, gwh, a = t.chunk(4, 1)  # (image, class), grid xy, grid wh, anchors
            a, (b, c) = a.long().view(-1), bc.long().T  # anchors, image, class
            gij = (gxy - offsets).long()
            gi, gj = gij.T  # grid indices

            # Append
            indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image, anchor, grid
            tbox.append(torch.cat((gxy - gij, gwh), 1))  # box
            anch.append(anchors[a])  # anchors
            tcls.append(c)  # class

        return tcls, tbox, indices, anch

        注意需要从对应的py文件中import对应的函数,并需要注释原始函数

# from utils.metrics import bbox_iou
from utils.iou import bbox_iou

         如果需要应用对应的IoU loss的变体,即可将Focal设置为True,并将对应的IoU也设置为True,如CIoU=True,Focal=True时为Focal-CIoU,此时可以调整gamma,默认设置为0.5。

        如果想要使用AlphaXIoU,将alpha设置为3同时将对应的IoU也设置为True即可,alpha默认设置为1。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/196191.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Kindling 观测 Kubernetes 应用网络连接状态

kindling介绍&#xff1a; Kindling 解决的是&#xff0c;在不入侵应用的前提下&#xff0c;如何观测网络的问题&#xff0c;其功能主要是通过暴露内核事件来实现观测。如果主机内核版本高于 4.14&#xff0c;可以使用 eBPF 模块&#xff1b;如果主机内核是低版本&#xff0c;…

多级缓存实现

多级缓存实现1.什么是多级缓存2.JVM进程缓存2.1.导入案例2.2.初识Caffeine2.3.实现JVM进程缓存2.3.1.需求2.3.2.实现3.Lua语法入门3.1.初识Lua3.1.HelloWorld3.2.变量和循环3.2.1.Lua的数据类型3.2.2.声明变量3.2.3.循环3.3.条件控制、函数3.3.1.函数3.3.2.条件控制3.3.3.案例4…

俯卧撑计数 opencv-python + mediapipe

分享一个国外的趣味项目&#xff0c;可以计数&#xff0c;也可以完善进行动作是打分&#xff0c;确定标准程度 原文链接&#xff1a;https://aryanvij02.medium.com/push-ups-with-python-mediapipe-open-a544bd9b4351 程序原理介绍 在新加坡军队中&#xff0c;有一种测试叫做…

程序股票交易接口怎么撤单?

在程序股票交易接口的开发基础上&#xff0c;还能增加一个撤单的委托模块&#xff0c;因为程序股票交易接口的开发不单单是委托下单&#xff0c;那照样也能撤单&#xff0c;这两种的开发原理上&#xff0c;都不冲突&#xff0c;有的股票接口需要计算多种算法&#xff0c;算起来…

CNCAP2021法规adas功能场景

CNCAP2021法规adas功能场景概述功能介绍试验场景概述 C-NCAP是中国汽车技术研究中心于2006年3月2日正式发布的首版中国新车评价规程。中国新车评价规程每三年进行一次规程改版&#xff0c;最新的是2021版本。本文只针对cncap2021主动安全场景进行梳理。 功能介绍 1、AEB(Aut…

vue2低代码平台搭建(三)组件间交互的实现

前言 大家好,我是L丶Y,我们在上一篇文章中主要介绍了低代码平台的页面设计器相关的一些功能原理,打通了页面设计器顶部操作栏、左侧组件列表,中间画布、右侧属性配置四个部分的关系。能够实现组件列表的展示、组件到画布的拖动,属性配置修改对组件渲染效果影响,并说明了…

Picgo配置Bilibili图床

Picgo 配置Bilibili 图床 picgo-plugin-bilibili 为 PicGo 开发的一款插件&#xff0c;新增了B站图床 图床。 使用用户动态的图片上传API。填写SESSDATA即可&#xff0c;获取方式在下面。 文章目录Picgo 配置Bilibili 图床在线安装获取B站SESSDATA图片样式解决B站防盗链&#…

隐函数及参数方程求导——“高等数学”

各位CSDN的uu们你们好呀&#xff0c;今天&#xff0c;小雅兰的内容是隐函数求导和参数方程求导&#xff0c;下面&#xff0c;就让我们进入求导数的世界吧 一、隐函数的导数 二、隐函数求导 三、由参数方程确定的函数的导数 四、相关变化率 一、隐函数的导数 要想知道隐函数…

官宣:计算中间件 Apache Linkis 正式毕业成为 Apache 顶级项目

Apache 软件基金会&#xff08;ASF&#xff09;孵化器于2022年12月03日&#xff0c;通过了 Apache Linkis 计算中间件项目的孵化毕业投票。2023年01月18日&#xff0c;Apache 软件基金会官方宣布 Apache Linkis 顺利毕业&#xff0c;成为 Apache 顶级项目&#xff08;TLP&#…

泊松分布的计算方式

如果都要计算泊松分布了&#xff0c;那么就默认你知道泊松分布的基本知识了&#xff0c;我这里只介绍如何计算&#xff0c;我是用的Excel直接套用公式计算的&#xff0c;如果想在代码里用&#xff0c;我的实现方式是&#xff0c;先用Excel把值全部求出来&#xff0c;然后做成ma…

开组会写论文必备工具清单

来源&#xff1a;投稿 作者&#xff1a;卷舒 编辑&#xff1a;学姐 公式工具 https://www.latexlive.com/ snip这个工具也与之类似&#xff0c;但是需要安装&#xff0c;且有50个的限制。 这是一个LaTeX公式在线编辑器。提供了各类快捷符号及公式模板。在输入区域尝试LaTeX公…

创建SpringBoot工程

目录 一、官网创建 1、进入spring官网&#xff1a;Spring | Home 2、点击Spring Boot&#xff0c;滑倒最下面&#xff0c;点击Spring initializr 3&#xff0c;创建Spring Boot工程&#xff0c;版本选2.多的&#xff0c;高版本配置可能会出现问题&#xff0c;jdk选你电脑上装…

训练营day15

层序遍历 10 226.翻转二叉树 101.对称二叉树 2 102.二叉树的层序遍历 力扣题目链接 给你一个二叉树&#xff0c;请你返回其按 层序遍历 得到的节点值。 &#xff08;即逐层地&#xff0c;从左到右访问所有节点&#xff09;。 接下来我们再来介绍二叉树的另一种遍历方式&#x…

Java面试题二-并发编程、IO、设计模式、通信/网络、JVM

本文目录如下&#xff1a;Java面试题(二)四、并发编程线程和进程的区别&#xff1f;守护线程是什么&#xff1f;创建线程有哪几种方式&#xff1f;线程有哪些状态&#xff1f;sleep() 和 wait() 有什么区别&#xff1f;线程的sleep()方法和yield()方法有什么区别&#xff1f;线…

KVM虚拟机安装Virtio驱动

KVM虚拟机安装Virtio驱动KVM服务器上配置Virtio驱动ISO文件Windows实例安装Virtio驱动Linux实例安装kvm驱动(一般不需要)KVM服务器上配置Virtio驱动ISO文件 参考&#xff1a;https://www.ibm.com/docs/zh/cloud-orchestrator/2.5.0.1?topicimages-installing-virtio-driver-k…

成功解决:尚硅谷中的谷粒商城前端项目运行依赖问题。【详细图解+问题说明+解决思路】

一个混迹于Github、Stack Overflow、开源中国、CSDN、博客园、稀土掘金、51CTO等 的野生程序员。 目标&#xff1a;分享更多的知识&#xff0c;充实自己&#xff0c;帮助他人 GitHub公共仓库&#xff1a;https://github.com/zhengyuzh 以github为主&#xff1a; 1、分享前端后端…

程序员超级书单:技术人必看

写在前面 电影一部两小时打底&#xff0c;却很少有人可以静下心看30分钟书。今天刷沸点摸鱼, 无意中摸到一条让我有写作冲动的鱼&#xff0c;开工几天了&#xff0c;大家应该都暗暗立下不少flag&#xff0c;比如坚持锻炼&#xff0c;再比如自己今年要看多少本书籍。 行业内卷…

舆情监控工具app推荐,舆情监控工具包括哪些?

舆情监控工具的目的是帮助组织了解公众对各种话题、问题或品牌的意见和情绪。它们可以从多种来源收集数据&#xff0c;包括社交媒体、新闻媒体、博客、论坛和评论区。舆情监控工具app推荐&#xff0c;舆情监控工具包括哪些? 一、舆情监控工具app推荐 TOOM舆情监控工具&…

算法基础集训(第31天)------>BFS之经典【走迷宫】

一&#xff1a;概念定义给定一个 nm 的二维整数数组&#xff0c;用来表示一个迷宫&#xff0c;数组中只包含 0 或 1&#xff0c;其中 0 表示可以走的路&#xff0c;1 表示不可通过的墙壁。最初&#xff0c;有一个人位于左上角 (1,1) 处&#xff0c;已知该人每次可以向上、下、左…

扎克伯格15年邮件曝光:AR/VR平台全盘细节,谈收购Unity的优势

近期一份2015年时扎克伯格的一封邮件曝光&#xff0c;文中解释了Facebook在VR领域的战略策略和VR对未来的影响等&#xff0c;揭露了但是Facebook在VR/AR方向上的全盘计划&#xff0c;同时曝光当时Facebook就有收购Unity的计划&#xff0c;值得一看。以下是摘要&#xff1a;1&am…