KV cache
对于decoder-only 模型比如现在如火如荼的大模型,其在生成内容的过程中,为了避免冗余计算,会将Transformer里的self-attention的K和V矩阵给缓存起来,这个过程即为KV cache。
decoder-only模型的生成过程是自回归的(auto-regressive),生成过程中先根据输入生成下一个token,再将生成的token与输入一起生成下一个token,重复这个过程直到遇到停止符号或者达到限定的输出token个数。(gif图来自illustrated-gpt2)
因为decoder-only模型的生成过程是自回归的,并且decoder的self-attention是causal的,即每一个token的attention计算只与其前面的tokens有关,所以我们每生成一个token时都重复计算了前面出现过的token的attention。为了节省计算量,可以将已经计算过的token的attention矩阵存储下来,每生成下一个token时直接使用存储好的attention矩阵并将新计算的token attention存储起来。(下面图片来自博客,不考虑softmax和scale示意对比KV cache使用)
在每一步计算时,只需要使用到上一步计算过的K和V矩阵,所以KV cache只会缓存K和V。当然缓存的代价就是需要额外的显存存储:
- 每缓存一个token,其需要的空间为
2 * precision_in_bytes * head_dim * n_heads * n_layers
(式中2是因为缓存K和V两个矩阵,precision_in_bytes
是token的存储精度占用字节大小,head_dim
是attention的head维度,n_head
是attention的head个数,n_layers
是transformer的层个数)。 - 对于16-bit精度的模型以最大上下文长度
max_context_length
进行批量推理要求的缓存大小2 * 2 * head_dim * n_heads * n_layers * max_context_length * batch_size
,比如Llama-2-13B模型对应最大上下文窗口为4096,batch大小为8时要求的缓存显存最多高达25GB左右。
transformers包生成时默认使用KV cache(use_cache=True),我们可以用如下代码去测试一下使用了KV cache以及不使用时的性能差异。
## 代码来自 https://medium.com/@joaolages/kv-caching-explained-276520203249
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
for use_cache in (True, False):
times = []
for _ in range(10): # measuring 10 generations
start = time.time()
model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)
times.append(time.time() - start)
print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")
Multi-query attention 和Grouped-query attention
Multi-query attention
Multi-query attention(MQA)出自2019年11月的论文《Fast Transformer Decoding: One Write-Head is All You Need》,它让multi-head attention里的多个head共享K和V矩阵,并做试验验了修改之后模型的性能下降不明显,但因为减少了参数,推理时KV cache占用的存储和读取时间都会少很多。
Grouped-query attention
Grouped-query attention(GQA)出自2023年5月的论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》, 如上图所示,它的共享K和V矩阵介于Multi-query attention(MQA)和Multi-head attention(MHA)之间,通过实验证明GQA可达到类似MQA的速度以及MHA的性能。
Grouped-query attention将query heads划分为G个groups,每一组query heads共享一个key head和value head,将 G Q A − G GQA_{-G} GQA−G 记为有G个groups的grouped-query attention,则 G Q A − 1 GQA_{-1} GQA−1为Multi-query attention, G Q A − H GQA_{-H} GQA−H则等价于Multi-head attention。
论文还提出了一个将Multi-head attention模型转变MQA或GQA模型的方法,其分为两步:
- 将MHA模型的checkpoint转变成MQA或GQA模型,使用如下图示意的mean pooling将多个K和V矩阵变成单个矩阵(论文做了试验比较选取第一个head、随机初始化、mean pooling,mean pooling的效果是最好的)。
- 使用少量比例(5%左右)的预训练数据来继续预训练使模型适应新结构。
关于GQA的组个数选取,论文做了消融实验后对于总head个数为64时G选取的是8,而在Llama2-70B模型也是8(总heads数也为64)。
实现
不考虑性能的代码示意如下:
from dataclasses import dataclass
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
@dataclass
class GPTConfig:
block_size: int = 1024 # max sequence length
vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
n_layer: int = 12 # number of layers
n_head: int = 12 # number of heads
n_embd: int = 768 # embedding dimension
n_kv_heads: int = 12 # grouped-query的group个数
def repeat_kv(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Perform repeat of kv heads along a particular dimension.
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
n_rep: amount of repetitions of kv_n_heads
Unlike torch.repeat_interleave, this function avoids allocating new memory.
from https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py#L47
llama2里的写法差不多https://github.com/meta-llama/llama/blob/llama_v2/llama/model.py#L164C1-L165C1
"""
if n_rep == 1:
return hidden
(b, s, kv_n_heads, d) = hidden.shape
hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
return hidden.reshape(b, s, kv_n_heads * n_rep, d)
## adapt from https://github.com/karpathy/nanoGPT/blob/master/model.py
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
# not really a 'bias', more of a mask, but following the OpenAI/HF naming though
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# attention (materializes the large (T,T) matrix for all the queries and keys)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
### multi-query
class MultiQueryAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_embd//config.n_head)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
# not really a 'bias', more of a mask, but following the OpenAI/HF naming though
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
qkv = self.c_attn(x)
q, k, v = qkv.split([self.n_embd, self.n_embd//self.n_head, self.n_embd//self.n_head], dim=2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
k = repeat_kv(k.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = repeat_kv(v.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)
# attention (materializes the large (T,T) matrix for all the queries and keys)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
### grouped-query attention
class GroupedQueryAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_kv_heads*config.n_embd//config.n_head)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
self.n_kv_heads = config.n_kv_heads
# not really a 'bias', more of a mask, but following the OpenAI/HF naming though
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
qkv = self.c_attn(x)
q, k, v = qkv.split([self.n_embd, self.n_kv_heads*self.n_embd//self.n_head, self.n_kv_heads*self.n_embd//self.n_head], dim=2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
k = repeat_kv(k.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)
v = repeat_kv(v.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)
# attention (materializes the large (T,T) matrix for all the queries and keys)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
Sliding Window Attention
Mistral 7B使用Sliding Window Attention(SWA)来减少KV cache的内存占用,每次计算attention时,只考虑固定窗口大小W内的信息。对于位置i的隐状态,只会考虑在其前面i-W到i的窗口内的隐状态信息,如下图示意所示,所以对于在第k层的位置i来说,最多可以访问到 W × k W\times k W×k个tokens。在Mistral 7B里,W=4096,层数为32,所以理论上的attention范围近似为131K。
因为使用固定attention窗口,所以Mistral 7B使用滚动(rolling) buffer cache, cache大小固定为W,在时刻t的K和V存储在cache的第i mod W
个位置,也就是说如果位置i比W大,cache中原先存储的值会被覆盖掉。下图是W=3时的示意。
参考资料
-
看图学KV Cache
-
Transformer Inference Arithmetic
-
Transformers KV Caching Explained(其gif动画有助于加深理解)
-
KV caching内存增长
-
KV cache 是chatbot 规模化的一大工程挑战
-
Techniques for KV Cache Optimization in Large Language Models
-
KV cache quantization
-
Inference Optimization