pytorch——损失函数之nn.BCELoss二进制交叉熵和 nn.BCEWithLogitsLoss

news2025/1/14 4:22:23

文章目录

  • 1、pytorch损失函数之nn.BCELoss()(二进制交叉熵)
    • 1.1 是什么?
    • 1.2 怎么代码实现和代码使用?
    • 1.3 推导过程
      • 分析交叉熵作为损失函数的梯度情况:
      • 举一个sigmoid导致的梯度消失的MSE损失的例子
    • 1.3 应用场景
      • 1.3.1 二分类
      • 1.3.2 多分类
      • 1.3.3 位置的回归
      • 1.3.4 用途的一个示例
  • 2、BCEWithLogitsLoss
  • 参考

1、pytorch损失函数之nn.BCELoss()(二进制交叉熵)

基础的损失函数 BCE (Binary cross entropy)

1.1 是什么?

这种BCE损失是交叉熵损失的一种特殊情况,因为当你只有两个类时,它可以被简化为一个更简单的函数。这用于测量例如自动编码器中重建的误差。这个公式假设x和y是概率,所以它们严格地在0和1之间
在这里插入图片描述

1.2 怎么代码实现和代码使用?

pytorch中,表示求一个二分类的交叉熵:

class torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction=‘elementwise_mean’)

它的loss如下:
l ( x , y ) = L = { l 1 , l 2 , . . . , l n } , 其中 l n = − w n [ y n l o g y n ^ + ( 1 − y n ) l o g ( 1 − y n ^ ) ] l(x,y)=L=\{l_1,l_2,...,l_n\},其中l_n=-w_n[y_nlog\hat{y_n}+(1-y_n)log(1-\hat{y_n})] l(x,y)=L={l1,l2,...,ln},其中ln=wn[ynlogyn^+(1yn)log(1yn^)]

这里n表示批量大小。 w n w_n wn​表示权重。

当参数reduce设置为 True,且参数size_average设置为True时,表示对交叉熵求均值,当size_average设置为Flase时,表示对交叉熵求和。参数weight设置的是 w n w_n wn​,其是一个tensor, 且size与批量数一样(不设置时可能都为1)。目标值 y的范围是0-1之间。输入输出的维度都是 ( ( N , ∗ ) (N,*) N,N是批量数,*表示目标值维度。

1.3 推导过程

我们定义:一个二项分布,随机变量只有两种可能值,所以是一个二分类。定义二分类的交叉熵形式:

− y l o g y ^ − ( 1 − y ) l o g ( 1 − y ^ ) . . . . . . . . . . . . . . ( 1 ) -ylog\hat{y}-(1-y)log(1-\hat{y})..............(1) ylogy^(1y)log(1y^)..............(1)
其中 y ^ \hat{y} y^​是输出值在0-1之间.

就是将最后分类层的每个输出节点使用sigmoid激活函数激活,然后对每个输出节点和对应的标签计算交叉熵损失函数,具体图示如下所示:

图片来源:https://www.zhihu.com/question/358811772/answer/920451413

在这里插入图片描述

左上角就是对应的输出矩阵(batch_size x num_classes), 然后经过sigmoid激活后再与绿色标签计算交叉熵损失,计算过程如右方所示。

import torch
import numpy as np

pred = np.array([[-0.4089, -1.2471, 0.5907],
                [-0.4897, -0.8267, -0.7349],
                [0.5241, -0.1246, -0.4751]])
label = np.array([[0, 1, 1],
                  [0, 0, 1],
                  [1, 0, 1]])

pred = torch.from_numpy(pred).float()
label = torch.from_numpy(label).float()

crition1 = torch.nn

在这里插入图片描述

输出结果一致,因此训练时使用BCEWithLogitsLoss()和MultiLabelSoftMarginLoss()都可。

分析交叉熵作为损失函数的梯度情况:

我们假设,对于批量样本 ( x 1 , y 1 ) , ( x 2 , y 2 ) . . . {(x_1,y_1),(x_2,y_2)...} (x1,y1),(x2,y2)...则可以对交叉熵求和或者求均值:

∑ i − y i l o g y i ^ − ( 1 − y i ) l o g ( 1 − y i ^ ) . . . . . . . . . . . ( 2 ) \sum_{i}-y_ilog\hat{y_i}-(1-y_i)log(1-\hat{y_i})...........(2) iyilogyi^(1yi)log(1yi^)...........(2)
(这里我们将标签值y视作先验分布, y ^ \hat{y} y^​为模型分布)

若激活函数使用的是sigmoid函数,则 y ^ = σ ( z ) \hat{y}=\sigma(z) y^=σ(z),其中 z = w x + b z=wx+b z=wx+b。采用链式法则求导,则有:

1 n ∑ i − y i l o g y i ^ − ( 1 − y i ) l o g ( 1 − y i ^ ) . . . . . . . . . . ( 2 ) \frac{1}{n}\sum_{i}-y_ilog\hat{y_i}-(1-y_i)log(1-\hat{y_i})..........(2) n1iyilogyi^(1yi)log(1yi^)..........(2)

求导,可得:
∂ L ∂ w = − 1 n ∑ i ( y σ ( z ) − 1 − y 1 − σ ( z ) ) ∂ σ ∂ w = − 1 n ∑ i ( y σ ( z ) − 1 − y 1 − σ ( z ) ) σ ′ x \frac{\partial L}{\partial w}=-\frac{1}{n}\sum_i(\frac{y}{\sigma(z)}-\frac{1-y}{1-\sigma(z)})\frac{\partial \sigma}{\partial w}=-\frac{1}{n}\sum_i(\frac{y}{\sigma(z)}-\frac{1-y}{1-\sigma(z)}) {\sigma}'x wL=n1i(σ(z)y1σ(z)1y)wσ=n1i(σ(z)y1σ(z)1y)σx

由于 σ ( z ) = 1 / ( 1 + e − z ) \sigma(z)=1/(1+e^{-z}) σ(z)=1/(1+ez)

所以最终得到: ∂ L ∂ w = 1 n ∑ i x ( σ ( z ) − y ) \frac{\partial L}{\partial w}=\frac{1}{n}\sum_i x(\sigma(z)-y) wL=n1ix(σ(z)y)

而对偏置的导数也等于 ∂ L ∂ b = 1 n ∑ i ( σ ( z ) − y ) \frac{\partial L}{\partial b}=\frac{1}{n}\sum_i (\sigma(z)-y) bL=n1i(σ(z)y)可以看见使用交叉熵作为损失函数后,反向传播的梯度不在于sigmoid函数的导数有关了。这就从一定程度上避免了梯度消失。

举一个sigmoid导致的梯度消失的MSE损失的例子

二次函数为损失函数的梯度情况,梯度消失问题

二次函数 L = ( y − y ^ ) 2 2 L=\frac{(y-\hat{y})^2}{2} L=2(yy^)2

采用链式法则求导,则有:

∂ L ∂ w = ( y ^ − y ) σ ( z ) ′ x \frac{\partial L}{\partial w}=(\hat{y}-y){\sigma(z)}'x wL=(y^y)σ(z)x
∂ L ∂ b = ( y ^ − y ) σ ( z ) ′ \frac{\partial L}{\partial b}=(\hat{y}-y){\sigma(z)}' bL=(y^y)σ(z)
可以看出梯度都与sigmoid函数的梯度有关,如下图所示,sigmoid函数在两端的梯度均接近0,这导致反向传播的梯度也很小,这就这就不利于网络训练,这就是 梯度消失问题 。

在这里插入图片描述

1.3 应用场景

在机器学习或者深度学习中,分类问题是一个最常见的任务,分类问题一般又分为:二分类任务、多分类任务和多标签分类任务

  • 二分类任务:输出只有0和1两个类别;
  • 多分类任务:一般指的是输出只有一个标签,类别之间是互斥的关系;
  • 多标签分类任务:输出的结果是多标签,类别之间可能互斥也可能有依赖、包含等关系。

在面对不同的分类问题的时候,选择的loss function也不一样,二分类和多标签分类通常使用sigmoid函数而多分类则一般使用softmax函数(互斥性质)。

1.3.1 二分类

BCE可以处理二分类问题,而且通常是sigmoid+BCELoss。

This loss is a special case of cross entropy for when you have only two classes so it can be reduced to a simpler function. This is used for measuring the error of a reconstruction in, for example, an auto-encoder. This formula assume xx and yy are probabilities, so they are strictly between 0 and 1.

1.3.2 多分类

若是遇到多分类问题使用二进制交叉熵。
目标:多分类问题 => 多个二分类问题

比如我们有3个类别,那么我们通过softmax得到 y ^ = [ 0.2 , 0.5 , 0.3 ] \hat{y}=[0.2,0.5,0.3] y^=[0.2,0.5,0.3]的到的一个一个样本的分类结果,这个结果的通俗解释就是:为第一类的概率为0.2,为第二类的概率为0.5,为第三类的结果过0.3。
假设这个样本真实类别为第二类,那么我们希望模型输出的结果过应该是 y = [ 0 , 1 , 0 ] y=[0,1,0] y=[0,1,0],这个就是标签值。那么损失函数可以使用交叉熵:

L = − ∑ k 3 y k l o g ( y ^ ) L=-\sum_k^3y_klog(\hat{y}) L=k3yklog(y^)

可以看见实际上这个求和只有一项。也就是 L = − l o g ( 0.5 ) L=-log(0.5) L=log(0.5)

pytorch中提供了多分类使用的损失函数nn.CrossEntropyLoss()使用的原理,与这里类似。

作者:杨夕
链接:https://www.zhihu.com/question/358811772/answer/2677137156
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

class BCELosswithLogits(nn.Module):
      def __init__(self, pos_weight=1, reduction='mean'):
          super(BCELosswithLogits, self).__init__()
          self.pos_weight = pos_weight
          self.reduction = reduction

      def forward(self, logits, target):
          # logits: [N, *], target: [N, *]
          logits = F.sigmoid(logits)
          loss = - self.pos_weight * target * torch.log(logits) - \
                (1 - target) * torch.log(1 - logits)
          if self.reduction == 'mean':
              loss = loss.mean()
          elif self.reduction == 'sum':
              loss = loss.sum()
          return loss

存在问题:由于 head classes的主导以及negative instances的影响,导致 BCE Loss 函数 容易受到 类别不均衡问题 影响;

优化方向:绝大部分balancing方法都是reweight BCE从而使得稀有的instance-label对能够得到得到合理的“关注”

1.3.3 位置的回归

使用中心位置使用BCE是有理论依据的,可以认为,效果等价于square L2 norm(这个结论的出处还没找到,等找到了补充,20230506)

1.3.4 用途的一个示例

在这里插入图片描述

2、BCEWithLogitsLoss

nn.BCEWithLogitsLoss() 函数等效于 sigmoid + nn.BCELoss。

在这里插入图片描述

BCEWithLogitsLoss损失函数把 Sigmoid 层集成到了 BCELoss 类中. 该版比用一个简单的 Sigmoid 层和 BCELoss 在数值上更稳定, 因为把这两个操作合并为一个层之后, 可以利用 log-sum-exp 的 技巧来实现数值稳定.

torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)

参数:

    weight (Tensor, optional) – 自定义的每个 batch 元素的 loss 的权重. 必须是一个长度 为 “nbatch” 的 Tensor

参考

https://atcold.github.io/pytorch-Deep-Learning/en/week11/11-1/
https://mp.weixin.qq.com/s/AwgQcafQ2pAuU7_0gEFnmg
https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#normalization-layers-source
https://samuel92.blog.csdn.net/article/details/105900876
https://blog.csdn.net/geter_CS/article/details/84747670

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/494747.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java版工程项目管理系统平台,助力工程企业实现数字化管理系统源代码

Java版工程项目管理系统 Spring CloudSpring BootMybatisVueElementUI前后端分离 功能清单如下: 首页 工作台:待办工作、消息通知、预警信息,点击可进入相应的列表 项目进度图表:选择(总体或单个)项目显示1…

一文讲透TCP/IP协议 | 图解+秒懂+史上最全

目录 🙋‍♂️ TCP/IP协议详解 🙋‍♂️ TCP/IP协议的分层模型 OSI模型的七层框架 TCP/IP协议与七层ISO模型的对应关系 (一)TCP/IP协议的应用层 (二)TCP/IP协议的传输层 (三)…

Vuex从了解到实际运用(二)——获取vuex中的全局状态(state getters)

vuex从了解到实际运用——获取vuex中的全局状态state getters 知识回调(不懂就看这儿!)场景复现项目实战vuex定义一个store实例在store中定义数据在组件中获取值vuex的计算属性通过getters获取全局状态state和getters获取全局状态的区别 知识…

Windows安装Docker 容器教程

Windows安装Docker 容器教程 什么是docker I. 简介 什么是 Docker 容器 Docker 容器是一种轻量级、可移植、自包含的软件打包和部署技术。它可以将应用程序和依赖项打包在一个可移植的容器中,并提供一个一致的运行环境,无论在哪个计算机上运行都能够…

Copyleaks:AI抄袭和内容检测工具

【产品介绍】 Copyleaks是一个基于AI人工智能的抄袭和内容检测工具,可以帮助用户在互联网上发现和防止内容被盗用。支持检测各种类型的文本,包括学术论文、网站内容、商业文件、法律合同、创意作品等,并提供详细的相似度报告和原始来源链接。…

基于R语言APSIM模型应用

随着数字农业和智慧农业的发展,基于过程的农业生产系统模型在模拟作物对气候变化的响应与适应、农田管理优化、作物品种和株型筛选、农田固碳和温室气体排放等领域扮演着越来越重要的作用。APSIM (Agricultural Production Systems sIMulator)模型是世界知名的作物生…

【Hello Network】TCP协议

作者:小萌新 专栏:网络 作者简介:大二学生 希望能和大家一起进步 本篇博客简介:较为详细的介绍TCP协议 TCP协议 TCP协议可靠性TCP的协议格式序号与确认序号窗口大小六个标志位 确认应答机制 (ACK)超时重传机…

Spring整合Swagger自动生成API文档

认识Swagger Swagger 是一个规范和完整的框架,用于生成、描述、调用和可视化 RESTful 风格的 Web 服务。总体目标是使客户端和文件系统作为服务器以同样的速度来更新。参数和模型紧密集成到服务器端的代码,允许API来始终保持同步。 作用: …

【LeetCode】数据结构题解(6)[回文链表]

回文链表 1.题目来源2.题目描述3.解题思路4.代码展示 所属专栏:玩转数据结构题型 博主首页:初阳785 代码托管:chuyang785 感谢大家的支持,您的点赞和关注是对我最大的支持!!! 博主也会更加的努力…

C++入门2(缺省参数 inline函数 函数重载 函数模板)

C入门2 缺省参数结合优先级 inline函数vs中的测试实例inline函数要点内联函数与宏定义区别: 函数重载定义名字粉碎技术C编译时函数名修饰约定规则 函数模板 缺省参数 函数定义时,缺省值赋值是从右向左依次赋值 调用函数时,从左向右依次给实参值&#xf…

【HTTP/1.1、HTTP/2、HTTP/3】

文章目录 HTTP/1.1 如何优化?避免发送HTTP请求减少HTTP次数减少 HTTP 响应的数据大小 HTTP/2HTTP/1.1性能问题HTTP/2的性能优化头部压缩二进制帧(重点)并发传输服务器主动推送资源 HTTP/2问题总结 HTTP/3HTTP/2的性能问题队头阻塞TCP 与 TLS …

跟着我学 AI丨打败李世石和柯洁的 AlphaGo

强化学习是一种人工智能的方法,它模仿了人类学习的方式。通过试错来学习,实现从经验中提取知识的目的。强化学习的核心思想是基于奖励的学习,它的目标是通过在环境中采取行动,并根据行动结果获得奖励,从而学会最优的行…

CNNs: AlexNet补充

CNNs: AlexNet的补充 导言对AlexNet模型进行调整模型不同层的表征其他探索总结 导言 上上篇和上一篇我们详细地讲述了AlexNet的网络结构和不同超参数对同一数据集的不同实验现象。 本节,我们就AlexNet的一些其他相关问题进行解剖,如修改AlexNet参数量调…

JVM内存模型基础

大家好,我是易安! 我们知道运行一个Java应用程序,我们必须要先安装JDK或者JRE包。这是因为Java应用在编译后会变成字节码,然后通过字节码运行在JVM中,而JVM是JRE的核心组成部分。 JVM不仅承担了Java字节码的分析&#…

JavaWeb ( 五 ) Servlet

2.3.Servlet Servlet(Server Applet)是Java Servlet的简称。 是在服务器端执行的 , 用于响应客户端请求的Java类。HttpServlet 是使用java语言对http通信的实现。 2.3.1.Servlet声明 在 web.xml 中声明Servlet的请求url及对应的类路径 , 3.0版本后可以…

APSIM模型

随着数字农业和智慧农业的发展,基于过程的农业生产系统模型在模拟作物对气候变化的响应与适应、农田管理优化、作物品种和株型筛选、农田固碳和温室气体排放等领域扮演着越来越重要的作用。APSIM (Agricultural Production Systems sIMulator)模型是世界知名的作物生…

趣谈西工大电子实习物联网智慧交通系统

学习简介: 物联网智慧交通系统是电子实习中相当有趣的一个环节,可以在一定程度上弥补没有被分配到智能小车的遗憾。在这个模块当中,你将在老师的带领下以完成两个小任务为驱动,让自身能力在八个学时当中充分锻炼。 下面这两张图…

微信小程序商城搭建--后端+前端+小程序端

介绍: 前端技术:React、AntdesignPro、umi、JavaScript、ES6、TypeScript、 小程序 后端技术:Springboot、Mybatis、Spring、Mysql 软件架构: 后端采用Springboot搭配前端React进行开发,完成用户管理、轮播图管理、…

[MySQL / Mariadb] 数据库学习-Linux中安装MySQL,YUM方式

[Mariadb] 数据库学习笔记 在Linux中安装MySQL,YUM方式mariadb 介绍安装启服务初始配置修改密码 密码策略,默认策略是1show variables; 查所有变量show variables like "%变量%"; 查特定的变量参数临时:永久: MySQL基本操作连接SQL…

使用@PropertySource加载配置文件

1.PropertySource和PropertySources注解 1.1.PropertySource注解概述 PropertySource注解是Spring 3.1开始引入的配置类注解。通过**PropertySource注解可以将properties配置文件中的key/value存储到Spring的Environment中,Environment接口提供了方法去读取配置文…