Transformer模型原理—论文精读

news2024/11/8 18:49:31

文章目录

    • 前言
    • 模型架构
      • Encoder和Decoder
        • Encoder
        • Decoder
      • Attention
      • FFN
      • Embeddings和Positional Encoding
        • Embeddings
        • Positional Encoding
    • 总结

前言

今天来看一下Transformer模型,由Google团队提出,论文名为《Attention Is All You Need》。论文地址。
正如标题所说的,注意力是你所需要的一切,该模型摒弃了传统的RNN和CNN结构,网络结构几乎由Attention机制构成,该论文的亮点在于提出了Multi-head attention机制,其又包含了self-attention,接下来我们将慢慢介绍该模型的原理。

模型架构

正如文中提到大多数的序列传导模型都含有encoder-decoder结构,Transformer的encoder是将一段表征序列 ( x 1 , ⋯   , x n ) (x_1,\cdots,x_n) (x1,,xn)映射为另一种连续表示的序列 ( z 1 , ⋯   , z n ) (z_1,\cdots,z_n) (z1,,zn),即encoder的输出信息;而decoder是将encoder输出和decoder前一步的输出自回归的共同生成序列 ( y 1 , ⋯   , y m ) (y_1,\cdots,y_m) (y1,,ym)。举个例子,现在有一个机器翻译任务,首先将句子embedding为高维向量,输入encoder中,其输出随后输入decoder进行解码得到最终翻译结果,如下图所示。
Encoder-Decoder

需要注意的是,Transformer的输出 y i y_i yi是一次一次自回归的生成的,也就是每一次输出都需要调用最后一层encoder的输出序列。这里不像多层RNN隐层的并行传递,Transformer是串行的。如下图所示。
多层Transformer

Encoder和Decoder

好了,接下来该介绍encoder和decoder的神秘面纱了,如下图所示。
Transformer模型架构
在读论文时第一眼看这个架构图,一开始是比较懵的,这到底做了些啥操作。后来看了李沐老师讲解的Transformer才有了一定的理解。

Encoder

回到论文的讲解!
Enc

这里说到作者实验用到了6层的encoder,这里是为了学到更多的语义信息。并且每层encoder都包含两个子层,分别是多头注意力机制Multi-head attention前馈神经网络FFN。当然了,作者对两个子层的输出都做了residual连接Layer normalization(LN),加了残差连接是为了网络能搭的更深,并且容易训练,防止梯度消失;而LN完全是针对每一个样本自身的特征缩放,能将每个词都归一为相同空间的语义信息。BN也是一种常见的特征缩放方法,常用于CNN,不适用于NLP任务,因为其对所有batch的同一个特征做缩放,在图像中是非常友好的,而NLP中每一个sequence的长度是不一样的,所以在同一个batch中越长的语句得不到充分的缩放表示。

Decoder

Dec

同样的,作者实验用到了6层decoder,不同于encoder,这里作者还设置了mask的multi-head attention,其原因在于在解码时,模型是看不到整条句子的,因此,必须在当前时刻掩码掉后面的词,才能做到正确训练和有效预测。

Attention

谈到注意力机制,像我们人一样,看到一幅图片,我们会关注其强烈的表征现象,能让我们快速了解新事物的信息,如下图所示。特别在处理NLP任务中,长距离的记忆能力是一个难题,引入注意力机制,关注更重要的词,可以缓解这一现象。
在这里插入图片描述

在Transformer中,每个单词embedding为三个不同的向量,分别是 Q u e r y Query Query向量 Q Q Q K e y Key Key向量 K K K V a l u e Value Value向量 V V V。具体来说,对于一个句子,只需要将其输入到三个linear层,通过学习三个 W W W参数就能得到不一样的 Q 、 K 、 V Q、K、V QKV。至于为什么说 Q 、 K 、 V Q、K、V QKV要不一样,其实一样也可以,但是这里为了增强数据的表达能力,保证在 Q K T QK^T QKT矩阵内积时可以得到不同的空间投影提高模型泛化能力

生成的 Q 、 K 、 V Q、K、V QKV矩阵后便可以进行attention计算了,如下图所示
在这里插入图片描述

假设有三个矩阵 Q 、 K 、 V Q、K、V QKV,维度分别为 ( d q , d m o d e l ) 、 ( d k , d m o d e l ) 、 ( d v , d m o d e l ) (d_q, d_{model})、(d_k, d_{model})、(d_v, d_{model}) (dq,dmodel)(dk,dmodel)(dv,dmodel),其中 q = k = v q=k=v q=k=v

  1. 首先进入的是Multi-Head Attention多头注意力机制。这里可以h层,也就是我们说的多头,类似cv中的channel数量,能学习更多维度信息。多头注意力机制中包含了Scaled Dot-product Attention,也是self-attention
  2. 其次进入self-attention,对于每个sequence,用它的query矩阵: ( d q , d m o d e l ) (d_q, d_{model}) (dq,dmodel)和key向量shape: ( d k , d m o d e l ) (d_k, d_{model}) (dk,dmodel)进行内积,本质上是求解每个词之间的余弦相似度,如果两者相似度较高,则赋予较大的值来反应两者的关系,反之如果是正交的,内积为0,则它们就没有相似性,这里输出的attention score矩阵维度是shape: ( d q , d k ) (d_q, d_k) (dq,dk)
  3. 再次将输出矩阵进行scale缩放相似度,为了防止softmax推向梯度平缓区,使得收敛困难,公式如 Attention ( Q 、 K 、 V ) \text{Attention}(Q、K、V) Attention(QKV)所示。
  4. 从次是通过可选的mask操作,为了保证decoder得到sequence的leak信息。具体来说是通过将权重矩阵添加一个上三角的负无穷矩阵,这样softmax就能将这些值推为0,即无权重,保证mask的作用。
  5. 最后将attention score矩阵shape: ( d q , d k ) (d_q, d_{k}) (dq,dk)与Value矩阵shape: ( d v , d m o d e l ) (d_v, d_{model}) (dv,dmodel)内积,得到encoder后的sequence信息表征shape: ( d q , d m o d e l ) (d_q, d_{model}) (dq,dmodel)

Scale缩放公式:
Attention ( Q 、 K 、 V ) = softmax ( Q K T ( d k ) ) V \begin{aligned} \text{Attention}(Q、K、V)=\text{softmax}(\frac{QK^T}{\sqrt(d_k)})V \end{aligned} Attention(QKV)=softmax(( dk)QKT)V

Multi-head公式:
Multihead ( Q 、 K 、 V ) = Concat ( head i , ⋯   , head h ) W O where  head i = Attention ( Q W i Q , Q W i K , Q W i V ) \begin{aligned} \text{Multihead}(Q、K、V)=\text{Concat}(\text{head}_i,\cdots,\text{head}_h)W^O\\ \text{where} \text{ } \text{head}_i=\text{Attention}(QW_{i}^{Q},QW_{i}^{K},QW_{i}^{V}) \end{aligned} Multihead(QKV)=Concat(headi,,headh)WOwhere headi=Attention(QWiQ,QWiK,QWiV)

上述的操作执行完后,便可以通过多个头的concat将矩阵拼接,随后通过linear层降维,完成Multi-head attention的过程。

在这里插入图片描述
值得注意的是多头数量必须可被 d m o d e l d_{model} dmodel整除。这个很好理解,在CNN中,我们经常将feature map的width和height升维后,会把channel数降低,学到更深的信息一个道理。

FFN

除了注意子层外,encoder和decoder中的每个层都包含一个完全连接的前馈网络,它分别和相同地应用于每个位置。这由两个线性变换组成,中间有一个ReLU激活。换句话说就是MLP模型。公式如下所示:

FFN ( x ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2 \begin{aligned} \text{FFN}(x) = \text{max}(0, xW_1 + b_1)W_2 + b_2 \end{aligned} FFN(x)=max(0,xW1+b1)W2+b2

在这里插入图片描述
论文作者给定了MLP的中间状态输出维度为2048,而最后输出维度为512,当然就是512->2048->512这样变换。

Embeddings和Positional Encoding

Emdedding+Positional Encoding

Embeddings

在Transformer中,在嵌入层中,这些权重乘以 d m o d e l \sqrt{d_{model}} dmodel 。其原因是在嵌入层学emdedding的时候,在L2norm后,不管维度多大最终权重都会比较小,但后续要和positional encoding相加(不会经过norm),需要保持差不多的scale。

Positional Encoding

self-attention对输入sequence中各单词的位置或者说顺序不敏感,因为通过Query向量和Key向量的内积,本质上就是一些词由其他词的线性表出,并没有说有位置的信息存在。比如“我吃牛肉”这句话,在Transformer看来和“牛吃我肉”是没什么区别的。
为了缓解该问题,作者提出了位置编码。简而言之,就是在词向量输入到注意力模块之前,与该词向量等长的位置向量进行了按位相加,进而赋予了词向量相应的位置信息。

作者给出了位置编码的定义公式,具体如下:

P E p o s , 2 i = s i n ( p o s / 1000 0 2 i / d m o d e l ) P E p o s , 2 i + 1 = c o s ( p o s / 1000 0 2 i / d m o d e l ) PE_{pos, 2i}=sin(pos/10000^{2i/d_{model}}) \\ PE_{pos, 2i+1}=cos(pos/10000^{2i/d_{model}}) PEpos,2i=sin(pos/100002i/dmodel)PEpos,2i+1=cos(pos/100002i/dmodel)

这样通过 s i n ( α + β ) = s i n ( α ) c o s ( β ) + c o s ( α ) s i n ( β ) sin(\alpha+\beta)=sin(\alpha)cos(\beta)+cos(\alpha)sin(\beta) sin(α+β)=sin(α)cos(β)+cos(α)sin(β)。可以将牛(pos=3)可以由pos=2和pos=4表达,使得Transformer可以更容易掌握单词间的相对位置。

总结

关于Transformer比较重要的点基本上就这些,当然还有很多细节的地方需要去探索,接下来我将会写更多的论文分享,总结一些经典的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/606214.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Openwrt_XiaoMiR3G路由器_刷入OpenWrt

刷入Openwrt之前请保证小米R3G路由器已经刷入Breed控制台固件。 刷入Breed请参考: Openwrt_XiaoMiR3G路由器_刷入Breed固件 路由器具体配置参考 小米路由器3G参数 - 小米商城 既然要刷入OpwnWrt就需要线编译固件,使用lede的OpenWrt源码编译。 进入 …

K8S集群安装(二)

目录 1 安装说明.... 1 1.1 安装环境.... 1 1.2 生产环境可部署Kubernetes集群的两种方式.... 2 1.3 操作系统初始化配置.... 3 2 安装Docker/kubeadm/kubelet【所有节点】.... 4 2.1 安装Docker. 4 2.2 添加阿里云YUM软件源.... …

《Apollo 智能驾驶进阶课程》

来自 : https://www.bilibili.com/video/BV1G341117NQ/ https://apollo.baidu.com/ 主要学习资源如下: Apollo社区公众号,直接有整个视频教程的微信推文教程:链接一个CSDN博主记录的笔记: https://blog.csdn.net/qq_45…

08 redis经典五种数据类型及底层实现

redis是字典数据库KV键值对是什么 redis 是 key-value 存储系统,其中key类型一般为字符串,value 类型则为redis对象(redisObject)Redis定义了redisObjec结构体来表示string、hash、list、set、zset等数据类型 C语言struct结构体语法简介Redis 中每个对象…

【华为机试】死记硬背没思路?一般人我劝你还是算了吧

大家好,我是哪吒。 五月份之前,如果你参加华为OD机试,收到的应该是2022Q4或2023Q1,这两个都是A卷题。 5月10日之后,很多小伙伴收到的是B卷,那么恭喜你看到本文了,抓紧刷题吧。B卷新题库正在更…

Spring依赖注入解析

目录 依赖注入大致要点 依赖注入大致流程 Bean的预实例化 doGetBean createBean 完备Bean的创建过程 createBeanInstance populateBean 依赖注入大致要点 Spring在Bean实例的创建过程中做了很多精细化控制finishBeanFactoryInitialization方法里面的preInstantiateSing…

【计算机网络复习】第六章 局域网 LAN

局域网( LAN)概述 LAN的特点 • 覆盖范围小 房间、建筑物、园区范围 • 高传输速率 10Mb/s~1000Mb/s • 低误码率 10-8 ~ 10-11 • 拓扑:总线型、星形、环形 • 介质:UTP、Fiber、C…

6年测试经验之谈,为什么要做自动化测试?

一、自动化测试 自动化测试是把以人为驱动的测试行为转化为机器执行的一种过程。 个人认为,只要能服务于测试工作,能够帮助我们提升工作效率的,不管是所谓的自动化工具,还是简单的SQL 脚本、批处理脚本,还是自己编写…

智能优化算法:指数分布优化算法-附代码

智能优化算法:指数分布优化算法 文章目录 智能优化算法:指数分布优化算法1.指数分布优化算法1.1种群初始化1.2EDO开发1.3EDO探索 2.实验结果3.参考文献4.Matlab5.python 摘要:指数分布优化算法(Exponential distribution optimize…

全新好用的窗口置顶工具WindowTop

打开WindowTop软件,所有已打开的窗口都会在左上角出现一个置顶栏,点击置顶栏的置顶复选框即可置顶窗口或取消窗口。   在WindowTop软件的置顶栏一项里可以自由调整置顶栏的元素(包含增删位置)。   可改变置顶栏的外观&#x…

剖析ffmpeg视频解码播放:时间戳的处理

一、视频播放基础理论 1.1 视频编码和解码基础 视频编码和解码是视频播放的基础,理解它们的工作原理对于深入理解视频播放至关重要。在这一部分,我们将详细介绍视频编码和解码的基础知识。 视频编码(Video Encoding)是将原始视…

离散数学_十章-图 ( 5 ):连通性 - 上

📷10.5 图的连通性 1. 通路1.1 通路1.2 回路1.3 其他术语 2. 无向图的连通性2.1 无向图的连通与不连通2.2 定理2.3 连通分支 3. 图是如何连通的3.1 割点( 关节点)3.2 割边( 桥)3.3 不可分割图3.4 𝑘(&#…

Linux内核模块开发 第 5 章

The Linux Kernel Module Programming Guide Peter Jay Salzman, Michael Burian, Ori Pomerantz, Bob Mottram, Jim Huang译 断水客(WaterCutter) 5 预备知识(Preliminaries) 5.1 模块的入口函数和出口函数 C 程序通常从 ma…

建筑与建材行业相关深度学习数据集大合集

近期又整理了一批建筑与建材行业相关深度学习数据集,分享给大家。废话不多说,直接上干货!! 1、埃及的地标数据集 自从历史开始以来,埃及一直是许多文明、文化和非常著名的地标的家园,现在你(和你的ML模型…

守护进程【Linux】

文章目录 前导知识shell、terminal、console进程组作业会话测试 会话控制jobfgbgps 守护进程作用查看守护进程创建守护进程 前导知识 shell、terminal、console terminal(终端)是一种可以和计算机交互的设备,通常有键盘和显示器&#xff0c…

RocketMq 的基本知识1

一RocketMq的基本知识 1.1 RocketMq的基本知识 MQ , Message Queue ,是一种提供 消息队列服务 的中间件,也称为消息中间件。 1.2 作用 1.流量消峰 2.异步传输 3.日志收集 1.3 核心概念 1消息: 消息是指,消息系统所…

基于内存操作的Redis数据库--详解

目录 基本概念 基本操作 redis的五个基本类型 Redis-key(不区分大小写) 字符串 string Redis的特殊类型 geospatial地理空间 事务 Redis的持久化 RDB(.rdb) 触发机制 优点 缺点 AOF(.aof) 优点…

冈萨雷斯DIP第8章知识点

8.1 基础 图像中的冗余 编码冗余:用于表示灰度的8比特编码所包含的比特数,要比表示该灰度所需要的比特数多。可通过变长编码来解决。 空间和时间冗余:与相邻像素相似(图像);时间:相邻帧中的像素(视频)。可以使用行程…

缺陷管理利器推荐:介绍几款好用的缺陷管理工具

缺陷管理是项目管理工作中的重要环节。Excel表格是国内团队常用的缺陷管理工具,具备上手容易,免费的优点,不过也存在协同不便,不易管理,效率低的不足之处。 一套缺陷管理工具可以帮助我们进行规范化自动化的缺陷管理&a…