损失函数总结(四):NLLLoss、CTCLoss

news2024/9/22 15:37:26

损失函数总结(四):NLLLoss、CTCLoss

  • 1 引言
  • 2 损失函数
    • 2.1 NLLLoss
    • 2.2 CTCLoss
  • 3 总结

1 引言

在前面的文章中已经介绍了介绍了一系列损失函数 (L1LossMSELossBCELossCrossEntropyLoss)。在这篇文章中,会接着上文提到的众多损失函数继续进行介绍,给大家带来更多不常见的损失函数的介绍。这里放一张损失函数的机理图:
在这里插入图片描述

2 损失函数

2.1 NLLLoss

NLLLoss(Negative Log Likelihood Loss,负对数似然损失)通常用于训练分类模型,尤其是在多类别分类任务中。它是一种用于度量模型的类别概率分布实际类别分布之间的差距的损失函数。NLLLoss 的数学表达式如下:
L NLL ( Y , Y ′ ) = − 1 n ∑ i = 1 n ∑ j = 1 C y i j log ⁡ ( y i j ′ ) L_{\text{NLL}}(Y, Y') = -\frac{1}{n} \sum_{i=1}^{n} \sum_{j=1}^{C} y_{ij} \log(y_{ij}') LNLL(Y,Y)=n1i=1nj=1Cyijlog(yij)

其中:

  • L CE ( Y , Y ′ ) L_{\text{CE}}(Y, Y') LCE(Y,Y) 是整个数据集上的交叉熵损失
  • n n n 是样本数量。
  • C C C 是类别数量。
  • y i j y_{ij} yij 是第 i i i 个样本的实际类别分布,通常是一个独热编码(one-hot encoding)向量,表示实际类别
  • y i j ′ y_{ij}' yij 是第 i i i 个样本的模型预测的类别概率分布,通常是一个概率向量,表示模型对每个类别的预测概率

注意:上面的公式和 CrossEntropyLoss 公式相同,但实际上是不同的。实际关系为:
NLLLoss + LogSoftmax = CrossEntropyLoss

代码实现(Pytorch):

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)
output.backward()
# 2D loss example (used, for example, with image inputs)
N, C = 5, 4
loss = nn.NLLLoss()
# input is of size N x C x height x width
data = torch.randn(N, 16, 10, 10)
conv = nn.Conv2d(16, C, (3, 3))
m = nn.LogSoftmax(dim=1)
# each element in target has to have 0 <= value < C
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
output = loss(m(conv(data)), target)
output.backward()

NLLLoss 通常用于分类任务,特别是当模型输出的是类别概率分布时。NLLLoss 和 CrossEntropyLoss 是等价的,可以相互替换。。。

2.2 CTCLoss

论文链接:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks

CTC Loss(Connectionist Temporal Classification Loss,连接时序分类损失)通常用于训练序列到序列(sequence-to-sequence)模型,尤其是在语音识别自然语言处理中的任务,其中输出序列的长度与输入序列的长度不一致。CTC Loss 的主要目标是将模型的输出与目标序列对齐,以度量它们之间的相似度。CTCLoss 的数学表达式如下:
L CTC ( S ) = − ln ⁡ ∑ ( x , z ) ∈ S p ( z ∣ x ) = − ∑ ( x , z ) ∈ S l n p ( z ∣ x ) L_{\text{CTC}}(S) = -\ln \sum_{(x,z) \in S} p(z|x) = -\sum_{(x,z) \in S} lnp(z|x) LCTC(S)=ln(x,z)Sp(zx)=(x,z)Slnp(zx)

其中:

  • S S S 表示训练集
  • L CTC ( S ) L_{\text{CTC}}(S) LCTC(S) 表示 给定标签序列和输入,最终输出正确序列的概率

代码实现(Pytorch):

# Target are to be padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10  # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
# Target are to be un-padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
# Target are to be un-padded and unbatched (effectively N=1)
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
# Initialize random batch of input vectors, for *size = (T,C)
input = torch.randn(T, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.tensor(T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

CTCLoss 在语音识别自然语言处理中具有广泛的应用,可以广泛用于sequence-to-sequence任务。

3 总结

到此,使用 损失函数总结(四) 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。如果存在没有提及的损失函数也可以在评论区提出,后续会对其进行添加!!!!

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1132133.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Note】一般二叉树的顺序存储

一般二叉树的存储示意图 位置i012345678910结点ABC^D^E^^F

离子阱领域新突破!滑铁卢团队在量子比特控制方面取得进展

&#xff08;图片来源&#xff1a;网络&#xff09; 要想打造未来量子计算机&#xff0c;就需要对量子比特进行精准可靠地控制。研究人员利用激光开发出了一种新方法&#xff0c;可以控制由化学元素钡制成的单个量子比特&#xff0c;这是迄今为止最好的离子阱量子比特控制方法…

使用Redis部署 PHP 留言板应用

使用Redis部署 PHP 留言板应用 启动 Redis 领导者&#xff08;Leader&#xff09;启动两个 Redis 跟随者&#xff08;Follower&#xff09;公开并查看前端服务清理 启动 Redis 数据库 创建 Redis Deployment apiVersion: apps/v1 kind: Deployment metadata:name: redis-le…

【IDEA查看一个jar包的依赖】

首先install,打jar包 重新创建一个项目 选中刚才的jar包 在这个包下就能看到jar包的依赖了

每日一练 | 华为认证真题练习Day122

1、路由器所有的接口属于同一个广播域。 A. 对 B. 错 2、下列配置默认路由的命令中&#xff0c;正确的是&#xff08;&#xff09;。 A. [Huawei]ip route-static 0.0.0.0 0.0.0.0 192.168.1.1 B. [Huawei-Serial0]ip route-static 0.0.0.0 0.0.0.0 0.0.0.0 C. [Huawei]ip…

做测试5年,熬到阿里P6,月薪25k,我总结了这些技能点

你是不是经常在工作中遇到过这些问题&#xff1a; Linux下查看端口占用命令你还记得吗&#xff1f; python容器数据操作你清楚吗&#xff1f; Devops要用了&#xff0c;各种必备技能你还不熟&#xff1f; 接口调不同&#xff0c;排查思路不清楚&#xff1f; …… 以上这些…

JMeter + Ant + Jenkins持续集成-接口自动化测试

需要安装的工具&#xff1a; jdk1.8jmeter3.2ant1.9jenkins2.1 1、Jdkwin7系统如何安装jdk及环境变量的配置-百度经验 安装包安装设置环境变量验证是否安装正确 Java -version检查&#xff0c;如下就代表安装成功了&#xff0c;环境变量设置就去搜索了&#xff0c;网上很多…

基于springboot实现就业信息管理系统项目【项目源码+论文说明】

基于springboot实现就业信息管理系统演示 摘要 随着信息化时代的到来&#xff0c;管理系统都趋向于智能化、系统化&#xff0c;就业信息管理系统也不例外&#xff0c;但目前国内仍都使用人工管理&#xff0c;市场规模越来越大&#xff0c;同时信息量也越来越庞大&#xff0c;人…

【网安大模型专题10.19】※论文5:Keep the Conversation Going: $0.42 each using ChatGPT

Keep the Conversation Going: Fixing 162 out of 337 bugs for $0.42 each using ChatGPT 写在最前面背景介绍自动程序修复流程Process of APR (automated program repair)1、漏洞程序2、漏洞定位模块3、补丁生成4、补丁验证 &#xff08;可以学习的PPT设计&#xff09;经典的…

微软成AI热潮大赢家,继续押注大模型和人工智能

KlipC报道&#xff1a;微软在官网发布了财报&#xff0c;据数据显示该公司营收同比增长13%达565亿美元&#xff0c;营业利润同比增长25%达269亿美元&#xff0c;净利润同比增长27%达223亿美元。 KlipC的合伙人Andi D表示&#xff1a;“微软的智能云部门收入同比增长19%&#xf…

echarts插件-liquidFill(水球图)

echarts插件-liquidFill&#xff08;水球图&#xff09; 1.下载2.引入&#xff1a;3.使用 1.下载 echarts.js下载&#xff1a;https://cdnjs.com/libraries/echarts echarts-liquidfill.js下载&#xff1a;https://github.com/ecomfe/echarts-liquidfill 2.引入&#xff1a; …

Ubuntu虚拟机部署OpenStack

1、部署环境 系统&#xff1a;ubuntu-22.04.3-desktop-amd64DevStack版本&#xff1a;2024.1VMware Workstation&#xff1a;8G内存、4核处理器、100G硬盘/1、网络NAT模式/1 2、Ubuntu环境设置 点击show applications&#xff0c;选择Software&Updates 跟换Ubuntu的镜像…

Selenum八种常用定位(案例解析)

Selenium是一个备受推崇的工具。它有着丰富的功能&#xff0c;让我们能够与网页互动&#xff0c;执行各种任务&#xff0c;能为测试工程师和开发人员提供了很大的便利。 要充分利用Selenium&#xff0c;就需要了解如何正确定位网页上的元素。 接下来我将带大家共同探讨Seleni…

什么是WMS系统条码化管理

WMS系统是一种用于仓库管理的信息化系统&#xff0c;旨在提高仓库操作的效率和准确性。而在WMS系统中&#xff0c;条码化管理是一项关键的技术和方法&#xff0c;它通过将商品和物料打上条码&#xff0c;并利用扫描设备进行数据采集和处理&#xff0c;实现了仓库管理的全面自动…

全网最全的阿里云ACP认证介绍,看这篇就够了!

IT行业的朋友们在找工作的时候&#xff0c;一定会经常看到“有阿里云ACP认证优先”。这也就是说&#xff0c;有一个ACP认证会是你的加分项。 获得阿里云ACP认证的好处&#xff1a; 1、增加职场竞争力&#xff0c;为企业招投标提供资质证书&#xff1b; 2、官方认证证书&#…

echarts-进度条

echarts-进度条 option {title: {text:"xxxx统计",left: 1%,top: 0%,textStyle: {color: "#2E3033",fontSize:18,},},tooltip: {axisPointer: {type: "shadow",},},grid: {top: 9%,left: "12%",right:"22%",bottom:"0…

算法通过村第十六关-滑动窗口|白银笔记|经典题目讲解

文章目录 前言最长字串专场无重复字符的最长字串至多包含两个不同字串的最长子串至多包含K个不同字串的最长子串 长度最小的子数组盛水最多的容器寻找字串异位词(排序)字符串的排序找到字符串中所有字母异位 总结 前言 提示&#xff1a;所有的话语都颇为类似&#xff0c;而沉默…

视频号视频提取工具,操作简单!一键搞定

在当下信息爆炸的时代&#xff0c;视频成为了人们获取信息、娱乐和交流的重要方式。而随着视频创作的普及&#xff0c;越来越多的人希望能够从各类视频中提取出有价值的素材和片段&#xff0c;以便用于自己的创作需求。然而&#xff0c;对于大多数人来说&#xff0c;费时费力地…

极智项目 | 实战静默活体人脸检测

欢迎关注我的公众号 [极智视界]&#xff0c;获取我的更多经验分享 大家好&#xff0c;我是极智视界&#xff0c;本文来介绍 实战静默活体人脸检测。 本文介绍的 实战静默活体人脸检测&#xff0c;提供完整的可以一键执行的项目工程源码&#xff0c;获取方式有两个&#xff1a…

云音乐Android Cronet接入实践

背景 网易云音乐产品线终端类型广泛&#xff0c;除了移动端&#xff08;IOS/安卓&#xff09;之外&#xff0c;还有PC、MAC、Iot多终端等等。移动端由于上线时间早&#xff0c;用户基数大&#xff0c;沉淀了一些端侧相对比较稳定的网络策略和网络基础能力。然而由于各端在基础…