pytorch中常用的损失函数

news2024/12/23 6:51:33

1 损失函数的作用

损失函数是模型训练的基础,并且在大多数机器学习项目中,如果没有损失函数,就无法驱动模型做出正确的预测。 通俗地说,损失函数是一种数学函数或表达式,用于衡量模型在某些数据集上的表现。损失函数在深度学习主要作用如下:

  • 衡量模型性能:损失函数用于评估模型的预测结果与真实结果之间的误差程度。较小的损失值表示模型的预测结果与真实结果更接近,反之则表示误差较大。因此,损失函数提供了一种度量模型性能的方式。
  • 参数优化:在训练机器学习和深度学习模型时,损失函数被用作优化算法的目标函数。通过最小化损失函数,可以调整模型的参数,使模型能够更好地逼近真实结果。
  • 反向传播:在深度学习中,通过反向传播算法计算损失函数对模型参数的梯度。这些梯度被用于参数更新,以便优化模型。损失函数在反向传播中扮演着重要的角色,指导参数的调整方向。
  • 防止过拟合:过拟合是指模型在训练数据上表现良好,但在新数据上表现较差的现象。损失函数可以帮助在训练过程中监控模型的过拟合情况。通过观察训练集和验证集上的损失,可以及早发现模型是否过拟合,从而采取相应的措施,如正则化等。

2 pytorch中常见的损失函数

损失函数名称适用场景
torch.nn.MSELoss()均方误差损失回归
torch.nn.L1Loss()平均绝对值误差损失回归
torch.nn.CrossEntropyLoss()交叉熵损失多分类
torch.nn.NLLLoss()负对数似然函数损失多分类
torch.nn.NLLLoss2d()图片负对数似然函数损失图像分割
torch.nn.KLDivLoss()KL散度损失回归
torch.nn.BCELoss()二分类交叉熵损失二分类
torch.nn.MarginRankingLoss()评价相似度的损失
torch.nn.MultiLabelMarginLoss()多标签分类的损失多标签分类
torch.nn.SmoothL1Loss()平滑的L1损失回归
torch.nn.SoftMarginLoss()多标签二分类问题的损失

多标签二分类

2.1 L1损失函数

预测值与标签值进行相差,然后取绝对值,根据实际应用场所,可以设置是否求和,求平均,公式可见下,Pytorch调用函数:nn.L1Loss

import torch
import torch.nn as nn

Loss_fn = nn.L1Loss(size_average=None, reduce=None, reduction='mean')

input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = Loss_fn(input, target)
print(output)

运行结果显示如下:

tensor(1.4177, grad_fn=<MeanBackward0>)

2.2 L2损失函数

预测值与标签值进行相差,然后取平方,根据实际应用场所,可以设置是否求和,求平均,公式可见下,Pytorch调用函数:nn.MSELoss

import torch.nn as nn
import torch

loss = nn.MSELoss(size_average=None, reduce=None, reduction='mean')

input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)
print(output)

运行结果显示如下:

tensor(1.7956, grad_fn=<MseLossBackward0>)

2.3 Huber Loss损失函数

简单来说就是L1和L2损失函数的综合版本,结合了两者的优点,公式可见下,Pytorch调用函数:nn.SmoothL1Loss

import matplotlib.pyplot as plt
import torch

# 定义函数和参数
smooth_l1_loss = nn.SmoothL1Loss(reduction='none')
x = torch.linspace(-1, 1, 10000)
y = smooth_l1_loss(torch.zeros(10000), x)

# 绘制图像
plt.plot(x, y)
plt.xlabel('x')
plt.ylabel('SmoothL1Loss')
plt.title('SmoothL1Loss Function')
plt.show()

 运行结果显示如下:

2.4 二分类交叉熵损失函数

简单来说,就是度量两个概率分布间的差异性信息,在某一程度上也可以防止梯度学习过慢,公式可见下,Pytorch调用函数有两个,一个是nn.BCELoss函数,用的时候要结合Sigmoid函数,另外一个是nn.BCEWithLogitsLoss()

import torch.nn as nn
import torch

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
print(output)

运行结果显示如下:

tensor(0.6214, grad_fn=<BinaryCrossEntropyBackward0>)
import torch
import torch.nn as nn

label = torch.empty((2, 3)).random_(2)
x = torch.randn((2, 3), requires_grad=True)

bce_with_logits_loss = nn.BCEWithLogitsLoss()
output = bce_with_logits_loss(x, label)

print(output)

 运行结果显示如下:

tensor(0.7346, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

2.5 多分类交叉熵损失函数

也是度量两个概率分布间的差异性信息,Pytorch调用函数也有两个,一个是nn.NLLLoss,用的时候要结合log softmax处理,另外一个是nn.CrossEntropyLoss

import torch
import torch.nn.functional as F

input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
output = F.nll_loss(F.log_softmax(input, dim=1), target)
print(output)

运行结果显示如下:

tensor(2.9503, grad_fn=<NllLossBackward0>)
import torch
import torch.nn as nn

loss = nn.CrossEntropyLoss()
inputs = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(inputs, target)

print(output)

运行结果显示如下:

tensor(1.6307, grad_fn=<NllLossBackward0>)

2.6 自定义损失

通过对 nn 模块进行子类化,将损失函数创建为神经网络图中的节点。 这意味着我们的自定义损失函数是一个 PyTorch 层,与卷积层完全相同。

class Custom_MSE(nn.Module):
  def __init__(self):
    super(Custom_MSE, self).__init__();

  def forward(self, predictions, target):
    square_difference = torch.square(predictions - target)
    loss_value = torch.mean(square_difference)
    return loss_value
  
  # def __call__(self, predictions, target):
  #   square_difference = torch.square(y_predictions - target)
  #   loss_value = torch.mean(square_difference)
  #   return loss_value

可以在“forward”函数调用或“call”内部定义损失的实际实现。

3 总结

损失函数在人工智能领域中起着至关重要的作用,它不仅是模型训练和优化的基础,也是评估模型性能、解决过拟合问题以及指导模型选择的重要工具。不同的损失函数适用于不同的问题和算法,选择合适的损失函数对于取得良好的模型性能至关重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1189817.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数模之线性规划

线性规划 优化类问题&#xff1a;有限的资源&#xff0c;最大的收益 例子: 华强去水果摊找茬&#xff0c;水果摊上共3个瓜&#xff0c;华强总共有40点体力值,每劈一个瓜能带来40点挑衅值,每挑一个瓜问“你这瓜保熟吗”能带来30点挑衅值,劈瓜消耗20点体力值&#xff0c;问话消耗…

Linux awk命令

除了使用 sed 命令&#xff0c;Linux 系统中还有一个功能更加强大的文本数据处理工具&#xff0c;就是 awk。 曾有人推测 awk 命令的名字来源于 awkward 这个单词。其实不然&#xff0c;此命令的设计者有 3 位&#xff0c;他们的姓分别是 Aho、Weingberger 和 Kernighan&#x…

7+差异分析+WGCNA+PPI网络,学会了不吃亏

今天给同学们分享一篇生信文章“Integrated PPI- and WGCNA-Retrieval of Hub Gene Signatures Shared Between Barretts Esophagus and Esophageal Adenocarcinoma”&#xff0c;这篇文章发表在Front Pharmacol期刊上&#xff0c;影响因子为5.6。 结果解读&#xff1a; 选定研…

【解决方案】vue 项目 npm run dev 时报错:‘cross-env‘ 不是内部或外部命令,也不是可运行的程序

报错 cross-env 不是内部或外部命令&#xff0c;也不是可运行的程序 或批处理文件。 npm ERR! code ELIFECYCLE npm ERR! errno 1 npm ERR! estate1.0.0 dev: cross-env webpack-dev-server --inline --progress --config build/webpack.dev.conf.js npm ERR! Exit status 1 np…

什么是final修饰 使用final修饰类、方法、变量的区别?

简介: 变量成为常量&#xff0c;不允许修改 当final修饰类时&#xff0c;该类变为最终类&#xff08;或称为不可继承的类&#xff09;。不能从最终类派生子类。这样做的目的是为了防止其他类修改或扩展最终类的行为。当final修饰方法时&#xff0c;该方法成为最终方法&#xf…

Qt QtCreator调试Qt源码配置

目录 前言1、编译debug版Qt2、QtCreator配置3、调试测试4、总结 前言 本篇主要介绍了在麒麟V10系统下&#xff0c;如何编译debug版qt&#xff0c;并通过配置QtCreator实现调试Qt源码的目的。通过调试源码&#xff0c;我们可以对Qt框架的运行机制进一步深入了解&#xff0c;同时…

计算摄像技术03 - 数字感光器件

一些计算摄像技术知识内容的整理&#xff1a;感光器件的发展过程、数字感光器件结构、数字感光器件的指标。 目录 一、感光器件的发展过程 二、数字感光器件结构 &#xff08;1&#xff09;CCD结构 ① 微透镜 ② 滤光片 ③ 感光层 电荷传输模式 &#xff08;2&#xff09;CMOS结…

代码随想录算法训练营第16天|104. 二叉树的最大深度111.二叉树的最小深度222.完全二叉树的节点个数

JAVA代码编写 104. 二叉树的最大深度 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;3示例 2&#xff1a; …

API接口自动化测试

本节介绍&#xff0c;使用python实现接口自动化实现。 思路&#xff1a;讲接口数据存放在excel文档中&#xff0c;读取excel数据&#xff0c;将每一行数据存放在一个个列表当中。然后获取URL,header,请求体等数据&#xff0c;进行请求发送。 结构如下 excel文档内容如下&#x…

【vue会员管理系统】篇五之系统首页布局和导航跳转

一、效果图 1.首页 2.会员管理&#xff0c;跳转&#xff0c;跳其他页面也是如此&#xff0c;该页的详细设计会在后面的章节完善 二、代码 新增文件 components下新增文件 view下新增文件&#xff1a; 1.componets下新建layout.vue 放入以下代码&#xff1a; <template…

计算机组成原理之指令

引言 关于riscv操作数 32个寄存器 | X0~X31|快速定位数据。在riscv中&#xff0c;只对寄存器中的数据执行算术运算 2^61个存储字 | 只能被数据传输指令访问。riscv体系采用的是字节寻址。 一个寄存器是8bytes&#xff0c;64位&#xff08;double word&#xff09; 每次取的…

Python高级语法----深入asyncio:构建异步应用

文章目录 异步I/O操作示例:异步网络请求异步任务管理示例:并发执行多个任务使用异步队列示例:生产者-消费者模式在现代软件开发中,异步编程已经成为提高应用性能和响应性的关键技术之一。Python的asyncio库为编写单线程并发代码提供了强大的支持。本文将深入探讨asyncio的三…

Hadoop原理,HDFS架构,MapReduce原理

Hadoop原理&#xff0c;HDFS架构&#xff0c;MapReduce原理 2022找工作是学历、能力和运气的超强结合体&#xff0c;遇到寒冬&#xff0c;大厂不招人&#xff0c;可能很多算法学生都得去找开发&#xff0c;测开 测开的话&#xff0c;你就得学数据库&#xff0c;sql&#xff0c…

C++ vector 动态数组的指定元素删除

文本旨在对 C 的容器 vector 进行肤浅的分析。 文章目录 Ⅰ、vector 的指定元素删除代码结果与分析 Ⅱ、vector 在新增元素后再删除指定元素代码结果与分析 Ⅲ、vector 在特定条件下新增元素代码结果与分析 参考文献 Ⅰ、vector 的指定元素删除 代码 #include <iostream&g…

另辟蹊径者 PoseiSwap:背靠潜力叙事,构建 DeFi 理想国

前不久&#xff0c;灰度在与 SEC 就关于 ETF 受理的诉讼案件中&#xff0c;以灰度胜诉告终。灰度的胜利&#xff0c;也被加密行业看做是加密 ETF 在北美地区阶段性的胜利&#xff0c; 该事件也带动了加密市场的新一轮复苏。 此前&#xff0c;Nason Smart Money 曾对加密市场在 …

深度学习 opencv python 公式识别(图像识别 机器视觉) 计算机竞赛

文章目录 0 前言1 课题说明2 效果展示3 具体实现4 关键代码实现5 算法综合效果6 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于深度学习的数学公式识别算法实现 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;学…

Spring Cloud LoadBalancer基础知识

LoadBalancer 概念常见的负载均衡策略使用随机选择的负载均衡策略创建随机选择负载均衡器配置 Nacos 权重负载均衡器创建 Nacos 负载均衡器配置 自定义负载均衡器(根据IP哈希策略选择)创建自定义负载均衡器封装自定义负载均衡器配置 缓存 概念 LoadBalancer(负载均衡器)是一种…

【Linux】文件重定向以及一切皆文件

文章目录 前言一、重定向二、系统调用dup2三、重定向的使用四、一切皆文件 前言 Linux进程默认情况下会有3个缺省打开的文件描述符&#xff0c;分别是标准输入0&#xff0c; 标准输出1&#xff0c; 标准错误2&#xff0c; 0,1,2对应的物理设备一般是&#xff1a;键盘&#xff…

2007-2022年上市公司工业机器人渗透度数据

2007-2022年上市公司工业机器人渗透度数据 1、时间&#xff1a;2007-2022年 2、指标&#xff1a;股票代码、年份、工业机器人渗透度 3、计算方式&#xff1a;首先&#xff0c;计算行业层面的工业机器人渗透度指标&#xff1b;其次&#xff0c;构建企业层面的工业机器人渗透度…

渗透必备:Proxifier玩转代理

目录 0# 概述 1# Proxifier介绍 2# 操作过程 2.1 配置代理服务器 2.2 配置代理规则 3# Proxifier玩转代理 3.0 配置说明 3.1 通过Proxifier进行内网渗透 3.2 通过Proxifier将VM虚拟机代理 3.3 通过Proxifier进行小程序抓包 3.4 补充 4# 总结 0# 概述 在日常的渗透过…