LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

news2025/1/27 15:55:03

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145368666


GQA
GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力机制的变体,通过将查询(Query)分组,每组与相同的键(Key)值(Value)交互,优化计算效率和性能,保持模型对于输入信息有效关注,减少计算资源的消耗,适用于处理大规模数据和复杂任务的场景。KVCache 是缓存机制,用于存储和快速检索键值对(KV),当模型处理新的输入(Q)时,直接从缓存中读取KV数据,无需重新计算,显著提高模型的推理速度和效率。GQA 与 KVCache 在提升模型性能和优化资源利用方面,都发挥着重要作用,结合使用可以进一步增强模型在实际应用中的表现。

从 MHA 到 GQA,再到 GQA+KVCache,简单实现,参考:

  • GQA:从头实现 LLaMA3 网络与推理流程
  • KVCache:GPT(Decoder Only) 类模型的 KV Cache 公式与原理

Scaled Dot-Product Attention (缩放点积注意力机制),也称单头自注意力机制,公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QK)V

1. MultiHeadAttention

MultiHeadAttention (多头注意力机制),合计 43 行:

  1. __init__ 初始化 (10行):
    • 输入:heads(头数)、d_model(维度)、dropout (用于 scores)
    • 计算 d_k 每个 Head 的维度,即 d m o d e l = h e a d s × d k d_{model} = heads \times d_{k} dmodel=heads×dk
    • 线性层是 QKVO,Dropout 层
  2. attention 注意力 (10行):
    • q q q 的维度 [bs,h,s,d],与 k ⊤ k^{\top} k[bs,h,d,s],mm 之后 scores 是 [bs,h,s,s]
    • mask 的维度是 [bs,s,s],使用 unsqueeze(1),转换成 [bs,1,s,s]
    • QKV 的计算,额外支持 Dropout
  3. forward 推理 (12行):
    • QKV Linear 转换成 [bs,s,h,dk],再转换 [bs,h,s,dk]
    • 计算 attn 的 [bs,h,s,dk]
    • 转换 [bs,s,h,dk],再 contiguous(),再 合并 h × d k = d h \times d_{k} = d h×dk=d
    • 再过 O
  4. 测试 (11行):
    • torch.randn 构建数据
    • Mask 的 torch.tril(torch.ones(bs, s, s))

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制 MultiHeadAttention
    """
    def __init__(self, heads, d_model, dropout=0.1):  # 10行
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):  # 10行
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None):  # 12行
        bs = q.size(0)
        # 进行线性操作划分为成 h 个头
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        # 矩阵转置
        k = k.transpose(1, 2)  # [bs,h,s,d] = [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)
        print(f"[Info] attn: {attn.shape}")
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output
def main():
    # 设置超参数
    bs, s, h, d = 2, 10, 8, 512
    dropout_rate = 0.1
    # 创建 MultiHeadAttention 实例
    attention = MultiHeadAttention(h, d, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 测试无掩码的情况
    output_no_mask = attention(q, k, v)
    print("Output shape without mask:", output_no_mask.shape)
    # 测试有掩码的情况
    output_with_mask = attention(q, k, v, mask)
    print("Output shape with mask:", output_with_mask.shape)
    # 检查输出是否符合预期
    assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"
    assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"
    print("Test passed!")
if __name__ == '__main__':
    main()

2. GroupQueryAttention

GroupQueryAttention (分组查询注意力机制),相比于 MHA,参考 torch.nn.functional.scaled_dot_product_attention

  1. __init__ :增加参数 kv_heads,即 KV Head 数量,KV 的 Linear 层输出维度(kv_heads * self.d_k)也需要修改。
  2. forward:使用 repeat_interleave 扩充 KV 维度,其他相同,增加 3 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):
    """
    分组查询注意力机制(Group Query Attention)
    """
    def __init__(self, heads, d_model, kv_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.kv_heads = kv_heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):
        # [2, 8, 10, 64] x [2, 8, 64, 10] = [2, 8, 10, 10]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        # 进行线性操作
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 10, 8, 64]
        k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 10, 4, 64]
        v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)
        # 复制键值头以匹配查询头的数量
        group = self.h // self.kv_heads
        k = k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]
        v = v.repeat_interleave(group, dim=2)
        # 矩阵转置, 将 head 在前
        k = k.transpose(1, 2)  # [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output
def main():
    # 设置超参数, GQA 8//4=2组
    bs, s, h, d, kv_heads = 2, 10, 8, 512, 4
    dropout_rate = 0.1
    # 创建 MultiHeadAttention 实例
    attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 测试无掩码的情况
    output_no_mask = attention(q, k, v)
    print("Output shape without mask:", output_no_mask.shape)
    # 测试有掩码的情况
    output_with_mask = attention(q, k, v, mask)
    print("Output shape with mask:", output_with_mask.shape)
    # 检查输出是否符合预期
    assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"
    assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"
    print("Test passed!")
if __name__ == '__main__':
    main()

3. GQA + KVCache

GroupQueryAttention + KVCache,相比于 GQA,增加 KVCache:

  1. forward :增加参数 kv_cache,合并 [cached_k, new_k],同时返回 new_kv_cache,用于迭代,增加 5 行。
  2. 设置 cur_qkvcur_mask,迭代序列s维度,合计 8 行。

即:

import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):
    """
    分组查询注意力机制(Group Query Attention)
    """
    def __init__(self, heads, d_model, kv_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.kv_heads = kv_heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):
        # [2, 8, 1, 64] x [2, 8, 64, 10] = [2, 8, 1, 10]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None, kv_cache=None):
        bs = q.size(0)
        # 进行线性操作
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 1, 8, 64]
        new_k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]
        new_v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]
        # 处理 KV Cache
        if kv_cache is not None:
            cached_k, cached_v = kv_cache
            new_k = torch.cat([cached_k, new_k], dim=1)
            new_v = torch.cat([cached_v, new_v], dim=1)
        # 复制键值头以匹配查询头的数量
        group = self.h // self.kv_heads
        k = new_k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]
        v = new_v.repeat_interleave(group, dim=2)
        # 矩阵转置, 将 head 在前
        # KV Cache 最后1轮: q—>[2, 8, 1, 64] k->[2, 8, 10, 64] v->[2, 8, 10, 64]
        k = k.transpose(1, 2)  # [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)  # [2, 8, 1, 64]
        print(f"[Info] attn: {attn.shape}")
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        # 更新 KV Cache
        new_kv_cache = (new_k, new_v)  # 当前的 KV 缓存
        return output, new_kv_cache
def main():
    # 设置超参数
    bs, s, h, d, kv_heads = 2, 10, 8, 512, 4
    dropout_rate = 0.1
    # 创建 GroupQueryAttention 实例
    attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 模拟逐步生成序列,测试 KV Cache
    print("Testing KV Cache...")
    kv_cache, output = None, None
    for i in range(s):
        cur_q = q[:, i:i+1, :]
        cur_k = k[:, i:i+1, :]
        cur_v = v[:, i:i+1, :]
        cur_mask = mask[:, i:i+1, :i+1]   # q是 i:i+1,k是 :i+1
        output, kv_cache = attention(cur_q, cur_k, cur_v, cur_mask, kv_cache)
        print(f"Output shape at step {i}:", output.shape)
    # 检查输出是否符合预期
    assert output.shape == (bs, 1, d), "Output shape is incorrect when using KV Cache"
    print("Test passed!")
if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2284013.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kotlin开发(七):对象表达式、对象声明和委托的奥秘

Kotlin 让代码更优雅! 每个程序员都希望写出优雅高效的代码,但现实往往不尽人意。对象表达式、对象声明和 Kotlin 委托正是为了解决代码中的复杂性而诞生的。为什么选择这个主题?因为它不仅是 Kotlin 语言的亮点之一,还能极大地提…

数据库、数据仓库、数据湖有什么不同

数据库、数据仓库和数据湖是三种不同的数据存储和管理技术,它们在用途、设计目标、数据处理方式以及适用场景上存在显著差异。以下将从多个角度详细说明它们之间的区别: 1. 数据结构与存储方式 数据库: 数据库主要用于存储结构化的数据&…

【2024年华为OD机试】 (B卷,100分)- 字符串摘要(JavaScriptJava PythonC/C++)

一、问题描述 题目描述 给定一个字符串的摘要算法,请输出给定字符串的摘要值。具体步骤如下: 去除字符串中非字母的符号:只保留字母字符。处理连续字符:如果出现连续字符(不区分大小写),则输…

DIY QMK量子键盘

最近放假了,趁这个空余在做一个分支项目,一款机械键盘,量子键盘取自固件名称QMK(Quantum Mechanical Keyboard)。 键盘作为计算机或其他电子设备的重要输入设备之一,通过将按键的物理动作转换为数字信号&am…

mamba论文学习

rnn 1986 训练速度慢 testing很快 但是很快就忘了 lstm 1997 训练速度慢 testing很快 但是也会忘(序列很长的时候) GRU实在lstm的基础上改进,改变了一些门 transformer2017 训练很快,testing慢些,时间复杂度高&am…

智慧消防营区一体化安全管控 2024 年度深度剖析与展望

在 2024 年,智慧消防营区一体化安全管控领域取得了令人瞩目的进展,成为保障营区安全稳定运行的关键力量。这一年,行业在政策驱动、技术创新应用、实践成果及合作交流等方面呈现出多元且深刻的发展态势,同时也面临着一系列亟待解决…

解锁微服务:五大进阶业务场景深度剖析

目录 医疗行业:智能诊疗的加速引擎 电商领域:数据依赖的破局之道 金融行业:运维可观测性的提升之路 物流行业:智慧物流的创新架构 综合业务:服务依赖的优化策略 医疗行业:智能诊疗的加速引擎 在医疗行业迈…

javascript-es6 (一)

作用域(scope) 规定了变量能够被访问的“范围”,离开了这个“范围”变量便不能被访问 局部作用域 函数作用域: 在函数内部声明的变量只能在函数内部被访问,外部无法直接访问 function getSum(){ //函数内部是函数作用…

jenkins-k8s pod方式动态生成slave节点

一. 简述: 使用 Jenkins 和 Kubernetes (k8s) 动态生成 Slave 节点是一种高效且灵活的方式来管理 CI/CD 流水线。通过这种方式,Jenkins 可以根据需要在 Kubernetes 集群中创建和销毁 Pod 来执行任务,从而充分利用集群资源并实现更好的隔离性…

【云安全】云原生-K8S-简介

K8S简介 Kubernetes(简称K8S)是一种开源的容器编排平台,用于管理容器化应用的部署、扩展和运维。它由Google于2014年开源并交给CNCF(Cloud Native Computing Foundation)维护。K8S通过提供自动化、灵活的功能&#xf…

aws(学习笔记第二十六课) 使用AWS Elastic Beanstalk

aws(学习笔记第二十六课) 使用aws Elastic Beanstalk 学习内容: AWS Elastic Beanstalk整体架构AWS Elastic Beanstalk的hands onAWS Elastic Beanstalk部署node.js程序包练习使用AWS Elastic Beanstalk的ebcli 1. AWS Elastic Beanstalk整体架构 官方的guide AWS…

反向代理模块。。

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求,然后将请求转发给内部网络上的服务器,将从服务器上得到的结果返回给客户端,此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说,反向代理就相当于…

C语言的灵魂——指针(1)

指针是C语言的灵魂,有了指针C语言才能完成一些复杂的程序;没了指针就相当于C语言最精髓的部分被去掉了,可见指针是多么重要。废话不多讲我们直接开始。 指针 一,内存和地址二,编址三,指针变量和地址1&#…

14-6-2C++STL的list

(一&#xff09;list对象的带参数构造 1.list&#xff08;elem);//构造函数将n个elem拷贝给本身 #include <iostream> #include <list> using namespace std; int main() { list<int> lst(3,7); list<int>::iterator it; for(itlst.begi…

Ubuntu Server 安装 XFCE4桌面

Ubuntu Server没有桌面环境&#xff0c;一些软件有桌面环境使用起来才更加方便&#xff0c;所以我尝试安装桌面环境。常用的桌面环境有&#xff1a;GNOME、KDE Plasma、XFCE4等。这里我选择安装XFCE4桌面环境&#xff0c;主要因为它是一个极轻量级的桌面环境&#xff0c;适合内…

一个简单的自适应html5导航模板

一个简单的 HTML 导航模板示例&#xff0c;它包含基本的导航栏结构&#xff0c;同时使用了 CSS 进行样式美化&#xff0c;让导航栏看起来更美观。另外&#xff0c;还添加了一些 JavaScript 代码&#xff0c;用于在移动端实现导航菜单的展开和收起功能。 PHP <!DOCTYPE htm…

实现B-树

一、概述 1.历史 B树&#xff08;B-Tree&#xff09;结构是一种高效存储和查询数据的方法&#xff0c;它的历史可以追溯到1970年代早期。B树的发明人Rudolf Bayer和Edward M. McCreight分别发表了一篇论文介绍了B树。这篇论文是1972年发表于《ACM Transactions on Database S…

无人机微波图像传输数据链技术详解

无人机微波图像传输数据链技术是无人机通信系统中的关键组成部分&#xff0c;它确保了无人机与地面站之间高效、可靠的图像数据传输。以下是对该技术的详细解析&#xff1a; 一、技术原理 无人机微波图像传输数据链主要基于微波通信技术实现。在数据链路中&#xff0c;图像数…

macos的图标过大,这是因为有自己的设计规范

苹果官方链接&#xff1a;App 图标 | Apple Developer Documentation 这个在官方文档里有说明&#xff0c;并且提供了sketch 和 ps 的模板。 figma还提供了模板&#xff1a; Figma

微信阅读网站小程序的设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导&#xff0c;欢迎高校老师/同行前辈交流合作✌。 技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;…