一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA

news2024/11/26 10:27:19

第一部分 多头注意力

// 待更

第二部分 LLaMA2之分组查询注意力——Grouped-Query Attention

自回归解码的标准做法是缓存序列中先前标记的键 (K) 和值 (V) 对,从而加快注意力计算速度
然而,随着上下文窗口或批量大小的增加,多头注意力 (MHA)模型中与 KV 缓存大小相关的内存成本显着增长

对于较大的模型,KV 缓存大小成为瓶颈,键和值投影可以在多个头之间共享,而不会大幅降低性能,可以使用

  • 具有单个 KV 投影的原始多查询格式(MQA)
    ChatGLM2-6B即用的这个,详见此文《ChatGLM两代的部署/微调/实现:从基座GLM、ChatGLM的LoRA/P-Tuning微调、6B源码解读到ChatGLM2的微调与实现》的3.1.2节
    不过,多查询注意(Multi-query attention,简称MQA)只使用一个键值头,虽大大加快了解码器推断的速度,但MQA可能导致质量下降,而且仅仅为了更快的推理而训练一个单独的模型可能是不可取的
  • 或具有多个 KV 投影的分组查询注意力(grouped-query attention,简称GQA),速度快 质量高
    23年,还是Google的研究者们提出了一种新的方法,即分组查询注意(GQA,论文地址为:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints),这是一种多查询注意的泛化,它通过折中(多于一个且少于查询头的数量,比如4个)键值头的数量,使得经过强化训练的GQA以与MQA相当的速度达到接近多头注意力的质量

经实验论证,GQA 变体在大多数评估任务上的表现与 MHA 基线相当,并且平均优于 MQA 变体

第三部分 多查询注意力(Muti Query Attention)

3.1 MQA的核心特征:各自Query矩阵,但共享Key 和 Value 矩阵

多查询注意力(Muti Query Attention)是 19 年Google一研究者提出的一种新的 Attention 机制(对应论文为:Fast Transformer Decoding: One Write-Head is All You Need、这是其解读之一),其能够在保证模型效果的同时加快 decoder 生成 token 的速度

那其与17年 Google提出的transformer中多头注意力机制(简称MHA)有啥本质区别呢?有意思的是,区别在于:

  • 我们知道MHA的每个头都各自有一份不同的Key、Query、Value矩阵
  • 而MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量
    总之,MQA 实际上是将 head 中的 key 和 value 矩阵抽出来单独存为一份共享参数,而 query 则是依旧保留在原来的 head 中,每个 head 有一份自己独有的 query 参数

下图对比了多头注意力(Multi-Head Attention)、LLaMA2中分组查询注意力(Grouped-Query Attention)、多查询注意力(Muti Query Attention)的差别

总之,MHA 和 MQA 之间的区别只在于建立 Wqkv Layer 上

# Multi Head Attention
self.Wqkv = nn.Linear(                        # 【关键】Multi-Head Attention 的创建方法
    self.d_model, 
    3 * self.d_model,                         # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
    device=device
)

query, key, value = qkv.chunk(                # 【关键】每个 tensor 都是 (1, 512, 768)
    3, 
    dim=2
)


# Multi Query Attention
self.Wqkv = nn.Linear(                                # 【关键】Multi-Query Attention 的创建方法
    d_model,
    d_model + 2 * self.head_dim,                      # 只创建 query 的 head 向量,所以只有 1 个 d_model
    device=device,                                    # 而 key 和 value 不再具备单独的头向量
)

query, key, value = qkv.split(                        # query -> (1, 512, 768)
    [self.d_model, self.head_dim, self.head_dim],     # key   -> (1, 512, 96)
    dim=2                                             # value -> (1, 512, 96)
)

对比上面的代码,你可以发现

  • 在 MHA 中,query, key, value 每个向量均有 768 维度
  • 而在 MQA 中,只有 query 是 768 维,而 key 和 value 均只剩下 96 维了,恰好是 1 个 head_dim 的维度

因此,可以确认:在 MQA 中,除了 query 向量还保存着 8 个头,key 和 value 向量都只剩 1 个「公共头」了,这也正好印证了论文中所说的「所有 head 之间共享一份 key 和 value 的参数」

剩下的问题就是如何将这 1 份参数同时让 8 个头都使用,代码里使用矩阵乘法 matmul 来广播,使得每个头都乘以这同一个 tensor,以此来实现参数共享:

def scaled_multihead_dot_product_attention(
        query,
        key,
        value,
        n_heads,
        multiquery=False,
    ):
    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
    kv_n_heads = 1 if multiquery else n_heads
    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 96, 512)  if multiquery
    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 512, 96)  if multiquery
    
    attn_weight = q.matmul(k) * softmax_scale                       # (1, 8, 512, 512)
    attn_weight = torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)

    out = attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
    out = rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)

    return out, attn_weight, past_key_value

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1173338.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Buuctf-Crypto-之深夜刷题部分wp

萌萌哒的八戒 首先下载好附件,解压,是一幅猪图,图的下方是一串看不懂的字,百度输入关键词猪、密码,可知这是猪圈密码, 手撸得WHENTHEPIGWANTTOEAT 大写不对,换成小写。 …

数据结构——常见简答题汇总

目录 1、绪论 2、线性表 3、栈、队列和数组 4、串 5、树与二叉树 6、图 7、查找 8、排序 1、绪论 什么是数据结构? 数据结构是相互之间存在一种或多种特定关系的数据元素的集合。数据结构包括三个方面:逻辑结构、存储结构、数据的运算。 逻辑结…

中小学智慧校园电子班牌管理系统源码

智慧校园云平台电子班牌系统,利用先进的云计算技术,将教育信息化资源和教学管理系统进行有效整合,实现基础数据共享、应用统一管理。借助全新的智能交互识别终端和移动化教育管理系统,以考勤、课表、通知、家校互通等功能为切入点…

如何将 XxlJob 集成达梦数据库

1. 前言 在某些情况下,你的项目可能会面临数据库选择的特殊要求,随着国产化的不断推进,达梦数据库是一个常见的选择。本篇博客将教你如何解决 XxlJob 与达梦数据库之间的 SQL 兼容性问题,以便你的任务调度系统能够在这个数据库中…

NNDL 作业6 卷积

一、概念 (一)卷积 (1)什么叫卷积 卷积、旋积或褶积(英语:Convolution)是通过两个函数f和g生成第三个函数的一种数学运算,其本质是一种特殊的积分变换,描述一个函数和另一个函数在某个维度上…

类和对象解析

导言: Java是一门纯面向对象的语言,在面对对象的世界里,一切皆为对象。而对象的创建又和类的定义息息相关。本文主要阐述了类和对象的使用与理解。解释类的定义方式以及对象的实例化,类中的成员变量和成员方法的使用,…

【qemu逃逸】D3CTF2021-d3dev

前言 题目给的是一个 docker 环境,所以起环境非常方便,但是该怎么调试呢?有无佬教教怎么在 docker 中调试? 我本来想着直接起一个环境进行调试,但是缺了好的库,所以就算了,毕竟本题也不用咋调…

044_第三代软件开发-保存PDF

第三代软件开发-保存PDF 文章目录 第三代软件开发-保存PDF项目介绍保存PDF头文件源文件使用 关键字: Qt、 Qml、 pdf、 painter、 打印 项目介绍 欢迎来到我们的 QML & C 项目!这个项目结合了 QML(Qt Meta-Object Language&#xff…

阿里5年经验之谈 —— 记录一次jmeter压测的过程!

在软件架构与中间件实验的最后,要求进行非功能测试,那得非压力测试莫属了。虽然之前学习秒杀项目的时候看视频里面用过jmeter,但没有自己实操过,趁着这次机会,使用一下。 QPS与TPS 1、TPS: Transactions …

力扣周赛 -- 370周赛

先更新前两道题目,下午更新后两道 两道模板题(拓扑排序) 拓扑排序 拓扑排序(Topological Sorting):一种对有向无环图(DAG)的所有顶点进行线性排序的方法,使得图中任意一点 $u$ 和 $v$&#xf…

【LeetCode】每日一题 2023_11_5 重复的DNA序列

文章目录 刷题前唠嗑重复的DNA序列题目描述代码和解题思路偷看大佬题解结语 刷题前唠嗑 LeetCode? 启动!!! 重复的DNA序列 题目链接:187. 重复的DNA序列 题目描述 代码和解题思路 func findRepeatedDnaSequences(s string) …

fastapi-Headers和Cookies

在FastAPI中,Headers是一个特殊的类型,用于处理HTTP请求头(Headers)。Headers允许你接收、访问和修改HTTP请求中的头部信息。 使用Headers,你可以在FastAPI的路由视图中将请求头作为参数接收,并对它们进行…

linux基本用法

文章目录 前言一、开关机操作1.1 开机登陆1.2 关机1.3 系统目录结构 二、常用的基本命令(重点)2.1 相对路径与绝对路径2.2 处理目录的常用命令2.2.1 ls2.2.2 cd 切换目录2.2.3 pwd ( 显示目前所在的目录 )2.2.4 mkdir (创建新目录)2.2.5 rmdir ( 删除空的…

【Vue.js】Vue3全局配置Axios并解决跨域请求问题

系列文章目录 文章目录 系列文章目录背景一、部署Axios1. npm 安装 axios2. 创建 request.js,创建axios实例3. 在main.js中全局注册axios4. 在页面中使用axios 二、后端解决跨域请求问题方法一 解决单Contoller跨域访问方法二 全局解决跨域问题 背景 对于前后端分离…

回溯算法--4后问题

1.问题描述 四皇后问题&#xff1a;在4 4 的方格棋盘上放置4个皇后&#xff0c;使得没有两个皇后在同一行、同一列、也不在同一条45度的斜线上。问有多少种可能的布局&#xff1f; 解是4维向量 比如上面这个解<2,4,1,3> 分别表示圆圈的第2列、第4列等 还可以得到另一解…

LeetCode题:83删除排序链表中的重复元素 141环形链表

83删除排序链表中的重复元素 题目内容 给定一个已排序的链表的头 head &#xff0c; 删除所有重复的元素&#xff0c;使每个元素只出现一次 。返回 已排序的链表 。 示例 1&#xff1a; 输入&#xff1a;head [1,1,2] 输出&#xff1a;[1,2]示例 2&#xff1a; 输入&#xf…

下载安装PyCharm的步骤

1、首先进入Pycharm官网&#xff0c;并进行下载&#xff0c;日常使用社区版也是OK的 官网&#xff1a;https://www.jetbrains.com/pycharm/download/?sectionwindows 2、可以自定义路径进行安装&#xff0c;注意路径要全英哈 3、大家可以根据自己的需要来进行勾选 4、安装完成…

【漏洞复现】Webmin 远程命令执行(CVE-2019-15107)

感谢互联网提供分享知识与智慧&#xff0c;在法治的社会里&#xff0c;请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞验证 1.5、深度利用1、反弹Shell 1.6、修复建议1.7、参考链接 说明内容漏洞编号CVE-2019-15107漏洞…

蒙哥马利算法模乘(四)

一 蒙哥马利算法模乘介绍 蒙哥马利模乘算法主要为了进行大数运算a*b mod n,在介绍蒙哥马利模乘之前,先让我们来了解蒙哥马利约减。 1.1 蒙哥马利约减 a mod n 如果a是一个2048位的整数,n是一个1024位的整数,如果直接采用相除的方式,不论在空间还是时间上都会产生非常大…

【漏洞复现】Django_debug page_XSS漏洞(CVE-2017-12794)

感谢互联网提供分享知识与智慧&#xff0c;在法治的社会里&#xff0c;请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞分析3、漏洞验证 说明内容漏洞编号CVE-2017-12794漏洞名称Django_debug page_XSS漏洞漏洞评级影响范…