15.transformer全解

news2025/1/11 5:44:09

欢迎访问个人网络日志🌹🌹知行空间🌹🌹


文章目录

    • 1.基础介绍
    • 2.网络结构
      • 2.1 Input/Output Embedding
      • 2.2 自注意力机制 self-attention
      • 2.3 point-wise全连接层
      • 2.4 位置编码 Position Encoding
    • 3.输入处理过程示例
    • 4.代码实现

1.基础介绍

论文:Attention Is All You Need

这是Google2017年06月份发表的文章,在这篇文章中作者提出了后来对CV和NLP都产生了影响很大的Transformer网络结构,成为继MLPRNN后又一倍受关注的基础模型。用于序列化数据的学习以输出序列化的预测结果,如应用在NLP领域。Transformer最早的提出就是应用在机器翻译领域,在WMT2014 英语翻译成德语的任务上,BLEU指标达到了28.4,比之前的SOTA提升了2个点。Transformer中使用多头注意力层替换了之前序列转录模型中使用循环神经网络单元。

图片来自于1
在这里插入图片描述

在RNN中,如上图,要计算 h t h_t ht必须先计算 h t − 1 h_{t-1} ht1及其之前的所有输出,这导致模型的计算无法在时间上并行,导致运算效率比较低。此外,因时序信息是一步步向后传递的,因此对于序列早期的信息在后面的计算中有可能会丢掉,而存储 h t h_t ht当序列长度过长时又会占用过多的内存。而Transformer结构使用自注意力机制,使得模型能够进行并行化计算,提升训练速度。

2.网络结构

对于序列数据的学习,经典的结构就是编码-解码结构,编码器将输入序列 ( x 1 , x 2 , . . . , x n ) (x_1,x_2,...,x_n) (x1,x2,...,xn)映射成 ( z 1 , z 2 , . . . , z n ) (z_1,z_2,...,z_n) (z1,z2,...,zn),解码器以 z z z为输入得到 ( y 1 , y 2 , . . . , y m ) (y_1,y_2,...,y_m) (y1,y2,...,ym)作为输出,这里的输出过程是先输出 y 1 y_1 y1,再根据 y 1 y_1 y1输出 y 2 y_2 y2,再根据 y 1 , y 2 y_1,y_2 y1,y2再输出 y 3 y_3 y3,也称这种方式为自回归(auto-regressive)Transformer也是编码解码结构,其中编码解码模型都是由自注意力层和全连接层组成。其网络结构如下图:

在这里插入图片描述

如上图,编码器中的一个block由两个子层sublayer组成,分别是MultiHead Attension层和MLP层组成。MLP层中使用了残差结构,并使用了Layer Normalization,表示为 L a y e r N o r m ( x + S u b l a y e r ( x ) ) LayerNorm(x + Sublayer(x)) LayerNorm(x+Sublayer(x))

解码器中除了使用了于编码器中相同的两个子层外还引入了第三种子层Masked Multi-Head Attention层用于模型自回归的学习,保证在模型训练时t时刻不会看到 t t t时刻以后的序列信息,从而保证训练和预测的时候行为是一致的。

下面对上图中的各个组成单元分别进行介绍:

2.1 Input/Output Embedding

Embedding这个词字面意思表示嵌入,这里介绍,Embedding是将高维数据转换成低维数据,借此可将字词的稀疏向量进行向量化表示。常见的Embedding由自然语言处理中的word embedding,图神经网络中的node embedding等。在这篇文章中作者介绍了NLP中的Word Embedding

2.2 自注意力机制 self-attention

注意力函数可以看成是query值和key-value对到输出output的一个映射,query/key/value都是向量,outputvalue维度相同,outputvalue的加权和,每个value的权重通过计算querykey的相似度得到的,相似度的计算也被称为compatibility function,不同的注意力机制有不同的计算方法。

transformer中使用的querykey是等长的,维度都为 d k d_k dk,outputvalue的维度是 d v d_v dvtransformer中使用的querykey的相似度计算方式很简单,就是计算两个向量的内积再除以向量的维度,然后做softmax得到权重值。

实际计算中,会将多个query/key/value向量打包计算,写成矩阵的形式为:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

这里因为有除以 d k \sqrt{d_k} dk ,因此被称为scaled dot product attention。之所以除以 d k \sqrt{d_k} dk 是为了当序列长度比较大的时候还能比较好的衡量querykey之间的相似度,减少尺度导致的误差变化。

在这里插入图片描述

多头注意力机制 Multi-Head Attention

将前面介绍的attention中的query/key/value通过可学习参数的线性变换投影h次,得到h个query/key/value函数,将每个函数的输出并到一起再经过线性投影得到最终的输出。

在这里插入图片描述

计算公式为:

M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) W O MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中,
h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) head_i=Attention(QW_i^Q, KW_i^K,VW^V_i) headi=Attention(QWiQ,KWiKVWiV)

W i Q ∈ R d m o d e l × d k , W i K ∈ R d m o d e l × d k , W i V ∈ R d m o d e l × d v , W O ∈ R h d v × d m o d e l W_i^Q\in\mathbb{R}^{d_{model}\times d_k},W_i^K\in\mathbb{R}^{d_{model}\times d_k},W_i^V\in\mathbb{R}^{d_{model}\times d_v},W^O\in\mathbb{R}^{hd_{v}\times d_{model}} WiQRdmodel×dk,WiKRdmodel×dk,WiVRdmodel×dv,WORhdv×dmodel是线性投影的可学习参数。

从网络结构图中可以看到,在编码器中的注意力层和解码器的第一个注意力层,Q/K/V使用的是同一个输入,因此这种注意力机制被称为自注意力机制。

2.3 point-wise全连接层

普通的全连接层,其输入的shape:[N,C]其中,N表示的是样本的数量,C表示每个特征的维度,而point_wise全连接层的输入shape:[N,L,C]其中N表示的是样本的数量,L表示句子的长度,C表示的是单词的个数,然后每次全连接是作用在最后一个维度C上的。

pytorch中的nn.Linear函数在处理3dtensor时默认是作用在最后一维上的,可以写成下面形式:

fc = nn.Sequential(
        nn.Linear(512, 12),
        nn.ReLU(),
        nn.Linear(12, 28),
    )

t = torch.randn((3, 16, 512))
fc(t).shape
# torch.Size([3, 16, 28])

计算公式如下:

F F N ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FFN(x) = max(0, xW_1+b_1)W_2+b_2 FFN(x)=max(0,xW1+b1)W2+b2

2.4 位置编码 Position Encoding

前面介绍的attention中只是使用query/key的形式将输出表示成了value的加权和,这里没有输入序列的顺序信息,在RNN中是通过逐个词输出来学习序列信息的,而transformer中是将一个序列一次性输入到模型中,并没有序列中每个单词的信息,因此,这里引入位置编码来表示输入序列的时序信息,并将其作为模型的输入。

对于长度为L的输入序列,要标识每个单词的位置信息,一种方式是给每个位置生成一个唯一的表示位置的向量。transformer中使用如下的方式来计算输入序列的位置编码:

P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos, 2i)}=sin(pos/10000^{2i/d_{model}})\\ PE_{(pos, 2i+1)}=cos(pos/10000^{2i/d_{model}}) PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)

其中, d m o d e l d_{model} dmodel表示的是位置向量的维度,和Input Embedding后得到的每个词的维度相同。 p o s pos pos表示长度为 L L L的序列中的第 p o s pos pos个单词, i i i表示位置向量 d m o d e l d_{model} dmodel维度上的第 i i i维。

使用pytorch实现的位置编码函数为:

import torch

def position_encoding(
    seq_len: int, dim_model: int, device: torch.device = torch.device("cpu"),
) -> Tensor:
    pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1)
    dim = torch.arange(dim_model, dtype=torch.float, device=device).reshape(1, 1, -1)
    phase = pos / 1e4 ** (dim // dim_model)

    return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase))

从上面的代码可以看到,Position Encoding没有使用需要学习的参数,只是手动设计了表示序列位置信息的编码方式。

3.输入处理过程示例

transformer用于翻译任务为例:

输入: x = I am cold

输出: y = 我冷

输入句子的dictionary中有3个词,则输入可以表示成:

word2index = {"I":0,"am":1,"cold":2}

输入句子的向量表示为:

x = [[[0],[1],[2]]

transformer中输入的处理主要有input embeddingposition encoding这两步,如下图:

在这里插入图片描述

对输入句子序列处理结束后将其输入到attention中进行处理,其处理过程如下图所示:

在这里插入图片描述

上图中 L L L表示的序列的长度, d k d_k dkattention中使用的权重的维度, d k d_k dk的大小决定了模型的大小。上图,只描述了Single Head的计算过程,对于Multi Head,使用多组 W Q , W K , W V WQ,WK,WV WQ,WK,WV进行计算,然后将计算得到的结果再进行concatenate即可。

上图描述了attention的计算过程,在attention之后的计算是point wise feed forward。其计算过程表示如下图:

在这里插入图片描述

可以看到这里的FFN是作用在输入样本序列的每个单词向量上的,与之前常见的FFN作用在整个样本上不同。

pytorch中的nn.Linear层处理3d向量时,默认作用在最后一维进行计算,因此可以将attention层输出的结果直接输入到nn.Linear中。

fc = nn.Sequential(
        nn.Linear(512, 12)
    )

t = torch.randn((3, 16, 512))
fc(t).shape
# torch.Size([3, 16, 12])

4.代码实现

使用pytorch实现的transformer可以将代码仓库。

  • 1.http://colah.github.io/posts/2015-08-Understanding-LSTMs/
  • 2.https://zhuanlan.zhihu.com/p/164502624
  • 3.https://www.bilibili.com/video/BV1pu411o7BE/?spm_id_from=333.337.search-card.all.click&vd_source=e75f432df49764db96371bce27ab9fd5

欢迎访问个人网络日志🌹🌹知行空间🌹🌹


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/421181.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论文阅读和分析:Hybrid Mathematical Symbol Recognition using Support Vector Machines

HMER论文系列 1、论文阅读和分析:When Counting Meets HMER Counting-Aware Network for HMER_KPer_Yang的博客-CSDN博客 2、论文阅读和分析:Syntax-Aware Network for Handwritten Mathematical Expression Recognition_KPer_Yang的博客-CSDN博客 3、论…

自然语言处理(七): Deep Learning for NLP: Recurrent Networks

目录 1. N-gram Language Models 2. Recurrent Neural Networks 2.1 RNN Unrolled 2.2 RNN Training 2.3 (Simple) RNN for Language Model 2.4 RNN Language Model: Training 2.5 RNN Language Model: Generation 3. Long Short-term Memory Networks 3.1 Language M…

论文阅读【14】HDLTex: Hierarchical Deep Learning for Text Classification

论文十问十答: Q1论文试图解决什么问题? 多标签文本分类问题 Q2这是否是一个新的问题? 不是 Q3这篇文章要验证一个什么科学假设? 因为文本标签越多,分类就越难,所以就将文本类型进行分层分类,这…

【人工智能与深度学习】判别性循环稀疏自编码器和群体稀疏性

【人工智能与深度学习】判别性循环稀疏自编码器和群体稀疏性 判别类循环稀疏自编码器 (DrSAE)组稀疏组稀疏自编码器的问与答图像级别训练,无权重分享(weight sharing)的局域过滤器 (local filters)判别类循环稀疏自编码器 (DrSAE) DrSAE的设计结合了稀疏编码(稀疏自编码器)…

数据库并发控制基本概念和基本技术

并发控制与基本技术一、并发控制1. 概述2. 并发访问可能出现的问题二、并发控制的主要技术1、基本技术2、封锁及锁的类型2.1、什么是封锁2.2、基本封锁类型2.2.1、排它锁(Exclusive Locks,简记为 X 锁)2.2.2、共享锁(Share Locks&…

基于ArkUI框架开发-ImageKnife渲染层重构

ImageKnife是一款图像加载缓存库,主要功能特性如下: ●支持内存缓存,使用LRUCache算法,对图片数据进行内存缓存。 ●支持磁盘缓存,对于下载图片会保存一份至磁盘当中。 ●支持进行图片变换:支持图像像素源图…

【SSconv:全色锐化:显式频谱-空间卷积】

SSconv: Explicit Spectral-to-Spatial Convolution for Pansharpening (SSconv:用于全色锐化的显式频谱-空间卷积) 全色锐化的目的是融合高空间分辨率的全色(PAN)图像和低分辨率的多光谱(LR-MS&#xff…

【微服务】6、一篇文章学会使用 SpringCloud 的网关

目录一、网关作用二、网关的技术实现三、简单使用四、predicates(1) 网关路由可配置的内容(2) 路由断言工厂(Route Predicate Factory)五、filters(1) GatewayFilter(2) 给全部进入 userservice 的请求添加请求头(3) 全局过滤器 —— GlobalFilter(4) 过…

PX4从放弃到精通(二十七):固定翼姿态控制

文章目录前言一、roll/pitch姿态/角速率控制二、偏航角速率控制三、主程序前言 固件版本 PX4 1.13.2 欢迎交流学习,可加左侧名片 一、roll/pitch姿态/角速率控制 roll/pitch的姿态控制类似,这里只介绍roll姿态控制, 代码位置: …

如何确定NetApp FAS存储系统是否正常识别到了boot device?

近期处理了几个NetApp FAS存储控制器宕机的案例,其中部分有代表性的就是其实控制器并没有物理故障,问题是控制器里面的boot device的SSD盘出现了问题。这里给大家share一下如何确定系统是否成功识别到了boot device设备。 对于很多非专业人士来说&#…

mongodb使用docker搭建replicaSet集群与变更监听

在mongodb如果需要启用变更监听功能(watch),mongodb需要在replicaSet或者cluster方式下运行。 replicaSet和cluster从部署难度相比,replicaSet要简单许多。如果所存储的数据量规模不算太大的情况下,那么使用replicaSet方式部署mongodb是一个…

凹凸/法线/移位贴图的区别

你是否在掌握 3D 资产纹理的道路上遇到过障碍? 不要难过! 许多刚接触纹理或 3D 的艺术家在第一次遇到凹凸贴图(Bump Map)、法线贴图(Normal Map)和移位贴图(Displacement Map)时通常…

Linux Redis主从复制 | 哨兵监控模式 | 集群搭建 | 超详细

Linux Redis主从复制 | 哨兵监控模式 | 集群搭建 | 超详细一 Redis的主从复制二 主从复制的作用三 主从复制的流程四 主从复制实验4.1 环境部署4.2 安装Redis(主从服务器)4.3 修改Master节点Redis配置文件 (192.168.163.100)4.4 修改Slave节点Redis配置文…

MySQL-用户与权限

目录 🍁DB权限表 🍁新建普通用户 🍂创建新用户(create user) 🍂创建新用户(grant) 🍁删除普通用户 🍁修改用户密码 🍂Root用户修改自己的密码 🍂Root用户修改普通用户密码 &#x1f…

区块链概论

目录 1.概述 2.密码学原理 2.1.hash函数 2.2.签名 3.数据结构 3.1.区块结构 3.2.hash pointer 3.3.merkle tree 3.3.1.概述 3.3.2.证明数据存在 3.3.3.证明数据不存在 4.比特币的共识协议 4.1.概述 4.2.验证有效性 4.2.1.验证交易有效性 4.2.2.验证节点有效性 …

~~~~~不得不会的账号与权限管理小知识

目录一.用户账号和组账号概述二. useradd添加用户账号三. passwd 修改密码四. 修改用户账户的属性五 . userdel 删除用户账号六. 用户账号的初始配置文件七. 组账号文件八 . 文件/目录的权限及归属8.1设置文件和目录的权限chmod8.2 设置文件和目录的归属chown命令8.3 补充扩展:…

JAVA本地监听与远程端口扫描的设计与开发

随着Internet的不断发展,信息技术已成为社会进步的巨大推动力。不管是存储于服务器里还是流通于Internet上的信息都已成为一个关系事业成败的关键,这就使保证信息的安全变得格外重要。本地监听与远程端口扫描程序就是在基于Internet的端口扫描的基础上&a…

Optional类快速上手

目录 一、概述 二、使用 1、创建对象 2、安全消费值 3、安全获取值 4、过滤 5、判断 6、数据转换 一、概述 我们在编码的时出现最多的就是空指针异常,所以在很多情况下我们需要做各种非空的判断。 尤其是对象中的属性还是一个对象的情况下,这种…

Doris(3):创建用户与创建数据库并赋予权限

Doris 采用 MySQL 协议进行通信,用户可通过 MySQL client 或者 MySQL JDBC连接到 Doris 集群。选择 MySQL client 版本时建议采用5.1 之后的版本,因为 5.1 之前不能支持长度超过 16 个字符的用户名。 1 创建用户 Root 用户登录与密码修改 Doris 内置 ro…

从C出发 19 --- 函数定义细节剖析

因为编译器是自上而下执行代码的,当编译到 paw2 的时候不知道是什么东西,看起来像一个函数但是前面的代码没有发现它,这个时候编译器就会报错 为了防止编译器报错 应该在调用前先声明 ,注意声明的三要素 声明的作用: 让编译器先…