Focal Loss论文解读和调参教程

news2024/12/24 8:24:28

论文:Focal Loss for Dense Object Detection

论文papar地址:ICCV 2017 Open Access Repository

在各个主流深度学习框架里基本都有实现,本文会以mmcv里的focal loss实现为例(基于pytorch)

简介:

本文是何恺明团队ICCV 2017的一篇文章,主要针对检测场景类别不均衡导致一阶段算法没有二阶段算法精度高,在CE loss的基础上进行改进,提出了Focal Loss,并且本文改动了faster rcnn,魔改成了一个一阶段的算法RetinaNet,也是后续很多工作拿来当baseline的anchor-based一阶段算法。

动机是作者认为,一阶段和二阶段算法的精度差距,主要原因是一阶段基本都是dense detect(指采样的区域很密集,简而言之就是anchor box/proposal很多),而二阶段的算法是精选出高质量的样本(比如RPN、selective search),在二阶段产生相对较少的ROI进行回归和分类预测。一阶段产生那么多anchor ,但是其中只有一小部分变成最后预测的bbox result,因此会有很多易分类负样本在loss function里占很大的比重,就会不利于训练。也就是说Focal Loss的贡献就是缓解了类别不平衡问题(注意:这里的类别不平衡不单单是指正负样本数量的不平衡,还有难易样本数量的不平衡)。

Focal Loss具体原理

修改是基于CE loss的(因此focal loss是分类的loss,当然也用于检测框的分类,只是跟回归无关),首先为正样本加入权重因子α,这样的操作一般叫Balanced Cross Entropy,为了解决正负样本不平衡对损失函数造成的影响。

最原本的CE loss(cross entroy loss交叉熵损失函数)形式如下:

为了解决正负样本不平衡问题(负样本太多,正样本太少),一个nature的思路就是给正负样本添加权重alpha,用来减小负样本的占比影响,

 显然alpha越大,正样本的loss占比越大!即α设置的越大,负样本对loss的影响越小。这样就解决了正负样本数量不平衡对最后整个loss函数造成的影响。

下面解决难易样本数量不平衡:在训练时,易分样本数量远大于难分样本数量,易分样本指的是:target为正样本,且pred得分(检测框的score)高,即易分正样本;target为负样本,且pred得分低,即易分负样本

为此我们再引入一个权重gamma,用来减小易分样本的占比影响

 至此,只需要组合上面的α和γ,就得到了Focal Loss的最终形式:

 这种分类loss既能够缓解正负样本数量不均衡的问题,也能缓解难易样本数量不均衡问题,只引入了两个超参数。

值得一提的是,作者在原文中通过实验证明,在COCO数据集上,α取0.25,γ取2的组合精度最高。

RetinaNet

因为这篇文章里提出了一个比较著名的网络RetinaNet,因此顺便也介绍下。

RetinaNet是一个一阶段的网络,由一个主干网络和两个特定于任务(目标检测)的两个子网络(其实就是一个分类头+一个回归头)。

作者用这个很简单的retinanet当做一个一阶段算法的baseline,通过在上面用focal loss超越了二阶段的faster rcnn精度,同时又保留了一阶段的高效率。以此来证明一阶段和二阶段的算法精度差距确实就在于作者提出的类别不平衡猜想。

mmcv中focal loss实现源码和调参

这里首先提示一句,一般看到的二阶段算法的cls_loss都是最基础的CE loss,因为二阶段已经有成熟的RPN,因此生成的anchor或者说proposal的类别不均衡问题不严重,因此没必要用focal loss。

这里就以mmdet里的focal loss实现为例,源码位置在mmdet\models\losses\focal_loss.py

class FocalLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=True,
                 gamma=2.0,
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0,
                 activated=False):
        """`Focal Loss <https://arxiv.org/abs/1708.02002>`_

        Args:
            use_sigmoid (bool, optional): Whether to the prediction is
                used for sigmoid or softmax. Defaults to True.
            gamma (float, optional): The gamma for calculating the modulating
                factor. Defaults to 2.0.
            alpha (float, optional): A balanced form for Focal Loss.
                Defaults to 0.25.
            reduction (str, optional): The method used to reduce the loss into
                a scalar. Defaults to 'mean'. Options are "none", "mean" and
                "sum".
            loss_weight (float, optional): Weight of loss. Defaults to 1.0.
            activated (bool, optional): Whether the input is activated.
                If True, it means the input has been activated and can be
                treated as probabilities. Else, it should be treated as logits.
                Defaults to False.
        """
        super(FocalLoss, self).__init__()
        assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
        self.use_sigmoid = use_sigmoid
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.activated = activated

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.

        Args:
            pred (torch.Tensor): The prediction.
            target (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Options are "none", "mean" and "sum".

        Returns:
            torch.Tensor: The calculated loss
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if self.use_sigmoid:
            if self.activated:
                calculate_loss_func = py_focal_loss_with_prob
            else:
                if torch.cuda.is_available() and pred.is_cuda:
                    calculate_loss_func = sigmoid_focal_loss
                else:
                    num_classes = pred.size(1)
                    target = F.one_hot(target, num_classes=num_classes + 1)
                    target = target[:, :num_classes]
                    calculate_loss_func = py_sigmoid_focal_loss

            loss_cls = self.loss_weight * calculate_loss_func(
                pred,
                target,
                weight,
                gamma=self.gamma,
                alpha=self.alpha,
                reduction=reduction,
                avg_factor=avg_factor)

        else:
            raise NotImplementedError
        return loss_cls

可以看到只需要在init这个loss的时候赋予gamma和alpha就可以,比如我改变我的htc算法config里的

loss_cls=dict(
    type='CrossEntropyLoss',
    use_sigmoid=False,
    loss_weight=1.0),

改成

loss_cls=dict(
    type='FocalLoss),

即可,用的alpha和gamma都是论文里默认的“最优决策”:α=0.25,γ=2.0

当然这两个超参数要根据你实际的数据集和任务场景调整。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/418154.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1.mybatis-plus入门及使用

1.什么是MybatisPlus MyBatis-Plus 官网 为什么要学MybatisPlus&#xff1f; MybatisPlus可以节省大量时间&#xff0c;所有的CRUD代码都可以自动化完成MyBatis-Plus是一个MyBatis的增强工具&#xff0c;在 MyBatis 的基础上只做增强不做改变&#xff0c;为简化开发、提高效…

Java——数组中出现次数超过一半的数字

题目链接 牛客在线oj题——数组中出现次数超过一半的数字 题目描述 给一个长度为 n 的数组&#xff0c;数组中有一个数字出现的次数超过数组长度的一半&#xff0c;请找出这个数字。 例如输入一个长度为9的数组[1,2,3,2,2,2,5,4,2]。由于数字2在数组中出现了5次&#xff0c;…

FastDFS与Nginx结合搭建文件服务器,并内网穿透实现公网访问

文章目录前言1. 本地搭建FastDFS文件系统1.1 环境安装1.2 安装libfastcommon1.3 安装FastDFS1.4 配置Tracker1.5 配置Storage1.6 测试上传下载1.7 与Nginx整合1.8 安装Nginx1.9 配置Nginx2. 局域网测试访问FastDFS3. 安装cpolar内网穿透4. 配置公网访问地址5. 固定公网地址5.1 …

低代码开发重要工具:jvs-flow (流程引擎)2.1.7版本更新内容

流程引擎主要包含了流程定义和编辑、任务分配和处理、流程监控和跟踪、数据模型和存储、条件和规则设置、安全性和权限管理、性能优化以及持续集成和部署等功能&#xff0c;以满足不同业务场景下的需求。 JVS流程引擎从V2版本开始&#xff0c;由flowable切换为 jvs-flow&#…

2023 年 五 大数据恢复软件帮助您找回数据

您是否刚刚丢失了一份需要数天工作才能更换的重要文件&#xff1f;不要恐慌&#xff01;此列表中排名前 10 位的最佳数据恢复软件应用程序可以帮助您找回数据&#xff0c;您甚至可能不必在它们上花任何钱。 五大最佳数据恢复软件工具 以下是我们最喜欢的 10 大数据恢复软件应用…

记录-vue项目中使用PWA

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 前言&#xff1a; 梳理了一下项目中的PWA的相关用法&#xff0c;下面我会正对vue2和vue3的用法进行一些教程示例&#xff0c;引入离线缓存机制&#xff0c;即使你断网&#xff0c;也能访问页面。一旦用…

动力节点王鹤SpringBoot3笔记——第八章 文章管理模块

目录 第八章 文章管理模块 8.1 配置文件 8.2 视图文件 8.3 Java代码 第八章 文章管理模块 创建新的Spring Boot项目&#xff0c;综合运用视频中的知识点&#xff0c;做一个文章管理的后台应用。 新的Spring Boot项目Lession20-BlogAdmin。Maven构建工具&#xff0c;包…

VxLAN数据中心L2互连(hand-off方式)

用Arista的veos做了个DCI&#xff08;hand-off&#xff09;实验。模拟了VxLAN数据中心hand-off方式做L2互通。 在此分享。 实现思路 分别在DC1、DC2内配置BGP EVPN协议创建VXLAN隧道&#xff0c;实现各数据中心内部VM之间的通信&#xff0c;DC1-BL和DC2-BL通过二层接口方式接…

spring事务(注解 @Transactional )失效场景

目录标题1. 代理不生效1.1 将注解标注在接口方法上1.2 被final、static关键字修饰的类或方法1.3 类方法内部调用示例解决方案&#xff1a;新加一个Service方法1.4 (类本身) 未被spring管理2. 框架或底层不支持的功能2.1 非public修饰的方法2.2 多线程调用举例1举例22.3 数据库本…

C. Uncle Bogdan and Country Happiness(dfs + 回溯)

Problem - C - Codeforces 波格丹叔叔在弗林特船长的团队里待了很长一段时间&#xff0c;有时会怀念他的家乡。今天他告诉你他的国家是如何引入幸福指数的。有n个城市和n -1条连接城市的无方向道路。任何城市的公民都可以通过这些道路到达任何其他城市。城市编号从1到n&#xf…

【软考:软件设计师】 4 计算机组成与体系结构(三)计算机安全 | 加密技术

欢迎来到爱书不爱输的程序猿的博客, 本博客致力于知识分享&#xff0c;与更多的人进行学习交流 本文收录于软考中级&#xff1a;软件设计师系列专栏,本专栏服务于软考中级的软件设计师考试,包括不限于知识点讲解与真题讲解两大部分,并且提供电子教材与电子版真题,关注私聊即可 …

服务(第二篇)LAMP

一、编译安装apache ①关闭防火墙&#xff0c;将安装Apache所需软件包传到/opt目录下 systemctl stop firewalld.service setenforce 0 [rootxxx opt]# ls apr-1.6.2.tar.gz apr-util-1.6.0.tar.gz httpd-2.4.29.tar.bz2 ②安装环境依赖包 yum -y install gcc gcc-c mak…

专业排名全美top6|建筑学硕士学历CSC获批顺利赴美

E老师人文社科背景&#xff0c;二本院校任教&#xff0c;硕士毕业&#xff0c;没有英文文章&#xff0c;且申请周期只有一个月。据此我们提出&#xff0c;以赶上CSC申报为前提&#xff0c;尽量申请美国综合或者专业排名靠前的学校。最终我们助E老师获得美国专业排名TOP6的弗吉尼…

六个阶段形成CRM销售漏斗,优点有哪些

CRM销售漏斗是反映机会状态以及销售效率的重要的销售管理模型。对企业来说&#xff0c;CRM销售漏斗是一个必不可少的工具。通过销售漏斗&#xff0c;企业可以跟踪和分析客户旅程的每个阶段&#xff0c;并制定相应的销售战略。下面来说说&#xff0c;什么是CRM销售漏斗&#xff…

高频PCB电路设计常见的66个问题

随着电子技术快速发展&#xff0c;以及无线通信技术在各领域的广泛应用&#xff0c;高频、高速、高密度已逐步成为现代电子产品的显著发展趋势之一。信号传输高频化和高速数字化&#xff0c;迫使PCB走向微小孔与埋/盲孔化、导线精细化、介质层均匀薄型化&#xff0c;高频高速高…

Redis消息队列实现异步秒杀

Redis秒杀优化 改进秒杀业务&#xff0c;提高并发性能 需求&#xff1a; 1.新增秒杀优惠券的同时&#xff0c;将优惠券的信息保存到redis中 2.基于Lua脚本&#xff0c;判断秒杀库存&#xff0c;一人一单&#xff0c;决定用户是否抢购成功 3.如果抢购成功&#xff0c;将优惠…

Android系统启动流程--init进程的启动流程

这可能是个系列文章&#xff0c;用来总结和梳理Android系统的启动过程&#xff0c;以加深对Android系统相对全面的感知和理解&#xff08;基于Android11&#xff09;。 1.启动电源&#xff0c;设备上电 引导芯片代码从预定义的地方&#xff08;固化在ROM&#xff0c;全称Read …

hive 入门 一般用于正式环境 修改元数据(二)

安装配置可参考 https://blog.csdn.net/weixin_43205308/article/details/130020674 1、如果启动过derby&#xff0c;最小初始化过 在安装路径下删除 derby.log metastore_db rm -rf derby.log metastore_db此处省略安装mysql数据库 2、配置MySQL 登录mysql mysql -uroot …

EightCap易汇:外汇投资入门需要了解哪些必要知识?

外汇市场是国际投资市场&#xff0c;日内交易量巨大&#xff0c;盈利机会极多。外汇是一种含有杠杆的投资产品&#xff0c;杠杆带来了高收益&#xff0c;也会带来高风险&#xff0c;对于外汇新手来说存在一定难度。新手投资者要如何交易&#xff0c;才能抓住外汇市场的盈利机会…

C++标准库 -- 关联容器 (Primer C++ 第五版 · 阅读笔记)

C标准库 -- 关联容器(Primer C 第五版 阅读笔记&#xff09;第11章 关联容器------(持续更新)11.1、使用关联容器11.2、关联容器概述11.3、关联容器操作11.4、无序容器第11章 关联容器------(持续更新) 关联容器和顺序容器有着根本的不同:关联容器中的元素是按关键字来保存和…