Transformer架构笔记

news2024/11/24 9:32:58

Attention is All You Need.


3.Model Architecture

3.1 整体架构如图

在这里插入图片描述

3.2 Encoder与Decoder

  • Encoder:由 N = 6 N=6 N=6个相同的Block/Layer堆叠而成。每个Block有两个子层sub-layer:多头注意力和MLP(FFN,前馈神经网络),每个子层被一个残差链接包围,且后接一个LayerNorm。由于残差链接要求被包围的Block输入输出shape一致(比如Resnet中利用PWConv达到降维、尺寸缩减,来保证输入输出的shape一致)。多头注意力层和FFN层就要满足此条件,因此作者选择固定网络中的Embedding dimension为512。每个子层的输出可以用公式表示为 LayerNorm ( x + SubLayer ( x ) ) \text{LayerNorm}(x + \text{SubLayer}(x)) LayerNorm(x+SubLayer(x)) x x x为输入。

Q: SliceGPT如何解决各层hidden dimension一致的问题?

  • Decoder:同样为6个Block。有两个多头注意力子层,一个FNN子层。且使用自回归的方式:t时刻Decoder的输入,为t时刻以前所有Decoder的输出的总和

    • masked self multi-head attention: mask样本t时刻以后的输入,保证t时刻无法看到未来的输入,避免模型“作弊”。理解起来很简单,因为预测第i个位置的输出时只能依靠第i个位置以前的所有输出单词的语义信息
  • 为什么用LN而不是BN?
    BN针对特征做归一化,LN则针对样本。LN在机器翻译中会用的更多,主要是因为输入序列的长度通常不一致。(在训练中,使用zero padding来解决输入长度不一致的问题)。如果一个batch中输入序列的长度差异很大,则得到的mean,square也会产生震荡。并且在预测时,由于使用全局的mean square,如果输入序列很长,比train set的序列都要长,则预测效果也会变差(因为之前没有统计过)。

3.3 Attention

概括来说,Attention就是一个将query,key,value映射为output的函数。output,k,v,q均为矩阵/向量。output为value的加权平均,权重由key和query的相似度函数计算得到,不同的attention实现会有不同的计算相似度的方法。

  • 注:这里的query,key,value都是指attention的输入。当然有些文章也会指Attention中q, k, v对应的权重矩阵。这里需要指出,输入的shape一般都是一样的: x ∈ R N × d model x \in R^{N \times d_{\text{model}}} xRN×dmodel。其中 N N N为序列长度, d model d_{\text{model}} dmodel为embedding长度。如果是权重矩阵,比如key的 w k ∈ R d model × d k w_k \in R^{d_{\text{model}}\times d_{\text{k}}} wkRdmodel×dk,相当于把输入 x x x投影到另一个维度 d m o d e l → d k d_{model} \to d_k dmodeldk

3.3.1 Scaled Dot-Product Attention

最基本的Attention。 Q Q Q K K K的hidden dimension维度 d k d_k dk V V V d v d_v dv,计算 K K K Q Q Q的内积作为 V V V的权重。具体公式如下:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V (1) \text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V \tag1 Attention(Q,K,V)=softmax(dk QKT)V(1)

Q , K , V Q, K, V QKV 在这里就是输入乘以权重矩阵后的结果: Q ∈ R N × d k , K ∈ R d k × m , V ∈ R m × d v Q \in R^{N \times d_k}, K \in R^{d_k \times m}, V \in R^{m \times d_v} QRN×dkKRdk×mVRm×dv。比如 Q = x w k Q = xw_k Q=xwk

Scaled:除以了 d k \sqrt{d_k} dk 。当 d k d_k dk很大时,维度数过多,各个值可能差别很大,softmax后大的值很大,小的值很小,softmax函数的梯度,会变得很小,train不动。(这个具体是根据softmax函数的一次函数来看)。

  • Mask:mask掉的是key-value,即对t时刻以后的key-query weight设为0(即进入softmax前设一个很大的负数)。

3.3.2 多头注意力

将高维的 Q , V , K Q,V,K QVK(输入)投影到低维,投影h次分别得到h个低维的 Q , V , K Q,V,K QVK,再分别做attention(相当于h个头),把h个头的attention输出并在一起,然后投影回原维度。多头注意力能让模型学习更多参数,相当于学习不同角度的信息。

示意图如Figure 2右图所示:h个头的attention。原特征维度 d model d_\text{model} dmodel,则投影后的维度为 d model / h = d k = d v d_\text{model} / h = d_k = d_v dmodel/h=dk=dv K , Q , V K,Q,V KQV经投影矩阵(权重矩阵),也就是Linear层降维后送入Attention,将h个头的Attention拼接,再经过最后的Linear升维度,得到输出。原文中 h = 8 ,   d model = 512 ,   d k = d v = 64 h=8, \ d_\text{model}=512, \ d_k = d_v = 64 h=8, dmodel=512, dk=dv=64为每个头的hidden dimension长度。

在这里插入图片描述

Attention的参数量。https://blog.csdn.net/qq_46009046/article/details/134417286

具体Multi-Head实现如以下公式:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) W O w h e r e   h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) \begin{aligned} \mathrm{MultiHead}(Q,K,V)& =\mathrm{Concat}(\mathrm{head}_1,...,\mathrm{head}_\mathrm{h})W^O \\ \mathrm{where~head_i}& =\mathrm{Attention}(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V}) \end{aligned} MultiHead(Q,K,V)where headi=Concat(head1,...,headh)WO=Attention(QWiQ,KWiK,VWiV)

W i Q ∈ R d m o d e l × d k , W i K ∈ R d m o d e l × d k , W i V ∈ R d m o d e l × d v W_i^Q\in\mathbb{R}^{d_{\mathrm{model}}\times d_k},W_i^K\in\mathbb{R}^{d_{\mathrm{model}}\times d_k},W_i^V\in\mathbb{R}^{d_{\mathrm{model}}\times d_v} WiQRdmodel×dk,WiKRdmodel×dk,WiVRdmodel×dv W O ∈ R h d v × d m o d e l W^O\in\mathbb{R}^{hd_v\times d_{\mathrm{model}}} WORhdv×dmodel均为权重矩阵(投影矩阵)。 Q ,   K ,   V ∈ R N × d m o d e l Q ,\ K,\ V \in \mathbb{R}^{N \times d_{model}} Q, K, VRN×dmodel

3.3.3 Attention的不同应用

假设输入序列长为 N N N

  1. Encoder:输入shape为 R ∈ N × d model R \in {N \times d_\text{model}} RN×dmodel。输出shape与输入相同。

  2. Decoder中的Masked-MHSA:与Encoder中的attention类似,不过用了mask保证自回归性。

  3. Cross-Attention:Encoder的输出作为key-value,之前的Decoder输出信息的总和得到的embedding作为query,相当于考虑decoder与encoder的embedding之间的相关性。具体到机器翻译中,如英译中,就是中文某个字与英文某个句子(多个字)的关系。

3.4 Point-Wise FFN

简而言之就是两个全连接层。Point指的是每个word embedding,即针对每个word做处理。与PWConv类似但参数量不同。
PWFFN为: in_embedding × out_embedding \text{in\_embedding} \times \text{out\_embedding} in_embedding×out_embedding,而PWConv为: in_channel × out_channel \text{in\_channel} \times \text{out\_channel} in_channel×out_channel
为什么能用Point-Wise FFN,也就是对每个word做投影?因为Attention部分已经汇聚了语义信息(词与词之间的相关度),MLP相当于只是做维度的变换。

3.5 Embeddings and Softmax

embedding层中权重乘以 d model \sqrt{d_\text{model}} dmodel ,因为L2Norm把权重惩罚得很小。

3.6 Positional Encoding

Attention中没有考虑单词的位置信息,所以理论上来说,如果不加入位置信息,把输入word顺序打乱之后,attention输出应该一致。所以要加入位置编码信息。

预测过程:

Encoder:直接获得整个句子的输入的Embedding,经过6个block,然后得到Encoder的输出,比较直接。
Decoder:先输入起始符(S),经过n层decoder,每一层decoder都要通过masked自注意力机制,以及交互注意力机制(和encoder输出的k,v进行计算,这里就说明decoder在交互注意力不需要再算k,v),在最后一层输出预测结果。然后把最新的预测结果以及历史预测结果综合起来,再放到decoder中,做下一个word的预测,直到得到终止符(E)。

训练过程

Encoder:与预测过程类似。

Decoder:与预测过程不同的是,采用Teacher Forcing的方式,即直接告诉Decoder整个目标序列(也就是正确的输出结果),但是由于预测阶段中,不可能直接告诉你答案是上面,因此会加一些随机mask,比如对15%的单词加入mask,以保证训练效果。


文献标注

Attention的参数数量: https://blog.csdn.net/qq_46009046/article/details/134417286

Attention的训练,推理过程: https://blog.csdn.net/AIcar_lrm/article/details/138577652


问题

  1. auto regressive 是什么意思?

    自回归模型:自变量和因变量应当属于同一个分布。比如根据前n天的股票价格,预测下一天的股票价格。(过去时候的输入作为当前时刻的输入)

  2. 为什么要除以 d k \sqrt {d_k} dk ?
    • 在数值过大/过小的情况下,softmax的偏导数值过小,造成学习困难。除以$\sqrt {d_k} $相当于将输入控制在均值为0,方差为1,控制了数量级。还起到了归一化的作用
  3. BN和LN是什么。Transformer中用的是什么?

    LN:对每个Batch中,所有样本基于所有维度进行归一化,均值和方差数量根据batch_size决定()。
    BN:对各个Batch中的样本,在各个维度进行归一化,均值和方差数量根据batch中的word数量决定(图片的话是维度)。
    使用的实际上是InstanceNorm,对每个instance的所有维度做归一化,但是pytorch中可以用nn.LayerNorm函数实现

实现细节:

  • d m o d e l = 512 d_{model}=512 dmodel=512,模型中的embedding dimension和output的维度固定为512,以保证训练效率
  • h = 8 h = 8 h=8代表使用8个头的注意力
    d k = d v = d m o d e l / h = 64 d_k = d_v = d_{model} / h=64 dk=dv=dmodel/h=64 为每个矩阵得到的向量维度。
    ding dimension和output的维度固定为512,以保证训练效率
  • h = 8 h = 8 h=8代表使用8个头的注意力
    d k = d v = d m o d e l / h = 64 d_k = d_v = d_{model} / h=64 dk=dv=dmodel/h=64 为每个矩阵得到的向量维度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2246584.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【大数据学习 | Spark-Core】spark-shell开发

spark的代码分为两种 本地代码在driver端直接解析执行没有后续 集群代码,会在driver端进行解析,然后让多个机器进行集群形式的执行计算 spark-shell --master spark://nn1:7077 --executor-cores 2 --executor-memory 2G sc.textFile("/home/ha…

增量预训练(Pretrain)样本拼接篇

增量预训练(Pretrain)样本拼接篇 一、Pretrain阶段,为什么需要拼接拼接? 为了提高pretrain效率、拓展LLM最大长度,随机将若干条短文本进行拼接是pretrain阶段常见手段。 二、有哪些拼接方式? 拼接方式一…

【AI最前线】DP双像素sensor相关的AI算法全集:深度估计、图像去模糊去雨去雾恢复、图像重建、自动对焦

Dual Pixel 简介 双像素是成像系统的感光元器件中单帧同时生成的图像:通过双像素可以实现:深度估计、图像去模糊去雨去雾恢复、图像重建 成像原理来源如上,也有遮罩等方式的pd生成,如图双像素视图可以看到光圈的不同一半&#x…

从零开始-VitePress 构建个人博客上传GitHub自动构建访问

从零开始-VitePress 构建个人博客上传GitHub自动构建访问 序言 VitePress 官网:VitePress 中文版 1. 什么是 VitePress VitePress 是一个静态站点生成器 (SSG),专为构建快速、以内容为中心的站点而设计。简而言之,VitePress 获取用 Markdown…

使用uniapp编写APP的文件上传

使用uniapp插件文件选择、文件上传组件(图片,视频,文件等) - DCloud 插件市场 实用效果: 缺陷是只能一个一个单独上传

【51单片机】红外遥控

学习使用的开发板:STC89C52RC/LE52RC 编程软件:Keil5 烧录软件:stc-isp 开发板实图: 文章目录 红外遥控硬件电路 NEC协议编码编程实例LCD1602显示Data红外遥控控制扇叶转速 红外遥控 红外遥控是利用红外光进行通信的设备&#…

【解决】Unity TMPro字体中文显示错误/不全问题

问题描述:字体变成方块 原因:字体资源所承载的长度有限 1.找一个中文字体放入Assets中 2.选中字体创建为TMPro 字体资源 3.选中创建好的字体资源(蓝色的大F) 在右边的属性中找到Atlas Width h和 Atlas Heigth,修改的大一点&…

深度学习:GPT-1的MindSpore实践

GPT-1简介 GPT-1(Generative Pre-trained Transformer)是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构,具有如下创新点: NLP领域的迁移学习:通过最…

CKA认证 | Day2 K8s内部监控与日志

第三章 Kubernetes监控与日志 1、查看集群资源状态 在 Kubernetes 集群中,查看集群资源状态和组件状态是非常重要的操作。以下是一些常用的命令和解释,帮助你更好地管理和监控 Kubernetes 集群。 1.1 查看master组件状态 Kubernetes 的 Master 组件包…

概念解读|K8s/容器云/裸金属/云原生...这些都有什么区别?

随着容器技术的日渐成熟,不少企业用户都对应用系统开展了容器化改造。而在容器基础架构层面,很多运维人员都更熟悉虚拟化环境,对“容器圈”的各种概念容易混淆:容器就是 Kubernetes 吗?容器云又是什么?容器…

JDBC编程---Java

目录 一、数据库编程的前置 二、Java的数据库编程----JDBC 1.概念 2.JDBC编程的优点 三.导入MySQL驱动包 四、JDBC编程的实战 1.创造数据源,并设置数据库所在的位置,三条固定写法 2.建立和数据库服务器之间的连接,连接好了后&#xff…

移动充储机器人“小奥”的多场景应用(上)

在当前现代化城市交通体系中,移动充储机器人“小奥”发挥着至关重要的作用。该机器人不仅是一个简单的设备,而是一个集成了高科技的移动充电站,为新能源汽车提供了一种前所未有的便捷充电解决方案。该机器人配备了先进的电池管理系统&#xf…

element dialog会隐藏body scroll 导致tab抖动 解决方案如下

element dialog会隐藏body scroll 导致tab抖动 解决方案如下 在dialog标签添加 :lockScroll"false"搞定

Android 功耗分析(底层篇)

最近在网上发现关于功耗分析系列的文章很少,介绍详细的更少,于是便想记录总结一下功耗分析的相关知识,有不对的地方希望大家多指出,互相学习。本系列分为底层篇和上层篇。 大概从基础知识,测试手法,以及案例…

Bugku CTF_Web——my-first-sqli

Bugku CTF_Web——my-first-sqli 进入靶场 随便输一个看看 点login没有任何回显 方法一: 上bp抓包 放到repeter测试 试试万能密码(靶机过期了重新开了个靶机) admin or 11--shellmates{SQLi_goeS_BrrRrRR}方法二: 拿包直接梭…

BUUCTF—Reverse—easyre(1)

非常简单的逆向 拿到exe文件先查下信息,是一个64位程序,没有加壳(壳是对代码的加密,起混淆保护的作用,一般用来阻止逆向)。 然后拖进IDA(64位)进行反汇编 打开以后就可以看到flag flag{this_Is_a_EaSyRe}

全面击破工程级复杂缓存难题

目录 一、走进业务中的缓存 (一)本地缓存 (二)分布式缓存 二、缓存更新模式分析 (一)Cache Aside Pattern(旁路缓存模式) 读操作流程 写操作流程 流程问题思考 问题1&#…

React基础知识一

写的东西太多了,照成csdn文档编辑器都开始卡顿了,所以分篇写。 1.安装React 需要安装下面三个包。 react:react核心包 react-dom:渲染需要用到的核心包 babel:将jsx语法转换成React代码的工具。(没使用jsx可以不装)1.1 在html中…

Vue3中使用:deep修改element-plus的样式无效怎么办?

前言:当我们用 vue3 :deep() 处理 elementui 中 el-dialog_body和el-dislog__header 的时候样式一直无法生效,遇到这种情况怎么办? 解决办法: 1.直接在 dialog 上面增加class 我试过,也不起作用,最后用这种…

鸿蒙进阶-状态管理

大家好啊,这里是鸿蒙开天组,今天我们来学习状态管理。 开始组件化开发之后,如何管理组件的状态会变得尤为重要,咱们接下来系统的学习一下这部分的内容 状态管理机制 在声明式UI编程框架中,UI是程序状态的运行结果&a…