yolov7改进之使用QFocalLoss

news2025/4/17 3:54:08

深度学习三大件:数据、模型、Loss。一个好的Loss有利于让模型更容易学到需要的特征,不过深度学习已经白热化了,Loss这块对一个成熟任务的提升是越来越小了。虽然如此,也不妨碍我们在难以从数据和模型层面入手时,从这个方面尝试了。

BCEBlurWithLogitsLoss

yolov7中loss由三部分构成,cls loss, obj loss, box loss。分别是类别损失、框的置信度损失、框的位置损失。其中cls loss和obj loss都是使用BCEBlurWithLogitsLoss,这个loss的源代码如下:

class BCEBlurWithLogitsLoss(nn.Module):
    # BCEwithLogitLoss() with reduced missing label effects.
    def __init__(self, alpha=0.05):
        super(BCEBlurWithLogitsLoss, self).__init__()
        self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none')  # must be nn.BCEWithLogitsLoss()
        self.alpha = alpha

    def forward(self, pred, true):
        loss = self.loss_fcn(pred, true)
        pred = torch.sigmoid(pred)  # prob from logits
        dx = pred - true  # reduce only missing label effects
        # dx = (pred - true).abs()  # reduce missing label and false label effects
        alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
        loss *= alpha_factor
        return loss.mean()

与普通的交叉熵损失相比,这个损失可以降低missing label(也就是本身是个正样本,但是没有标注)以及false label(错误标注,代码中注释部分)。做法是降低这些样本的loss权重,具体而言就是通过pred和label的差异来衡量,差异太大就降低权重。

FocalLoss

yolov7中也提供了另一种Loss, 在hyp参数中设置大于0的gamma就可以开启,FocalLoss的思想是加大困难样本的权重,同时为正负样本分配不同的权重,公式如下:
F L ( p t ) = − a t ( 1 − p t ) γ l o g ( p t ) FL(p_t)=-a_t(1-p_t)^\gamma log(p_t) FL(pt)=at(1pt)γlog(pt)
讨论对于一个二分类的问题,也就是两个类别讨论。当一个样本被分错时,也就是当标签类y = 1时,p = 0.3,根据上式可以看到,y=1 , p= 0.3 , 则 p t = 0.3 p_t = 0.3 pt=0.3那么 ( 1 − p t ) γ (1 - p_t)^\gamma (1pt)γ就很大(通常 γ \gamma γ取2)。这也就说明,分错的这个类表示难分的类。

QFocalLoss

FocalLoss存在一些问题:
classification score 和 IoU/centerness score 训练测试不一致。

这个不一致主要体现在两个方面:

1) 用法不一致。训练的时候,分类和质量估计各自训记几个儿的,但测试的时候却又是乘在一起作为NMS score排序的依据,这个操作显然没有end-to-end,必然存在一定的gap。

2) 对象不一致。借助Focal Loss的力量,分类分支能够使得少量的正样本和大量的负样本一起成功训练,但是质量估计通常就只针对正样本训练。那么,对于one-stage的检测器而言,在做NMS score排序的时候,所有的样本都会将分类score和质量预测score相乘用于排序,那么必然会存在一部分分数较低的“负样本”的质量预测是没有在训练过程中有监督信号的,有就是说对于大量可能的负样本,他们的质量预测是一个未定义行为。这就很有可能引发这么一个情况:一个分类score相对低的真正的负样本,由于预测了一个不可信的极高的质量score,而导致它可能排到一个真正的正样本(分类score不够高且质量score相对低)的前面。具体可以参考QFocalLoss作者的博客, 知乎《大白话 Generalized Focal Loss》。
对于第一个问题,为了保证training和test一致,同时还能够兼顾分类score和质量预测score都能够训练到所有的正负样本,那么一个方案呼之欲出:就是将两者的表示进行联合。这个合并也非常有意思,从物理上来讲,我们依然还是保留分类的向量,但是对应类别位置的置信度的物理含义不再是分类的score,而是改为质量预测的score。这样就做到了两者的联合表示,同时,暂时不考虑优化的问题,我们就有可能完美地解决掉第一个问题

简单的说,做过检测的应该知道,一个检测结果的置信度=类别的置信度* 框的置信度。但是训练的时候,这两个置信度是分开的,这个不一致可能出现类别置信度低,但框置信度高,。QFocalLoss的出发点就是在训练时,分类损失中的label修改成label* 框置信度。这样修改会出现0~1之间的label,原版的FocalLoss是不支持离散label的,所以作者进行了魔改:
image.png
yolov7中有人提交了一个QFocalLoss代码,不过并没有启用,并且从代码来看,并没有正确实现作者的主要意图:将分类score和框置信度score联合训练。所以我没有使用yolov7中的实现,而是对其进行了修改:

class QFocalLoss(nn.Module):
    # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
    def __init__(self, loss_fcn, beta = 2.0):
        super(QFocalLoss, self).__init__()
        self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
        self.beta = beta
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = 'none'  # required to apply FL to each element

    def forward(self, pred, target):
        assert len(target) ==2, "target must be a tuple of (class, score)"
        label, score = target
        iou_target = label*score.view(-1,1)
        pred_sigmoid = torch.sigmoid(pred)  # prob from logits
        scale_factor = torch.abs(pred_sigmoid - iou_target)
        beta = self.beta
        loss = self.loss_fcn(pred, iou_target) * scale_factor.pow(beta)

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:  # 'none'
            return loss

主要区别在于此处target由类别标签和预测的框置信度共同构成。
为了使用QFocalLoss, 需要在调用时,将预测框的执行度也传入,具体而言,其实就是预测框和GT的iou, 在代码里本身就已经有了:

 # Objectness
                tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)  # iou ratio

                # Classification
                selected_tcls = targets[i][:, 1].long()
                if self.nc > 1:  # cls loss (only if multiple classes)
                    t = torch.full_like(ps[:, 5:], self.cn, device=device)  # targets
                    t[range(n), selected_tcls] = self.cp
                    
                    if isinstance(self.BCEcls, QFocalLoss):
                        lcls += self.BCEcls(ps[:, 5::], (t, iou.detach()))
                    else:
                        lcls += self.BCEcls(ps[:, 5:], t)  # BCE

再把代码中原本用来触发FocalLoss的地方,修改为触发QFocalLoss即可。

使用体验

从个人在私有任务使用看,使用QFocalLoss对于map基本没有提升,甚至误检增多了(recall升高)。分析可能是和FocalLoss一样的问题,由于提高了困难样本的权重,会造成过于关注一些模棱两可的样本。对于小样本来说,估计会有提升,对于一般任务,不如使用BCEBlurWithLogitsLoss。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1160261.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Object转List<>,转List<Map<>>

这样就不会局限在转换到List<Map<String,Object>>这一种类型上了.可以转换成List<Map<String,V>>上等,进行泛型转换虽然多了一个参数,但是可以重载啊注: 感觉field.get(key) 这里处理的不是很好,如果有更好的办法可以留言 public static <K, V> …

大数据Doris(十四):Doris表中的数据基本概念

文章目录 Doris表中的数据基本概念 一、​​​​​​​Row & Column

前端项目 index.html 中发请求 fetch

想要在前端项目 index.html文件中向后端发起请求&#xff0c;但是引入axios报错&#xff08;我这边会报错&#xff09;&#xff0c;可以使用fetch。 //window.location.origin----获取域名&#xff0c;包括协议、主机号、端口号fetch(window.location.origin "/api/pla…

el-tabs 默认选中第一个

1. 实际开发中el-tabs 都会设置第一个为默认值 ,这样会好看一点, 而渲染的数据经常是通过后端返回的数据 , v-model 无法写死默认值 解决办法 , 通过计算机属性 ,在data 定义一个 selectedTab watch: {defaultTab(newVal) {this.selectedTab newVal; // 设置第一个标签页…

腾讯云双11云服务器活动:3年366元,超多超值云服务器!

腾讯云在双11活动中推出了一款3年366元的云服务器&#xff0c;配置为2核2G 40GB SSD盘&#xff0c;300GB月流量&#xff0c;4M带宽。这一配置相较于其他厂商同等规格的云服务器&#xff0c;具有较高的性价比。在市场上很少有厂商提供3年期的优惠服务器&#xff0c;因此此次双11…

函数式接口详解(Java)

函数式接口详解&#xff08;Java&#xff09;_函数式接口作为参数_凯凯凯凯.的博客-CSDN博客 函数式接口&#xff1a;有且仅有一个抽象方法的接口 Java中函数式编程体现就是Lambda表达式&#xff0c;所以函数式接口就是可以适用于Lambda使用的接口 只有确保接口中仅有一个抽…

第二证券:怎么判断股票浮筹多少?

股票的浮筹是指公司的股份中&#xff0c;揭露生意在市场上的股份&#xff0c;一般是指除了大股东和筹码安稳的组织等&#xff0c;其他组织和个人能够自在生意的股份。在出资股票时&#xff0c;了解公司的浮筹是非常重要的&#xff0c;由于它直接联络到股票的供需联络和股价动摇…

初识JavaScript(一)

文章目录 一、JavaScript介绍二、JavaScript简介1.ECMAScript和JavaScript的关系2.ECMAScript的历史3.什么是Javascript&#xff1f;4.JavaScript的作用?5.JavaScript的特点 三、JavaScript基础1.注释语法2.JavaScript的使用 四、JavaScript变量与常量变量关键字var和let的区别…

苹果AirTag固件更新

苹果公司针对其热销的物品追踪器 AirTag 于今天发布了新的固件更新&#xff0c;最新版本号为 2A61&#xff0c;但是这次更新苹果并未提供发布说明&#xff0c;所以目前还不知道这次更新有什么新内容。 关于这次更新&#xff0c;用户无法自己手动更新 AirTag 固件&#xff0c;因…

5.1 运输层协议概述

思维导图&#xff1a; 前言&#xff1a; 第5章 运输层笔记 1. 概览 主要内容&#xff1a;介绍运输层协议的特点、进程间通信、端口、UDP和TCP协议、可靠传输、TCP报文段的首部格式、TCP的关键概念&#xff08;如滑动窗口、流量控制、拥塞控制和连接管理&#xff09;。重要性…

自定义类型结构体(上)

目录 结构体类型的声明结构体的概念结构体的声明特殊的声明结构的自引用 结构体变量的创建和初始化结构成员访问操作符 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 &#x1f412;&#x1f412;&#x1f412; 个人主页 &#x1f978;&#x1f978;&#x1…

linux下df -h 命令一直卡住的解决方法

在Linux中&#xff0c;偶尔遇到用 df -h 查看磁盘情况时&#xff0c;一直卡住无法显示结果。 解决方法&#xff1a; 1、首先使用strace追踪到底执行到哪里卡住 $ strace df -h 2、如果没有strace命令则进行安装 $ yum install strace -y 3、显示出卡住的地方&#xff0c;如…

[0xGame 2023 公开赛道] week4 crypto/pwn/rev

最后一周结束了&#xff0c;难度也很大&#xff0c;已经超出我这认为的新生程度了。 crypto Orac1e 先看题&#xff0c;题目先是给了加密过的flag然后提供不限次数的解密&#xff0c;不过仅提供解密后unpad的结果。 from Crypto.Util.number import * from Crypto.Cipher i…

WinCC7.5 将归档数据打印到MSHGrid(不是MSFlexGrid控件)

参考网址&#xff1a;https://www.cnblogs.com/fishingsriver/p/14397431.html 新建变量 MSHGrid控件 查询按钮 Sub OnClick(ByVal Item) Dim myCatalog,myDS,PCN…

【python】路径管理+路径拼接问题

路径管理 问题相对路径问题绝对路径问题 解决os库pathlib库最终解决 问题 环境&#xff1a;python3.7.16 win10 相对路径问题 因为python的执行特殊性&#xff0c;使用相对路径时&#xff0c;在不同路径下用python指令会有不同的索引效果&#xff08;python的项目根目录根据执…

【嵌入式】HC32F07X CAN通讯配置和使用配置不同缓冲器以连续发送

一 背景说明 使用小华&#xff08;华大&#xff09;的MCU HC32F07X实现 CAN 通讯配置和使用 二 原理分析 【1】CAN原理说明&#xff08;参考文章《CAN通信详解》&#xff09;&#xff1a; CAN是控制器局域网络(Controller Area Network, CAN)的简称&#xff0c;是一种能够实现…

公司上网行为监控能监控到什么

公司上网行为监控是一个备受关注的话题&#xff0c;它可以监控员工的网络行为&#xff0c;保护企业的机密和数据安全。但是&#xff0c;这种监控行为也会涉及到员工的隐私权和数据安全问题。 公司上网行为监控能监控到什么&#xff1a; 1、访问网站精确到员工信息、计算机名称…

uni-starter 使用常见问题

1. Invalid uni-id config file 没有找到uni-id文件导致 需要在uniCloud-aliyun/cloudfunctions/common/uni-config-center/uni-id/下新建 config.json 如果没有uni-id 就新建一个。 注意&#xff1a;config.json是一个标准json文件&#xff0c;不支持注释 uni-starter 按照…

git简单介绍,回车换行问题,倒计时+进度条小程序的实现+代码

目录 git--版本控制工具 介绍 使用 小程序 引入 回车换行问题 缓冲区问题 倒计时 分析 代码 进度条 分析 代码 git--版本控制工具 首先,我们需要下载git : yum install git 介绍 Git是一种分布式版本控制系统&#xff0c;用于跟踪文件和目录的变化并协调多个人之…

ACID模型

ACID 是数据库管理系统&#xff08;DBMS&#xff09;中用来确保事务处理正确性和可靠性的四个特性的首字母缩写。ACID 是指原子性&#xff08;Atomicity&#xff09;、一致性&#xff08;Consistency&#xff09;、隔离性&#xff08;Isolation&#xff09;和持久性&#xff08…