大模型推理优化技术-KV Cache

news2025/1/22 19:38:39

近两年大模型火出天际;同时,也诞生了大量针对大模型的优化技术。本系列将针对一些常见大模型优化技术进行讲解。

  • 大模型推理优化技术-KV Cache
  • 大模型推理服务调度优化技术-Continuous batching
  • 大模型底显存推理优化-Offload技术
  • 大模型推理优化技术-KV Cache量化
  • 大模型推理优化技术-KV Cache优化方法综述
  • 大模型访存优化技术-FlashAttention
  • 大模型显存优化技术-PagedAttention
  • 大模型解码优化-Speculative Decoding及其变体

另外,我撰写的大模型相关的博客及配套代码均整理放置在Github:llm-action,有需要的朋友自取。

而本文将针对仅解码器Transformer架构(Decoder-Only Transformer)的模型必备推理优化技术 KV Cache 进行讲解。

image.png

KV Cache 简介

KV Cache 是大模型推理性能优化的一个常用技术,该技术可以在不影响任何计算精度的前提下,通过空间换时间的思想,提高推理性能。

KV Cache 诞生的背景

对于仅解码器Transformer架构的模型的推理,我们给一个输入文本,模型会输出一个回答(长度为 N),其实该过程中执行了 N 次推理过程。即类 GPT 的仅解码器模型一次推理只输出一个token,输出的 token 会与输入 tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。

针对一个仅解码器Transformer架构的模型,假设用户输入为“recite the first law”,模型续写得到的输出为“A robot may not ”,模型的生成过程如下:

  1. 将“ecite the first law”输入模型,得到每个token的注意力表示。使用“law”的注意力表示,预测得到下一个token为“A”(实际还需要将该注意力表示映射成概率分布logits,为了方便叙述,我们忽略该步骤)。
  2. 将“A”拼接到原来的输入,得到“recite the first law A”,将其输入模型,得到注意力表示,使用“A”的注意力表示,预测得到下一个token为“robot”。
  3. 将“robot”拼接到原来的输入,依此类推,预测得到“robot”,最终得到“recite the first law A robot may not”

image.png

仅解码器Transformer架构的自回归模型为带 Masked 的 Self Attention。因此,在没有KV Cache的情况下,其计算过程如下所示。

image.png

正常情况下,Attention的计算公式如下:

image.png

为了看上去方便,我们暂时忽略scale项,因此,Attention的计算公式如下所示(softmaxed 表示已经按行进行了softmax):

image.png

image.png

Q K T QK^T QKT变为矩阵时,softmax 会针对行进行计算,详细如下(softmaxed 表示已经按行进行了softmax):

image.png

其中, A t t 1 ( Q , K , V ) Att_1(Q,K,V) Att1(Q,K,V)表示 Attention 的第一行, A t t 2 ( Q , K , V ) Att_2(Q,K,V) Att2(Q,K,V)表示 Attention 的第二行。

image.png

对于 A t t 1 ( Q , K , V ) Att_1(Q,K,V) Att1(Q,K,V),由于 Q 1 K 2 T Q_1K_2^T Q1K2T这个值会mask掉,你会发现, Q 1 Q_1 Q1 在第二步参与的计算与第一步是完全一样的,并且 V 1 V_1 V1 参与计算Attention时也仅仅依赖于 Q 1 Q_1 Q1 ,与 Q 2 Q_2 Q2 毫无关系。

对于 A t t 2 ( Q , K , V ) Att_2(Q,K,V) Att2(Q,K,V) V 2 V_2 V2 参与计算Attention时也仅仅依赖于 Q 2 Q_2 Q2 ,与 Q 1 Q_1 Q1 毫无关系。

image.png

其计算方式如 Step2 所示。

image.png

image.png

其计算方式如 Step2 所示。

image.png

对于 A t t k ( Q , K , V ) Att_k(Q,K,V) Attk(Q,K,V) V k V_k Vk 参与计算Attention时也仅仅依赖于 Q k Q_k Qk

看上面图和公式,我们可以得出以下结论:

  1. 当前计算方式存在大量冗余计算,每一次生成新的Token都需要计算之前的KV。
  2. A t t k ( Q , K , V ) Att_k(Q,K,V) Attk(Q,K,V)的计算过程中,主要与 Q k Q_k Qk 有关。 V k V_k Vk 参与计算Attention时也仅仅依赖于 Q k Q_k Qk
  3. 每一步中,其实只需要根据 Q k Q_k Qk 计算 A t t k ( Q , K , V ) Att_k(Q,K,V) Attk(Q,K,V) 就可以,之前已经计算的Attention完全不需要重新计算。但是 K 和 V 是全程参与计算的,所以这里我们需要把每一步的 K 、 V 缓存起来。

KV Cache 步骤

正是因为 Self Attention 中带 Masked ,因此,在推理的时候,前面已经生成的 Token 不需要与后面的 Token 产生 Attention ,从而使得前面已经计算的 K 和 V 可以缓存起来。

一个典型的带有 KV cache 优化的生成大模型的推理过程包含了两个阶段:

  1. 预填充阶段:输入一个prompt序列,为每个transformer层生成 key cache 和 value cache(KV cache)。

  2. 解码阶段:使用并更新KV cache,一个接一个地生成token,当前生成的token词依赖于之前已经生成的token。

预填充阶段计算过程如下:

image.png

解码阶段计算过程如下:

image.png

使不使用 KV Cache 的对比

下图展示了使用KV Cache和不使用KV Cache的对比,其中,紫色部分表示从缓存获取,灰色部分表示会被Masked。

image.png

image.png

image.png

image.png

下面使用 transformers 来比较有 KV Cache 和没有 KV Cache的情况下,GPT-2的生成速度。

import numpy as np  
import time  
import torch  
from transformers import AutoModelForCausalLM, AutoTokenizer  
  
device = "cuda" if torch.cuda.is_available() else "cpu"  
tokenizer = AutoTokenizer.from_pretrained("gpt2")  
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)  
  
for use_cache in (True, False):  
    times = []  
    for _ in range(10): # measuring 10 generations  
        start = time.time()  
        model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)  
        times.append(time.time() - start)  
    print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

运行结果:

  • 使用 KV caching: 11.885 ± 0.272 秒
  • 不使用 KV caching: 56.197 ± 1.855 秒

可以看到使不使用 KV cache 推理性能果差异显存。

使用 KV Cache 解码阶段计算量分析

FLOPs,floating point operations,表示浮点数运算次数,衡量了计算量的大小。
如何计算矩阵乘法的FLOPs呢?
对于 A ∈ R 1 × n , B ∈ R n × 1 A∈R^{1×n},B∈R^{n×1} AR1×n,BRn×1 ,计算 AB 需要进行 n 次乘法运算和 n-1 次加法运算(假设 n > > 1 n>>1 n>>1,则约等于n),共计 2n 次浮点数运算,需要 2 n 2n 2n 的FLOPs。对于 A ∈ R m × n , B ∈ R n × p A∈R^{m×n}, B∈R^{n×p} ARm×n,BRn×p ,计算 AB 需要的浮点数运算次数为 m ∗ 2 n ∗ p = 2 m n p m*2n*p=2mnp m2np=2mnp

下面来看看在一个 Token 生成过程中一层 Transformer 的计算量。

首先,分析 self-attention 块的计算,计算公式如下:

Q = x W Q , K = x W K , V = x W V Q=xW_Q,K=xW_K,V=xW_V Q=xWQ,K=xWK,V=xWV

x o u t = s o f t m a x ( Q K T h ) ⋅ V ⋅ W O + x x_{out}=softmax(\frac {QK^T}{\sqrt h}) \cdot V \cdot W_O + x xout=softmax(h QKT)VWO+x

我们来看看不使用 KV Cache 时,假设输入数据的形状为 [b, s],隐藏层维度为 h,则输入的形状为 [b, s, h]。self-attention块的计算如下:

  1. 计算 Q,K,V :矩阵乘法的输入和输出形状为 [b, s, h]x[h,h]->[b, s, h] 。计算量为 $ 3* bs2hh = 3∗2bsh2=6bsh2$ 。
  2. Q K T QK^T QKT 矩阵乘法的输入和输出形状为 [b, head_num, s, per_head_hidden_size]×[b, head_num, per_head_hidden_size, s]→[b, head_num, s, s],计算量为 b ∗ s ∗ 2 h ∗ s = 2 b s 2 h b*s*2h*s=2bs^2h bs2hs=2bs2h
  3. 计算在 V 上的加权 $score \cdot V $,矩阵乘法的输入和输出形状为 [b, head_num, s, s]×[b, head_num, s, per_head_hidden_size]→[b, head_num, s, per_head_hidden_size] 。计算量为 b ∗ s ∗ 2 s ∗ h = 2 b s 2 h b*s*2s*h=2bs^2h bs2sh=2bs2h
  4. attention后的线性映射,矩阵乘法的输入和输出形状为 [b, s, h]x[h,h]->[b, s, h] 。计算量为 2 b s h 2 2bsh^2 2bsh2

不使用 KV Cache 时,输入的形状为 [b, 1, h],kv cache中含有 𝑘 𝑣 𝑙 𝑒 𝑛 𝑔 𝑡 h 𝑘𝑣_{𝑙𝑒𝑛𝑔𝑡ℎ} kvlength 个 past word。self-attention块的计算如下:

  1. 计算 𝑄 , 𝐾 , 𝑉 𝑄,𝐾,𝑉 Q,K,V :矩阵乘法的输入和输出形状为 [b, 1, h]×[h, h]→[b, 1, h] 。计算量为 3 ∗ b ∗ 2 h ∗ h = 3 ∗ 2 b h 2 = 6 b h 2 3*b*2h*h=3*2bh^2=6bh^2 3b2hh=32bh2=6bh2
  2. Q K T QK^T QKT 矩阵乘法的输入和输出形状为 [b, head_num, 1, per_head_hidden_size]×[b, head_num, per_head_hidden_size, kv_length+1]→[b, head_num, 1, kv_length+1] 。计算量为 𝑏 ∗ 2 h ∗ ( 𝑘 𝑣 𝑙 𝑒 𝑛 𝑔 𝑡 h + 1 ) = 2 b ( k v 𝑙 𝑒 𝑛 𝑔 𝑡 h + 1 ) h 𝑏 * 2h * (𝑘𝑣_{𝑙𝑒𝑛𝑔𝑡ℎ}+1) = 2b(kv_{𝑙𝑒𝑛𝑔𝑡ℎ}+1)ℎ b2h(kvlength+1)=2b(kvlength+1)h
  3. 计算在V上的加权 $score \cdot V $ ,矩阵乘法的输入和输出形状为 [b, head_num, 1, kv_length+1]×[b,head_num,kv_length+1,per_head_hidden_size]→[b,head_num,1,per_head_hidden_size] 。计算量为 2 𝑏 ( 𝑘 𝑣 𝑙 𝑒 𝑛 𝑔 𝑡 h + 1 ) h 2𝑏(𝑘𝑣_{𝑙𝑒𝑛𝑔𝑡ℎ}+1)ℎ 2b(kvlength+1)h
  4. attention后的线性映射,矩阵乘法的输入和输出形状为 [b, 1, h]×[h, h]→[b, 1, h] 。计算量为 2 b h 2 2bh^2 2bh2

接下来分析MLP块的计算,计算公式如下:

x = f g e l u ( 𝑥 o u t W 1 ) W 2 + x o u t x=f_{gelu}(𝑥_{out}W_1)W_2+x_{out} x=fgelu(xoutW1)W2+xout

不使用 KV Cache 时:

  1. 第一个线性层,矩阵乘法的输入和输出形状为 [b, s, h]×[h, 4h]→[b, s, 4h] 。计算量为 8 b s h 2 8bsh^2 8bsh2
  2. 第二个线性层,矩阵乘法的输入和输出形状为 [b, s, 4h]×[4h,h]→[b, s, h] 。计算量为 8 b s h 2 8bsh^2 8bsh2

使用 KV Cache 时:

  1. 第一个线性层,矩阵乘法的输入和输出形状为 [b, 1, h]×[h, 4h]→[b, 1, 4h] 。计算量为 8 b h 2 8bh^2 8bh2
  2. 第二个线性层,矩阵乘法的输入和输出形状为 [b, 1, h]×[4h, h]→[b, 1, h] 。计算量为 8 b h 2 8bh^2 8bh2

将上述self-attention块和MLP块计算量相加,得到:

  • 采用kv cache时,得到每个transformer层的计算量大约为 24 b h 2 + 4 b h ( k v l e n g t h + 1 ) 24bh^2+4bh(kv_{length}+1) 24bh2+4bh(kvlength+1)
  • 不采用kv cache时,得到每个transformer层的计算量大约为 24 b s h 2 + 4 b s 2 h 24bsh^2+4bs^2h 24bsh2+4bs2h

此外,另一个计算量的大头是logits的计算,将隐藏向量映射为词表大小。

  • 采用kv cache时,矩阵乘法的输入和输出形状为 [b, 1, h]×[h,V]→[b,1,V] ,计算量为 2 b h V 2bhV 2bhV
  • 不采用kv cache时为,矩阵乘法的输入和输出形状为 [b, s, h]×[h,V]→[b,s,V] ,计算量为 2 b s h V 2bshV 2bshV

KV Cache 显存占用分析

假设输入序列的长度为s ,输出序列的长度为n ,transformer层数为l,隐藏层维度 h,KV Cache 存储 kv_seq_len 个 KV value,形状为 [b, head_num, kv_seq_len, head_dim], 峰值kv_seq_len为 s+n ,以float16来保存KV cache,那么KV cache的峰值显存占用大小为 b ( s + n ) h ∗ l ∗ 2 ∗ 2 = 4 b l h ( s + n ) b(s+n)h*l*2*2=4blh(s+n) b(s+n)hl22=4blh(s+n) 。这里第一个 2 表示 K/V cache,第二个2表示float16占 2 个 bytes。

以GPT3-175B为例,对比KV cache与模型参数占用显存的大小。模型配置如下:

模型名参数量层数隐藏维度注意力头数
GPT3175B961228896

GPT3 模型占用显存大小为350GB。假设批次大小b=64 ,输入序列长度s=512 ,输出序列长度n=32 ,则KV cache 峰值占用显存为 4blh(s+n) = 164,282,499,072 bytes ≈ 164 𝐺𝐵 ,大约是模型参数显存的0.5倍。

KV Cache 存在的问题以及优化措施

当将LLMs应用于无限输入流时,使用原始的 Dense Attention 会出现两个主要挑战:

  • 上下文越长,那么矩阵占用的内存也会越多,不仅如此还会增加Decoder时候的延迟。
  • 现有模型的长度外推能力有限,即当序列长度超出预训练期间设置的attention窗口大小时,其性能会下降。

因此,目前提出了一些优化方法,比如:使用滑动窗口的注意力机制,主要有如下几种方式。

  • 一种方式是如下图 B 的窗口注意力(Window Attention):只缓存最近的 L 个 Token 的 KV。虽然推理效率很高,但一旦起始Token的键和值被驱逐,性能就会急剧下降。
  • 一种方式是下图 C 的滑动窗口重计算(Sliding Window w/ Re-computation):根据每个新 Token 的 L 个最近 Token 重建 KV 状态。虽然它在长文本上表现良好,但其 O ( T L 2 ) O(TL^2) O(TL2) 的复杂性(源于上下文重新计算中的二次注意力)使其相当慢。

image.png

  • 还有一种方式是StreamingLLM,在当前滑动窗口方法的基础上,重新引入了一些最初的 tokens 的KV在注意力计算中使用。StreamingLLM 中的KV缓存可以概念上分为两部分,如下图所示:(1)attention sink 是 4 个最初的 tokens,稳定了注意力计算;(2)Rolling KV Cache 保留了最近的token,这个窗口值是固定的。此外,还需要有些小改动来给attention注入位置信息,StreamingLLM就可以无缝地融入任何使用相对位置编码的自回归语言模型,如RoPE和ALiBi。

image.png

KV Cache 源码分析

GPT2 中 KV Cache 代码实现:

class GPT2Attention(nn.Module):
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        ...
  
        # 拆分 Q、K、V
        query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
        
        ...
        
        # [batch, sequence_len, embeded_dim] -> [batch, heads, sequence_len, head_dim]
        query = self._split_heads(query, self.num_heads, self.head_dim) # 当前token对应的query
        key = self._split_heads(key, self.num_heads, self.head_dim) # 当前token对应的key
        value = self._split_heads(value, self.num_heads, self.head_dim) # 当前token对应的value

        ##################################
        # KV Cache 核心代码逻辑
        if layer_past is not None: 
            past_key, past_value = layer_past # 从 KV Cache 去数据
            key = torch.cat((past_key, key), dim=-2) # 将当前token的key与历史的K拼接
            value = torch.cat((past_value, value), dim=-2) # 将当前token的value与历史的V拼接

        if use_cache is True:
            present = (key, value) # 将数据存到 KV Cache
        else:
            present = None
        ##################################
        ...
        
        # 使用当前token的query与K和V计算注意力表示
        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) # 返回att输出(激活)和权重

        # 合并多头注意力
        # attn_output: [batch, heads, sequence_len, head_dim] -> [batch, heads, embed_dim]
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs  # a, present, (attentions)

Baichuan2 中 KV Cache 代码实现:

class Attention(nn.Module):

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Tuple[torch.Tensor]] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        proj = self.W_pack(hidden_states)
        proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
        query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # [bsz, nh, t, hd]

        if past_key_value is not None:
            # 取出 KV Cache 中的值
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        
        # 保存 KV Cache 中的值
        past_key_value = (key_states, value_states) if use_cache else None

Huggingface Transformer 库中 LLaMA 中 KV Cache 代码实现:

class LlamaAttention(nn.Module):
    ...
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    
        ...

        past_key_value = getattr(self, "past_key_value", past_key_value)
        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            # 将当前 Token 的 kv 值更新到 KV Cache,并返回新的 KV
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        ...

        return attn_output, attn_weights, past_key_value

Huggingface Transformer 库中对Cache进行了抽象,里面实现了各种Cache,如:生成模型默认的动态缓存DynamicCache、StaticCache 和 StreamingLLM 论文中提到的SinkCache。


@dataclass
class Cache:
    """
    所有Cache的基础抽象类。实际数据结构由每个子类决定。
    """

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
                cache to be created.

        Return:
            A tuple containing the updated key and value states.
        """
        raise NotImplementedError("Make sure to implement `update` in a subclass.")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states, if there is any."""
        raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")

    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
        """Given the sequence length of the new inputs, returns the usable length of the cache."""
        # Cache without size limit -> all cache is usable
        # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
        #   length, we will need to evict part of the cache (and thus not all cache is usable)
        max_length = self.get_max_length()
        previous_seq_length = self.get_seq_length(layer_idx)
        if max_length is not None and previous_seq_length + new_seq_length > max_length:
            return max_length - new_seq_length
        return previous_seq_length

    @property
    def seen_tokens(self):
        logger.warning_once(
            "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
            "model input instead."
        )
        if hasattr(self, "_seen_tokens"):
            return self._seen_tokens
        else:
            return None


class DynamicCache(Cache):
    # 随着生成更多 Token 而动态增长的Cache。这是生成模型的默认设置。
    # 它将键和值状态存储为张量列表,每层一个张量。每个张量的期望形状是
    # [batch_size, num_heads, seq_len, head_dim]。
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
        
        
class StaticCache(Cache):
    """
    与 torch.compile(model) 一起使用的静态 Cache 类
    """
    ...
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
        使用张量进行索引是非常重要的,否则你会向设备引入一个副本。
        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for. Kept for backward compatibility
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
                to know how much of the cache it should overwrite.

        Return:
            A tuple containing the updated key and value states.
        """
        new_cache_positions = cache_kwargs.get("cache_position")
        k_out = self.key_cache
        v_out = self.value_cache

        k_out[:, :, new_cache_positions] = key_states
        v_out[:, :, new_cache_positions] = value_states

        return k_out, v_out
    
class SinkCache(Cache):
    """
    # 正如[Attention Sinks 论文](https://arxiv.org/abs/2309.17453)中所描述的缓存。
    # 它允许模型生成超出其上下文窗口的长度,而不会失去会话的流畅性。
    # 因为它抛弃了过去tokens,模型将失去生成依赖于被丢弃的上下文的tokens的能力。
    # 它将键和值状态存储为张量列表,每层一个张量。每个张量的期望形状是
    # [batch_size, num_heads, seq_len, head_dim]
    """
    ...
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
        # with partially rotated position embeddings, like Phi or Persimmon.
        sin = cache_kwargs.get("sin")
        cos = cache_kwargs.get("cos")
        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
        using_rope = cos is not None and sin is not None

        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # [bsz, num_heads, seq_len, head_dim]
        if len(self.key_cache) <= layer_idx:
            # Empty cache
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)

        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
            # Growing cache
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        else:
            # Shifting cache
            keys_to_keep = self.key_cache[layer_idx][
                :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
            ]

            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
            if using_rope:
                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
                    key_states, cos[: self.window_length], sin[: self.window_length]
                )
                if partial_rotation_size is not None:
                    keys_to_keep, keys_pass = (
                        keys_to_keep[..., :partial_rotation_size],
                        keys_to_keep[..., partial_rotation_size:],
                    )
                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
                if partial_rotation_size is not None:
                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)

            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)

            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
            values_to_keep = self.value_cache[layer_idx][
                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
            ]
            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    

从 GPT2 、 Baichuan2 和 LLaMA 的源码中可以看到 KV Cache 核心代码的实现就几行并不复杂,但是带来的收益却挺大。

结语

本文简要分析了 KV Cache 原理、源码以及计算量和显存占用,这是一种典型的通过空间换时间(计算)的技术,虽然并不复杂,但是现在基本上是仅解码器Transformer架构生成大语言模型必备优化技术。

参考文档:

  • 图解大模型推理优化:KV Cache
  • 大模型推理百倍加速之KV cache篇
  • 大模型推理加速:看图学KV Cache*
  • 大模型推理性能优化之KV Cache解读
  • LLM推理:首token时延优化与System Prompt Caching
  • [LLM]KV cache详解 图示,显存,计算量分析,代码*
  • 分析transformer模型的参数量、计算量、中间激活、KV cache*
  • The Illustrated GPT-2 (Visualizing Transformer Language Models)*
  • Transformers KV Caching Explained*
  • LLM推理技术之StreamingLLM:如何拥有无限长生成能力

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2239763.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣 LeetCode 24. 两两交换链表中的节点(Day2:链表)

解题思路&#xff1a; 暂存节点tmp和tmp1 注意&#xff1a;while (cur.next ! null && cur.next.next ! null)表示为偶数和奇数时的循环停止条件&#xff0c;并且while语句中的顺序不可交换&#xff0c;交换会报空指针异常 class Solution {public ListNode swapPai…

动态规划-背包问题——494.目标和

1.状态表示 题目来源 494.目标和——力扣 测试用例 2.算法原理 1.状态表示 首先我们需要将问题简化&#xff0c;这里需要找到能将数组组合计算成为指定数字target的添加方式&#xff0c;那么我们就可以将数字分为两类&#xff0c;一类是前面添加""的&#xff0c;另…

哪些因素会导致充电器的充电速度变慢?-纳米软件

充电器的充电速度变慢可能由多种原因引起。以下是一些常见的因素&#xff1a; 一、充电器本身的问题 充电头功率不足&#xff1a;不同的充电头有不同的输出功率&#xff0c;如果使用的充电头功率较低&#xff0c;那么充电速度就会变慢。例如&#xff0c;一些老旧的充电头可能…

刷题强训(day06) -- 大数加法、链表相加、大数乘法

目录 1、大数加法 1.1 题目 1.2 思路 1.3 代码实现 2、链表相加&#xff08;二&#xff09; 2.1 题目 2.2 思路 2.3 代码实现 3、大数乘法 3.1 题目 3.2 思路 3.3 代码实现 1、大数加法 1.1 题目 1.2 思路 这道题可以模拟列竖式相加解答&#xff0c; 将每一位都转…

数字后端教程之Innovus report_property和get_property使用方法及应用案例

数字IC后端实现Innovus中使用report_property可以报告出各种各样object的属性&#xff0c;主要有cell&#xff0c;net&#xff0c;PG Net&#xff0c;Pin&#xff0c;时钟clock&#xff0c;时序库lib属性&#xff0c;Design属性&#xff0c;timing path&#xff0c;timin arc等…

网络基础 - 网段划分篇

我们知道&#xff0c;IP 地址(IPv4 地址)由 “网络标识(网络地址)” 和 “主机标识(主机地址)” 两部分组成&#xff0c;例如 192.168.128.10/24&#xff0c;其中的 “/24” 表示从第 1 位开始到多少位属于网络标识&#xff0c;那么&#xff0c;剩余位就属于主机标识了&#xf…

python实战(八)——情感识别(多分类)

一、任务目标 本文使用的是来自Kaggle的一个情感识别数据集&#xff0c;这个数据集的总数据量是5934条&#xff0c;标签为anger、fear、joy三种情感的其中一种&#xff0c;很明显是一个多分类任务。这里&#xff0c;我们将使用微调技巧进行深度学习建模&#xff0c;同时我们会比…

23423234

c语言中的小小白-CSDN博客c语言中的小小白关注算法,c,c语言,贪心算法,链表,mysql,动态规划,后端,线性回归,数据结构,排序算法领域.https://blog.csdn.net/bhbcdxb123?spm1001.2014.3001.5343 给大家分享一句我很喜欢我话&#xff1a; 知不足而奋进&#xff0c;望远山而前行&am…

opencv入门学习总结

opencv学习总结 不多bb&#xff0c;直接上代码&#xff01;&#xff01;&#xff01; 案例一&#xff1a; import cv2 # 返回当前安装的 OpenCV 库的版本信息 并且是字符串格式 print(cv2.getVersionString()) """ 作用&#xff1a;它可以读取不同格式的图像文…

MySQL 中的索引下推功能

看到索引&#xff0c;应该大家都可以联想到这个是和查询效率有关系的&#xff0c;既然有这个功能&#xff0c;那么那句古话说的好啊&#xff1a;存在即合理。那么这个就是说有了这个功能&#xff0c;可以提升查询效率。 什么是索引下推 我们先有一个大概的理解&#xff1a;在…

重拾CSS,前端样式精读-媒体查询

前言 本文收录于CSS系列文章中&#xff0c;欢迎阅读指正 说到媒体查询&#xff0c;大家首先想到的可能是有关响应式的知识点&#xff0c;除此之外&#xff0c;它还可以用于条件加载资源&#xff0c;字体大小&#xff0c;图像和视频的优化&#xff0c;用户界面调整等等方面&am…

物理设备命名规则(Linux网络服务器 15)

Linux系统中的一切都是文件&#xff0c;硬件设备也不例外。既然都是文件&#xff0c;就必须有文件名称。系统内核中udev设备管理器会自动把硬件名称规范化起来&#xff0c;目的是让用户通过设备文件的名字可以大致了解设备属性以及分区信息。这对于陌生的设备来说特别方便。另外…

NVIDIA NIM 开发者指南:入门

NVIDIA NIM 开发者指南&#xff1a;入门 NVIDIA 开发者计划 想要了解有关 NIM 的更多信息&#xff1f;加入 NVIDIA 开发者计划&#xff0c;即可免费访问任何基础设施云、数据中心或个人工作站上最多 16 个 GPU 上的自托管 NVIDIA NIM 和微服务。 加入免费的 NVIDIA 开发者计…

猿创征文|Inscode桌面IDE:打造高效开发新体验

猿创征文&#xff5c;Inscode桌面IDE&#xff1a;打造高效开发新体验 引言 在当今快速发展的软件开发领域&#xff0c;一个高效、易用的集成开发环境&#xff08;IDE&#xff09;是每个开发者必不可少的工具。Inscode 桌面 IDE 作为一款新兴的开发工具&#xff0c;凭借其强大…

力扣 LeetCode 142. 环形链表II(Day2:链表)

解题思路&#xff1a; 使用set判断是否重复添加&#xff0c;如果set加入不进去证明之前到达过该节点&#xff0c;有环 public class Solution {public ListNode detectCycle(ListNode head) {Set<ListNode> set new HashSet<>();ListNode cur head;while (cur …

激活函数解析:神经网络背后的“驱动力”

神经网络中的激活函数&#xff08;Activation Function&#xff09;是其运作的核心组件之一&#xff0c;它们决定了神经元如何根据输入信号进行“激活”&#xff0c;进而影响整个模型的表现。理解激活函数的工作原理对于设计和优化神经网络至关重要。本篇博客将深入浅出地介绍各…

基于表格滚动截屏(表格全部展开,没有滚动条)

import html2canvasPro from html2canvas // 截图&#xff0c;平辅表格 async function resetAgSize() {const allColumns gridApi.value.getColumns()let totalColumnWidth 0let totalColumnHeight 0// 遍历每一个行节点gridApi.value.forEachNode((rowNode) > {totalCo…

vs2015QT项目添加多语言翻译总结

一、简介 当软件有国际化的需求时&#xff0c;就需要多语言翻译功能&#xff0c;最常见的语言就是支持中文和英语&#xff0c;本文介绍在vs2015QT环境下&#xff0c;进行国际化翻译的具体流程。 二、多语言翻译实现流程 1.底层实现原理介绍 QT写的客户端软件&#xff0c;能…

wireshark演进之路——从GTK到Qt

Wireshark 自 1998 年诞生至今&#xff0c;已有超过26年的历史了。它最早由 Gerald Combs 创建&#xff0c;最初名为 Ethereal。2006 年&#xff0c;Ethereal 更名为 Wireshark&#xff0c;并继续发展成了全球领先且人尽皆知的网络协议分析工具&#xff0c;其GUI演变就是其中非…

哈希表的实现--C++

文章目录 一、哈希概念1.1、直接定址法1.2、哈希冲突1.3、负载因子1.4、将关键字转为整数1.5、哈希函数1.5.1、除法散列法/除留余数法1.5.2、乘法散列法1.5.3、全域散列法1.5.4、其他方法 二、处理哈希冲突2.1、开放定址法2.1.1、线性探测2.1.2、二次探测2.1.3、双重散列2.1.4、…