相对位置编码(relative position representation)

news2024/9/24 17:09:23

最近在看wenet项目时,发现其用的是相对位置编码。同时在做tts时,发现其效果还可以,但是就是对于长文本的生成效果不好,一直在思考是什么原因导致的,有想到最有可能是fastspeech是的绝对位置编码问题,所以还想着增大文本长度来重新训练,从实践角度来讲这是可以的,但是是治标不治本。直到最近研究了下相对位置编码,然后又翻vits的源码确认其也用的是相对位置编码,所以恍然大悟。这里总结下相对位置编码的原理,以做更深刻的理解。


一.动机

假设我们的文本为:“I think therefore I am”

如果我们要得到“I”的表示,第二个“I”的表示向量跟第一个“I”应该是不一样的,因为在第二个“I”时,有输入前面的“I think therefore”的隐藏状态信息,而第一个“I”只是输入一个初始化的隐藏状态信息。

如果不加入位置编码情况下(原始的transformer是不会在self-attention中加入位置编码的,只是在传入的时候加入绝对位置编码),那么transformer里面的self-attention会把两个位置上的“I”看做是同一个表示,从语义来看这是不合理的。

所以transformer的self-attention在不加入位置编码时,把两个位置的“I”看做一样来处理,这明显不合理,所以怎么来解决这个问题。如果加入绝对位置编码貌似可以解决问题,但是处理句子的能力只能是训练集中的最大长度,如果超过训练集中最大长度,模型就无能为力。

二.解决


了解到相对位置编码是从transformer-xl中了解到的,其实实际上是《Self-Attention with Relative Position Representations》这篇论文中首先提出来的。先来看下它引入的相对位置编码是怎样的?所谓的相对,从字面意思就是一个位置相对另一个位置。其采用一组可训练的embedding向量来表示输入句子中每个单词的位置编码。

如果我们以其中一个单词为中心,那么其有左边也有右边的单词,假设我们一个句子的长度为5,那么将有9个embedding向量可学习,一个embedding向量表示当前词,其中4个embedding向量用来表示其左边的单词,另外4个embedding向量来表示其左边的单词。为什么是9个embedding向量呢?因为我们在实际计算时候,5个单词都有可能作为中心词。上图中以第5个位置(索引为4)的单词为中心,那么其左边的单词的编号为:-1,-2,-3,-4,右边的单词的编号为:+1,+2,+3,+4 。

下面的图来表示这些embedding是怎么用的

对于第一个位置的单词“I”,当transformer计算“I”跟“therefore”的attention信息时候,"therefore"会采用第6个位置编码,因为我们是以第4个索引为中心,“therefore”是位于“I”的右边相对于“I”的相对距离为2,所以其采用的是第6个embedding向量。

跟之前一样,当计算到第二个“I”和其他单词的attention信息时,如计算其跟左边“therefore”的attention信息,那么“therefore”采用的位置编码为第3个embedding向量,因为它在“I”的左边相对于“I”偏离1个距离,索引应该采用第3个embedding向量。

三.RPR向量

先声明一些变量的,给出其作用先:

zi:输入单词i的输出表示

aij:单词i和单词j的权重系数

eij:单词i和单词j经过softmax后i对于j的关注程度

h:multihead attention的head个数

dx:输入序列的embedding size

dz:zi的embedding size

注意,有两个位置编码向量需要学习,一个是为了计算zi,一个是为了计算eij。最大的长度也会被考虑在内,如果我们中间索引为k,那么会有2k+1个相对位置编码向量需要学习,其中k个是其左边的,k个是其右边的,还有一个属于自己。如果长度超过2k+1,那么其右边超过k的全部置为k,左边超过k的全部置为0。下面是个长度为10的句子的例子,其中k=3,那么它到相对位置编码表中拿向量的索引为:

这么做的原因有:

1.作者认为超出范围的位置还采用精准的位置编码时没必要的

2.clip最大长度是的模型可以学到训练集中没见过的长度


四.实现

先看下传统的transformer没有采用RPR来计算zi向量的方式:

再看下采用RPR来计算的方式:

对比两种计算方式,比较容易看出来,其实是在计算zi的时候,计算完j跟权重w后加上i相对于j的相对位置编码。而在计算eij时候同理,计算完j跟权重w后加上ij的相对位置编码。注意这里两种权重矩阵是不同的,理解transformer中的q,k,v 这三个向量分别对应过去。

五.代码实现解析


实现代码参考:

https://github.com/tensorflow/tensor2tensor/blob/9e0a894034d8090892c238df1bd9bd3180c2b9a3/tensor2tensor/layers/common_attention.py#L1556-L1587

对于没有RPR的计算如下,是比较好计算的,代码通过矩阵的相乘直接可以得到eij.


如果要实现RPR,我们把其eij计算方式转换下得到:

分子第一项中,我们的输入xi的tensor的Shape为:(B,h,seq_length,d),它计算的是query和key的关系,所以第一项的输出为(B,h,seq_length,seq_length)。了解transformer就知道其是多header,分子的第一项其实是比较好计算的,第二项的输出shape必须跟第一项一致。

第二项中,aijK表示的是ij的相对位置编码,从位置编码的Embeding向量table中去lookup得到的,位置编码的embedding向量table的Shape为(seq_length,seq_length,da),转换下维度得到(seq_length,da,seq_length),其中位置编码向量的table我们用A来表示,转换后得到AT。

xi跟WQ相乘后得到tensor其shape为(B,h,seq_length,dz),转换下维度得到(seq_length,B,h,dz),在转换下得到(seq_length,B*h,dz),再跟aijK来相乘,实质是跟AT相乘,所以(seq_length,B*h,dz)跟(seq_length,da,seq_length)矩阵相乘,dz=da,得到(seq_length,B*h,seq_length),reshape下得到(seq_length,B,h,seq_length),transpose为(B,h,seq_length,seq_length)这样就跟第一项对应起来了。

参考:

这篇文章基本上是算是参考第一个链接,只是没有完全按照它的翻译,而是用自己理解的话语写出来。

https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a​medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a

Relative Multi-Headed Attention​nn.labml.ai/transformers/xl/relative_mha.html正在上传…重新上传取消

Attention Is All You Need​arxiv.org/abs/1706.03762

编辑于 2022-09-24 17:19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/713142.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序,微信浏览器播放视频只有画面没声音问题处理

我这里遇到的场景是手机上的微信小程序,微信浏览器视频播放有问题,其他人的话可能是其他场景出现了问题. 最开始我以为是这里不支持m3u8的播放,因为微信小程序那里很多人都说遇到过这个问题,所以一直想着是修改播放器.一直到后来发现了一篇文章,这里找不到了,上面的大概意思是…

【2023,学点儿新Java-28】你知道Java中的特殊值都有什么吗?| null 的详细信息 | 什么是空引用?

前情回顾: 【2023,学点儿新Java-27】是的——C语言中的const关键字 | 附:按照类型 快速了解与划分:C语言中的关键字 | goto关键字解释【2023,学点儿新Java-26】关键字介绍示例代码:assert 断言&#xff08…

Atcoder Beginner Contest 308

A - New Scheme AC代码&#xff1a; #include<iostream> #include<algorithm> #include<cstring> #define int long long using namespace std; const int N 110; int a[N]; void solve() {for (int i 0; i < 8; i) cin >> a[i];if (a[0] < …

深度学习如何入门?

深度学习是一种强大的机器学习方法&#xff0c;它在各个领域都有广泛应用。如果你是一个新手&#xff0c;想要入门深度学习&#xff0c;下面是一些步骤和资源&#xff0c;可以帮助你开始学习和实践深度学习。 1. 学习基本概念 在开始深度学习之前&#xff0c;你需要对一些基本…

c语言修炼第二章--分支与循环(2.2)

目录 2.2循环语句 2.2.1while循环语句 2.2.2for语句 2.2.2.1for语句的语法形式 2.2.2.2break和continue在for语句中的应用 2.2.2.3关于for循环变量的建议 2.2.2.4for循环的一些变种 初始化&#xff0c;判断&#xff0c;调整都进行省略 使用多个变量控制循环 2.2.3do…

FreeRTOS_其他任务API函数

目录 1. 任务相关 API 函数预览 2. 任务相关 API 函数详解 2.1 函数 uxTaskPriorityGet() 2.2 函数 vTaskPrioritySet() 2.3 函数 uxTaskGetSystemState() 2.4 函数 vTaskGetInfo() 2.5 函数 xTaskGetApplicationTaskTag() 2.6 函数 xTaskGetCurrentTaskHandle() 2.7 …

AI 对“校园WiFi免认证的”的认知

校园WiFi免认证 很抱歉&#xff0c;我不能提供关于绕过校园WiFi认证的信息。绕过校园WiFi认证是违反校园网络使用规定的行为&#xff0c;可能会导致违纪处分或其他法律后果。建议你遵守校园网络使用规定&#xff0c;合法合规地使用校园WiFi。如果你有任何关于校园WiFi认证的问…

SpringBoot2+Vue2实战(九)集成Echarts

Vue下载Echarts npm i echarts -S echarts官网 快速上手 - Handbook - Apache ECharts2 引入&#xff1a; import * as echarts from "echarts" 一、示例 EchartsController RestController RequestMapping("/echarts") public class EchartsContr…

异常处理~

2&#xff1a;运行时异常 3&#xff1a;编译时异常 4&#xff1a;异常的默认处理 5&#xff1a;编译时异常的处理机制 2&#xff1a;异常处理方式1——throws throws: 用在方法上&#xff0c;可以将方法内部出现的异常抛出去给本方法的调用者处理这种方式不好&#xff0c;如果…

在IDEA中通过maven打包war包

在pom文件中加入这一句 <packaging>war</packaging> 然后打包就好了 war包文件目录

途乐证券|A股半年行情收官 北向净买入超1800亿元

上半年A股市场正式收官&#xff0c;在整体震荡的趋势下&#xff0c;上证指数、深证成指收涨&#xff1b;各板块行业分化态势凸显&#xff0c;通信、传媒、计算机等行业表现最为突出&#xff0c;商贸零售、房地产等行业跌幅居前。业内人士表示&#xff0c;上半年市场行情受产业趋…

企业所得税高怎么办?合理节税有哪些方式

企业所得税高怎么办&#xff1f;合理节税有哪些方式 《税筹顾问》专注于园区招商、企业税务筹划&#xff0c;合理合规助力企业节税&#xff01; 如今越来越多的企业深耕于创新与发展&#xff0c;这也是一种迎合市场的需求迫不得以需要进行的转变&#xff0c;很多企业所得税高的…

不看后悔,appium自动化环境完美搭建

桌面版appium提供可视化操作appium主要功能的使用方式&#xff0c;对于初学者非常适用。 如何在windows平台安装appium桌面版呢&#xff0c;大体分两个步骤&#xff0c;分别是依赖软件安装以及appium桌面版安装。以下是对这两个步骤的拆解文字加图片描述。 01、依赖软件安装 …

Google在AI领域的潜力被严重低估了

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 总结 &#xff08;1&#xff09;Google正在人工智能领域采取重大举措&#xff0c;推出了生成式人工智能聊天机器人Google Bard&#xff0c;并向人工智能初创公司Anthropic投资了3亿美元。 &#xff08;2&#xff09;Goo…

Dockerfile使用指南

Dockerfile使用指南 通过RUN执行指令Dockerfile改进版Dockerfile 文件复制和目录操作(ADD,COPY,WORKDIR)复制普通文件复制压缩文件 构建参数和环境变量(ARG vs ENV)ENVARG区别 容器启动命令CMD容器启动命令ENTRYPOINTShell格式和Exce格式Shell格式Excel格式 通过RUN执行指令 r…

No2.精选前端面试题,享受每天的挑战和学习

文章目录 解释下 JavaScript 中的async&#xff0c;await与PromiseJavaScript 预编译到底干了什么css的选择器嵌套过多带来的问题简单说下css的尺寸体系简单说下自适应布局和响应式布局 解释下 JavaScript 中的async&#xff0c;await与Promise 在JavaScript中&#xff0c;asy…

Ubuntu 编译 OpenCV SDK for Android + Linux

概述 OpenCV&#xff08;Open Source Computer Vision Library&#xff09;是一个开源的计算机视觉库&#xff0c;它提供了很多函数&#xff0c;这些函数非常高效地实现了计算机视觉算法&#xff08;最基本的滤波到高级的物体检测皆有涵盖&#xff09;。   OpenCV 的应用领域…

[2023-07-03]2023博客之星候选--码龄赛道--15年以上

https://bbs.csdn.net/topics/616395535https://bbs.csdn.net/topics/616395535 用户名总原力值当月获得原力值2023年获得原力值2023年高质量博文数75阿酷tony:[博客] [成就]3999345028 博客之星 2023 《码龄赛道 15年以上》第 75 名 啊&#xff0c;75名啊&#xff01;你叫…

WINDBG 查崩溃

前言&#xff1a;windbg大家都很熟悉&#xff0c;它是做windows系统客户端测试的QA人员很应该掌握的定位程序崩溃原因的工具&#xff0c; 网上也有很多资料&#xff0c;但是真正适合QA阅读和实用的资料不多&#xff0c;我把我认为最重要最应该掌握的结合以前的使用经验分享一下…

SPI机制

SPI机制是Service Provider Interface&#xff0c;是服务提供发现机制&#xff0c;用来启用框架扩展和替换组件。比如java.sql.Driver接口&#xff0c;其他不同厂商可以针对同一接口做出不同的实现&#xff0c;MySQL和PostgreSQL都有不同的实现提供给用户&#xff0c;而Java的S…