Focal Loss-解决样本标签分布不平衡问题

news2024/11/26 18:54:47

文章目录

    • 背景
      • 交叉熵损失函数
      • 平衡交叉熵函数
    • Focal Loss损失函数
      • Focal Loss vs Balanced Cross Entropy
      • Why does Focal Loss work?
    • 针对VidHOI数据集
    • Reference

背景

Focal Loss由何凯明提出,最初用于图像领域解决数据不平衡造成的模型性能问题。

交叉熵损失函数

L o s s = L ( y , p ^ ) = − y l o g ( p ^ ) − ( 1 − y ) l o g ( 1 − p ^ ) Loss=L(y,\hat{p})=-ylog(\hat{p})-(1-y)log(1-\hat{p}) Loss=L(y,p^)=ylog(p^)(1y)log(1p^)

其中, p ^ \hat{p} p^为预测概率大小。y为label,二分类中对应0和1。
L c e ( y , p ^ ) = { − l o g ( p ^ ) , if  y = 1 − l o g ( 1 − p ^ ) , if  y = 0 L_{ce}(y,\hat{p})= \left\{ \begin{array}{ll} -log(\hat{p}), & \text{if } y = 1 \\ -log(1-\hat{p}), & \text{if }y=0 \end{array} \right. Lce(y,p^)={log(p^),log(1p^),if y=1if y=0
对于所有样本,需要求平均作为最终的结果:
L = 1 N ∑ i = 1 N l ( y i , p ^ i ) L=\frac{1}{N}\sum_{i=1}^{N}l(y_i,\hat{p}_i) L=N1i=1Nl(yi,p^i)
对于二分类问题,可以改写成:
L = 1 N ( ∑ y i = 1 m − l o g ( p ^ ) + ∑ y i = 0 n − l o g ( 1 − p ^ ) ) L=\frac{1}{N}(\sum_{y_i=1}^{m}-log(\hat{p})+\sum_{y_i=0}^{n}-log(1-\hat{p})) L=N1(yi=1mlog(p^)+yi=0nlog(1p^))
其中,N为样本总数,m和n为正、负样本数, m + n = N m+n=N m+n=N

当样本分布不平衡时,损失函数L的分布也会发生倾斜,若m>>n时,正样本就会在损失函数中占据主导地位,由于损失函数的倾斜,训练的模型会倾向于样本较多的类别,导致对较少样本类别的性能较差。

平衡交叉熵函数

对于样本不平衡造成的损失函数倾斜,最直接的方法就是添加权重因子,提高少数类别在损失函数中的权重,从而平衡损失函数的分布。还是以之前的二分类问题为例,我们添加权重参数 α ∈ [ 0 , 1 ] \alpha∈[0,1] α[0,1]
L = 1 N ( ∑ y i = 1 m − α l o g ( p ^ ) + ∑ y i = 0 n − ( 1 − α ) l o g ( 1 − p ^ ) ) L=\frac{1}{N}(\sum_{y_i=1}^{m}-\alpha log(\hat{p})+\sum_{y_i=0}^{n}-(1-\alpha)log(1-\hat{p})) L=N1(yi=1mαlog(p^)+yi=0n(1α)log(1p^))
其中, α 1 − α = n m \frac{\alpha}{1-\alpha}=\frac{n}{m} 1αα=mn,权重大小由正负样本数量比来设置。

Focal Loss损失函数

Focal Loss从loss角度提供了一种样本不均衡的解决方案:
L f o c a l ( y , p ^ ) = { − ( 1 − p ^ ) γ l o g ( p ^ ) , if  y = 1 − p ^ γ l o g ( 1 − p ^ ) , if  y = 0 L_{focal}(y,\hat{p})= \left\{ \begin{array}{ll} -(1-\hat{p})^\gamma log(\hat{p}), & \text{if } y = 1 \\ -\hat{p}^\gamma log(1-\hat{p}), & \text{if }y=0 \end{array} \right. Lfocal(y,p^)={(1p^)γlog(p^),p^γlog(1p^),if y=1if y=0
p t = { p ^ , if  y = 1 1 − p ^ , otherwise.  p_t= \left\{ \begin{array}{ll} \hat{p}, & \text{if } y = 1 \\ 1-\hat{p}, & \text{otherwise. } \end{array} \right. pt={p^,1p^,if y=1otherwise. 

则表达式统一为:
L f o c a l = − ( 1 − p t ) γ l o g ( p t ) L_{focal}=-(1-p_t)^\gamma log(p_t) Lfocal=(1pt)γlog(pt)
与交叉熵表达式对照: L c e = − l o g ( p t ) L_{ce}=-log(p_t) Lce=log(pt),仅仅多了一个可变系数 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ.

其中, p t p_t pt反应了与ground truth的接近程度,越大表示分类越准。 γ > 0 \gamma>0 γ>0为调节因子。

对于分类不准确的样本, p t → 0 p_t→0 pt0 ( 1 − p t ) γ → 1 (1-p_t)^\gamma→1 (1pt)γ1 L f o c a l → L c e L_{focal}→L_{ce} LfocalLce;对于分类准确的样本, p t → 1 p_t→1 pt1 ( 1 − p t ) γ → 0 (1-p_t)^\gamma→0 (1pt)γ0 L f o c a l → 0 L_{focal}→0 Lfocal0;因此,Focal Loss对于分类不准确的样本,损失没有改变;对于分类准确的样本,损失会变小。整体来看,Focal Loss增加了分类不准确样本在损失函数中的权重。

如下是不同调节因子 γ \gamma γ对应的Loss-proba分布图,可以看出Cross Entropy(CE)和Focal Loss(FL)之间的区别,Focal Loss使损失函数更倾向于难分的样本。

在这里插入图片描述

Focal Loss vs Balanced Cross Entropy

  • Focal Loss是从样本分类难易程度出发,让Loss聚焦于难分类的样本;
  • Balanced Cross Entropy是从样本分布角度对Loss添加权重因子。
    • 缺点:仅仅考虑样本分布,有些难以区分的类别的样本数可能也比较多,此时被BCE赋予了较低的权重,会导致模型很难识别该类别!

Why does Focal Loss work?

Focal Loss从样本难易分类的角度出发,解决了样本不平衡导致模型性能较低的问题。

WHY?

样本不平衡造成的问题就是,样本数少的类别分类难度大,因此Focal Loss聚焦于难分样本,解决了样本少的类别分类精度不高的问题,对于难分样本中样本多的类别,也会被Focal Loss聚焦。因此,它不仅解决了样本不平衡问题,还提升了模型整体性能。

但是,要使模型训练过程中聚焦于难分类样本,仅仅将Loss倾向于难分类样本是不够的,因为模型参数更新取决于Loss的梯度:
w = w − α ∂ L ∂ w w=w-\alpha\frac{\partial L}{\partial w} w=wαwL
若Loss中难分类样本的权重较高,但是难分类样本的Loss梯度为0,难分类样本就不会影响到模型的参数更新。对于梯度问题,Focal Loss中的梯度与 x t x_t xt的关系如下所示,其中 x t = y x x_t=yx xt=yx y ∈ { − 1 , 1 } y∈\{-1,1\} y{1,1}为类别, p t = σ ( x t ) p_t=\sigma(x_t) pt=σ(xt),对于易分样本, x t > 0 x_t>0 xt>0,即 p t > 0.5 p_t>0.5 pt>0.5,由下图可知,此时的导数趋于0。对于难分样本,导数数值较大,因此,学习过程中更聚焦于难分样本。

在这里插入图片描述

难易分类样本是动态的, p t p_t pt在训练的过程中,可能会在难易之间相互转换。

在Loss梯度中,难训练样本起主导作用,参数朝着优化难训练样本的方向改变,变化之后可能会导致原本易训练的样本 p t p_t pt变化,即变成难训练样本。若发生了这种情况会导致模型收敛速度较慢。

为了防止这种难易样本的频繁变化,应该选择较小的学习率。

针对VidHOI数据集

因为VidHOI数据集中的一个人-物对会被多个交互标签同时标注,如< human,next to & watch & hold, cup >,所以会面临multi-class multi-label的分类问题。以往常常使用Binary cross-entropy,能够计算每个交互类别独立于其他类别的损失。但是,VidHOI数据集分布不均且具有长尾分布,为了解决这个不均衡问题同时避免过分强调最频繁类别的重要性,我们采用class-balanced Focal loss:
C B f o c a l ( p i , y i ) = − 1 − β 1 − β n i ( 1 − p y i ) γ l o g ( p y i ) w i t h   p y i = { p i , if  y i = 1 1 − p i , otherwise. CB_{focal}(p_i,y_i)=-\frac{1-\beta}{1-\beta^{n_i}}(1-p_{y_i})^{\gamma}log(p_{y_i}) \\ with \ p_{y_i} = \left\{ \begin{array}{ll} p_i, & \text{if } y_i = 1 \\ 1-p_i, & \text{otherwise.} \end{array} \right. CBfocal(pi,yi)=1βni1β(1pyi)γlog(pyi)with pyi={pi,1pi,if yi=1otherwise.

其中的 − ( 1 − p y i ) γ l o g ( p y i ) -(1-p_{y_i})^{\gamma}log(p_{y_i}) (1pyi)γlog(pyi)是Lin提出的Focal loss, p i p_i pi表示预估为第i个类别的可能性, y i ∈ { 0 , 1 } y_i∈\{0,1\} yi{0,1}表示Ground Truth的label。变量 n i n_i ni表示第i个类别在Ground Truth下的样本量, β ∈ [ 0 , 1 ) \beta∈[0,1) β[0,1)是可调节参数。所有类别的平均损失作为一个预测的损失。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class FocalBCEWithLogitLoss(nn.modules.loss._Loss):
    """Focal Loss with binary cross-entropy
    Implement the focal loss with class-balanced loss, using binary cross-entropy as criterion
    Following paper "Class-Balanced Loss Based on Effective Number of Samples" (CVPR2019)

    Args:
        gamma (int, optional): modulation factor gamma in focal loss. Defaults to 2.
        alpha (int, optional): modulation factor alpha in focal loss. If a integer, apply to all;
            if a list or array or tensor, regard as alpha for each class; if none, no alpha. Defaults to None.
        weight (Optional[torch.Tensor], optional): weight to each class, !not the same as alpha. Defaults to None.
        size_average (_type_, optional): _description_. Defaults to None.
        reduce (_type_, optional): _description_. Defaults to None.
        reduction (str, optional): _description_. Defaults to "mean".
    """

    def __init__(
        self,
        gamma=2,
        alpha=None,
        weight: Optional[torch.Tensor] = None,
        size_average=None,
        reduce=None,
        reduction: str = "mean",
        pos_weight: Optional[torch.Tensor] = None,
    ):
        super(FocalBCEWithLogitLoss, self).__init__(size_average, reduce, reduction)
        self.gamma = gamma
        # a number for all, or a Tensor with the same num_classes as input
        if isinstance(alpha, (list, np.ndarray)):
            self.alpha = torch.Tensor(alpha)
        else:
            self.alpha = alpha
        self.register_buffer("weight", weight)
        self.register_buffer("pos_weight", pos_weight)
        self.weight: Optional[torch.Tensor]
        self.pos_weight: Optional[torch.Tensor]

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        if self.alpha is not None:
            if isinstance(self.alpha, torch.Tensor):
                alpha_t = self.alpha.repeat(input.shape[0], 1)
            else:
                alpha_t = torch.ones_like(input) * self.alpha
        else:
            alpha_t = None
		# 二元交叉熵
        ce = F.binary_cross_entropy_with_logits(input, target, reduction="none")
        # pt = torch.exp(-ce)
        # modulator = ((1 - pt) ** self.gamma)
        # following author's repo https://github.com/richardaecn/class-balanced-loss/blob/master/src/cifar_main.py#L226-L266
        # explaination https://github.com/richardaecn/class-balanced-loss/issues/1
        # A numerically stable implementation of modulator.
        if self.gamma == 0.0:
            modulator = 1.0
        else:
            # e^(-gamma*target*input - gamma*log(1+e^(-input)))
            modulator = torch.exp(-self.gamma * target * input - self.gamma * torch.log1p(torch.exp(-input)))
        # focal loss
        fl_loss = modulator * ce
        # alpha
        if alpha_t is not None:
            alpha_t = alpha_t * target + (1 - alpha_t) * (1 - target)
            fl_loss = alpha_t * fl_loss
        # pos weight
        if self.pos_weight is not None:
            fl_loss = self.pos_weight * fl_loss
        # reduction
        if self.reduction == "mean":
            return fl_loss.mean()
        elif self.reduction == "sum":
            return fl_loss.sum()
        else:
            return fl_loss

C B f o c a l ( p i , y i ) = − 1 − β 1 − β n i ( 1 − p y i ) γ l o g ( p y i ) w i t h   p y i = { p i , if  y i = 1 1 − p i , otherwise. CB_{focal}(p_i,y_i)=-\frac{1-\beta}{1-\beta^{n_i}}(1-p_{y_i})^{\gamma}log(p_{y_i}) \\ with \ p_{y_i} = \left\{ \begin{array}{ll} p_i, & \text{if } y_i = 1 \\ 1-p_i, & \text{otherwise.} \end{array} \right. CBfocal(pi,yi)=1βni1β(1pyi)γlog(pyi)with pyi={pi,1pi,if yi=1otherwise.

原始版本的代码:

def focal_loss(labels, logits, alpha, gamma):
  """Compute the focal loss between `logits` and the ground truth `labels`.

  Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
  where pt is the probability of being classified to the true class.
  pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).

  Args:
    labels: A float32 tensor of size [batch, num_classes].
    logits: A float32 tensor of size [batch, num_classes].
    alpha: A float32 tensor of size [batch_size]
      specifying per-example weight for balanced cross entropy.
    gamma: A float32 scalar modulating loss from hard and easy examples.
  Returns:
    focal_loss: A float32 scalar representing normalized total loss.
  """
  with tf.name_scope('focal_loss'):
    logits = tf.cast(logits, dtype=tf.float32)
    cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=labels, logits=logits)

    # positive_label_mask = tf.equal(labels, 1.0)
    # probs = tf.sigmoid(logits)
    # probs_gt = tf.where(positive_label_mask, probs, 1.0 - probs)
    # # With gamma < 1, the implementation could produce NaN during back prop.
    # modulator = tf.pow(1.0 - probs_gt, gamma)

    # A numerically stable implementation of modulator.
    if gamma == 0.0:
      modulator = 1.0
    else:
      modulator = tf.exp(-gamma * labels * logits - gamma * tf.log1p(
          tf.exp(-1.0 * logits)))

    loss = modulator * cross_entropy

    weighted_loss = alpha * loss
    focal_loss = tf.reduce_sum(weighted_loss)
    # Normalize by the total number of positive samples.
    focal_loss /= tf.reduce_sum(labels)
  return focal_loss

Reference

  1. https://zhuanlan.zhihu.com/p/266023273
  2. https://github.com/nizhf/hoi-prediction-gaze-transformer
  3. https://github.com/richardaecn/class-balanced-loss/blob/master/src/cifar_main.py#L226-L266

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/940853.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式底层驱动需要知道的基本知识

先说结论&#xff0c;能&#xff0c;肯定能&#xff0c;必须能&#xff01; 但是&#xff0c;问题重点在于坚持&#xff0c;程序员这一行 &#xff0c;下班回家一般都要10点了&#xff0c;再刷两个小时枯燥的学习视频&#xff0c;我想大多数人是坚持不下来的。 但是&#xff…

ABB D674A906U01流量计变送器模块

流量测量&#xff1a; 该模块用于准确测量液体或气体的流量&#xff0c;通常以标准单位&#xff08;如立方米每小时或加仑每分钟&#xff09;表示。 传感器技术&#xff1a; 它通常使用各种传感器技术&#xff08;例如涡轮、电磁、超声波等&#xff09;来检测流体的流动并进行…

冠达管理:股票停牌后会大涨吗?

股票停牌是指证券买卖所为了保护市场秩序、保护出资者利益等原因暂时中止某些股票的买卖。但是&#xff0c;股票停牌前的股价与停牌后的股价会有什么不同呢&#xff1f;股票停牌后是否会大涨呢&#xff1f;在本文中&#xff0c;咱们将从多个视点进行剖析&#xff0c;以帮助人们…

合宙Air724UG LuatOS-Air LVGL API控件--按钮 (Button)

按钮 (Button) 按钮控件&#xff0c;这个就不用多说了&#xff0c;界面的基础控件之一。 示例代码 – 按键回调函数 event_handler function(obj, event) if event lvgl.EVENT_CLICKED then print(“Clicked\n”) elseif event lvgl.EVENT_VALUE_CHANGED then print(“To…

java.lang.IllegalStateException: Unable to find

java.lang.IllegalStateException: Unable to find a SpringBootConfiguration, you need to use ContextConfiguration or SpringBootTest(classes…) with your test 错误场景&#xff1a;在使用mybatisplus做测试时&#xff0c;出现此错误 解决方案&#xff1a;SpringBoot…

【MCU】SD NAND芯片之国产新选择

文章目录 前言传统SD卡和可贴片SD卡传统SD卡可贴片SD卡 实际使用总结 前言 随着目前时代的快速发展&#xff0c;即使是使用MCU的项目上也经常有大数据存储的需求。可以看到经常有小伙伴这样提问&#xff1a; 大家好&#xff0c;请问有没有SD卡芯片&#xff0c;可以直接焊接到P…

python可视化matplotlib——绘制正弦和余弦

这是一个使用matplotlib库绘制正弦和余弦函数曲线的代码示例。代码中导入了需要的库&#xff0c;并设置了x轴和y轴的标签字体为华文楷体。然后&#xff0c;使用numpy生成一组x轴上的值t&#xff0c;并使用正弦函数生成对应的y轴值s&#xff0c;再使用余弦函数生成对应的y轴值z。…

使用Tampermonkey(篡改猴)向页面注入js脚本

一、Tampermonkey 简单介绍 Tampermonkey是一款浏览器插件&#xff0c;适用于Chrome、Microsoft Edge、Safari、Opera Next 和 Firefox。他允许我们自定义javascript给指定网页添加功能&#xff0c;或修改现有功能。也可以用来辅助调试&#xff0c;或去除网页广告等。 官网地…

Vulkan LoaderLayer

目录 一、Loader The Loader 二、Layer 调度链Dispatch Chains JSON 一、Loader Vulkan是一个层架构&#xff0c;由Vulkan ApplicationLoaderLayerICDs(Installable Client Drivers)组成。 Vulkan 是一个显式 API&#xff0c;可以直接控制 GPU 的实际工作方式。因此&#x…

get请求报错400 非法参数

get请求报错400 非法参数 背景&#xff1a;get请求数据&#xff0c;SpringBoot提供接口&#xff0c;返回400&#xff0c;报错非法参数此种情况排除接口本身错误之外&#xff0c;检查参数中有没有特殊字符 " < > [ \ ] ^ { | } 我这边就是因为其中一个参数中有中括…

数字货运保持深层角力,满帮业绩与投资价值双丰收

上半年&#xff0c;经济持续活跃&#xff0c;货运物流行业承担着帮助经济要素流通的职责&#xff0c;也成为最直接的受益者。 数字货运平台满帮8月23日发布的财报显示&#xff0c;2023年第二季度&#xff0c;其平台单量、用户量均取得显著增长&#xff0c;并带动平台业绩创下新…

【网络设备】交换机的概念、工作原理、功能以及以太网帧格式

个人主页&#xff1a;insist--个人主页​​​​​​ 本文专栏&#xff1a;网络基础——带你走进网络世界 本专栏会持续更新网络基础知识&#xff0c;希望大家多多支持&#xff0c;让我们一起探索这个神奇而广阔的网络世界。 目录 一、认识交换机 二、交换机的主要功能 1、数…

Android——基本控件(下)(十九)

1. 菜单&#xff1a;Menu 1.1 知识点 &#xff08;1&#xff09;掌握Android中菜单的使用&#xff1b; &#xff08;2&#xff09;掌握选项菜单&#xff08;OptionsMenu&#xff09;的使用&#xff1b; &#xff08;3&#xff09;掌握上下文菜单&#xff08;ContextMenu&am…

STM32F4_SD卡

目录 前言 1. SDIO协议简介 2. SDIO命令及响应 3. SD卡的操作模式及切换 4. STM32的SDIO接口 5. SDIO结构体 6. SDIO相关寄存器 7. 实验程序 7.1 main.c 7.2 SDIO_Card.c 7.3 SDIO_Card.h 前言 在之前的单片机学习过程中&#xff0c;我们已经了解到了单片机系统都需…

SQL server开启变更数据捕获(CDC)

一、CDC简介 变更数据捕获&#xff08;Change Data Capture &#xff0c;简称 CDC&#xff09;&#xff1a;记录 SQL Server 表的插入、更新和删除操作。开启cdc的源表在插入、更新和删除操作时会插入数据到日志表中。cdc通过捕获进程将变更数据捕获到变更表中&#xff0c;通过…

java子类继承父类方法、或者接口中方法的javadoc注释

说明 详情可以阅读&#xff1a; https://docs.oracle.com/en/java/javase/19/docs/specs/javadoc/doc-comment-spec.html#method-comment-inheritance 子类继承父类、或者子类实现接口&#xff0c;在子类中为了避免重复写注释&#xff0c;可以在子类方法注释的主要描述部分、或…

基于GitHooks实现项目自动实时部署

目录 基于GitHooks实现项目自动部署 基于SVNJenkins发布项目 基于GitHooks实现项目自动部署 以上创建的所有任务&#xff0c;构建工作是基于在开发人员提交完代码到远程仓库完成&#xff0c;通知运维后&#xff0c;需要手动执行构建任务&#xff0c;这样就有些不太方便。我们…

智能优化算法一元函数优化

目录 一、问题描述 二、解决方法 1.模拟退火 1.1 算法思路 1.2 求解代码 1.3 计算结果 2.粒子群算法 2.1 算法思路 2.2 求解代码 2.3 计算结果 3.遗传算法 3.1 算法思路 3.2 求解代码 3.3 计算结果 一、问题描述 本篇文章所做的是分别用模拟退火、粒子群算法…

MySQL 8.1安装

1. 下载地址 https://dev.mysql.com/downloads/mysql/8.0.html 我这里没有采用installer安装&#xff0c;因为installer安装依赖visual studio&#xff0c;所以&#xff0c;我下载的是zip文件。 最终下载的版本如下&#xff1a; 2. 添加环境变量 解压&#xff0c;添加环境…

图的存储:邻接表法

1.邻接表的定义 不同于邻接矩阵&#xff08;二维数组存储&#xff09;&#xff0c;邻接表采用的顺序链式存储实现的。 1.存储方式 顶点&#xff1a;使用结构体存储顶点&#xff0c;一个顶点包括顶点信息和指向第一条边或者弧的指针。用邻接表存储图&#xff1a;使用结构体存…