编码器 | 基于 Transformers 的编码器-解码器模型

news2024/12/23 13:35:39

基于 transformer 的编码器-解码器模型是 表征学习模型架构 这两个领域多年研究成果的结晶。本文简要介绍了神经编码器-解码器模型的历史,更多背景知识,建议读者阅读由 Sebastion Ruder 撰写的这篇精彩 博文。此外,建议读者对 自注意力 (self-attention) 架构 有一个基本了解,可以阅读 Jay Alammar 的 这篇博文 复习一下原始 transformer 模型。

本文分 4 个部分:

  • 背景 - 简要回顾了神经编码器-解码器模型的历史,重点关注基于 RNN 的模型。

  • 编码器-解码器 - 阐述基于 transformer 的编码器-解码器模型,并阐述如何使用该模型进行推理。

  • 编码器 - 阐述模型的编码器部分。

  • 解码器 - 阐述模型的解码器部分。

每个部分都建立在前一部分的基础上,但也可以单独阅读。这篇分享是第三部分 编码器

编码器

如前一节所述, 基于 transformer 的编码器将输入序列映射到上下文相关的编码序列:

仔细观察架构,基于 transformer 的编码器由许多 残差注意力模块 堆叠而成。每个编码器模块都包含一个 双向 自注意力层,其后跟着两个前馈层。这里,为简单起见,我们忽略归一化层 (normalization layer)。此外,我们不会深入讨论两个前馈层的作用,仅将其视为每个编码器模块 的输出映射层。双向自注意层将每个输入向量 与全部输入向量 相关联并通过该机制将每个输入向量 提炼为与其自身上下文相关的表征: 。因此,第一个编码器块将输入序列 (如下图浅绿色所示) 中的每个输入向量从 上下文无关 的向量表征转换为 上下文相关 的向量表征,后面每一个编码器模块都会进一步细化这个上下文表征,直到最后一个编码器模块输出最终的上下文相关编码 (如下图深绿色所示)。

我们对 编码器如何将输入序列 "I want to buy a car EOS" 变换为上下文编码序列这一过程进行一下可视化。与基于 RNN 的编码器类似,基于 transformer 的编码器也在输入序列最后添加了一个 EOS,以提示模型输入向量序列已结束 。

9a65d4c010e62a9773c466d0b7a326f5.png

上图中的 基于 transformer 的编码器由三个编码器模块组成。我们在右侧的红框中详细列出了第二个编码器模块的前三个输入向量: , 及 。红框下部的全连接图描述了双向自注意力机制,上面是两个前馈层。如前所述,我们主要关注双向自注意力机制。

可以看出,自注意力层的每个输出向量 都 直接 依赖于 所有 输入向量 。这意味着,单词 “want” 的输入向量表示 与单词 “buy” (即 ) 和单词 “I” (即 ) 直接相关。 因此,“want” 的输出向量表征, ,是一个融合了其上下文信息的更精细的表征。

我们更深入了解一下双向自注意力的工作原理。编码器模块的输入序列 中的每个输入向量 通过三个可训练的权重矩阵 ,, 分别投影至 key 向量 、value 向量 和 query 向量 (下图分别以橙色、蓝色和紫色表示):

请注意,对每个输入向量 ) 而言,其所使用的权重矩阵都是 相同 的。将每个输入向量 投影到 querykeyvalue 向量后,将每个 query 向量 ) 与所有 key 向量 进行比较。哪个 key 向量与 query 向量 越相似,其对应的 value 向量 对输出向量 的影响就越重要。更具体地说,输出向量 被定义为所有 value 向量的加权和 加上输入向量 。而各 value 向量的权重与 和各个 key 向量 之间的余弦相似度成正比,其数学公式为 ,如下文的公式所示。关于自注意力层的完整描述,建议读者阅读 这篇 博文或 原始论文。

好吧,又复杂起来了。我们以上例中的一个 query 向量为例图解一下双向自注意层。为简单起见,本例中假设我们的 基于 transformer 的解码器只有一个注意力头 config.num_heads = 1 并且没有归一化层。

f08f4a411957a2b34be7747a637ce7ed.png

图左显示了上个例子中的第二个编码器模块,右边详细可视化了第二个输入向量 的双向自注意机制,其对应输入词为 “want”。首先将所有输入向量 投影到它们各自的 query 向量 (上图中仅以紫色显示前三个 query 向量), value 向量 (蓝色) 和 key 向量 (橙色)。然后,将 query 向量 与所有 key 向量的转置 ( ) 相乘,随后进行 softmax 操作以产生 自注意力权重 。 自注意力权重最终与各自的 value 向量相乘,并加上输入向量 ,最终输出单词 “want” 的上下文相关表征, (图右深绿色表示)。整个等式显示在图右框的上部。 和   的相乘使得将 “want” 的向量表征与所有其他输入 (“I”,“to”,“buy”,“a”,“car”,“EOS”) 的向量表征相比较成为可能,因此自注意力权重反映出每个输入向量 对 “want” 一词的最终表征 的重要程度。

为了进一步理解双向自注意力层的含义,我们假设以下句子: “ 房子很漂亮且位于市中心,因此那儿公共交通很方便 ”。 “那儿”这个词指的是“房子”,这两个词相隔 12 个字。在基于 transformer 的编码器中,双向自注意力层运算一次,即可将“房子”的输入向量与“那儿”的输入向量相关联。相比之下,在基于 RNN 的编码器中,相距 12 个字的词将需要至少 12 个时间步的运算,这意味着在基于 RNN 的编码器中所需数学运算与距离呈线性关系。这使得基于 RNN 的编码器更难对长程上下文表征进行建模。此外,很明显,基于 transformer 的编码器比基于 RNN 的编码器-解码器模型更不容易丢失重要信息,因为编码的序列长度相对输入序列长度保持不变, ,而 RNN 则会将 压缩到 ,这使得 RNN 很难有效地对输入词之间的长程依赖关系进行编码。

除了更容易学到长程依赖外,我们还可以看到 transformer 架构能够并行处理文本。从数学上讲,这是通过将自注意力机制表示为 querykeyvalue 的矩阵乘来完成的:

输出 是由一系列矩阵乘计算和 softmax 操作算得,因此可以有效地并行化。请注意,在基于 RNN 的编码器模型中,隐含状态 的计算必须按顺序进行: 先计算第一个输入向量的隐含状态 ; 然后计算第二个输入向量的隐含状态,其取决于第一个隐含向量的状态,依此类推。RNN 的顺序性阻碍了有效的并行化,并使其在现代 GPU 硬件上比基于 transformer 的编码器模型的效率低得多。

太好了,现在我们应该对:
a) 基于 transformer 的编码器模型如何有效地建模长程上下文表征,以及
b) 它们如何有效地处理长序列向量输入这两个方面有了比较好的理解了。

现在,我们写一个 MarianMT 编码器-解码器模型的编码器部分的小例子,以验证这些理论在实践中行不行得通。


关于前馈层在基于 transformer 的模型中所扮演的角色的详细解释超出了本文的范畴。Yun 等人 (2017)  的工作认为前馈层对于将每个上下文向量 映射到目标输出空间至关重要,而单靠 自注意力 层无法达成这一目的。这里请注意,每个输出词元 都经由相同的前馈层处理。更多详细信息,建议读者阅读论文。

我们无须将 EOS 附加到输入序列,虽然有工作表明,在很多情况下加入它可以提高性能。相反地,基于 transformer 的解码器必须把 作为第 0 个目标向量,并以之为条件预测第 1 个目标向量。

from transformers import MarianMTModel, MarianTokenizer
import torch

tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")

embeddings = model.get_input_embeddings()

# create ids of encoded input vectors
input_ids = tokenizer("I want to buy a car", return_tensors="pt").input_ids

# pass input_ids to encoder
encoder_hidden_states = model.base_model.encoder(input_ids, return_dict=True).last_hidden_state

# change the input slightly and pass to encoder
input_ids_perturbed = tokenizer("I want to buy a house", return_tensors="pt").input_ids
encoder_hidden_states_perturbed = model.base_model.encoder(input_ids_perturbed, return_dict=True).last_hidden_state

# compare shape and encoding of first vector
print(f"Length of input embeddings {embeddings(input_ids).shape[1]}. Length of encoder_hidden_states {encoder_hidden_states.shape[1]}")

# compare values of word embedding of "I" for input_ids and perturbed input_ids
print("Is encoding for `I` equal to its perturbed version?: ", torch.allclose(encoder_hidden_states[0, 0], encoder_hidden_states_perturbed[0, 0], atol=1e-3))

输出:

Length of input embeddings 7. Length of encoder_hidden_states 7
    Is encoding for `I` equal to its perturbed version?: False

我们比较一下输入词嵌入的序列长度 ( embeddings(input_ids),对应于 ) 和 encoder_hidden_states 的长度 (对应于)。同时,我们让编码器对单词序列 “I want to buy a car” 及其轻微改动版 “I want to buy a house” 分别执行前向操作,以检查第一个词 “I” 的输出编码在更改输入序列的最后一个单词后是否会有所不同。

不出意外,输入词嵌入和编码器输出编码的长度, 和  ,是相等的。同时,可以注意到当最后一个单词从 “car” 改成 “house” 后, 的编码输出向量的值也改变了。因为我们现在已经理解了双向自注意力机制,这就不足为奇了。

顺带一提, 自编码 模型 (如 BERT) 的架构与 基于 transformer 的编码器模型是完全一样的。 自编码 模型利用这种架构对开放域文本数据进行大规模自监督预训练,以便它们可以将任何单词序列映射到深度双向表征。在 Devlin 等 (2018)  的工作中,作者展示了一个预训练 BERT 模型,其顶部有一个任务相关的分类层,可以在 11 个 NLP 任务上获得 SOTA 结果。你可以从 此处 找到 🤗 transformers 支持的所有 自编码 模型。

敬请关注其余部分的文章。


英文原文: https://hf.co/blog/encoder-decoder

原文作者: Patrick von Platen

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

审校/排版: zhongdongy (阿东)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/614448.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AUTOSAR】Bootloader说明(一)---- 时序流程

电机控制器选用TI TMS28xx DSP,包括boot-loader与应用软件两个部分。其中boot-loader包括下列内容: RAM自检应用程序有效性检查UDS命令处理FLASH操作 下面分别说明DSP上电后整个软件运行流程及程序刷新过程。 DSP软件执行流程 DSP复位后,将…

【Mysql基础】-关于常用的函数简单案例

目录 一、系统函数 二、日期函数 三、字符串函数数 说明:以下所有的操作在8.0的mysql数据库操作系统上操作 一、系统函数 1 显示连接列表:show PROCESSLIST; 2 MD5加密:select MD5("root") 二、日期函数 1、 推算一周之后的…

QMI8658 - 姿态传感的零偏(常值零偏)标定

1. 零偏 理论上在静止状态下三轴输出为0,0,0,但实际上输出有一个小的偏置,这是零偏的静态分量(也称固定零偏)。 陀螺生产出来后就一直固定不变的零偏值。对于传统的高性能惯性器件来说,该误差在出厂标定时往往就被补偿…

《水经注地图服务》用户如何登录?

《水经注地图服务》(WeServer)是一款可快速发布全国乃至全球海量卫星影像的地图发布服务产品,该产品完全遵循OGC相关协议标准,是一个基于若干项目成功经验总结的产品。它可以轻松发布100TB级海量卫星影像,从而使“在内…

如何使用 Raycast 一键打开预设工作环境

工作中,你一定遇到过这样的场景:你正在认真写代码,线上突然出现报警。看到报警信息之后,你不得不打开浏览器,点开收藏夹,打开监控页面、告警页面、trace 页面、日志搜索平台……有时,还需要打开…

chatgpt赋能python:Python取值:了解基础知识和应用方法

Python取值:了解基础知识和应用方法 什么是Python取值? Python取值是指从一个对象中获取信息或者值。对象可以包括列表、字典、元组、变量等。Python提供了多种方法来取值,包括基础的索引和切片操作,以及高级的列表推导式、字典…

MySQL JDBC详解

文章目录 简介JDBC APIJDBC Driver ManagerJDBC 驱动 JDBC 开发步骤一,导入 JDBC 驱动包,并加载驱动类二,建立数据库连接三,发送 SQL 语句,并获取执行结果Statement 对象PreparedStatement 对象 四,处理返回…

ADAS方案的简单比较

ADAS方案的简单比较 1 概述2 厂商Tesla硬件布局网络基础结构:HydraNet多头网络 NVIDIA百度(Apollo)版本历史硬件布局软件框架各版本框架 WaymoVolvo-Uber 3 芯片4 其他from [最全自动驾驶技术架构和综述](https://blog.csdn.net/buptgshengod…

项目质量管理

质量与项目质量 质量的定义:一组固有特征满足要求的程序。 质量是反应实体主题明确和隐含需求的能力的特性总和 质量与等级的关系: 一个低等级(功能有限),高质量(无明显缺陷,用户手册易读&am…

《Datawhale南瓜书》出第二版啦!

Datawhale干货 作者:Datawhale开源项目团队 作为机器学习的入门经典教材,周志华老师的《机器学习》,自2016年1月底出版以来,首印5000册一周售罄,并在8个月内重印9次。先后登上了亚马逊,京东,当…

【运维知识进阶篇】iptables防火墙详解

这篇文章给大家介绍下iptables防火墙,防火墙大致分三种,分别是硬件、软件和云防火墙。硬件的话部署在企业网络的入口,有三层路由的H3C、华为、Cisco(思科),还有深信服等等;软件的话一般是开源软…

【服务器】iPad远程服务器进行开发

文章目录 前言1. 本地环境配置2. 内网穿透2.1 安装cpolar内网穿透(支持一键自动安装脚本)2.2 创建HTTP隧道 3. 测试远程访问4. 配置固定二级子域名4.1 保留二级子域名4.2 配置二级子域名 5. 测试使用固定二级子域名远程访问6. iPad通过软件远程vscode6.1 创建TCP隧道 7. ipad远…

人工智能 AI | ChatGPT 时代,程序员的生存之道

ChatGPT 近期炙手可热,仿佛没有什么问题是它不能解决的。出于对 ChatGPT 的好奇,我们决定探索下它对于前端开发人员来讲,是作为辅助工具多一些,还是主力工具更多一些? 2D 能力测试 我们就挑选一个著名的递归回溯问题—…

代码随想录算法训练营第三十九天|62.不同路径|63. 不同路径 II

LeetCode62.不同路径 动态规划五部曲: 1,确定dp数组(dp table)以及下标的含义:dp[i][j] :表示从(0 ,0)出发,到(i, j) 有dp[i][j]条不同的路径。 2&#xff0c…

cpu飚高的排查思路

cpu的衡量指标 使用率util:代表的是单位时间内CPU繁忙情况的统计。操作系统对cpu的管理就是利用周期的tick时钟中断,将cpu的使用划分时间片。每个时间片内去执行不同进程/线程里的代码。所以cpu的使用率统计其实也是以tick为单位的:统计周期…

开源代码分享(1)—考虑经济性的储能运行优化

参考文献: [1]Practical operation strategies for pumped hydroelectric energy storage (PHES) utilising electricity price arbitrage - ScienceDirect [2]Towards an objective method to compare energy storage technologies: development and validation of…

Python——Flask快速开发一个物资管理平台(源码+适合大作业)

目录 一、前言 二、项目展示 三、代码包 四、项目简介 五、运行步骤 一、前言 Flask 框架结合原生的 HTML 和 Bootstrap 可以快速开发 Web 应用程序。 Flask 框架是 Python 中一个轻量级的 Web 应用框架,它非常适合构建小型项目和原型化开发。Flask 框架具有可扩展的…

K8S利用nginx快速部署一个网站之基本概念(十)

在Kubernetes部署应用程序流程 使用Deployment控制器部署镜像: kubectl create deployment web --imagenginx --replicas3 kubectl get deploy,pods 使用Service将Pod暴露出去: kubectl expose deployment web --port80 --target-port80 --typeNodePor…

Pyside6-第六篇-各按钮的信号与槽

今天是Pyside6的第六篇内容。一起来看看各按钮的信号与槽。 from PySide6.QtCore import Qt from PySide6.QtGui import QAction from PySide6.QtWidgets import QApplication, QWidget, QRadioButton, \QPushButton, QCheckBox, QToolButton, QMenuclass Example(QWidget):def…

使用BERT进行文本分类

本范例我们微调transformers中的BERT来处理文本情感分类任务。 我们的数据集是美团外卖的用户评论数据集。 模型目标是把评论分成好评(标签为1)和差评(标签为0)。 #安装库 #!pip install datasets #!pip install transformers[torch] #!pip install torchkeras 公众号算法美食…