GQA分组注意力机制

news2024/11/28 4:36:00

一、目录

  1. 定义
  2. demo

二、实现

  1. 定义
    grouped query attention(GQA)
    1 GQA 原理与优点:将query 进行分组,每组query 参数共享一份key,value, 从而使key, value 矩阵变小。
    2. 优点: 降低内存读取模型权重的时间开销:由于Key矩阵和Value矩阵数量变少了,因此权重参数量也减少了,需要读取到内存的数量量少了,因此减少了读取权重的等待时间。
    3. 效果(并未降低模型性能):GQA通过设置合适的分组大小,可以和MQA的推理性能几乎相等,同时逼近MHA的模型性能。
  2. llama3 分组数为4, chatglm2 分组数为2 .
    在这里插入图片描述
    在这里插入图片描述
    参考:https://zhuanlan.zhihu.com/p/693928854
    demo
import torch
import torch.nn as nn
import math

#GQA
bs=3
seq_len =5
hidden_size= 32
n_heads=4
n_kv_heads = 2
head_dim = hidden_size//n_heads #
groups = n_heads//n_kv_heads # 4/2
print("groups=",groups)
x=torch.randn((bs,seq_len,hidden_size))
print("x:", x.shape)
wq = nn.Linear(hidden_size,n_heads*head_dim,bias=False)
wk = nn.Linear(hidden_size, n_kv_heads * head_dim, bias=False)
wv = nn.Linear(hidden_size, n_kv_heads * head_dim, bias=False)
xq,xk,xv=wq(x),wk(x),wv(x)
xq = xq.view(bs,seq_len, n_heads, head_dim).transpose(1, 2)
xk = xk.view(bs,seq_len, n_kv_heads, head_dim).transpose(1, 2)
xv = xv.view(bs,seq_len, n_kv_heads, head_dim).transpose(1, 2)
print("xq:",xq.shape) #[bs,n_heads,seq_len, head_dim]
print("xk:", xk.shape)#[bs,n_kv_heads,seq_len, head_dim]
print("xv:", xv.shape)#[bs,n_kv_heads,seq_len, head_dim]
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
    keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
    values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
    return keys, values
#复制kv head
key,val = repeat_kv(xk,xv, groups,dim=1)
print("key:", key.shape)
print("val:", val.shape)
attn_weights = torch.matmul(xq, key.transpose(2, 3)) / math.sqrt(head_dim)
print("attn_weights:", attn_weights.shape) #[bs,n_heads,seq_len,seq_len]
attn_output = torch.matmul(attn_weights, val)
print("attn_output:", attn_output.shape)  # [bs,n_heads,seq_len,head_dim]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1631597.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Llama 3 安装使用方法

Llama3简介: llama3是一种自回归语言模型,采用了transformer架构,目前开源了8b和70b参数的预训练和指令微调模型,400b正在训练中,性能非常强悍,并且在15万亿个标记的公开数据进行了预训练,比ll…

Java设计模式 _结构型模式_桥接模式

一、桥接模式 1、桥接模式 桥接模式(Bridge Pattern)是一种结构型设计模式。用于把一个类中多个维度的抽象化与实现化解耦,使得二者可以独立变化。 2、实现思路 使用桥接模式,一定要找到这个类中两个变化的维度:如支…

什么是中间件?中间件有哪些?

什么是中间件? 中间件(Middleware)是指在客户端和服务器之间的一层软件组件,用于处理请求和响应的过程。 中间件是指介于两个不同系统之间的软件组件,它可以在两个系统之间传递、处理、转换数据,以达到协…

[论文笔记]GAUSSIAN ERROR LINEAR UNITS (GELUS)

引言 今天来看一下GELU的原始论文。 作者提出了GELU(Gaussian Error Linear Unit,高斯误差线性单元)非线性激活函数: GELU x Φ ( x ) \text{GELU} x\Phi(x) GELUxΦ(x),其中 Φ ( x ) \Phi(x) Φ(x)​是标准高斯累积分布函数。与ReLU激活函数通过输入…

Spring Web MVC入门(3)——响应

目录 一、返回静态页面 RestController 和 Controller之间的关联和区别 二、返回数据ResponseBody ResponseBody作用在类和方法的情况 三、返回HTML代码片段 响应中的Content-Type常见的取值: 四、返回JSON 五、设置状态码 六、设置Header 1、设置Content…

docker如何生成springboot镜像

1、在springboot的jar包所在的目录下创建Dockerfile文件,此案例的目录为/usr/java Dockerfile的文件内容如下: FROM openjdk:8 LABEL author"zengyanhui" LABEL email"1181159889qq.com" WORKDIR /usr/java/springbootdemo COPY s…

动漫渐显引导页HTML5单页源码

挺不错的动漫渐显引导页,记事本右键打开即可修改~ 动漫渐显引导页HTML5单页源码

重生之我是Nginx服务专家

nginx服务访问页面白色 问题描述 访问一个域名服务返回页面空白,非响应404。报错如下图。 排查问题 域名解析正常,网络通讯正常,绕过解析地址访问源站IP地址端口访问正常,nginx无异常报错。 在打开文件时,发现无法…

179. 最大数(LeetCode)

文章目录 前言一、题目讲解二、算法原理三、代码编写1.仿函数写法2.lambda表达式 四、验证五.总结 前言 在本篇文章中,我们将会带着大家采用贪心的方法解决LeetCode中最大数这道问题!!! 一、题目讲解 一组非负整数,包…

机器学习的指标评价

之前在学校的小发明制作中,在终期答辩的时候,虽然整个项目的流程都答的很流畅。 在老师提问的过程中,当老师问我recall,precision,accuracy等指标是如何计算的,又能够表示模型的哪方面指标做得好。我听到这个问题的时候&#xff…

信息系统项目管理师0076:应用集成(5信息系统工程—5.3系统集成—5.3.5应用集成)

点击查看专栏目录 文章目录 5.3.5应用集成5.3.5应用集成 随着网络和互联网的发展以及分布式系统的日益流行,大量异构网络及各计算机厂商推出的软、硬件产品分布在分布式系统的各层次(如硬件平台、操作系统、网络协议、计算机应用),乃至不同的网络体系结构上都广泛存在着互操…

10.通用定时器

驱动电机 RGB LED亮度(呼吸灯) 舵机(遥控车、机械臂) 通用定时器作用 1.延时 2.定时器更新中断 3.输出比较(PWM波、驱动IO输出波形(脉冲)) 4.输入捕获&…

VMware安装ubuntun虚拟机使用桥接模式无法上网问题解决

问题:最近准备使用VMware虚拟机搭建k8s集群服务,因为需要在同一个网段下,我使用桥接的方式,我发现主机在使用有线连接时虚拟机网络连接正常,但是使用无线网就显示连接不上网络。 解决方法 一、查看网络连接&#xff…

Codeforces Round 941 (Div. 2)(A-D)

A. Card Exchange(思维 Problem - A - Codeforces 题目大意: 给定n张牌,每次选k张相同的牌,把他们变成k-1张任意的牌,求最后手中最少能有几张牌。 思路: 直接判断这n张牌当中有没有k张一样的牌&#xff0c…

Python快速入门1数据类型(需要具有编程基础)

数据类型: Python 3.0版本中常见的数据类型有六种: 不可变数据类型可变数据类型Number(数字)List(列表)String(字符串)Dictionary(字典)Tuple(元…

NLP transformers - 文本分类

Text classification 文章目录 Text classification加载 IMDb 数据集Preprocess 预处理EvaluateTrainInference 本文翻译自:Text classification https://huggingface.co/docs/transformers/tasks/sequence_classification notebook : https://colab.research.googl…

明日周刊-第8期

现在露营的人越来越多了,都是带着帐篷或者遮阳篷聚在一起喝喝茶聊聊天,这是一种很好的放松方式。最近我养了一只金毛,目前两个月大,非常可爱。 文章目录 一周热点资源分享言论歌曲推荐 一周热点 一、人工智能领域 本周&#xff…

Java面试八股之main方法的参数中字符串数组的第一个元素是什么

Java main方法的参数中字符串数组的第一个元素是什么 Java main 方法的参数中字符串数组的第一个参数通常是指命令行启动Java应用程序时传递给该程序的第一个命令行参数。当您在命令行中执行一个Java应用程序,可以跟随类名后面附加一系列参数,这些参数将…

Debian 系统设置SSH 连接时长

问题现象: 通过finalshell工具连接Debian系统远程操作时,总是一下断开一下断开,要反复重新连接 ,烦人! 解决办法: 找到ssh安装目录下的配置文件:sshd_config vi sshd_config : 找到…

李沐70_bert微调——自学笔记

微调BERT 1.BERT滴哦每一个词元返回抽取了上下文信息的特征向量 2.不同的任务使用不同的特性 句子分类 将cls对应的向量输入到全连接层分类 命名实体识别 1.识别应该词元是不是命名实体,例如人名、机构、位置 2.将非特殊词元放进全连接层分类 问题回答 1.给…