【深度学习】实验 — 动手实现 GPT【二】:注意力机制、注意力掩码、多头注意力机制

news2024/11/27 2:19:20

【深度学习】实验 — 动手实现 GPT【二】:注意力机制、多头注意力机制

  • 注意力机制
    • 简单示例:单个元素的情况
    • 简单示例:计算所有输入词元的注意力权重
      • 推广到所有输入序列词元:
  • 注意力掩码
  • 代码实现多头注意力
  • 测试

注意力机制

简单示例:单个元素的情况

  • 假设我们有以下输入句子,已按照第 3 章中的描述嵌入为 3 维向量(此处使用非常小的嵌入维度,仅用于说明,方便在页面上显示而不换行):
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
  • (在本书中,我们遵循机器学习和深度学习的常见惯例,即训练样本表示为行,特征值表示为列;在上面的张量中,每一行表示一个词,每一列表示一个嵌入维度。)

  • 本节的主要目的是演示如何使用第二个输入序列 x ( 2 ) x^{(2)} x(2) 作为查询,计算上下文向量 z ( 2 ) z^{(2)} z(2)

  • 图示展示了该过程的初始步骤,其中通过点积操作计算 x ( 2 ) x^{(2)} x(2) 与所有其他输入元素之间的注意力分数 ω。

请添加图片描述

  • 我们使用输入序列中的元素 2,即 x ( 2 ) x^{(2)} x(2),作为示例来计算上下文向量 z ( 2 ) z^{(2)} z(2);在本节稍后,我们将推广此方法来计算所有的上下文向量。
  • 第一步是通过计算查询 x ( 2 ) x^{(2)} x(2) 与所有其他输入词元之间的点积,得到未归一化的注意力分数:
query = inputs[1]  # 2nd input token is the query

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)

print(attn_scores_2)

输出

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
  • 步骤 2: 将未归一化的注意力分数(“omegas”, ω \omega ω)归一化,使其总和为 1。
  • 以下是一种简单的归一化方法,使未归一化的注意力分数总和为 1(这种方式是约定俗成的,有助于解释,并对训练稳定性非常重要):

请添加图片描述

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

输出

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)
  • 然而,在实际操作中,通常推荐使用 softmax 函数进行归一化,因为它在处理极端值方面更有效,并且在训练过程中具有更理想的梯度特性。
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

输出

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)
  • 步骤 3:通过将嵌入的输入词元 x ( i ) x^{(i)} x(i) 与注意力权重相乘,并将所得向量求和,计算上下文向量 z ( 2 ) z^{(2)} z(2)请添加图片描述
query = inputs[1] # 2nd input token is the query

context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i

print(context_vec_2)

输出

tensor([0.4419, 0.6515, 0.5683])

简单示例:计算所有输入词元的注意力权重

推广到所有输入序列词元:

  • 上面我们计算了输入 2 的注意力权重和上下文向量。

  • 接下来,我们将推广该计算,以求得所有的注意力权重和上下文向量。
    请添加图片描述

  • (请注意,此图中的数字已截取至小数点后两位,以减少视觉杂乱;每行的值应相加为 1.0 或 100%;同样,其他图中的数字也被截取。)

  • 在自注意力机制中,首先计算注意力分数,随后对其进行归一化以得出总和为 1 的注意力权重。

  • 然后,这些注意力权重被用于通过输入的加权求和生成上下文向量。

请添加图片描述

  • 将之前的步骤 1应用于所有成对元素,以计算未归一化的注意力分数矩阵:
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

输出

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
  • 我们可以通过矩阵乘法更高效地实现上述计算:
attn_scores = inputs @ inputs.T
print(attn_scores)

输出

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
  • 与之前的步骤 2类似,我们对每一行进行归一化,使每一行的值相加为 1:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

输出

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
  • 应用之前的步骤 3来计算所有上下文向量:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

输出

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

注意力掩码

  • 模型在序列中某一位置的预测仅依赖于之前位置的已知输出,而不依赖未来位置的输出。
  • 简单来说,这确保了每个下一个词的预测仅依赖于前面的词。
  • 为了实现这一点,对于每个给定词元,我们将未来的词元(即在当前词元之后的词元)进行掩码处理:
    请添加图片描述
attn_weights

输出

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
  • 最简单的方式是通过 PyTorch 的 tril 函数创建一个掩码,将主对角线下方的元素(包括主对角线)设置为 1,主对角线上方的元素设置为 0,以掩盖未来的注意力权重:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
  • 然后,我们可以将注意力权重与此掩码相乘,以将对角线上方的注意力分数置为零:
masked_simple = attn_weights*mask_simple
print(masked_simple)
tensor([[0.2098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1385, 0.2379, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1390, 0.2369, 0.2326, 0.0000, 0.0000, 0.0000],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.0000, 0.0000],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.0000],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
  • 然而,如果在 softmax 之后应用掩码(如上所述),会破坏 softmax 创建的概率分布。
  • Softmax 确保所有输出值的总和为 1。
  • 在 softmax 之后进行掩码处理则需要重新归一化输出以再次使其总和为 1,这会使过程复杂化,并可能导致意想不到的效果。
  • 为确保每行的总和为 1,我们可以按如下方式归一化注意力权重:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
  • 让我们简单了解一种更高效的方法来实现上述目标。
  • 因此,与其将对角线上方的注意力权重置零并重新归一化结果,我们可以在未归一化的注意力分数进入 softmax 函数之前,将对角线上方的分数掩码为负无穷大。请添加图片描述
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

输出

tensor([[0.9995,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.9544, 1.4950,   -inf,   -inf,   -inf,   -inf],
        [0.9422, 1.4754, 1.4570,   -inf,   -inf,   -inf],
        [0.4753, 0.8434, 0.8296, 0.4937,   -inf,   -inf],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654,   -inf],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
  • 如下所示,现在每行的注意力权重再次正确地总和为 1:
attn_weights = torch.softmax(masked, dim=-1)
print(attn_weights)

输出

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3680, 0.6320, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2284, 0.3893, 0.3822, 0.0000, 0.0000, 0.0000],
        [0.2046, 0.2956, 0.2915, 0.2084, 0.0000, 0.0000],
        [0.1753, 0.2250, 0.2269, 0.1570, 0.2158, 0.0000],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

代码实现多头注意力

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

测试

batch = torch.stack((inputs, inputs), dim=0)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

输出

tensor([[[-0.6033, -0.2785],
         [-0.5409, -0.2509],
         [-0.5241, -0.2439],
         [-0.4974, -0.2357],
         [-0.5224, -0.2520],
         [-0.4887, -0.2361]],

        [[-0.6033, -0.2785],
         [-0.5409, -0.2509],
         [-0.5241, -0.2439],
         [-0.4974, -0.2357],
         [-0.5224, -0.2520],
         [-0.4887, -0.2361]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
  • 另外请注意,我们在上面的 MultiHeadAttention 类中添加了一个线性投影层 (self.out_proj)。这只是一个不会改变维度的线性变换。在大型语言模型的实现中,使用这样的投影层是一个标准惯例,但并非绝对必要(最近的研究表明,移除该层不会影响模型性能);

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2231881.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

简单的kafkaredis学习之kafka

简单的kafka&redis学习整理之kafka 1. kafka 1.1 什么是消息队列 在学习Kafka之前我们先来看一下什么是消息队列&#xff0c;消息队列(Message Queue)&#xff1a;可以简称为MQ 例如&#xff1a;Java中的Queue队列&#xff0c;也可以认为是一个消息队列 消息队列&#x…

基于人工智能的搜索和推荐系统

互联网上的搜索历史分析和用户活动是个性化推荐的基础&#xff0c;这些推荐已成为电子商务行业和在线业务的强大营销工具。随着人工智能的使用&#xff0c;在线搜索也在改进&#xff0c;因为它会根据用户的视觉偏好提出建议&#xff0c;而不是根据每个客户的需求和偏好量身定制…

ssm042在线云音乐系统的设计与实现+jsp(论文+源码)_kaic

摘 要 随着移动互联网时代的发展&#xff0c;网络的使用越来越普及&#xff0c;用户在获取和存储信息方面也会有激动人心的时刻。音乐也将慢慢融入人们的生活中。影响和改变我们的生活。随着当今各种流行音乐的流行&#xff0c;人们在日常生活中经常会用到的就是在线云音乐系统…

TVS 静电管 选型

参数选型举例: 静电管选型举例: 针对信号引脚一般只需ESD防护,关注其在IEC 61000−4−2波形下的测试结果:最大耐压值、钳位电压等,注意此时钳位电压的限值就不是Absolute maximum ratings值了,原因有2 1、Absolute maximum ratings值是指持续加压会损坏芯片 2、如果关…

监控调度台在交通运输行业的优势?

在当今快速发展的交通运输行业中&#xff0c;高效、安全的管理成为确保运营顺畅和乘客满意的关键。监控调度台作为这一领域的核心设备&#xff0c;正发挥着越来越重要的作用。它集成了视频监控、数据分析、实时通讯等多种功能&#xff0c;为交通运输行业带来了诸多优势。下面我…

华为ENSP--ISIS路由协议

项目背景 为了确保资源共享、办公自动化和节省人力成本&#xff0c;公司E申请两条专线将深圳总部和广州、北京两家分公司网络连接起来。公司原来运行OSFP路由协议&#xff0c;现打算迁移到IS-IS路由协议&#xff0c;张同学正在该公司实习&#xff0c;为了提高实际工作的准确性和…

设计模式07-结构型模式2(装饰模式/外观模式/代理模式/Java)

4.4 装饰模式 4.4.1 装饰模式的定义 1.动机&#xff1a;在不改变一个对象本身功能的基础上给对象增加额外的新行为 2.定义&#xff1a;动态地给一个对象增加一些额外的职责&#xff0c;就增加对象功能来说&#xff0c;装饰模式比生成子类实现更为灵活 4.4.2 装饰模式的结构…

Spring @RequestMapping 注解

文章目录 Spring RequestMapping 注解一、引言二、RequestMapping注解基础1、基本用法2、处理多个URI 三、高级用法1、处理HTTP方法2、参数和消息头处理 四、总结 Spring RequestMapping 注解 一、引言 在Spring框架中&#xff0c;RequestMapping 注解是构建Web应用程序时不可…

【Linux】IPC 进程间通信(一):管道(匿名管道命名管道)

✨ 无人扶我青云志&#xff0c;我自踏雪至山巅 &#x1f30f; &#x1f4c3;个人主页&#xff1a;island1314 &#x1f525;个人专栏&#xff1a;Linux—登神长阶 ⛺️ 欢迎关注&#xff1a;&#x1f44d;点赞 &#…

单片机串口接收状态机STM32

单片机串口接收状态机stm32 前言 项目的芯片stm32转国产&#xff0c;国产芯片的串口DMA接收功能测试不通过&#xff0c;所以要由原本很容易配置的串口空闲中断触发DMA接收数据的方式转为串口逐字节接收的状态机接收数据 两种方式各有优劣&#xff0c;不过我的芯片已经主频跑…

信息学科平台系统开发:基于Spring Boot的最佳实践

3系统分析 3.1可行性分析 通过对本基于保密信息学科平台系统实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本基于保密信息学科平台系统采用Spring Boot框架&a…

探索 ONLYOFFICE 8.2 版本:更高效、更安全的云端办公新体验

引言 在当今这个快节奏的时代&#xff0c;信息技术的发展已经深刻改变了我们的工作方式。从传统的纸质文件到电子文档&#xff0c;再到如今的云端协作&#xff0c;每一步技术进步都代表着效率的飞跃。尤其在后疫情时代&#xff0c;远程办公成为常态&#xff0c;如何保持团队之间…

51c自动驾驶~合集4

我自己的原文哦~ https://blog.51cto.com/whaosoft/12413878 #MCTrack 迈驰&旷视最新MCTrack&#xff1a;KITTI/nuScenes/Waymo三榜单SOTA paper&#xff1a;MCTrack: A Unified 3D Multi-Object Tracking Framework for Autonomous Driving code&#xff1a;https://gi…

STM32HAL-最简单的长、短、多击按键框架(多按键)

概述 本文章使用最简单的写法实现长、短、多击按键框架,非常适合移植各类型单片机,特别是资源少的芯片上。接下来将在stm32单片机上实现,只需占用1个定时器作为时钟扫描按键即可。 一、开发环境 1、硬件平台 STM32F401CEU6 内部Flash : 512Kbytes,SARM …

【论文精读】LPT: Long-tailed prompt tuning for image classification

&#x1f308; 个人主页&#xff1a;十二月的猫-CSDN博客 &#x1f525; 系列专栏&#xff1a; &#x1f3c0;论文精读_十二月的猫的博客-CSDN博客 &#x1f4aa;&#x1f3fb; 十二月的寒冬阻挡不了春天的脚步&#xff0c;十二点的黑夜遮蔽不住黎明的曙光 目录 1. 摘要 2. …

队列的模拟实现

概念&#xff1a; 队列 &#xff1a;只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性表&#xff0c;队列具有先进先出 FIFO(First In First Out) 入队列&#xff1a;进行插入操作的一端称为 队尾&#xff08; Tail/Rear &#xff09; 出队列&a…

Centos安装配置Jenkins

下载安装 注意&#xff1a;推荐的LTS版本对部分插件不适配&#xff0c;直接用最新的版本&#xff0c;jenkins还需要用到git和maven&#xff0c;服务器上已经安装&#xff0c;可查看参考文档[1]、[2]&#xff0c;本次不再演示 访问开始使用 Jenkins 下载jenkins 上传至服务器…

在Python中最小化预测函数的参数

在 Python 中&#xff0c;最小化预测函数的参数通常涉及使用优化算法来调整模型的参数&#xff0c;以减少预测误差。下面介绍几种常见的方法来实现这一目标&#xff0c;主要使用 scipy 和 numpy 库。 1、问题背景 我正在尝试通过解决自己想出的问题来学习Python&#xff0c;我…

统信UOS系统应用开发

包括cpu 、内存 、安全等接口描述。 文章目录 一、内存管理非文件形式的内存动态函数库调用接口二、cpu内置安全飞腾国密加速硬件用户态驱动API说明真随机数真随机数三、cpu多核调度cpu亲和性获取接口用于cpu set集操作的相关宏定义一、内存管理 非文件形式的内存动态函数库调…

postman 获取登录接口中的返回token并设置为环境变量的方法 postman script

postman是一个比较方便的API开发调试工具&#xff0c; 我们在访问API时一般都需要设置一个token来对服务进行认证&#xff0c; 这个token一般都是通过登录接口来获取。 这个postman脚本放到登录接口的sctipt--> post-response里面即可将登陆接口中返回的token值设置到postma…