LLM - GPT(Decoder Only) 类模型的 KV Cache 公式与原理 教程

news2024/9/21 13:52:05

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/141605718

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


Img

在 GPT 类模型中,KV Cache (键值缓存) 是用于优化推理效率的重要技术,基本思想是通过缓存先前计算的 键(Key) 和 值(Value),避免在推理过程中,重复计算 Mask 的 注意力(Attention) 矩阵,从而加速生成过程。

1. 公式

矩阵乘法的基础性质:

A ⋅ B = [ A 1 A 2 … A n ] ⋅ [ B 1 B 2 ⋮ B n ] = A 1 B 1 + A 2 B 2 + ⋯ + A n B n A \cdot B = \begin{bmatrix} A_{1} & A_{2} & \dots & A_{n} \end{bmatrix} \cdot \begin{bmatrix} B_{1} \\ B_{2} \\ \vdots \\ B_{n} \end{bmatrix} = A_{1}B_{1} + A_{2}B_{2} + \dots + A_{n}B_{n} AB=[A1A2An] B1B2Bn =A1B1+A2B2++AnBn

其中 A i A_{i} Ai A A A 的行向量, B i B_{i} Bi B B B 的列向量,也就是说相同维度的向量相乘,可拆解成行向量乘以列向量,即 A A A n n n 列, B B B n n n 行。

例如:基础的矩阵乘法:

A = [ 1 2 3 4 ] , B = [ 5 6 7 8 ] C = [ 1 ∗ 5 + 2 ∗ 7 1 ∗ 6 + 2 ∗ 8 3 ∗ 5 + 4 ∗ 7 3 ∗ 6 + 4 ∗ 8 ] = [ 19 22 43 50 ] A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}, \quad B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \\ C = \begin{bmatrix} 1*5 + 2*7 & 1*6 + 2*8 \\ 3*5 + 4*7 & 3*6 + 4*8 \end{bmatrix} = \begin{bmatrix} 19 & 22 \\ 43 & 50 \end{bmatrix} A=[1324],B=[5768]C=[15+2735+4716+2836+48]=[19432250]

也可以写成,行列向量相乘的形式,即 A 拆分出多个行向量,B 拆分出多个列向量,即:

C = [ 1 3 ] ⋅ [ 5 6 ] + [ 2 4 ] ⋅ [ 7 8 ] = [ 1 ∗ 5 1 ∗ 6 3 ∗ 5 3 ∗ 6 ] + [ 2 ∗ 7 2 ∗ 8 4 ∗ 7 4 ∗ 8 ] C = \begin{bmatrix} 1 \\ 3 \end{bmatrix} \cdot \begin{bmatrix} 5 & 6 \end{bmatrix} + \begin{bmatrix} 2 \\ 4 \end{bmatrix} \cdot \begin{bmatrix} 7 & 8 \end{bmatrix}= \begin{bmatrix} 1*5 & 1*6 \\ 3*5 & 3*6 \end{bmatrix} + \begin{bmatrix} 2*7 & 2*8 \\ 4*7 & 4*8 \end{bmatrix} C=[13][56]+[24][78]=[15351636]+[27472848]
= [ 1 ∗ 5 + 2 ∗ 7 1 ∗ 6 + 2 ∗ 8 3 ∗ 5 + 4 ∗ 7 3 ∗ 6 + 4 ∗ 8 ] = [ 19 22 43 50 ] =\begin{bmatrix} 1*5 + 2*7 & 1*6 + 2*8 \\ 3*5 + 4*7 & 3*6 + 4*8 \end{bmatrix} = \begin{bmatrix} 19 & 22 \\ 43 & 50 \end{bmatrix} =[15+2735+4716+2836+48]=[19432250]

进一步拆解:

A ⋅ B = A 1 B 1 + A 2 B 2 + ⋯ + A n B n = [ a 1 , 1 B 1 a 2 , 1 B 2 ⋮ a m , 1 B n ] + [ a 1 , 2 B 1 a 2 , 2 B 2 ⋮ a m , 2 B n ] + ⋯ + [ a 1 , n B 1 a 2 , n B 2 ⋮ a m , n B n ] = [ a 1 , 1 B 1 + a 1 , 2 B 1 + ⋯ + a 1 , n B 1 a 2 , 1 B 2 + a 2 , 2 B 2 + ⋯ + a 2 , n B 2 ⋯ a m , 1 B n + a m , 2 B n + ⋯ + a m , n B n ] A \cdot B = A_{1}B_{1} + A_{2}B_{2} + \dots + A_{n}B_{n} \\ = \begin{bmatrix} a_{1,1}B_{1} \\ a_{2,1}B_{2} \\ \vdots \\ a_{m,1}B_{n} \end{bmatrix} + \begin{bmatrix} a_{1,2}B_{1} \\ a_{2,2}B_{2} \\ \vdots \\ a_{m,2}B_{n} \end{bmatrix} + \cdots + \begin{bmatrix} a_{1,n}B_{1} \\ a_{2,n}B_{2} \\ \vdots \\ a_{m,n}B_{n} \end{bmatrix} \\ = \begin{bmatrix} a_{1,1}B_{1} + a_{1,2}B_{1} + \cdots + a_{1,n}B_{1} \\ a_{2,1}B_{2} + a_{2,2}B_{2} + \cdots + a_{2,n}B_{2} \\ \cdots \\ a_{m,1}B_{n} + a_{m,2}B_{n} + \cdots + a_{m,n}B_{n} \end{bmatrix} AB=A1B1+A2B2++AnBn= a1,1B1a2,1B2am,1Bn + a1,2B1a2,2B2am,2Bn ++ a1,nB1a2,nB2am,nBn = a1,1B1+a1,2B1++a1,nB1a2,1B2+a2,2B2++a2,nB2am,1Bn+am,2Bn++am,nBn

基础的矩阵乘法的另一种形式:

C = [ 1 3 ] ⋅ [ 5 , 6 ] + [ 2 4 ] ⋅ [ 7 , 8 ] C=\begin{bmatrix} 1 \\ 3 \end{bmatrix} \cdot \begin{bmatrix} 5,6 \end{bmatrix} + \begin{bmatrix} 2 \\ 4 \end{bmatrix} \cdot \begin{bmatrix} 7,8 \end{bmatrix} C=[13][5,6]+[24][7,8]
[ 1 ∗ [ 5 6 ] 3 ∗ [ 5 6 ] ] + [ 2 ∗ [ 7 8 ] 4 ∗ [ 7 8 ] ] \begin{bmatrix} 1*[5&6] \\ 3*[5&6] \end{bmatrix} + \begin{bmatrix} 2*[7&8] \\ 4*[7&8] \end{bmatrix} [1[53[56]6]]+[2[74[78]8]]
[ 1 ∗ 5 1 ∗ 6 3 ∗ 5 3 ∗ 6 ] + [ 2 ∗ 7 2 ∗ 8 4 ∗ 7 4 ∗ 8 ] = [ 19 22 43 50 ] \begin{bmatrix} 1*5 & 1*6 \\ 3*5 & 3*6 \end{bmatrix} + \begin{bmatrix} 2*7 & 2*8 \\ 4*7 & 4*8 \end{bmatrix} = \begin{bmatrix} 19 & 22 \\ 43 & 50 \end{bmatrix} [15351636]+[27472848]=[19432250]

如果 A A A 是下三角矩阵,即包含 Mask 信息,Decoder 无法观察到之后的推理部分,则 A ⋅ B A \cdot B AB,输出:

A ⋅ B = [ a 1 , 1 B 1 a 2 , 1 B 2 + a 2 , 2 B 2 ⋯ a m , 1 B n + a m , 2 B n + ⋯ + a m , n B n ] A \cdot B = \left[ \begin{array}{llll} a_{1,1}B_{1}\\ a_{2,1}B_{2} + a_{2,2}B_{2}\\ \cdots \\ a_{m,1}B_{n} + a_{m,2}B_{n} + \cdots + a_{m,n}B_{n} \end{array} \right] AB= a1,1B1a2,1B2+a2,2B2am,1Bn+am,2Bn++am,nBn

2. 推理

第1步:

在 Decoder 解码过程中,只关注 Transformer 的 自注意力(Self-Attention),输入第 1 个 Token,将 Token 转换成 输入特征 I n p u t 1 = [ 1 , d e m b ] Input_{1}=[1,d_{emb}] Input1=[1,demb],暂时忽略 batch_size d e m b d_{emb} demb 表示 Embedding Size。

  1. 输入特征 I n p u t 0 = [ 1 , d e m b ] Input_{0}=[1,d_{emb}] Input0=[1,demb],乘以权重 W = [ d e m b , 3 ∗ d e m b ] W=[d_{emb}, 3*d_{emb}] W=[demb,3demb] (已训练完成,值是固定的),输出维度 [ 1 , 3 ∗ d e m b ] [1, 3*d_{emb}] [1,3demb],即作为 Q\K\V,每个向量 [ 1 , d e m b ] [1,d_{emb}] [1,demb]

    • Q 1 = [ 1 , d e m b ] Q_{1}=[1,d_{emb}] Q1=[1,demb] K 1 = [ 1 , d e m b ] K_{1}=[1,d_{emb}] K1=[1,demb] V 1 = [ 1 , d e m d ] V_{1}=[1,d_{emd}] V1=[1,demd],只与输入特征 I n p u t 0 Input_{0} Input0 的 Embedding 相关。
  2. 根据 Self-Attention 的公式,忽略 d \sqrt{d} d ,只有1维,mask 不起作用,即
    A t t ( Q , K , V ) = s o f t m a x ( Q K ⊤ + m a s k ) ∗ V A t t 1 ( Q , K , V ) = s o f t m a x ( Q 1 K 1 ⊤ ) V 1 其中  s o f t m a x ( x i ) = e x i ∑ j = 1 n e x j Att(Q,K,V)=softmax(QK^{\top}+mask)*V \\ Att_{1}(Q,K,V)=softmax(Q_{1}K_{1}^{\top})V_{1} \\ 其中 \ softmax(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} Att(Q,K,V)=softmax(QK+mask)VAtt1(Q,K,V)=softmax(Q1K1)V1其中 softmax(xi)=j=1nexjexi

  3. A t t 0 Att_{0} Att0 ( [ 1 , d e m b ] [1,d_{emb}] [1,demb]) 经过一系列推理,最后输出 [ 1 , d v ] [1, d_{v}] [1,dv] d v d_{v} dv 是全部词元 Token 的数量,根据概率值即可获得最后的 Token。

第 2 步

将第 1 步输出的 Token 转换成 [ 1 , d e m b ] [1,d_{emb}] [1,demb],与第 1 步组合至一起,即 输入特征 I n p u t 1 = [ 2 , d e m b ] Input_{1}=[2,d_{emb}] Input1=[2,demb]

  1. 输入特征 I n p u t 1 = [ 2 , d e m b ] Input_{1}=[2,d_{emb}] Input1=[2,demb],乘以权重 W = [ d e m b , 3 ∗ d e m b ] W=[d_{emb}, 3*d_{emb}] W=[demb,3demb],权重是固定的,因此只需要计算第 2 个输入的特征 [ 1 , d e m b ] [1,d_{emb}] [1,demb],第 1 个不需要计算,也就是说 Q\K\V 的维度是 [ 2 , d e m b ] [2, d_{emb}] [2,demb],只需计算一次即可,剩余的可以直接 c o n c a t concat concat 到一起。

  2. 根据 Self-Attention 的公式,忽略 d \sqrt{d} d ,注意第1行,已经计算,第2行,需要使用 Q 2 Q_{2} Q2 K 2 K_{2} K2 V 2 V_{2} V2,进行计算,即:
    A t t 2 ( Q , K , V ) = s o f t m a x ( Q K ⊤ + m a s k ) ∗ V s o f t m a x ( [ Q 1 K 1 ⊤ Q 2 K 1 ⊤ + Q 2 K 2 ⊤ ] ) ⋅ [ V 1 V 2 ] = [ s o f t m a x ( Q 1 K 1 ⊤ ) V 1 s o f t m a x ( Q 2 K 1 ⊤ ) V 1 + s o f t m a x ( Q 2 K 2 ⊤ ) V 2 ] = [ A t t 1 ( Q , K , V ) s o f t m a x ( Q 2 K 1 ⊤ ) V 1 + s o f t m a x ( Q 2 K 2 ⊤ ) V 2 ] Att_{2}(Q,K,V) = softmax(QK^{\top}+mask)*V \\ softmax(\left[ \begin{array}{ll} Q_{1}K_{1}^{\top}\\ Q_{2}K_{1}^{\top} + Q_{2}K_{2}^{\top}\\ \end{array} \right]) \cdot \begin{bmatrix} V_{1} \\ V_{2} \\ \end{bmatrix} \\= \left[ \begin{array}{ll} softmax(Q_{1}K_{1}^{\top})V_{1}\\ softmax(Q_{2}K_{1}^{\top})V_{1} + softmax(Q_{2}K_{2}^{\top})V_{2}\\ \end{array} \right] \\ = \left[ \begin{array}{} Att_{1}(Q,K,V) \\ softmax(Q_{2}K_{1}^{\top})V_{1} + softmax(Q_{2}K_{2}^{\top})V_{2}\\ \end{array} \right] Att2(Q,K,V)=softmax(QK+mask)Vsoftmax([Q1K1Q2K1+Q2K2])[V1V2]=[softmax(Q1K1)V1softmax(Q2K1)V1+softmax(Q2K2)V2]=[Att1(Q,K,V)softmax(Q2K1)V1+softmax(Q2K2)V2]

  3. KV 都是成对出现的,如果 缓存 KV,则可以加快推理速度。

第 3 步:重复进行。

3. 缓存占用

关于 Llama3 的 KV Cache 源码,参考 model.py:

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

关于 KV 的缓存内存占用:

相关参数 batch_size=32head=32layer=32dim_size=4096seq_length=2048,float32(4个字节)类,计算 KV cache 的缓存占用:
M = 2 ∗ N b s ∗ ( N d i m / N h e a d ∗ N h e a d ) ∗ N l a y e r ∗ N s e q ∗ 4 = 2 ∗ 32 ∗ 4096 ∗ 32 ∗ 2048 ∗ 4 / 1024 / 1024 / 1024 = 64 G M=2*N_{bs}*(N_{dim}/N_{head}*N_{head})*N_{layer}*N_{seq}*4 \\ =2*32*4096*32*2048*4/1024/1024/1024=64G M=2Nbs(Ndim/NheadNhead)NlayerNseq4=23240963220484/1024/1024/1024=64G
也就是说 head 数量无关,因为维度除以 Head 再乘以 Head。Llama3 使用 GQA (Grouped Query Attention) 分组查询注意力机制,降低 4 倍的 KV Cache,head=32,kv_head=8,即 scale=head/kv_head=4

参考:

  • CSDN - 从头开始实现 LLaMA3 的网络结构与推理流程 教程
  • Transformers KV Caching Explained

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2078929.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

element+plus中导航菜单关于index报错

elementplus中导航菜单关于index报错 官方解释必须为string 修改后正常使用 而当动态插入的时候,需要使用:index,且需要写为: :index“String(index)” index是为dom加的属性(是定值),而:index加的值相当…

Python开发工具:PyCharm

本文是 Python 系列教程第 2 篇,完整系列请查看 Python 专栏。 1 安装 官网下载地址https://www.jetbrains.com.cn/pycharm/,文件比较大(约861MB)请耐心等待 双击exe安装 安装成功后会有一个30天的试用期,激活教程见先…

Level3 — PART 4 机器学习算法 — 决策树

目录 引言 信息量 信息熵 案例 ID3 属性选择—信息增益 决策树生成 Python实现ID3 C4.5 属性选择—信息增益率 连续型属性 缺失值 剪枝 CART 分类树属性选择—基尼系数 回归树属性选择—方差 剪枝 Python实现CART CHAID GBRT 决策树对比 模拟题 CDA L…

GDB基础指令分类与汇总

前言 在图形化界面中,我们进行调试一般而言比较方便,举例如下: 不过有时候,我们在Linux下没有这样的图形界面,这时可以使用GDB调试器来帮我们完成上面的工作。 GDB基础指令分类与汇总 类别指令含义举例基本使用gcc pro…

源代码防泄密的途径有很多种,如何保护源代码呢

随着各行各业业务数据信息化发展,各类产品研发及设计等行业,都有关乎自身发展的核心数据,包括业务数据、代码数据、机密文档、用户数据等敏感信息,这些信息数据有以下共性: — 属于核心机密资料,万一泄密会…

【实践经验】端口被占用问题:listen tcp:bind:only one usage of each socket address

文章目录 一. 问题描述二. 分析1. 适用错误 三. 解决方法1. 打开控制台2. 查看端口的使用情况2.1 不知道端口号——查看所有运行的端口2.2 知道端口号 3. 查看使用进程的程序4. 杀死进程5. 验证端口是否释放 一. 问题描述 goland启动项目后报错:“listen tcp:bind:…

四、监控搭建-Prometheus-采集端批量部署

四、监控搭建-Prometheus-采集端批量部署 1、背景2、目标3、传承4、操作4.1、准备部署工具4.2、编制部署脚本4.3、服务端添加客户端 1、背景 在前三篇中我们搭建了Prometheus平台,采集端部署和配合图形化grafana部署,将Linux主机进行监控。基本完成了一…

day41.动态规划

一.动态规划 121.买卖股票的最佳时机I 思路: dp[i][1] 表示第i天不持有股票所得最多现金 dp[i][0] 表示第i天持有股票所得最多现金 相信有人和我有一样的疑惑,为什么dp【i】【0】的转移是dp【i-1】【0】,-price【i】,因为题目规定了只能进行一次买卖&…

Linux信号处理机制基础

什么是信号 信号在最早的UNIX系统中即被引入,已有30多年的历史,但只有很小的变化。信号是提供异步事件处理机制的软件中断。进程之间可以相互发送信号,这使信号成为一种进程间通信(Inter-ProcessCommunication,lPC)的基本手段 信号的名称与…

水控器数码管驱动方案

目录 方案1 方案2 总结 方案1 数码管驱动电路选用2片74HC595和外围电阻实现,如图1所示。74HC595的封装为S0-16(窄体),芯片价格0.42,整个LED驱动电路成本约0.9元(不包含数码管)。 图1、74HC595驱动电路 方案2 为减少PCB板密度,数…

x86中部署docker环境

使用dockerhub搜索Ubuntu x86 1、拉取镜像 docker pull balenalib/odyssey-x86-ubuntu 2、查看镜像 docker images 3、保存镜像 docker save -o ubuntuX86.tar ubuntu/x86:v1 4、加载镜像 docker load -i ubuntuX86.tar 5、创建并运行容器 docker run -itd balenalib/odyssey-…

灵办AI搜索引擎和文档总结工具

前言—— 在信息爆炸的时代,如何高效地获取和处理知识成为了每个人面临的挑战。随着人工智能技术的迅猛发展,本文将深入探讨这一创新工具的功能与优势,以及如何在日常生活和工作中充分利用它,开启智能化的信息获取新篇章。 点击…

计算机毕业设计选题推荐-剧本杀服务平台-剧本杀拼团管理系统-Java/Python项目实战

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

【刷课利器】一条指令完成网页视频完播

网页右击“检查” 点击控制台控制台输入 document.querySelector(‘video’).currentTime document.querySelector(‘video’).duration

封装CUDA为动态链接库+Qt调用

由于工作需要在Qt中调用CUDA做并行计算,加速算法实现时间,发现有两种方法可以在Qt中调用CUDA代码。 第一种是在项目中创建CUDA的cu文件,编写CUDA的核函数给其他的QT代码调用,Qt的代码正常编译,CUDA代码使用nvcc编译器编…

无敌保姆级华为认证 HCIE 笔试+实验考试指引,简直不要太详细

HCIE(Huawei Certified ICT Expert,华为认证ICT专家)是华为认证体系中最高级别的ICT技术认证,旨在打造高含金量的专家级认证,为技术融合背景下的ICT产业提供新的能力标准,以实现华为认证引领ICT行业技术认证…

网安面试设备篇幅:安全准入

吉祥知识星球http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247485367&idx1&sn837891059c360ad60db7e9ac980a3321&chksmc0e47eebf793f7fdb8fcd7eed8ce29160cf79ba303b59858ba3a6660c6dac536774afb2a6330#rd 《网安面试指南》http://mp.weixin.qq.com/s…

悦数 RAG 正式亮相 :从知识到应用的飞跃,只要几分钟

自 2023 年 8 月悦数与 LlamaIndex 联合发布 Graph RAG 以来,该技术就一直处于技术潮流的前沿。它通过提供更具上下文感知的能力和数据训练的方法,缓解了传统搜索增强技术的幻觉,确保所提供的回复不仅精确,而且有足够丰富的信息。…

科目三灯光模拟满分操作大全!建议收藏

今天一起备考一下科三的灯光模拟考试吧~它可以说是科目三中容易被扣分的操作了,考试开始一旦操作错误,就直接挂科了!要想满分通过,这里为大家总结了下面这些窍门~ 操作步骤归类总结 01.开启近光灯 语音指令: 夜间与…

sql 4,创建表类型

1,整数类型(类型,占有空间,范围)标准sql:int / integer 4字节 无符号 0 - 2/32-1 有符号 -2 31 / 2 / 31 -1 smallint 2字节 无符号 0 - 2/16-1 有符号 -2 17 / 2 / 17 -1mysql方言:tinyint 1字节 无符号 0 - 2/8 -1 有符号 -2 7 / 2/7-1med…