本文记录大模型推理阶段 KV Cache 的原理及显存占用情况。
Self-Attention 与 KV Cache
如图,当新生成的 token x 进到模型计算 Attention 时,先分别乘上参数矩阵
W
q
W_q
Wq、
W
k
W_k
Wk、
W
v
W_v
Wv 得到向量 q,以及矩阵 K、V。然后根据下面公式计算当前 token 跟前面 tokens 的注意力权重(本文为了简化,不考虑多头 MHA)。
自回归生成过程中,K和V矩阵并没有太大变化,比如下图中 cold 这个词对应了 K 的某一列和 V 的某一行,算完就放那里不再变了。
轮到生成 chill 这个词时,其实只需要在原始 K 矩阵追加一列,原始 V 矩阵追加一行,而没必要每生成一个 token 都重新计算一遍 K、V 矩阵,这便是 KV Cache
的意义。
因此在推理的时候,不用每次传入前面全部 token 序列的 embedding,而只需传入 KV Cache 以及当前 token x 的 embedding。Transformer 在算完当前 token x 的 Attention 之后,会把新的 K’ 和 V’ 更新到 GPU 显存中。 左图中 Masked Multi Self Attention
这块也是唯一和前面序列有交互的模块,其他模块(比如 Layer Norm、FFN、位置编码等)都不涉及跟已生成 token 的交互。
KV Cache 显存占用分析
KV Cache 显存计算方式如下:
2
∗
p
r
e
c
i
s
i
o
n
∗
n
l
a
y
e
r
∗
d
m
o
d
e
l
∗
s
e
q
_
l
e
n
∗
b
a
t
c
h
_
s
i
z
e
2 * precision * n_{layer} * d_{model} * seq\_len * batch\_size
2∗precision∗nlayer∗dmodel∗seq_len∗batch_size
- 2 2 2 是指 K 跟 V 俩矩阵。
- p r e c i s i o n precision precision 是模型每个参数的字节数,比如 fp32 精度下每个参数 4 字节。
- n l a y e r n_{layer} nlayer 和 n m o d e l n_{model} nmodel 分别是模型 Decoder layer 层数和 embedding 维度大小。
- s e q _ l e n seq\_len seq_len、 b a t c h _ s i z e batch\_size batch_size 顾名思义分别是最大序列长度和 global batch size。
比如以 OPT-30B 模型(bf16,48层,7168维,1024上下文,128 batch size)为例,KV Cache 占的显存是:
2
∗
2
∗
48
∗
7168
∗
1024
∗
128
=
180
,
388
,
626
,
432
b
y
t
e
s
≈
180
G
B
2*2*48*7168*1024*128 \\=180,388,626,432 bytes \\≈ 180GB
2∗2∗48∗7168∗1024∗128=180,388,626,432bytes≈180GB
模型本身仅占显存:
2
∗
30
B
=
60
B
b
y
t
e
s
=
60
G
B
2*30B=60Bbytes=60GB
2∗30B=60Bbytes=60GB
光 KV Cache 就顶模型本身占显存的3倍。(当然一般推理时 batch size是1,这时候KV Cache显存占用就砍到 1/128 了,不过 batch 模式能够最大化利用显存,所以这也是为啥各个大模型厂商 batch 模型都比较便宜了)
参考资料:油管《The KV Cache: Memory Usage in Transformers》