【AIGC入门一】Transformers 模型结构详解及代码解析

news2024/12/22 19:10:17

Transformers 开启了NLP一个新时代,注意力模块目前各类大模型的重要结构。作为刚入门LLM的新手,怎么能不感受一下这个“变形金刚的魅力”呢?

目录

Transformers ——Attention is all You Need

背景介绍

模型结构

位置编码

代码实现:

Attention

Scaled Dot-product Attention

Multi-head Attention

Position-Wise Feed-Forward Networks

Encoder and Decoder

Add & Norm

mask 机制

参考链接


论文链接:Attention Is All You Need

Transformers ——Attention is all You Need

背景介绍

        在Transformer提出之前,NLP主要基于RNN、LSTM等算法解救相关问题。这些模型在处理长序列时面临梯度消失和梯度爆炸等问题,且这些模型是串行计算的,运行时间较长。

        Transformer 模型的提出是为了摆脱序列模型的顺序依赖性,引入了注意力机制,使得模型能够在不同位置上同时关注输入序列的各个部分,且支持并行计算。该模型的提出对深度学习和自然语言处理领域产生了深远的影响,成为了现代NLP模型的基础架构,并推动了attention 机制在各种任务中的应用。

模型结构

位置编码

        任何一门语言,单词在句子中的位置以及排列顺序是非常重要的。一个单词在句子的位置或者排列顺序不同,整个句子的意义就发生了偏差。举个例子:

小明小王500块

小王小明500块

顺序不同,债主关系就发生变化了😑

        当采用了Attention之后,句子中的词序信息就会丢失,模型就没法知道每个词在句子中的相对和绝对的位置信息。目前位置编码有多种方法:

(1)整型值标记位置,即第一个token标记为1, 第二个token标记为2。。。以此类推

         可能存在的问题:

  • 随着序列长度的增加,位置值会越来越大;
  • 推理的序列长度比训练时所用的序列长度更长,不利于模型的泛化

(2)用[0,1] 范围标记位置

        将位置值的范围限制在[0,1]之内,即在第一种的方法进行归一化操作(除以序列长度)。比如有4个token,那么位置信息就是[0, 0.33, 0.69, 1]。 但这样产生的问题是,当序列长度不同时,token间的相对距离是不一样的。

        因此,一个好的位置编码方法应该满足以下特性:

(1)可以表示一个token 在序列中的绝对位置;

(2)在序列长度不同的情况下,不同序列中token 的相对位置/ 距离要保持一直;

(3)可以扩展到更长的句子长度;

        Transformers 中选择的是sincos编码法,其公式如下所示:

        其中,pos 是token在sentence中的位置,i是维度。

代码实现:

        假设句子长度是 s, embedding的维度是d, 最终生成的PE的shape是(s, d)。公式的核心是计算pos /10000\tfrac{2i}{d_{model}}, 这里可以借助对数和指数的性质进行如下操作:

a = e^{^{loga}}

        所以可以转换成1/ 10000^{^{2i/d_{model}}} = e^{ - log(10000) * 2i /d_{model})}(可对照代码进行推导理解)

class Position_Encoding(nn.Module):
    
    def __init__(self, max_length, d_model):
        self.max_length = max_length
        self.dim = d_model
        
        pe = torch.zeros(self.max_length, self.dim)
        
        position = torch.arange(0, self.max_length).unsqueeze(1)
        
        div_term = torch.exp( torch.arange(0, self.dim ,2) * (-1) *math.log(10000) / self.dim)
        
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        

    def forward(self, x):
        
        #  input_embedding + position_encoding
        #....

Attention

        Attention 是将query 和key、value映射为输出值,其中query 和 key 计算一个相似度,然后以这个相似度为权重,计算value的加权和,最终得到输出。

Scaled Dot-product Attention

        论文中用的是放缩点乘注意力(scaled dot-product attention),其公式是:

Attention(Q, K, V) = softmax(\frac{QK^{}{T}}{sqrt(d_k))})

其中, 计算时需要用到的矩阵Q(查询),K(键值),V(值)是输入单词的embedding 变换或者 上一个Encoder block的输出 。注意Q、K、V的shape 会存在一定的联系(因为需要做矩阵乘法运算)。

        公式中会除以dk的平方根,从而避免内积过大。还有解释是说 softmax 在 绝对值较大的区域梯度较小,梯度下降的速度比较慢。因此希望softmax的点乘数值尽可能小。

        论文中解释了为什么点积会变大。假设q 和 k中的元素满足独立分布,且均值是0,方差为1。点积 q*k = \sum_{i=1}^{dk} q_i * k_i   的均值是0, 方差是dk 。

Multi-head Attention

        作者发现,相比直接在dmodel 维度上的 q、k、v进行attention计算 ,使用不同的、可学习的linear function 分别地对q、k、v 进行多次映射(映射的维度是dk, dk, dv)  , 然后对每一组映射的q、k、v进行attention 并行计算,并concat得到最终输出。后一种方法更有效。就像卷积层可以用多个卷积核生成多个通道的特征,在Transformers中可以用多组self attention 生成多组注意力的结果,从而增加特征表示。其计算公式和流程图如下:

        注意: head的数量 * 每一组head中q的维度 = dmodel (输入Q的维度)

Position-Wise Feed-Forward Networks

        前馈网络比较简单,是一个两层的全连接层,第一层的激活函数是ReLU,第二层不使用激活函数,对应的公式如下所示:

Encoder and Decoder

       Transformer 从结构上可以分为Encoder 和Decoder 两个部分,这两者结构上比较类似,但也存在一些差异。

        上图红色区域对应的是Encoder部分,可以看出是由 Input Embedding 、Position Encoding 和6层的EncoderLayer组成。 EncoderLayer 主要包括Multi-head Attention, Add&Norm, Feed Forward ,Add&Norm。

        上图绿色区域对应的是Decoder部分,相比Encoder,需要注意Decoder中的Multi-head Attention 有所不同。首先是Masked Multi-head Attention, 是为了实现串行推理;第二个Multi-head Attention输入的Q、K、V来自不同的地方,其中Q是Masked Multi-head Attention 的输出, K和V是Encoder 的输出。

Add & Norm

        这部分主要由Add 和 Norm 组成,其计算公式如下所示:

        Add 是一种残差结构,和ResNet中的是一样的,可以帮助网络收敛。Norm 是指Layer Norm。

mask 机制

        Transformers中比较重要的一个知识点就是mask设置。mask主要来源有两个:第一个是填充操作的空白字符(为了保证batch内句子的长度一样会进行padding操作);第二个是因为模拟串行推理需要用到mask(Decoder部分)。

        一般情况下, query 和 key都是一样的,但是在Decoder的第二个多头注意力层中,query 来自目标语言,key来自源语言。为了生成mask, 首先要知道query 和 key中<pad> 字符的分布情况,它们的形状为[n, seq_len]。如果某处是True, 表明这个地方的字符是<pad>。

src_pad_mask = x == pad_idx
dst_pad_mask = y == pad_idx

        为了实现串行推理,即某字符只能知道该字符以及该字符之前的内容,即一个下三角全1矩阵。mask矩阵需要取反,实现方式如下所示:

mask = 1 - torch.tril(torch.ones(mask_shape))

        最后根据<pad>字符分布情况分别将mask对应的行或者列置1。

参考链接

  1. GitHub - P3n9W31/transformer-pytorch: Transformer model for Chinese-English translation.
  2. PyTorch Transformer 英中翻译超详细教程 - 知乎
  3. Transformer模型详解(图解最完整版) - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1389330.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

51单片机学习总结(自学)

1、模块化编程 c语言模块化编程实现思路设计代码 具体的程序实现代码如下所示 1&#xff1a;程序的头文件 2&#xff1a;程序的函数文件 3&#xff1a;程序的主文件控制函数的实现 持续更新中......

算法部署过程中如何确保数据的安全?

在数字化时代&#xff0c;数据安全成为了企业和个人面临的一项主要挑战。随着技术的迅速发展&#xff0c;尤其在算法部署过程中&#xff0c;确保敏感数据的安全性变得更加复杂和关键。在这个背景下&#xff0c;软件加密和授权机制的作用显得尤为重要。软件加密不仅仅是转换数据…

高纯气体市场调研:预计2029年将达到331亿美元

高纯气体应用领域极宽&#xff0c;在半导体工业&#xff0c;高纯氮、氢、氩、氦可作为运载气和保护气;高纯气体可作为配制混合气的底气。随着LED和半导体的发展&#xff0c;对于其原物料生产的所需要的高纯气体&#xff0c;特别是7N级别的高纯氨气的需求不断增加&#xff0c;近…

CMake TcpServer项目 生成静态库.a / 动态库.so

CMake 实战构建TcpServer项目 静态库/动态库-CSDN博客https://blog.csdn.net/weixin_41987016/article/details/135608829?spm1001.2014.3001.5501 在这篇博客的基础上&#xff0c;我们把头文件放在include里边&#xff0c;把源文件放在src里边&#xff0c;重新构建 hehedali…

深入理解 PyTorch 激活函数:从基础到高效实用技巧(4)

目录 torch.nn.functional激活层详解 tanh 1. 函数用途 2. 参数详解 3. Tanh函数的定义及数学解释 4. 使用示例 sigmoid 1. 函数用途 2. 参数详解 3. Sigmoid函数的定义及数学解释 4. 使用示例 hardsigmoid 1. 函数用途 2. 参数详解 3. Hardsigmoid函数的定义及…

蓝桥杯AcWing学习笔记 8-1数论的学习(上)

蓝桥杯 我的AcWing 题目及图片来自蓝桥杯C AB组辅导课 数论&#xff08;上&#xff09; 蓝桥杯省赛中考的数论不是很多&#xff0c;这里讲几个蓝桥杯常考的知识点。 欧几里得算法——辗转相除法 欧几里得算法代码&#xff1a; import java.util.Scanner ;public class Main…

大物②练习题解

1.【单选题】关于磁场中磁通量&#xff0c;下面说法正确的是&#xff08; D&#xff09; A、穿过闭合曲面的总磁通量不一定为零 B、磁感线从闭合曲面内穿出&#xff0c;磁通量为负 C、磁感线从闭合曲面内穿入&#xff0c;磁通量为正D、穿过闭合曲面的总磁通量一定为零 磁感线从…

(超详细)3-YOLOV5改进-添加SE注意力机制

1、在yolov5/models下面新建一个SE.py文件&#xff0c;在里面放入下面的代码 代码如下&#xff1a; import numpy as np import torch from torch import nn from torch.nn import initclass SEAttention(nn.Module):def __init__(self, channel512,reduction16):super()._…

云渲染的官网地址是什么?

云渲染的官网地址&#xff1a;http://www.xuanran100.com/?ycode1a12 云渲染能把渲染工作从本地移到云端进行&#xff0c;不需要设计师配置高性能电脑&#xff0c;十分方便。目前国内领先的云渲染平台是渲染100&#xff0c;它有以下几个优点&#xff1a;1、使用方便 一键提交渲…

Pandas加载大数据集

Scaling to large datasets — pandas 2.1.4 documentationhttps://pandas.pydata.org/docs/user_guide/scale.html#use-efficient-datatypes官方文档提供了4种方法&#xff1a;只加载需要的列、转化数据类型、使用chunking&#xff08;转化文件存储格式&#xff09;、使用Dask…

CXYGZL-程序员工作流,持续迭代升级中

概述 现在开源的工作流引擎&#xff0c;基本都是以BPMN.js为基础的&#xff0c;导致使用门槛过高&#xff0c;非专业人员无法驾驭。本工作流借鉴钉钉/飞书的方式&#xff0c;以低代码方式降低用户使用门槛&#xff0c;即使是普通企业用户也可以几分钟内就能搭建自己的工作流引…

O2066PM无线WIFI6E网卡Windows环境吞吐测试

从2023年开始&#xff0c;除手机外的无线终端设备也逐步向WIFI6/6E进行升级更新&#xff0c;基于802.11ax技术的设备能够进一步满足用户体验新一代Wi-Fi标准时获得优质的性能和覆盖范围。 用户对于WIFI模块&#xff0c;通常会关注WIFI模块的吞吐量&#xff0c;拿到样品之后&am…

详细的二进制安装部署Mysql8.2.0

目录 一、下载版本 二、卸载MariaDB 三、MySQL二进制安装 3.1 创建mysql工作目录&#xff1a; 3.2、上传软件&#xff0c;并解压并改名为app 3.3、修改环境变量 3.4、建立mysql用户和组(如果有可忽略) 3.5、创建mysql 数据目录&#xff0c;日志目录&#xff1b;并修改权…

高级分布式系统-第15讲 分布式机器学习--概念与学习框架

高级分布式系统汇总&#xff1a;高级分布式系统目录汇总-CSDN博客 分布式机器学习的概念 人工智能蓬勃发展的原因&#xff1a;“大” 大数据&#xff1a;为人工智能技术的发展奠定了坚实的物质基础。 大规模机器学习模型&#xff1a;具备超强的表达能力&#xff0c;可以解决…

vue2使用Lottie

文章目录 学习链接1.安装依赖2.创建lottie组件3.在相对应的页面应用4.相关data.json5.测试效果 学习链接 原文链接&#xff1a;lottie在vue中的使用 lottie官网&#xff1a;https://lottiefiles.com/ 1.安装依赖 npm install lottie-web2.创建lottie组件 <template>…

JNPF低代码引擎到底是什么?

最近听说一款可以免费部署本地进行试用的低代码引擎&#xff0c;源码上支持100%源码&#xff0c;提供的功能和技术支持比较完善。借助这篇篇幅我们了解下JNPF到底是什么&#xff1f; JNPF开发平台是一款PaaS服务为核心的零代码开发平台&#xff0c;平台提供了多租户账号管理、主…

短期交易离不开的工具!10日均线在现货白银中的应用

10日均线是一根短期均线&#xff0c;对于做短线交易的现货白银投资者来说&#xff0c;它是一个很好用的工具。下面我们就来讨论一下&#xff0c;在现货白银交易中10日均线的具体应用是什么&#xff1f; 验证趋势。我们可以使用10日均线来验证趋势。由于10日均线是短期均线&…

【51单片机系列】继电器使用

文章来源&#xff1a;《零起点学Proteus单片机仿真技术》。 本文是关于继电器使用相关内容。 继电器广泛应用在工业控制中&#xff0c;通过继电器对其他大电流的电器进行控制。 继电器控制原理图如下。继电器部分包括控制线圈和3个引脚&#xff0c;A引脚接电源&#xff0c;B引…

SD-WAN服务简介及挑选服务商指南

在跨境业务蓬勃发展的今天&#xff0c;越来越多的企业开始采用SD-WAN组网&#xff0c;这项技术不仅能够整合现有基础设施投资&#xff0c;还能以灵活、安全的方式支持跨境办公和访问海外网站。那么&#xff0c;如何为企业选择最适合的SD-WAN服务商呢&#xff1f; 首先&#xff…

RViz成功显示多个机器人模型以及解决显示的模型没有左右轮

RViz显示机器人模型没有左右轮 一、RViz成功显示多个机器人模型机器人模型的左右轮无法显示 一、RViz成功显示多个机器人模型 在RViz中显示多个机器人模型需要设置好几个关键的参数 首先点击Add&#xff0c;找到RobotModel&#xff0c;添加进来 Fixed Frame&#xff1a;选择T…