Transformer系列:注意力机制的优化,MQA和GQA原理简述

news2025/1/22 18:55:56
前言

多查询注意力(MQA)、分组查询注意力(GQA)是Transformer中多头注意力(MHA)的变种,它们大幅提高了解码器的推理效率,在LLaMA-2,ChatGLM2等大模型中有广泛使用,本篇介绍MQA、GQA的原理并分析其源码实现。


使用MQA,GQA的背景介绍

多查询注意力(Multi Query Attention,MQA)提出于2019年的论文《Fast Transformer Decoding: One Write-Head is All
You Need》,旨在解决Transformer增量推理阶段效率低下的问题,在当时并没有引起关注,而随着近几年Transformer和GPT成为生成式大模型的基座,面临着产业落地的实际情况,导致GPT的推理加速备受关注,因此MQA又重新被提及起来。
分组查询注意力(Group Query Attention,GQA)提出于2023年,是MQA更一般的形式,它介于MQA和MHA之间,是模型预测表现和模型推理性能之间的一个折衷。


MQA,GQA原理简述

MQA的原理很简单,它将原生Transformer每一层多头注意力的Key线性映射矩阵、Value线性映射矩阵改为该层下所有头共享,也就是说K、V矩阵每层只有一个,而Q矩阵不受影响,其数量和注意力头数相等。以ChatGLM2-6B为例,一共28层,32个注意力头,输入从4096经过Q、K、V矩阵映射维度为128,若采用原生多头注意力机制,则Q、K、V矩阵各有28×32个,而采用MQA的方式则整个模型包含28×32个Q矩阵,28个K矩阵,28个V矩阵,示意图如下

MHA和MQA的差别

可想而知,MQA这种方式大幅减小了参数数量,带来推理加速的同时会造成模型性能损失,且在训练过程使得模型变得不稳定,因此在此基础上提出了GQA,它将Query进行分组,每个组内共享一组Key、Value,令组的数量为N,若N等于1此时等效于MQA,若N等于Query头的数量,此时退化为MHA。GQA是推理效率和模型性能的trade-off。

GQA(中)和MQA(右)对比


MQA,GQA推理加速分析

MQA能够大幅加速采用MHA的Transformer的推理,但是会有明显的性能损失,而GQA通过设置合适的分组大小,可以和MQA的推理性能几乎相等,同时逼近MHA的模型性能。作者在GQA的论文中给到了实验结论来印证这一点。
作者采用T5模型作为研究对象,模型版本采用T5-Large和T5-XXL,它们都采用MHA注意力方式,其中Large参数量770M,XXL参数量11B,在此基础上作者通过up-training方法将T5-XXL改造为MQA和GQA,最终一共四个版本模型进行精度和推理效率的对比,结果如下

横轴代表平均每条样本的推理耗时,越大代表延迟越大,纵轴代表在众多数据集上的评价得分,越大代表得分越高。在MHA方式下,由于XXL的模型参数更大,因此MHA-XXL的推理延迟高于MHA-Large,同时MHA-XXL的模型评分在所有版本里面最高。
经过up-training方法将MHA改造为MQA之后,MQA-XXL获得了所有版本的最低延迟,甚至还低于小一个型号的Large模型,同时MQA使得其模型评分比MHA-XXL降低了,但还是超越了小一个信号的Large的模型,表明大参数量的MQA模型不论在精度还是效率上都超越了小参数量的MHA模型。
而经过up-training方法将MHA改造为GQA,GQA-XXL的推理延迟几乎和MQA-XXL相等,而其性能评分也和参数量最大的MHA-XXL十分接近,整体上GQA达到了最佳效果,推理性能和模型评分都十分优秀。

分组的大小对模型推理延迟的影响

GQA的分组数是一个超参数,组数越大越接近MHA,推理延迟越大,同时模型精度也越高,作者给出了他的实验结论表明,当组数量从1逐渐上升到8时,模型推理的开销并没有明显的增长,在8以后推理开销显著变大,最终作者采用8个分组作为他的最佳选择。
以上从实验结果层给到结论,MQA略微损失了模型精度,但是确实能够大幅降低推理开销,而如果选择了合适的分组数,GQA能够两者皆得。在理论层,MQA和GQA对推理的帮助主要是以下两点

    1. 降低内存读取模型权重的时间开销:由于Key矩阵和Value矩阵数量变少了,因此权重参数量也减少了,需要读取到内存的数量量少了,因此减少了读取权重的等待时间
    1. KV-Cache空间占用降低:KV-Cache会将之前推理过的Key、Value向量存储在内存中,而随着步长和batch_size的增长,KV-Cache空间占用越来越高,使得KV-Cache不能被高效的读写,而MHA和GQA方式使得KV-Cache需要存储的参数量降低了head_num倍,从而提高KV-Cache的读写效率;另一方面,可以有空间来增大batch_size,从而提高模型推理的吞吐量

注意MQA和GQA并没有降低Attention的计算量(FLOPs),因为Key、Value映射矩阵会以广播变量的形式拓展到和MHA和一样,因此计算量不变,只是Key、Value参数共享。


ChatGLM2-6B中的MQA/GQA源码分析

本节采用ChatGLM2-6B的模型源码modeling_chatglm.py来说明MQA和GQA的实现,这两者在代码上没有区别,因为MQA是GQA的特例,当分组数等于1时就是MQA,而chatglm2-6B采用的是分组数为2的GQA,从它的配置文件config.json可以观察得到

{
  "_name_or_path": "THUDM/chatglm2-6b",
  "model_type": "chatglm",
  "architectures": [
    "ChatGLMModel"
  ],
  "auto_map": {
    "AutoConfig": "configuration_chatglm.ChatGLMConfig",
    "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
    "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"
  },
  ...
  "multi_query_attention": true,
  "multi_query_group_num": 2,
  ....
}

其中multi_query_attention代表是否开启多查询注意力,multi_query_group_num代表分组数。
MQA和GQA仅涉及到注意力层,因此直接定位到SelfAttention的代码块

class SelfAttention(torch.nn.Module):
    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.layer_number = max(1, layer_number)

        # TODO 128 * 32
        self.projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        # TODO true 多查询注意力
        self.multi_query_attention = config.multi_query_attention
        # TODO qkv线性映射层到3*d
        self.qkv_hidden_size = 3 * self.projection_size
        if self.multi_query_attention:
            self.num_multi_query_groups_per_partition = config.multi_query_group_num  # 2
            self.qkv_hidden_size = (
                # TODO (128 * 32 + 2 * 128 * 2)
                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
            )
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         device=device, **_config_to_kwargs(config)
                                         )

该SelfAttention代表某一层下,32个注意力头的运算。在初始化阶段,作者采用大矩阵方案用一个矩阵将所有头QKV创建出来,因此projection_size为128×3,然后以MHA的方式将映射维度乘以3得到qkv_hidden_size,当采用多查询注意力时对qkv_hidden_size重新修改,它等于所有头的Q矩阵,加上KV矩阵各一个,因此projection_size为128 * 32 + 2 * 128 * 2=4608,最后通过一个Linear层实现QKV矩阵的创建。
在推理阶段使用query_key_value进行同意QKV映射,然后通过split算子将QKV进行分解,分解之后所有头的Query为4096维,Key和Value为256维,因为分组数为2,所有存在两个Key和两个Value。

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
    ):
        # TODO [1(seq), 1(batch), 4608]
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            # TODO [17, 1, 4096], [17, 1, 256], [17, 1, 256]
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,  # TODO 32 *128
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,  # TODO 2 * 128
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,  # TODO 2 * 128
                ],
                dim=-1,
            )

然后作者将QKV向量的维度进行reshape,将头维度拿出来,准备在下面的代码中对KV进行广播

            # TODO [17, 1, 32, 128]
            query_layer = query_layer.view(
                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            # TODO [17, 1, 2, 128]
            key_layer = key_layer.view(
                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
            # TODO TODO [17, 1, 2, 128]
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )

在这之前,需要将原始的Query和Key携带旋转位置编码RoPE,使得注意力能够感知到相对位置信息,此步骤和Value无关

        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

然后作者对Key和Value做广播,因为Query有32个头,而GQA有2组,因此要翻32/2=16倍数,通过torch的expand进行广播实现参数共享,广播之后QKV三者的shape变成一致

        if self.multi_query_attention:
            # TODO TODO [17, 1, 2, 128] => [17, 1, 2, 1, 128]
            key_layer = key_layer.unsqueeze(-2)
            # TODO [17, 1, 2, 16, 128]
            key_layer = key_layer.expand(
                # TODO expand 进行广播,k,v向量共享
                # TODO 只能对维度值是1的进行拓展,如果某些维不需要拓展,写为-1, 32 // 2=16
                # TODO 有32个头,KV组只有2组,要复制16份
                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
            )
            # TODO [17, 1, 2, 16, 128] => [17, 1, 32, 128]
            key_layer = key_layer.contiguous().view(
                key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.unsqueeze(-2)
            value_layer = value_layer.expand(
                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
            )
            value_layer = value_layer.contiguous().view(
                value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )

最后所有处理好之后计算注意力权重,并且将权重和Value相乘得到注意力的输出,这里和传统的MHA没有任何却别

context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)

MQA和GQA的主要代码流程结束,核心是先创建分组数的Key和Value矩阵,注意力点乘之前将Key和Value广播到和Query一致即可,全文完毕。

最后的最后

感谢你们的阅读和喜欢,我收藏了很多技术干货,可以共享给喜欢我文章的朋友们,如果你肯花时间沉下心去学习,它们一定能帮到你。

因为这个行业不同于其他行业,知识体系实在是过于庞大,知识更新也非常快。作为一个普通人,无法全部学完,所以我们在提升技术的时候,首先需要明确一个目标,然后制定好完整的计划,同时找到好的学习方法,这样才能更快的提升自己。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

五、面试资料

我们学习AI大模型必然是想找到高薪的工作,下面这些面试题都是总结当前最新、最热、最高频的面试题,并且每道题都有详细的答案,面试前刷完这套面试题资料,小小offer,不在话下。
在这里插入图片描述

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1789617.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot前后端简单集成MinIo开发

Spring Boot前后端简单集成MinIo开发 源码地址重要配置和代码MinIO配置核心代码 最终效果 源码地址 minioStudy 重要配置和代码 MinIO配置 pom文件引入依赖 <!-- minio --> <dependency><groupId>io.minio</groupId><artifactId>minio</ar…

GaussDB的数种形态

GaussDB作为一种新兴的关系型数据库产品&#xff0c;似乎有点让人摸不着头脑。有朋友问我GaussDB单机版怎么样&#xff0c;有人说GaussDB是分布式数据库&#xff0c;还有人说它是云数据库&#xff0c;还有人会把GaussDB和华为的数据仓库GaussDB DWS混为一谈。确实&#xff0c;公…

AD域渗透链和工具推荐

xmind下载地址&#xff1a; 链接: https://pan.baidu.com/s/1_BsmqLvN6aBnan0AIk5iBA 提取码: j97j

笔记 | 软件工程02:软件工程概述

1 软件工程产生背景 1.1 历史发展 1960s的个体作坊式软件开发带来的问题 1.2 软件开发需要解决的问题 代码规模增长带来的影响&#xff1a; 1.3 软件开发面临的挑战 指挥信息系统中的软件&#xff1a;规模大、质量要求高 装备中嵌入式软件系统&#xff1a;规模大、质量要求…

【MySQL数据库】索引与事务

&#x1f525;个人主页&#xff1a; 中草药 &#x1f525;专栏&#xff1a;【MySQL】探秘&#xff1a;数据库世界的瑞士军刀 目录 &#x1f5f3;️一.索引 &#x1f4ee;1.工作原理 &#x1f4ec;2.类型 &#x1f4ed;3.作用 &#x1f4ea;4.优缺点 &#x1f4eb;5.使用…

一维时间序列突变检测方法(小波等,MATLAB R2021B)

信号的突变点检测问题是指在生产实践中&#xff0c;反映各种系统工作状态的信号&#xff0c;可能因为受到不同类型的噪声或外界干扰而发生了信号突变&#xff0c;导致严重失真的信号出现&#xff0c;因此必须探测突变出现的起点和终点。研究目的在于设计出检测方案&#xff0c;…

python-字符替换

[题目描述] 给出一个字符串 s 和 q 次操作&#xff0c;每次操作将 s 中的某一个字符a全部替换成字符b&#xff0c;输出 q 次操作后的字符串输入 输入共 q2 行 第一行一个字符串 s 第二行一个正整数 q&#xff0c;表示操作次数 之后 q 行每行“a b”表示把 s 中所有的a替换成b输…

docker 存储 网络 命令

文章目录 1 docker存储1.1 目录挂载2.1卷映射2.1.1卷映射和目录挂载的区别2.1.2卷映射的使用 2 docker网络2.1查看docker的默认网络2.2查看容器的IP2.3容器互通2.4自定义网络2.4.1 创建自定义网络2.4.2创建容器的时候加入到自定义的网络2.4.3使用域名进行容器之间的访问2.4.4re…

小米路由器如何设置去广告功能,如何设置小米路由器的自定义Hosts(小米路由器如何去除小米广告、去除小米电视盒子开屏广告、视频广告)

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 实现方案 📒📝 操作步骤📝 注意事项⚓️ 相关链接 ⚓️📖 介绍 📖 小米设备的广告一直是用户头疼的问题,无论是开屏广告、应用内广告还是系统广告,都影响了用户体验。本文将详细介绍如何通过小米路由器实现去除广告…

低代码设计中的组织结构的作用与模式

一、组织结构的作用 在低代码设计中&#xff0c;组织结构是系统运作的基石&#xff0c;它定义了系统中的关键元素&#xff0c;包括人员、部门、角色&#xff0c;以及一人多部门、一人多部门多角色的复杂关系。这种定义不仅为系统提供了清晰的运行框架&#xff0c;还确保了系统…

学生问的一道CSS3媒体查询,实现响应式设计的题

目录 题目要求&#xff1a; 解题思路&#xff1a; 解题&#xff1a; 1&#xff09;大屏、3个DIV水平排列 2&#xff09;中屏、前2个DIV水平占一半&#xff0c;第三个另起一行&#xff0c;宽度占满 3&#xff09;小屏&#xff0c;3个DIV铺满&#xff0c;垂直排列 题目要求&…

深入理解计算机系统 家庭作业5.13

A:关键路径在xmm0那条路,书中几条关键路径全部是xmm0,有xmm1时,xmm1也是 B:3 C:1 D:按书中的定义: 关键路径才是下界!按书上的方法根据 图5-12 算出关键路径的CPE即可. 非关键路径把它视为黑盒子.因为是乱序和超标量的,没办法搞清楚处理器具体怎么处理这些指令.

17、Spring系列-SpringMVC-请求源码流程

前言 Spring官网的MVC模块介绍&#xff1a; Spring Web MVC是基于Servlet API构建的原始Web框架&#xff0c;从一开始就已包含在Spring框架中。正式名称“ Spring Web MVC”来自其源模块的名称&#xff08;spring-webmvc&#xff09;&#xff0c;但它通常被称为“ Spring MVC…

PHP 页面报错Warning</b>: Cannot modify header information - headers already sent by

先给出解决方案再解释&#xff0c;如果急着用就不用看解释了。 解决方案一&#xff1a;保存php文件编码为utf-8无BOM码&#xff0c;具体操作可以用notepad等编辑器完成&#xff0c;把 sesstion_start() 放在文档所有输出&#xff08;包括html标签和php的输出语句&#xff0c;具…

了解可燃气体报警器检验收费,守护企业安全新防线

在工业生产中&#xff0c;可燃气体报警器作为预防火灾和爆炸事故的重要设备&#xff0c;其准确性和可靠性至关重要。为了确保报警器的正常运行&#xff0c;定期的检验工作必不可少。 而关于检验收费&#xff0c;很多企业可能存在疑虑&#xff1a;这项费用是否合理&#xff1f;…

AC自动机(查询)

上面讲了AC自动机是如何建树和建自动机的&#xff0c;这里要讲的是AC自动机的查询和各个数组的功能和作用。 其实AC自动机的查询和KMP算法是及其相近的&#xff0c;都是一个指针跑主串&#xff0c;另一个指针跑ne串&#xff08;这里就是回跳边&#xff09;。 话都说到这了&…

C#中字节数组(byte[])末尾继续添加字节的示例

方法一&#xff1a;使用List 使用List可以很容易地在末尾添加字节&#xff0c;然后如果需要&#xff0c;可以将其转换回byte[]。 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Lin…

【C++小知识】为什么C语言不支持函数重载,而C++支持

为什么C语言不支持函数重载&#xff0c;而C支持 编译链接过程函数名修饰过程总结 在了解C函数重载前&#xff0c;如果对文件的编译与链接不太了解。可以看看我之前的一篇文章&#xff0c;链接: 文件的编译链接 想要清楚为什么C语言不支持函数重载而C支持&#xff0c;有俩个过程…

大数据学习问题记录

问题记录 node1突然无法连接finalshell node1突然无法连接finalshell 今天我打开虚拟机和finalshell的时候&#xff0c;发现我的node1连接不上finalshell,但是node2、node3依旧可以链接&#xff0c;我在网上找了很多方法&#xff0c;但是是关于全部虚拟机连接不上finalshell&a…

12 - 常用类

那就别跟他们比&#xff0c;先跟自己比&#xff0c;争取今天比昨天强一些&#xff0c;明天比今天强一些。 1.包装类 针对八种基本数据类型封装的相应的引用类型。 有了类的特点&#xff0c;就可以调用类中的方法。&#xff08;为什么要封装&#xff09; 基本数据类型包装类b…