【完整】分类模型中类别不均衡问题解决

news2024/12/24 2:30:26

目录

1. 数据类别不均衡问题

2. 解决办法

过采样:

欠采样:

ensemble 方法:

修改损失函数:

梯度调和机制:

Dice Loss:

标签平滑:

3. 类别不均衡问题loss设计

4. 梯度调和机制GHM Gradient Harmonizing Mechanism

(1)简介

(2) 理论

(3) 计算优化

5. Dice Loss

(1)Dice Coeffificient

(2)自调节

 6. 标签平滑Label Smoothing

(1)cross-entropy

(2)存在问题

 (3)正则化

7. 总结

8. 代码


1. 数据类别不均衡问题

常见的分类任务大部分数据的标签都是某几类,而很少的类别的数据有时也很重要,因而需要模型去预测。

2. 解决办法

过采样:

对于某些类别数据比较少,对它们进行重复采样,以达到相对平衡,重复采样的时候,有时也会对数据加上一点噪声

缺点:过采样可能导致这些类别产生过拟合的现象

欠采样:

对于某些类别数据特别多,只使用部分数据,抛弃一些数据;

缺点:欠采样则容易导致模型的泛化性变差

ensemble 方法:

结合 ensemble 方法,则将数据切分为 N 部分,每部分都包含数据少的类别的所有样本和数据多的类别的部分样本,训练 N 个模型,最后进行集成。

缺点:使用 ensemble 则会提高部署成本和带来性能问题

修改损失函数:

详情见后文第3部分

梯度调和机制:

详情见后文第4部分

Dice Loss:

详情见后文第5部分

标签平滑:

详情见后文第6部分

3. 类别不均衡问题loss设计

Focal Loss 是一种专门为类别不均衡设计的loss,在《Focal Loss for Dense Object Detection》这篇论文中被提出来,应用到了目标检测任务的训练中,但其实这是一种无关特定领域的思想,可以应用到任何类别不均衡的数据中。

首先,还是从二分类的cross entropy(CE)loss入手:

为了符号的方便,

 

 

 p\in [0,1]为模型对于 label=1(ground-truth)的类别的预测概率。 

 下图的蓝色曲线为原生的 CE loss,容易看出来,那些容易分类(p_{t}\gg 0.5)的样本也会产生不小的 loss,但这些大量的容易样本的 loss 加起来,会压过那些稀少类别的样本的 loss。

改进方法为:给 CE loss 增加一个权重因子\alpha ,正样本权重因子为 \alpha \in [0,1],负样本为 1-\alpha

实际使用中,一般设置为类别的逆反频率,即频率低的类别权重应该更大,比如稀少的正样本的 \alpha为负样本的频率。或者当作一个超参数。

但是,这种做法只是平衡了正负样本的重要性,无法区分容易(easy)样本和困难(hard)样本,这也是类别不均衡的数据集很容易出现的问题:容易分类的样本贡献了大部分的 loss,并且主导了梯度。

因此,Focal Loss 的主要思想就是让 loss 关注那些困难样本,而降低容易样本的重要性。

 

 如上式,在 CE 的基础增加一个调节因子 \left ( 1-p_{t} \right )^{\gamma }。上图 1 可以看出,\gamma 越大,容易样本的 loss 贡献越小。

Focal Loss 具有以下两个属性:

1. 当一个样本被错误分类时,且p_{t}  很小时(即为困难样本),那么调节因子是接近 1 的,loss 则基本不受影响。而相反的,当 p_{t}->1,分类很好的样本(容易样本),调节因子则会偏于 0,loss 贡献变得很小;

2. 不同的 \gamma 参数可以平滑地调整容易样本的重要性降低的比率。当\gamma =0  时,则等同于普通的 CE。而当 \gamma  变大时,那么调节因子的影响也会同样变大,即容易样本的重要性会降低。

论文在实验中,Focal Loss 还保留上述的权重因子\alpha _{t} : 

 通常来说,当增加\gamma时,\alpha 应该稍微降低。

在作者的实验中, \gamma =2,\alpha =0.25取得了最佳性能。

缺点:

1. Focal Loss 存在两个超参数,并且是互相影响,构成许多参数组合,会导致调参需要很多尝试成本

2. Focal Loss 是一种静态的 loss,那么同一种超参数无法适用于不同的数据分布;

4. 梯度调和机制GHM Gradient Harmonizing Mechanism

出自《Gradient Harmonized Single-stage Detector》论文

(1)简介

GHM 有两大观点:

观点 A:与focal loss类似

模型从容易分类的样本的到收益很少,模型应该关注那些困难分类的样本,不管它属于哪一种类别,但大量的容易样本加起来的贡献会盖过困难样本,使得训练效益很低;

观点 B:

1. Focal Loss 存在两个超参数,并且是互相影响,构成许多参数组合,会导致调参需要很多尝试成本;并且,Focal Loss 是一种静态的 loss,那么同一种超参数无法适用于不同的数据分布;

下图展示了上述观点 ,梯度范数 gradient norm 的大小则代表样本的分类难易程度,收益实际即对应为梯度;

 2. 有一些特别困难分类的样本,它们很可能是离群点,加入这些样本的训练,会影响模型的稳定性;

3. 提出了 gradient density(梯度密度)的梯度调和机制,来缓解这种类别不均衡的问题。

通过 GHM 的梯度调和之后,容易样本的 gradient norm 会被削弱许多,并且特别困难的样本也会被轻微削弱,分别对应观点 A 和观点 B-2 的解决方案。

(2) 理论

其主要思想是:首先仍然是降权大量容易样本贡献的梯度总和,其次是对于那些特别困难样本即离群点,也应当相对地降权

对于二分类问题,同样的交叉熵 loss 如下:

 其中, p\in [0,1]为模型的预测概率, p*\in [0,1]为真实的标签 ground-truth label;

x 为模型 unnormalized 的直接输出p=sigmoid(x),.

这里的 g 为 gradient norm,可以用来表示一个样本的分类难易程度以及对在全局梯度中的影响程度,g 越大则分类难度越高。

下图 2 展示了在目标检测模型中 gradient norm 的分布情况,表明了容易样本在梯度中会占主导地位,以及模型无法处理一些特别困难的样本,这些样本的数量甚至超过了中等困难的样本,但模型不应过于关注这些样本,因为它们可以认为是离群点。(对应上述观点 A 和观点 B-2)

为了解决这种 gradient norm 的分布问题,论文提出了一种调和手段:Gradient Density

其中,g_{k} 为第 k 个样本的 gradient norm。

g 的梯度密度即 GD(g) 表示落于以 g 为中心,长度为\epsilon  的中心区域的样本数量,然后除以有效长度进行标准化。

那么,梯度密度调和参数为:

 N 为样本数量。

GD(g_{i})/N 可以看作是梯度上在第 i 个样本周边样本频率的一种正则化:        

1. 如果所有样本的梯度是均匀分布的,那么对于每个样本 i:GD(g_{i})=N\rightarrow \beta _{i}=1,意味着每个样本都没起到任何改变

2. 容易样本的频率很高,那么\beta  就会变得很小,起到降低这些样本的权重的效果;并且特别困难样本即离群点的频率会比中等困难的样本频率多,意味着这些离群点的\beta 会相对较小,那么也会相对地轻微降低这些样本的权重; 

3. 从第 2 点可以看出,GHM 其实只适用于那些容易样本和特别困难样本的数量比中等困难样本多的场景。

 因此,经过 GHM 调和之后的 loss 为:

(3) 计算优化

GHM 的计算复杂度是 O(N^{2}),论文通过 Unit Region 的方法来逼近原生的梯度密度,大大降低了计算复杂度 .

首先,将 gradient norm 的值域空间 [0,1] 划分为 M 个长度为\epsilon  的 Unit Region。对于第 j 个 Unit Region:r_{j}=[(j-1)\epsilon ,j\epsilon ] 。

接着,让R_{j}等于落在r_{j}的样本数量;并且定义 ind(g)=t,s.t.(t-1)\epsilon <g<t\epsilon ,即计算 g 所在的 Unit Region 的索引的函数

则,梯度密度的近似函数如下,得到计算复杂度优化的 GHM Loss:

怎么理解这种近似思想呢:

1. 先回忆原生 GHM 的梯度密度计算:g 的梯度密度即 GD(g) 表示落于以 g 为中心,长度为\epsilon  的中心区域的样本数量,然后除以有效长度进行标准化;

2. 将 gradient norm 划分了 M 个 Unit Region 之后,假如第 i 个样本的 g_{i} 落入第 j 个 Unit Region,那么同样落入该 Unit Region 的样本可以认为是落于以 g_{i}为中心的中心区域,并且有效长度为\epsilon ,即得到上述的近似梯度密度函数

 最后,在使用 Unit Region 优化之后,还结合 Exponential moving average(EMA)的思想,让梯度密度更加平滑,减少对极端数据的敏感度

 R_{i}^{(t)} 为在 t 次遍历中,落入第 j 个 Unit Region 的数量; \alpha即为 EMA 中的 momentum 参数。

5. Dice Loss

Dice Loss 来自《Dice Loss for Data-imbalanced NLP Tasks》这篇论文

该论文观点有2:

1)负样本数量远超过正样本,导致容易的负样本会主导了模型的训练;

2)交叉熵其实是准确率(accuracy)导向的导致了训练和测试的不一致。在训练过程中,每一个样本对目标函数的贡献是相同,但是在测试的时候,像分类任务很重要的指标 F1 score,由于正样本数量很少,每一个正样本就对于 F1 score 的贡献则更多了。

(1)Dice Coeffificient

dice coeffificient 是一种 F1 导向的统计,用于计算两个集合的相似度:

 对应到二分类场景中,A 是模型预测为正样本的样本集合,B 是真实的正样本集合。此时,dice coefficient 其实等同于 F1:

对于每一个样本x_{i} ,它对应 dice coefficient 的为:

但是,显而易见,这样会导致负样本(y_{i1}=0)对目标的贡献为 0。因此,为了避免负样本的作用为 0,让训练更加平滑,在分子和分母中同时加入一个因子\gamma :

 

为了更快地收敛,分母可以为平方的形式,那么 Dice Loss(DL)则变为:

(修改为 1-DSC,目的是让 DSC 最大化变成目标函数最小化,这是 loss 函数常用的转换套路了,并且让 loss 为正数)

另外,以计算 set-level 的 dice coefficient,而不是独立样本的 dice coefficient 加起来,这样可以让模型更加容易学习:

 

(2)自调节

上述未经过平滑的 DSC 公式其实是 F1 的 soft 版本,因为对于 F1 score,只存在正判或误判。模型预估通常以 0.5 为边界来判断是否为正样本:

 DSC 使用连续的概率 p,而不是使用二分p_{i1}>0.5 ,这种 gap 对于均衡的数据集不是什么大问题。

但是对于大部分为容易的负样本的数据集来说,是存在极端的害处:

  • 容易分类的负样本很容易主导整个训练过程,因为它们的预测概率相对来说更容易接近 0;

  • 同时,模型会变得难以区分困难分类的负样本和正样本,这对于 F1 score 的表现有着很大的负向影响。

为了解决这种问题,DSC 在原来的基础上,给 soft 概率 p 乘上一个衰减因子(1-p) :

 (1-p_{i1})^{\alpha }是一个与每一个样本关联的权重,并且在训练过程会动态改变,根据样本的分类难易程度,实现对样本权重的自调节:

 (1-p_{i1})^{\alpha }p_{i1}对于预测概率接近 0 和 1 的容易样本,该值明显小很多,可以减少模型对这些样本的关注。

 6. 标签平滑Label Smoothing

标签平滑不是针对不均衡类别设计的 loss 优化,但不失为一种提升分类模型泛化能力的有效措施。出自这篇论文《Rethinking the Inception Architecture for Computer Vision》

(1)cross-entropy

在 K 分类模型中,第 k 个 label 的预估概率为:

k\in 1,2,...,K, z_{i}为 logits

ground-truth 真实 label 为:q(k|x),\sum_{k}q(k|x)=1 。

那么,对应的交叉熵 loss 则为:

(2)存在问题

对 z_{k} 求导得到梯度为:\frac{\partial l}{\partial z_{_{k}}}=p(k)-q(k),并且范围在 -1 到 1

对于我们的交叉熵 loss,最小化则等同于真实标签的最大似然,而仅当q(k)=\delta _{k,y}  时才能达到最大似然, \delta _{k,y}当 k=y 时为 1,其他则为 0。

而对于有限值的z_{k}  是无法达到这种最大似然的情况,但可以接近这种情况,当所有的z_{y}\gg z_{k} for k\neq y,即当对应 ground-truth 的 logits 远远大于其他的 logits,直观上来看,这是由于模型对自己的预测结果太过于自信了,这会产生以下两个问题:

  1. 它可能会造成过拟合。如果模型学习到了为每个样本的 ground-truth label 赋予完全的概率,那这无法保证泛化性

  2. 它鼓励最大的 logit 和其他的 logits 差别尽可能大,再加上梯度的边界仅在 -1 到 1,这会降低模型的适应(adapt)能力。

 (3)正则化

基于上述分析,作者提出了一种优化的交叉熵,增加了正则化:label-smoothing regularization

其中, \epsilon为 [0,1] 的超参数,K 为标签类别数量。

  • 这种方法避免了最大的 logit 比其他 logits 太过于大,给模型增加了正则化,提升了模型的泛化能力;

  • 即使发生这种情况,交叉熵 loss 会变得更大,因为不同于q(k)=\delta _{k,y} ,每个{q}'_{k}  都会贡献 loss。

7. 总结

(1)FocalLoss和DiceLoss思想比较接近,都是为了减少模型对容易样本的关注而进行的loss优化,而GHMLoss除了对容易样本降权,还实现了对特别困难样本的轻微降权,因为特别困难的样本可以认为是离群点。

(2)GHM Loss 仅适用于二分类,而 Focal Loss 和 Dice Loss 很容易扩展到多分类,但实际使用中 Focal Loss对于多分类调参比较困难(每种类别对应的 -balanced,加上 ,参数组合过于多)。

(3) Label Smoothing 虽然不是针对类别不均衡的问题,但在分类模型中,其效果往往比原生的交叉熵有些小提升。

8. 代码

https://github.com/QunBB/DeepLearning/tree/main/Trick/unbalance


参考:

分类模型:类别不均衡问题之loss设计 (qq.com)  

论文:

[1]Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). Focal loss for dense object detection. In Proceedings of the IEEE international conference on computer vision (pp. 2980-2988).

[2] Li, B., Liu, Y., & Wang, X. (2019, July). Gradient harmonized single-stage detector. In Proceedings of the AAAI conference on artificial intelligence (Vol. 33, No. 01, pp. 8577-8584).

[3]Li, X., Sun, X., Meng, Y., Liang, J., Wu, F., & Li, J. (2019). Dice loss for data-imbalanced NLP tasks. arXiv preprint arXiv:1911.02855.

[4]Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., & Wojna, Z. (2016). Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 2818-2826).

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/114066.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Graphviz安装向导

目录 1、首先在官网下载graphviz 2、安装。 3、测试 1、首先在官网下载graphviz 下载网址&#xff1a;Download | Graphviz 根据自身电脑位数选择合适的下载地址 2、安装。 打开第一步已经下载好的软件。点击下一步&#xff0c;在安装路径选择时可将安装路径修改为 E:\G…

JavaScript:栈的封装及十进制转二进制栈方法实现案例

栈的定义&#xff1a;是只允许在一端进行插入或删除的线性表。首先栈是一种线性表&#xff0c;但限定这种线性表只能在某一端进行插入和删除操作。 JavaScript中对栈的封装 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8&qu…

微信HOOK 协议接口 实战开发篇 2.好友列表与二叉树

前言&#xff1a;由于篇幅所限&#xff0c;文章无法详细到每个步骤&#xff0c;仅能写出关键的HOOK思路 好友列表 好友和群列表在汇编代码中有固定的常量保存 如图示&#xff0c;找到常量&#xff0c;回车进入 入口地址结构为 其指针内部便是我们需要的数据 群列表 搜索Ch…

Linux中的进程状态

目录 一、冯诺伊曼体系结构​编辑 关于冯诺依曼&#xff0c;必须强调几点&#xff1a; 二、操作系统 1、概念 2、操作系统的作用 3、本质 4、总结 5、系统调用和库函数概念 三、进程 1、基本概念 2、描述进程 3、task_struct 4、查看进程 5、通过系统调用获取进程…

第19章 随机变量

第19章 随机变量 19.1随机变量示例 定义19.1.1&#xff1a;概率空间上的随机变量R是域等于样本空间的全函数。 R的陪域可以是任何东西&#xff0c;但通常是实数的一个子集。 例&#xff1a; 例如&#xff0c;假设我们抛三个独立的、公平的硬币。令C表示正面朝上的次数。如…

js中的JSON的简单用法

目录 1.JSON说明 2.JSON.stringify 3.JSON.parse 4.示例 1.JSON说明 当数据在浏览器与服务器之间进行交换时&#xff0c;这些数据只能是文本&#xff0c;JSON 属于文本并且我们能够把任何 JavaScript 对象转换为 JSON&#xff0c;然后将 JSON 发送到服务器。我们也能把从服…

最强docker部署模板

00.背景 最近学校让一个小组做一个web项目最后部署到linux服务器上&#xff0c;项目本身并不难就是简单的增删改查&#xff0c;但是我想借着这个机会写一个docker部署的模板&#xff0c;方便自己以后用&#xff0c;也希望可以帮助到大家。 01.docker简介 docker可以快捷 轻量…

Redis原理篇—网络模型

Redis原理篇—网络模型 笔记整理自 b站_黑马程序员Redis入门到实战教程 用户空间和内核态空间 服务器大多都采用 Linux 系统&#xff0c;这里我们以 Linux 为例来讲解: ubuntu 和 Centos 都是 Linux 的发行版&#xff0c;发行版可以看成对 Linux 包了一层壳&#xff0c;任何 …

第八章:数据库编程

一、嵌入式、过程化SQL、存储过程和函数 1、【单选题】 下表为oracle数据库表cj.temp_20221106的数据。建立存储过程: CREATE OR REPLACE PROCEDURE proc_temp_20221106(i INT) IS CURSOR c_temp IS SELECT * FROM cj.temp_20221106; ROW_NR c_temp%ROWTYPE; i_count …

【Linux】基础IO——系统文件IOfd重定向理解

文章目录一、回顾C文件接口1.打开和关闭2.读写文件3.细节二、系统文件I/O 1.open和closeumask小细节2.read和write1.write2.read3.小总结三、理解文件四、文件描述符fd1.引入2.理解3.分配规则4.close(1)问题五、重定向1.重定向2.接口3.追加重定向4.输入重定向六、Linux一切皆文…

信息技术 定义内涵

工作流运行 定义内涵 工作流运行是工作流模板的依次执行&#xff0c;在工作流运行时&#xff0c;用户可以随时取消或查看正在 运行的任务。由于工作流运行的模板的不同&#xff0c;运行过程中可能会产生不同的新资源&#xff0c;如数据 处理类型的工作流会产生新的数据集&…

Java框架精品项目【用于个人学习】

难度系数说明&#xff1a; 难度系数用来说明项目本身进行分析设计的难度 难度系数大于1的项目可用作参赛作品、大作业、计算机毕业设计等需求 前言 大家好&#xff0c;我是二哈喇子&#xff0c;此博文整理了各种项目需求&#xff0c;用于博主自己学习&#xff0c;当做个人笔记…

黑烟车识别抓拍系统 python

黑烟车识别抓拍系统利用Python基于YOLOv5通过道路已有卡口相机对现场画面中包括黑烟车、车牌信息、车头车尾照片、林格曼黑度等级数据回传给后台。Python是一种由Guido van Rossum开发的通用编程语言&#xff0c;它很快就变得非常流行&#xff0c;主要是因为它的简单性和代码可…

IMX6ULL学习笔记(14)——GPIO接口使用【C语言方式】

一、GPIO简介 i.MX6ULL 芯片的 GPIO 被分成 5 组,并且每组 GPIO 的数量不尽相同&#xff0c;例如 GPIO1 拥有 32 个引脚&#xff0c; GPIO2 拥有 22 个引脚&#xff0c; 其他 GPIO 分组的数量以及每个 GPIO 的功能请参考 《i.MX 6UltraLite Applications Processor Reference M…

【魔法圣诞树】代码实现详解 --多种实战编程技巧倾情打造

一、前言 本文会基于C# GDI技术 从零到一 实现一颗 魔法圣诞树&#xff01;源码和素材在文末全部都有&#xff01; 二、魔法圣诞树 对于用代码画圣诞树&#xff0c;网上各种编程语言像python、css、java、c/c我们都有见到过了&#xff0c;那么在绘图方面&#xff0c;还有一位…

从刘润的商业简史,预测互联网与能源的未来,辉煌的人生需要顺势而为

所有理所当然的现在&#xff0c;都是曾经看起来不可能的未来。 所有现在看起来不可想象的未来&#xff0c;可能都是明天理所当然的现在。 未来已来&#xff0c;只是尚未流行。 “一切历史都是当代史。”学习历史的目的&#xff0c;正是为了从中总结规律&#xff0c;然后用这些…

第四章:数据库安全性

一、数据库安全概述和控制 1、【单选题】TCSEC/TDI安全级别划分中&#xff0c;C1级需要实现的安全策略为&#xff1a; 我的答案&#xff1a;A 2、【单选题】能够对系统的数据加以标记&#xff0c;对标记的主体和客体实施强制存取控制&#xff08;MAC&#xff09;、审计等安全机…

绿盟SecXOps安全智能分析技术白皮书 工作流运行

工作流运行 定义内涵 工作流运行是工作流模板的依次执行&#xff0c;在工作流运行时&#xff0c;用户可以随时取消或查看正在 运行的任务。由于工作流运行的模板的不同&#xff0c;运行过程中可能会产生不同的新资源&#xff0c;如数据 处理类型的工作流会产生新的数据集&…

【关于时间序列的ML】项目 10 :用机器学习预测降雨

&#x1f50e;大家好&#xff0c;我是Sonhhxg_柒&#xff0c;希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流&#x1f50e; &#x1f4dd;个人主页&#xff0d;Sonhhxg_柒的博客_CSDN博客 &#x1f4c3; &#x1f381;欢迎各位→点赞…

浅谈会话技术:Cookie,Session、Token

◼️ 什么是会话 会话&#xff1a; 数据交互的过程&#xff0c;在web中指 浏览器从发出一个请求到浏览器关闭&#xff0c;这个过程就是一个会话。在这个过程中&#xff0c;需要有很多的状态和数据需要我们关注&#xff0c;记录&#xff0c;这个就是我们要研究的会话 ◼️ 什么…