PyTorch内置损失函数汇总 !!

news2024/9/25 2:24:49

文章目录

一、损失函数的概念

二、Pytorch内置损失函数

1. nn.CrossEntropyLoss

2. nn.NLLLoss

3. nn.NLLLoss2d

4. nn.BCELoss

5. nn.BCEWithLogitsLoss

6. nn.L1Loss

7. nn.MSELoss

8. nn.SmoothL1Loss

9. nn.PoissonNLLLoss

10. nn.KLDivLoss

11. nn.MarginRankingLoss

12. nn.MultiLabelMarginLoss

13. nn.SoftMarginLoss

14. nn.MultilabelSoftMarginLoss

15. nn.MultiMarginLoss

16. nn.TripletMarginLoss

17. nn.HingeEmbeddingLoss

18. nn.CosineEmbeddingLoss

19. nn.CTCLoss


一、损失函数的概念

损失函数(loss function):衡量模型输出与真实标签的差异。

损失函数也叫代价函数(cost function)/ 准测(criterion)/ 目标函数(objective function)/ 误差函数(error function)。

二、Pytorch内置损失函数

1. nn.CrossEntropyLoss

功能:交叉熵损失函数,用于多分类问题。这个损失函数结合了nn.LogSoftmaxnn.NLLLoss的计算过程。通常用于网络最后的分类层输出

主要参数:

  • weight:各类别的loss设置权值
  • ignore_index:忽略某个类别
  • reduction:计算模式,可为 none /sum /mean:

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

nn.CrossEntropyLoss(weight=None, 
					size_average=None, 
					ignore_index=-100, 
					reduce=None, 
					reduction=‘mean’)

用法示例:

# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()

2. nn.NLLLoss

功能:负对数似然损失函数,当网络的最后一层是nn.LogSoftmax时使用。用于训练 C 个类别的分类问题

主要参数:

  • weight:各类别的loss设置权值,必须是一个长度为 C 的 Tensor
  • ignore _index:设置一个目标值, 该目标值会被忽略, 从而不会影响到 输入的梯度
  • reduction :计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

nn.NLLLoss(weight=None,
		   size_average=None, 
		   ignore_index=-100, 
		   reduce=None, 
		   reduction='mean')

用法示例:

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)

3. nn.NLLLoss2d

功能:对于图片输入的负对数似然损失. 它计算每个像素的负对数似然损失。它是nn.NLLLoss的二维版本。适用于图像相关的任务,比如像素级任务或分割

torch.nn.NLLLoss2d(weight=None, ignore_index=-100, reduction='mean')

4. nn.BCELoss

功能:二元交叉熵损失函数,用于二分类问题。计算的是目标值和预测值之间的交叉熵。

注意事项:输入值取值在 [0,1]

主要参数:

  • weight:各类别的loss设置权值
  • ignore_index:忽略某个类别
  • reduction:计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

torch.nn.BCELoss(weight=None, 
				 size_average=None,
				 reduce=None, 
 				 reduction='mean')

用法示例:

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)

5. nn.BCEWithLogitsLoss

功能:结合了nn.Sigmoid层和nn.BCELoss的损失函数,用于二分类问题,尤其在预测值没有经过nn.Sigmoid层时

注意事项:网络最后不加sigmoid函数

主要参数:

  • pos_weight:正样本的权值
  • weight:各类别的loss设置权值
  • ignore_index:忽略某个类别
  • reduction:计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

nn.BCEWithLogitsLoss(weight=None, 
					 size_average=None, 
					 reduce=None, reduction='mean', 
					 pos_weight=None)

用法示例:

loss = nn.BCEWithLogitsLoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(input, target)

6. nn.L1Loss

功能:L1损失函数,也称为最小绝对偏差(LAD)。它是预测值和真实值之间差的绝对值的和

主要参数:

  • reduction:计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

torch.nn.L1Loss(reduction='mean')

用法示例:

loss = nn.L1Loss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)

7. nn.MSELoss

功能:均方误差损失函数,计算预测值和真实值之间差的平方的平均值,用于回归问题。

主要参数:

  • reduction:计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

torch.nn.MSELoss(reduction='mean')

用法示例:

loss = nn.MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)

8. nn.SmoothL1Loss

功能:平滑L1损失,也称为Huber损失,主要用于回归问题,尤其是当预测值与目标值差异较大时,比起L1损失更不易受到异常值的影响

  • size_average
  • reduce
  • reduction
  • beta
torch.nn.SmoothL1Loss(reduction='mean')

其中,

用法示例:

loss = nn.SmoothL1Loss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)

9. nn.PoissonNLLLoss

功能:泊松负对数似然损失,适用于计数或事件率预测,其中预测的是事件发生的平均率

主要参数:

  • log_inpput:输入是否为对数形式,决定计算公式
  • full:计算所有loss,默认为False
  • eps:修正项,避免log(input)为nan
torch.nn.PoissonNLLLoss(log_input=True, full=False,  eps=1e-08,  reduction='mean')

用法示例:

loss = nn.PoissonNLLLoss()
log_input = torch.randn(5, 2, requires_grad=True)
target = torch.randn(5, 2)
output = loss(log_input.exp(), target)

10. nn.KLDivLoss

功能::KL散度损失,用于衡量两个概率分布之间的差异。通常用于模型输出与某个目标分布或另一个模型输出之间的相似性度量

注意事项:需提前将输入计算 log-probabilities,如通过nn.logsoftmax()

主要参数:

  • reduction:none / sum / mean / batchmean

①. batchmean:batchsize维度求平均值

②. none:逐个元素计算

③. sum:所有元素求和,返回标量

④. mean:加权平均,返回标量

torch.nn.KLDivLoss(reduction='mean')

用法示例:

loss = nn.KLDivLoss(reduction='batchmean')
input = torch.log_softmax(torch.randn(5, 10), dim=1)
target = torch.softmax(torch.randn(5, 10), dim=1)
output = loss(input, target)

11. nn.MarginRankingLoss

功能:边缘排序损失,用于排序学习任务,它鼓励正例的得分比负例的得分更高一个边界值

注意事项:该方法计算两组数据之间的差异,返回一个 n*n 的loss 矩阵

主要参数:

  • margin:边界值,x1和x2之间的差异值
  • reduction:计算模式,可为none / sum / mean

①. y=1时,希望x1比x2大,当x1>x2时,不产生loss

②. y=-1时,希望x2比x1大,当x2>x1时,不产生loss

torch.nn.MarginRankingLoss(margin=0.0, reduction='mean')

用法示例:

loss = nn.MarginRankingLoss()
input1 = torch.randn(3, requires_grad=True)
input2 = torch.randn(3, requires_grad=True)
target = torch.randn(3).sign()
output = loss(input1, input2, target)

12. nn.MultiLabelMarginLoss

功能:多标签边缘损失,用于多标签分类问题,其中每个类别的损失是独立计算的。

举例:四分类任务,样本x属于0类或3类

主要参数:

  • reduction:计算模式,可为none / sum / mean
torch.nn.MultiLabelMarginLoss(reduction='mean')

对于mini-batch(小批量) 中的每个样本按如下公式计算损失:

用法示例:

loss = nn.MultiLabelMarginLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([[3, 0, -1, -1, -1],
                       [1, 3, -1, -1, -1],
                       [1, 2, 3, -1, -1]])
output = loss(input, target)

13. nn.SoftMarginLoss

功能:软边缘损失,用于二分类任务,是逻辑回归损失的平滑版本。

主要参数:

  • reduction:计算模式,可为none / sum / mean
torch.nn.SoftMarginLoss(reduction='mean')

用法示例:

loss = nn.SoftMarginLoss()
input = torch.randn(3, requires_grad=True)
target = torch.tensor([-1, 1, 1], dtype=torch.float)
output = loss(input, target)

14. nn.MultilabelSoftMarginLoss

功能:多标签软边缘损失,用于多标签分类问题,它是每个标签的二元交叉熵损失的加权版本

主要参数:

  • weight:各类别的loos设置权值
  • reduction:计算模式,可为none / sum / mean
torch.nn.MultiLabelSoftMarginLoss(weight=None, reduction='mean')

用法示例:

loss = nn.MultiLabelSoftMarginLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, 5).random_(2)
output = loss(input, target)

15. nn.MultiMarginLoss

功能:多类别边缘损失,是SVM(支持向量机)的一个变种,用于多类别分类问题。

主要参数:

  • p:可选1或2
  • weight:各类别的loos设置权值
  • margin:边界值
  • reduction:计算模式,可为none / sum / mean
torch.nn.MultiMarginLoss(p=1, margin=1.0, weight=None,  reduction='mean')

用法示例:

loss = nn.MultiMarginLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
output = loss(input, target)

16. nn.TripletMarginLoss

功能:三元组边缘损失,用于度量学习,其中学习的是输入样本之间的相对距离。人脸验证中常用

主要参数:

  • p:范数的阶,默认为2
  • margin:边界值
  • reduction:计算模式,可为none / sum / mean

和孪生网络相似,具体例子:给一个A,然后再给B、C,看看B、C谁和A更像。

torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, reduction='mean')

其中,

用法示例:

loss = nn.TripletMarginLoss(margin=1.0, p=2)
anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative = torch.randn(100, 128, requires_grad=True)
output = loss(anchor, positive, negative)

17. nn.HingeEmbeddingLoss

功能:铰链嵌入损失,用于学习基于距离的相似性,当两个输入被认为是不相似的时,会惩罚它们的距离。常用于非线性embedding和半监督学习

注意事项:输入x 应为两个输入之差的绝对值

主要参数:

  • margin:边界值
  • reduction:计算模式,可为none / sum / mean

torch.nn.HingeEmbeddingLoss(margin=1.0,  reduction='mean')

用法示例:

loss = nn.HingeEmbeddingLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, -1, 1])
output = loss(input, target)

18. nn.CosineEmbeddingLoss

功能:余弦嵌入损失,用于学习输入之间的余弦相似性,适用于确定两个输入是否在方向上是相似的

主要参数:

  • margin:可取值[-1, 1],推荐为 [0,0.5]
  • reduction:计算模式,可为none / sum / mean
torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')

用法示例:

loss = nn.CosineEmbeddingLoss()
input1 = torch.randn(3, 5, requires_grad=True)
input2 = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, -1, 1])
output = loss(input1, input2, target)

19. nn.CTCLoss

功能:连接时序分类(CTC)损失,用于无对齐或序列到序列问题,如语音或手写识别。

主要参数:

  • blank:blank label
  • zero_infinity:无穷大的值或梯度置0
  • reduction:计算模式,可为none / sum / mean
torch.nn.CTCLoss(blank=0, reduction='mean')

用法示例:

T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch
S_min = 10  # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
loss = nn.CTCLoss()
output = loss(input, target, input_lengths, target_lengths)

在实际的代码实现中,你需要根据你的模型和数据来调整输入和目标张量的尺寸

参考:https://yolov5.blog.csdn.net/article/details/123441628

参考:深度学习爱好者

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1408553.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

sylar高性能服务器-日志(P26-P29)内容记录

文章目录 P26:协程01一、方法函数二、结果展示 P27-28:协程02-03一、方法函数二、结果展示 P29:协程04一、方法函数二、结果展示 P26:协程01 ​ 本节内容主要介绍了开始协程的一些准备工作,平常我们使用assert断言时&…

香港web3盛会:Unisat确认参加Big Demo Day项目路演

本次“Big Demo Day”将于1月31日举办第十期,是由Zeepr 总冠名,Central Research、Techub News联合主办、数码港、852web3支持举行的大型线下活动。Big Demo Day集结了Web2和Web3行业精英聚焦香港市场。 Unisat确认参加 Big Demo Day 线下活动&#xff0…

HIS项目介绍、项目环境准备、版本控制介绍、Git基础、Git指针、Git分支、Git标签

案例1:项目环境准备 环境准备说明: 本阶段共使用虚拟机6台,操作系统使用RockyLinux8.6 环境准备要求: 最小化安装即可配置好主机名和IP地址搭建好yum源关闭防火墙和SELinux!!! 项目主机列表 主机名IP地址规格角色服务Progra…

python内置函数有哪些?整理到了7大分类48个函数,都是工作中常用的函数

python内置函数 一、入门函数 1.input() 功能: 接受标准输入,返回字符串类型 语法格式: input([提示信息])实例: # input 函数介绍text input("请输入信息:") print("收到的数据是:%s" % (text))#输出…

“趣味夕阳,乐享生活”小组活动(第二节)

立冬以来,天气日渐寒冷,气温变化较大,各种传染病多发,为进一步增强老年人冬季预防传染病保健意识及科学合理健康的生活方式。近日,1月22日,南阳市人人社工灌涨站开展了“趣味夕阳,乐享生活”小组…

在IntelliJ IDEA中通过Spring Boot集成达梦数据库:从入门到精通

目录 博客前言 一.创建springboot项目 新建项目 选择创建类型​编辑 测试 二.集成达梦数据库 添加达梦数据库部分依赖 添加数据库驱动包 配置数据库连接信息 编写测试代码 验证连接是否成功 博客前言 随着数字化时代的到来,数据库在应用程序中的地位越来…

RS450服务器硬盘亮黄灯故障及从MegaRAID9240-4i阵列卡的恢复业务过程

最近一台ThinkCenter RS450服务器硬盘亮黄灯,引起进入系统很慢,于是将业务系统备份后,对该服务器硬盘进行修复。 该服务器的总共三块硬盘组件了Raid5,因此待第一块盘亮红灯后,尝试进入Raid管理器,将报错的…

gitlab备份-迁移-升级方案9.2.7升级到15版本最佳实践

背景 了解官方提供的版本的升级方案 - GitLab 8: 8.11.Z 8.12.0 8.17.7 - GitLab 9: 9.0.13 9.5.10 9.2.7 - GitLab 10: 10.0.7 10.8.7 - GitLab 11: 11.0.6 11.11.8 - GitLab 12: 12.0.12 12.1.17 12.10.14 - GitLab 13: 13.0.14 13.1.11 13.8.8 13.12.15 - G…

HTML小白入门学习-列表标签

前言 在上一篇文章中,我们学习了下图所示的几个文本格式标签,分别是加粗、斜体、下划线、删除线、下标和上标,忘记了的小伙伴可以回去再看看哦。 在网页中,我们也会经常看到列表,比如某资讯网页的信息列表&#xff…

C# Bitmap类学习1

Bitmap对象封装了GDI中的一个位图,此位图由图形图像及其属性的像素数据组成.因此Bitmap是用于处理由像素数据定义的图像的对象。 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using …

【新加坡机器人学会支持】第三届工程管理与信息科学国际学术会议 (EMIS 2024)

第三届工程管理与信息科学国际学术会议 (EMIS 2024) 2024 3rd International Conference on Engineering Management and Information Science 【国际高级别专家出席/新加坡机器人学会支持】 第三届工程管理与信息科学国际学术会议 (EMIS 2024)将于2024年4月12-14日在中国洛…

SpringBoot项目多数据源配置与MyBatis拦截器生效问题解析

在日常项目开发中,由于某些原因,一个服务的数据源可能来自不同的库,比如: 对接提供的中间库,需要查询需要的数据同步数据,需要将一个库的数据同步到另一个库,做为同步工具的服务对接第三方系统…

黑马Java——面向对象进阶(static继承)

1.static静态变量 静态变量是随着类的加载而加载的,优先与对象出现的

“豚门”、“吗喽”,为啥品牌宣传瞄上网红动物?

近期,新茶饮品牌喜茶联名红山动物园,凭借可爱周边拿捏无数消费者,再往前一段时间,还有奈雪联名“吗喽”表情包,为什么品牌宣传会瞄上网红动物,今天媒介盒子就来和大家聊聊。 一、 萌元素引起用户情绪共鸣 …

C#使用DateTime.Now.AddDays方法获取任一天的信息

目录 一、使用DateTime对象的AddDays方法获取任一天信息方法 二、举例说明获取昨天的信息 三、涉及到的知识点 1. MessageBox.Show()中信息分行的办法 使用DateTime.Now属性可以得到当前的日期信息,此时调用ToString方法,并在该方法中添加…

使用PHP自定义一个加密算法,实现编码配合加密,将自己姓名的明文加密一下

<meta charset"UTF-8"> <?phpfunction customEncrypt($lin, $key mySecretKey){// 定义一个简单的替换规则$li array(L > M, I > Y, Y > O, A > N, E > Q, );$yan ;for($i 0; $i < strlen($lin); $i){$char $lin[$i];if(isset($li[…

27.移除元素(力扣LeetCode)

文章目录 27.移除元素&#xff08;力扣LeetCode&#xff09;题目描述方法一&#xff1a;vector成员函数&#xff1a;erase方法二&#xff1a;暴力解法方法三&#xff1a;双指针法 27.移除元素&#xff08;力扣LeetCode&#xff09; 题目描述 给你一个数组 nums 和一个值 val&…

6.php开发-个人博客项目Tp框架路由访问安全写法历史漏洞

目录 知识点 php框架——TP URL访问 Index.php-放在控制器目录下 ​编辑 Test.php--要继承一下 带参数的—————— 加入数据库代码 --不过滤 --自己写过滤 --手册&#xff08;官方&#xff09;的过滤 用TP框架找漏洞&#xff1a; 如何判断网站是thinkphp&#x…

最小二乘2D圆拟合(高斯牛顿法)

欢迎关注更多精彩 关注我&#xff0c;学习常用算法与数据结构&#xff0c;一题多解&#xff0c;降维打击。 本期话题&#xff1a;最小二乘2D圆拟合 相关背景资料 点击前往 2D圆拟合输入和输出要求 输入 8到50个点&#xff0c;全部采样自圆上&#xff0c;z轴坐标都为0。每个…

算法练习-螺旋矩阵(思路+流程图+代码)

难度参考 难度&#xff1a;中等 分类&#xff1a;数组 难度与分类由我所参与的培训课程提供&#xff0c;但需要注意的是&#xff0c;难度与分类仅供参考。以下内容均为个人笔记&#xff0c;旨在督促自己认真学习。 题目 给定一个正整数n&#xff0c;生成一个包含1到 n^2 所有元…