【Python】nn.CTCLoss()函数详解与示例

news2024/11/13 15:14:36

前言

在深度学习领域,特别是在处理序列到序列的预测任务时,如语音识别和手写识别,nn.CTCLoss函数是一个非常重要的工具。本文将详细解析PyTorch中的nn.CTCLoss函数,包括其原理、原型和示例。

  • 前言
  • 函数原理
    • CTC算法简介
    • CTC Loss函数
    • 函数原型
    • 调用方式
      • 注意事项
    • 示例
  • 小结

函数原理

CTC算法简介

CTC(Connectionist Temporal Classification)是一种针对序列数据的端到端训练方法,尤其适用于RNN(循环神经网络)模型。传统的RNN序列学习任务需要事先标注好输入序列和输出序列之间的映射关系,但在实际应用中,这种标注往往非常昂贵且难以获得。CTC算法通过引入多对一的映射关系,使得RNN模型能够直接对序列数据进行学习,而无需预先标注输入和输出的映射关系。

CTC Loss函数

CTC Loss函数的目标是最大化所有能够映射到正确标签序列的输出序列的概率之和。具体来说,CTC Loss通过以下步骤计算:

  1. 扩展字符集:在原始的字符集中增加一个空白标签(blank),用于分隔不同的字符。
  2. 多对一映射:定义从RNN输出层到最终标签序列的多对一映射函数,去除连续的相同字符和空白标签。
  3. 计算路径概率:对于每一个可能的输出序列(即路径),计算其映射到正确标签序列的概率。
  4. 累加概率:将所有能够映射到正确标签序列的路径概率相加。
  5. 取负对数:将上一步得到的概率和取负对数,作为损失值。
    CTC Loss通过动态规划算法有效地计算所有路径的概率,从而避免了暴力计算的复杂性。

函数原型

PyTorch中的nn.CTCLoss函数原型如下:

torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)
参数说明:
blank(int,可选):空白标签的索引,默认为0。
reduction(str,可选):指定损失的计算方式,可选值为'none''mean''sum',默认为'mean'。
zero_infinity(bool,可选):当设置为True时,任何无限或NaN的损失值将被视为0,默认为False。

调用方式

loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
参数说明:
log_probs(Tensor):模型输出的张量,形状为(T, N, C),其中T是序列长度,N是batch size,C是包括空白标签在内的字符集总长度。这个张量通常需要经过torch.nn.functional.log_softmax处理。
targets(Tensor或LongTensor):标签张量,形状为(N, S)(sum(target_lengths)),其中N是batch size,S是标签长度。注意,标签中不能包含空白标签。
input_lengths(Tensor):形状为(N)的张量,包含每个输入序列的长度。
target_lengths(Tensor):形状为(N)的张量,包含每个目标序列的长度。

注意事项

输入形状:nn.CTCLoss期望的输入logits(或log_probs,如果logits=False)形状通常为(T, N, C),其中T是序列长度,N是batch大小,C是类别数(包括空白标签)。
目标格式:目标targets需要是长度为N的列表,其中每个元素是长度为Si的整数列表或张量(Si是第i个样本的目标序列长度),或者是一个形状为(N, S_max)的二维张量,其中S_max是所有目标序列中的最大长度,并使用特定的值(如CTC Loss期望的最小类别索引以下的值)来填充较短的序列。
序列长度:需要提供input_lengths和target_lengths,分别表示每个输入序列和目标序列的长度。
空白标签索引:在初始化nn.CTCLoss时,需要指定空白标签的索引。

示例

import torch
import torch.nn as nn

# 假设有模型输出和标签
# 假设log_probs已经通过log_softmax处理,但注意这里我们简化了形状以匹配示例
# 在实际应用中,T, N, C应该是根据你的数据来确定的
T = 50  # 序列长度
N = 20  # batch size
C = 28  # 类别数(假设有26个字母加上空格和空白标签,这里空白标签设为27)
log_probs = torch.randn(T, N, C).log_softmax(2)  # [T, N, C]

# 假设标签长度不一,这里我们构造一个简化的例子
# 注意:targets应该是列表的列表或二维张量,但为了简化,我们使用二维张量并填充-100(PyTorch的CTCLoss会忽略小于等于最小类别的索引)
# 在实际应用中,你应该使用真实的标签索引,并且不需要填充(除非你使用二维张量并希望统一形状)
max_target_length = 10
targets = torch.randint(1, C-1, (N, max_target_length), dtype=torch.long)  # 假设所有目标序列都不包含空白标签
# 假设有些序列较短,我们用-100填充(注意:这里使用-100只是示例,实际中应确保它小于最小的类别索引)
targets[targets == C-1] = -100  # 假设C-1不是有效的类别索引,我们用它来模拟较短的序列

# input_lengths和target_lengths
input_lengths = torch.full((N,), T, dtype=torch.long)  # 每个输入序列的长度都是T
# target_lengths需要真实反映每个目标序列的长度
# 这里我们假设所有目标序列都是完整的max_target_length长度(在实际应用中,你需要计算每个序列的真实长度)
target_lengths = torch.full((N,), max_target_length, dtype=torch.long)

# 但是,由于我们使用了-100来填充较短的序列,实际上我们需要计算每个序列的真实长度
# 这里我们手动设置几个较短的序列长度作为示例
target_lengths[0:5] = torch.tensor([5, 7, 3, 8, 9], dtype=torch.long)

# 注意:如果targets是二维张量并且包含填充值,你需要确保CTCLoss能够忽略这些填充值
# PyTorch的CTCLoss通过忽略小于等于最小类别索引的值来实现这一点
# 在这个例子中,我们假设C-1(即27)不是有效的类别索引,并且所有有效的类别索引都大于它

# 初始化CTC Loss,注意设置正确的空白标签索引
ctc_loss = nn.CTCLoss(blank=C-1)  # 假设空白标签的索引是C-1(即27)

# 但是,由于PyTorch的CTCLoss不直接支持二维张量作为targets(如果包含填充),
# 我们需要将targets转换为CTCLoss期望的格式:列表的列表或TensorList
# 这里我们为了简化,假设所有目标序列都没有填充,并直接传递二维张量(在实际中,你可能需要转换)

# 如果targets包含填充,并且你使用的是二维张量,你需要先处理它,或者改用列表的列表
# 这里我们假设没有填充,直接传递
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

print(loss.item())

在这里插入图片描述

小结

nn.CTCLoss是处理序列到序列预测任务时的强大工具,它简化了序列数据的标注过程,并允许RNN模型直接对序列数据进行端到端的学习。通过合理利用CTC Loss,我们可以有效地训练出性能优异的序列预测模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1990571.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Golang在整洁架构基础上实现事务

前言 大家好,这里是白泽,这篇文章在 go-kratos 官方的 layout 项目的整洁架构基础上,实现优雅的数据库事务操作。 视频讲解 📺:B站:白泽talk 本期涉及的学习资料: 我的开源Golang学习仓库&am…

【随笔】详解Java POI及其使用方法

引言 随着企业和开发者对数据处理需求的不断增加,操作Excel文件已经成为日常编程工作的重要部分。在Java中,Apache POI(Poor Obfuscation Implementation)库虽然首页其貌不扬,但它绝对是处理Excel文件的强大工具。本文…

JavaWeb之servlet关于Ajax实现前后端分离

一、什么是Ajax: AJAX Asynchronous JavaScript and XML(异步的 JavaScript 和 XML)。 AJAX 不是新的编程语言,而是一种使用现有标准的新方法。 AJAX 最大的优点是在不重新加载整个页面的情况下,可以与服务器交换数据并更新部…

proteus仿真c51单片机(四)双机串口通信(电路设计及代码)

实验要求 1.通过甲机的按键给乙机发送控制字符,同时也可以实现乙机给甲机发送控制字符 2.用PROTEUS软件根据所给电路画出电路图,用KEIL软件调试程序和编译,最后在PROTEUS软件中实现仿真。 3.甲乙两个单片机通过串口进行通信&am…

Please refer to dump files (if any exist) [date].dump, [date]-jvmRun[N]……解决

一、问题 Please refer to dump files (if any exist) [date].dump, [date]-jvmRun[N].dump and [date].dumpstream.二、解决方案 1、当打包构建的时候出现这个问题,如果你只是打包部署,那么就是将maven的test禁止可以成功打包 2、当你是本地服务器启动…

前端模块化-理解Tapable与Webpack中的Hooks

前言 Webpack 中的核心架构是基于 Tapable 实现的,Tapable 是一个类似于 Node.js 的 EventEmitter 的库,专门用于实现发布-订阅模式。Webpack 中的核心组件 Compiler、Compilation、Module、Chunk、ChunkGroup、Dependency、Template 都是通过 Tapable …

Fiddler安装与使用

下载Fiddler 访问Fiddler官方网站,下载适用于您操作系统的最新版本Fiddler。目前,Fiddler支持Windows、macOS和Linux平台。 Web Debugging Proxy and Troubleshooting Tools|Fiddler (telerik.com) 安装Fiddler,以Windows为例 Windows用户…

gitlab给用户添加项目权限

1.进入管理员界面 2.进入群组 3.添加用户

【RISC-V设计-04】- RISC-V处理器设计K0A之架构

【RISC-V设计-04】- RISC-V处理器设计K0A之架构 文章目录 【RISC-V设计-04】- RISC-V处理器设计K0A之架构1. 简介2. 主要特点3. 结构框图4. 指令列表5. CSR指令集6. 中断返回指令7. 总结 1. 简介 在前几篇文章中,介绍了RISC-V处理器的结构和指令集,从本…

Animate软件基本概念:视频及音频

视频和音频是ANimate软件中比较重要的素材类型。 FlashASer:AdobeAnimate2021软件零基础入门教程https://zhuanlan.zhihu.com/p/633230084 FlashASer:实用的各种Adobe Animate软件教程https://zhuanlan.zhihu.com/p/675680471 FlashASer:A…

DSP如何进行竞价

下面根据DSP的系统构成还拆解讲解里面的各个模块,这一节将竞价系统,也就是竞价流程 0、负载均衡 增加吞吐量、加强网络数据处理能力、提高网络的灵活性和可用性。 1、ADX发起竞价请求 上面会携带User ID等用户信息和广告信息一大堆信息。 2、解析竞价…

fastadmin 表单添加默认搜索条件

项目场景:员工列表,查看员工邀约客户明细,在 dialog 窗口中的 table怎么获取当前员工的数据呢?看似简单的需求,实际操作起来还是有点考究的,记录一下实现步骤。 页面1:员工列表 页面2&#xff…

sql_day14(获取各门店的面积)

描述:获取各门店的面积 获取各门店的面积 门店面积信息可以从分店面积明细表中获取。 先取实际经营面积(8), 如果取不到(实际经营面积为空)再取经营面积(7)。 如果取不到(经营面积为空)再取合同面积(1)。…

AI大模型赋能开发者|海云安创始人谢朝海受邀在ISC.AI 2024大会就“大模型在软件开发安全领域的应用”主题发表演讲

近日,ISC.AI 2024 第十二届互联网安全大会在北京国家会议中心盛大开幕。作为全球规格最高、规模最大、影响力最深远的安全峰会之一,本次大会以“打造安全大模型 引领安全行业革命”为主题,聚焦安全与AI两大领域,吸引了众多行业领袖…

您知道Jmeter中Redirect Automatically 和 Follow Redirects的使用场景吗?

相信很多使用过jmeter的同学都没有关注过请求中的Redirect Automatically 和 Follow Redirects选项,如下图: 在 JMeter 中,Redirect Automatically 和 Follow Redirects 是与 HTTP 请求重定向相关的两个选项,它们之间是有很大区别…

Ubuntu小键盘消失,并且安装好搜狗输入法后无法打出中文的问题

Ubuntu右上角的键盘图标不见了_ubuntu虚拟机键盘选项消失了-CSDN博客解决Ubuntu18.04安装好搜狗输入法后无法打出中文的问题_ubuntu18.04 搜狗输入法无法输入中文-CSDN博客 sudo apt install libqt5qml5 libqt5quick5 libqt5quickwidgets5 qml-module-qtquick2sudo apt instal…

小智常见报表-自由报表

概述 自由报表:具有自由设计、修改、完善的能力的报表。 应用场景 如下图所示,简单展示数据 示例说明 数据准备 在数据面板中添加数据集,可选择Json数据集和API服务数据集。Json数据集输入如下图所示: [{"姓名"…

Keytool:Uniapp 云打包需要生成证书的操作笔记

文章目录 背景操作步骤概述安装 Java 并检测版本生成证书 xxx.keystore问题:报错,没有权限使用证书 背景 我用 uniapp 想要用云打包,但是需要本机生成一个证书 操作步骤概述 安装 Java在终端输入 /usr/libexec/java_home -V 之后&#xff…

2024华为数通HCIP-datacom最新题库(H12-831变题更新⑨)

请注意,华为HCIP-Datacom考试831已变题 请注意,华为HCIP-Datacom考试831已变题 请注意,华为HCIP-Datacom考试831已变题 近期打算考HCIP的朋友注意了,如果你准备去考试,还是用的之前的题库,切记暂缓。 如…

【Python学习笔记】序列化

【Python学习笔记】序列化 文章目录 【Python学习笔记】序列化1.python使用pickle序列化数据1.1. 环境准备1.2. 序列化datetime对象1.3. 序列化DataFrame对象1.3.1 json1.3.2 pickle 1.4 序列化list列表 2. flaskweb接口传输序列化数据2.1. bytes形式传输2.1.1. datetime对象2.…