nn.KLDivLoss,nn.CrossEntropyLoss,nn.MSELoss,Focal_Loss

news2024/11/14 3:53:08
  1. KL loss:https://blog.csdn.net/qq_50001789/article/details/128974654

https://pytorch.org/docs/stable/nn.html

1. nn.L1Loss

1.1 公式

L1Loss: 计算预测 x和 目标y之间的平均绝对值误差MAE, 即L1损失
l o s s = 1 n ∑ i = 1 , . . . n ∣ x i − y i ∣ loss=\frac{1}{n} \sum_{i=1,...n}|x_i-y_i| loss=n1i=1,...nxiyi

1.2 语法

torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')
  • size_averagereduce已经被弃用,具体功能可由reduction替代
  • reduction:指定损失输出的形式,有三种选择:none|mean|sum。none:损失不做任何处理,直接输出一个数组mean:将得到的损失求平均值再输出,会输出一个数sum:将得到的损失求和再输出,会输出一个数

注意:输入的x y 可以是任意维数的数组,但是二者形状必须一致

1.3 应用案例

对比reduction不同时,输出损失的差异

import torch.nn as nn
import torch

x = torch.rand(10, dtype=torch.float)
y = torch.rand(10, dtype=torch.float)
L1_none = nn.L1Loss(reduction='none')
L1_mean = nn.L1Loss(reduction='mean')
L1_sum = nn.L1Loss(reduction='sum')
out_none = L1_none(x, y)
out_mean = L1_mean(x, y)
out_sum = L1_sum(x, y)
print(x)
print(y)
print(out_none)
print(out_mean)
print(out_sum)

在这里插入图片描述

2.nn.MSELoss

2.1 公式

MSELoss也叫L2 loss, 即计算预测x和目标y的平方误差损失。MSELoss的计算公式如下:

l o s s = 1 n ∑ i = 1 , . . n ( x i − y i ) 2 loss=\frac{1}{n} \sum_{i=1,..n}(x_i-y_i)^2 loss=n1i=1,..n(xiyi)2

:输入x 与y 可以是任意维数的数组,但是二者shape大小一致

2.2 语法

torch.nn.MSELoss(reduction = 'mean')

其中:

  • reduction:指定损失输出的形式,有三种选择:none|mean|sumnone:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数

2.3 应用案例

对比reduction不同时,输出损失的差异

import torch.nn as nn
import torch

x = torch.rand(10, dtype=torch.float)
y = torch.rand(10, dtype=torch.float)
mse_none = nn.MSELoss(reduction='none')
mse_mean = nn.MSELoss(reduction='mean')
mse_sum = nn.MSELoss(reduction='sum')
out_none = mse_none(x, y)
out_mean = mse_mean(x, y)
out_sum = mse_sum(x, y)


print('x:',x)
print('y:',y)
print("out_none:",out_none)
print("out_mean:",out_mean)
print("out_sum:",out_sum)

在这里插入图片描述

3 nn.SmoothL1Loss

3.1 公式

SmoothL1Loss是结合L1 lossL2 loss改进的,其数学定义如下:

在这里插入图片描述
在这里插入图片描述
如果绝对值误差低于 β \beta β, 则使用 L 2 L2 L2损失,,否则使用绝对值损失 L 1 L1 L1, ,此损失对异常值的敏感性低于 L 2 L2 L2 ,即当 x x x y y y相差过大时,该损失数值要小于 L 2 L2 L2损失,在某些情况下该损失可以防止梯度爆炸。

3.2 语法

torch.nn.SmoothL1Loss( reduction='mean', beta=1.0)
  • reduction:指定损失输出的形式,有三种选择:none|mean|sum。none:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数
  • beta:损失在 L 1 L1 L1 L 2 L2 L2之间切换的阈值,默认beta=1.0

3.3 应用案例

import torch.nn as nn
import torch

# reduction设为none便于逐元素对比损失值
loss_none = nn.SmoothL1Loss(reduction='none')
loss_sum = nn.SmoothL1Loss(reduction='sum')
loss_mean = nn.SmoothL1Loss(reduction='mean')
x = torch.randn(10)
y = torch.randn(10)
out_none = loss_none(x, y)
out_sum = loss_sum(x, y)
out_mean = loss_mean(x, y)
print('x:',x)
print('y:',y)
print("out_none:",out_none)
print("out_mean:",out_mean)
print("out_sum:",out_sum)

在这里插入图片描述

4. nn.CrossEntropyLoss

nn.CrossEntropyLoss 在pytorch中主要用于多分类问题的损失计算。

4.1 交叉熵定义

交叉熵主要是用来判定实际的输出与期望的输出的接近程度,也就是交叉熵的值越小,两个概率分布就越接近。
假设概率分布p为期望输出(target),概率分布q为实际输出(pred), H ( p , q ) H(p,q) H(p,q)为交叉熵, 则:
在这里插入图片描述
在这里插入图片描述

Pytorch中的CrossEntropyLoss()函数

Pytorch中计算的交叉熵并不是采用交叉熵定义的公式得到的,其中q为预测值,p为target值:
在这里插入图片描述
而是交叉熵的另外一种方式计算得到的:
在这里插入图片描述
PytorchCrossEntropyLoss()函数的主要是将log_softmax NLLLoss最小化负对数似然函数)合并到一块得到的结果

CrossEntropyLoss()=log_softmax() + NLLLoss() 

在这里插入图片描述

  • (1) 首先对预测值pred进行softmax计算:其中softmax函数又称为归一化指数函数,它可以把一个多维向量压缩在(0,1)之间,并且它们的和为1
    在这里插入图片描述
  • (2) 然后对softmax计算的结果,再取log对数。
  • (3) 最后再利用NLLLoss() 计算CrossEntropyLoss, 其中NLLLoss() 的计算过程为:将经过log_softmax计算的结果与target 相乘并求和,然后取反。

其中(1),(2)实现的是log_softmax计算,(3)实现的是NLLLoss(), 经过以上3步计算,得到最终的交叉熵损失的计算结果。

4.2 语法

torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction='mean',label_smoothing=0.0)
  • 最常用的参数为 reduction(str, optional) ,可设置其值为 mean, sum, none ,默认为 mean。该参数主要影响多个样本输入时,损失的综合方法。mean表示损失为多个样本的平均值,sum表示损失的和,none表示不综合。
  • weight: 可手动设置每个类别的权重,weight的数组大小和类别数需保持一致

4.3 应用案例

import torch.nn as nn
import torch
loss_func = nn.CrossEntropyLoss()
pre = torch.tensor([[0.8, 0.5, 0.2, 0.5]], dtype=torch.float)
tgt = torch.tensor([[1, 0, 0, 0]], dtype=torch.float)
print('----------------手动计算---------------------')
print("1.softmax")
print(torch.softmax(pre, dim=-1))
print("2.取对数")
print(torch.log(torch.softmax(pre, dim=-1)))
print("3.与真实值相乘")
print(-torch.sum(torch.mul(torch.log(torch.softmax(pre, dim=-1)), tgt), dim=-1))
print('-------------调用损失函数-----------------')
print(loss_func(pre, tgt))
print('----------------------------------------')

在这里插入图片描述
由此可见:

  • ①交叉熵损失函数会自动对输入模型的预测值进行softmax。因此在多分类问题中,如果使用nn.CrossEntropyLoss()则预测模型的输出层无需添加softmax层

  • nn.CrossEntropyLoss()=nn.LogSoftmax()+nn.NLLLoss().

nn.CrossEntropyLoss() 的target可以是one-hot格式,也可以直接输出类别,不需要进行one-hot处理,如下示例:

import torch
import torch.nn as nn
x_input=torch.randn(3,3)#随机生成输入 
print('x_input:\n',x_input) 
y_target=torch.tensor([1,2,0])#设置输出具体值 print('y_target\n',y_target)

#计算输入softmax,此时可以看到每一行加到一起结果都是1
softmax_func=nn.Softmax(dim=1)
soft_output=softmax_func(x_input)
print('soft_output:\n',soft_output)

#在softmax的基础上取log
log_output=torch.log(soft_output)
print('log_output:\n',log_output)

#对比softmax与log的结合与nn.LogSoftmaxloss(负对数似然损失)的输出结果,发现两者是一致的。
logsoftmax_func=nn.LogSoftmax(dim=1)
logsoftmax_output=logsoftmax_func(x_input)
print('logsoftmax_output:\n',logsoftmax_output)

#pytorch中关于NLLLoss的默认参数配置为:reducetion=True、size_average=True
nllloss_func=nn.NLLLoss()
nlloss_output=nllloss_func(logsoftmax_output,y_target)
print('nlloss_output:\n',nlloss_output)

#直接使用pytorch中的loss_func=nn.CrossEntropyLoss()看与经过NLLLoss的计算是不是一样
crossentropyloss=nn.CrossEntropyLoss()
crossentropyloss_output=crossentropyloss(x_input,y_target)
print('crossentropyloss_output:\n',crossentropyloss_output)
  • 其中pred为x_input=torch.randn(3,3,对应的target为y_target=torch.tensor([1,2,0]), target并没有处理为one-hot格式,也可以正常计算结果的。

5. nn.BCELoss和nn.BCEWithLogitsLoss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1228449.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringCloud 微服务全栈体系(十四)

第十一章 分布式搜索引擎 elasticsearch 四、RestAPI ES 官方提供了各种不同语言的客户端,用来操作 ES。这些客户端的本质就是组装 DSL 语句,通过 http 请求发送给 ES。官方文档地址:https://www.elastic.co/guide/en/elasticsearch/client/…

Amazon EC2的出现,是时代的选择了它,还是它选择了时代

目录 Amazon EC2简介 友商云服务器对比(Amazon VS Tencent) 友商云服务器对比(Amazon VS Alibaba) Amazon 云服务器的绝对优势 Amazon EC2功能 Amazon EC2 Linux 实例入门 启动实例 连接到的实例 清除的实例 终止的实例…

Android 10.0 系统修改usb连接电脑mtp和PTP的显示名称

1.前言 在10.0的产品定制化开发中,在usb模块otg连接电脑,调整为mtp文件传输模式的时候,这时可以在电脑看到手机的内部存储 显示在电脑的盘符中,会有一个mtp名称做盘符,所以为了统一这个名称,就需要修改这个名称,接下来分析下处理的 方法来解决这个问题 2.系统修改usb连…

源码分析Mybatis拦截器(Interceptor)拦截saveBatch()获取不到实体id的原因

1.背景 由于业务需求想在Mybatis拦截器层面获取insert后实体id去做相关业务。但是发现执行saveBatch()方法时,获取参数实体的时候,拿不到自增id。但是save()方法可以。 save方法之所以可以是因为: MybatisPlus的BaseMapper执行insert方法后…

如何在虚拟机的Ubuntu22.04中设置静态IP地址

为了让Linux系统的IP地址在重新启动电脑之后IP地址不进行变更,所以将其IP地址设置为静态IP地址。 查看虚拟机中虚拟网络编辑器获取当前的子网IP端 修改文件/etc/netplan/00-installer-config.yaml文件,打开你会看到以下内容 # This is the network conf…

java拼图小游戏

第一步是创建项目 项目名自拟 第二部创建个包名 来规范class 然后是创建类 创建一个代码类 和一个运行类 代码如下: package heima;import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.KeyEvent; import jav…

深入理解注意力机制(下)——缩放点积注意力及示例

一、介绍 在这篇文章中,我们将重点介绍 Transformer 背后的 Scaled Dot-Product Attention,并详细解释其计算逻辑和设计原理。 在文章的最后,我们还会提供一个Attention的使用示例,希望读者看完后能够对Attention有更全面的了解。…

将word中的表格无变形的弄进excel中

在上篇文章中记录了将excel表拷贝到word中来: 记录将excel表无变形的弄进word里面来-CSDN博客 本篇记录:将word中的表格无变形的弄进excel中。 1.按F12,“另存为...”,保存类型:“单个文件页面”,保存。…

Java读写Jar

Java提供了读写jar的类库Java.util.jar,Java获取解析jar包的工具类如下: import java.io.File; import java.io.IOException; import java.net.URL; import java.net.URLClassLoader; import java.util.Enumeration; import java.util.HashMap; import …

【C++入门到精通】新的类功能 | 可变参数模板 C++11 [ C++入门 ]

阅读导航 引言一、新的类功能1. 默认成员函数2. 类成员变量初始化3. 强制生成默认函数的关键字default4. 禁止生成默认函数的关键字delete5. override 和 final(1)override(2)final 二、可变参数模板递归函数方式展开参数包逗号表…

C# Winform围棋棋盘

C# Winform简单的围棋棋盘vs2008winform小游戏C#vs2010winform棋盘C#窗体小游戏 这是一个简单的围棋棋盘小游戏,使用C# Winform编写棋盘界面,玩家可以在空白的交叉点上下棋子 项目获取: 项目获取:typora: typora/img (gitee.co…

支付宝沙箱支付

支付宝沙箱支付 支付宝沙箱(Alipay Sandbox)是支付宝提供的一个模拟环境,用于开发者在不影响真实交易的情况下进行支付宝相关功能的测试和调试。在软件开发中,沙箱环境通常指的是一个隔离的测试环境,可以模拟真实环境…

【论文阅读】MAG:一种用于航天器遥测数据中有效异常检测的新方法

文章目录 摘要1 引言2 问题描述3 拟议框架4 所提出方法的细节A.数据预处理B.变量相关分析C.MAG模型D.异常分数 5 实验A.数据集和性能指标B.实验设置与平台C.结果和比较 6 结论 摘要 异常检测是保证航天器稳定性的关键。在航天器运行过程中,传感器和控制器产生大量周…

Python 自动化(十八)admin后台管理

admin后台管理 什么是admin后台管理 django提供了比较完善的后台数据库的接口,可供开发过程中调用和测试使用 django会搜集所有已注册的模型类,为这些模型类提供数据管理界面,供开发使用 admin配置步骤 创建后台管理账号 该账号为管理后…

2023年中职“网络安全“—Web 渗透测试①

2023年中职"网络安全"—Web 渗透测试① Web 渗透测试任务环境说明:1.访问地址http://靶机IP/task1,分析页面内容,获取flag值,Flag格式为flag{xxx};2.访问地址http://靶机IP/task2,访问登录页面。…

【每日一题】三个无重叠子数组的最大和

文章目录 Tag题目来源题目解读解题思路方法一:滑动窗口 写在最后 Tag 【滑动窗口】【数组】【2023-11-19】 题目来源 689. 三个无重叠子数组的最大和 题目解读 解题思路 方法一:滑动窗口 单个子数组的最大和 我们先来考虑一个长度为 k 的子数组的最…

【开源】基于Vue.js的创意工坊双创管理系统

项目编号: S 049 ,文末获取源码。 \color{red}{项目编号:S049,文末获取源码。} 项目编号:S049,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 管理员端2.2 Web 端2.3 移动端 三、…

leetcode算法之分治-快排

目录 1.颜色分类2.排序数组3.数组中的第k个最大元素4.最小的k个数 1.颜色分类 颜色分类 class Solution { public:void sortColors(vector<int>& nums) {int n nums.size();int left -1,rightn,i0;while(i<right){if(nums[i] 0) swap(nums[left],nums[i]);e…

力扣 字母异位词分组 哈表 集合

&#x1f468;‍&#x1f3eb; 力扣 字母异位词分组 ⭐ 思路 由于互为字母异位词的两个字符串包含的字母相同&#xff0c;因此对两个字符串分别进行排序之后得到的字符串一定是相同的&#xff0c;故可以将排序之后的字符串作为哈希表的键。 &#x1f351; AC code class Solut…

设计模式-行为型模式-策略模式

一、什么是策略模式 策略模式是一种行为设计模式&#xff0c;它允许在运行时选择算法或行为&#xff0c;并将其封装成独立的对象&#xff0c;使得这些算法或行为可以相互替换&#xff0c;而不影响使用它们的客户端。&#xff08;ChatGPT生成&#xff09; 主要组成部分&#xff…