文章目录
- 一、论文链接
- 二、公式理解
- 代码
一、论文链接
https://arxiv.org/pdf/2008.13367.pdf
二、公式理解
简单说明下,这里的IACS是IoU-aware classification score的缩写。VFL原文里面这个target socre也就是q,是一个和IOU有关的软标签。对于挑选出的正样本(也就是和GT相交的样本)的分类标签分数不再是0、1两个极端,而是与IOU有关的标签q,这样的好处在于:比如IOU很大,标签就是0.9,IOU很小,标签就是0.1,直觉上也更有道理。
当q>0时,VFL对于正样本没有超参,也就是没有任何衰减;而当q=0时,VFL对于负样本是有超参项的,gama会减少负样本贡献,alpha是为了防止过度抑制,总体来说是会减少负样本的贡献的。这就是和focal loss的不同之处,focal loss两项都有超参,正样本的质量也是会被降低的。
代码
下面代码抄自PaddleDetection套件:
def varifocal_loss(pred,
target,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
use_sigmoid=True):
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
Args:
pred (Tensor): The prediction with shape (N, C), C is the
number of classes
target (Tensor): The learning target of the iou-aware
classification score with shape (N, C), C is the number of classes.
alpha (float, optional): A balance factor for the negative part of
Varifocal Loss, which is different from the alpha of Focal Loss.
Defaults to 0.75.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
iou_weighted (bool, optional): Whether to weight the loss of the
positive example with the iou target. Defaults to True.
"""
# pred and target should be of the same size
assert pred.shape == target.shape
if use_sigmoid:
pred_new = F.sigmoid(pred)
else:
pred_new = pred
target = target.cast(pred.dtype)
if iou_weighted:
focal_weight = target * (target > 0.0).cast('float32') + \
alpha * (pred_new - target).abs().pow(gamma) * \
(target <= 0.0).cast('float32')
else:
focal_weight = (target > 0.0).cast('float32') + \
alpha * (pred_new - target).abs().pow(gamma) * \
(target <= 0.0).cast('float32')
if use_sigmoid:
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
else:
loss = F.binary_cross_entropy(
pred, target, reduction='none') * focal_weight
loss = loss.sum(axis=1)
return loss
和公式有一点点不同的地方是q=0时,VFL=alpha * (pred_new - target).abs().pow(gamma),猜测target实际可能并不一定等于0,当然也兼容=0的情况。