常用分类损失CE Loss、Focal Loss及GHMC Loss理解与总结

news2024/9/28 1:07:33

一、CE Loss

定义

交叉熵损失(Cross-Entropy Loss,CE Loss)能够衡量同一个随机变量中的两个不同概率分布的差异程度,当两个概率分布越接近时,交叉熵损失越小,表示模型预测结果越准确。

公式

二分类

二分类的CE Loss公式如下,

其中,M:正样本数量,N:负样本数量,y_{i}:真实值, p_{i}:预测值

多分类

在计算多分类的CE Loss时,首先需要对模型输出结果进行softmax处理。公式如下,

其中, output:模型输出,p:对模型输出进行softmax处理后的值, ​​​​​:真实值的one hot编码​(假设模型在做5分类,如果y_{i}=2,则=[0,0,1,0,0])

代码实现

二分类

import torch
import torch.nn as nn
import math

criterion = nn.BCELoss()
output = torch.rand(1, requires_grad=True)
label = torch.randint(0, 1, (1,)).float()
loss = criterion(output, label)

print("预测值:", output)
print("真实值:", label)
print("nn.BCELoss:", loss)

for i in range(label.shape[0]):
    if label[i] == 0:
        res = -math.log(1-output[i])
    elif label[i] == 1:
        res = -math.log(output[i])
print("自己的计算结果", res)


"""
预测值: tensor([0.7359], requires_grad=True)
真实值: tensor([0.])
nn.BCELoss: tensor(1.3315, grad_fn=<BinaryCrossEntropyBackward0>)
自己的计算结果 1.331509556677378
"""

多分类

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("预测值:", output)
print("真实值:", label)
print("nn.CrossEntropyLoss:", loss)

output = torch.softmax(output, dim=1)
print("softmax后的预测值:", output)

one_hot = torch.zeros_like(output).scatter_(1, label.view(-1, 1), 1)
print("真实值对应的one_hot编码", one_hot)

res = (-torch.log(output) * one_hot).sum()
print("自己的计算结果", res)


"""
预测值: tensor([[-0.7459, -0.3963, -1.8046,  0.6815,  0.2965]], requires_grad=True)
真实值: tensor([1])
nn.CrossEntropyLoss: tensor(1.9296, grad_fn=<NllLossBackward0>)
softmax后的预测值: tensor([[0.1024, 0.1452, 0.0355, 0.4266, 0.2903]], grad_fn=<SoftmaxBackward0>)
真实值对应的one_hot编码 tensor([[0., 1., 0., 0., 0.]])
自己的计算结果 tensor(1.9296, grad_fn=<SumBackward0>)
"""

二、Focal Loss

定义

虽然CE Loss能够衡量同一个随机变量中的两个不同概率分布的差异程度,但无法解决以下两个问题:1、正负样本数量不平衡的问题(如centernet的分类分支,它只将目标的中心点作为正样本,而把特征图上的其它像素点作为负样本,可想而知正负样本的数量差距之大);2、无法区分难易样本的问题(易分类的样本的分类错误的损失占了整体损失的绝大部分,并主导梯度)

为了解决以上问题,Focal Loss在CE Loss的基础上改进,引入了:1、正负样本数量调节因子以解决正负样本数量不平衡的问题;2、难易样本分类调节因子以聚焦难分类的样本

公式

二分类

公式如下,

 

​​​​​​​

其中,\alpha:正负样本数量调节因子,\gamma:难易样本分类调节因子

多分类

其中,\alpha _{y_{i}}y_{i}类别的权重

代码实现

二分类

def sigmoid_focal_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = -1,
    gamma: float = 2,
    reduction: str = "none",
) -> torch.Tensor:
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
    Returns:
        Loss tensor with the reduction option applied.
    """
    inputs = inputs.float()
    targets = targets.float()
    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss

步骤1、首先对输入进行sigmoid处理,

p = torch.sigmoid(inputs)

步骤2、随后求出CE Loss,

ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

步骤3、定义p_{t}^{i},公式为:

p_t = p * targets + (1 - p) * (1 - targets)

步骤4、为CE Loss添加难易样本分类调节因子,

loss = ce_loss * ((1 - p_t) ** gamma)

步骤5、定义\alpha _{t}^{i},公式为:

alpha_t = alpha * targets + (1 - alpha) * (1 - targets)

步骤6、为步骤4的损失添加正负样本数量调节因子,

loss = alpha_t * loss

多分类

def multi_cls_focal_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: torch.Tensor,
    gamma: float = 2,
    reduction: str = "none",
) -> torch.Tensor:
    
    inputs = inputs.float()
    targets = targets.float()
    ce_loss = nn.CrossEntropyLoss()(inputs, targets, reduction="none")
    one_hot = torch.zeros_like(inputs).scatter_(1, targets.view(-1, 1), 1)
    p_t = inputs * one_hot
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * one_hot
        loss = alpha_t * loss

    return loss

三、GHMC Loss

定义

Focal Loss在CE Loss的基础上改进后,解决了正负样本不平衡以及无法区分难易样本的问题,但也会过分关注难分类的样本(离群点),导致模型学歪。为了解决这个问题,GHMC(Gradient Harmonizing Mechanism-C)定义了梯度模长,该梯度模长正比于分类的难易程度,目的是让模型不要关注那些容易学的样本,也不要关注那些特别难分的样本

公式

1、定义梯度模长

二分类的CE Loss公式如下,

假设x是模型的输出,假设p=sigmoid(x),求损失对x的偏导,

因此,定义梯度模长如下,

其中, p:预测值,p^{\ast }:真实值

梯度模长与样本数量的关系如下,

2、定义梯度密度(单位梯度模长g上的样本数量

  

其中,g_{k}:第k个样本的梯度模长,\delta _{\varepsilon }(g_{k},g)g_{k}(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})范围内的样本数量,l_{\varepsilon }(g):区间(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})的长度

3、定义梯度密度协调参数(gradient density harmonizing parameter)

其中,N:样本总数

 4、定义GHMC Loss

 

代码实现

def _expand_binary_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(labels >= 1).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds] - 1] = 1
    bin_label_weights = label_weights.view(-1, 1).expand(
        label_weights.size(0), label_channels)
    return bin_labels, bin_label_weights


class GHMC(nn.Module):
    def __init__(
            self,
            bins=10,
            momentum=0,
            use_sigmoid=True,
            loss_weight=1.0):
        super(GHMC, self).__init__()
        self.bins = bins
        self.momentum = momentum
        self.edges = [float(x) / bins for x in range(bins+1)]
        self.edges[-1] += 1e-6
        if momentum > 0:
            self.acc_sum = [0.0 for _ in range(bins)]
        self.use_sigmoid = use_sigmoid
        self.loss_weight = loss_weight

    def forward(self, pred, target, label_weight, *args, **kwargs):
        """ Args:
        pred [batch_num, class_num]:
            The direct prediction of classification fc layer.
        target [batch_num, class_num]:
            Binary class target for each sample.
        label_weight [batch_num, class_num]:
            the value is 1 if the sample is valid and 0 if ignored.
        """
        if not self.use_sigmoid:
            raise NotImplementedError
        # the target should be binary class label
        if pred.dim() != target.dim():
            target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))
        target, label_weight = target.float(), label_weight.float()
        edges = self.edges
        mmt = self.momentum
        weights = torch.zeros_like(pred)

        # 计算梯度模长
        g = torch.abs(pred.sigmoid().detach() - target)

        valid = label_weight > 0
        tot = max(valid.float().sum().item(), 1.0)
        
        # 设置有效区间个数
        n = 0
        for i in range(self.bins):
            inds = (g >= edges[i]) & (g < edges[i+1]) & valid
            num_in_bin = inds.sum().item()
            if num_in_bin > 0:
                if mmt > 0:
                    self.acc_sum[i] = mmt * self.acc_sum[i] \
                        + (1 - mmt) * num_in_bin
                    weights[inds] = tot / self.acc_sum[i]
                else:
                    weights[inds] = tot / num_in_bin
                n += 1
        if n > 0:
            weights = weights / n

        loss = F.binary_cross_entropy_with_logits(
            pred, target, weights, reduction='sum') / tot
        return loss * self.loss_weight

步骤一、将梯度模长划分为bins(默认为10)个区域,

self.edges = [float(x) / bins for x in range(bins+1)]
"""
[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000, 1.0000]
"""

步骤二、计算梯度模长

g = torch.abs(pred.sigmoid().detach() - target)

步骤三、计算落入不同bin区间的梯度模长数量

valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0
for i in range(self.bins):
    inds = (g >= edges[i]) & (g < edges[i+1]) & valid
    num_in_bin = inds.sum().item()
    if num_in_bin > 0:
        if mmt > 0:
            self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_bin
            weights[inds] = tot / self.acc_sum[i]
        else:
            weights[inds] = tot / num_in_bin
        n += 1
if n > 0:
    weights = weights / n

步骤四、计算GHMC Loss

loss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / tot * self.loss_weight

【参考文章】

Focal Loss的理解以及在多分类任务上的使用(Pytorch)_focal loss 多分类_GHZhao_GIS_RS的博客-CSDN博客

focal loss 通俗讲解 - 知乎

Focal Loss损失函数(超级详细的解读)_BigHao688的博客-CSDN博客

5分钟理解Focal Loss与GHM——解决样本不平衡利器 - 知乎 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/737832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【QT】QT搭建OpenCV环境

QT/OpenCV 01、开始之前02、QT03、CMake04、OpenCV05、配置06、测试 01、开始之前 本文版本&#xff1a; 1、QT&#xff1a;Based on Qt 5.12.2 (MSVC 2017, 32 bit)&#xff0c;编译方式是MinGW 2、CMake&#xff1a;cmake-3.27.0-rc4-windows-x86_64.msi 3、OpenCV&#xff1…

2023年值得入手的开放式耳机推荐,蓝牙耳机的选购指南分享推荐

身为一个音乐爱好者&#xff0c;出于对音质和佩戴舒适的追求&#xff0c;也有入手了很多品类的耳机&#xff0c;其中不乏有有线耳机、无线蓝牙耳机&#xff0c;两种不同的音频传输方式大类&#xff0c;其各自所拥有的特性也是不同的。而居于后者的无线蓝牙耳机&#xff0c;在现…

【Java基础教程】(八)面向对象篇 · 第二讲:Java 数组全面解析——动态与静态初始化、二维数组、方法参数传递、排序与转置、对象数组、操作API~

Java基础教程之面向对象 第二讲 本节学习目标1️⃣ 概念1.1 动态初始化1.2 静态初始化 2️⃣ 二维数组3️⃣ 数组与方法参数的传递4️⃣ 数组排序5️⃣ 数组转置6️⃣ 对象数组7️⃣ 数组操作API7.1 数组复制7.2 数组排序 &#x1f33e; 总结 本节学习目标 掌握数组的动态及静…

水库监测中仪器安装及监测结果的要求有哪些

水库监测点位布设需要根据水库运行情况和安全监测的需求来进行&#xff0c;一般分为基础监测点位和重要部位监测点位&#xff0c;基础监测点位主要包括上游水位、上游库水位变幅、库岸稳定以及上下游坝坡稳定等。重要部位监测点位主要包括坝轴线、溢洪道进口和泄水洞出口等部位…

前端报错:“Uncaught SyntaxError: missing ) after argument list“只是参数列表后面缺少 “)”?

报错"Uncaught SyntaxError: missing ) after argument list"&#xff0c;字面翻译过来的意思&#xff1a;语法错误: 参数列表后面缺少 )。 一直以为是少了 一个小括号找了好久 发现并不是 据提示是参数列表的问题&#xff0c;找到文件中存在参数列表的地方。如下图…

如何利用MyBatis完成web项目的环境搭建(导入核心依赖包、日志、编译环境,配置文件以及Druid连接池)

目录 项目环境搭建 servlet实例 核心依赖 导入日志 编译环境 mapper注册 resouces中 dao中 MyBatis配置文件 实例效果 导入配置文件 Druid连接池 Druid连接池是什么&#xff1f; 如何配置Druid连接池&#xff1f; 实体类 实例效果 项目环境搭建 1.在pom.xml中…

STM32 Proteus UCOSII系统锅炉报警系统设计压力温度水位-0059

STM32 Proteus UCOSII系统锅炉报警系统设计压力温度水位-0059 Proteus仿真小实验&#xff1a; STM32 Proteus UCOSII系统锅炉报警系统设计压力温度水位-0059 功能&#xff1a; 硬件组成&#xff1a;51单片机 8位数码管MAX7219数码管驱动模块多个按键LED灯蜂鸣器 1.准确测量…

IronOCR for .NET 2023.7.0 Crack

IronOCR for .NET 关于 读取 .NET 应用程序中图像和 Pdf 文本的高级 OCR &#xff08;光学字符识别&#xff09; 库。 IronOCR for .NET enables software engineers to read text content from images & PDFs in .NET applications and Web sites. Read text and barcod…

HarmonyOS/OpenHarmony应用开发-程序包安装、卸载、更新流程

一、应用程序包安装和卸载流程 1.开发者 开发者可以通过调试命令进行应用的安装和卸载&#xff0c;可参考多HAP的调试流程。 图1 应用程序包安装和卸载流程&#xff08;开发者&#xff09; 2.终端设备用户 开发者将应用上架应用市场后&#xff0c;终端设备用户可以在终端设…

python_day4_dict

字典dict:键值对(无重复,无下标索引&#xff09; my_dict {python: 99, java: 88, c: 77, c: 66} my_dict2 {} # 空字典 my_dict3 dict() print(f"my_dict:{my_dict},类型为&#xff1a;{type(my_dict)}") print(f"my_dict2:{my_dict2},类型为&#xff1a;…

AI应用系列--- TalkingPhoto 会说话的照片

利用HeyGen的服务可以生成有趣的Talkingphoto&#xff0c;方法有二&#xff1a; 1、访问HeyGen - AI Video Generator 网站&#xff0c;登录后即可根据提示或者案例生成talkingphoto 2、是使用HeyGen的 Discord​​​​​​机器人&#xff1a;https://discord.com/channels/1…

MySQL数据库期末项目 图书馆管理系统

1 项目需求分析 1.1 项目名称 图书馆管理系统 1.2 项目功能 在以前大多部分图书馆都是由人工直接管理&#xff0c;其中每天的业务和操作流程非常繁琐复杂&#xff0c;纸质版的登记信息耗费了大量的人力物力。因此图书馆管理系统应运而生&#xff0c;该系统采用智能化设计&#…

我来为你揭秘如何将音频转文字才简单

曾经有一位聋哑人士&#xff0c;他很想写一本回忆录&#xff0c;但是因为无法听取自己的回忆录音&#xff0c;他不得不寻找其他方法。于是&#xff0c;他试着用一些软件将他的录音转成文字&#xff0c;但是结果却非常糟糕&#xff0c;充斥着大量错误和不连贯的词语。于是&#…

【大虾送书第一期】《高并发架构实战:从需求分析到系统设计》

目录 ✨写在前面 ✨足够真实的高并发系统设计场景 ✨贴合工作场景的设计文档形式 ✨求同存异的典型系统架构案例 &#x1f990;博客主页&#xff1a;大虾好吃吗的博客 &#x1f990;专栏地址&#xff1a;免费送书活动专栏地址 写在前面 很多软件工程师的职业规划是成为架构师&a…

手机副业哪些靠谱,推荐几个兼职思路

科思创业汇 大家好&#xff0c;这里是科思创业汇&#xff0c;一个轻资产创业孵化平台。赚钱的方式有很多种&#xff0c;我希望在科思创业汇能够给你带来最快乐的那一种&#xff01; 下面给大家介绍几个靠谱的兼职项目 1.问答答主 知乎、百度、悟空等渠道做问答&#xff0c;…

【手把手】一篇讲清楚FastDFS的安装及使用

分布式存储发展历程 前段时间618活动火热进行&#xff0c;正是购物的好时机。当我们访问这些电商网站的时候&#xff0c;每一个商品都会有各式各样的图片展示介绍&#xff0c;这些图片一张两张可以随便丢在服务器的某个文件夹中&#xff0c;可是电商网站如此大体量的图片&…

XSS漏洞学习笔记

浏览器安全 同源策略 影响源的因素&#xff1a;host,子域名,端口,协议 a.com通过以下代码: <script scrhttp://b.com/b.js> 加载了b.com上的b.js&#xff0c;但是b.js是运行在a.com页面中的&#xff0c;因此相对于当前打开的页面(a.com)来说&#xff0c;b.js的源就应该…

Nodejs 学习笔记

Author&#xff1a;德玛玩前端 Date: 2023-07-06 Nodejs 一、Nodejs概述 1.1、什么是JavaScript 1995年由Netscape公司退出&#xff0c;后经ECMA统一标准的脚本语言。通常狭义上理解的JS是指在浏览器内置的JS解释器中运行的&#xff0c;主要用途是操作网页内容&#xff0c;实…

跨境电商亚马逊卖家为何要使用云服务器?

云服务器&#xff0c;作为跨境电商亚马逊开店必备的工具之一&#xff0c;深受各方卖家的青睐&#xff0c;由于跨境电商亚马逊平台有着一人一店的规定&#xff0c;但很多卖家朋友&#xff0c;为了获得更多的流量&#xff0c;便去开设多个店铺进行引流&#xff0c;这样操作极易诱…

【数据结构】搜索二叉树/map/set

二叉搜索树&#xff08;搜索二叉树&#xff09; 1.1.二叉搜索树概念 二叉搜索树又称二叉排序树&#xff0c;它或者是一棵空树&#xff0c;或者是具有以下性质的二叉树: 若它的左子树不为空&#xff0c;则左子树上所有节点的值都小于根节点的值 若它的右子树不为空&#xff0c;则…