PyTorch 中各类损失函数介绍

news2024/10/24 13:20:38

深度学习损失函数

在深度学习中,损失函数(Loss Function)是衡量模型预测值与真实值之间差异的函数。损失函数的选择对于模型训练的效果至关重要,因为它直接影响到模型优化的方向和效率。其主要作用包括:

  1. 衡量误差:损失函数为模型提供了一个衡量预测值与真实值之间差异的量化指标。这个指标告诉我们模型的预测有多准确,或者在某些任务中,预测有多不准确。

  2. 指导训练:损失函数定义了优化的目标。在训练过程中,模型的参数通过最小化损失函数来调整,以此来提高模型的预测性能。

  3. 反向传播:损失函数是神经网络中反向传播算法的核心。通过计算损失函数关于模型参数的梯度,反向传播算法能够更新网络权重,以减少预测误差。

  4. 影响模型泛化:选择合适的损失函数可以影响模型的泛化能力。一些损失函数可能在训练集上表现良好,但在未见过的数据上表现不佳,这可能导致过拟合。

  5. 正则化:损失函数可以内置正则化项,如 L1 或 L2 正则化,以控制模型的复杂度并减少过拟合。

  6. 处理不平衡数据:在处理类别不平衡的数据集时,可以选择合适的损失函数来给予少数类更多的权重,从而提高模型对这些类别的识别能力。

  7. 适应不同类型的任务:不同的任务可能需要不同的损失函数。例如,分类问题常用交叉熵损失,而回归问题可能使用均方误差损失。

  8. 多任务学习:在多任务学习中,可以设计损失函数来平衡不同任务之间的贡献,使得模型能够同时学习多个相关任务。

  9. 自适应学习率:一些损失函数,如 Focal Loss,可以自适应地调整不同样本的学习率,这在处理类别不平衡问题时非常有用。

  10. 可解释性和可视化:损失函数的值和变化可以提供模型学习过程的可解释性,有助于调试和理解模型的行为。

PyTorch 中各类损失函数

在 PyTorch 中,损失函数是构建和训练神经网络时不可或缺的一部分。它们用于评估模型的预测值与真实值之间的差异,并指导模型通过优化过程最小化这种差异。以下是 PyTorch 中一些常用的损失函数及其简要介绍:

  1. L1 Loss (Mean Absolute Error Loss)

    • 计算预测值与真实值之间的平均绝对误差。
    • 适用于回归问题,特别是当目标变量包含异常值时。
    • torch.nn.L1Loss()
  2. MSE Loss (Mean Squared Error Loss)

    • 计算预测值与真实值之间的均方误差。
    • 适用于回归问题,尤其是数据相对表现良好时。
    • torch.nn.MSELoss()
  3. Cross-Entropy Loss

    • 用于多类分类问题,结合了 LogSoftmaxNLLLoss
    • torch.nn.CrossEntropyLoss()
  4. Binary Cross-Entropy Loss

    • 用于二元分类问题,计算真实标签和预测标签之间的二元交叉熵。
    • torch.nn.BCELoss()
  5. Binary Cross-Entropy Loss with Logits

    • 结合了 Sigmoid 激活函数和 BCE 损失,用于二元分类问题。
    • torch.nn.BCEWithLogitsLoss()
  6. Hinge Loss

    • 用于最大间隔分类问题,如支持向量机(SVM)。
    • torch.nn.HingeLoss()
  7. Huber Loss (Smooth L1 Loss)

    • 结合了 MSE 和 MAE 的特点,对异常值不敏感。
    • torch.nn.SmoothL1Loss()
  8. KL Divergence Loss

    • 计算两个概率分布之间的 Kullback-Leibler 散度。
    • torch.nn.KLDivLoss()
  9. Poisson Loss

    • 用于预测计数数据的发生次数,例如文本生成。
    • torch.nn.PoissonNLLLoss()
  10. Margin Ranking Loss

    • 用于排名问题,预测相对距离。
    • torch.nn.MarginRankingLoss()
  11. Triplet Margin Loss

    • 用于度量学习,尤其是在训练使用三元组(anchor, positive, negative)的模型。
    • torch.nn.TripletMarginLoss()
  12. CTC Loss (Connectionist Temporal Classification Loss)

    • 用于序列建模问题,如语音识别。
    • torch.nn.CTCLoss()

损失函数使用示例

在 PyTorch 中,损失函数通常通过 torch.nn 模块提供,用于衡量模型预测值与真实值之间的差异。以下是一些常用损失函数的使用示例:

  1. 均方误差损失 (MSELoss)
import torch
import torch.nn as nn

# 预测值和真实值
预测值 = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
真实值 = torch.tensor([2.0, 2.5, 3.5])

# 初始化损失函数
criterion = nn.MSELoss()

# 计算损失
loss = criterion(预测值, 真实值)

# 反向传播
loss.backward()

print(loss.item())
  1. 二元交叉熵损失 (BCELoss)
# 预测值和真实值
预测值 = torch.tensor([0.8, 0.4, 0.3], requires_grad=True)
真实值 = torch.tensor([1, 0, 1])

# 初始化损失函数
criterion = nn.BCELoss()

# 计算损失
loss = criterion(预测值, 真实值)

# 反向传播
loss.backward()

print(loss.item())
  1. 交叉熵损失 (CrossEntropyLoss)
# 预测值(需要使用log_softmax或softmax进行处理)
预测值 = torch.tensor([[0.1, 0.2, 0.7], [0.8, 0.1, 0.1]], requires_grad=True)
# 真实值(多类分类问题中,真实值是类别的索引)
真实值 = torch.tensor([2, 0])

# 初始化损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(预测值, 真实值)

# 反向传播
loss.backward()

print(loss.item())
  1. L1损失 (L1Loss)
# 预测值和真实值
预测值 = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
真实值 = torch.tensor([2.0, 2.5, 3.5])

# 初始化损失函数
criterion = nn.L1Loss()

# 计算损失
loss = criterion(预测值, 真实值)

# 反向传播
loss.backward()

print(loss.item())
  1. Huber损失 (SmoothL1Loss)
# 预测值和真实值
预测值 = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
真实值 = torch.tensor([2.0, 2.5, 3.5])

# 初始化损失函数
criterion = nn.SmoothL1Loss()

# 计算损失
loss = criterion(预测值, 真实值)

# 反向传播
loss.backward()

print(loss.item())
  1. KL散度损失 (KLDivLoss)
# 预测值(概率分布,需要使用log_softmax或softmax进行处理)
预测值 = torch.tensor([0.1, 0.2, 0.7], requires_grad=True)
# 真实值(概率分布)
真实值 = torch.tensor([0.2, 0.5, 0.3])

# 初始化损失函数
criterion = nn.KLDivLoss()

# 计算损失
loss = criterion(torch.log_softmax(预测值, dim=0), torch.softmax(真实值, dim=0))

# 反向传播
loss.backward()

print(loss.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2222414.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目一:3-8译码器的设计与实现(FPGA)

本文以Altera公司生产的Cyclone IV系列的EP4CE15F17C8为主芯片的CRD500开发板作为项目的硬件实现平台,并以Quarter 18.1和ModelSim为开发工具和仿真工具。 目录 一、3-8译码器工作原理 二、设计步骤 1、创建工程文件夹和编辑设计文件 (1)…

(三)将PaddleOCR编译成dll通过Java调用实现ocr识别

说明: 本文编译的PaddleOCR版本:v2.8.1,关于windows下如何生成c项目及如何编译PaddleOCR请参照我的上一篇文章《(二)Windows通过vs c编译PaddleOCR-2.8.1-CSDN博客》,本文是上一个篇文章的延伸。 背景&…

douyin uid转sec_uid 各种进行转换

第一步输入uid: 进行转换: 同时支持接口转换,批量转换,是一个很实用的工具 uid转sec_uid

微信小程序上传图片添加水印

微信小程序使用wx.chooseMedia拍摄或从手机相册中选择图片并添加水印&#xff0c; 代码如下&#xff1a; // WXML代码&#xff1a;<canvas canvas-id"watermarkCanvas" style"width: {{canvasWidth}}px; height: {{canvasHeight}}px;"></canvas&…

如何使用 Spring Cloud 实现客户端负载平衡

微服务系统通常运行每个服务的多个实例。这是实施弹性所必需的。因此&#xff0c;在这些实例之间分配负载非常重要。执行此操作的组件是负载均衡器。Spring 提供了一个 Spring Cloud Load Balancer 库。在本文中&#xff0c;您将学习如何使用它在 Spring Boot 项目中实现客户端…

QPainterPath路径类

函数drawPath()绘制的是一个复合的图形&#xff0c;它使用一个QPainterPath类型的参数作为绘图的对象,QPainterPath类用于记录绘图的操作顺序&#xff0c;优点是绘制复杂图形时只需要创建一个painterpath,然后重复调用就可以了 在使用QPainterPath把路径画好之后&#xff0c;我…

脚本-把B站缓存m4s文件转换成mp4格式

js脚本&#xff0c;自动处理视频 1. 需求简介1.1 pc安装b站客户端1.2 设置视频缓存目录1.3 找个视频缓存1.4 打开缓存文件夹![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/0eb346a84d5f42a7908f1d39bf410c3b.png)1.5 用notepad编辑后缀m4s文件&#xff0c;删除文件内…

Windows系统启动MongoDB报错无法连接服务器

文章目录 发现问题解决办法 发现问题 1&#xff09;、先是发现执行 mongo 命令&#xff0c;启动报错&#xff1a; error: MongoNetworkError: connect ECONNREFUSED 127.0.0.1:27017&#xff1b; 2&#xff09;、再检查 MongoDB 进程 tasklist | findstr mongo 发现没有进程&a…

澳元/美元价格预测:不排除跌至0.6600的可能

澳元/美元一路下跌至0.6620附近。美元保持强劲上涨势头&#xff0c;升至创下三个月新高。汇价的下跌让关键的200日均线受到考验。 澳元/美元周三再度遭遇抛售兴趣&#xff0c;迅速扭转周二的多头尝试&#xff0c;滑落至0.6630附近的新低。这次急剧下跌也对关键的200日均线构成…

yjs机器学习常见算法01——KNN(02)Kd树

1.什么是Kd树&#xff0c;为什么要引入Kd树 knn是寻找k个邻近的点&#xff0c;在这个过程中&#xff0c;需要一个点一个点的与未分类点进行比较&#xff0c;这样的时间复杂度非常高&#xff0c;因此引入了一种原理类似二叉树的Kd树&#xff0c;以减少比较搜索的次数。 kd树的本…

PyTorch求导相关

PyTorch是动态图&#xff0c;即计算图的搭建和运算是同时的&#xff0c;随时可以输出结果&#xff1b;而TensorFlow是静态图。 在pytorch的计算图里只有两种元素&#xff1a;数据&#xff08;tensor&#xff09;和 运算&#xff08;operation&#xff09; 运算包括了&#xf…

Psychophysiology:脑-心交互如何影响个体的情绪体验?

摘要 情绪的主观体验与对身体(例如心脏)活动变化的情境感知和评估相关。情绪唤醒增加与高频心率变异性(HF-HRV)降低、EEG顶枕区α功率降低以及心跳诱发电位(HEP)振幅较高有关。本研究使用沉浸式虚拟现实(VR)技术来研究与情绪唤醒相关的脑心相互作用&#xff0c;以实现自然而可…

SSM考研科目学习APP-计算机毕业设计源码90377

摘 要 基于Android的考研科目学习系统的设计与实现&#xff0c;旨在为广大考研学子提供一个便捷、高效的学习平台。该系统充分利用Android操作系统的广泛普及与灵活定制性&#xff0c;结合考研科目的特点和需求&#xff0c;实现了个性化的学习方案、丰富的题库资源以及智能化…

【个人同步与备份】电脑(Windows)与手机/平板(Android)之间文件同步

文章目录 1. syncthing软件下载2. syncthing的使用2.1. 添加设备2.1.1. syncthing具备设备发现功能&#xff0c;因此安装好软件&#xff0c;只需确认设备信息是否对应即可2.1.2. 如果没有发现到&#xff0c;可以通过设备ID连接2.1.3. 设置GUI身份验证用户&#xff0c;让无关设备…

LeetCode: 3274. 检查棋盘方格颜色是否相同

一、题目 给你两个字符串 coordinate1 和 coordinate2&#xff0c;代表 8 x 8 国际象棋棋盘上的两个方格的坐标。   以下是棋盘的参考图。   如果这两个方格颜色相同&#xff0c;返回 true&#xff0c;否则返回 false。   坐标总是表示有效的棋盘方格。坐标的格式总是先…

大模型技术学习过程梳理,零基础入门到精通,收藏这一篇就够了

“ 学习是一个从围观到宏观&#xff0c;从宏观到微观的一个过程 ” 今天整体梳理一下大模型技术的框架&#xff0c;争取从大模型所涉及的理论&#xff0c;技术&#xff0c;应用等多个方面对大模型进行梳理。 01 — 大模型技术梳理 这次梳理大模型不仅仅是大模型本身的技术…

接口测试(八)jmeter——参数化(CSV Data Set Config)

一、CSV Data Set Config 需求&#xff1a;批量注册5个用户&#xff0c;从CSV文件导入用户数据 1. 【线程组】–>【添加】–>【配置元件】–>【CSV Data Set Config】 2. 【CSV数据文件设置】设置如下 3. 设置线程数为5 4. 运行后查看响应结果

vue3项目页面实现echarts图表渐变色的动态配置

完整代码可点击vue3项目页面实现echarts图表渐变色的动态配置-星林社区 https://www.jl1mall.com/forum/PostDetail?postId202410151031000091552查看 一、背景 在开发可配置业务平台时&#xff0c;需要实现让用户对项目内echarts图表的动态配置&#xff0c;让用户脱离代码也…

基于Matlab 人脸识别技术

Matlab 人脸识别技术 算法流程&#xff1a; 本系统运用PCA算法来实现人脸特征提取&#xff0c;然后通过计算欧式距离来判别待识别测试人脸&#xff0c;本个系统框架图如下&#xff1a; 图&#xff1a; 人脸识别系统框架图 整个系统的流程是这样的&#xff0c;首先通过图像采…

给哔哩哔哩bilibili电脑版做个手机遥控器

前言 bilibili电脑版可以在电脑屏幕上观看bilibili视频。然而&#xff0c;电脑版的bilibili不能通过手机控制视频翻页和调节音量&#xff0c;这意味着观看视频时需要一直坐在电脑旁边。那么&#xff0c;有没有办法制作一个手机遥控器来控制bilibili电脑版呢&#xff1f; 首先…