样本不平衡的解决办法

news2024/11/17 7:51:07

背景

Focal loss是最初由何恺明提出的,最初用于图像领域解决数据不平衡造成的模型性能问题。本文试图从交叉熵损失函数出发,分析数据不平衡问题,focal loss与交叉熵损失函数的对比,给出focal loss有效性的解释。

交叉熵损失函数

L o s s = L ( y , p ^ ) = − y l o g ( p ^ ) − ( 1 − y ) l o g ( 1 − p ^ ) Loss = L(y, \hat{p})=-ylog(\hat{p}) - (1 - y)log(1-\hat{p}) Loss=L(y,p^)=ylog(p^)(1y)log(1p^)
其中, p ^ \hat{p} p^为预测概率,y为真实label,二分类对应0、1。
对二分类中换种写法如下
L c e ( y , p ^ ) = { − l o g ( p ^ ) ,         i f    y = 1 − l o g ( 1 − p ^ ) , i f    y = 0 ( 1 ) L_{ce}(y, \hat{p}) = \left\{\begin{matrix} -log(\hat{p}), \ \ \ \ \ \ \ if \ \ y = 1&\\ -log(1-\hat{p}), if \ \ y = 0 & \end{matrix}\right. (1) Lce(y,p^)={log(p^),       if  y=1log(1p^),if  y=0(1)

样本不均衡问题

对于所有样本,损失函数为
L = 1 N ∑ i = 1 N l ( y i , p ^ i ) L=\frac{1}{N}\sum_{i=1}^{N}l(y_i, \hat{p}_i) L=N1i=1Nl(yi,p^i)
对于二分类问题,损失函数具体化为
L = 1 N ( ∑ y i = 1 m − l o g ( p ^ ) + ∑ y i = 0 n − l o g ( 1 − p ^ ) ) L=\frac{1}{N}(\sum_{y_i=1}^{m}-log(\hat{p}) + \sum_{y_i=0}^{n}-log(1-\hat{p})) L=N1(yi=1mlog(p^)+yi=0nlog(1p^))
其中m为正样本个数,n为负样本个数,N为样本总数,m+n=N。

当样本分布失衡时,在损失函数L的分布也会发生倾斜,如m<<n时,负样本就会在损失函数占据主导地位。由于损失函数的倾斜,模型训练过程中会倾向于样本多的类别,造成模型对少样本类别的性能较差。

平衡交叉熵函数(balanced cross entropy)

基于样本非平衡造成的损失函数倾斜,一个直观的做法就是在损失函数中添加权重因子,提高少数类别在损失函数中的权重,平衡损失函数的分布。如在上述二分类问题中,添加权重参数 α ∈ [ 0 , 1 ] \alpha \in [0, 1] α[0,1] 1 − α 1-\alpha 1α
L = 1 N ( ∑ y i = 1 m − α l o g ( p ^ ) + ∑ y i = 0 n − ( 1 − α ) l o g ( 1 − p ^ ) ) L=\frac{1}{N}(\sum_{y_i=1}^{m}-\alpha log(\hat{p}) + \sum_{y_i=0}^{n}-(1- \alpha)log(1-\hat{p})) L=N1(yi=1mαlog(p^)+yi=0n(1α)log(1p^))
其中 α 1 − α = n m \frac {\alpha}{1- \alpha} = \frac {n}{m} 1αα=mn,即正负样本损失权重的根据正负样本的比例分布进行设置,具体里说,成反比关系。

focal loss

focal loss也是针对样本不均衡问题,从loss角度提供的另外一种解决方法。
focal loss的具体形式为:
L f l = { − ( 1 − p ^ ) γ l o g ( p ^ )    i f   y = 1 − p ^ l o g ( 1 − p ^ )         i f   y = 0   ( 2 ) L_{fl} = \left\{\begin{matrix} -(1-\hat{p})^{\gamma}log(\hat{p}) \ \ if \ y=1 &\\ -\hat{p}log(1-\hat{p}) \ \ \ \ \ \ \ if \ y=0 & \end{matrix}\right. \ (2) Lfl={(1p^)γlog(p^)  if y=1p^log(1p^)       if y=0 (2)
p t = { p ^      i f   y = 1 1 − p ^   o t h e r w i s e p_t = \left\{\begin{matrix} \hat{p} \ \ \ \ if \ y = 1&\\ 1-\hat{p} \ otherwise & \end{matrix}\right. pt={p^    if y=11p^ otherwise
将focal loss表达式(2)统一为一个表达式:
L f l = − ( 1 − p t ) γ l o g ( p t )      ( 3 ) L_fl=-(1-p_t)^{\gamma}log(p_t) \ \ \ \ (3) Lfl=(1pt)γlog(pt)    (3)
同理可将交叉熵表达式(1)统一为一个表达式:
L c e = − l o g ( p t ) L_{ce} = -log(p_t) Lce=log(pt)
p t p_t pt 反映了与ground truth即类别 y y y 的接近程度, p t p_t pt越大说明越接近类别 y y y,即分类越准确。
γ \gamma γ 为可调节因子。

对比表达式(3)和(4), focal loss相比交叉熵多了一个modulating factor即 ( 1 − p t ) γ (1-p_t)^{\gamma} (1pt)γ。对于分类准确的样本 p t → 1 p_t \rightarrow 1 pt1,modulating factor趋近于0。对于分类不准确的样本 1 − p t → 1 1-p_t \rightarrow 1 1pt1,modulating factor趋近于1。即相比交叉熵损失,focal loss对于分类不准确的样本,损失没有改变,对于分类准确的样本,损失会变小。 整体而言,相当于增加了分类不准确样本在损失函数中的权重。

p t p_t pt也反应了分类的难易程度, p t p_t pt 越大,说明分类的置信度越高,代表样本越易分; p t p_t pt 越小,分类的置信度越低,代表样本越难分。因此focal loss相当于增加了难分样本在损失函数的权重,使得损失函数倾向于难分的样本,有助于提高难分样本的准确度。focal loss与交叉熵的对比,可见下图:
在这里插入图片描述

focal loss vs balanced cross entropy

focal loss相比balanced cross entropy而言,二者都是试图解决样本不平衡带来的模型训练问题,后者从样本分布角度对损失函数添加权重因子,前者从样本分类难易程度出发,使loss聚焦于难分样本。

focal loss为什么有效

focal loss从样本难易分类角度出发,解决样本非平衡带来的模型训练问题。

相信很多人会在这里有一个疑问,样本难易分类角度怎么能够解决样本非平衡的问题,直觉上来讲样本非平衡造成的问题就是样本数少的类别分类难度较高。因此从样本难易分类角度出发,使得loss聚焦于难分样本,解决了样本少的类别分类准确率不高的问题,当然难分样本不限于样本少的类别,也就是focal loss不仅仅解决了样本非平衡的问题,同样有助于模型的整体性能提高。

要想使模型训练过程中聚焦难分类样本,仅仅使得Loss倾向于难分类样本还不够,因为训练过程中模型参数更新取决于Loss的梯度。
w = w − α ∂ L ∂ w    ( 5 ) w = w - \alpha \frac{ \partial {L}}{ \partial {w}} \ \ (5) w=wαwL  (5)
如果Loss中难分类样本权重较高,但是难分类样本的Loss的梯度为0,难分类样本不会影响模型学习过程。

对于梯度问题,在focal loss的文章中,也给出了答案。如下图所示为focal loss的梯度示意图。其中 x t = y x x_t = yx xt=yx ,其中 y ∈ { 0 , 1 } y \in {\{0, 1}\} y{0,1} 为类别, p t = σ ( x t ) p_t=\sigma(x_t) pt=σ(xt) ,对于易分样本, x t > 0 x_t > 0 xt>0,即 p t > 0.5 p_t>0.5 pt>0.5 。由图中可以看出,对于focal loss而言,在 x t > 0 x_t>0 xt>0 时,导数很小,趋近于0。因此对于focal loss,导数中易是难分类样本占主导,因此学习过程更加聚焦正在难分类样本。
在这里插入图片描述

一点小思考

难分类样本与易分类样本其实是一个动态概念,也就是说 p t p_t pt 会随着训练过程而变化。原先易分类样本即 p t p_t pt 大的样本,可能随着训练过程变化为难训练样本即 p t p_t pt 小的样本。

上面讲到,由于Loss梯度中,难训练样本起主导作用,即参数的变化主要是朝着优化难训练样本的方向改变。当参数变化后,可能会使原先易训练的样本 p t p_t pt 发生变化,即可能变为难训练样本。当这种情况发生时,可能会造成模型收敛速度慢,正如苏剑林在他的文章中提到的那样。

为了防止难易样本的频繁变化,应当选取小的学习率。防止学习率过大,造成 w w w 变化较大从而引起 p t p_t pt 的巨大变化,造成难易样本的改变。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/592855.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

危机先知:TOOM舆情监控助力风险预警

随着社交媒体和互联网的普及&#xff0c;公众的声音在网络上如洪水般涌现。这些声音传递着情绪、态度和观点&#xff0c;对个人、组织甚至整个社会产生着巨大影响。因此&#xff0c;舆情监控成为了一个不可或缺的工具&#xff0c;帮助企业和组织及时了解公众对其品牌、产品或服…

决策树基本理论知识

目录 1、决策树是一种树模型 2、决策树的训练与测试 3、信息增益&#xff08;ID3&#xff09; 3.1、衡量标准-熵 3.2、决策树构造实例 4、决策树算法 ​5、连续值离散化 6、预剪枝 1、决策树是一种树模型&#xff1a; &#xff08;1&#xff09;、从根结点开始一步步走…

【C++】哈希表封装unordered系列

文章目录 前言一、哈希表的封装总结 前言 在看本篇文章前大家尽量拿出上一篇文章的代码跟着一步步实现&#xff0c;否则很容易引出大量模板错误而无法解决。 一、哈希表的封装 首先我们要解决映射的问题&#xff0c;我们目前的代码只能映射整形&#xff0c;那么如何支撑浮点数…

Java使用zxing.jar生成二维码

由于时代科学的进步&#xff0c;二维码已经和我们的生活密不可分&#xff0c;在开发过程中往往会涉及到和二维码相关的开发&#xff0c;今天这篇文章就教会大家如何使用zxing.jar包生成二维码 下面这个就是百度上面自带的一个生成二维码的功能&#xff0c;那他是怎么实现这个功…

计算机组成原理与体系结构概述

目录 一、计算机的发展 二、计算机的硬件系统 三、硬件的工作原理 四、计算机系统的层次结构 五、计算机的性能指标 一、计算机的发展 第一代计算机&#xff1a;电子管计算机 第一台电子计算机&#xff1a;ENIAC&#xff08;1946&#xff09; 设计目的&#xff1a;计算导弹…

平板触控笔哪种好?主动式电容笔推荐

现在市面上的电容笔分为主动式和被动式电容笔&#xff0c;很多小伙伴都分不清主动式和被动式电容笔的区别。今天给大家介绍一下这两款电容笔的区别。给大家分享几款好用的平替电容笔。 一、主动式电容笔和被动式电容笔的区别&#xff1a; 1.主动式电容笔&#xff1a; 主动式电…

数据结构与算法(九)

红黑树复习 图 图&#xff0c;是一种数据结构 集合只有同属于一个集合&#xff1b;线性结构存在一对一的关系&#xff0c;树形结构一对多的关系&#xff0c;图形结构&#xff0c;多对多的关系。 微信中&#xff1a;许多的用户组成了一个多对多的朋友关系网&#xff0c;这个关…

【C语言】变量

&#x1f6a9; WRITE IN FRONT &#x1f6a9; &#x1f50e; 介绍&#xff1a;"謓泽"正在路上朝着"攻城狮"方向"前进四" &#x1f50e;&#x1f3c5; 荣誉&#xff1a;2021|2022年度博客之星物联网与嵌入式开发TOP5|TOP4、2021|2022博客之星T…

【机器学习】分类问题和逻辑(Logistic)回归算法详解

在阅读本文前&#xff0c;请确保你已经掌握代价函数、假设函数等常用机器学习术语&#xff0c;最好已经学习线性回归算法&#xff0c;前情提要可参考https://blog.csdn.net/weixin_45434953/article/details/130593910 分类问题是十分广泛的一个问题&#xff0c;其代表问题是&…

Android studio 环境安装

1. Java JDK安装 https://download.oracle.com/java/17/latest/jdk-17_windows-x64_bin.exe 下载jdk-17 并安装 安装完成后设置环境变量 #新增环境变量JAVA_HOME C:\Program Files\Java\jdk-17#Path 环境变量添加 %JAVA_HOME%\bin %JAVA_HOME%\jdk\bin#新增环境变量CLASSPAT…

HEVC量化编码介绍

介绍 ● 视频编码中&#xff0c;残差信号经过DCT&#xff0c;变换系数具有较大动态范围&#xff0c;因此对变换系数量化可以有效减小信号取值空间&#xff0c;获得更好的压缩效果&#xff1b; ● 多对一映射机制&#xff0c;所以不可避免的引入失真&#xff0c;这是视频编码中…

Spring(三)对bean的详解

一、引入外部属性文件 首先我们将依赖进行导入&#xff1a; <!--MySQL驱动--><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>8.0.22</version></dependency><!--数据…

idea连接Linux服务器

一、 介绍 配置idea的ssh会话和sftp可以实现对linux远程服务器的访问和文件上传下载&#xff0c;是替代Xshell的理想方式。这样我们就能在idea里面编写文件并轻松的将文件上传到linux服务器中。而且还能远程编辑linux服务器上的文件。掌握并熟练使用&#xff0c;能够大大提高我…

烂怂if-else代码优化方案 | 京东云技术团队

0.问题概述 代码可读性是衡量代码质量的重要标准&#xff0c;可读性也是可维护性、可扩展性的保证&#xff0c;因为代码是连接程序员和机器的中间桥梁&#xff0c;要对双边友好。Quora 上有一个帖子&#xff1a; “What are some of the most basic things every programmer s…

日本医疗保健和健康管理公司【Zerospo】申请纳斯达克IPO上市

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 猛兽财经获悉&#xff0c;来自日本的医疗保健和健康管理公司【Zerospo】&#xff0c;近期已向美国证券交易委员会&#xff08;SEC&#xff09;提交招股书&#xff0c;申请在纳斯达克IPO上市&#xff0c;股票代码为&#xff…

感谢海洋一所陈老师用Pospac MMS解算pospac数据及GNSS验潮

非常感谢海洋一所陈老师 帮忙用Pospac MMS解算博主的pospa从数据。解算的结果txt文件大小有2个G&#xff0c;令人非常吃惊&#xff0c;因为原始数据的时长不到1天&#xff0c;打开文件才知道每行位置数据的间隔时间是5ms&#xff0c;5ms正是惯导数据的采样频率。 用抽稀软件按…

短视频矩阵系统源码-开源开发php语言搭建

短视频矩阵系统源码---------- php源码是什么&#xff1f; PHP源码指的就是PHP源代码&#xff0c;源代码是用特定编程语言编写的人类可读文本&#xff0c;源代码的目标是为可以转换为机器语言的计算机设置准确的规则和规范。因此&#xff0c;源代码是程序和网站的基础。 PHP…

【数据结构】插入排序详细图解(一看就懂)

&#x1f4af; 博客内容&#xff1a;【数据结构】插入排序详细图解&#xff08;一看就懂&#xff09; &#x1f600; 作  者&#xff1a;陈大大陈 &#x1f989;所属专栏&#xff1a;数据结构笔记 &#x1f680; 个人简介&#xff1a;一个正在努力学技术的准前端&#xff0c;…

ARM-FS6818-点亮LED灯

点亮LED灯 1.开发板介绍 2.cpu控制硬件原理 六大指令里边&#xff0c;只有内存访问指令能访问cpu之外的内容。那cpu如何控制硬件&#xff1f; *load/store指令-->操作4G内存 任何一个芯片都有一个地址映射表。告诉地址空间是如何映射的&#xff0c;便于我们找到对应的硬件地…

ChatGPT3.5-4资源汇总,直连无梯子

什么是ChatGPT? ChatGPT&#xff0c;全称&#xff1a;聊天生成预训练转换器&#xff08;英语&#xff1a;Chat Generative Pre-trained Transformer&#xff09;&#xff0c;是OpenAI开发的人工智能聊天机器人程序&#xff0c;于2022年11月推出。该程序使用基于GPT-3.5、GPT-4…