Group Query Attention (GQA) 机制详解以及手动实现计算

news2024/9/23 3:13:27

Group Query Attention (GQA) 机制详解

1. GQA的定义

Grouped-Query Attention (GQA) 是对 Multi-Head Attention (MHA) 和 Multi-Query Attention (MQA) 的扩展。通过提供计算效率和模型表达能力之间的灵活权衡,实现了查询头的分组。GQA将查询头分成了G个组,每个组共享一个公共的键(K)和值(V)投影。

2. GQA的变体

GQA有三种变体:

  • GQA-1:一个单独的组,等同于 Multi-Query Attention (MQA)。
  • GQA-H:组数等于头数,基本上与 Multi-Head Attention (MHA) 相同。
  • GQA-G:一个中间配置,具有G个组,平衡了效率和表达能力。
3. GQA的优势

使用G个组可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。GQA提供了对模型质量和效率的细致控制。

4. GQA的实现

GQA的最简形式可以通过实现 GroupedQueryAttention 类来实现。GroupedQueryAttention 类继承自 Attention 类,重写了 forward 方法,其中使用了 MultiQueryAttention 类的实例来处理每个组的查询。通过将每个组的结果拼接起来,然后与投影矩阵进行矩阵乘法运算,最终得到 GQA 的输出。[1]

pytorch 示例实现:

假设我们有以下初始化的query, key, value:

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)
1. 确定分组数量

首先,我们需要确定将查询头分为多少组。在这个例子中,我们有8个查询头,而键和值的头数为2,所以我们可以将查询头分为4组,每组有2个查询头。

2. 对查询进行分组

然后,我们将查询头分组。我们可以使用 torch.chunk 函数将查询张量沿着头维度分割成4个组,每个组有2个头。

query_groups = torch.chunk(query, 4, dim=2)  # shape of each group: (1, 256, 2, 64)

3. 计算注意力分数

对于每一个查询组,我们计算它与键的注意力分数。我们首先计算查询组和键的点积,然后通过 torch.softmax 函数得到注意力分数。

attention_scores = []
for query_group in query_groups:
    score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)
    score = torch.softmax(score, dim=-1)
    attention_scores.append(score)
4. 计算注意力输出

接下来,我们使用注意力分数对值进行加权求和,得到每一个查询组的注意力输出。

attention_scores = []
for query_group in query_groups:
    score = torch.matmul(query_group, key.transpose(-2, -1))  # shape: (1, 256, 2, 256)
    score = torch.softmax(score, dim=-1)
    attention_scores.append(score)
5. 拼接输出

最后,我们将所有查询组的注意力输出拼接起来,得到最终的 Grouped Query Attention 的输出。

attention_outputs = []
for score in attention_scores:
    output = torch.matmul(score, value)  # shape: (1, 256, 2, 64)
    attention_outputs.append(output)

这就是 Grouped Query Attention 的实现过程。在这个过程中,我们将查询头分组,然后对每一个查询组分别计算注意力分数和输出,最后将所有查询组的输出拼接起来。这样可以减少存储每个头的键和值所需的内存开销,特别是在具有大的上下文窗口或批次大小的情况下。


  1. Grouped-Query Attention (GQA) - The Large Language Model Playbook

  2. 安全验证 - 知乎
  3. 安全验证 - 知乎
  4. 安全验证 - 知乎
  5. Grouped-Query Attention (GQA) - The Large Language Model Playbook

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1616611.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一文学会Amazon transit GateWay

这是一个中转网关,使用时候需要在需要打通的VPC内创建一个挂载点,TGW会管理一张路由表来决定流量的转发到对应的挂载点上。本质上是EC2的请求路由到TGW,然后在查询TGW的路由表来再来决定下一跳,所以需要同时修改VPC 内子网的路由表…

ssm071北京集联软件科技有限公司信息管理系统+jsp

北京集联软件科技有限公司信息管理系统 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本信息管理系统就是在这样的大环境下诞生,其可以帮助管理…

使用PlantUML绘制活动图、泳道图

最近在学PlantUML 太漂亮了 给大家欣赏一下 我也记录一下 startuml |使用前| start :用户打开旅游App; |#LightSkyBlue|使用后| :用户浏览旅游信息; |#AntiqueWhite|登机前| :用户办理登机手续; :系统生成登机牌; |使用前| :用户到达机场; |登机前| :用户通过安检; |#Light…

2024HVV在即| 最新漏洞CVE库(1.5W)与历史漏洞POC总结分享!

前言 也快到护网的时间了,每年的护网都是一场攻防实战的盛宴,那么漏洞库就是攻防红蓝双方人员的弹药库,红队人员可以通过工具进行监测是否存在历史漏洞方便快速打点,而蓝队则可以对资产进行梳理和监测历史漏洞,及时处理和修复,做好准备. 下面分享的…

发布自己的Docker镜像到DockerHub

学会了Dockerfile生成Docker image 之后,如何上传自己的镜像到 DockerHub呢?下面我以自己制作的 bs-cqhttp 镜像为例,演示一下如何将自己的镜像发布到 Docker 仓库。 1 生成自己的 Docker 镜像 1.1 实例镜像用到的文件 图1 实例镜像制作用到…

Web前端安全问题分类综合以及XSS、CSRF、SQL注入、DoS/DDoS攻击、会话劫持、点击劫持等详解,增强生产安全意识

前端安全问题是指发生在浏览器、单页面应用、Web页面等前端环境中的各类安全隐患。Web前端作为与用户直接交互的界面,其安全性问题直接关系到用户体验和数据安全。近年来,随着前端技术的快速发展,Web前端安全问题也日益凸显。因此&#xff0c…

注意libaudioProcess.so和libdevice.a是不一样的,一个是动态链接,一个是静态

libaudioProcess.so是动态链接,修改需要改根文件系统,需要bsp重新配置 libdevice.a是静态链接,直接替换就行 动态链接文件修改 然后执行fw_update.sh

HarmonyOS ArkUI实战开发-手势密码(PatternLock)

ArkUI开发框架提供了图案密码锁 PatternLock 组件,它以宫格图案的方式输入密码,用于密码验证,本节读者简单介绍一下该控件的使用。 PatternLock定义介绍 interface PatternLockInterface {(controller?: PatternLockController): PatternL…

3D MINS 多模态影像导航系统

3D MINS多模态影像导航系统(Multimodal Image Navigation System)是SunyaTech研发的建立在DICOM(Digital Imaging and Communications in Medicine)图像基础之上的多模态影像导航系统,集二维影像PACS管理、三维影像层级…

shell进阶之正则表达式:字符转义(十七)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

NovaMSS音乐源分离v1.3.3社区版

软件介绍 NovaMSS 基于最新 AI 模型优化的音乐源分离工具。它能够轻松地批量提取伴奏、人声、贝斯、鼓点等音轨,并且支持 GPU 加速,以提高处理速度和效率。社区版完全免费,简单易用,上传文件,点击处理,查看…

BI建设案例:FineBI大数据分析平台助力工程机械行业降本增效

工程机械行业作为国民经济的重要支柱,产品多样化、应用广泛,市场集中度高。其上游涉及原材料和核心零部件,下游则与房地产、基建工程和采矿等行业紧密相连。 如今,中国已崛起为全球工程机械制造大国,各类机械产品产量…

java开发之路——node.js安装

1. 安装node.js 最新Node.js安装详细教程及node.js配置 (1)默认的全局的安装路径和缓存路径 npm安装模块或库(可以统称为包)常用的两种命令形式: 本地安装(local):npm install 名称全局安装(global):npm install 名称 -g本地安装和全局安装…

基于spring boot的实习管理系统

基于spring boot的实习管理系统设计与实现 开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7(一定要5.7版本) 数据库工具:Navicat11 开发软件&…

丁晴无硫指套:高科技产品保护的利器

Nitrile Sulphur-Free Finger Cots: A Weapon for Protecting High-Tech Products 随着科技的不断发展,微型电机、精密电子器件、仪器仪表等高科技产品的制造与应用日益普及。然而,这些产品的制造过程中往往需要特殊的保护措施,以防止静电、…

自动驾驶---OpenSpace之Hybrid A*规划算法

1 背景 笔者在上周发布的博客《自动驾驶---低速场景之记忆泊车》中,大体介绍了记忆泊车中的整体方案,其中详细阐述了planning模块的内容,全局规划及局部规划(会车)等内容,包括使用的算法,但是没…

.gitignore语法及配置问题

语法及配置 前言.gitignore语法Git 忽略规则优先级gitignore规则不生效Java项目中常用的.gitignore文件c项目中常用的.gitignore注意事项 前言 在工程中,并不是所有文件都需要保存到版本库中,例如“target”目录及目录下的文件就可以忽略。在Git工作区的…

四信AI睿析—边缘智脑:赋能农业新时代,开启智慧种植新篇章

方案简介 本系统前端安装土壤墒情监测站,包括温湿度传感器、二氧化碳传感器、PH值传感器、土壤电导率传感器、土壤温湿度传感器、光照传感器等组成;高清枪机摄像头等、负责种植区域温湿度、土壤EC、土壤温湿度、光照等环境因子、视频数据、农作物生长图…

Spring 依赖-ApiHug准备-测试篇-015

🤗 ApiHug {Postman|Swagger|Api...} 快↑ 准√ 省↓ GitHub - apihug/apihug.com: All abou the Apihug apihug.com: 有爱,有温度,有质量,有信任ApiHug - API design Copilot - IntelliJ IDEs Plugin | Marketplace gradle…

Golang | Leetcode Golang题解之第44题通配符匹配

题目: 题解: func isMatch(s string, p string) bool {for len(s) > 0 && len(p) > 0 && p[len(p)-1] ! * {if charMatch(s[len(s)-1], p[len(p)-1]) {s s[:len(s)-1]p p[:len(p)-1]} else {return false}}if len(p) 0 {retur…