均方误差损失函数(MSE)和交叉熵损失函数详解

news2024/12/23 1:04:59

为什么需要损失函数

前面的文章我们已经从模型角度介绍了损失函数,对于神经网络的训练,首先根据特征输入和初始的参数,前向传播计算出预测结果,然后与真实结果进行比较,得到它们之间的差值。

损失函数又可称为代价函数或目标函数,是用来衡量算法模型预测结果和真实标签之间吻合程度(误差)的函数。通常会选择非负数作为预测值和真实值之间的误差,误差越小,则模型越好。

     有了这个损失函数,我们便可以采用优化算法更新网络参数,使得训练样本的平均损失最小。

而损失函数根据任务的不同,也可以分为不同的类型,下面进行介绍。

 

均方误差损失函数(MSE)

其中f(xi)是第i个样本的模型预测值,Yi是第i个样本的真实标签值,二者差值求平方,一共有n个样本,平方和求平均。

在回归问题中,均方误差损失函数用于度量样本点到回归曲线的距离,通过最小化平方损失使样本点可以更好地拟合回归曲线。由于无参数、计算成本低和具有明确物理意义等优点,MSE已成为一种优秀的距离度量方法。尽管MSE在图像和语音处理方面表现较弱,但它仍是评价信号质量的标准。

代码实现:

import numpy as np

# 自定义实现

def MSELoss(x:list,y:list):

    """    x:list,代表模型预测的一组数据    y:list,代表真实样本对应的一组数据    """

    assert len(x)==len(y)

    x=np.array(x)

    y=np.array(y)

    loss=np.sum(np.square(x - y)) / len(x)

    return loss

#计算过程举例x=[1,2]y=[0,1]loss=((1-0)**2 + (2-1)**2)÷2=(1+1)÷2=1

# pytorch版本

loss = nn.MSELoss()

predict = torch.randn(3, 5, requires_grad=True)

target = torch.randn(3, 5)

output = loss(predict, target)

从代码中可以看到,MSELoss需要的两个参数分别是真实标签值和模型预测值,两者可以是任意形状的张量,但二者形状和维度需要一致。就是说每个样本的预测值和标签值可以是任意维度的张量,这点要注意,在实际应用中时要认真考虑标签的形状。

 

交叉熵损失

pytorch中的CrossEntropyLoss()函数实际就是先把输出结果进行sigmoid,随后再放到传统的交叉熵函数中,就会得到结果。

交叉熵是信息论中的一个概念,最初用于估算平均编码长度,引入机器学习后,用于评估当前训练得到的概率分布与真实分布的差异情况。为了使神经网络的每一层输出从线性组合转为非线性逼近,以提高模型的预测精度,在以交叉熵为损失函数的神经网络模型中一般选用tanh、sigmoid、softmax或ReLU作为激活函数。

交叉熵损失函数刻画了实际输出概率与期望输出概率之间的相似度,也就是交叉熵的值越小,两个概率分布就越接近,特别是在正负样本不均衡的分类问题中,常用交叉熵作为损失函数。目前,交叉熵损失函数是卷积神经网络中最常使用的分类损失函数,它可以有效避免梯度消散。在二分类情况下也叫做对数损失函数。

一般的交叉熵用数学公式表示是:

-Q(x) log P(x)

其中Q(x)是真实值,P(x)是预测值。

当p(x)和Q(x)是矩阵的时候,就分别对其计算,然后求和即可

在pytorch中的交叉熵损失CrossEntropyLoss 包含了两部分,softmax和交叉熵计算,下面分别介绍这两部分

假设有 N 个样本,每个样本属于 C 个类别之一。对于第 i 个样本,它的真实类别标签为 y_i,模型的预测输出 logits 为xi​=(xi1​,xi2​,…,xiC​),其中xic表示第i个样本在第c 类别上的原始输出分数(logits)(注意这里是预测分数值,不是概率值)。

交叉熵损失的计算步骤如下:

(1)预测概率分布
对 logits 进行 softmax 操作,将预测输出其转换为概率分布:

其中 pic表示第i个样本属于第c类别的预测概率。

   此时预测输出的概率分布是f(xi)=(pi1,pi2,…,piC)

  1. 真实概率分布:

对于样本i,其真实分布会根据归属的类别自动创建一个one-hot概率分布,即所属类别的位置为1,其它均为0,则会输出一个one-hot概率分布Q(xi)=(qi1,qi2,…,qiC)。比如5个类别,第i个样本的真实类别为3,则Q(xi)[0,0,1,0,0]。

实际计算的时候不难发现target中为0经过乘法都是0了,因此最后只剩下正确类型的这个损失差距 最后公式可以演变成 - log Q(x)

(3)负对数似然(Negative Log-Likelihood)
对于单个样本,计算负对数似然:

其中是第i 个样本的交叉熵损失,但事实上,只在真实类别位置处概率为1,其余位置均为0,因此,可以进一步简化为

 其中,yi代表第i个样本在真实类别j=yi处的预测概率。其本质是利用真实概率分布筛选了预测概率分布在真实类别的概率值,并求负对数似然。

对于N个样本,则对这N个样本的交叉熵损失函数求和再求平均即可。

  1. 代码解析

cross_loss = torch.nn.CrossEntropyLoss(reduction='none')

#注意这里的预测输入是N*C,其中N是样本数,C是类别数,此时还不是概率,所以使用交叉熵损失函数的网络最后不需要softmax,损失函数自带。

input = torch.tensor([[4, 14, 19, 15],

                       [18, 6, 14, 7],

                       [18, 5, 3, 16]], dtype=torch.float)

#真实标签是每个样本的类别(1*N),api会自动生成one-hot概率分布(N*C)

    target = torch.tensor([0, 3, 2])

  #然后计算损失函数值

  loss = cross_loss(input, target)

    torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)

参数

    参数说明:

1. weight:

  • CE 和 BCE 系列都有此参数,用于为每个类别的 loss 设置权重,常用于类别不均衡问题;
  • weight 必须是 float 类型的 1D tensor,长度和类别长度一致:weight = torch.from_numpy(np.array([0.6, 0.2, 0.2])).float().to(device)
  • 注意:weight 加起来未必一定要等于 1,类 c 对应的 weight 为 W_c = (N-N_c) / N,数目越多的类,weight 越小,weight 越大,此类得到的 loss 被放大;

2. ignore_index:

  • 其中 BCE 系列没有此参数,此参数用于指定忽略某些类别的 loss;

3. size_average:

  • 该参数指定 loss 是否在一个 batch 内平均,即是否除以 N,目前此参数已经被弃用

4. reduce:

  • 目前此参数已经被弃用

5. reduction:

  • 此参数在新版本中是为了取代 ”size_average“ 和 "reduce" 参数的;
  • mean (default):返回 N 个 loss 的平均值;
  • sum:返回 N 个 loss 的 sum;
  • None:直接返回一个 batch 中的 N 个 loss;

6. pos_weight:

  • 只有 BCEWithLogits 系列有次参数;
  • 与 weight 参数的区别是:WIP;

(5)nn.CrossEntropyLoss=nn.LogSoftmax(dim=1)+nn.NLLLoss()

(5)多维交叉熵

文本类数据通常是三维数据,预测通常是(batch_size,seq_length,num_vocab_size),而target是(batch_size,seq_length),此时需要预测的形状,通常使用permute操作成 (batch_size,num_vocab_size,seq_length)

参考资料

https://zhuanlan.zhihu.com/p/261059231

交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数 - 不愿透漏姓名的王建森 - 博客园

MSELoss — PyTorch 2.5 documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2263971.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

抓包 127.0.0.1 (loopback) 使用 tcpdump+wireshark

直接使用 wireshark无法抓取 127.0.0.1环回的数据包,一种解决方法是先传到路由器再返回,但这样可能造成拥塞。 Linux 先使用tcpdump抓包并输出为二进制文件,然后wireshark打开。 比如 sudo tcpdump -i lo src host localhost and dst host…

免费GIS工具箱:轻松将glb文件转换成3DTiles文件

在GIS地理信息系统领域,GLB文件作为GLTF文件的二进制版本,主要用于3D模型数据的存储和展示。然而,GLB文件的使用频率相对较低,这是因为GIS系统主要处理的是地理空间数据,如地图、地形、地貌、植被、水系等,…

安防监控Liveweb视频汇聚融合平台助力执法记录仪高效使用

Liveweb平台可接入的设备除了常见的智能分析网关与摄像头以外 ,还可通过GB28181协议接入执法记录仪,实现对执法过程的全程监控与录像,并对执法轨迹与路径进行调阅回看。那么,如何做到执法记录仪高效使用呢? 由于执法记…

【Unity3D】实现可视化链式结构数据(节点数据)

关键词:UnityEditor、可视化节点编辑、Unity编辑器自定义窗口工具 使用Newtonsoft.Json、UnityEditor相关接口实现 主要代码: Handles.DrawBezier(起点,终点,起点切线向量,终点切线向量,颜色,n…

网络安全核心目标CIA

网络安全的核心目标是为关键资产提供机密性(Confidentiality)、可用性(Availablity)、完整性(Integrity)。作为安全基础架构中的主要的安全目标和宗旨,机密性、可用性、完整性频频出现,被简称为CIA,也被成为你AIC,只是顺序不同而已…

[项目代码] YOLOv8 遥感航拍飞机和船舶识别 [目标检测]

项目代码下载链接 <项目代码>YOLO 遥感航拍飞机和船舶识别<目标检测>https://download.csdn.net/download/qq_53332949/90163939YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为…

去雾Cycle-GAN损失函数

文章目录 GAN-LossIdentity-LossDP-lossCycle-Loss G和F都是生成器 G是hazy → \to → gt F是gt → \to → hazy D y D_y Dy​判别无雾图是真实还是生成的? D x D_x Dx​判别有雾图是真实还是生成的? GAN-Loss 在 DAM-CCGAN 中存在两个判别器 D x D_x D…

2024年企业中生成式 AI 的现状报告

从试点到生产,企业 AI 格局正在被实时改写。我们对 600 名美国企业 IT 决策者进行了调查,以揭示新兴的赢家和输家。 从试点到生产 2024 年标志着生成性人工智能成为企业关键任务的一年。这些数字讲述了一个戏剧性的故事:今年人工智能支出飙升…

组件十大传值

一、defineProps 和 defineEmits defineProps 用于定义子组件接收的 props,即父组件传递给子组件的数据。 接收父组件传递的数据:定义子组件可以接受的属性及其类型。类型检查:确保传递的数据符合预期的类型。 defineEmits 用于定义子组件…

WPF 依赖属性和附加属性

除了普通的 CLR 属性, WPF 还有一套自己的属性系统。这个系统中的属性称为依赖属性。 1. 依赖属性 为啥叫依赖属性?不叫阿猫阿狗属性? 通常我们定义一个普通 CLR 属性,其实就是获取和设置一个私有字段的值。假设声明了 100 个 …

递归实现指数型枚举(递归)

92. 递归实现指数型枚举 - AcWing题库 每个数有选和不选两种情况 我们把每个数看成每层,可以画出一个递归搜索树 叶子节点就是我们的答案 很容易写出每dfs函数 dfs传入一个u表示层数 当层数大于我们n时,去判断每个数字的选择情况,输出被选…

mac 安装graalvm

Download GraalVM 上面链接选择jdk的版本 以及系统的环境下载graalvm的tar包 解压tar包 tar -xzf graalvm-jdk-<version>_macos-<architecture>.tar.gz 移入java的文件夹目录 sudo mv graalvm-jdk-<version> /Library/Java/JavaVirtualMachines 设置环境变…

【WPS安装】WPS编译错误总结:WPS编译失败+仅编译成功ungrib等

WPS编译错误总结&#xff1a;WPS编译失败仅编译成功ungrib等 WPS编译过程问题1&#xff1a;WPS编译失败错误1&#xff1a;gfortran: error: unrecognized command-line option ‘-convert’; did you mean ‘-fconvert’?解决方案 问题2&#xff1a;WPS编译三个exe文件只出现u…

Redis 集群实操:强大的数据“分身术”

目录 Redis Cluster集群模式 1、介绍 2、架构设计 3、集群模式实操 4、故障转移 5、常用命令 Redis Cluster集群模式 1、介绍 redis3.0版本推出的Redis Cluster 集群模式&#xff0c;每个节点都可以保存数据和整个集群状态&#xff0c;每个节点都和其他所有节点连接。Cl…

数据结构day5:单向循环链表 代码作业

一、loopLink.h #ifndef __LOOPLINK_H__ #define __LOOPLINK_H__#include <stdio.h> #include <stdlib.h>typedef int DataType;typedef struct node {union{int len;DataType data;};struct node* next; }loopLink, *loopLinkPtr;//创建 loopLinkPtr create();//…

植物大战僵尸杂交版v3.0.2最新版本(附下载链接)

B站游戏作者潜艇伟伟迷于12月21日更新了植物大战僵尸杂交版3.0.2版本&#xff01;&#xff01;&#xff01;&#xff0c;有b站账户的记得要给作者三连关注一下呀&#xff01; 不多废话下载链接放上&#xff1a; 夸克网盘链接&#xff1a;&#xff1a;https://pan.quark.cn/s/5c…

Unity 圆形循环复用滚动列表

一.在上一篇垂直循环复用滚动列表的基础上&#xff0c;扩展延申了圆形循环复用滚动列表。实现此效果需要导入垂直循环复用滚动列表里面的类。 1.基础类 using System.Collections.Generic; using UnityEngine; using UnityEngine.UI; using UnityEngine.EventSystems; using …

京东大数据治理探索与实践 | 京东零售技术实践

01背景和方案 在当今的数据驱动时代&#xff0c;数据作为关键生产要素之一&#xff0c;其在商业活动中的战略价值愈加凸显&#xff0c;京东也不例外。 作为国内领先的电商平台&#xff0c;京东在数据基础设施上的投入极为巨大&#xff0c;涵盖数万台服务器、数 EB 级存储、数百…

android:sharedUserId 应用进程声明介绍

背景 adb install 安装系统软件报错,原因是签名不一致,进程改变。 代码分析 AndroidManifest.xml 定义的 android:sharedUserId 应用归属进程不同,从phone切换到system。 初始配置 <manifest xmlns:android="http://schemas.android.com/apk/res/android"c…

Liveweb视频融合共享平台在果园农场等项目中的视频监控系统搭建方案

一、背景介绍 在我国的大江南北遍布着各种各样的果园&#xff0c;针对这些地处偏僻的果园及农场等环境&#xff0c;较为传统的安全防范方式是建立围墙&#xff0c;但是仅靠围墙仍然无法阻挡不法分子的有意入侵和破坏&#xff0c;因此为了及时发现和处理一些难以察觉的问题&…