深入理解二分类和多分类CrossEntropy Loss和Focal Loss

news2024/11/17 21:30:36

深入理解二分类和多分类CrossEntropy Loss和Focal Loss

二分类交叉熵

在二分的情况下,模型最后需要预测的结果只有两种情况,对于每个类别我们的预测得到的概率为 p p p 1 − p 1-p 1p,此时表达式为( 的 log ⁡ \log log底数是 e e e):
L = 1 N ∑ i L i = 1 N ∑ i − [ y i ⋅ log ⁡ ( p i ) + ( 1 − y i ) ⋅ log ⁡ ( 1 − p i ) ] L=\frac{1}{N} \sum_{i} L_i =\frac{1}{N} \sum_{i} -[y_i \cdot \log (p_i) +(1-y_i) \cdot \log (1-p_i)] L=N1iLi=N1i[yilog(pi)+(1yi)log(1pi)]
其中:

  • y i y_i yi —— 表示样本 i i i的label,正类为1 ,负类为0
  • p i p_i pi—— 表示样本 i i i预测为正类的概率

由于二分类交叉熵很容易理解,在此就不做举例了。

多分类交叉熵

多分类交叉熵就是对二分类交叉熵的扩展,在计算公式中和二分类稍微有些许区别,但是还是比较容易理解,具体公式如下所示:
L = 1 N ∑ i L i = − 1 N ∑ i ∑ c = 1 M y i c log ⁡ ( p i c ) L=\frac{1}{N} \sum_{i} L_i=-\frac{1}{N} \sum_{i} \sum_{c=1}^M y_{ic} \log(p_{ic}) L=N1iLi=N1ic=1Myiclog(pic)
其中:

  • M M M——类别的数量
  • y i c y_{ic} yic——符号函数(0或1 ),如果样本 i i i的真实类别等于 c c c取 1,否则取 0
  • p i c p_{ic} pic——观测样本 i i i属于类别 c c c的预测概率

举例说明

预测(已经经过softmax归一化)真实
0.1 0.2 0.70 0 1
0.3 0.4 0.30 1 0
0.1 0.2 0.71 0 0

现在我们利用这个表达式计算上面例子中的损失函数值:
sample 1 loss = − ( 0 × log ⁡ 0.1 + 0 × log ⁡ 0.2 + 1 × log ⁡ 0.7 ) = 0.35 , sample 2 loss = − ( 0 × log ⁡ 0.1 + 1 × log ⁡ 0.7 + 0 × log ⁡ 0.2 ) = 0.35 , sample 3 loss = − ( 1 × log ⁡ 0.3 + 0 × log ⁡ 0.4 + 0 × log ⁡ 0.4 ) = 1.20 , L = 0.35 + 0.35 + 1.2 3 = 0.63 \text{sample 1 loss}=-(0 \times \log 0.1+0 \times \log 0.2 + 1 \times \log 0.7)=0.35 ,\\ \text{sample 2 loss}=-(0 \times \log 0.1+1 \times \log 0.7 + 0 \times \log 0.2)=0.35 ,\\ \text{sample 3 loss}=-(1 \times \log 0.3+0 \times \log 0.4 + 0 \times \log 0.4)=1.20,\\ L=\frac{0.35+0.35+1.2}{3}=0.63 sample 1 loss=(0×log0.1+0×log0.2+1×log0.7)=0.35,sample 2 loss=(0×log0.1+1×log0.7+0×log0.2)=0.35,sample 3 loss=(1×log0.3+0×log0.4+0×log0.4)=1.20,L=30.35+0.35+1.2=0.63
其实可以看到,多分类交叉熵只计算正确标签对应概率的损失值,相对错误标签其 y i c = 0 y_{ic}=0 yic=0,所以导致错误标签对应的损失值为0。

Pytorch的CrossEntropyLoss分析

参数设定

CrossEntropyLoss在Pytorch官网中,我们可以看到整个文档已经对该函数CrossEntropyLoss进行了较充分的解释。所以我们简要介绍其参数和传入的值的格式,特别是针对多分类的情况。

常见的传入参数如下所示:

  • weight:传入的是一个list或者tensor,其检索对应位置的值为该类的权重。注意,如果是GPU的环境下,则传入的值必须是tensor,并且其应该在GPU中。

  • reduction:传入的是一个字符串,有三种形式可以选择,分别是mean/sum/none,默认是meanmeansum如字面意思所示,代表损失值取平均,损失值求和的形式。none是计算每个位置对应的损失值,返回和label对应的形状。

更多参数解释如下图所示:

使用方法

CrossEntropyLoss传入的值为两个,分别是inputtarget。输出只有一个Output

  • input的形状为 ( N , C ) / ( N , C , d 1 , d 1 , … ) (N,C)/(N,C,d_1,d_1,\ldots) (N,C)/(N,C,d1,d1,),前者对应二维情况,后者对应高维情况,值得注意的是 C C C是在dim=1的位置上,可能在高维的情况下很多人都以为默认应该是最后一个维度dim=-1

  • target的形状为 ( N ) / ( N , d 1 , d 1 , … ) (N)/(N,d_1,d_1,\ldots) (N)/(N,d1,d1,),前者对应二维情况,后者对应高维情况。注意的是target的值对应的是类别对应的索引,不是one-hot的形式

  • Output的形状和target的形状一致。

更多参数解释如下图所示:

二维情况下对应的5分类交叉熵损失计算(官网示例):

>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()

高维情况下对应的交叉熵计算:

input = torch.randn(2,3,5,5,4)#最后一个维度对应的是类别
target = torch.empty(2,3,5,5, dtype=torch.long).random_(4) #四分类
loss_fn=CrossEntropyLoss(reduction='sum')
_input=torch.permute(input,dims=(0,-1,1,2,3))
loss=loss_fn(_input,target)#输入的类别一定是在dim=1的位置上
print(loss)
# 当然也可以将输入先转为2维的形式在计算,结果是一样的
_input=input.view(-1,4)
_target=target.view(-1)
loss=loss_fn(_input,_target)
print(loss)

内在原理

Pytorch中的CrossEntropyLoss()是将logSoftmax()NLLLoss()函数进行合并的,也就是说其内在实现就是基于logSoftmax()NLLLoss()这两个函数。

input=torch.rand(3,5)
target=torch.empty(3,dtype=torch.long).random_(5)
loss_fn=CrossEntropyLoss(reduction='sum')
loss=loss_fn(input,target)
print(loss)
_input=torch.nn.LogSoftmax(dim=1)(input)
loss=torch.nn.NLLLoss(reduction='sum')(_input,target)
print(loss)

其实也就是和官网上所说的一样,CrossEntropyLoss()是对输出计算softmax(),在对结果取log()对数,最后使用NLLLoss()得到对应位置的索引值。

Focal Loss原理和实现

Focal Loss来自于论文Focal Loss for Dense Object Detection,用于解决类别样本不平衡以及困难样本挖掘的问题,其公式非常简洁:
F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) FL(p_t)=- \alpha_t (1-p_t) ^{\gamma} \log (p_t) FL(pt)=αt(1pt)γlog(pt)
p t p_t pt是模型预测的结果的类别概率值。 − log ⁡ ( p t ) - \log (p_t) log(pt)和交叉熵损失函数一致,因此当前样本类别对应的那个 p t p_t pt如果越小,说明预测越不准确, 那么 ( 1 − p t ) γ (1-p_t)^{\gamma} (1pt)γ 这一项就会增大,这一项也作为困难样本的系数,预测越不准,Focal Loss越倾向于把这个样本当作困难样本,这个系数也就越大,目的是让困难样本对损失和梯度的贡献更大。

前面的 α t \alpha_t αt是类别权重系数。如果你有一个类别不平衡的数据集,那么你肯定想对数量少的那一类在loss贡献上赋予一个高权重,这个 α t \alpha_t αt就起到这样的作用。因此, α t \alpha_t αt应该是一个向量,向量的长度等于类别的个数,用于存放各个类别的权重。一般来说 α t \alpha_t αt中的值为每一个类别样本数量的倒数,相当于平衡样本的数量差距。

这里提供一个二维/高维的Focal Loss的实现:

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=torch.tensor([0.2, 0.3, 0.5,1])):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, input, target):
        logpt = nn.functional.log_softmax(input, dim=1) #计算softmax后在计算log
        pt = torch.exp(logpt) #对log_softmax去exp,把log取消就是概率
        alpha=self.alpha[target].unsqueeze(dim=1) # 去取真实索引类别对应的alpha
        logpt = alpha*(1 - pt) ** self.gamma * logpt #focal loss计算公式
        loss = nn.functional.nll_loss(logpt, target,reduction='sum') # 最后选择对应位置的元素
        return loss

参考资料

CrossEntropy官网详细说明。

Pytorch中的CrossEntropyLoss()函数案例解读和结合one-hot编码计算Loss

详解PyTorch实现多分类Focal Loss——带有alpha简洁实现

最近工作

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/503233.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何用ChatGP协助你,从品牌角度对产品提出升级建议?

该场景对应的关键词库(19个): 品牌洋葱图思维模型、产品信息、人群、品类、属性、体验、差异化特征、功效、品牌价值主张、目标用户、需求、痛点、爽点、消费者、外观、功能、结构、产品优化建议、产品开发可行性。 提问模板(3个&#xff09…

《Vue.js 设计与实现》—— 01 权衡的艺术

书籍链接:https://weread.qq.com/web/bookDetail/c5c32170813ab7177g0181ae 框架设计里到处都体现了权衡的艺术。 当我们设计一个框架时,框架本身的各个模块之间并不是相互独立的,而是相互关联、相互制约的。 作为框架设计者,一…

Windows10安装免安装版redis

下载 官方下载地址:github.com/MicrosoftAr…选择版本 解压安装 配置环境变量&注册成服务 配置环境变量 以管理员启动命令行,在redis安装根目录,把redis注册服务 redis-server --service-install redis.windows-service.conf --lo…

Communications chemisty|德睿智药工作-用于分子性质预测的药物约束异构图Transformer模型

德睿智药的分子性质预测任务 题目: Pharmacophoric-constrained heterogeneous graph transformer model for molecular property prediction 文献来源:COMMUNICATIONS CHEMISTRY | (2023) 6:60 | 代码:https://github.com/stardj/PharmHG…

springboot+dubbo+zookeeper 项目实战

现在有一段代码再前台,后台系统中都存在,都需要这段代码,存在这种情况,我们可以选择将这段代码提取出来作为一个服务,让前台和后台系统作为消费者远程调用这段代码,提高了代码的复用性。 springboot集成dub…

Unity Audio -- (2)创建动态音效

评估场景需求 本节的目标是添加脚步声到角色身上,当角色走路时,触发动画事件并播放声音。 脚步声是我们在真实世界中常常被我们所忽视的声音,但脚步声能够传达出许多环境信息。你现在可以花一小段时间绕着你周围的环境走一走并仔细听听脚步声…

CLIP : Learning Transferable Visual Models From Natural Language Supervision

CLIP : Learning Transferable Visual Models From Natural Language Supervision IntroductionApproach Introduction 在raw的数据上自监督的训练模型,已经在NLP领域取得了革命性进展,这种模型需要收到硬件、数据的限制,但是能得到很好的迁…

算法 DAY55 动态规划11 392.判断子序列 115.不同的子序列

392.判断子序列 本题可以直接用双指针解法。但是本题是编辑距离的入门题目,故采用动态规划解法为后序“编辑距离”类题目打基础。 本题与最大子序列非常相似,但不同的是s必须连续,t可以不连续。 五部曲 1、dp[i][j] 表示以下标i-1为结尾的字…

Seata介绍

介绍: Seata的设计目标是对这个业务无侵入,因此从业务无侵入的2PC方案开始的,在传统的2PC的基础上演进的。它把一个分布式事务拆分理解成一个包含了若干分支事务的全局事务。全局事务的职责是协调其下管辖的分支事务达成一致性,要…

25.自定义注解

自定义注解 一、什么是注解 Annontation是Java1.5开始引入的新特征,中文名称叫注解。 它提供了一种安全的类似注释的机制,用来将信息或元数据(metadata)与程序元素(类、方法、成员变量等)进行关联。为程序…

大数据技术之SparkSQL——数据的读取和保存

一、通用的加载和保存方式 SparkSQL提供了通用的保存数据和数据加载的方式。根据不同的参数读取,并保存不同格式的数据。SparkSQL默认读取和保存的文件格式为Parquet。 1.1 加载数据 spark.read.load 是加载数据的通用方式。 如果读取不同格式的数据,可…

如何编译DPDK静态库

阅读前面文章https://blog.csdn.net/qq_36314864/article/details/130243348,知道了哪些dpdk文件可以在windows下生成。 打开vs,新建一个生成静态库工程,在生成的lib文件中找到D:\dpdk-21.07\build\lib D:\dpdk-21.07\build\drivers找到对应的文件,并按照路径,新建筛选项…

【Vue学习笔记7】Vue3中如何开发组件

重点学习:vue3.0之组件通信机制defineProps(组件接收外部传来的参数)、defineEmits(向组件外部传递参数)。 1. 评级组件第一版 简单的评级需求,只需要一行代码就可以实现: "★★★★★☆…

SLAM面试笔记(5) — ROS面试

目录 1 ROS概述 2 ROS通信机制 问题:服务通信概念 问题:服务通信理论模型 3 常见面试题 问题:roslaunch和rosrun区别? 问题:什么是ROS? 问题:ROS中的节点是什么? 问题&…

挠性航天器姿态机动动力学模型及PD鲁棒控制

挠性航天器姿态机动动力学模型及PD鲁棒控制 1挠性航天器姿态机动动力学模型2挠性航天器姿态机动PD鲁棒控制2.1 动力学模型及PD控制律2.2仿真模型2.3 控制程序2.4 被控对象程序2.5 绘图程序2.6 结果 1挠性航天器姿态机动动力学模型 2挠性航天器姿态机动PD鲁棒控制 2.1 动力学模…

【NLP开发】Python实现聊天机器人(ChatterBot,集成web服务)

🍺NLP开发系列相关文章编写如下🍺: 🎈【NLP开发】Python实现词云图🎈🎈【NLP开发】Python实现图片文字识别🎈🎈【NLP开发】Python实现中文、英文分词🎈🎈【N…

澳大利亚兔灾和——栈?

一.背景 1859年,当一位叫托马斯奥斯汀的农民收到英国老家送来的24只野兔并将它们放归农场的时候,他绝对意想不到,这些看似人畜无害的小兔子,竟为古老的澳洲大陆带来一场巨大的生态破坏。到20世纪初,澳大利亚的兔子数量…

操作系统内存管理(上)——内存管理基础

一、内存的基本知识 1.什么是内存?有什么作用? 内存可存放数据。程序执行前先放到内存才能被CPU处理——缓和CPU和硬盘之间的速度矛盾。 给内存的存储单元编址。如果计算机按字节编址,则每个存储单元大小为1字节。即1B8b(8个二进…

智能医院导航导诊系统,门诊地图导航怎么做?

现在很多医院都是综合化大型医院,有很多的科室,院区面积也逐渐扩大,一方面给病患提供了更为全面的医疗资源,另一方面,医院复杂的环境也给病患寻医问诊带来了一定的困扰。电子地图作为大家最喜闻乐见的高效应用形式&…

Python的socket模块及示例

13.2 socket模块 socket由一些对象组成,这些对象提供网络应用程序的跨平台标准。 13.2.1 认识socket模块 socket又称“套接字”,应用程序通常通过“套接字”向网络发出请求或应答网络请求,使主机间或一台计算机上的进程间可以通信。sock…