大模型入门 ch 03:注意力机制

news2024/9/20 19:36:40

本文是github上的大模型教程LLMs-from-scratch的学习笔记,教程地址:教程链接

Chapter 3: Attention Mechanism

本文首先从固定参数的注意力机制说起,然后拓展到可以训练的注意力机制,然后加入掩码mask,最后拓展到多头注意力机制。


1. 注意力机制

一个句子中的每一个token,都会受到其他token的影响(这里先不考虑忽略未来的单词,掩码的问题后面再说),注意力机制可以让一个token收到其他token的影响,生成一个最终我们想要的embedding。即每个token有一个原始的embedding,通过注意力机制后,得到了一个新的embedding,这个embedding是结合了上下文语义得到的。

举个简单的例子,我们直接使用tokens的embedding之间两两点乘,得到互相之间的点乘结果,然后将点乘结果归一化,得到embeddings之间的注意力得分。

归一化一般使用softmax函数,通过取指数,除以求和得到归一化结果
torch.exp(x) / torch.exp(x).sum(dim=0)

得到token之间的相关权重后,我们就可以加权求和,得到每一个token的最终embedding。


2. 可以训练的注意力头

在上面的例子中,我们直接使用token对应的embedding来计算相关系数,以及最终的加权求和,这显然是不合理的,如果这样的话,那么我们只能训练token对应的词嵌入来学习模型,或者是一些全连接层,因此我们需要引入新的矩阵,来学习到更多的参数,这就是transformer的QKV矩阵。
QKV都是对原始的embedding做线性变换,得到新的向量,但是模型就可以通过训练QKV,学习海量知识。

在这里插入图片描述

QKV的维度不固定,可以与原始嵌入相同,也可以不同。总之,通过QKV三个矩阵,我们将原始token的embedding转换成了3个新的向量。

  • Query vector: q ( i ) = W q   x ( i ) q^{(i)} = W_q \,x^{(i)} q(i)=Wqx(i)
  • Key vector: k ( i ) = W k   x ( i ) k^{(i)} = W_k \,x^{(i)} k(i)=Wkx(i)
  • Value vector: v ( i ) = W v   x ( i ) v^{(i)} = W_v \,x^{(i)} v(i)=Wvx(i)

可以使用矩阵乘法实现:

keys = inputs @ W_key 
values = inputs @ W_value

然后我们计算KQ之间的点积,作为两两token之间的关联度。为什么要用两个不一样的矩阵,我的猜测是,如果使用的是一个矩阵计算相似度,那么关于对角线对称的元素就会完全相同,但是使用两个不同的矩阵计算,就不会存在这样的情况,可以学习到的内容更多。

我们使用K和Q的点积得到了两两之间的注意力得分,同样使用softmax进行归一化,得到最终的注意力权重。

注意到没有直接对注意力得分softmax,而是除以维度的方根后再softmax,这是因为在计算注意力权重时,如果直接将Query和Key的点积结果用于softmax函数,当Key的维度较高时,点积的结果会变得非常大。这可能导致softmax函数在梯度下降过程中学习困难,因为大的数值会使softmax的梯度变得非常小(接近于0),这在数值稳定性上是一个问题,称为“梯度消失”。

最后一步,不再使用原始的embedding加权,我们使用V矩阵变换后的向量进行加权求和,得到结果向量。

代码如下

class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

3. 隐藏未来的单词

对语言任务来说,在训练模型的时候不能使用未来的文本来预测之前的文本。因此我们需要屏蔽未见文本对先前文本的影响。在我们计算得到注意力权重后,我们人为地将上三角矩阵的权重置为0。
有一种naive的方法,就是将上注意力权重都置为0后,重新对剩下的元素归一化。但是我们要介绍的是一般使用的方法:

我们在计算出注意力得分后,对右上角元素都赋值为负无穷大,负无穷大在经过softmax后就变为了0。

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)

最后为了防止过拟合,一般会使用dropout,对注意力权重矩阵进行随机丢弃,加强模型泛化性能。

总结以上的所有内容,我们现在就可以写出一个单头的注意力机制了,并且加入了对batch输入的处理:

class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

4. 多头注意力机制

我们已经实现了单个头的注意力机制,那么要实现多个头,就是使用多个不同的注意力头,各自对输入进行处理,然后将各自得到的输出$z_i$拼接起来,非常显而易见,我们有第一个最直白的写法:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

这是一种最简单直白的写法,直接声明num_heads个注意力单元,然后在前向传播的时候,依次调用这num_heads个注意力头,然后将输出拼接起来。(dim=-1代表最后一维拼接)


问题是,这样的话需要循环num_heads次得到结果,并且需要声明num_heads个注意力头,相信熟悉线性代数的朋友已经想到了,可以通过曾广矩阵来拓展注意力头。比如单个的注意力头是(d_in, d_out),那么有n个头的注意力机制就是(d_in, n*d_out)
假设输入是(tokens, d_in),那么(tokens, d_in) @ (d_in, n*d_out) --> (tokens, n*d_out),输出的结果完美得到了n个头对应的输出,我们只需要按照每d_out列拆开,得到n(tokens, d_out)的矩阵,就能还原出n个头对应的结果,进行后续的attention score计算。这样写起来虽然麻烦一些,但是效率更高。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

不同于前者,将不同的注意力头分开计算,第二种方法直接扩展query,key和value矩阵的列数,将多个矩阵运算简化为一个矩阵运算,计算完再更改维度还原成一个个注意力头,效率更高。这样,我们就完成了一个完整的注意力机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2130778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于 onsemi NCV78343 NCV78964的汽车矩阵式大灯方案

一、方案描述 大联大世平集团针对汽车矩阵大灯,推出 基于 onsemi NCV78343 & NCV78964的汽车矩阵式大灯方案。 开发板搭载的主要器件有 onsemi 的 Matrix Controller NCV78343、LED Driver NCV78964、Motor Driver NCV70517、以及 NXP 的 MCU S32K344。 二、开…

抖音微信超火国庆节国旗头像生成源码

源码介绍: 抖音微信超火国庆节国旗头像生成源码,静态页前端生成速度超快!源码直接上传到服务器即可使用。 1、打开地址后点击上传->选一张你喜欢的头像->然后点右边箭头符合选款式->最后点保存头像->按照提示 2、保存到手机即…

开源多场景问答社区论坛Apache Answer本地部署并发布至公网使用

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

CCDO|数据跃动未来:首席数据官如何引领构建活数据引擎

在数字化浪潮汹涌澎湃的今天,数据已成为企业最宝贵的资产之一,它不仅记录着过去,更预示着未来的方向。随着大数据、人工智能、云计算等技术的飞速发展,数据的潜力被前所未有地激发,而首席数据官(CDO&#x…

T4周:猴痘病识别

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客** >- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)** 1. 设置GPU 如果使用的是CPU可以忽略这步 …

Eclipse折叠if、else、try catch的{}

下载插件com.cb.eclipse.folding_1.0.6.jar。将插件放到eclipse的dropins文件夹中。修改配置,然后保存,重启Eclipse即可。

Flink快速上手

Flink快速上手 批处理Maven配置pom文件java编写wordcount代码 有界流处理无界流处理 批处理 Maven配置pom文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://ww…

《深度学习》深度学习 框架、流程解析、动态展示及推导

目录 一、深度学习 1、什么是深度学习 2、特点 3、神经网络构造 1&#xff09;单层神经元 • 推导 • 示例 2&#xff09;多层神经网络 3&#xff09;小结 4、感知器 神经网络的本质 5、多层感知器 6、动态图像示例 1&#xff09;一个神经元 相当于下列状态&…

通信原理:绪论

1、消息、信号与信息 消息&#xff1a; 通信系统要传输的对象&#xff0c;是具体的、物理上存在的东西。也是信息的载体。形式多种&#xff1a; 连续消息&#xff1a;语音、温度、活动图片.离散消息&#xff1a;数据、符号、文字. 信息&#xff1a; 消息中所蕴含的内容&…

proteus+51单片机+实验(LCD1620、定时器)

目录 1.LCD1602液晶显示屏 1.1基本概念 1.1.1LCD的简介 1.1.2LCD的显示原理 ​​​1.1.3LCD的硬件电路 1.1.4LCD的常见指令 1.1.5LCD的时序 ​​​​​​​1.2代码 1.2.1写命令和写数据操作 1.2.2初始化和测试代码 1. 3.3功能函数 1.3proteus代码 1.3.1器件代码 1.…

几种手段mfc140u.dll丢失的解决方法,了解mfc140u.dll

在使用Windows操作系统时&#xff0c;许多用户可能会遇到“找不到mfc140u.dll”或“mfc140u.dll未找到”的错误提示。这个错误通常是由于该文件丢失或损坏所致。本文将详细介绍mfc140u.dll文件的作用、丢失的原因及其解决方法&#xff0c;帮助您快速恢复系统的正常运行。 一、m…

无人机视角的道路损害数据集,2400张图像,包括纵向裂缝(LC)、横向裂缝(TC)、鳄鱼裂缝(AC)、斜裂(OC)、修补(RP)和坑洞(PH),共2.3GB

数据集名称 无人机视角的道路损害数据集 数据集描述 这是一个专注于道路损害检测的数据集&#xff0c;包含了从无人机视角拍摄的2400张高清图像&#xff0c;涵盖了六种典型的道路损害类型&#xff1a;纵向裂缝&#xff08;LC&#xff09;、横向裂缝&#xff08;TC&#xff0…

c++ 点云生成二维俯视图

🙋 结果预览 一、代码实现 #include <pcl/io/pcd_io.h> #include <pcl/point_types.h> #include

S7_1200配方功能快速入门

配方数据文件按照标准 CSV 格式存储在 S7-1200 CPU 装载存储器或 S7-1200 SIMATIC 存储卡“程序卡”中。分别可通过 PLC Web 服务器或对于存储卡文件操作&#xff0c;将数据文件传送到 PC 进行管理和查看。也可将修改过后的配方数据文件上传至PLC&#xff0c;再通过“RecipeImp…

【数据结构】详细介绍各种排序算法,包含希尔排序,堆排序,快排,归并,计数排序

目录 1. 排序 1.1 概念 1.2 常见排序算法 2. 插入排序 2.1 直接插入排序 2.1.1 基本思想 2.1.2 代码实现 2.1.3 特性 2.2 希尔排序(缩小增量排序) 2.2.1 基本思想 2.2.2 单个gap组的比较 2.2.3 多个gap组比较(一次预排序) 2.2.4 多次预排序 2.2.5 特性 3. 选择排…

【AcWing】869. 试除法求约数

约数&#xff1a;当前数能整除这个数。 和判断质数一样的道理&#xff0c;同样是试除法。 约数也一定是成对出现的。在枚举的时候也可以只枚举较小的那一个约数就可以了&#xff0c;较大的那个约数直接算。 #include<iostream> #include<algorithm> #include<…

无人机之处理器篇

无人机的处理器是无人机系统的核心部件之一&#xff0c;它负责控制无人机的飞行、数据处理、任务执行等多个关键功能。以下是对无人机处理器的详细解析&#xff1a; 一、处理器类型 无人机中使用的处理器主要包括以下几种类型&#xff1a; CPU处理器&#xff1a;CPU是无人机的…

JDBC API详解一

DriverManager 驱动管理类&#xff0c;作用&#xff1a;1&#xff0c;注册驱动&#xff1b;2&#xff0c;获取数据库连接 1&#xff0c;注册驱动 Class.forName("com.mysql.cj.jdbc.Driver"); 查看Driver类源码 static{try{DriverManager.registerDriver(newDrive…

中间件常见漏洞

文章目录 中间件漏洞IIS文件解析漏洞1&#xff1a;/xx.asp/xx.jpg 、/xx.asa/xx.jsp2&#xff1a;xx.asp;.jpg3&#xff1a;xx.asa、xx.cer、xx.cdx4&#xff1a;IIS.7/8 CGI配置不当解析漏洞 Apache文件解析漏洞1&#xff1a;apache2.2版本解析漏洞2&#xff1a;其余配置问题…

IMX6 L508EN 模块调试(4G)

一、概述 提起 4G 网络连接&#xff0c;大家可能会觉得是个很难的东西&#xff0c;其实对于嵌入式 Linux 而言&#xff0c;4G 网络连接恰恰相反&#xff0c;不难&#xff01;大家可以看一下其他的嵌入式 Linux 或者 Android 开发板&#xff0c;4G 模块都是 MiniPCIE 接口的&…