论文学习笔记:Transformer Attention Is All You Need

news2024/12/23 22:42:30

Transformer: Attention Is All You Need

2022 年年底,一个大语言模型 ChatGPT 横空出世,并且迅速点燃了普罗大众对 AI 的热情,短短两个月, ChatGPT 就成为了史上最快成为上亿月活的应用,并且持续受到关注,ChatGPT 的问世,让谷歌感到深深的寒意,让微软重回巅峰,让 Meta 放弃元宇宙,重新押注 AIGC,也搅动着国内一众互联网公司,各种初创公司的内心,大家在短短的几个月内,都祭出了自己的大模型。

虽然各种各样的模型层出不穷,但是我们回过头看 ChatGPT 的演进历程,从 GPT-1, GPT-2, GPT-3 到 InstructGPT,可以发现都和一个称为 Transformer 的模型有关,而且这个 Transformer 的模型最早还是谷歌提出来的,结果却是 OpenAI 基于 Transformer 做出了 ChatGPT。

最近 AIGC 的持续火爆,让我也想对这一系列的模型的前世今生做个大致的了解,刚好 B 站上面有李沐大神的论文精读系列,把这些大语言模型以及 Transformer,还有 Transformer 后面在视觉领域的扩展包括 VIT,Swin Transformer 都做了详细的讲解,虽然很早以前自己也读过 Transformer 这篇文章,不过后面因为研究领域的重心转移,所以对后面的很多模型的发展都没有怎么关注了,最近结合李沐大神的论文精读系列,让我对这些模型的细节有了更为详细的理解。正好结合李沐的视频讲解和论文本身做一个论文学习笔记的整理。

Transformer 这篇文章发表于 2017 年的 NIPS 上,当年的 NIPS 还叫 NIPS,后来改成了 NeurIPS,这篇文章的 Title 叫 Attention Is All You Need,这个标题当年还成为了一种梗,引领了一种取文章标题的新潮流,后面有不少文章叫 XX is all you need。这篇文章提出了一种新的模型:Transformer。现在回过头来看,Transformer 可以被认为是与 CNN, MLP, RNN, LSTM 一样的基础模型,并且随着多模态学习的发展,大有一统语言与视觉的趋势。

接下来,进入正题。

Abstract

在主流的序列转录模型中,主要是依赖复杂的循环网络或者卷积神经网络,这类模型都包含一个 Encoder 和 Decoder。在性能最好的模型里,通常会在编码器与解码器之间使用一个注意力机制进行连接。这篇文章提出了一个简单的架构,称为 Transformer,仅仅只依靠注意力机制,不再需要循环或者卷积。文章提出的模型,在两个机器翻译任务上取得了更好的效果,同时具备更好的并行化以及更少的训练耗时。同时模型泛化到其它任务上也有不错的表现。

Introduction

在导言一开始,文章作者就提到在序列模型或者处理转录问题比如语言模型,机器翻译等问题上,比较常用的模型有 RNN, LSTM, GRU,这些模型都取得了不错的效果。其中,有很多的工作都用到了循环语言模型和编码器-解码器架构。

接下来,文章作者介绍了 RNN 模型的特点,RNN 是一种循环神经网络,可以比较方便的处理序列信息,RNN 处理序列信息的时候,是从左往右一个一个往下做,每个时刻都会输出一个称为 h t h_t ht 的隐藏态,当前时刻的隐藏态 h t h_t ht 是由前一时刻的隐藏态 h t − 1 h_{t-1} ht1 以及当前时刻的输入所决定。这样的话,就可以将学到的信息存在当前的隐藏态 h t h_t ht 中,RNN 的这种结构特点,可以让它将学到的历史信息存在隐藏态中,并且一步一步的传递下去。RNN 的这个特点,也给它带来了计算效率不高的问题,因为 RNN 需要时序地处理信息,所以难以并行化。第二个问题就是 RNN 处理长序列的时候,容易遗忘更早的信息,如果想保留更多的历史信息,就需要用更大的 h t h_t ht 这样会导致内存开销的增加。文章作者也提到说,虽然一直也有相关的工作在改进 RNN 的这些缺点,但是本质问题依然存在。

导言的第三段,作者讲的是如何在 RNN 模型中引入注意力机制,注意力机制并不是 Transformer 这篇文章的独创,很多语言模型或者转录模型里面都用到过注意力机制。不过这些注意力机制一般是和 RNN 联合起来一块使用的。

导言的最后,文章作者再次介绍了 Transformer 这个模型,相比于之前的模型,Transformer 是只使用纯注意力机制,不再使用 RNN,并且可以高效地做并行化计算。

Background

导言结束之后,接下来是相关背景工作的介绍,第一段主要是在介绍一些相关工作,如何用卷积神经网络替换循环神经网络处理时序输入,以增加并行化。不过,文章作者又提到说,卷积神经网络对于比较长的时序难以建模,因为卷积神经网络的基本运算形式是对相邻的输入信息做卷积运算,如果需要计算的信息隔得比较远,可能就需要一层一层做多次卷积,才能让并不相邻的输入信息得以运算。Transformer 可以一次看到所有的输入,所有一次就可以将所有的信息都纳入计算。不过,卷积神经网络的好处是可以方便的进行多个输出通道的输出,多个输出通道,意味着有更多的训练参数可以学习,不同的输出通道可以处理不同的特征。文章作者为了能让 Transformer 也能有多输出通道的好处,也提出了 multi-head attention 的机制,也就是多头注意力机制。

接下来第二段,作者提到说,自注意力机制并不是这篇文章的独创,在之前的一些工作中已经提出并使用。自注意力机制就是一个时序信息自己与自己计算相关性的一种方式。

第三段是介绍一种端到端的称为 memory-network 的工作,这个工作在文章发表的当年(2017)年还是比较流行的,不过现在已经不再流行了。

最后一段,就是作者再次强调 Transformer 的一些性质,就不再重复了。

Model Architecture

接下来,就是文章的核心部分,也就是模型架构的介绍,作者一开始就介绍说,目前比较好的处理序列的模型基本都会用 Encoder-Decoder 的架构。Encoder,也就是常说的编码器,会将一组长度为 n n n 的输入时序信息 ( x 1 , x 2 , . . . , x n ) (x_1, x_2, ..., x_n) (x1,x2,...,xn),映射到一组连续的特征表达上, z = ( z 1 , z 2 , . . . , z n ) \mathbf{z} = (z_1, z_2,...,z_n) z=(z1,z2,...,zn),得到 z \mathbf{z} z 之后,解码器再生成一组长为 m m m 的输出 ( y 1 , y 2 , . . . , y m ) (y_1, y_2, ..., y_m) (y1,y2,...,ym),在这个过程中的每一步都是自回归的,也就是上一步的输出,会作为当前的输入一部分参与模型的计算。

然后作者给出了模型的整体架构示意图,如下所示:

在这里插入图片描述

Transformer 主要模块就是多头自注意力以及全连接层,将这些模块一层一层堆在一起。从图上可以看出,整个架构是包含一个 Encoder 编码器,以及一个 Decoder 解码器的。我们先看编码器部分,编码器一开始是一组输入,比如说一组词向量,然后这组词向量,一般会通过一个嵌入层,将词向量转换成特征向量,然后这些词向量与一个叫做 positional encoding,也就是位置编码的东西相加,然后输入编码器的核心模块,编码器的核心模块其实是 N 个相同的模块叠加而成,文章中 N = 6,每个模块里面的组件都是一样的,文章作者把每个 transformer block 称为一个 layer,然后每个 layer 里面包括两个 sub-layer, 第一个 sub-layer 包括一个 multi-head attention, 也就是多头自注意力机制,然后是一个残差模块,接一个 layer-norm 模块;第二个 sub-layer 是一个全连接层,再接上一个残差模块以及 layer-norm 模块。每个 sub-layer 的输出可以表示成:LayerNorm(x + Sublayer(x)),这样就是一个标准的 Encoder 模块,然后 N 个这样的 Encoder 模块叠加,就构成了一个完整的 Encoder 编码器。文章作者也提到说,为了让所有的向量操作方便高效,所有模块的向量维度都统一为 512。

  • LayerNorm:文章提到了一个 LayerNorm 的机制,一般大家对 Batch-Norm 可能更熟悉一些,Batch-norm 是在特征层面进行归一化的操作,而 LayerNorm 是在样本层面进行归一化的操作。

介绍完编码器,接下来看解码器部分,编码器的输出,会作为解码器输入的一部分,如果看示意图的话,可以发现,解码器也是 N 个 block 组成的,每个 block 里面也是一些标准的组件,其中有一部分标准组件和编码器部分也是一致的。不过解码器的 block 里面,多了一个称为 masked multi-head attention 的组件,就是带掩码的多头自注意力机制模块。后面会详细介绍这个带掩码的多头自注意力机制模块。

接下来,我们看看这些标准模块里面的各个组件的细节。

Attention

首先,文章介绍了 attention,attention 函数是用来将一个 query 和一组 key-value 映射成一个输出的函数,其中,query,key, value 以及 output 都是向量。其中,output 可以看成是这组 value 的加权和,而权重就是由 query 和这组 key 之间计算相似度得到的。所以可以看到,output 其实和 value 的维度是一致的。

文章中也给出了一个表达式:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^{T}}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

根据上面的表达式,我们可以简单的推导一下,为了方便起见,这里假设 query, value, key 的维度都是一样的。比如说 Q Q Q 是一个 1 × d 1 \times d 1×d 的向量, K K K n n n 个 key 所对应的向量组成的矩阵,那么 $ K $ 就是一个 $ n \times d$ 的矩阵,与此类似, V V V n n n 个 value 所对应的向量组成的矩阵,那么 $ V $ 就是一个 $ n \times d$ 的矩阵,代入上面的表达式,可以得到最终的 output 就是一个 1 × d 1 \times d 1×d 的向量, softmax \text{softmax} softmax 这部分就是在计算权重。

在实际计算的时候,我们可以一次输入多个 query,也就是上式中的 Q Q Q 不再是一个向量,而是多个 query 向量组成的矩阵,比如说 Q Q Q m m m 个 query 向量组成的矩阵,那么 Q Q Q 就是一个 m × d m \times d m×d 的矩阵,从上面的表达式,可以看到 m m m 不管是 1 还是大于 1,并不会改变上面的运算性质,都是可以一次计算就能得到所有输出的。所有实际运算的时候,可以采取批处理的方式。

文章接下来介绍说,一般注意力机制有两种情况,一种是加性的注意力,另外一种是点乘的注意力机制,文章中用到的其实就是点乘的注意力机制,不过与一般的点乘注意力机制不同的地方在于,计算相似度的时候,会除以一个统一的 scale d k \sqrt{d_k} dk ,文章作者也解释了为什么需要除以这样一个统一的 scale,因为文章所用的向量维度 d k = 512 d_k = 512 dk=512,属于比较大的向量维度,为了让点乘之后的内积分布不至于太极端,使得 softmax 的梯度太小难以训练,除以一个统一的 scale 会让训练更为稳定。

文章将 attention 的计算流程也给出了一张图示,如下图所示:

在这里插入图片描述

整个流程差不多就和上面介绍的差不多,不过示意图里面,有一个 mask 的模块,这个是用于解码器中的,因为文章中介绍 transformer 模型用于自回归,自回归的形式,就是当前的输出作为下一次预测的输入,也就是 t t t 时刻的预测,应该基于 t − 1 t-1 t1 时刻之前的输出,而不应该看到 t t t 时刻之后的信息,但是对于 attention 来说,它并不区分哪些是以前的输出,哪些是之后的输出,因为 attention 机制是可以一次把所有的输入都算完,所以训练的时候,为了模拟真实的测试情况,文章非常巧妙地加入了一个掩码机制,也就是说我算 attention 的时候,还是可以一次性把所有的权重都计算出来,但是把这些权重送入 softmax 计算权重之前,可以将 t t t 时刻后面的权重置 0,具体做法就是给 t t t 时刻之前的 mask 值赋 1,然后 t t t 时刻后面的值赋以一个很大的负数,比如 -1e-20,这样经过 softmax 的时候,这些权重都会变成 0,从而实现掩码的效果。

Multi-Head Attention

接下来,我们看看什么是多头注意力机制,文章中也给出了解释,文章作者说,与其我用一个 attention 机制把一个长度 d d d 的向量一次性算完,不如先把这个向量投影到若干个低维的向量空间,然后在这些低维的向量空间分别进行 attention 的计算,得到若干个低维的 output 向量,再把这些低维的 output 向量 concat 到一起,再投影回来,就可以得到最终的输出。这样的好处,就是可以让投影参数变成可以学习的参数,不同的低维空间或许可以对应不同的相似性匹配,文章也介绍了具体的计算方式:

MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , . . . , head n ) W o \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_n)W^{o} MultiHead(Q,K,V)=Concat(head1,head2,...,headn)Wo

head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_{i} = \text{Attention}(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V}) headi=Attention(QWiQ,KWiK,VWiV)

其中, W i Q ∈ R d m o d e l × d k W_{i}^{Q} \in R^{d_{model} \times d_k} WiQRdmodel×dk, W i K ∈ R d m o d e l × d k W_{i}^{K} \in R^{d_{model} \times d_k} WiKRdmodel×dk, W i V ∈ R d m o d e l × d v , W o ∈ R h d v × d m o d e l W_{i}^{V} \in R^{d_{model} \times d_v}, W^{o} \in R^{hd_v \times d_{model}} WiVRdmodel×dv,WoRhdv×dmodel 都是可以训练的投影参数

文章中, h = 8 h=8 h=8 d k = d v = d m o d e l / h = 64 d_k = d_v = d_{model} / h = 64 dk=dv=dmodel/h=64

可以看到,相比单一的 attention,多头的 attention 可以学习的参数增加了很多。

Applications of Attention in our Model

讲完了多头注意力,接下来文章介绍了 transformer 模型使用注意力机制的几种方法,对应文章中的模型架构图,我们可以看到一共有三种使用方式:

  • 方式一:这个在编码器里面,我们看到一个输入被复制成三份送入了 attention 模块,说明这个地方的 query,key, value 其实是同一个东西
  • 方式二:这个在解码器里面,我们也可以看到有一个attention 模块与编码器的 attention 模块类似,一个输入被复制成三份送入后面的 attention 模块,这个地方的 query,key, value 也是同一个东西
  • 方式三:这个是将编码器与解码器进行连接的部分,可以看到 query 来自解码器上一层的输出,而 key, value 来自编码器的输出,这个 attention 就是将解码器中的 query 与编码器中的 key 进行相似度匹配,查询与 query 最近的 key 所对应的 value 进行加权输出

接下来,文章作者介绍了 Position-wise Feed-Forward Networks,这个其实就是一个 MLP,所以没有什么太多的技术细节,词嵌入层及 softmax 都是比较熟悉的操作。

Positional Encoding

最后,再介绍一下 Positional Encoding,位置编码也是 transformer 模型比较重要的一个性质,因为我们计算 attention 的时候,只关注了 query 和 key 之间的相似度,并不会考虑时序的位置信息,为了能够将时序信息被 transformer 所感知,所以文章作者另外引入了位置编码信息。
文章作者使用了三角周期函数来做位置编码:

P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin(pos/10000^{2i/d_{model}}) PE(pos,2i)=sin(pos/100002i/dmodel)

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos(pos/10000^{2i/d_{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)

其中, i i i 表示维度编号,如果 d m o d e l = 512 d_{model} = 512 dmodel=512,那么 i = 0 , 1 , . . . 511 i=0,1,...511 i=0,1,...511,pos 表示位置编号,对于序列长度为 n n n 的输入来说,pos 就是从 0 , 1 , . . . n − 1 0,1,...n-1 0,1,...n1

看到这里,文章的核心部分,也就是模型这块就介绍完了。

Why Self-Attention

第四部分,作者大概介绍了一下为什么会提出 self-attention 这种机制,主要是通过以下几个性质做了对比,如下表所示

在这里插入图片描述

可以看到,这个表里罗列了 Maximum path lengths, per-layer complexity and minimum number of sequential operations

  • Complexity per Layer:每种 layer 本身的算法复杂度
  • Sequential operations:顺序计算,你依赖的顺序计算越小,并行化的程度越高
  • Maximum path lengths:最大路径长度,也就是在序列中,从一个词跳到另外一个词需要多少步,也是越短越好,越短,处理长序列的能力越强

总体比较下来,self-attention 的并行化程度很高,同时处理长序列的能力也很强。

  • 参考:

https://www.bilibili.com/video/BV1pu411o7BE/?spm_id_from=333.788&vd_source=bb80399e033aacf33a21a9f9864c6086

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/479652.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Vue2.0源码学习】变化侦测篇-Object的变化侦测

文章目录 1. 前言2. 使Object数据变得“可观测”3. 依赖收集3.1 什么是依赖收集3.2 何时收集依赖?何时通知依赖更新?3.3 把依赖收集到哪里 4. 依赖到底是谁5. 不足之处6. 总结 1. 前言 我们知道:数据驱动视图的关键点则在于我们如何知道数据发…

记录docker swarm的使用

在前面的几篇文章中我们依次学习了dockerfile、docker-compose的使用,接下来是docker有一个比较 重要的使用,docker swarm的使用,与dockerfile和docker-compose相比较而言,docker swarm是在 多个服务器或主机上创建容器集群服务准…

Leetcode——66. 加一

💯💯欢迎来到的热爱编程的小K的Leetcode的刷题专栏 文章目录 1、题目2、暴力模拟(自己的第一想法)3、官方题解 1、题目 给定一个由 整数 组成的 非空 数组所表示的非负整数,在该数的基础上加一。最高位数字存放在数组的首位, 数组…

CTF-PHP反序列化漏洞2-典型题目

作者:Eason_LYC 悲观者预言失败,十言九中。 乐观者创造奇迹,一次即可。 一个人的价值,在于他所拥有的。可以不学无术,但不能一无所有! 技术领域:WEB安全、网络攻防 关注WEB安全、网络攻防。我的…

【纯属娱乐】随机森林预测双色球

目录 一、数据标准化二、预测代码三、后续 一、数据标准化 首先,我们需要对原始数据进行处理,将其转换为可用于机器学习的格式。我们可以将开奖号码中的红球和蓝球分开,将其转换为独热编码,然后将其与期数一起作为特征输入到机器…

ETL工具 - Kettle 查询、连接、统计、脚本算子介绍

一、 Kettle 上篇文章对 Kettle 流程、应用算子进行了介绍,本篇对查询、连接、统计、脚本算子进行讲解,下面是上篇文章的地址: ETL工具 - Kettle 流程、应用算子介绍 二、查询算子 数据输入使用 MySQL 表输入,表结构如下&#x…

给httprunnermanager接口自动化测试平台换点颜色瞧瞧

文章目录 一、背景1.1、修改注册表单的提示颜色1.2、修改后台代码:注册错误提示,最后提交注册,密码校验;1.3、修改了注册,那登录呢,也不能放过二、总结 一、背景 虽然咱给HttpRunnerManger引入进来&#xf…

【云台】开源版本SimpleBGC的电机驱动与控制方式

前言 最近想学习一下云台,发现资料确实还不太好找,比较有参考价值的是俄版的开源版本的云台代码,后面就不开源了,开源版本的是比较原始的算法,差不多是玩具级别的,不过还是决定学习一下,了解一…

PyCaret:低代码自动化的机器学习工具

PyCaret简介 随着ChatGPT和AI画图的大火,机器学习作为实现人工智能的底层技术被大众越来越多的认知,基于机器学习的产品也越来越多。传统的机器学习实现方法需要较强的编程能力和数据科学基础,这使得想零基础尝试机器学习变得非常困难。 机器…

Ucore lab5

实验目的 了解第一个用户进程创建过程了解系统调用框架的实现机制了解ucore如何实现系统调用sys_fork/sys_exec/sys_exit/sys_wait来进行进程管理 实验内容 练习0:已有实验代码改进 ​本实验中完成了用户进程的创建,能够对用户进程进行基本管理,并为…

C语言入门篇——自定义数据篇

目录 1、结构体 1.2、匿名结构体 1.3、结构体的自引用 1.4、结构体的声明和初始化 1.5、结构体的内存对齐 1.6、修改默认对齐数 1.7、结构体传参 2、枚举 3、共用体(联合体) 1、结构体 设计程序时,最重要的步骤之一是选择表示数据的…

【微机原理】8088/8086微处理器

目录 一、8088/8086的功能结构 1.总线接口部件(BIU) 2.执行部件(EU) 二、8088/8086的寄存器结构(14个) 溢出标志的概念 溢出和进位的区别 8086CPU是Intel系列的16位微处理器,他有16根数据…

框架学习之KOCA框架简介

KOCA框架简介 什么是KOCA术语定义发展历史 KOCA的总体架构产品优势开放性敏捷性(一体化解决方案)融合性安全性接入网关- KOCA Gateway KOCA DevOps流水线 KOCA技术栈 金证开发者社区:http://koca.szkingdom.com/ 什么是KOCA KOCA是金证基于…

LC-1376. 通知所有员工所需的时间(DFS:自上而下、自下而上)

1376. 通知所有员工所需的时间 难度中等125 公司里有 n 名员工,每个员工的 ID 都是独一无二的,编号从 0 到 n - 1。公司的总负责人通过 headID 进行标识。 在 manager 数组中,每个员工都有一个直属负责人,其中 manager[i] 是第…

JavaScript常用数组方法-汇总

快速检索 方法解析 1:concat(); 功能:合并数组,可以合并一个或多个数组,会返回合并数组之后的数据,不会改变原来的数组; var str1 [12,2,"hello"];var str2 ["world"]; console.lo…

简单毛概刷题网页制作 2.0(拖欠近一年版)

原因是大概一年之前学校的毛概期末刷题网站突然崩了,但是一直没有修复。当时眼看着复习时间逐渐被压缩,自己啥也做不了,遂自学前端完成毛概刷题网页一枚。 最早的毛概刷题网站仅仅是 1.0 版本(传送门),功能…

Excel技能之对齐,你可能都没想到

Excel表格,既然要做得漂漂亮亮,一定离不开对齐。拍照需要美颜,表格需要对齐。 内容全部挤到一边去。 有些靠左,有些靠右。 加上空格,感觉对齐。如果数据特别多,又逃不过加班的命运。 实在是混乱不堪。审美…

Linux常用命令——iostat命令

在线Linux命令查询工具 iostat 监视系统输入输出设备和CPU的使用情况 补充说明 iostat命令被用于监视系统输入输出设备和CPU的使用情况。它的特点是汇报磁盘活动统计情况,同时也会汇报出CPU使用情况。同vmstat一样,iostat也有一个弱点,就…

Linux安装MongoDB数据库,并内网穿透远程连接

文章目录 前言1. 配置Mongodb源2. 安装MongoDB3. 局域网连接测试4. 安装cpolar内网穿透5. 配置公网访问地址6. 公网远程连接7. 固定连接公网地址8. 使用固定地址连接 转载自Cpolar Lisa文章:Linux服务器安装部署MongoDB数据库 - 无公网IP远程连接「内网穿透」 前言 …

Qt中QDebug的使用

QDebug类为调试信息(debugging information)提供输出流。它的声明在<QDebug>中&#xff0c;实现在Core模块中。将调试或跟踪信息(debugging or tracing information)写出到device, file, string or console时都会使用QDebug。 此类的成员函数参考&#xff1a;https://doc…