交叉熵、Focal Loss以及其Pytorch实现

news2024/9/30 11:23:08

交叉熵、Focal Loss以及其Pytorch实现

本文参考链接:https://towardsdatascience.com/focal-loss-a-better-alternative-for-cross-entropy-1d073d92d075

文章目录

  • 交叉熵、Focal Loss以及其Pytorch实现
    • 一、交叉熵
    • 二、Focal loss
    • 三、Pytorch
      • 1.[交叉熵](https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html?highlight=nn+crossentropyloss#torch.nn.CrossEntropyLoss)
      • 2.[Focal loss](https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py)

一、交叉熵

损失是通过梯度回传用来更新网络参数是之产生的预测结果和真实值之间相似。不同损失函数有着不同的约束作用,不同的数据对损失函数有着不同的影响。

交叉熵是常见的损失函数,常见于语义分割、对比学习等。其函数表达式如下,其中 Y i 和 p i Y_i和p_i Yipi分别表示真实值和预测结果:
C r o s s E n t r o p y = − ∑ i = 1 i = n Y i l o g ( p i ) CrossEntropy=-\sum_{i=1}^{i=n}Y_ilog(p_i) CrossEntropy=i=1i=nYilog(pi)
因为 p i p_i pi值在0~1之间,故交叉熵大于等于0。这个函数什么时候最小呢?数学证明结果表明当 Y i = p i Y_i=p_i Yi=pi时交叉熵最小。下面,我们取二分类情况来进行简单证明:
B C E L o s s = − y l o g x − ( 1 − y ) l o g ( 1 − x ) BCELoss=-ylogx-(1-y)log(1-x) BCELoss=ylogx(1y)log(1x)
对BCELoss求导可得:
− y x + 1 − y 1 − x = y − x x − x 2 -\frac{y}{x}+\frac{1-y}{1-x}=\frac{y-x}{x-x^2} xy+1x1y=xx2yx
所以当 y = x y=x y=x时,二分类交叉熵取最小值。

那么,交叉熵有啥子问题?

  • 从表达式可以看出,交叉熵只针对单个像素进行比较,像素和像素之间并没有联系,这就需要我们在模型中使用空间注意力机制等使得特征在空间上进行交互。针对这个问题,不少论文提出了改进方案,如Context Prior for Scene Segmentation这篇论文就使用了和precision、recall等类似的损失函数(就像Dice loss和F1 score指标一样)。

  • 类别不平衡:这个问题比较常见,语义分割中类别在图片上总像素占比是不平衡。如果类别不平衡比较严重,交叉熵损失就会偏向于占比较高的类别,导致对占比较少的类别预测结果较差。解决这一方法为给交叉熵损失添加权重(平衡交叉熵)等,如下式:
    B a l a n c e d C r o s s E n t r o p y = − ∑ i = 1 i = n α i Y i l o g ( p i ) BalancedCrossEntropy=-\sum_{i=1}^{i=n}\alpha_iY_ilog(p_i) BalancedCrossEntropy=i=1i=nαiYilog(pi)

  • 困难样本:首先,我们要知道困难样本是那些模型反复出现巨大损失的例子,而简单样本是那些容易分类的例子。交叉熵对于所有样本同等对待,导致无法辨别困难样本和简单样本。解决这一问题就是接下来的损失函数Focal loss

二、Focal loss

Focal loss关注的是模型出错的例子,而不是它可以自信地预测的例子,确保对困难的例子的预测随着时间的推移而改善,而不是对容易的例子变得过于自信。

这到底是怎么做到的呢?Focal loss是通过一个叫做Down Weighting的东西来实现的。下调权重是一种技术,它可以减少容易的例子对损失函数的影响,从而使人们更加关注困难的例子。这种技术可以通过在交叉熵损失中加入一个调节因子来实现。其表达式如下:
F o c a l L o s s = − ∑ i = 1 i = n ( 1 − p i ) γ l o g p i FocalLoss=-\sum_{i=1}^{i=n}(1-p_i)^{\gamma}logp_i FocalLoss=i=1i=n(1pi)γlogpi
不同的 γ \gamma γ对损失有什么影响呢?如下图所示

img

不同的 γ \gamma γ ( 1 − p i ) γ (1-p_i)^{\gamma} (1pi)γ有什么影响呢,如下:

img

  • 在误分类样本的情况下, p i pi pi很小,使得调制因子大约或非常接近于1,这使损失函数不受影响。此时,Focal Loss和交叉熵损失相似。
  • 随着模型置信度的提高,即 p i → 1 pi→1 pi1,调制因子将趋于0,从而降低了分类良好的例子的损失值。聚焦参数, γ \gamma γ≥1,将重新调整调制因子,使容易的例子比困难的例子降权更多,减少它们对损失函数的影响。例如,考虑预测概率为0.9和0.6。考虑到 γ \gamma γ=2,对0.9计算出的损失值是4.5e-4,降权系数( 1 / ( 1 − q i ) 2 1/(1-q_i)^2 1/(1qi)2)为100,对0.6则是3.5e-2,降权系数为6.25。从实验来看, γ \gamma γ=2来说效果最好。
  • γ \gamma γ=0时,Focal Loss等同于Cross Entropy。

此外,加入平衡因子 α \alpha α,用来平衡正负样本本身的比例不均:文中 α \alpha α取0.25,即正样本要比负样本占比小,这是因为负例易分。其表达式如下:
F o c a l L o s s = − ∑ i = 1 i = n α i ( 1 − p i ) γ l o g p i FocalLoss=-\sum_{i=1}^{i=n}\alpha_i(1-p_i)^{\gamma}logp_i FocalLoss=i=1i=nαi(1pi)γlogpi
Focal Loss自然地解决了阶级不平衡的问题,(1因为来自多数类别的例子通常容易预测,而来自少数类别的例子由于缺乏数据或来自多数类别的例子在损失和梯度过程中占主导地位而难以预测。由于这种相似性,Focal Loss可能能够解决这两个问题。

三、Pytorch

1.交叉熵

Pytorch可以直接调用交叉熵损失函数nn.CrossEntropyLoss(),其功能还是比较全的。其中weight可以用了进行权重平衡,ignore_index可以用来忽略特定类别。输入的标签不需要进行one hot编码,其内部已经实现。nn.CrossEntropyLoss()=nn.NLLoss() + nn.LogSoftmax。
在这里插入图片描述

2.Focal loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/687373.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浅谈商业智能BI的过去、现在和未来

互联网的大发展也在引领各行各业的改变,包括商业智能BI,商业智能BI就是在数字化时代下飞速发展的。商业智能BI与互联网发展的同时,人工智能、大数据、区块链、云计算等新一代信息化、数字化技术也开始进行加速商业智能BI发展及应用&#xff0…

vmareWorkstation-提取vmdk-文件系统

参考博文-CSDN-BugM(作者)-将vmdk作为硬盘挂载到另一个linux系统的虚拟机上 一、以管理员身份运行wmware-workstation 二、将目的vmdk文件映射到一个linux虚拟机上 选择左下方的添加按钮 添加的文件的路径可以查看需要添加的虚拟机的情况,如…

PXE批量网络装机、PXE无人值守安装

目录 一、批量部署的优点 二、基本部署过程 三、部署pxe网络体系要求 四、搭建PXE远程安装服务器的步骤 1.安装并启用TFTP服务 2.安装并启用dhcp服务 3.准备linux内核、初始化镜像文件 4.准备PXE引导程序 5.安装FTP服务,准备CentOS7安装源 6.配置启动菜单文…

ATA-3090功率放大器在新能源汽车上的应用

随着全球对环境保护和节能减排的重视,新能源汽车正逐渐成为汽车市场的主流。而功率放大器作为电子控制系统中的关键部件之一,也扮演着越来越重要的角色。那么,功率放大器在新能源汽车上的应用有哪些呢? 图:新能源汽车 …

数字IC前端学习笔记:仲裁轮询(五)

相关文章 数字IC前端学习笔记:LSFR(线性反馈移位寄存器) 数字IC前端学习笔记:跨时钟域信号同步 数字IC前端学习笔记:信号同步和边沿检测 数字IC前端学习笔记:锁存器Latch的综合 数字IC前端学习笔记&am…

聚焦云原生安全|安全专家揭秘如何强化容器威胁检测能力

4月15日,由极狐主办的“当效率遇上安全 一起开启质效升级之旅”活动顺利开展。 作为国内云原生安全领导厂商,安全狗受邀出席此次活动。 厦门服云信息科技有限公司(品牌名:安全狗)成立于2013年,致力于提供云…

【中危】BootCDN 投毒风险

漏洞描述 BootCDN是免费的前端开源项目 CDN 加速服务,通过同步cdnjs仓库,提供了常用javascript组件的CDN服务。 多个开发者发现在特定请求中(如特定Referer及移动端user-agent)会返回包含指向union.macoms.la地址的恶意js文件&a…

快速提取地形图坐标到Revit建立场地模型?

一、快速提取地形图坐标到Revit建立场地模型? 1、打开免费三维地形软件: 2、定位项目位置: 3、按步骤提取坐标及高程: 4、保存为csv文件格式,打开CSV文件,删除第一行: 5、打开坐标转换软件,按步…

python3套接字编程之socket和socketserver(TCP和UDP通信)

socket和socketserver是python3中socket通信模块,关于其使用做如下总结。 目录 1.socket 1.1模块引入 1.2套接字获取 1.3套接字接口 1.3.1 服务端 1.3.2 客户端套接字函数 1.3.3 公共套接字函数 1.3.4 面向锁的套接字方法 1.3.5 面向文件的套接字的函数 …

Python注释解密、变量大揭秘,数据类型轻松入门!

文章目录 前言注释单行注释多行注释 变量数据类型1.整型(int)2.浮点型(float)3.布尔型(bool)4.字符串(str)5.列表(list)6.元组(tuple)…

RabbitMQ高可用集群部署

文章目录 1.RabbitMQ常见的集群模式2.部署基于镜像队列模式的RabbitMQ高可用集群2.1.镜像队列集群原理2.2.分别在两台机器中部署RabbitMQ2.2.1.基础环境配置2.2.2.安装Erlang环境2.2.3.部署RabbitMQ并开启管理界面2.2.4.配置RabbitMQ各节点变量信息2.2.5.访问RabbitMQ后台管理系…

vue3基础------ 下

目录 二.vue3基础 5.事件处理器 5-1 事件处理器 - 告别原生事件 5-2 事件修饰符 - 事件偷懒符? 6.表单控件绑定 6-1表单输入绑定-一根绳上的蚂蚱 6-2购物车案例 6-3表单修饰符 7.计算属性 7-1计算属性-很聪明,会缓存 7-2 可写计算属性 7-3之前案例的小改…

ModaHub魔搭社区:向量数据库MIlvus服务端配置(三)

目录 gpu 区域 logs 区域 metric_config 区域 gpu 区域 在该区域选择是否在 Milvus 里启用 GPU 用于搜索和索引创建。同时使用 CPU 和 GPU 可以达到资源的最优利用,在特别大的数据集里做搜索时性能更佳。 若要切换到 CPU-only 模式,只要将 enable 设…

【敬伟ps教程】色彩基础

文章目录 在通道内发现色光吸管工具与颜色面板在RGB通道创造色彩色彩三要素选择方式CMYK模式详解 在通道内发现色光 RGB基于色光的混合模式,是最常见的色彩模式 我们新建一个 RGB 画布,前景色改为黑色,AltDelete填充前景色。我们查看图像–…

实验篇(7.2) 18. 星型安全隧道 - 分支互访(IPsec) ❀ 远程访问

【简介】Hub-and-Spoke:各分支机构利用VPN设备与总部VPN设备建立VPN通道后,除了可以和总部进行通讯,还可以利用总部VPN设备互相进行数据交换,而各VPN分支机构不需要进行VPN的隧道连接。 实验要求与环境 OldMei集团深圳总部部署了域…

C# 线程基础 二

目录 八、前台线程和后台线程 九、线程参数的传递 十、线程中的 lock 关键字 十一、Monitor类锁定 结束 八、前台线程和后台线程 默认情况下,显式创建的线程是前台线程,通过手动的设置 Thread 类的属性 IsBackground true 来指示当前线程为一个后…

让GPT-3、ChatGPT、GPT-4一起做脑筋急转弯,GPT-4一骑绝尘!

作者 | python 一个烙饼煎一面一分钟,两个烙饼煎两面几分钟? 让你来回答,是不是一不小心就掉到沟里了?如果让大语言模型来做这种脑筋急转弯会怎样呢?研究发现,模型越大,回答就越可能掉到沟里&a…

VScode连接远程服务器

VScode连接远程服务器 文章目录 VScode连接远程服务器下载扩展通过扩展连接服务器在输入框中输入usernameip进行连接通过已保存的配置信息进行连接 连接成功之后访问服务器文件访问文件 下载扩展 下载以下三个扩展 Remote-SSH Remote - SSH: Editing Configuration Files R…

Docker Network 基础

一、是什么 Docker网络是Docker容器之间和容器与外部网络之间的通信和连接的一种机制。在Docker中,每个容器都可以有自己的网络栈,包括网络接口、IP地址和网络配置。Docker网络提供了一种灵活且可定制的方式,使得容器之间可以相互通信&#x…

【单元测试】Junit 4(二)--eclipse配置Junit+Junit基础注解

目录 1.0 前言 1.1 配置Junit 4 1.1.1 安装包 1.1.2 创建Junit项目 1.2 Junit 4 注解 1.2.1 测试用例相关的注解 1.2.1.1 Before 1.2.1.2 After 1.2.1.3 BeforeClass 1.2.1.4 AfterClass 1.2.1.5 Test 1.2.1.6 Ignore 1.2.1.7 示例 1.2.2 打包测试Suite相关的注解…