LLM推理优化笔记1:KV cache、Grouped-query attention等

news2024/12/27 13:58:56

KV cache

对于decoder-only 模型比如现在如火如荼的大模型,其在生成内容的过程中,为了避免冗余计算,会将Transformer里的self-attention的K和V矩阵给缓存起来,这个过程即为KV cache。

在这里插入图片描述

decoder-only模型的生成过程是自回归的(auto-regressive),生成过程中先根据输入生成下一个token,再将生成的token与输入一起生成下一个token,重复这个过程直到遇到停止符号或者达到限定的输出token个数。(gif图来自illustrated-gpt2)
在这里插入图片描述

因为decoder-only模型的生成过程是自回归的,并且decoder的self-attention是causal的,即每一个token的attention计算只与其前面的tokens有关,所以我们每生成一个token时都重复计算了前面出现过的token的attention。为了节省计算量,可以将已经计算过的token的attention矩阵存储下来,每生成下一个token时直接使用存储好的attention矩阵并将新计算的token attention存储起来。(下面图片来自博客,不考虑softmax和scale示意对比KV cache使用)

在这里插入图片描述

在每一步计算时,只需要使用到上一步计算过的K和V矩阵,所以KV cache只会缓存K和V。当然缓存的代价就是需要额外的显存存储:

  • 每缓存一个token,其需要的空间为 2 * precision_in_bytes * head_dim * n_heads * n_layers(式中2是因为缓存K和V两个矩阵,precision_in_bytes是token的存储精度占用字节大小,head_dim是attention的head维度,n_head是attention的head个数,n_layers是transformer的层个数)。
  • 对于16-bit精度的模型以最大上下文长度max_context_length进行批量推理要求的缓存大小2 * 2 * head_dim * n_heads * n_layers * max_context_length * batch_size,比如Llama-2-13B模型对应最大上下文窗口为4096,batch大小为8时要求的缓存显存最多高达25GB左右。

transformers包生成时默认使用KV cache(use_cache=True),我们可以用如下代码去测试一下使用了KV cache以及不使用时的性能差异。

## 代码来自 https://medium.com/@joaolages/kv-caching-explained-276520203249
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
    start = time.time()
    model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)
    times.append(time.time() - start)
  print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

Multi-query attention 和Grouped-query attention

Multi-query attention

Multi-query attention(MQA)出自2019年11月的论文《Fast Transformer Decoding: One Write-Head is All You Need》,它让multi-head attention里的多个head共享K和V矩阵,并做试验验了修改之后模型的性能下降不明显,但因为减少了参数,推理时KV cache占用的存储和读取时间都会少很多。

Grouped-query attention

在这里插入图片描述

Grouped-query attention(GQA)出自2023年5月的论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》, 如上图所示,它的共享K和V矩阵介于Multi-query attention(MQA)和Multi-head attention(MHA)之间,通过实验证明GQA可达到类似MQA的速度以及MHA的性能。

Grouped-query attention将query heads划分为G个groups,每一组query heads共享一个key head和value head,将 G Q A − G GQA_{-G} GQAG 记为有G个groups的grouped-query attention,则 G Q A − 1 GQA_{-1} GQA1为Multi-query attention, G Q A − H GQA_{-H} GQAH则等价于Multi-head attention。

论文还提出了一个将Multi-head attention模型转变MQA或GQA模型的方法,其分为两步:

  • 将MHA模型的checkpoint转变成MQA或GQA模型,使用如下图示意的mean pooling将多个K和V矩阵变成单个矩阵(论文做了试验比较选取第一个head、随机初始化、mean pooling,mean pooling的效果是最好的)。
  • 使用少量比例(5%左右)的预训练数据来继续预训练使模型适应新结构。

在这里插入图片描述

关于GQA的组个数选取,论文做了消融实验后对于总head个数为64时G选取的是8,而在Llama2-70B模型也是8(总heads数也为64)。

在这里插入图片描述

实现

不考虑性能的代码示意如下:

from dataclasses import dataclass
import math
import torch
import torch.nn as nn 
from torch.nn import functional as F

@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension
    n_kv_heads: int = 12 # grouped-query的group个数

def repeat_kv(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
    """Perform repeat of kv heads along a particular dimension.
    hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
    n_rep: amount of repetitions of kv_n_heads
    Unlike torch.repeat_interleave, this function avoids allocating new memory.
    from https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py#L47
    llama2里的写法差不多https://github.com/meta-llama/llama/blob/llama_v2/llama/model.py#L164C1-L165C1
    """
    if n_rep == 1:
        return hidden
    (b, s, kv_n_heads, d) = hidden.shape
    hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
    return hidden.reshape(b, s, kv_n_heads * n_rep, d)


## adapt from https://github.com/karpathy/nanoGPT/blob/master/model.py
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        # attention (materializes the large (T,T) matrix for all the queries and keys)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y
    

### multi-query
class MultiQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_embd//config.n_head)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split([self.n_embd, self.n_embd//self.n_head, self.n_embd//self.n_head], dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        k = repeat_kv(k.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = repeat_kv(v.view(B, T, 1, C // self.n_head), self.n_head).transpose(1, 2) # (B, nh, T, hs)
                # attention (materializes the large (T,T) matrix for all the queries and keys)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y
    
### grouped-query attention
class GroupedQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, config.n_embd + 2*config.n_kv_heads*config.n_embd//config.n_head)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.n_kv_heads = config.n_kv_heads
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split([self.n_embd, self.n_kv_heads*self.n_embd//self.n_head, self.n_kv_heads*self.n_embd//self.n_head], dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        k = repeat_kv(k.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)
        v = repeat_kv(v.view(B, T, self.n_kv_heads, C // self.n_head), self.n_head//self.n_kv_heads).transpose(1, 2) # (B, nh, T, hs)

                # attention (materializes the large (T,T) matrix for all the queries and keys)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

Sliding Window Attention

Mistral 7B使用Sliding Window Attention(SWA)来减少KV cache的内存占用,每次计算attention时,只考虑固定窗口大小W内的信息。对于位置i的隐状态,只会考虑在其前面i-W到i的窗口内的隐状态信息,如下图示意所示,所以对于在第k层的位置i来说,最多可以访问到 W × k W\times k W×k个tokens。在Mistral 7B里,W=4096,层数为32,所以理论上的attention范围近似为131K。

在这里插入图片描述

因为使用固定attention窗口,所以Mistral 7B使用滚动(rolling) buffer cache, cache大小固定为W,在时刻t的K和V存储在cache的第i mod W个位置,也就是说如果位置i比W大,cache中原先存储的值会被覆盖掉。下图是W=3时的示意。
在这里插入图片描述

参考资料

  1. 看图学KV Cache

  2. Transformer Inference Arithmetic

  3. Transformers KV Caching Explained(其gif动画有助于加深理解)

  4. KV caching内存增长

  5. KV cache 是chatbot 规模化的一大工程挑战

  6. Techniques for KV Cache Optimization in Large Language Models

  7. KV cache quantization

  8. Inference Optimization

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924377.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

让人工智能为你的旋律填词,开启音乐新章

在音乐的世界里&#xff0c;旋律如同灵动的精灵&#xff0c;飞舞在我们的心间。但有时&#xff0c;为这美妙的旋律找到最贴切、最动人的歌词&#xff0c;却成为了创作者们的难题。如今&#xff0c;随着科技的进步&#xff0c;人工智能正逐渐成为我们在音乐创作道路上的得力助手…

【香橙派 Orange pi AIpro】| 开发板深入使用体验

目录 一. &#x1f981; 写在前面二. &#x1f981; 愉快的安装流程2.1 安装前准备2.2 流程准备2.2.1 烧录镜像2.2.2 开机2.2.3 连网2.2.4 SSH远程连接开发板 2.3 体验 AI 应用样例 三. &#x1f981; 写在最后 一. &#x1f981; 写在前面 大家好&#xff0c;我是狮子呀&…

Windows下vscode配置C++环境

一、vscode下载及安装 vscode官网 选安装位置。 勾选这几项。 1.vscode界面中文配置 &#xff08;1&#xff09;点击扩展小图标&#xff0c;搜索插件&#xff0c;找到插件Chinese (Simplified) (简体中文) Language Pack&#xff0c;点击install。 &#xff08;2&#xf…

3dmax-vray5大常用材质设置方法

3dmax云渲染平台——渲染100 以高性价比著称&#xff0c;是预算有限的小伙伴首选。 15分钟0.2,60分钟内0.8;注册填邀请码【7788】可领30元礼包和免费渲染券 提供了多种机器配置选择(可以自行匹配环境)最高256G大内存机器&#xff0c;满足不同用户需求。 木纹材质 肌理调整&…

红酒的艺术之旅:品味、鉴赏与生活的整合

在繁忙的都市生活中&#xff0c;红酒如同一道不同的风景线&#xff0c;将品味、鉴赏与日常生活巧妙地整合在一起。它不仅仅是一种饮品&#xff0c;更是一种艺术&#xff0c;一种生活的态度。今天&#xff0c;就让我们一起踏上这趟红酒的艺术之旅&#xff0c;探寻雷盛红酒如何以…

json-server服务使用教程

目录标题 安装 json-server启动 json-server 本地服务 安装 json-server npm install -g json-server0.17.4json-server -v报错请参考&#xff1a;执行json-server -v报错 因为在此系统上禁止运行脚本。 启动 json-server 本地服务 查看本机IP&#xff1a;ipconfig Shift右…

【简历】安徽某二本学院:Java简历指导,简历通过率接近为0

注&#xff1a;为保证用户信息安全&#xff0c;姓名和学校等信息已经进行同层次变更&#xff0c;内容部分细节也进行了部分隐藏 简历说明 这是一份二本独立院校的Java简历&#xff0c;那么一般来说&#xff0c;我们的独立院校包括专升本&#xff0c;在目前的it的投递中&#x…

NI VST 毫米波测试仪器创新

目录 概览​从UHF至V频段的频率覆盖范围&#xff1a;54 GHz远程测量模块​PXIe-5842&#xff1a;VST架构的扩展54 GHz扩频PXIe-5842功能​​宽频覆盖范围​IF和毫米波测试端口可满足多频带需求​高达2 GHz瞬时带宽误差矢量幅度测量性能相位相干同步基于PXI平台集成多种仪器 互补…

在centos7.9下静默安装Oracle19c详细流程

文章目录 下载安装包1.下载Oracle19c的安装包2.下载Oracle19c的预安装包3.拖到Linux中 一、安装依赖二、创建用户和组三、修改Linux相关内核参数四、修改用户限制五、关闭防火墙和SELinux六、创建安装目录并解压安装包七、设置环境变量八、安装Oracle数据库九、创建实例十、配置…

智慧校园毕业管理:全面解读毕业批次功能

在智慧校园的毕业管理系统中&#xff0c;毕业批次模块通过其精心设计的毕业批次功能&#xff0c;为即将离校的学子们提供了一个高效、便捷的过渡平台。这一特色功能聚焦于特定时间段内的毕业生群体&#xff0c;巧妙融合数字技术&#xff0c;从信息核实到最终的离校程序&#xf…

Leetcode3200. 三角形的最大高度

Every day a Leetcode 题目来源&#xff1a;3200. 三角形的最大高度 解法1&#xff1a;模拟 枚举第一行是红色还是蓝色&#xff0c;再按题意模拟即可。 代码&#xff1a; /** lc appleetcode.cn id3200 langcpp** [3200] 三角形的最大高度*/// lc codestart class Solutio…

5.1 软件工程基础知识-软件工程概述

软件工程诞生原因 软件工程基本原理&#xff08;容易被考到&#xff09; 软件生存周期 能力成熟度模型 - CMM 能力成熟度模型 - CMMI 真题

Neo4j:图数据库的革命性力量

Neo4j 首席技术官 prathle 撰写了一篇出色的博文&#xff0c;总结最近围绕 GraphRAG 的热议、我们从一年来帮助用户使用知识图谱 LLM 构建系统中学到的东西&#xff0c;以及我们认为该领域的发展方向。Neo4j一时间又大火起来&#xff0c;本文将带你快速入门这神奇的数据库。 前…

Run LoongArch64 Alpine VM on x86_64

一、Build from source(build on x86_64) Obtain the latest libvirt, virt-manager, and qemu source code, compile and install them. 1.1 Build libvirt from source sudo apt-get update sudo apt-get install augeas-tools bash-completion debhelper-compat dh-apparm…

昇思25天学习打卡营第13天 | ResNet50迁移学习

昇思25天学习打卡营第13天 | ResNet50迁移学习 文章目录 昇思25天学习打卡营第13天 | ResNet50迁移学习数据集加载数据集数据集可视化 模型训练固定特征 总结打卡 在实际应用场景中&#xff0c;由于训练数据集不足&#xff0c;很少从头开始训练整个网络。普遍做法是在一个非常大…

全自主巡航无人机项目思路:STM32/PX4 + ROS + AI 实现从传感融合到智能规划的端到端解决方案

1. 项目概述 本项目旨在设计并实现一款高度自主的自动巡航无人机系统。该系统能够按照预设路径自主飞行&#xff0c;完成各种巡航任务&#xff0c;如电力巡线、森林防火、边境巡逻和灾害监测等。 1.1 系统特点 基于STM32F4和PX4的高性能嵌入式飞控系统多传感器融合技术实现精…

【Caffeine】⭐️SpringBoot 项目整合 Caffeine 实现本地缓存

目录 &#x1f378;前言 &#x1f37b;一、Caffeine &#x1f37a;二、项目实践 2.1 环境准备 2.2 项目搭建 2.3 接口测试 ​&#x1f49e;️三、章末 &#x1f378;前言 小伙伴们大家好&#xff0c;缓存是提升系统性能的一个不可或缺的工具&#xff0c;通过缓存可以避免大…

【答疑】8080或其他端口被占用如何解决?

我们在做项目时总会遇到各式各样千奇百怪的问题&#xff0c;但基本上每个刚接触tomcat的小白早晚都会遇到一个问题——8080端口被占用&#xff1a; 报错信息很容易理解&#xff0c;端口8080已经被使用了&#xff0c;那么这时我们该如何知道是谁使用了这个端口并关掉它呢&#x…

c++基础语法之内联函数

引言&#xff1a; 在C编程中&#xff0c;性能优化是一个永恒的话题。内联函数&#xff08;Inline Functions&#xff09;作为提高程序执行效率的一种重要手段&#xff0c;在编译器优化过程中扮演着关键角色。 一、内联函数的基本概念 定义&#xff1a;内联函数是C中一种特殊…

C#可空类型与数组

文章目录 可空类型NULL合并运算符&#xff08;??&#xff09;数组数组声明数组初始化数组赋值数组访问多维数组交错数组数组类数组类的常用属性数组类的常用方法 可空类型 C#提供了一种特殊的数据类型&#xff0c;nullable类型&#xff08;可空类型&#xff09;&#xff0c;可…