MQA(Multi-Query Attention)详解

news2024/9/20 8:54:18

论文名称:Fast Transformer Decoding: One Write-Head is All You Need

论文地址:https://arxiv.org/abs/1911.02150v1

        MQA(Multi-Query Attention)是Google团队在2019年提出的,是MHA (Multi-head Attention,多头注意力机制)的一种变体,也是用于自回归解码的一种注意力机制。

        传统的MHA是将输入划分为多个Head,并为每个Head独立计算注意力。在MHA中的,Q、K、V会根据每个head做不同的转换(模拟:每个Head都有自己的感知域/parameter sets,可以独立学习输入中的不同特性)。这在Head数量较多时候可能会存在计算密集的问题。

        而与MHA 不同的是,MQA 让所有的Head之间共享同样的一份 K 和 V 矩阵(意味K和V的计算唯一),只让 Q 保留了原始多头的性质(每个Head存在不同的转换),从而大大减少 K 和 V 矩阵的参数量以及KV Cache的显存占用,以此来达到提升推理速度,但是会带来精度上的损失。技术被大量应用于大预言模型,如ChatGLM2。

从代码角度来看,形式如下:

K_shared = WK * K
V_shared = WV * V

for i in range(num_heads):
    Qi = WQi * Q
    ...
    ...

下面一段代码来自于下面这个链接的作者的实现chatGLM2中的Multi Query Attention_multi-query attention-CSDN博客

源码请看huggingface的transformers包中的bertselfattention源码实现。

    class MultiQuerySelfAttention(nn.Module):
    def __init__(self, num_attention_heads, hidden_size):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
 
        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.attention_head_size)
        self.value = nn.Linear(hidden_size, self.attention_head_size)
 
        self.dropout = nn.Dropout(0.1)
 
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
 
    def forward(self,hidden_states):
        # hidden_states (B, L, D)
        mixed_query_layer = self.query(hidden_states)
        # query_layer  (B, h, L, d)
        # 在此处,将query划分为多头[batch_size, head_num, 序列长度, embedding长度]
        query_layer = self.transpose_for_scores(mixed_query_layer)
 
        # 每个key、value head参数都是共享的,只计算一次
        key = self.key(hidden_states)
        #key_layer  (B, 1, L, d)
        key_layer = key.unsqueeze(1)
        value = self.value(hidden_states)
        # value_layer  (B, 1, L, d)
        value_layer = value.unsqueeze(1)
 
        # key_layer  (B, 1, d, L)
        key_layer = key_layer.transpose(-1, -2)
        #广播算法 (B, h, L, d) * (B, 1, d, L) => (B, h, L, d) * (B, h, d, L) = (B, h, L, L)
        attention_scores = torch.matmul(query_layer, key_layer)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        #广播算法 (B, h, L, L) * (B, 1, L, d) =>(B, h, L, L) * (B, h, L, d)= (B, h, L, d)
        context_layer = torch.matmul(attention_probs, value_layer)
        #(B, h, L, d) => (B, L, h, d)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # (B,L, h*d) => (B,L,D)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        # (B,L, h*d) => (B,L,D)
        context_layer = context_layer.view(new_context_layer_shape)

        return context_layer

稍微补充一下:原论文中的MQA伪代码如下,和自注意力的MQA实现有些区别,个人猜测如下

        这里简单理解下,一般情况下我们讲的都是自注意力XXX,比如自注意力MHA,这时Q、K、V都来自于输入X;但是,论文中讲述的应该是纯粹的MHA和MQA,此时构成Q和K的输入就不同。(猜想来自于传统注意力机制,该机制多应用于seq-seq任务)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1967760.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信运营新助手:自动回复神器,让沟通更高效!

在现代职场中,效率是成功的关键。然而,我们经常会面对大量重复且繁琐的日常任务,消耗宝贵的时间和精力。 今天,我想向大家分享一个强大的微信自动回复神器,它将帮助你高效管理沟通,提升工作效率。 1、自动…

GraphHopper:开源路线规划引擎

在当今信息爆炸的时代,我们越来越依赖于智能路线规划来帮助我们节省时间、提高效率。GraphHopper作为一款开源的路线规划引擎,为我们提供了一个强大而灵活的工具,让我们可以在自己的应用程序中实现高效的路径计算。 什么是GraphHopper&#…

电脑录屏怎么录?2024四大工具助你轻松录制每一刻!

无论是教学演示、游戏直播,还是工作汇报,一款好用的录屏软件都能帮助我们轻松完成任务。那么,电脑录屏怎么录呢?今天为大家推荐几款实用的电脑录屏工具,让你轻松成为录屏达人! Foxit REC:专业与…

Linux进程控制——进程程序替换、bash的模拟实现

文章目录 exec系列函数execlexeclp和execle execv系列函数bash的模拟实现实现思路完整代码其他问题 在学习进程的时候,我们想fork一个子进程,然后就可以给他布置任务了 但是如果我们分成两个人开发,父子进程分别负责不同的任务,等…

揭秘智能工牌:如何成为房企销售团队的数字化转型加速器

在这个竞争激烈的市场环境中,房企想要脱颖而出,不仅需要优质的产品和服务,更需要高效的销售团队。而销售团队的能力提升,离不开精细化管理和科技的赋能。DuDuTalk智能语音工牌,正是这样一款融合了AI技术与销售实战智慧…

无人机之森林防火篇

无人机在森林火灾中的应用是一个快速发展的领域,它们在火灾预防、监测、救援和灾后评估等方面发挥着重要作用。 一、无人机在森林火灾监测中的应用 在森林火灾的监测方面,无人机凭借其高空、高速、长时间巡查的优势,能够全面覆盖监测区域&am…

体育器材管理系统(完整开发文档)

1.1研究背景及意义 研究背景: 体育器材是高校体育教学和课外体育活动的重要物质基础,其使用和管理对于保障教学质量、提高学生体育素质具有重要意义。随着高校体育教学和课外活动的不断发展,体育器材的种类和数量不断增加,传统的…

Linux进程(一)

目录 一.进程的介绍1.引出进程2.进程的介绍 二.创建进程1.创建进程的原理2.什么是fork函数(1).通过手册查看fork 3.例子 一.进程的介绍 1.引出进程 Google Chrome 是一个进程 Google Chrome 底下的选项是多个线程 通过top命令可以查看正在运行的进程 2.进程的介绍 课本概念 …

F5云安全防护能力如何?一文为你解惑

伴随云计算的快速发展,云安全已成为实施云战略的重要保障。来自F5 SOAS报告的调查显示,近三分之二的企业将使用AI和机器学习划入优先事项,并把云安全列为最关键的应用场景。作为一家提供多云应用安全和应用交付的公司,F5的云安全防…

python实现微信聊天图片DAT文件还原

完整代码如下: from glob import glob import os from tqdm import tqdmdef get_sign(dat_r):signatures [(0x89, 0x50, 0x4e), (0x47, 0x49, 0x46), (0xff, 0xd8, 0xff)]mats [".png", ".gif", ".jpg"]for now in dat_r:for j, x…

嵌入式行业,中年危机是否存在?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「嵌入式的资料从专业入门到高级教程」,点个关注在评论区回复“666”之后私信回复“666”,全部无偿共享给大家!!! 肯定有,你看到的…

嵌入式人工智能(40-基于树莓派4B的水滴传感器和火焰传感器)

虽然这两个传感器水火不容,我还是把他们放到一起了。本文是有线传感器的最后一个部分了。后面如果还有文章介绍有线传感器,也是补充学习其他内容不得已而为之。如果不是,就当我没说,哈哈。 1、水滴传感器 水滴传感器又称雨滴传感…

实现字母的大小写转换。多组输入输出(c语言)

1.我们先输入字母&#xff08;用getchar的函数&#xff09;&#xff0c;判断是不是字母&#xff0c;我们可以用a<tmp<z或者A<tmp<Z,注意&#xff1a;小写转换大写用tmp-32&#xff0c;大写转换小写用tmp32.. #include<stdio.h> int main() {int a 0;while …

以太坊交易手续费计算

Gas 中译是&#xff1a;瓦斯、汽油&#xff0c;代表一种可燃气体。 这形象地比喻以太坊的交易手续费计算模式&#xff0c;不同于比特币中直接支付比特币作为转账手续费&#xff0c; 以太坊视为一个去中心化的计算网络&#xff0c;当你发送Token、执行合约、转移以太币或者在此区…

东巴古籍——纳西族古老文字的见证

关注我们 - 数字罗塞塔计划 - 华夏大地上的每个民族都有各自独特的文化传承&#xff0c;在前面的文章中&#xff0c;我们已经介绍过中国档案文献遗产名录中收录的永州女书和水族水书&#xff08;详细参见《永州女书——世上唯一专属于女性的文字》、《水书——破解象形文字含义…

二叉树LeetCode热题

94.二叉树的中序遍历 题目 给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。 输入&#xff1a;root [1,null,2,3]输出&#xff1a;[1,3,2] 代码 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* …

一文搞懂大模型在多GPU环境的分布式训练!

随着大模型时代的到来&#xff0c;模型参数量、训练数据量、计算量等各方面急剧增长。大模型训练面临新的挑战&#xff1a; 显存挑战&#xff1a;例如&#xff0c;175B的GPT-3模型需要175B*4bytes即700GB模型参数空间&#xff0c;而常见的GPU显存如A100是80G显存&#xff0c;这…

zabbix使用脚本自定义监控项

1. 在zabbix_agent的配置文件中配置自定义key和脚本位置 vim /etc/zabbix/zabbix_agentd.confUserParametermq_check_log,/etc/zabbix/zabbix_agentd.d/mqlog.shmq_check_log&#xff1a;是这个自定义参数的名称。在Zabbix的监控项&#xff08;item&#xff09;配置中&#xf…

WinForm中使用Graphics画元素

前言 有时候我们需要在一个图像上显示一些文字&#xff0c;或者画一些标志&#xff0c;这就想我们平时截图也需要做一些描述信息。在C#中我们可以Graphics这个对象来绘制自己所需要描述的信息&#xff0c;当然在WPF中的它的设计思路又不一样了&#xff0c;在WPf中考虑使用的矩…

upload-labs靶场:1—10通关教程

目录 Pass-01&#xff08;JS 验证&#xff09; Pass-02&#xff08;MIME&#xff09; Pass-03&#xff08;黑名单绕过&#xff09; Pass-04&#xff08;.htaccess 绕过&#xff09; Pass-05&#xff08;大小写绕过&#xff09; Pass-06&#xff08;空格绕过&#xff09; …