解析transformer中的各模块结构

news2024/12/29 11:33:06

transformer是一种编解码(encoder-decoer)结构,用于自然语言处理、计算机视觉等领域,编解码结构是当前大模型必包含的部分。

文章目录

1. 词嵌入模块

2.位置编码模块

3. 多头注意力机制模块

3.1 自注意力机制模块

3.2 多头注意力机制

3.3 为什么使用自注意力机制模块

4. 层归一化模块

5. 残差模块

6. 前馈神经网络模块

7. 交叉多头注意力机制模块

8. 掩码多头注意力机制模块

编解码结构图:

图片

transformer模块编码输入得到特征,然后解码得到输出。

transformer论文的一张经典图:

结合transformer论文和代码,模块主要包括了:

  • 词嵌入模块(input embedding)

  • 位置编码模块(Positional Encoding)

  • 多头注意力机制模块(Multi-Head Attention)

  • 层归一化模块(LayNorm)

  • 残差模块

  • 前馈神经网络模块(FFN)

  • 交叉多头注意力机制模块(Cross Multi-Head Attention)

  • 掩膜多头注意力机制模块(Masked Multi-Head Attention)

1. 词嵌入模块

词嵌入模块调用nn.Embedding,其主要作用是将每个单词表示成一个向量,方便下一步计算和处理。


class TokenEmbedding(nn.Embedding):
    """
    Token Embedding using torch.nn
    they will dense representation of word using weighted matrix
    """

    def __init__(self, vocab_size, d_model):
        """
        class for token embedding that included positional information

        :param vocab_size: 字典中词的个数
        :param d_model: 嵌入维度
        """
        super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)

若batch_size是bs,输入词的个数是seq_len,词嵌入维度是n_model。

因此输入的维度x:tensor(bs,seq_len)

TokenEmbdding(x)后的维度:tensor(bs,seq_len,d_model)

2.位置编码模块

位置编码模块根据每个词的位置进行编码得到位置向量,原理是采用三角函数编码。如下图,左边采用正弦函数编码,右边采用余弦函数编码,最后进行拼接。

其中行表示嵌入维度,列表示词的位置。

另一种位置编码方法是正弦函数和余弦函数根据位置进行交叉,如位置1采用sin,位置2采用cos,位置3采用sin,位置4采用cos,以此类推。如下图:

若batch_size是bs,输入词的个数是seq_len,词嵌入维度是n_model。


class PositionalEncoding(nn.Module):
    """
    compute sinusoid encoding.
    """

    def __init__(self, d_model, max_len, device):
        """
        constructor of sinusoid encoding class

        :param d_model: dimension of model,词嵌入维度
        :param max_len: max sequence length,最大语句的长度
        :param device: hardware device setting
        """
        super(PositionalEncoding, self).__init__()

        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # we don't need to compute gradient

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, d_model, step=2, device=device).float()
        # 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)
    # 位置编码
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # compute positional encoding to consider positional information of words

因此输入的维度x:tensor(bs,seq_len)

PositionalEncoding(x),输出的维度是:tensor(seq_len,d_model)

将词嵌入模块与位置模块相加得到Encoder模块的输入。

维度表示这一过程:

tensor(bs,seq_len,n_model) + tensor(seq_len,d_model) = tensor(bs,seg_len,n_model)

3. 多头注意力机制模块

多头注意力机制模块是transformer的核心,是encoder和decoder的重要组成部分。

多头注意力机制模块通过多个自注意力机制拼接,然后通过全连接层得到输出,理解自注意力机制就能很好的理解多头注意力机制。

3.1 自注意力机制模块

自注意力机制模块通过WQ,WK,WV矩阵得到Q,K,V矩阵。如下图:

图片

然后矩阵Q与矩阵K的转置相乘得到每个词间的相关系数,并通过词嵌入维度和softmax对相关系数进行归一化。

图片

最后乘V矩阵,得到自注意力机制的输出,

图片

向量维度表示这一过程:

输入x:tensor(bs,seq_len,n_model)

WQ,WK,WV:tensor(n_model,n_model)

x@WQ、x@WK和x@WV得到Q,K,V的矩阵:tensor(bs,seq_len,n_model),其中@表示矩阵内即

Q@K^T:tensor(bs,seq_len,seq_len)

coef=softmax(Q@K^T)/sqrt(n_model):tensor(bs,seq_len,seq_len)

coef@V:tensor(bs,seq_len,n_model)

3.2 多头注意力机制

多头注意力机制,顾名思义,包含多个自注意力机制,然后将多个自注意力机制的输出进行拼接,最后通过全连接层得到输出。

如下图,输入x通过多个自注意力机制得到多个Q,K,V。

图片

然后按照上节描述的那样,得到多个Z。

图片

对多个注意力机制的输出Z进行拼接:

图片

最后喂入全连接层,得到最终的结果。

向量维度表示这一过程:

多头:n_head

输入x:tensor(bs,seq_len,n_model)

n_head的自注意力机制模块的输出:tensor(bs,seq_len,n_model)

拼接:(bs,seq_len,n_model*n_head)

全连接层权重:tensor(n_model*n_head,n_model)

喂入全连接层后的输出:tensor(bs,seq_len,n_model)

我们通过矩阵思想实现多头注意力机制模块,矩阵每行表示某词的嵌入向量,行数表示词个数,列数表示嵌入维度。

3.3 为什么使用自注意力机制模块

这个机制使得模型能够自动确定输入序列中的不同元素对生成特定输出的重要性,这种权重分配方法允许模型更好地理解序列中的上下文关系,同时处理顺序无关的数据。

4. 层归一化模块

层归一化模块用于调整数据范围,加速模型的训练,LayerNorm在每个样本上统计所有维度的值,计算均值和方差,这样的做的优点是可以不受样本数的限制。

向量维度表示这一过程:

x: tensor(bs,seq_len,n_model)

LayerNorm(x): tensor(bs,seq_len,n_model)

5. 残差模块

这个比较简单,相信只要了解深度学习都知道resnet网络。

向量维度表示这一过程:

x: tensor(bs,seq_len,n_model)

selfAttentions(x): tensor(bs,seq_len,n_model)

y: tensor(bs,seq_len,n_model)

层归一化LayerNorm(y): tensor(bs,seq_len,n_model)

6. 前馈神经网络模块

FFN模块通过特征变换和维度扩展,使得模型能够更好地理解和表示输入序列的信息,从而提高了自然语言处理任务的性能,模块通过两个全连接层实现。

self.linear1 = nn.Linear(d_model, hidden)
self.linear2 = nn.Linear(hidden, d_model)

向量维度表示这一过程:

输入x: tensor(bs,seq_len,n_model)

x=linear1(x): tensor(bs,seq_len,hidden)

linear2(x): tensor(bs,seq_len,n_model)

7. 交叉多头注意力机制模块

交叉多头注意力模块位置如下标注的红色矩形框:

除了Q,K,V的含义不一样,交叉多头注意力模块和多头注意力机制非常相似,交叉注意力机制计算两个序列的注意力,用于处理两个序列之间的语义关系,论文中是计算输入序列和输出序列的注意力,其中K和V来源于输入序列,Q来源于输出序列。多头注意力机制是计算单个序列的注意力,Q,K,V都来源于同一个序列。

向量维度表示这一过程:

q: tensor(bs,encoder_seq_len,d_model)

v: tensor(bs,encoder_seq_len,d_model)

k: tensor(bs,decoder_seq_len,d_model)

其中bs,encoder_seq_len,decoder_seq_len,d_model分别为样本个数,输入语句的长度,输出语句的长度,词嵌入维度

x = multi_head_attention(q,k,v)   # 多头注意力模块

x的维度:tensor(bs,decoder_seq_len,n_model)

8. 掩码多头注意力机制模块

掩码多头注意力机制模块位置如下红色矩形框:

掩码的目的是为了防止网络看到不该看的内容,输出语句是有前后关系的,Transformer推理时,我们是一个词一个词的输出,但在训练时这样做效率太低了。我们还是希望利用向量的思想囊括所有的输出单词,因此我们会将target一次性给到Transformer,利用掩码选择计算注意力机制的单词序列。

参考链接:http://jalammar.github.io/illustrated-transformer/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1638535.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Golang图片验证码的使用

一、背景 最近在使用到Golang进行原生开发,注册和登录页面都涉及到图片验证码的功能。找了下第三方库的一些实现,发现了这个库用得还是蛮多的。并且支持很多类型的验证方式,例如支持数字类型、字母类型、音频验证码、中文验证码等等。 项目地…

操作系统(2)——进程线程

目录 小程一言专栏链接: [link](http://t.csdnimg.cn/8MJA9)基础概念线程详解进程详解进程间通信调度常用调度算法 重要问题哲学家进餐问题问题的描述策略 读者-写者问题问题的描述两种情况策略 总结进程线程一句话 小程一言 本操作系统专栏,是小程在学操作系统的过…

Dockerfile实战(SSH、Systemctl、Nginx、Tomcat)

目录 一、构建SSH镜像 1.1 dockerfile文件内容 1.2 生成镜像 1.3 启动容器并修改root密码 二、构建Systemctl镜像 2.1 编辑dockerfile文件 ​编辑2.2 生成镜像 2.3 启动容器,并挂载宿主机目录挂载到容器中,然后进行初始化 2.4 进入容器验证 三、…

考研管理类联考(专业代码199)数学基础【2】整式与分式

一、整式及其运算 1.常用乘法公式(逆运算就是因式分解) 公式扩展① 公式扩展② 公式扩展③ 2.整式除法定理 若整式 F(x) 除以x-a的余式为r(x),则 F(x) (x -a) g(x) r(x) ,故r(a)F(a)成立 二、指数和对数的运算性质 1.指数运算…

【电路笔记】-石英晶体振荡器

石英晶体振荡器 文章目录 石英晶体振荡器1、概述2、石英晶体等效模型3、石英晶体振荡器示例14、Colpitts 石英晶体振荡器5、Pierce振荡器6、CMOS晶体振荡器7、微处理器水晶石英钟8、石英晶体振荡器示例21、概述 任何振荡器最重要的特性之一是其频率稳定性,或者换句话说,其在…

Linux migrate_type初步探索

1、基础知识 我们都知道Linux内存组织管理结构架构,顶层是struct pglist_data,然后再到struct zone,最后是struct page。大概的管理结构是这样的: 根据物理内存的地址范围可划分不同的zone,每个zone里的内存由buddy…

审计师能力与专长数据集(2014-2022年)

01、数据介绍 审计师是专门从事审计工作的人员,他们对企业、政府机关、金融机构等组织进行独立的、客观的、合法的审计,以评估这些组织的财务状况、经营绩效和风险水平。审计师通过收集和评估证据,以确定被审计单位的财务报表是否公允、合法…

【第3章】spring-mvc请求参数处理

文章目录 前言一、准备1. 增加mavan配置 二、简单参数1.JSP2.Controller 三、复杂参数1.JSP2.Controller 三、扩展1.JSP2.header3.cookie4.session 总结 前言 在上一章的基础上,我们来学习对于请求参数的解析,前后端分离已经是大势所趋,JSP相…

IOS上线操作

1、拥有苹果开发者账号 2、配置证书,进入苹果开发者官网(https://developer.apple.com/) 3、点击账户(account),然后创建一个唯一的标识符 4、点击"Identifiers",然后点击"&qu…

【C++】学习笔记——内存管理

文章目录 二、类和对象20. 友元1. 友元函数2.友元类 21. 内部类22. 匿名对象23. 拷贝对象时的一些编译器优化 三、内存管理1. C/C内存分布2. C语言中动态内存管理方式:malloc/calloc/realloc/free3. C内存管理方式 未完待续 二、类和对象 20. 友元 1. 友元函数 我…

ELK Stack 8 接入ElasticFlow

介绍 Netflow v5 / v9 / v10(IPFIX),支持大部分网络厂商及VMware的分布式交换机。 NetFlow是一种数据交换方式。Netflow提供网络流量的会话级视图,记录下每个TCP/IP事务的信息。当汇集起来时,它更加易于管理和易读。…

基于Springboot+Vue的Java项目-入校申报审批系统开发实战(附演示视频+源码+LW)

大家好!我是程序员一帆,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &am…

产业结构-整体升级、合理化、高级化数据集(1990-2022年)

一、数据介绍 数据名称:产业结构协调-高级化、合理化 数据年份:1990-2022年 数据范围:全国31个省份 数据来源:中国统计NJ、国家TJ局 数据类型:内含原始版本、线性插值版本、ARIMA填补版本 数据说明:参…

分类规则挖掘(二)

目录 三、决策树分类方法(一)决策树生成框架(二)ID3分类方法(三)决策树的剪枝(四)C4.5算法 三、决策树分类方法 决策树 (Decision Tree) 是从一组无次序、无规则,但有类别…

240 基于matlab的飞行轨迹仿真程序

基于matlab的飞行轨迹仿真程序,多种不同的飞行轨迹,输出经度、纬度、高度三维轨迹,三个方向的飞行速度。程序已调通,可直接运行。 240 飞行轨迹仿真 三维轨迹 飞行速度 - 小红书 (xiaohongshu.com)

Hive优化以及相关参数设置

1.表层面设计优化 1.1 表分区 分区表实际上就是对应一个 HDFS 文件系统上的独立的文件夹,该文件夹下是该分区所有的数据文件。Hive 中的分区就是分目录,把一个大的数据集根据业务需要分割成小的数据集。在查询时通过 WHERE 子句中的表达式选择查询所需要…

sunshine+n2n+moonlight串流远程控制全教程

远程主机说明(两台电脑不在同一局域网下): 控制台电脑 被控制电脑 所有工具下载地址:https://www.lanzouw.com/b00eepod7e 密码:1234 一、首先NTN组网 使用NTN技术创建虚拟局域网,实现设备之间的P2P连接。 NTN组网…

制作一个RISC-V的操作系统十五-软件定时器

文章目录 定时器分类定时器相关分类软件定时器设计初始化创建删除触发流程图形示意 优化代码 定时器分类 硬件定时器:由硬件频率和触发限制的大小决定,只有一个,精度高 软件定时器:基于硬件定时器实现,精度大于等于硬…

python学习之词云图片生成

代码实现 import jieba import wordcloudf open("D:/Pythonstudy/data/平凡的世界.txt", "r", encoding"utf-8") t f.read() print(t) f.close() ls jieba.lcut(t) txt " ".join(ls)w wordcloud.WordCloud(font_path"D:/cc…

Redis系列-1 Redis介绍

背景: 本文介绍Redis相关知识,包括Redis的使用、单线程机制、事务、内存过期和淘汰机制。后续将在《三方件-3 Redis持久化机制》中介绍Redis基于RDB和AOF的持久化机制;在《三方件-4 Redis集群》介绍主从、哨兵和Cluster集群相关的内容&#…