深度学习炼丹-不平衡样本的处理

news2025/1/11 14:22:30
  • 前言
  • 一,数据层面处理方法
    • 1.1,数据扩充
    • 1.2,数据(重)采样
      • 数据采样方法总结
    • 1.3,类别平衡采样
  • 二,算法(损失函数)层面处理方法
    • 2.1,Focal Loss
    • 2.2,损失函数加权
  • 参考资料

前言

在机器学习的经典假设中往往假设训练样本各类别数目是均衡的,但在实际场景中,训练样本数据往往都是不均衡(不平衡)的。比如在图像二分类问题中,一个极端的例子是,训练集中有 95 个正样本,但是负样本只有 5 个。这种类别数据不均衡的情况下,如果不做不平衡样本的处理,会导致模型在数目较少的类别上出现“欠学习”现象,即可能在测试集上完全丧失对负样本的预测能力。

除了常见的分类、回归任务,类似图像语义分割、深度估计等像素级别任务中也是存在不平衡样本问题的。

解决不平衡样本问题的处理方法一般有两种:

  1. 从“数据层面”入手:分为数据采样法和类别平衡采样法。
  2. 从“算法层面”入手:代价敏感方法。

注意本文只介绍不平衡样本的处理思想和策略,不涉及具体代码,在实际项目中,需要针对具体人物,结合不平衡样本的处理策略来设计具体的数据集处理或损失函数代码,从而解决对应问题。

一,数据层面处理方法

数据层面的处理方法总的来说分为数据扩充和数据采样法,数据扩充会直接改变数据样本的数量和丰富度,采样法的本质是使得输入到模型的训练集样本趋向于平衡,即各类样本的数目趋向于一致。

数据层面的采样处理方法主要有两种策略:

  1. 数据重采样方法,发生在数据预处理阶段,会改变整体训练集的数目和分布。
  2. 类别平衡采样方法,发生在数据加载阶段(这里的加载是指加载到模型中,不是指从硬盘中读取文件),通过设置采样策略来使得不同类别样本送入模型训练总的次数是近似的。

1.1,数据扩充

所谓数据不平衡,其实就是某些类别的数据量太少,那就直接增加一些呗,简单直接。如果有的选,那肯定是优先选择重新采取数据的办法了,当然大部分时候我们都没得选,这个时候最有效的办法自然是通过数据增强来扩充数据了。

数据增强的手段有多种,常见的如下:

  • 水平 / 竖直翻转
  • 90°,180°,270° 旋转
  • 翻转 + 旋转(旋转和翻转其实是保证了数据特征的旋转不变性能被模型学习到,卷积层面的方法可以参考论文 ACNet)
  • 亮度,饱和度,对比度的随机变化
  • 随机裁剪(Random Crop)
  • 随机缩放(Random Resize)
  • 加模糊(Blurring)
  • 加高斯噪声(Gaussian Noise)

值得注意的是数据增强手段的使用必须结合具体任务而来,除了前三种以外,其他的要慎重考虑。因为不同的任务场景下数据特征依赖不同,比如高斯噪声,在天池铝材缺陷检测竞赛中,如果高斯噪声增加不当,有些图片原本在采集的时候相机就对焦不准,导致工件难以看清,倘若再增加高斯模糊属性,部分图片样本基本就废了。

参考文章 如何针对数据不平衡做处理。

虽然目前深度学习框架中都自带了一些数据增强函数,但更多更强的数据增强手段可以使用一些图像增强库,比如 imgaug 这个 python 库。

模型训练过程中,pytorch 框架如何在数据构建 pipeline 阶段使用 imgaug 库可以参考文章 数据增强-imgaug。

1.2,数据(重)采样

简单的数据重采样方法分为数据上采样over-samplingup-sampling,也叫数据过采样) 或 也叫数据欠采样数据下采样(under-samplingdown-sampling )。

1,对于样本数目较少的类别,可用数据过采样方法over-sampling),即通过复制方法使得该类图像数目增至与样本最多类的样本数一致。

2,而对于样本数较多的类别,可使用数据欠采样Under-sampling,也叫数据欠采样)方法。对于深度学习和计算机视觉领域的任务来说,下采样并不是直接随机丢弃一部分图像,正确的下采样策略是: 在批处理训练时(数据加载阶段 dataloader),对于样本较多的类别,严格控制每批(batch)随机抽取的图像数目,使得每批读取的数据中正负样本是均衡的(类别均衡)。以二分类任务为例,假设原始数据分布情况下每批处理训练正负样本平均数量比例为 9:1,如仅使用下采样策略,则可在每批随机挑选训练样本时每 9 个正样本只取 1 个作为该批训练集的正样本,负样本选择策略不变,这样可使得每批读取的训练数据中正负样本时平衡的。

数据过采样和欠采样示意图如下所示。

数据采样方法总结

数据过采样和欠采样本质的简单理解就是“增加图片”和“删图片”:

  • 过采样:重复正比例数据,实际上没有为模型引入更多形式数据,过分强调正比例数据,会放大正比例噪音对模型的影响。
  • 欠采样:丢弃大类别的部分数据,和过采样一样会存在过拟合的问题。

同时两种数据重采样方法都是会改变数据原始分布的,比如数据过采样增加较小类别的样本数,数据欠采样减少较大类别的样本数,有可能产生模型过拟合等问题

这里的较小类别的意思是样本数目较少的类别,较大类别即样本数目较多的类别。

以上内容都是对解决类别不平衡问题中数据采样方法的策略描述,但想要在实际任务中解决问题,还要求我们加深对任务(task)的分析、对数据的理解分析,以及要求我们有更多的数据处理、数据采样的代码经验,即良好的策略 + 熟练的工具。

需要注意的是,因为仅仅使用数据上采样策略有可能会引起模型过拟合问题,所以在实际任务中,更为保险的数据采样策略哇往往是将上采样和下采样结合起来使用。

1.3,类别平衡采样

前面的数据重采样策略是着重于类别样本数量,而另一类采样策略则是直接着重于类别本身,不改变数据总体样本数,即类别平衡采样方法。其简单策略是把样本按类别分组,每个类别生成一个样本列表,训练过程中随机选择 1 个或几个类别,然后从每个类别所对应的样本列表中随机选择样本,这样可保证每个类别参与训练的机会比较均衡。

上述类别平衡方法过于简单,实际应用中有很多限制,比如在类别数很多的多分类任务中(如 ImageNet 数据集)。由此,在类别平衡采样的基础上,国内海康威视研究院提出了一种“类别重组采样”的平衡方法

类别重组法是在《解析卷积神经网络》这本书中看到的,可惜没在网上找到原论文和代码,但这个方法感觉还是很有用的,且也比较好复现。

如下图所示,类别重组方法步骤如下:

类别重组法步骤示意图

  1. 对原始样本的每个类别的样本分别排序好,计算每个类别的样本数目,并记录样本数最多的那个类别的样本数量 max_num
  2. 基于最大样本数 max_num 产生一个随机数列表,然后用此列表中的随机数对各自类别的样本数求余,得到对应索引值列表 index_listrandom.shuffle(list(range(max_num)))
  3. 根据该索引值列表 index_list,从该类的图像数据中提取图像,生成该类的图像随机列表。
  4. 最后吧所有类别的随机列表连接在一起后一起随机打乱次序,即可得到最终的图像列表,可以发现最终的这个图像随机列表中每个类别的样本数目是一致的(样本数较少的类别,图像会存在多次采样)。然后每轮(epoch)都对此列表进行遍历数据用于模型训练,如此重复。

如何得到一个随机整数列表

类别重组法对有点很明显,在设计好重组代码函数后,只需要原始图像列表即可,所有操作都在内存中在线完成,易于实现且更通用。其实仔细深究可以发现,海康提出的这个类别重组法和前面的数据采样方法是很类似的,其本质都是通过采样(sampler)策略让类别不均衡的各类数据在每轮训练中出现的次数是一致的

二,算法(损失函数)层面处理方法

类别不平衡问题的本质是导致样本数目较少的类别出现“欠学习”这一机器学习现象,直观表现是较小样本的损失函数权重占比也较少。一个很自然的解决办法是增加小样本错分的惩罚代价,并将此代价直接体现在目标函数(损失函数)里,这就是“代价敏感”的方法。“代价敏感”方法的本质可以理解为调整模型在小类别上的注意力。

2.1,Focal Loss

Focal Loss 是在二分类问题的交叉熵(CE)损失函数的基础上引入的,主要是为了解决 one-stage 目标检测中正负样本比例严重失衡的问题,该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘,经实践证明 Focal Lossone-stage 目标检测中还是很有效的,但是在多分类中不一定有效。

Focal Loss 作者通过在交叉熵损失函数上加上一个调整因子(modulating factor ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ,把高置信度 p p p(易分样本)样本的损失降低一些。Focal Loss 定义如下:

F L ( p t ) = − ( 1 − p t ) γ l o g ( p t ) = { − ( 1 − p ) γ l o g ( p ) , i f y = 1 − p γ l o g ( 1 − p ) , i f y = 0 FL(p_t) = -(1-p_t)^\gamma log(p_t) = \left\{\begin{matrix} -(1-p)^\gamma log(p), & if \quad y=1 \\ -p^\gamma log(1-p), & if\quad y=0 \end{matrix}\right. FL(pt)=(1pt)γlog(pt)={(1p)γlog(p),pγlog(1p),ify=1ify=0

Focal Loss 有两个性质:

  • 当样本被错误分类且 p t p_t pt 值较小时,调制因子接近于 1loss 几乎不受影响;当 p t p_t pt 接近于 1,调质因子(factor)也接近于 0容易分类样本的损失被减少了权重,整体而言,相当于增加了分类不准确样本在损失函数中的权重。
  • γ \gamma γ 参数平滑地调整容易样本的权重下降率,当 γ = 0 \gamma = 0 γ=0 时,Focal Loss 等同于 CE Loss γ \gamma γ 在增加,调制因子的作用也就增加,实验证明 γ = 2 \gamma = 2 γ=2 时,模型效果最好。

直观地说,调制因子减少了简单样本的损失贡献,并扩大了样本获得低损失的范围。例如,当 γ = 2 \gamma = 2 γ=2 时,与 C E CE CE 相比,分类为 p t = 0.9 p_t = 0.9 pt=0.9 的样本的损耗将降低 100 倍,而当 p t = 0.968 p_t = 0.968 pt=0.968 时,其损耗将降低 1000 倍。这反过来又增加了错误分类样本的重要性(对于 p t ≤ 0.5 pt≤0.5 pt0.5 γ = 2 \gamma = 2 γ=2,其损失最多减少 4 倍)。在训练过程关注对象的排序为正难 > 负难 > 正易 > 负易。

1. 正难3. 正易, γ \gamma γ 衰减
2. 负难, α \alpha α 衰减4. 负易, α 、 γ \alpha、\gamma αγ衰减

在实践中,我们通常采用带 α \alpha αFocal Loss

F L ( p t ) = − α ( 1 − p t ) γ l o g ( p t ) FL(p_t) = -\alpha (1-p_t)^\gamma log(p_t) FL(pt)=α(1pt)γlog(pt)

作者在实验中采用这种形式,发现它比非 α \alpha α 平衡形式(non- α \alpha α-balanced)的精确度稍有提高。实验表明 γ \gamma γ 取 2, α \alpha α 取 0.25 的时候效果最佳。

更多理解参考 focal loss 论文。

2.2,损失函数加权

除了 Focal Loss 这种高明的损失函数策略外,针对图像分类问题,还有一种简单直接的损失函数加权方法,即在计算损失函数过程中,对每个类别的损失做加权处理,具体的 PyTorch 实现方式如下:

weights = torch.FloatTensor([1, 1, 8, 8, 4]) # 类别权重分别是 1:1:8:8:4
# pos_weight_weight(tensor): 1-D tensor,n 个元素,分别代表 n 类的权重,
# 为每个批次元素的损失指定的手动重新缩放权重,
# 如果你的训练样本很不均衡的话,是非常有用的。默认值为 None。
criterion = nn.BCEWithLogitsLoss(pos_weight=weights).cuda()

参考资料

  • 《解析卷积神经网络》
  • 如何针对数据不平衡做处理
  • 10 Techniques to deal with Imbalanced Classes in Machine Learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/71020.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Akka 学习(三)Actor的基本使用

目录一 基本案例1.1 Java 版1.2 Scala版二 Actor的创建2.1 ActorRef2.2 Props2.3 ActorSelection三 Promise、Future和事件驱动的编程模型3.1 阻塞与事件驱动3.2 Future进行Actor响应3.2.1 Java版3.2.2 Scala 版3.2.3 总结3.3 成功处理3.4 失败处理3.5 恢复3.6 链式调用3.7 结果…

小程序开发工具怎么使用?

小程序开发工具怎么用? 小程序开发工具分两种: 一种是微信官方提供的微信开发者工具 这个需要从事代码行业,职业是程序员又或者对代码知识有一定程度的人,才能上手使用。 另一种是第三方小程序开发平台,提供的小程序开发工具 …

python环境、基础语法、几种常见的数据类型

文章目录前言一、基本知识介绍二、举例实操以及重要知识再现(列表、元组、集合、字典)前言 一、基本知识介绍 python基础 标准库与扩展库中的对象的导入与使用: import 模块名(as别名) import numpy as np from 模块名 import 对象名&#x…

程序人生:快来一起学习软件测试,一起月薪过万(测试理论基础学习)

测试基础 为什么要有测试呢?现在软件已经和人的生活息息相关了,所以保证软件的稳定很重要。但是所有开发出来的软件都是有缺陷的。包括代码错误,逻辑错误,设计不合理等。 测试的目的 测试的目的主要有四个点 1找到软件缺陷 2…

Flink SQL增量查询Hudi表

前言 前面总结了Spark SQL增量查询Hudi表和Hive增量查询Hudi表。最近项目上也有Flink SQL增量查询Hudi表的需求,正好学习总结一下。 官网文档 地址:https://hudi.apache.org/cn/docs/querying_data#incremental-query 参数 read.start-commit 增量查…

WWW2022 | 基于领域增强的图对比协同过滤方法+代码实践

嘿,记得给“机器学习与推荐算法”添加星标今天跟大家分享一篇将对比学习应用于图协同过滤方法的文章,该论文发表于WWW2022会议上。其主要思想是在图神经网络协同过滤方法上应用了两种领域类型的对比学习方法,分别是显式的结构领域和隐式的语义…

TGK-Planner-前后端路径规划(基于梯度的后端无约束优化)

高速移动无人机的在线路径规划一直是学界当前研究的难点,引起了大量机器人行业的研究人员与工程师的关注。然而无人机的计算资源有限,要在短时间内规划出一条安全可执行的路径,这就要求无人机的运动规划算法必须轻型而有效。本文将介绍一种无…

electron-vue中报错 Cannot use import statement outside a module解决方案(亲测有效!!!)

错误: Cannot use import statement outside a module(不能在模块之外使用导入语句)。 原因: 安装的某个依赖包里使用了import语法,因为我们打包输出的是commonjs规范,所以不识别import语法而导致报错。 可以从 .electron-vue/w…

PrimoBurnerSDK蓝光刻录工具开发工具包

PrimoBurnerSDK蓝光刻录工具开发工具包 PrimoBurnerSDK是一个CD、DVD和蓝光刻录工具开发工具包。它还提供了一个全面灵活的API,用于快速轻松地实现各种燃烧/翻录替代方案。 PrimoBurner SDK for.NET的强大功能: 自2003年以来一直在发展的广泛使用的老式发…

比机器人还智能的数字孪生地下停车场监管系统!

现在的停车场管理大多采用人工或智能收费系统,两种方式都有一个弊端就是无法直接知晓停车场内部信息。 车驶入停车场只能自行寻找停车位,工作人员也只有走进停车场才能知晓停车场内部情况,无可避免造成很多麻烦。 停车场智慧监管系统结合数…

期货开户交易操作技巧

期货交易的时候需要有一些操作技巧,以及要注意一些操作上常见的错误。 个人建议刚刚开始交易的投资者期货交易的投资者,一定要多看慢做,首先要摒弃做这个会一夜暴富的想法。抱着个想法来的往往都会折戟沉沙,一去不复返了。所以我…

基于springboot+mybatis+mysql+vue中学生成绩管理系统

基于springbootmybatismysqlvue中学生成绩管理系统一、系统介绍二、功能展示1.登陆2.用户管理(管理员)3.班主任信息管理(管理员)4.教师信息管理(管理员、班主任)5.学生信息管理(管理员)6.成绩信息管理(管理员、班主任、…

一个人,仅30天!开发一款3D竞技足球游戏!他究竟经历了些什么?

今天,晓衡向大家推荐一款Coco Store 优质 3D足球竞技游戏 资源《足球快斗》玩法介绍:游戏为 7V7 足球竞技类玩法。玩家控制本队的一个球员(脚下高亮圆圈显示的是玩家),其他球员和守门员为电脑AI控制,期间可…

Jvm上如何运行其他语言?JSR223规范最详细讲解

一 在Java的平台里,其实是可以执行其他的语言的。包括且不仅限于jvm发展出来的语言。 有的同学可能会说,在java项目里执行其他语言,这不吃饱了撑着么,java体系那么庞大,各种工具一应俱全,放着好好的java不…

责任链模式在复杂数据处理场景中的实战

相信大家在日常的开发中都遇到过复杂数据处理和复杂数据校验的场景,本文从一线开发者的角度,分享了责任链模式在这种复杂数据处理场景下的实战案例,此外,作者在普通责任链模式的基础上进行了升级改造,可以适配更加复杂…

34_DAC原理及数模转换实验

目录 数模转换原理 DAC模块框图 事件选择控制数字模拟转换 DAC转换 DAC数据格式 选择DAC触发 DAC输出电压计算 硬件连接 DAC配置步骤 实验源码 数模转换原理 STM32的DAC模块(数字/模拟转换模块)是12位数字输入,电压输出型的DAC。DAC可以配置为8位或12位模式,也可以与…

linux安装nginx

1.nginx官网 http://nginx.org/en/download.html 下载安装包,如图所示下载nginx-1.23.2,并上传到指定目录:/usr/local/src/nginx 2.解压 tar -zxvf nginx-1.23.2.tar.gz3.安装nginx, cd /usr/local/src/nginx/nginx-1.23.2 该目录…

Titanic 泰坦尼克数据集 特诊工程 机器学习建模

以下内容为讲课时使用到的泰坦尼克数据集分析、建模过程,整体比较完整,分享出来,希望能帮助大家。部分内容由于版本问题,可能无法顺利运行。 Table of Contents 1 经典又有趣的Titanic问题1.1 目标1.2 解决方法1.3 项目目的2…

Vector-常用CAN工具 - CANoe入门到精通_03

NetWork Node 前面已经介绍了CANoe的基本情况、硬件环境搭建、CANoe软件环境配置,今天我们就来聊一下NetWork Node,在我们的测试工作中,大部分情况我们默认CANoe作为一个Client端,但是有些情况,我们需要实时监测被测件…

Akka 学习(四)Remote Actor

目录一 介绍1.1 Remote Actor1.2 适用场景1.3 踩坑点二 实战2.1 需求2.2 Java 版本2.2.1 效果图2.2.2 实体类2.2.3 服务端Actor 处理2.2.4 服务端配置文件2.2.5 客服端Actor处理2.2.6 客服端配置文件2.2.7 测试2.3 Scala 版本2.3.1 效果2.2.3 服务端Actor处理2.3.4 客户端Actor…