上海声通团队在WeNet中开源Branchformer

news2024/12/24 9:07:16

上海声通信息科技股份有限公司作为交互式人工智能市场的领导者,具有极强的技术优势和突出的产品特点。公司基于自研的融合通信及人工智能两项核心技术,打造了丰富的、高度标准化的产品模块,为客户提供高效、稳定的产品体验。公司主要的业务场景为智慧城市、智慧出行、智慧通信和智慧金融,同时公司也在积极开辟产品的其他场景以及创新应用。

论文介绍

Branchformer是由卡内基梅隆大学提出的一种结构更加灵活,可解释性更强,且可以更加灵活配置的新一代encoder结构。在ESPnet框架中,在同等参数量的情况下实验测试多个常用数据集(aishell等)结果均齐平或优于Conformer结构。其文章已被ICML2022收录,本文主要讲解其大致结构,并对其在WeNet框架中对其进行复现。

引言:自从被提出后,Conformer结构凭借其高效性被广泛的应用在包括ASR等任务的语音领域,并在多项任务保持着state-of-the-art。相对于Transformer结构,它能够更好的捕获局部与全局特征。然而,Conformer利用了一种串行的方式在每个encoder_layer将音频依次通过self-attention模块与卷积模块并传入下一层。纵然,这种方式也取得了很不错的效果,但是其可解释性可能有点迷惑。局部特征与全局特征之间的关系是怎样的?他们是怎么融合的,他们同等重要吗?还是其中的哪一种扮演着更重要的角色呢?

带着上面的问题,一种新的encoder结构Branchformer被提出了。相较于Conformer马卡龙夹心堆叠结构的结构,Branchformer做了如下改进:

  • 采用了并行的双分支结构。其中分支一利用multiheaded self-attention机制提取输入序列中的全局特征,分支二则引入了cgMLP结构,意在捕获音频序列中的局部特征。

  • MLP with convolutional gating(cgMLP)模块,利用深度可分离卷积于线性门控单元的组合来学习序列中的特征表示。

  • Concat与可学习参数加权等多种特征组合方式

  • Stochastic Layer Skip,训练时通过随机丢弃encoder_layer来增强模型鲁棒性(Espnet代码中添加,论文中没有提到)

模型实现

通过对论文及源代码的阅读,我们发现Branchformer与Conformer的区别主要在于其encoder_layer中对特征的提取及组合方式,而缩小后看整体的处理流程相差并不大。我们参考ESPnet中Branchformer的代码,完成了其在WeNet框架中的实现。

cgMLP模块

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        cache: torch.Tensor = torch.zeros((0, 0, 0))
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward cgMLP"""

        xs_pad = x

        # size -> linear_units
        xs_pad = self.channel_proj1(xs_pad)

        # linear_units -> linear_units/2
        xs_pad, new_cnn_cache = self.csgu(xs_pad, cache)

        # linear_units/2 -> size
        xs_pad = self.channel_proj2(xs_pad)

        out = xs_pad

        return out, new_cnn_cache

CSGU是cgMLP中的关键,其先将输入序列按feature dimension一分为二,其中一部分会通过layer norm及depth-wise convolution,而后再与另一部分做element-wise multiplication得到输出。由于WeNet中引入了cache,我们在这里计算并更新cnn_cache。

 def forward(
        self,
        x: torch.Tensor,
        cache: torch.Tensor = torch.zeros((0, 0, 0))
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward CSGU"""
        
        x_r, x_g = x.chunk(2, dim=-1)
        # exchange the temporal dimension and the feature dimension
        x_g = x_g.transpose(1, 2)  # (#batch, channels, time)

        if self.lorder > 0:
            if cache.size(2) == 0:  # cache_t == 0
                x_g = nn.functional.pad(x_g, (self.lorder, 0), 'constant', 0.0)
            else:
                assert cache.size(0) == x_g.size(0)  # equal batch
                assert cache.size(1) == x_g.size(1)  # equal channel
                x_g = torch.cat((cache, x_g), dim=2)
            assert (x_g.size(2) > self.lorder)
            new_cache = x_g[:, :, -self.lorder:]
        else:
            # It's better we just return None if no cache is required,
            # However, for JIT export, here we just fake one tensor instead of
            # None.
            new_cache = torch.zeros((0, 0, 0), dtype=x_g.dtype, device=x_g.device)

        x_g = x_g.transpose(1, 2)
        x_g = self.norm(x_g)  # (N, T, D/2)
        x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2)  # (N, T, D/2)
        if self.linear is not None:
            x_g = self.linear(x_g)

        x_g = self.act(x_g)
        out = x_r * x_g  # (N, T, D/2)
        out = self.dropout(out)
        return out, new_cache

Merge Two Branches 合并分支特征

作者提出了三种不同的特征融合方法,直接concat,等权重线性融合,可学习权重融合,代码整体改动不大,仅需要注意替换numpy等模块以免影响torch.jit模型导出。

if self.merge_method == "concat":
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(torch.cat([x1, x2], dim=-1))
                )
            elif self.merge_method == "learned_ave":
                if (
                    self.training
                    and self.attn_branch_drop_rate > 0
                    and torch.rand(1).item() < self.attn_branch_drop_rate
                ):
                    # Drop the attn branch
                    w1, w2 = torch.tensor(0.0), torch.tensor(1.0)
                else:
                    # branch1
                    score1 = (self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5)
                    score1 = score1.masked_fill(mask_pad.eq(0), -float('inf'))
                    score1 = torch.softmax(score1, dim=-1).masked_fill(
                        mask_pad.eq(0), 0.0
                    )

                    pooled1 = torch.matmul(score1, x1).squeeze(1)  # (batch, size)
                    weight1 = self.weight_proj1(pooled1)  # (batch, 1)

                    # branch2
                    score2 = (self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5)
                    score2 = score2.masked_fill(mask_pad.eq(0), -float('inf'))
                    score2 = torch.softmax(score2, dim=-1).masked_fill(
                        mask_pad.eq(0), 0.0
                    )

                    pooled2 = torch.matmul(score2, x2).squeeze(1)  # (batch, size)
                    weight2 = self.weight_proj2(pooled2)  # (batch, 1)

                    # normalize weights of two branches
                    merge_weights = torch.softmax(
                        torch.cat([weight1, weight2], dim=-1), dim=-1
                    )  # (batch, 2)
                    merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
                        -1
                    )  # (batch, 2, 1, 1)
                    w1, w2 = merge_weights[:, 0], merge_weights[:, 1]  # (batch, 1, 1)

                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(w1 * x1 + w2 * x2)
                )
            elif self.merge_method == "fixed_ave":
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(
                        (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
                    )
                )
            else:
                raise RuntimeError(f"unknown merge method: {self.merge_method}")

文中作者对比了不同merge操作对模型的影响,并且可视化了可学习参数在不同深度encoder_layer中的权重分布

Stochastic Layer Skip

在ESPnet代码中添加了Stochastic depth,在配置参数中启用此选项可以在训练时随机跳过某些层。从而使得Branchformer能够训练更加深的网络,在训练时随机跳过层可以加速训练且使模型更加鲁棒。

stoch_layer_coeff = 1.0
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
if self.training and self.stochastic_depth_rate > 0:
    skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
    stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)

流式推理

虽然Branchformer利用俩个分支分别计算global与local特征,但在流式计算时实际上与Conformer类似,可以分别计算出atten_cache与cnn_cache做更新。方法与Conformer无异,基本可直接套用。

for i, layer in enumerate(self.encoders):
            # NOTE(xcsong): Before layer.forward
            #   shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
            #   shape(cnn_cache[i])       is (b=1, hidden-dim, cache_t2)
            xs, _, new_att_cache, new_cnn_cache = layer(
                xs, att_mask, pos_emb,
                att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
                cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
            )
            # NOTE(xcsong): After layer.forward
            #   shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
            #   shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
            r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
            r_cnn_cache.append(new_cnn_cache.unsqueeze(0))

实验结果

我们在WeNet上贡献了完整的Branchformer训练方案,并针对encoder layer number,linear units 等参数在aishell数据集上做了相关实验。

模型配置attentionattention_rescorectc_prefix_beam_searchctc_greedy_search

24 layers + 2048 linear units

5.124.815.285.28
24 layers + 1024 linear units5.334.885.415.40
12 layers + 2048 linear units5.375.085.695.69

参考资料

Branchformer:https://arxiv.org/abs/2207.02971

ESPnet:https://github.com/espnet/espnet/blob/master/espnet2/asr/encoder/branchformer_encoder.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/705365.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Arduino Proteus仿真空气净化器温湿度PM2.5空气质量MQ135-0049

Arduino Proteus仿真空气净化器温湿度PM2.5空气质量MQ135-0049 Proteus仿真小实验&#xff1a; Arduino Proteus仿真空气净化器温湿度PM2.5空气质量MQ135-0049 功能&#xff1a; 硬件组成&#xff1a;ARDUINO -UNO-R3开发板、 LCD1602 、DHT11温湿度传感器、电位器模拟PM2.…

垃圾收集策略与算法

垃圾收集策略与算法 程序计数器、虚拟机栈、本地方法栈随线程而生&#xff0c;也随线程而灭&#xff1b;栈帧随着方法的开始而入栈&#xff0c;随着方法的结束而出栈。这几个区域的内存分配和回收都具有确定性&#xff0c;在这几个区域内不需要过多考虑回收的问题&#xff0c;因…

详解C++类型转换特性(代码+详解)

C类型转换 引言1. C语言中的类型转换2. 为什么C需要四种类型转换 C强制类型转换1.static_cast补充 2.dynamic_cast3.const_cast4.reinterpret_cast RTTI 引言 1. C语言中的类型转换 在C语言中&#xff0c;如果赋值运算符左右两侧类型不同&#xff0c;或者形参与实参类型不匹配…

虚拟机Centos7环境下如何安装wget

一、wget简介 wget 是一个从网络上自动下载文件的自由工具&#xff0c;支持通过 HTTP、HTTPS、FTP 三个最常见的 TCP/IP协议 下载&#xff0c;并可以使用 HTTP 代理。“wget” 这个名称来源于 “World Wide Web” 与 “get” 的结合。所谓自动下载&#xff0c;是指 wget 可以在…

JVM oop内存模型

一、oop模型 1、非数组对象 InstaceOopDesc 2、数组对象 arrayOopDesc 2.1 基本数据类型数组 typeArrayOopDesc 2.2 引用类型数组 objArrayOopDesc 3、MarkOopDesc 存放锁信息、分代年龄等 二、对象的内存结构 对象内存结构分成三大部分 对象头 &#xff08;64位操作系统&a…

软考02原码反码和补码

文章目录 前言一、原码二、反码三、补码总结 前言 机器是通过二进制来存储数据的&#xff0c;最好是在学习了软考01进制转换基础上开始学习原码反码和补码。 一、原码 原码通常以固定位数表示,不足补0&#xff0c;由于需要区分正负数所以&#xff0c;最高位为符号位(0为正&…

Electron中启动node服务

记一次遇到的问题&#xff0c;我们知道Electron 中主进程是在node环境中&#xff0c;所以打算在node环境中再启动一个node服务。但是直接使用exec命令启动就会卡主。对应的代码如下 // 启动Node server const startServer async () > {try {console.log(开始启动node serv…

React | 再战Redux

✨ 个人主页&#xff1a;CoderHing &#x1f5a5;️ React.js专栏&#xff1a;React.js 再战Redux &#x1f64b;‍♂️ 个人简介&#xff1a;一个不甘平庸的平凡人&#x1f36c; &#x1f4ab; 系列专栏&#xff1a;吊打面试官系列 16天学会Vue 7天学会微信小程序 Node专栏…

chatgpt赋能python:下载Python的方法及使用指南

下载Python的方法及使用指南 Python是一种高级编程语言&#xff0c;被广泛应用于各种领域。如果你是一名程序员或者对编程有兴趣&#xff0c;那么学习Python会是一个不错的选择。本文将介绍Python的下载方法&#xff0c;并提供使用Python的基础指南。 Python的下载方法 Pyth…

Istio与Mcp Server服务器讲解与搭建演示

01Istio与外部注册中心 Istio为何需要对接外部注册中心 Istio 对 Kubernetes 具有较强的依赖性&#xff1a; 1.服务发现就是基于 Kubernetes 实现的&#xff0c;如果要使用 Istio&#xff0c;首先需要迁移到 Kubernetes 上&#xff0c;并使用 Kubernetes 的服务注册发现机制…

【数据挖掘】时间序列教程【二】

2.4 示例&#xff1a;颗粒物浓度 在本章中&#xff0c;我们将使用美国环境保护署的一些空气污染数据作为运行样本。该数据集由 2 年和 5 年空气动力学直径小于或等于 3.2017 \&#xff08;mu\&#xff09;g/m\&#xff08;^2018\&#xff09; 的颗粒物组成。 我们将特别关注来自…

认识GCC

GNU GNU是Linux系统下的一些工具包&#xff0c;GNU是GNU is Not Unix的缩写&#xff0c;因为当年Unix收费后&#xff0c;理查德马修斯托曼打算做一套GNU操作系统&#xff0c;当时GNU的工具包已经写好&#xff0c;就差内核即可组装成一个完整的操作系统&#xff0c;正好Linux写…

跨链 vs 多链

跨链 dApp 可以在部署在多个不同区块链上的多个不同智能合约上运行&#xff0c;而多链 dApp 则可以在不同网络上以多个单独的版本部署。 由于对区块空间的需求不断增加&#xff0c;Web3 应用层现在存在于数百个不同的区块链、二层网络和应用链上。这种现实催生了两个新术语——…

【教程】解决php微擎中的goto加密解密,一键解密工具

今天&#xff0c;我将向大家揭秘一款神奇的工具——goto解密工具&#xff0c;轻松解密这个看似棘手的问题。 无数开发者都曾因为php中的goto功能而头疼不已。goto解密工具其中之一就是解密goto代码。通过精妙的算法和强大的解析能力&#xff0c;它能够解密被goto加密的代码段&…

Vue项目设置网站小徽标

一、预期效果 自定义Vue项目的网站小徽标&#xff0c;用于显示网站的logo&#xff0c;效果大致如下 二、制作 .ico文件 2.1 打开比特虫官网 比特虫官网&#xff1a;https://www.bitbug.net/ 2.2 操作步骤如图 三、引入Vue项目 3.1 将生成的 .ico文件放入我们的 Vue 项目 3.…

servlet+JSP与SpringBoot+Vue项目交互——servlet请求SpringBoot接口

问题 servletJSP与SpringBootVue项目交互——servlet请求SpringBoot接口 详细问题 笔者前一段时间开发一个项目&#xff0c;使用的技术框架是servletJSP&#xff0c;现阶段开发的项目技术框架为SpringBootVue&#xff0c;笔者现在需要输入servletJSP请求SpringBoot接口&…

C语言编程—递归

递归指的是在函数的定义中使用函数自身的方法。 举个例子&#xff1a;从前有座山&#xff0c;山里有座庙&#xff0c;庙里有个老和尚&#xff0c;正在给小和尚讲故事呢&#xff01;故事是什么呢&#xff1f;"从前有座山&#xff0c;山里有座庙&#xff0c;庙里有个老和尚&…

2024考研408-计算机组成原理第六章-总线学习笔记

文章目录 前言初识总线一、总线概述1.1、总线的概述1.1.1、认识总线1.1.2、设计总线需要的特性1.1.3、总线的分类①按照数据传输格式分&#xff08;串行、并行&#xff09;②按照总线功能连接的总线&#xff08;片内总线、系统总线、通信总线&#xff09;③按照时序控制方式&am…

css新特性(五)

css基础&#xff08;一&#xff09;css基础&#xff08;一&#xff09;_上半场结束&#xff0c;中场已休息&#xff0c;下半场ing的博客-CSDN博客Emmet语法Emmet语法_上半场结束&#xff0c;中场已休息&#xff0c;下半场ing的博客-CSDN博客css基础&#xff08;二&#xff09;c…

Retrofit注解

1. 注解类型 Retrofit路径结合的规则 2. 网络请求方法 2.1 Get请求 完整地址&#xff1a;http://mock-api.com/2vKVbXK8.mock/getUserInfo?iduserid 2.1.1 Query 创建Retrofit实例必须传入baseurl(http://mock-api.com/2vKVbXK8.mock/)&#xff0c;在GET("getUserIn…