从零开始实现大语言模型(四):简单自注意力机制

news2024/11/28 16:52:27

1. 前言

理解大语言模型结构的关键在于理解自注意力机制(self-attention)。自注意力机制可以判断输入文本序列中各个token与序列中所有token之间的相关性,并生成包含这种相关性信息的context向量。

本文介绍一种不包含训练参数的简化版自注意力机制——简单自注意力机制(simplified self-attention),后续三篇文章将分别介绍缩放点积注意力机制(scaled dot-product attention),因果注意力机制(causal attention),多头注意力机制(multi-head attention),并最终实现OpenAI的GPT系列大语言模型中MultiHeadAttention类。

2. 从循环神经网络到自注意力机制

解决机器翻译等多对多(many-to-many)自然语言处理任务最常用的模型是sequence-to-sequence模型。Sequence-to-sequence模型包含一个编码器(encoder)和一个解码器(decoder),编码器将输入序列信息编码成信息向量,解码器用于解码信息向量,生成输出序列。在Transformer模型出现之前,编码器和解码器一般都是一个循环神经网络(RNN, recurrent neural network)。

RNN是一种非常适合处理文本等序列数据的神经网络架构。Encoder RNN对输入序列进行处理,将输入序列信息压缩到一个向量中。状态向量 h 0 h_0 h0包含第一个token x 0 x_0 x0的信息, h 1 h_1 h1包含前两个tokens x 0 x_0 x0 x 1 x_1 x1的信息。以此类推, Encoder RNN最后一个状态 h m h_m hm是整个输入序列的概要,包含了整个输入序列的信息。Decoder RNN的初始状态等于Encoder RNN最后一个状态 h m h_m hm h m h_m hm包含了输入序列的信息,Decoder RNN可以通过 h m h_m hm知道输入序列的信息。Decoder RNN可以将 h m h_m hm中包含的信息解码,逐个元素地生成输出序列。

RNN的神经网络结构及计算方法使Encoder RNN必须用一个隐藏状态向量 h m h_m hm记住整个输入序列的全部信息。当输入序列很长时,隐藏状态向量 h m h_m hm对输入序列中前面部分的tokens的偏导数(如对 x 0 x_0 x0的偏导数 ∂ h m x 0 \frac{\partial h_m}{x_0} x0hm)会接近0。输入不同的 x 0 x_0 x0,隐藏状态向量 h m h_m hm几乎不会发生变化,即RNN会遗忘输入序列前面部分的信息。

本文不会详细介绍RNN的原理,大语言模型的神经网络中没有循环结构,RNN的原理及结构与大语言模型没有关系。对RNN的原理感兴趣读者可以参见本人的博客专栏:自然语言处理。

2014年,文章Neural Machine Translation by Jointly Learning to Align and Translate提出了一种改进sequence-to-sequence模型的方法,使Decoder每次更新状态时会查看Encoder所有状态,从而避免RNN遗忘的问题,而且可以让Decoder关注Encoder中最相关的信息,这也是attention名字的由来。

2017年,文章Attention Is All You Need指出可以剥离RNN,仅保留attention,且attention并不局限于sequence-to-sequence模型,可以直接用在输入序列数据上,构建self-attention,并提出了基于attention的sequence-to-sequence架构模型Transformer。

3. 简单自注意力机制

自注意力机制的目标是计算输入文本序列中各个token与序列中所有tokens之间的相关性,并生成包含这种相关性信息的context向量。如下图所示,简单自注意力机制生成context向量的计算步骤如下:

  1. 计算注意力分数(attention score):简单注意力机制使用向量的点积(dot product)作为注意力分数,注意力分数可以衡量两个向量的相关性;
  2. 计算注意力权重(attention weight):将注意力分数归一化得到注意力权重,序列中每个token与序列中所有tokens之间的注意力权重之和等于1;
  3. 计算context向量:简单注意力机制将所有tokens对应Embedding向量的加权和作为context向量,每个token对应Embedding向量的权重等于其相应的注意力权重。

图一

3.1 计算注意力分数

对输入文本序列***“Your journey starts with one step.”* **做tokenization,将文本中每个单词分割成一个token,并转换成Embedding向量,得到 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6。自注意力机制分别计算 x i x_i xi x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的注意力权重,进而计算 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6与其相应注意力权重的加权和,得到context向量 z i z_i zi

如下图所示,将context向量 z i z_i zi对应的向量 x i x_i xi称为query向量,计算query向量 x 2 x_2 x2对应的context向量 z 2 z_2 z2的第一步是计算注意力分数。将query向量 x 2 x_2 x2分别点乘向量 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6,得到实数 ω 21 , ω 22 , ⋯   , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26,其中 ω 2 i \omega_{2i} ω2i是query向量 x 2 x_2 x2与向量 x i x_i xi的注意力分数,可以衡量 x 2 x_2 x2对应token与 x i x_i xi对应token之间的相关性。

图二

两个向量的点积等于这两个向量相同位置元素的乘积之和。假如向量 x 1 = ( x 11 , x 12 , x 13 ) x_1=(x_{11}, x_{12}, x_{13}) x1=(x11,x12,x13),向量 x 2 = ( x 21 , x 22 , x 23 ) x_2=(x_{21}, x_{22}, x_{23}) x2=(x21,x22,x23),则向量 x 1 x_1 x1 x 2 x_2 x2的点积等于 x 11 × x 21 + x 12 × x 22 + x 13 × x 23 x_{11}\times x_{21} + x_{12}\times x_{22} + x_{13}\times x_{23} x11×x21+x12×x22+x13×x23

可以使用如下代码计算query向量 x 2 x_2 x2 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的注意力分数:

import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

执行上面代码,打印结果如下:

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

3.2 计算注意力权重

如下图所示,将注意力分数 ω 21 , ω 22 , ⋯   , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26归一化可得到注意力权重 α 21 , α 22 , ⋯   , α 26 \alpha_{21}, \alpha_{22}, \cdots, \alpha_{26} α21,α22,,α26。每个注意力权重 α 2 i \alpha_{2i} α2i的值均介于0到1之间,所有注意力权重的和 ∑ i α 2 i = 1 \sum_i\alpha_{2i}=1 iα2i=1。可以用注意力权重 α 2 i \alpha_{2i} α2i表示 x i x_i xi对当前context向量 z 2 z_2 z2的重要性占比,注意力权重 α 2 i \alpha_{2i} α2i越大,表示 x i x_i xi x 2 x_2 x2的相关性越强,context向量 z 2 z_2 z2 x i x_i xi的信息量比例应该越高。使用注意力权重对 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6加权求和计算context向量,可以使context向量的数值分布范围始终与 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6一致。这种数值分布范围的一致性可以使大语言模型训练过程更稳定,模型更容易收敛。

图三

可以使用softmax函数将注意力分数归一化得到注意力权重:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

执行上面代码,打印结果如下:

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

3.3 计算context向量

简单注意力机制使用所有tokens对应Embedding向量的加权和作为context向量,context向量 z 2 = ∑ i α 2 i x i z_2=\sum_i\alpha_{2i}x_i z2=iα2ixi

图四

可以使用如下代码计算context向量 z 2 z_2 z2

query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

执行上面代码,打印结果如下:

tensor([0.4419, 0.6515, 0.5683])

3.4 计算所有tokens对应的context向量

将向量 x 2 x_2 x2作为query向量,按照3.1所述方法,可以计算出注意力分数 ω 21 , ω 22 , ⋯   , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26。使用softmax函数将注意力分数 ω 21 , ω 22 , ⋯   , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26归一化,可以得到注意力权重 α 21 , α 22 , ⋯   , α 26 \alpha_{21}, \alpha_{22}, \cdots, \alpha_{26} α21,α22,,α26。Context向量 z 2 z_2 z2是使用注意力权重对 x 1 , x 2 , ⋯   , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的加权和。

计算所有tokens对应的context向量,可以使用矩阵乘法运算,分别将各个 x i x_i xi作为query向量,一次性批量计算注意力分数及注意力权重,并最终得到context向量 z i z_i zi

如下面代码所示,可以使用矩阵乘法,一次性计算出所有注意力分数:

attn_scores = inputs @ inputs.T
print(attn_scores)

执行上面代码,打印结果如下:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

@操作符是PyTorch中的矩阵乘法运算符号,与函数torch.matmul运算逻辑相同。

将一个 n n n m m m列的矩阵 A A A与另一个 m m m n n n B B B的矩阵相乘,结果 C C C是一个 n n n n n n列的矩阵。其中矩阵 C C C i i i j j j列元素等于矩阵 A A A的第 i i i行与矩阵 B B B的第 j j j列两个向量的内积。

如下面代码所示,使用softmax函数注意力分数归一化,可以一次批量计算出所有注意力权重:

attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

执行上面代码,打印结果如下:

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

可以同样使用矩阵乘法运算,一次性批量计算出所有context向量:

all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

执行上面代码,打印结果如下:

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

4. 结束语

自注意力机制是大语言模型神经网络结构中最复杂的部分。为降低自注意力机制原理的理解门槛,本文介绍了一种不带任何训练参数的简化版自注意力机制。

自注意力机制的目标是计算输入文本序列中各个token与序列中所有tokens之间的相关性,并生成包含这种相关性信息的context向量。简单自注意力机制生成context向量共3个步骤,首先计算注意力分数,然后使用softmax函数将注意力分数归一化得到注意力权重,最后使用注意力权重对所有tokens对应的Embedding向量加权求和得到context向量。

接下来,该去看看大语言模型中真正使用到的注意力机制了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1902887.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【瑞数补环境实战】某网站Cookie补环境与后缀分析还原

文章目录 1. 写在前面2. 特征分析3. 接口分析3. 补JS环境4. 补后缀参数 【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、CSDN博客专家、阿里云博客专家、华为云享专家。一路走…

2.Python学习:数据类型和变量

1.标识符命名规则 只能由数字、字母、下划线组成不能以数字开头不能是关键字(如class等python内部已经使用的标识符)区分大小写 查看关键字: print(keyword.kwlist)2.数据类型 2.1常见数据类型 2.1.1Number数值型: 整数int&a…

LLM - 卷积神经网络(CNN)

1. 卷积神经网络结构:分为输入层,卷积层,池化层,全连接层; (1)首先进入输入层,对数据数据进行处理,将输入数据向量化处理,最终形成输入矩阵。 (…

python-22-零基础自学python-数据分析基础 打开文件 读取文件信息

学习内容:《python编程:从入门到实践》第二版 知识点: 读取文件 、逐行读取文件信息等 练习内容: 练习10-1:Python学习笔记 在文本编辑器中新建一个文件,写几句话来总结一下你至此学到的Python知识,其中…

Android C++系列:Linux Socket编程(三)CS模型示例

1. TCP通信 下图是基于TCP协议的客户端/服务器程序的一般流程: 服务器调用socket()、bind()、listen()完成初始化后,调用accept()阻塞等待,处于 监听端口的状态,客户端调用socket()初始化后,调用connect()发出SYN段并阻塞等待服 务器应答,服务器应答一个SYN-ACK段,客户…

【Python实战因果推断】22_倾向分2

目录 Propensity Score Propensity Score Estimation Propensity Score and ML Propensity Score and Orthogonalization Propensity Score 倾向加权法围绕倾向得分的概念展开,其本身源于这样一种认识,即不需要直接控制混杂因素 X 来实现条件独立性…

ARM架构和Intel x86架构

文章目录 1. 处理器架构 2. ARM架构 3. Intel x86架构 4. 架构对比 1. 处理器架构 处理器架构是指计算机处理器的设计和组织方式,它决定了处理器的性能、功耗和功能特性。处理器架构影响着从计算机系统的硬件设计到软件开发的各个方面。在现代计算技术中&#…

暑期备考2024年汉字小达人:吃透18道选择题真题(持续)

结合最近几年的活动安排,预计2024年第11届汉字小达人比赛还有3个多月就启动。那么孩子如何在2024年的汉字小达人活动中取得好成绩呢?根据以往成绩优秀学员的经验,利用暑假集中备考效率最高。把汉字小达人的备考纳入到暑期学习计划&#xff0c…

【Selenium配置】WebDriver安装浏览器驱动(ChromeEdge)

【Selenium配置】WebDriver安装浏览器驱动(Chrome&Edge) 文章目录 【Selenium配置】WebDriver安装浏览器驱动(Chrome&Edge)Chrome确认Chrome版本下载对应driver把解压后的chromedriver文件放在chrome安装目录下&#xff0…

blender 纹理绘制-贴花方式

贴画绘制-1分钟blender_哔哩哔哩_bilibili小鸡老师的【Blender风格化角色入门教程】偏重雕刻建模https://www.cctalk.com/m/group/90420100小鸡老师最新的【风格化角色全流程进阶教程】偏重绑定。早鸟价进行中!欢迎试听https://www.cctalk.com/m/group/90698829, 视…

STM32-TIM定时器

本内容基于江协科技STM32视频内容,整理而得。 文章目录 1. TIM1.1 TIM定时器1.2 定时器类型1.3 基本定时器1.4 通用定时器1.4 高级定时器1.5 定时中断基本结构1.6 预分频器时序1.7 计数器时序1.8 计数器无预装时序1.9 计数器有预装时序1.10 RCC时钟树 2. TIM库函数…

基于单片机的多功能电子时钟的设计

摘要:提出了一种基于单片机的多功能电子时钟的设计方法,以 AT89C52单片机作为系统的主控芯片,采用DS1302作为时钟控制芯片,实现日期时钟显示并且提供精准定时的功能。此外,还可经由DHT22所构成的温湿度传感电路&#x…

JVM原理(十九):JVM虚拟机内存模型

1. 硬件的效率与一致性 数据不安全的原因:缓存一致性的问题 共享内存多核系统:在多路处理器系统中,每个处理器都有自己的高速缓存,而他们又共享同一主内存。 线程先后执行结果不一致问题:除了增加高速缓存之外&#…

Kotlin算法:把一个整数向上取值为最接近的2的幂指数值

Kotlin算法&#xff1a;把一个整数向上取值为最接近的2的幂指数值 import kotlin.math.ln import kotlin.math.powfun main(args: Array<String>) {val number intArrayOf(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)number.forEach {println("$…

IT高手修炼手册(3)程序员命令

一、前言 程序员在日常工作中&#xff0c;掌握一些高效的快捷键可以大大提高编码和开发效率。 二、通用快捷键 文本操作Ctrl A&#xff1a;全选当前页面内容 Ctrl C&#xff1a;复制当前选中内容 Ctrl V&#xff1a;粘贴当前剪贴板内的内容 Ctrl X&#xff1a;剪切当前选中…

【基于R语言群体遗传学】-8-代际及时间推移对于变异的影响

上一篇博客&#xff0c;我们学习了在非选择下&#xff0c;以二项分布模拟遗传漂变的过程&#xff1a;【基于R语言群体遗传学】-7-遗传变异&#xff08;genetic variation&#xff09;-CSDN博客 那么我们之前有在代际之间去模拟&#xff0c;那么我们就想知道&#xff0c;遗传变…

Java版Flink使用指南——安装Flink和使用IntelliJ制作任务包

大纲 安装Flink操作系统安装JDK安装Flink修改配置启动Flink测试 使用IntelliJ制作任务包新建工程Archetype 编写测试代码打包测试 参考资料 在《0基础学习PyFlink》专题中&#xff0c;我们熟悉了Flink的相关知识以及Python编码方案。这个系列我们将使用相对主流的Java语言&…

SQL Server 2022 中的 Tempdb 性能改进非常显著

无论是在我的会话中还是在我写的博客中&#xff0c;Tempdb 始终是我的话题。然而&#xff0c;当谈到 SQL Server 2022 中引入的重大性能变化时&#xff0c;我从未如此兴奋过。他们解决了我们最大的性能瓶颈之一&#xff0c;即系统页面闩锁并发。 在 SQL Server 2019 中&#x…

python函数和c的区别有哪些

Python有很多内置函数&#xff08;build in function&#xff09;&#xff0c;不需要写头文件&#xff0c;Python还有很多强大的模块&#xff0c;需要时导入便可。C语言在这一点上远不及Python&#xff0c;大多时候都需要自己手动实现。 C语言中的函数&#xff0c;有着严格的顺…

【pyhton学习】深度理解类和对象

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 文章目录 一、一切皆对象1.1 对象的概念1.2 如何创建类对象1.3 类型检测 二、属性与方法2.1 如何查看属性与方法2.2 属性和方法…