深度学习中常用损失函数介绍

news2024/11/20 11:46:42

选择正确的损失函数对于训练机器学习模型非常重要。不同的损失函数适用于不同类型的问题。本文将总结一些常见的损失函数,并附有易于理解的解释、用法和示例

均方误差损失(MSE)

 loss_fn = nn.MSELoss()py

均方误差(Mean Squared Error,简称 MSE)损失是在监督学习中,特别是在回归问题中经常使用的一种损失函数。它计算了预测值与真实值之间差异的平方的平均值,用于衡量模型预测的准确性。

MSE 损失的数学表达式定义如下:

由于误差被平方,因此较大的误差会受到更重的惩罚。这有助于模型在训练过程中减少较大的预测误差。MSE 损失函数在整个定义域上连续可微,这一特性使得使用基于梯度的优化算法(如梯度下降法)求解时更加高效。

MSE 的计算公式简单,实现起来容易,计算效率高。

平均绝对误差损失(MAE)

 loss_fn = nn.L1Loss()

平均绝对误差(Mean Absolute Error,简称 MAE)损失是在统计学和机器学习中常用于回归问题的另一种损失函数。MAE 损失计算了预测值与真实值之间差异的绝对值的平均,提供了一个对模型预测偏差的直观度量。与MSE相比,它对大误差的敏感性较低。

MAE 损失的数学表达式定义如下:

与 MSE 相比,MAE 对异常值或离群点的影响较小。这是因为它不对误差进行平方,因此较大的误差不会对总体损失产生过大的影响。MAE 直接反映了平均每个样本预测误差的绝对量,易于理解和解释。

MAE 在误差为零的点不可微,这可能使得基于梯度的优化方法在找到最优解时遇到困难。

Huber损失(平滑L1损失)

 loss_fn = nn.SmoothL1Loss()

Huber Loss,又被称为 Smooth L1 Loss,是一种在回归任务中常用的损失函数,它是平方误差损失(squared loss)和绝对误差损失(absolute loss)的结合。这种损失函数主要用于减少异常值(outliers)在训练模型时的影响,从而提高模型的鲁棒性。

Huber Loss 函数通过一个参数 δ\deltaδ(delta)来定义,该参数决定了损失函数从平方误差向绝对误差转变的点。具体的数学表达式为:

Huber Loss 通常用于回归问题,尤其是当数据中可能包含异常值时。它在金融、气象预测、机器人导航等领域找到了广泛的应用,这些领域中的预测任务常常需要对异常值具有较高的容忍度。

交叉熵损失(Cross-Entropy Loss)

 loss_fn = nn.CrossEntropyLoss()

交叉熵损失(Cross-Entropy Loss),在机器学习领域,尤其是分类问题中,扮演了重要的角色。它主要用于衡量两个概率分布之间的差异,通常用于评估模型预测的概率分布与实际标签的概率分布之间的距离。

对于二分类问题,交叉熵损失可以定义为:

交叉熵损失直接对模型输出的概率进行优化,使模型学习产生接近真实标签的概率分布。当模型的预测错误且置信度高时,交叉熵损失会给予更大的惩罚,反之则减少惩罚,这种特性使得训练过程更加高效。在实现时,通常会结合 softmax 函数和对数函数的数值稳定技术,以避免计算中的下溢或上溢问题。

交叉熵损失广泛应用于各种分类任务,如图像识别、文本分类和医学诊断等。它特别适合于处理输出为概率分布的场景,能够有效地推动模型在预测准确性和概率校准方面的性能。

二元交叉熵损失 (BCE)

 loss_fn = nn.BCELoss()

二元交叉熵损失(Binary Cross-Entropy Loss,简称 BCE)是交叉熵损失在二分类问题中的特定形式,它用于衡量模型预测的概率与实际标签之间的差异。BCE 损失在处理只有两个类别(通常标记为0和1)的分类任务时非常常见和有效。

二元交叉熵损失的公式定义如下:

BCE 损失直接优化模型输出的概率,使其尽可能接近实际的标签。这种优化帮助提升模型在概率预测的准确性。当预测的概率与实际标签差距较大时,损失会显著增加,从而使模型快速学习调整这些预测。这种特性使得模型在训练初期能快速改进其错误预测。实现时,为了避免计算中的数值问题(如对数函数的输入为0),通常结合使用sigmoid 函数。

BCE 损失广泛用于各种需要进行二分类的机器学习任务中,包括医疗影像分析、邮件垃圾分类、在线广告点击预测等。在这些场景中,预测是否属于某一类别(是或否)是核心任务。

在使用 BCE 损失时,标签值必须严格为0或1,因为对数函数在计算时要求输入必须位于(0,1)区间内。

二元交叉熵损失加对数损失

 loss_fn = nn.BCEWithLogitsLoss()

二元交叉熵损失加对数(Binary Cross-Entropy with Logits Loss,通常简称为 BCE with Logits Loss)是一种结合了二元交叉熵损失和逻辑斯蒂(sigmoid)激活函数的损失函数。这种损失函数常用于二分类问题中,尤其是当模型的输出还未通过sigmoid函数转换为概率时。

这个损失函数直接在一个步骤中处理了模型的原始输出(也称为logits)和真实标签之间的交叉熵,避免了先将logits转换为概率再计算损失的复杂度。其公式如下:

这个损失函数通过结合sigmoid激活和对数损失计算来改进数值稳定性,减少计算中可能出现的数值问题(如对数函数输入接近0或1时的数值不稳定)。直接在一个公式中处理logits,避免了单独使用sigmoid函数和交叉熵损失可能引入的额外计算开销。适用于任何需要输出概率预测的二分类模型,尤其是在深度学习中,这种损失函数被广泛用于训练二分类神经网络。

与BCE Loss类似,使用BCE with Logits Loss时,标签yi需要严格为0或1

Kullback-Leibler Divergence Loss (KLDivLoss)

 loss_fn = nn.KLDivLoss()

Kullback-Leibler Divergence(简称 KL 散度或 KL Divergence),在机器学习中通常用作损失函数,称为 KLDiv Loss。它是用来衡量两个概率分布之间差异的一种方法。在许多机器学习任务中,特别是在涉及概率分布、生成模型或信息理论的领域,KL 散度都有着重要的应用。

KL 散度用于测量两个概率分布 P 和 Q 之间的不相似性。对于离散概率分布,其表达式为:

对于连续概率分布,则表达为:

KL 散度是非对称的,在概率模型中,KL 散度可以用来衡量模型预测分布与真实分布之间的差异,常用于生成模型(如变分自编码器)的训练。KL 散度从信息论的角度解释为由于知道真实分布 P而不是预测分布 Q而获得的信息增益。

在机器学习中,特别是在生成模型如 GANs 和 VAEs 中,KL 散度用来确保生成的分布尽可能接近真实数据分布。在语言模型中,通过最小化模型分布与实际数据分布之间的 KL 散度,来优化模型。在文档分类或聚类中,使用 KL 散度来度量文档之间的相似性。

相较于其他损失函数,如交叉熵,KL 散度在计算上可能更复杂,特别是在处理连续分布时。

负对数似然损失(Negative Log-Likelihood Loss)

 loss_fn = nn.NLLLoss()

负对数似然损失(Negative Log-Likelihood Loss,简称 NLLLoss)是机器学习中一种常见的损失函数,尤其是在分类问题中与softmax函数结合使用时效果显著。它用于衡量模型输出概率分布与真实标签之间的匹配程度。

在分类任务中,NLLLoss 直接作用于模型的预测概率和真实标签。通常,该损失函数与 softmax 层一起使用,softmax 层用于将模型输出转化为概率分布。NLLLoss 的计算公式如下:

NLLLoss 通过最大化真实标签的预测概率来优化模型,有效地推动模型输出与目标标签的一致性。虽然 NLLLoss 常与分类任务关联,但它同样适用于任何涉及概率预测的场景,包括某些类型的回归任务。与交叉熵损失相比,当模型输出已是有效的概率分布时,使用 NLLLoss 可以省略将 logits 转化为概率的步骤,从而提高计算效率。

在多类分类问题中,NLLLoss 结合 softmax 层,常用于神经网络中,如图像分类、文本分类等。在使用 NLLLoss 之前,必须确保模型的输出是有效的概率值,即所有输出概率之和为1,且每个概率值都在0到1之间。

Hinge Loss

 def hinge_loss(outputs, targets):
     return torch.mean(torch.clamp(1 - outputs * targets, min=0))

Hinge Loss(铰链损失)是机器学习中常用于分类任务的一种损失函数,尤其是在支持向量机(SVM)中应用广泛。它旨在创建一个边界,该边界不仅能正确分类所有训练数据,而且能最大化边界与数据点之间的间隔。

在二分类问题中,Hinge Loss 的表达式通常定义为:

对于多类分类问题(多类SVM),Hinge Loss 可以扩展为:

Hinge Loss 试图确保正确分类的同时,最大化最近的类别边界,这有助于提高模型的泛化能力。例如Hinge Loss 是训练SVM的标准损失函数,广泛用于各种二分类和多分类问题。在训练SVM时,Hinge Loss 倾向于产生稀疏的模型解,这是因为只有那些在边界上或分类错误的样本才会对损失函数有贡献。

Hinge Loss 是一个非光滑函数,这使得优化过程较为复杂,通常需要使用次梯度方法或其他专门的优化算法。由于 Hinge Loss 的非光滑特性,选择合适的优化算法(如SMO、次梯度下降)对于实现有效训练至关重要。

在使用 Hinge Loss 时,标签 yi必须是 +1 和 −1,这与一些其他损失函数使用 0 和 1的标签编码方式不同。

Hinge Loss 提供了一种在确保分类精度的同时最大化分类间隔的方法,特别适用于那些需要高鲁棒性分类器的应用场景。

总结

本文介绍了几种常用的机器学习损失函数,包括均方误差(MSE)、平均绝对误差(MAE)、交叉熵损失、二元交叉熵损失、带对数的二元交叉熵损失、Kullback-Leibler散度、负对数似然损失和铰链损失。这些损失函数在回归、分类和概率模型评估中有着广泛的应用,各有其优势和特定的应用场景。

https://avoid.overfit.cn/post/1435dd9dc90e420e965b3ab939363216

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1954826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Navidrome音乐服务器 + 音流APP = 释放你的手机空间

20240727 By wdhuag 目录 前言: 参考: Navidrome音乐服务器 Demo试用: 支持多平台: 下载: 修改配置: 设置用NSSM成服务启动: 服务器本地访问网址: 音流 歌词封面API&am…

Golang | Leetcode Golang题解之第292题Nim游戏

题目: 题解: func canWinNim(n int) bool {return n%4 ! 0 }

网站打包封装成app,提高用户体验和商业价值

网站打包封装成app的优势 随着移动互联网的普及,用户对移动应用的需求越来越高。网站打包封装成app可以满足用户的需求,提高用户体验和商业价值。 我的朋友是一名电商平台的运营负责人,他曾经告诉我,他们的网站流量主要来自移动…

vite + xlsx + xlsx-style 导出 Excel

如下 npm i 依赖 npm i xlsxnpm i xlsx-style-vite1、简单的使用:.vue文件中使用 const dataSource ref([]) // 数据源const columns [{title: 用户名,key: userName,width: 120,},{title: 用户组,key: userGroup,width: 120,},{title: 状态,key: enable,width: …

MySQL 视图与事务

文章目录 视图事务事物的四大特性(ACID)事务的开启和结束事物隔离级别现象脏读不可重复度幻读 隔离级别读未提交(READ UNCOMMITTED)读提交 (READ COMMITTED)可重复读 (REPECTABLE READ)串行化 (SERIALIZABLE) 查看与设置事务隔离级别重复读的…

【前端 13】Vue快速入门

Vue快速入门 在现代Web开发中,尽管通过HTML、CSS和JavaScript我们能够构建出美观且功能丰富的页面,但随着项目规模的增大,这种传统的开发方式在效率上逐渐显得力不从心。为了提高开发效率,前端开发者们引入了多种框架和库&#x…

Python酷库之旅-第三方库Pandas(050)

目录 一、用法精讲 181、pandas.Series.var方法 181-1、语法 181-2、参数 181-3、功能 181-4、返回值 181-5、说明 181-6、用法 181-6-1、数据准备 181-6-2、代码示例 181-6-3、结果输出 182、pandas.Series.kurtosis方法 182-1、语法 182-2、参数 182-3、功能 …

异步通信方式的两种消息传输模型

文章目录 一、点对点模型1.1 什么是点对点模型1.2 点对点模型的特点 二、发布订阅模型2.1 什么是发布订阅模型2.2 发布订阅模式的日常案例2.3 发布订阅模型的特点 三、总结参考资料 一、点对点模型 1.1 什么是点对点模型 点对点模型(也叫队列模型) 1.2…

Shiro安全框架与SpringBoot的整合(下)

目录 一、整合前的配置 1.1 导入shiro依赖 1.2 config配置 1.2.1 ShiroConfig(⭐) 1.2.2 MyConfig(拦截器配置) 3. 拦截器(LoginInterceptor) 二、认证登录 2.1. controller 2.2 service和serviceImpl(不用) 2.3 mapper …

[Meachines] [Easy] Blocky Jar包反编译

信息收集 IP AddressOpening Ports10.10.10.37TCP:21,22,80,25565 $ nmap -p- 10.10.10.37 --min-rate 1000 -sC -sV PORT STATE SERVICE VERSION 21/tcp open ftp ProFTPD 1.3.5a 22/tcp open ssh OpenSSH 7.2p2 Ubuntu 4ubuntu2.2 (Ubuntu …

自动驾驶的六个级别是什么?

自动驾驶汽车和先进的驾驶辅助系统(ADAS)预计将帮助拯救全球数百万人的生命,消除拥堵,减少排放,并使我们能够在人而不是汽车周围重建城市。 自动驾驶的世界并不只由一个维度组成。从没有任何自动化到完整的自主体验&a…

VScode调试Python代码

用Python debugger 参考 vscode-python的debug 教学(最全)

动手学深度学习55 循环神经网络 RNN 的实现

动手学深度学习55 循环神经网络 RNN 的实现 从零开始实现简洁实现QA 课件:https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-scratch.html 从零开始实现 %matplotlib inline import math import torch from torch import nn from torch.nn import fun…

【前端逆向】最佳JS反编译利器,原来就是chrome!

有时候需要反编译别人的 min.js。 比如简单改库、看看别人的 min,js 干了什么,有没有重复加载?此时就需要去反编译Javascript。 Vscode 里面有一些反编译插件,某某Beautify等等。但这些插件看人品,运气不好搞的话,反…

力扣高频SQL 50题(基础版)第二十题

文章目录 力扣高频SQL 50题(基础版)第二十题2356.每位教师所教授的科目种类的数量题目说明思路分析实现过程准备数据实现方式结果截图 力扣高频SQL 50题(基础版)第二十题 2356.每位教师所教授的科目种类的数量 题目说明 表: Te…

算法——二分查找(day10)

目录 69. x 的平方根 题目解析: 算法解析: 代码: 35. 搜索插入位置 题目解析: 算法解析: 代码: 69. x 的平方根 69. x 的平方根 - 力扣(LeetCode) 题目解析: 老…

2025第十九届中国欧亚国际军民两用技术及西安国防电子航空航天暨无人机展

2025第十九届中国欧亚国际军民两用技术及西安国防电子航空航天暨无人机展 时间:2025年3月14-16日 地点:西安国际会展中心 详询主办方陆先生 I38(前三位) I82I(中间四位) 9I72(后面四位&am…

中间层 k8s(Kubernetes) 到底是什么,架构是怎么样的?

你是一个程序员,你用代码写了一个博客应用服务,并将它部署在了云平台上。 但应用服务太过受欢迎,访问量太大,经常会挂。 所以你用了一些工具自动重启挂掉的应用服务,并且将应用服务部署在了好几个服务器上,…

【C++】实验六

题目: 1、苹果和虫子 描述:你买了一箱n个苹果,很不幸的是买完时箱子里混进了一条虫子。虫子每x小时能吃掉一个苹果,假设虫子在吃完一个苹果之前不会吃另一个,那么经过y小时你还有多少个完整的苹果? 输入…

Linux基础复习(三)

前言 接Linux基础复习二 一、常用命令及其解释 Tab补全 在上一篇文章配置了IP然后通过远程SSH连接软件控制主机,在配置过程中会发现有些命令过于长,那么,Tab键补全就可以很好的帮助我们去快速的敲出命令,同时如果有些命令有遗…