Delay Penalty for RNN-T and CTC

news2024/11/6 7:13:14

1. 背景

之前介绍了如何在 RNN-T 流式模型上应用时延正则,以及在 Conformer 和 LSTM 上的实验结果。

本期公众号重点带大家回顾下具体的思路,以及如何类似地在 CTC 流式模型上应用时延正则。

有些内容可能有所重复,读者可适当跳过。

2. Delay penalty for RNN-T

标准 RNN-T

如图1所示,RNN-T lattice 包含了特征序列标签序列之间所有可能的对齐路径,两个序列的长度通常不一致。在 lattice 中,从点 (t,u) 出发,向上走的边表示输出 yu+1,分数为 y(t,u);向右走的边表示输出 ∅,分数为 ∅(t,u)。

此处我们提及的 lattice 边上的分数,无特殊说明情况下,都是 log-probability。

图1

假设 lattice 中路径 i 的 分数为 si,RNN-T 的目标函数 L 为最大化 lattice 中所有路径的分数之和:

L=log⁡∑iexp⁡(si)

我们通常使用动态规划算法 forward-backward[1] 来高效地计算目标函数 L,不需要显式计算每条路径的分数 si。具体地,令 α(t,u) 表示在 lattice 中在看到了特征 x0…t 的条件下,输出标签 y0…u 的分数。我们可以得到状态转移方程:

α(t,u)=LogAdd(α(t,u−1)+y(t,u−1),α(t−1,u)+∅(t−1,u)),

lattice 中所有路径的总分数 L,即状态转移的终点,可以计算为:

L=α(T−1,U)+∅(T−1,U)

我们可以发现,RNN-T 的目标函数 L 并没有考虑不同的路径所对应的时延。如图1所示,红色的路径更早地输出 symbol,时延较低;而蓝色的路径更晚地输出 symbol,时延较高。

与非流式模型不同,流式模型无法看到句子中所有的 context。流式模型为了看到更多的上下文,以达到更好的识别性能,会倾向于增强时延较高的路径, 如图1中蓝色的路径。如图2蓝色线所示,随着训练进行,没有时延正则的 RNN-T 流式模型的时延逐渐上升。

图2

Delay-penalized RNN-T

为了惩罚 RNN-T 模型的时延,我们的想法是在目标函数 L 上增加一个时延正则项 Ldelay,得到一个新的目标函数 Laug:

Laug=L+Ldelay

Ldelay 表示 lattice 中所有路径的平均时延分数(值越大,代表时延越低),定义为:

Ldelay=λ∑idiwi

其中,di 为路径 i 的时延分数,λ 是一个超参数,wi 为路径 i 的分数在整个 lattice 中的比重:

wi=∂L∂si=exp⁡(si)∑iexp⁡(si)

此处,di 的值越大,表示路径 i 的时延越低。

下文会具体讲解时延分数 di 的定义。

因此,通过引入时延正则项 Ldelay,RNN-T 会被约束着去增强那些时延较低(di 较大)的路径 i,为他们赋予一个更高的分数 si。

上文提到,我们在优化 L 的过程中,并没有显式计算各个路径 i 的分数 si。那么问题来了,为了优化 Laug,难道我们还要去显示地求出各个路径 i 的分数 si,来计算 wi 吗?这无疑是一种极其低效且不优雅的做法。

此时,Daniel 抛出了一长串数学公式,证明了我们可以优雅地实现 Laug 的优化。

由于篇幅限制,我们不在此列出具体的证明过程。感兴趣的同学可以阅读论文  https://arxiv.org/pdf/2211.00490.pdf,保证学过高中数学的同学都能看懂。

简而言之,对于一个较小的超参数 λ,带时延正则的目标函数 Laug 对路径分数 si 的导数 ∂Laug∂si 可以近似为:

∂Laug∂si≈exp⁡(λdi+si)∑iexp⁡(λdi+si)

我们只需要在优化标准目标函数 L 的过程中,将 si 替换为 λdi+si,即可达到近似地优化 Laug 的效果:

si′=λdi+si

接下来我们来讲一下在 RNN-T lattice 中如何定义 di。令 π={πu}0U−1 为输出标签序列 y0...U−1 (即向上走的边)的帧索引。我们定义路径 i 的时延分数 di 为这些帧索引 πu 相对于句子中间帧的 offset:

di=∑u(T−12−πu)

此处,之所以要加上它们相对于中间帧的 offset,是为了使得引入时延正则后,loss 函数的数值不会和原来相差太大。

图3

如图3所示,为了实现 si′,我们只需要修改 lattice 中那些输出 symbol 的边(即向上走的边),加上与帧索引对应的 offset:

y′(t,u)=y(t,u)+λ×(T−12−t)

因此,在执行 forward-backward 算法之前,我们只需要将 y(t,u) 替换为 y′(t,u),即可以一种简单高效的方式,近似地优化带时延正则的目标函数 Laug。

如图2中红色的线所示,通过在 RNN-T 目标函数上添加时延正则项,随着训练的进行,我们可以逐步降低流式模型的时延。

代码可以参考 k2 的 PR  https://github.com/k2-fsa/k2/pull/976 和 icefall 的 PR  https://github.com/k2-fsa/icefall/pull/654。

3. Delay penalty for CTC

CTC 的目标函数[2]和 RNN-T 目标函数的公式一样,也是最大化 lattice 中所有可能的对齐路径分数之和 L:

L=log⁡∑iexp⁡(si)

我们希望可以像 RNN-T 一样,对于 lattice 中每条路径,根据时延对应地修改它的分数 si,即 si′=λdi+si,达到近似地优化带时延正则的目标函数 Laug 的效果。

下面将介绍如何使用 k2 fsa 巧妙地实现这个功能。

大家可以下载文件  https://github.com/k2-fsa/next-gen-kaldi-wechat/blob/master/pdf/LF-MMI-training-and-decoding-in-k2-Part-I.pdf,了解如何用 k2 fsa 实现计算 CTC 目标函数。

图4

假设特征序列的长度为5,标签序列为 Z,O,O。利用 k2 fsa 我们可以得到对应的 CTC lattice。在图4所示,在 CTC lattice 中,每条从起点到终点的路径为:特征序列和标签序列之间的合法对齐路径。每条边上有三个属性:(1)输入标签(label);(2)输出标签( aux_label);(3)分数,即 log_softmax(encoder_output)

例如,以下三条对齐路径对应着不同的输入标签序列,他们的输出标签序列经过去除 ϵ 后,都可以得到 Z,O,O:

Z,O,∅,O,∅→Z,O,ϵ,O,ϵ

Z,Z,O,∅,O→Z,ϵ,O,ϵ,O

Z,∅,O,∅,O→Z,ϵ,O,ϵ,O

每条对齐路径的时延,取决于那些首次输出 symbol 的边的帧索引 π={πu}0U−1 ,如下面加粗的 symbol:

Z,O,∅,O,∅→Z,O,ϵ,O,ϵ

Z,Z,O,∅,O→Z,ϵ,O,ϵ,O

Z,∅,O,∅,O→Z,ϵ,O,ϵ,O

每条路径中,那些首次输出 symbol 的边的数量是相同的,为标签序列的长度 U。我们可以像上文 RNN-T 一样,定义每个路径 i 的时延分数 di 为:这些帧索引 πu 相对于句子中间帧的 offset。

图5

如图5所示,为了在 CTC 中实现 si′,我们只需要修改 lattice 中首次输出 symbol 的边(标记为红色)上的分数 yt,加上与帧索引(相对于中间帧)的 offset:

yt′=yt+λ×(T−12−t)

因此,在执行动态规划算法求 CTC lattice 中所有路径总分数之前,我们只需要将 yt 替换为 yt′,即可以一种简单高效的方式,近似地优化带时延正则的目标函数 Laug。

在 k2-fsa CTC 实现过程中,利用  k2.Fsa.get_total_scores() 求得 lattice 所有路径总分数。

具体地,如何修改 lattice 上那些首次输出 symbol 的边的分数,可以参考 k2 的 PR https://github.com/k2-fsa/k2/pull/1086,和 icefall 的 PR https://github.com/k2-fsa/icefall/pull/669,里面有详细的注释。

4. 实验结果

RNN-T

如表1所示,在使用 RNN-T 训练的流式 Conformer(chunk=0.32s)和 LSTM 模型上,应用时延正则可以有效降低模型的时延。我们只需通过调节超参数 λ,即可控制 WER 和 symbol delay 之间的 trade-off。

关于 RNN-T 时延正则,大家可以阅读论文  https://arxiv.org/pdf/2211.00490.pdf 了解更详细的实验结果。

表1

CTC

表2展示了使用 CTC 训练的流式 Conformer 模型 (chunk=0.32s),应用了时延正则后,在 librispeech 数据集 test-clean 和 test-other 上的结果。可以看出,我们同样可以通过调节超参数 λ,即可控制 WER 和 symbol delay 之间的 trade-off。

由于模型只使用了 CTC 损失函数训练了 25 个 epoch,WER 较差,大家可忽略其绝对数值。

表2

5. 总结

最后,再附上论文地址 https://arxiv.org/pdf/2211.00490.pdf,感兴趣的同学可以阅读 Daniel 的详细证明过程。有疑问的同学欢迎通过 github issue 或者评论区和我们讨论。

参考资料

[1] forward-backward: https://arxiv.org/pdf/1211.3711.pdf

[2] CTC 的目标函数: https://www.cs.toronto.edu/~graves/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/54709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iwebsec靶场 SQL注入漏洞通关笔记12-等价函数替换绕过

系列文章目录 iwebsec靶场 SQL注入漏洞通关笔记1- 数字型注入_mooyuan的博客-CSDN博客 iwebsec靶场 SQL注入漏洞通关笔记2- 字符型注入(宽字节注入)_mooyuan的博客-CSDN博客 iwebsec靶场 SQL注入漏洞通关笔记3- bool注入(布尔型盲注&#…

Ajax学习:同源策略(与跨域相关)ajax默认遵循同源策略

同源策略:是浏览器的一种安全策略 同源意味着:协议、域名、端口号必须相同 违背同源便是跨域 当前网页的url和ajax请求的目标资源的url必须协议、域名、端口号必须相同 比如:当前网页:协议http 域名 a.com 端口号8000 目标请求…

python——spark入门

Hadoop是对大数据集进行分布式计算的标准工具,这也是为什么当你穿过机场时能看到”大数据(Big Data)”广告的原因。它已经成为大数据的操作系统,提供了包括工具和技巧在内的丰富生态系统,允许使用相对便宜的商业硬件集群进行超级计算机级别的…

Android Poco初始化时,不大起眼但可能存在坑点的参数们

1. 前言 进行Android poco初始化的时候,可能大多数同学都是直接在Poco辅助窗里选择Android模式,然后选择自动帮我们补充poco的初始化脚本: 这种情况下,我们大多数都不会关注初始化的参数。但如果我们不了解这些参数的含义&#x…

Spring之@RequestMapping、@GetMapping、 @PostMapping 三者的区别

我的理解:其实RequestMapping、GetMapping、 PostMapping 三者就是父类和子类的区别,RequestMapping是父类,GetMapping、 PostMapping为子类集成了RequestMapping更明确了http请求的类型 分析三者的源码: RequestMapping .class&…

C#教务管理大数据平台系统源码

校务管理系统是专门针对幼儿园、培训学校的业务应用而设计研发的一款行业应用软件。校管家校务管理系统融入先进的协同管理理念,运用领先的信息化、网络化处理技术,结合丰富的教育培训行业经验,切实有效的解决幼儿园、培训学校日常工作中的关…

[附源码]计算机毕业设计-菜篮子系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

KVM虚机添加磁盘

KVM虚拟机添加磁盘两种方法: 1、添加虚拟磁盘文件 2、添加物理磁盘 需求 1、/kvm/kvms目录是我们KVM磁盘镜像集中管理的位置,我们需要在/kvm/kvms下给ceph1虚拟机创建名为ceph1-vdb.qcow2的磁盘文件,大小为80G,作为ceph1的vdb磁盘…

Python和labview先学哪个

前言 在这之前,先跟大家分享个身边的故事。 大学同学小宏,在北京一家电子设备公司做运维,上周四刚被升为运维部经理,薪资涨了35%。 但你一定想不到,他平时从不加班,甚至还经常迟到。 (文末送…

Qt入门总结

文章目录Qt一、各文件基本概念1、main.cpp文件2、XXX.pro文件3、XXX.h文件二、基本知识1、命名规范2、快捷键三、入门操作1、添加按钮2、重置窗口大小3、设置窗口标题4、设置固定的窗口大小5、对象树6、添加源文件/头文件7、窗口坐标系四、信号与槽1、让按钮附带功能2、自定义信…

浅谈affine_trans_point_2d与affine_trans_pixel

先看下两个坐标图谱: 变换前: 变换后: 我们根据1号点和9号点前后的关系,计算变换后其他点的坐标:这其实就是根据MARK点进行定位的原理 halcon代码: 执行结果: 我们发现,两种变换方…

湘江新区:金融活水赋能实体经济

湘江早报全媒体记者 黄荣佳 通讯员 易芳 吴硕 4月26日,艾布鲁环保在创业板首发上市; 10月28日,“国产操作系统第一股"麒麟信安敲响上市钟声,成为今年全省第一家在科创板上市的公司; 11月24日,…

临床信息去冗余 临床数据处理分组不同的GSE数据集有不同的临床信息,不同的分组技巧

最近,我发现学徒在学习GEO数据挖掘的过程中,遇到了第一个也是至关重要的一个难题就是对下载后的数据集进行合适的分组,因为只有对样本进行合适的分组,才有可能得到我们想要的信息。但是不同的GSE数据集有不同的临床信息&#xff0…

SpringCloud全系列知识(4)——统一网关Gateway

统一网关Gateway 一 认识网关 1.网关的功能 1.身份认证和权限校验 2.服务路由&#xff0c;负载均衡 3.请求限流 2.技术实现 Gatewayzuul 二 Gateway的使用 1.搭建网关服务 1.创建新的Module,引入 Gateway 和 Nacos 服务发现依赖。 <!--nacos服务发现依赖-->…

天宇优配|研判明年下半年投资机会或更大 险资看好“安全”与“发展”

上海证券报记者昨日获悉&#xff0c;多家稳妥资管公司已经拟定2023年出资战略&#xff0c;跟着本年以来多项稳经济方针逐步落地&#xff0c;险资遍及看好下一年经济复苏带来的商场出资时机。 权益出资方面&#xff0c;险资以为&#xff0c;当时股票商场估值处于前史较低水平&am…

Java语言有多少优势(总结版)

现在有越来越多的新技术工具、新语言涌现&#xff0c;面对林林总总的语言&#xff0c;总会有人问&#xff1a; 这么多语言应该先学哪一种&#xff1f; 什么语言值得我们长时间地学习&#xff1f; 学完之后职业发展前景大吗&#xff1f; 那么&#xff0c;我给出的答案是Java …

C++手敲Roberts_Prewitt_Sobel实现阈值分割

使用OPENCV,编写代码&#xff0c;学习图像二值化算法&#xff0c;以及边缘检测算法&#xff0c;进行图像的分割。 下面主要介绍Robert算子的实现过程&#xff1a; ①任务分析 调入并显示图像&#xff1b;使用Roberts 算子对图像进行边缘检测处理&#xff1b; Roberts 算子为…

【Scala专栏】字符串与集合

本文索引一、String/StringBuilder二、Array三、List四、Set五、Map六、TupleScala中的字符串在本质上就是Java的字符串String&#xff0c; 所以在用法上除了要符合Scala的语法以外&#xff0c;其他方面是如出一辙的。   Scala中的集合分为可变和不可变&#xff0c;默认声明…

内核编译 --- 链接器

先回顾一下编译知识 将一个程序的编译分为两个大的阶段&#xff1a;编译阶段和链接阶段 编译阶段又分为三个步骤&#xff1a;预编译&#xff0c;编译&#xff08;此编译和上面程序的编译不是同一个意思… 上面那个是指宽泛的编译&#xff09;和汇编 编译阶段经过预编译、编译…

笔记 vue3如何引入iconfont

本次采用的免费字体图标是iconfont 1、点我进入官网 2、具体流程 1、 需要什么图标在上面搜索框查找&#xff0c;然后加入购物车&#xff0c;选完后再点右上角的购物车 2、添加到项目中&#xff0c;有项目就选项目添加&#xff0c;没有就创建项目 3、确定后进入你的项目(可以…