Transformer原理详解

news2024/11/26 18:34:24

前言好久没有用了,我已经快忘记了自己还有一个CSDN账号了。 在某位不知名好友的提醒下,终于拾起来了,自己也从大二转变成了研二。
目前研究方向主要为:时间序列预测,自然语言处理,智慧医疗
欢迎感兴趣的小伙伴一起交流、一起进步!!!

目录

  • 一.背景
    • 1.1论文
    • 1.2发展动机
    • 1.3特点
    • 1.4地位
  • 二.模型整体架构
    • 2.1宏观理解
    • 2.2微观理解
    • 2.3工作流程
  • 三.模型架构详解
    • 3.1数据
    • 3.2输入(先做了解,后续部分会详细讲解)
    • 3.3Embedding
    • 3.4位置编码
    • 3.5Encoder_self-attention
    • 3.6Encoder_scaled dot-product attention
    • 3.7Encoder_Multi-Head
    • 3.8Encoder_padding掩码
    • 3.9Encoder_多头注意力整体
    • 3.10Encoder_Add&Norm
    • 3.11Encoder_前馈神经网络
    • 3.12Dncoder_Masked Multi-Head Attentuion
    • 3.13Encoder_Decoder_Multi-Head Attention
    • 3.14输出部分
    • 3.15模型训练_损失函数
    • 3.16模型训练_自定义学习率
  • 四.参考

一.背景

1.1论文

  • 2017年谷歌Brain团队发表了“Attention Is All You Need”(注意力就是你所需要的一切),正式提出了Transformer。这是一篇非常经典的自然语言处理领域的文章,把注意力机制推向了高潮。

文章链接:https://arxiv.org/abs/1706.03762

1.2发展动机

  • Transformer的提出,最初针对的是机器翻译任务。
  • 传统的机器翻译采用的主流框架为Encoder-Decoder框架,以RNN作为主要方法,基于RNN发展出来的LSTM和GRU也一度被认为是解决该问题的最优方法。
  • 然而,RNN模型的计算被限制为顺序,阻碍了样本训练的并行化,会导致在计算过程中发生信息丢失从而引起长期依赖问题。
  • RNN以及其衍生网络的特点就是,前后隐藏状态依赖强无法实现并行

1.3特点

  • Transformer的self-attention可以实现快速并行,改进了RNN类模型训练慢的缺点
  • Transformer可以增加到非常深的深度,充分挖掘DNN模型的特征,提升准确率

1.4地位

  • 开天辟地,各种竞赛屠榜,毕业论文首选
  • 对NLP、CV等各种领域,产生了深刻的影响
  • 截止目前,知网收录名称包含Transformer的期刊达18.51万,学位论文总数达4.16万
    在这里插入图片描述

二.模型整体架构

在这里插入图片描述

图1 Transformer模型架构图

2.1宏观理解

  • 输入一段法文,经过Transformer模型后,可以输出对应的英文

在这里插入图片描述

图2 Transformer宏观理解模型
  • Transformer本质上是一种Encoder-Decoder架构,因此中间部分的Transformer可以分为两个部分:Encoder(编码组件)和Decoder(解码组件)
    在这里插入图片描述
图3 Transformer宏观理解模型(Encoder-Decoder)
  • 其中,Encoder和Decoder分别有n1和n2层(这里n1=n2=6),因此又可以将模型进一步细分

在这里插入图片描述

图4 Transformer进一步宏观理解模型(Encoder-Decoder)

2.2微观理解

  • 每个编码器由两个子层组成:Self-Attention 层(自注意力层)和 Position-wise Feed Forward Network(前馈网络)
  • 解码器也有编码器中这两层,但是它们之间还有一个注意力层(即 Encoder-Decoder Attention),使得解码器关注输入句子的相关部分

在这里插入图片描述

图5 Transformer微观理解

2.3工作流程

  1. 输入需要翻译的数据(以:Why do we work为例)
  2. 经过Encoder层,得到隐藏层,输出到Decoder
  3. Decoder的输入(即<strat> 开始符号)
  4. 将Decoder的输入进行上述操作得到Decoder的输出(即输入Why后,得到翻译的结果:为)
  5. 将Decoder的输出和Decoder的输入进行拼接,共同作为下一步Decoder的输入(即:<strat>
  6. 得到Decoder的输出:什
    值得注意的是,第4步和第5步是并行进行的,这里为了详细解释原理,拆分为了两步。

在这里插入图片描述

图6 Transformer工作流程

三.模型架构详解

3.1数据

  • 原始数据为一组元组,前面的部分为中文,后面的部分为对应的英文。均为Tensor的数据类型。

在这里插入图片描述

图7 输出的原始数据

  • 拿到数据后,第一件事就是将字符串转换为机器可以识别的数字编码。
    在这里插入图片描述
图8 编码后的数据
  • 为每一句话加上起始符和终止符,这里分别以12,8作为起始符;13,9作为终止符(只要不在编码的数字里面,任何数字都是可以的)
    在这里插入图片描述
图9 编码后加上起始符和终止符
  • 之后,一般会经过shuffle(打乱顺序)、batch(分批提速)、padding(填充)等操作,使得模型学习的效果更好。这里为了方便,仅展示batch和padding的操作,不进行shuffle。
  • 将前两行数据作为一个batch,后两行数据作为一个batch,然后按照一个batch中最长的长度进行填充。
  • 值得注意的是,每一个batch最长的长度可能是不一样的
    在这里插入图片描述
图10 batch和padding

3.2输入(先做了解,后续部分会详细讲解)

  • 以第一个batch为例,左边为中文(input),右边为英文(target)在这里插入图片描述
图11 输入数据(以第一个batch为例)
  • 对于target,需要分为target_input(如图6中的步骤3-5。即经过Encoder后,传入Decoder的输入)和target_real(如图6中的步骤6。Decoder的输出)。
    在这里插入图片描述
图12 target部分拆分
  • 记录input和target_input的padding,消除影响。具体操作位:padding项变为1,非padding项变为0
  • 提醒:以下部分大家可能目前觉得难理解,没关系,现在有一个印象就好,后面到了这一部分就会理解
    在这里插入图片描述
图13 input的padding例子

在这里插入图片描述

图14 target_input的padding例子
  • 下面展示出4个padding的对应关系

在这里插入图片描述

图15 input的padding对应的位置(绿色:Encoder部分;红色:Decoder部分)

在这里插入图片描述

图16 target_input的padding对应的位置(对应于Decoder的Mask部分)

3.3Embedding

  • 将对应的数字编码,转为向量表示。这里以4维为例(每个数字,用一个4维的向量表示)
  • input由二的张量变成了一个三维的张量(即batchsize,input_sql,dmodel)
  • 大家不用纠结数字的具体表示,只需要了解结构就可以
    在这里插入图片描述
图17 Embedding示例

3.4位置编码

  • Transformer因为并行化,无法捕捉到位置信息
  • 解决方案:位置编码,引入绝对位置和相对位置
  • 下面公式中,pos指的是这句话的位置(向量的顺序),i = dim_index // 2 dmodel指的是向量的维度
  • 位置编码只和Embedding的维度和seq_length有关
    在这里插入图片描述
图18 位置编码公式
  • 值得注意的是,在原文中,作者给出了一个位置编码的示意图,如下图19所示。
  • 这里,d_model = 512,seq_length = 50。可以看到每一个位置(纵坐标)的颜色都不一样,这也是说明了位置编码可以反应真实的顺序。
    在这里插入图片描述
    图19 位置编码示例

3.5Encoder_self-attention

  • 没有self-attention,每个词只能表示自身的含义,不包含信息
  • self-attention会对词向量进行重构,使得词向量不单只包含自己,而是综合考虑全局,融入上下文
  • 举个例子,以“我宣你”为例,embedding后,会生成对应的词向量,经过线性变化后,每个词向量会生成q、k、v三个变量,来进行后续的计算,如图20所示。

在这里插入图片描述

图20 self-attention示例(生成q、k、v)
  • 接下来,q、k进行点乘操作,计算相关性,每一个q跟所有的k进行计算。计算后,得到一个相关性得分,经过softmax将得分映射到0-1之间。之后与v进行相乘,形成一个新的向量,这就是结合了上下文的向量,不仅仅包含embedding。

  • (这里要注意:q1和v2点乘 与 v2和q1点乘是不一样的。举个例子,有人对你笑,你可能是因为)

在这里插入图片描述

图21 self-attention示例(计算相关性得分)

3.6Encoder_scaled dot-product attention

  • 本文对self-attention进行了改进,公式如下图22所示。

  • 改进一:q、k、v是一个矩阵(原来是向量)

  • 改进二:除以根号dk,dk表示的是k的维度(这也就是为什么叫scaled)

在这里插入图片描述

图22 scaled dot-product attention公式
  • 为什么要除以根号dk呢?这里从两个角度进行解释

1.前向传播角度:softmax 是一种非常明显的 “马太效应”:强(大)的更强(大),弱(小)的更弱(小)。而缩放后,注意力值就分散些,这样一般就获得更好的泛化能力。

2.反向传播角度:不除以这个的话,注意力得分score是一个很大的值,softmax在反向传播时,容易造成梯度消失

  • 为什么选择根号dk,而不是其他的值呢?

假设向量q和k的各个分量是互相独立的随机变量,均值为0,方差为1,那么q和k点积后的均值为0,方差为dk

3.7Encoder_Multi-Head

  • 理论上,Embedding后,需要生成q、k、v,采用不同的参数,如图23所示,用两套不同的参数生成两套不同的q、k、v。这就叫不同的头(本文采用了8个头)
  • 本文中,用一套参数生成Q、K、V,进行拆分。拆分后,Q的维度减少(depth = d_model / heads_num),要保证d_model能够整除heads_num
    在这里插入图片描述
    图23 Multi-Head示例
  • 之后,每个头独立行动,进行scaled dot-product attention操作,将结果进行拼接,还原为之前的维度d_model,这就相当于将多个头提取到的特征进行了融合
  • 为什么使用多头?

能够从多个角度去理解内容,注意到不同子空间的信息,捕捉到更加丰富的特征信息

3.8Encoder_padding掩码

  • 在这之前,我们要知道padding只要不影响有效信息就可以了
  • 计算qk关系的时候,padding不会影响
  • 但是计算softmax的时候(公式如图24),padding会影响。因为softmax计算的是每一个值,可能让模型以为padding和我们的向量有了关系(如图25绿色部分所示),所以要让padding的结果变成0,因此需要padding mask

在这里插入图片描述

图24 softmax公式
  • padding mask:最开始的时候,就将padding变成1,非padding变成0,用来标记padding(注意,1只是一个标记作用,不是padding的值一直是1)

  • 计算的时候,获取标记,将padding变成一个非常小的数,那么和padding有关的项会变成极其小的数。对于指数函数,当x无限小的时候,那么值就无限趋近于0,因此可以把softmax中padding产生的影响消除掉。
    在这里插入图片描述

    图25 padding示例
  • 为什么padding变成1,非padding变成0?一开始padding为0不是更好吗?

如果padding一开始为0,那么我们看softmax公式,e的0次方,就是1,这没有消除padding的影响,所以一开始给padding为0是没有

3.9Encoder_多头注意力整体

  • 到了这里,学习完这些主要的操作后,我们回到最初的环节(接图17 embedding)

  • Embedding后,我们要经过scaled dot-product attention(缩放),位置编码。

  • 之后input需要经过多头注意力模块;target_input需要经过掩码多头注意力模块
    在这里插入图片描述

    图26 Embedding后的操作
  • 首先,经过线性变化,得到QKV三个矩阵,如图27所示
    在这里插入图片描述

    图27 线性变化

在这里插入图片描述

图28 线性变化结果(得到QKV矩阵)
  • 然后,进行分头(这里以2个头为例),如图29所示
    在这里插入图片描述

    图29 分头结果
  • 之后,利用attention进行注意力权重得分(scaled dot-product attention公式),结果如图30所示。

  • 注意:右图中使用padding mask消除padding的影响
    在这里插入图片描述

    图30 计算注意力权重
  • 接下来,经过softmax归一化,得到权重

  • softmax后,与V进行相乘,提取出新特征

  • 维度还原,将多头进行拼接
    在这里插入图片描述

    图31 维度还原
  • 到这里,多头注意力机制就结束了

3.10Encoder_Add&Norm

  • Add&Norm在许多地方都用到了
    在这里插入图片描述

    图32 Add&Norm
  • Add其实就是一个残差网络,能使训练层数到达比较深的层次,缓解梯度消失

  • Norm这里采用的是LayerNorm

BatchNorm:将特征进行标准化
LayerNorm:将样本进行标准化

  • 为什么不用BatchNorm?
    • 有padding项,BatchNorm不足以代表所有样本的均值和方差
    • batch size太小时,一个batch的样本,其均值与方差,不足以代表总体样本的均值与方差
    • NLP场景下不适合用BatchNorm

3.11Encoder_前馈神经网络

  • 前馈神经网络在Encoder和Decoder中也是都用到了
    在这里插入图片描述
    图33 前馈神经网络
  • 本质上,特别简单,就是一个两层的全连接层,可以参考如下代码:
def FeedForward(dff, d_model):
        return tf.keras.Sequential([
                    tf.keras.layers.Dense(units=dff, activation='relu'),
                    tf.keras.layers.Dense(d_model),
                ])
  • 第一层设置units个数(原文为2048),用relu激活函数
  • 第二层用d_model作为神经元个数

3.12Dncoder_Masked Multi-Head Attentuion

3.13Encoder_Decoder_Multi-Head Attention

3.14输出部分

3.15模型训练_损失函数

3.16模型训练_自定义学习率

四.参考

  1. http://t.csdnimg.cn/freZY
  2. https://github.com/LaoGong-zp/Transformer.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1216904.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Win通过WSL配置安装Redis

一共分为如下几步&#xff1a; 安装WSL发行版&#xff0c;如Ubuntu安装Redis配置Redis与WSL WSL安装 这里有微软官方的文档&#xff1a;https://learn.microsoft.com/zh-cn/windows/wsl/install 但我不建议零基础的这么做。很容易输完一些命令之后&#xff0c;把环境弄得乱七…

【python】OpenCV—Image Pyramid(8)

文章目录 1 图像金字塔2 拉普拉斯金字塔 1 图像金字塔 高斯金字塔 在 OpenCV 中使用函数 cv2.pyrDown()&#xff0c;实现图像高斯金字塔操作中的向下采样&#xff0c;使用函数 cv2.pyrUp() 实现图像金字塔操作中的向上采样 import cv2img cv2.imread(C://Users/Administrat…

C#winform门诊医生系统+sqlserver

C#winform门诊医生系统sqlserver说明文档 运行前附加数据库.mdf&#xff08;或sql生成数据库&#xff09; 主要技术&#xff1a;基于C#winform架构和sql server数据库 功能模块&#xff1a; 个人中心&#xff1a;修改个人信息、打开照片并进行修改 预约挂号&#xff1a;二级…

MIB 6.1810操作系统实验:准备工作(Tools Used in 6.1810)

6.1810 / Fall 2023 实验环境&#xff1a; Ubuntuxv6实验必要的依赖环境能通过make qemu进入系统 $ sudo apt-get update && sudo apt-get upgrade $ sudo apt-get install git build-essential gdb-multiarch qemu-system-misc gcc-riscv64-linux-gnu binutils-ri…

JVM——运行时数据区(堆+方法区+直接内存)

目录 1.Java堆2.方法区**方法区&#xff08;Method Area&#xff09;溢出**方法区&#xff08;Method Area&#xff09;字符串常量池静态变量的存储 3.直接内存(Direct Memory) 1.Java堆 ⚫ 一般Java程序中堆内存是空间最大的一块内存区域。创建出来的对象都存在于堆上。 ⚫ 栈…

NoC流量控制

参考链接1&#xff1a;https://blog.csdn.net/yang1573/article/details/128787167参考链接2&#xff1a;https://shili2017.github.io/posts/NOC5/参考文件&#xff1a;SE22_noc_flow_control.pdf

对象存储OSS服务器邀请试用

文章目录 试用产品领取产品试用权限上传文件开启加速传输提交作品小程序提交任务获取奖励 试用产品 先下载要上传的资源 电脑浏览器打开此页面开始试用&#xff0c;页面如下图 未登录的先登录 领取产品试用权限 在该页面中点击立即试用&#xff0c;弹框勾选服务协议并领取试…

差分详解(附加模板和例题)

一、一维差分 1.一维差分运用 设a[N]为原数组,b[N]为差分数组&#xff0c;c[N]为进行操作后得到的新数组 (1).先求出差分数组b[N] for(i1;i<n;i) {cin>>a[i];b[i]a[i]-a[i-1]; } (2).进行差分操作&#xff0c;利用void insert(int l,int r,int c)函数 void ins…

redis运维(五)再探redis

一 redis概述 ① redis简介 redis三大特性&#xff1a; 缓存、分布式内存数据库、持久化说明&#xff1a;非必须不建议在redis终端操作 ② redis亮点 ③ 初露锋芒 redis-benchmark redis-benchmark并发压力测试的问题解析 备注&#xff1a;多次测试取平均值,最好在物理机…

Android修行手册-Gson中不用实体类生成JsonObject或JsonArray

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例点击跳转>软考全系列点击跳转>蓝桥系列点击跳转>ChatGPT和AIGC &#x1f449;关于作者 专…

那些年我们追过的 内部类

目录 1. 什么是内部类&#xff1f; 2. 内部类的分类 3. 内部类 3.1 实例内部类 3.2 静态内部类 4. 局部内部类 5. 匿名内部类 6.对象的打印 “不积跬步无以至千里&#xff0c;不积小流无以成江海。”每天坚持学习&#xff0c;哪怕是一点点&#xff01;&#xff01;&a…

Java毕业设计心得体会

1. 开始准备选题、开题报告 大四上学期开学时开始准备论文的&#xff0c;首先是确定论文主题&#xff0c;看自己想做什么毕业设计&#xff0c;可以选取之前接触过的&#xff0c;做过的东西&#xff0c;这样快一些&#xff0c;我之前一直是学Java的&#xff0c;就打算直接用Jav…

【网络知识必知必会】再谈Cookie和Session

文章目录 前言1. 回顾 Cookie1.1 到底什么是 CookieCookie 的数据从哪里来Cookie 的数据长什么样Cookie 有什么作用Cookie 到哪里去Cookie 怎么存储的 2. 理解 Session3. Cookie 和 Session 的区别总结 前言 在讲 HTTP 协议时, 我们就谈到了 Cookie 和 Session, 当时我们只是粗…

redis运维(六)redis-cli命令

一 redis-cli 注意&#xff1a; redis-cli核redis-server版本必须适配 --> 见 redis-cli --version提示&#xff1a; 不过一般安装服务端 redis-server 时内置了客户端 redis-cli说明&#xff1a; redis-cli 是 redis 的一种命令行的客户端工具备注&#xff1a; redis-se…

单链表经典OJ题(四)

目录 1、链表中倒数第k个结点 2、消失的数字 3、轮转数组 4、合并两个有序数组 5、数组串联 6、序列中删除指定数字 1、链表中倒数第k个结点 链表中倒数第k个结点_牛客题霸_牛客网 (nowcoder.com) 这道题依然利用双指针法&#xff0c;具体解题思路如下&#xff1a; 1…

场景图形管理-多视图与相机(3)

在OSG中多视图的管理是通过osgViewer::CompositeViewer类来实现的。该类负责多个视图的管理及同步工作&#xff0c;继承自osgViewer;:ViewerBase类&#xff0c;继承关系图如图8-13所示 图8-13 osgViewer::CompositeViewer 的继承关系图 在前面已经讲到&#xff0c;osgViewer:Vi…

GDPU 商务英语 [初入职场](持续更新……)

&#x1f468;‍&#x1f3eb; 商务英语&#xff08;初入职场电子书PDF&#xff09;提取码&#xff1a;t9ri Unit 1 Job-seeking ✨ 单词 recruitment n. 招聘physical adj. 有形的;物质的profitability n. 盈利launch vt. 将(新产品等)投放市场budget n. 预算account for 占…

Qt QLable 字符过长省略

前言&#xff1a; 项目中常用到字符过长问题&#xff0c;Qt默认的省略并不好用&#xff0c;不是自己想要的&#xff1b; QFontMetri 可使用 QFontMetri 当text的像素宽度超过width&#xff0c;将返回字符串的一个省略版本取决于mode。否则将返回原字符串&#xff1b; mode…

工厂是否需要单独的设备管理部门

设备是工厂生产过程中不可或缺的重要资源&#xff0c;其正常运行和有效管理对于工厂的生产效率和质量至关重要。为了确保设备的良好状态和高效运行&#xff0c;许多工厂选择设立专门的设备管理部门。本文将探讨设备管理部门的职责、与生产部门下的点检维保团队的区别&#xff0…

PyCharm 远程连接服务器并使用服务器的 Jupyter 环境

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…