目录
前言
交叉熵损失函数
平衡交叉熵
Focal Loss
代码实现
前言
Focal loss是一个常用的解决类别不平衡问题的损失函数,由何恺明提出的(论文名称:Focal Loss for Dense Object Detection),用于图像领域解决one-stage目标检测中正负样本极不平衡和难分类样本学习问题。本文从交叉熵损失函数出发,分析样本不平衡问题,将focal loss与交叉熵损失函数对比,给出focal loss有效性的解释。
交叉熵损失函数
分类通常用到交叉熵,而且Focal Loss 也是基于交叉熵进行改进的,先介绍一下交叉熵的原理,会更易于理解Focal Loss。
二分类交叉熵损失函数,公式定义如下:
现定义如下的:
得到变形后的损失函数如下:
平衡交叉熵
由于存在正负样本极不平衡的问题,直接使用交叉熵损失函数,得到的效果不好。于是,首先平衡交叉熵。
一般为了解决类别不平衡的问题,会在损失函数中每个类别前增加一个权重因子 ∈ [0, 1]来协调类别不平衡。使用类似的方式定义,得到二分类平衡交叉熵损失函数:
平衡交叉熵采用平衡正负样本的重要性,但是没有区分难易样本。然后,类间不均衡较大会导致,交叉熵损失在训练的时候受到影响。易分类的样本的分类错误的损失占了整体损失的绝大部分,并主导梯度,会压垮交叉熵损失函数。
Focal Loss
Focal Loss在平衡交叉熵损失函数的基础上,增加一个调节因子降低易分类样本权重,聚焦于困难样本的训练,其定义如下:
权重帮助处理了类别的不均衡。
其中,是调节因子,≥ 0是可调节的聚焦参数,下图展示了 ∈ [0, 5]不同值时focal loss曲线。
γ
控制曲线的形状. γ
的值越大, 好分类样本的loss就越小, 我们就可以把模型的注意力投向那些难分类的样本. 一个大的 γ
让获得小loss的样本范围扩大了。同时,当γ=0
时,这个表达式就退化成了Cross Entropy Loss (交叉熵损失函数)。
在上图中,“蓝”线代表交叉熵损失。X轴即“预测为真实标签的概率”(为简单起见,将其称为pt)。Y轴是给定pt后Focal loss和CE的loss的值。
从图像中可以看出,当模型预测为真实标签的概率为0.6左右时,交叉熵损失仍在0.5左右。因此,为了在训练过程中减少损失,我们的模型将必须以更高的概率来预测到真实标签。换句话说,交叉熵损失要求模型对自己的预测非常有信心。但这也同样会给模型表现带来负面影响。
深度学习模型会变得过度自信, 因此模型的泛化能力会下降。
当使用γ> 1的Focal Loss可以减少“分类得好的样本”或者说“模型预测正确概率大”的样本的训练损失,而对于“难以分类的示例”,比如预测概率小于0.5的,则不会减小太多损失。
Focal Loss特点:
- 当很小时(样本难分,不管分的是否正确),调节因子趋近1,损失函数中样本的权重不受影响;当很大时(样本易分,不管分的是否正确),调节因子趋近0,损失函数中样本的权重下降很多
- 聚焦参数可以调节易分类样本权重的降低程度,越大权重降低程度越大
通过分析Focal Loss函数的特点可知,该损失函数降低了易分类样本的权重,聚焦在难分类样本上。
代码实现
class WeightedFocalLoss(nn.Module):
"Non weighted version of Focal Loss"
def __init__(self, alpha=.25, gamma=2):
super(WeightedFocalLoss, self).__init__()
self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
self.gamma = gamma
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
targets = targets.type(torch.long)
at = self.alpha.gather(0, targets.data.view(-1))
pt = torch.exp(-BCE_loss)
F_loss = at*(1-pt)**self.gamma * BCE_loss
return F_loss.mean()
主要参考了这篇文章:Focal Loss损失函数详解
这篇:focal loss 通俗讲解
提出它的论文解读:Focal loss论文详解