大模型推理性能优化之KV Cache解读

news2025/1/10 15:42:08

0. 引言

做大模型性能优化的一定对KV Cache不陌生,那么我们对这个技术了解到什么程度呢?请尝试回答如下问题:

  • KV Cache节省了Self-Attention层中哪部分的计算?

  • KV Cache对MLP层的计算量有影响吗?

  • KV Cache对block间的数据传输量有影响吗?本文打算剖析该技术并给出上面问题的答案。

1. KV Cache是啥

大模型推理性能优化的一个常用技术是KV Cache,该技术可以在不影响任何计算精度的前提下,通过空间换时间思想,提高推理性能。网上有一些关于该技术的分析博客,但读过后仍然会很迷糊,甚至可能会被带偏,认为这个Cache过程和数据库读取或CPU Cache加速类似的荒谬结论。刚开始我也有类似误解,直到逐行查阅并运行源码,才清楚了解到其Cache了啥,以及如何节省计算的。

2. 背景

生成式generative模型的推理过程很有特点,我们给一个输入文本,模型会输出一个回答(长度为N),其实该过程中执行了N次推理过程。即GPT类模型一次推理只输出一个token,输出token会与输入tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。

如上描述是我们通常认知的GPT推理过程。代码描述如下:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer


model = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()

# tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))

# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
i = 0
with torch.no_grad():
    while out_token != token_eos:
        logits, _ = model(in_tokens)
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = torch.cat((in_tokens, out_token), 0)
        text = tokenizer.decode(in_tokens)
        print(f'step {i} input: {text}', flush=True)
        i += 1

out_text = tokenizer.decode(in_tokens)
print(f' Input: {in_text}')
print(f'Output: {out_text}')

输出:

step 0 input: Lionel Messi is a player
step 1 input: Lionel Messi is a player who
step 2 input: Lionel Messi is a player who has
step 3 input: Lionel Messi is a player who has been
step 4 input: Lionel Messi is a player who has been a
step 5 input: Lionel Messi is a player who has been a key
step 6 input: Lionel Messi is a player who has been a key part
step 7 input: Lionel Messi is a player who has been a key part of
step 8 input: Lionel Messi is a player who has been a key part of the
step 9 input: Lionel Messi is a player who has been a key part of the team
step 10 input: Lionel Messi is a player who has been a key part of the team's
step 11 input: Lionel Messi is a player who has been a key part of the team's success
step 12 input: Lionel Messi is a player who has been a key part of the team's success.
step 13 input: Lionel Messi is a player who has been a key part of the team's success.

 Input: Lionel Messi is a
Output: Lionel Messi is a player who has been a key part of the team's success.

可以看出如上计算的问题吗?每次推理过程的输入tokens都变长了,导致推理FLOPs随之增大。有方法实现推理过程的FLOPs基本恒定不变或变小吗?(埋个伏笔,注意是基本恒定)。

3. 原理

在上面的推理过程中,每step内,输入一个token序列,经过Embedding层将输入token序列变为一个三维张量[b, s, h],经过一通计算,最后经logits层将计算结果映射至词表空间,输出张量维度为[b, s, vocab_size]

当前轮输出token与输入tokens拼接,并作为下一轮的输入tokens,反复多次。可以看出第i+1轮输入数据只比第 轮i输入数据新增了一个token,其他全部相同!因此第i+1轮推理时必然包含了第i轮的部分计算。KV Cache的出发点就在这里,缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果,就是这么简单,不存在什么Cache miss问题。

4. 实现细节

目前各大模型推理都实现了KV Cache,下面就看如何使用了。我们可以在上面代码基础上修改,主要改动:

  • 在推理时新增了past_key_values参数,该参数就会以追加方式保存每一轮的K V值。kvcache变量内容为((k,v), (k,v), ..., (k,v)),即有L个 k,v 组成的一个元组,其中 k 和 v 的维度均为[b, n_head, s, head_dims]。这里可以顺带计算出每轮推理对应的 cache 数据量为2bshL,这里s值等于当前轮次值。可以随着输出tokens的增长,cache数据量呈现线性增加特点。以GPT3-175B为例,假设以 float16 来保存 KV cache,senquence长度为100,batchsize=1,则KV cache占用显存为 2×100×12288×96×2 Byte = 472MB。

  • 推理输出的token直接作为下一轮的输入,不再拼接,因为上文信息已经在 kvcache 中。

代码示例:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer


model = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()

# tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))

# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
kvcache = None
out_text = in_text
i = 0
with torch.no_grad():
    while out_token != token_eos:
        logits, kvcache = model(in_tokens, past_key_values=kvcache) # 增加了一个 past_key_values 的参数
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = out_token # 输出 token 直接作为下一轮的输入,不再拼接
        text = tokenizer.decode(in_tokens)
        print(f'step {i} input: {text}', flush=True)
        i += 1
        out_text += text

print(f' Input: {in_text}')
print(f'Output: {out_text}')

通过上面代码只能看到调用层面的变化,实现细节还需看各框架的底层实现,例如Hugging Face的transformers库代码实现就比较清爽,在modeling_gpt2.py中Attention部分相关代码如下:

query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None: # 当输出第一个token后,layer_past就是非None了
            past_key, past_value = layer_past # 取出之前计算好的 key, value
            key = torch.cat((past_key, key), dim=-2) # past_key 与当前 token 对应的 key 拼接
            value = torch.cat((past_value, value), dim=-2) # past_value 与当前 token 对应的 value 拼接

        if use_cache is True:
            present = (key, value)
        else:
            present = None

在 block 层面也有相关代码,大家有空细品吧。还是那句话,说一千道一万不如阅读并运行源码一次。

其实,KV Cache 配置开启后,推理过程可以分为2个阶段:

  • 预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢。

  • 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。

这里用图可能更有助理解,下图是一个Decoder Block,含有Self-Attention和MLP,标红部分为KV Cache影响到的内容,即KV Cache开启后,标红的序列长度s变为 1,当batch_size=1时,Self-Attention中的2个dense全都变为gemv操作,MLP中的dense也全都变为gemv操作。看懂这个图就可以答对上面的3个问题啦。

3e4e90f36eb7d83649ef2484b96f97cb.png


Decoder Block of GPT

如下链接也有这方面的定量分析,写的很棒,推荐大家看看。

回旋托马斯x:分析transformer模型的参数量、计算量、中间激活、KV cache

5. 总结

KV Cache是Transformer推理性能优化的一项重要工程化技术,各大推理框架都已实现并将其进行了封装(例如 transformers库 generate 函数已经将其封装,用户不需要手动传入past_key_values)并默认开启(config.json文件中use_cache=True)。本文尝试打开封装分析该技术内部实现,希望对大家有所帮助,文中如有纰漏,欢迎指正。

最后给我们团队的开源项目打个广告,Adlik 是中兴通讯贡献的深度学习推理工具,已获得 Linux AI 基金会支持,目前仍在持续完善中,期待大家的支持和关注。另外,我们也在做深度学习编译器和大模型方向的前沿技术研发工作,欢迎感兴趣的小伙伴加入我们 ~~

f651ac44f91844646e658ea85a5c8f7d.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/561809.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

知识点滴 - 什么是膳食结构

膳食结构是指膳食中各类食物的数量及其在膳食中所占的比重,由于影响膳食结构的这些因素是在逐渐变化的,所以膳食结构不是一成不变的,人们可以通过均衡调节各类食物所占的比重,充分利用食品中的各种营养,达到膳食平衡&a…

适配器模式:代码接口的神奇转换

一、概要 适配器模式(Adapter Pattern)是一种结构型设计模式,它允许将一个类的接口转换成客户端所期望的另一个接口,使得原本由于接⼝不兼容⽽不能⼀起⼯作的那些类可以⼀起⼯作。通俗来讲,就是通过适配器来连接两个不…

Js 如何实现一个类似 chatGPT 打字机效果

在使用chatGPT的时候,会有一个打字机的效果,以下是分别使用原生Js和Vue实现 原生 JS 实现 如下是示例代码 <!DOCTYPE html> <html><head><meta charset"utf-8"><title>Printer 打字机效果</title><style>* {margin: 0;bor…

记录--使用率比较低的10个Web API

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 avaScript中有些API可能使用率比较低&#xff0c;下面我们逐一介绍它们的用法和使用场景。 至于标题&#xff0c;主要是想让你进来看看&#xff0c;兄弟们别打我&#xff01; Blob API Blob API 用于处…

电脑技巧:CopyQ剪切板增强工具介绍(附下载)

目录 1、软件简介 2、主要功能介绍 3、使用说明 4、总结 今天给大家再分享一款剪切板增强工具——CopyQ&#xff0c;感兴趣的朋友可以下载试一试&#xff01; 1、软件简介 CopyQ 是一款开源的、跨平台剪贴板管理工具&#xff0c;支持 Windows、macOS、Linux&#xff0c;可…

项目中遇到的一些问题总结(十)

nacos保护阈值 Nacos 中的保护阈值&#xff08;Protection Threshold&#xff09;是用来保护服务实例的一种机制。当某个服务实例出现故障或异常时&#xff0c;服务注册中心 Nacos 会通过心跳检测等方式将其从服务列表中移除&#xff0c;以避免客户端继续向其发送请求。但是&a…

在Linux中为Simulink添加ROS自定义消息类型

在Linux中为Simulink添加ROS自定义消息类型 基于Matlab/Simulink的ROS自定义消息类型的添加方法 ROS与Simulink联合仿真(三):自定义Message 1、下载 ROS Toolbox Interface for ROS Custom Messages 将 roscustommsg.mlpkginstall 文件放入 MATLAB 工作空间 双击 roscustommsg…

Nature -- 人类首个 “泛基因组”旨在编目人类遗传多样性

在人类基因组项目发布第一个人类基因组草图的20多年后,研究人员发布了人类“泛基因组”草图——这预示着一种新的参考基因组的出现,它能捕获到更多的人类遗传多样性信息。 泛基因组变异图由两个元素组成:序列图&#xff0c;其ode表示定向DNA链&#xff0c;双向边表示连通性关系…

学系统集成项目管理工程师(中项)系列25_计算机网络知识

1. OSI七层协议 1.1. 物理层 1.1.1. RS232、V.35、RJ-45、FDDI 1.2. 数据链路层 1.2.1. 【21上选17】 1.2.2. IEEE802.3/.2、HDLC、PPP、ATM 1.3. 网络层 1.3.1. IP、ICMP、IGMP、IPX、ARP 1.3.2. 路由选择 1.3.2.1. 【20下选17】 1.4. 传输层 1.4.1. TCP、UDP、SPX…

越小越好: Q8-Chat,在英特尔至强 CPU 上体验高效的生成式 AI

大语言模型 (LLM) 正在席卷整个机器学习世界。得益于其 transformer 架构&#xff0c;LLM 拥有从大量非结构化数据 (如文本、图像、视频或音频) 中学习的不可思议的能力。它们在 多种任务类型 上表现非常出色&#xff0c;无论是文本分类之类的抽取任务 (extractive task) 还是文…

又一批令人惊艳的 AI 工具,诞生了!

公众号关注 “GitHubDaily” 设为 “星标”&#xff0c;每天带你逛 GitHub&#xff01; 自 ChatGPT 发布以后&#xff0c;AIGC 行业的热度也一直在持续发酵。几个月过去了&#xff0c;对比之前&#xff0c;各类 AI 工具的热度不减反增&#xff0c;各行各业的人都早开始拥抱 AIG…

SAP工具箱 MR22自定义BAPI

点击蓝字 关注我们 一 前言 标准事务代码MR22 通过调整金额影响物料的成本价,前台界面中单个凭证中允许输入多行物料, 但是对应的BAPI函数仅支持输入单行物料 BAPI_MATVAL_DEBIT_CREDIT 正常库存BAPI_SALESORDSTCK_DEBIT_CREDIT 销售订单库存 这种情况 婶可忍叔不可忍 (感谢用户…

Python 闭包装饰器和多任务--闭包,装饰器,进程,线程

1.闭包案例 在函数嵌套的前提下&#xff0c;内部函数使用了外部函数的变量&#xff0c;并且外部函数返回了内部函数&#xff0c;我们把这个使用外部函数变量的内部函数称为闭包. 外层函数: config_name(),外层函数中的变量是 name 内层函数: inner(),inner()使用了外层函数的变…

34从零开始学Java之构造方法都有哪些特性?

作者&#xff1a;孙玉昌&#xff0c;昵称【一一哥】&#xff0c;另外【壹壹哥】也是我哦 千锋教育高级教研员、CSDN博客专家、万粉博主、阿里云专家博主、掘金优质作者 前言 在前面的几篇文章中&#xff0c;壹哥给大家介绍了不少关于方法的内容&#xff0c;这些内容是我们日常…

zabbix监控之javasnmp自定义监控

1、客户端开启 java jmxremote 远程监控功能 上传 tomcat 软件包到 /opt 目录中 cd /opt tar zxvf apache-tomcat-9.0.16.tar.gz mv apache-tomcat-9.0.16 /usr/local/tomcat #配置 java jmxremote 远程监控功能 vim /usr/local/tomcat/bin/catalina.sh ...... #位置在 cygw…

嵌入式音视频开发面试过程遇到的问题!

前言&#xff1a; 今天继续给大家分享音视频面试过程会被常问到的一些问题&#xff01; 面试的具体题目&#xff1a; 1、说一下播放器的设计过程&#xff1a; 这里的话主要分以下几步完成&#xff1a; 开启一个线程进行解封装操作 , 这包括&#xff1a;读取音频、视频的压缩数据…

chatgpt赋能Python-python_ps图片

Python PS图片的SEO指南 Python在数字图像处理中广泛应用。其中&#xff0c;Photoshop文件&#xff08;psd&#xff09;是一种常见的图像文件格式。但是&#xff0c;如何在搜索引擎上优化Python PS图片并提高其排名仍然是一个挑战。 什么是Python PS图片&#xff1f; Python…

数据结构和算法基础学习1

​​​​​​​ 网址第01周b--1.1数据结构研究_哔哩哔哩_bilibili

学C的第十九天【实用调试技巧:1. 调试;2. Windows环境调试介绍;3. 一些调试的实例;4. 一些调试的实例】

相关代码gitee自取&#xff1a;C语言学习日记: 加油努力 (gitee.com) 接上期&#xff1a;学C的第十八天【指针初阶&#xff1a;5. 指针和数组、6. 二级指针、7. 指针数组&#xff1b;初识结构体&#xff1a;1. 结构体的声明、2. 结构体成员的访问、3. 结构体传参&#xff1b…

java中的栈、堆、方法区

栈&#xff08;stack&#xff09; Java栈与堆不同每一个线程都有一个stack&#xff0c;栈的区域非常小&#xff0c;大概只有1M左右&#xff0c;但是存储速度非常快&#xff0c;所以我们把快速执行的任务存储在stack。 特点&#xff1a;自动分配&#xff0c;连续空间&#xff0…