目录
- 1. 文章主要内容
- 2. VariFocal Loss损失函数(原理:简单介绍,可自行详细研究)
- 2.1 VariFocal Loss损失函数
- 2.2 博主数据集实验效果
- 3. 代码详细改进流程(重要)
- 3.1 新建varifocalLoss.py文件
- 3.2 修改hyp.scratch-low.yaml文件
- 3.3 修改loss.py文件
- 4. 本篇小结
1. 文章主要内容
本篇博客主要涉及两个主体内容。第一个:简单介绍VariFocal Loss的原理。第二个:基于YOLOv5 6.0版本,将损失函数替换为VariFocal Loss的详细调试步骤(通读本篇博客大概需要5分钟左右的时间)。
2. VariFocal Loss损失函数(原理:简单介绍,可自行详细研究)
2.1 VariFocal Loss损失函数
VariFocal Loss是从Focal Loss而来,所以我们要首先了解Focal Loss。Focal Loss提出来要解决的问题是训练数据中,正负样本不均衡的问题。 何为正负样本不均衡?比如说,我们训练的图片样本,尤其是包含很多小目标的图片样本,其实要检测的目标(也就是我们说的正样本)只占图片区域的少部分(综合来看),大部分的区域则为背景区域(也就是我们说的负样本);这就会导致训练数据负样本占多,而正样本相对来说占少数,模型的训练效果会变差,Focal Loss给难分、易分样本加上权重因子,提升难分样本的权重,降低易分样本的权重,从而控制正负样本均衡的问题,其中背景类一般为易分样本,而目标类为难分样本。 同时,Focal Loss适合检测密集型目标的图片样本,这个对小尺寸、拥挤、遮挡等特点的数据集会有不错的效果。
VariFocal Loss是在Focal Loss的基础上提出的,因为Focal Loss对正负样本的处理是均衡的,而varifocal loss仅减少了负样本的损失贡献,而不以同样的方式降低正样本的权重。具体的公式、原理还请查看原论文或者网上的文章解析。
原论文地址:Focal Loss 论文、VariFocal Loss 论文
2.2 博主数据集实验效果
博主所训练的数据集特点:小尺寸目标居多,密集且目标尺寸不一,实验数据如下所示:
原YOLOv5s框架实验数据:P(查准率):0.935、R(召回率):0.927、mAP@0.5(平均检测精度):0.942
YOLOv5s+VFLoss:P(查准率):0.974、R(召回率):0.95、mAP@0.5(平均检测精度):0.962
由实验数据对比,YOLOv5s+VariFocal Loss能够极大的提升R,同时P、mAP也有所提升,提升R指标,就表明能够检测更多的目标,可以减少模型漏检的问题,且FLOPs和原YOLOv5模型一样。
3. 代码详细改进流程(重要)
3.1 新建varifocalLoss.py文件
( 注意:博主使用的是Pycharm集成开发工具)首先在data->tricks目录下新建一个叫varifocal loss的py文件( 注意:tricks文件夹是自定义创建的,没有自己创建一个即可),将如下代码复制到varifocal loss的py文件中:
import torch
from torch import nn
class VFLoss(nn.Module):
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super(VFLoss, self).__init__()
# 传递 nn.BCEWithLogitsLoss() 损失函数 must be nn.BCEWithLogitsLoss()
self.loss_fcn = loss_fcn #
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' # required to apply VFL to each element
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
pred_prob = torch.sigmoid(pred) # prob from logits
focal_weight = true * (true > 0.0).float() + self.alpha * (pred_prob - true).abs().pow(self.gamma) * (
true <= 0.0).float()
loss *= focal_weight
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
另外项目的data-tricks目录结构如下所示:( 之所以要新建文件,是为了方便,清晰的分辨哪些创新点,而不是一股脑都放在一个文件中)
3.2 修改hyp.scratch-low.yaml文件
在data/hyps文件夹下面,找到hyp.scratch-low.yaml(注意:hyps下面有多个yaml文件,博主这里修改hyp.scratch-low.yaml,是因为在train.py文件中调用了此文件,如果你调用了是另外的yaml文件,则需要在你调用的那个yaml文件中做修改),修改fl_gamma的这一行值,原本是0.0,这里修改为1.5即可,如下图所示:
3.3 修改loss.py文件
loss.py文件在utils文件夹下面,打开并定位到如下代码部分(大概是111行),修改成如下的代码所示:
BCEcls, BCEobj = VFLoss(BCEcls, g), VFLoss(BCEobj, g)
这里用VFLoss替换了YOLOv5的分类、置信度损失,回归框损失没有替换。同时,我们注意到,g = h[‘fl_gamma’] 这行代码就是hyp.scratch-low.yaml的fl_gamma值,设置为1.5(1.5为经验值)就可以进入到if g > 0 的条件当中,另外需要导入VFLoss的引用,不然会报错,只需要在文件首部添加from data.tricks.varifocalLoss import VFLoss
即可。
4. 本篇小结
本篇博客主要介绍了VF+YOLOv5的修改详细流程,助力模型对小目标、多尺寸、密集目标的检测涨点。另外,在修改过程中,要是有任何问题,评论区交流;如果博客对您有帮助,请帮忙点个赞,收藏一下;后续会持续更新本人实验当中觉得有用的点子,如果很感兴趣的话,可以关注一下,谢谢大家啦!