使用推测解码提高 LLM 推理速度

news2024/9/25 1:13:27

使用尖端优化技术加速推理的实用指南

       欢迎来到雲闪世界。大型语言模型非常耗电,需要大量 GPU 资源才能发挥良好性能。然而,Transformer 架构并没有充分利用 GPU。

从设计上讲,GPU 可以并行处理,但 Transformer 架构是自回归的。为了生成下一个 token,它必须查看之前的所有 token。Transformer 不允许您并行预测下一个 token。最终,这会使 LLM 的生成阶段非常缓慢,因为必须n按顺序生成每个新 token 。推测解码是一种旨在解决此问题的新型优化技术。

每次前向传递都会产生一个由 LLM 生成的新 token

推测解码有几种不同的方法。本文介绍的技术采用双模型方法。

推测解码

推测解码的工作原理是使用两个模型,一个大型模型和一个较小的辅助模型。较小的辅助模型首先生成一个由 n 个 token 组成的序列。然后,主模型在一次前向传递中验证 token 序列。

这个想法是,由于辅助模型很小,因此它会快速生成 token。主模型更大、更准确,不需要生成每一个 token。它只需要验证辅助模型生成的 token。

例如,假设助手模型产生以下 5 个 token。

辅助模型自动回归生成 token,而主模型一次性验证所有 token

主模型将对所有 5 个 token 执行一次前向传递。主模型的目的是验证每个 token 的正确性。如果根据主模型,其中一个 token 不正确,它会丢弃错误 token 后的整个序列。然后,主模型会以自回归方式用正确的 token 填充序列的其余部分。

主模型丢弃从第一个不正确的 token 开始的序列输出

使用推测解码,您一定能获得与单独运行主模型完全相同的输出。这是因为主模型会验证每个 token 并替换它认为不正确的 token。在理想情况下,如果辅助模型正确生成了大多数 token,主模型将能够快速验证 token,从而缩短端到端生成时间。

推测解码的实践

许多流行的推理服务(如 TGI 和 vLLM)都支持开箱即用的推测解码。挑战在于找到正确的辅助模型和主模型对。一般来说,最好选择具有相同架构和词汇表的模型。

在本教程中,我们将使用Llama 3.1 70B Instruct作为主模型,使用Llama 3.1 8B Instruct作为辅助模型。vLLM 为这两种模型提供支持,因此我们将使用它作为推理服务。

安装 Python 要求并导入包

!pip install vllm== 0.5 .4

从vllm导入SamplingParams
从vllm.engine.arg_utils导入AsyncEngineArgs
从vllm.engine.async_llm_engine导入AsyncLLMEngine

下载模型

model_args = AsyncEngineArgs(
            模型 = “meta-llama/Meta-Llama-3.1-70B-Instruct”,
            speculative_model = “meta-llama/Meta-Llama-3.1-8B-Instruct”,
            trust_remote_code = True,
            tensor_parallel_size = 4,
            max_num_seqs = 8,
            dtype = “half”,
            use_v2_block_manager = True,
            enforce_eager = True
         )

llm_engine = AsyncLLMEngine.from_engine_args(model_args)
  • 在引擎参数中,指定作为助手model时要使用的主模型。speculative_model
  • 由于 70B 型号占用约 140 GB 的 VRAM,我决定使用四个 GPU(A100 80GB),因此tensor_parallel_size设置为 4。
  • 使用dtype=halfFP16 精度加载权重,与完整的 FP32 位精度相比,其内存量只有一半。
  • 该推测模型需要参数use_v2_block_managerenforce_eager才能正确运行。
  • 使用 vLLM 而非常规 LLM 类的主要原因AsyncLLMEngine是 AsyncEngine 可以处理并发请求。

使用推测解码运行推理

import uuid
import asyncio

prompt = "Why is the earth flat?"
stream = True
max_tokens = 512
model_input = {"max_tokens": max_tokens}
sampling_params = SamplingParams(**model_input)

idx = str(uuid.uuid4().hex)
vllm_generator = llm_engine.generate(prompt, sampling_params, idx)

async def stream_tokens():
    full_text = ""
    async for request_output in llm_engine.generate(prompt, sampling_params, idx):
        if len(request_output.outputs) > 0:
            text = request_output.outputs[0].text
            delta = text[len(full_text):]
            full_text = text
            print(delta, end='', flush=True)
    print()

await stream_tokens()
  • 在顶部,我们指定常规参数,例如promptmax_tokensstream。为了轻松直观地看到性能提升,我们将设置stream=True
  • stream_tokens函数中,我们有一个从调用返回的生成器对象llm_engine.generate。为了打印生成器的输出,我们需要异步循环它,因为我们的 LLM 引擎被定义为异步。最后,我们获取生成的标记并打印它们。

性能结果

vLLM Llama 3.1 70B 基准

TPSA

  • 在较低的并发性下,我们看到了近 2 倍的改进。
  • 在更高的并发性下,性能提升约30%。

生成时间

  • 与 TPS 的性能提升类似,使用推测解码时,端到端延迟会减少一半。
  • 随着并发用户数量的增加,生成时间确实会慢慢增加,但仍然比不使用 spec-dec 时要低得多。

薄膜晶体管

  • TTFT 在推测解码的情况下实际上表现更差。
  • 原因是,在主模型验证生成的 token 之前,辅助模型必须生成 token。这会浪费一些时间,导致 TTFT 更高。

潜在挑战

推测解码似乎是一个本垒打,特别是因为它使用 vLLM 非常容易设置。然而,在使用这种推理加速技术时,需要注意一些事项。

  1. 选择合适的助手和主模型

使用 spec-dec 时,并非所有大型语言模型都相互兼容。例如,您不能使用 llama 2 7B 作为辅助模型,而使用 llama 3 70B 作为主模型。

这是因为辅助模型和主模型必须共享相同的词汇表。Llama 2 的词汇表大小为 32K 个标记,而 llama 3 的词汇表大小为 128K 个标记。

2.助手模型的尺寸很重要

由于推测解码依赖辅助模型来完成大部分繁重的工作,因此选择更快的辅助模型可以获得更高的性能。

理想情况下,辅助模型应该具有很少的参数,以便能够快速生成 token。但代价是较小的模型往往不太准确,这引出了我的第三点。

3.主要型号合格率

主模型的高接受率至关重要。如果主模型继续拒绝助手生成的大部分标记,它(较大的模型)现在必须自动回归生成其余的序列。

这会导致 TPS 和整体生成时间的性能大幅下降,因为相同的工作必须重复两次。最好尝试各种模型,看看哪一对模型的接受率最高。

对于上例中的 llama 3.1 70B 和 8B 对,主模型的接受率约为 70%。

结论

对于 LLM 推理而言,速度是一个重要因素,因为没有人愿意让用户久等。推测解码之类的工具可以成为一种很好的解决方案,同时保持较高的输出质量。

在这篇博文中,我们介绍了推测解码的工作原理以及如何使用 vLLM 实现它。虽然它不是每个 LLM 用例的完美解决方案,但它总是好工具。

感谢关注雲闪世界。(Aws解决方案架构师vs开发人员&GCP解决方案架构师vs开发人员)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2084138.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

哈希表两数求和

leetcode题目链接 这道题思路可以说easy 首先想到的就是两层for循环 代码如下 class Solution { public:vector<int> twoSum(vector<int>& nums, int target) {vector<int>result;int lengthnums.size();for(int i0;i<length;i){for(int ji1;j<…

【drools】电影推荐搭建

同步用了很久很久 反复了很多了次,最终: drools用的都是9.*** Y9KP 代码 D:\Future\06_movie-recommendation-system-springboot-drools-rule-engine\recommendation\src\main\resources\com.recommendation.movie.rules\medicine_symptoms.drl在spring boot 中集成drools做电…

16.神经网络 - 卷积层

神经网络 - 卷积层 pytorch官网网站卷积层&#xff08;Convolution Layers&#xff09;&#xff1a;Convolution Layers nn.Conv1d 一维卷积Applies a 1D convolution over an input signal composed of several input planes.nn.Conv2d 二维卷积Applies a 2D convolution ov…

扁线电机介绍

相比于圆线&#xff0c;扁线因为扁平矩形的特殊性能够让线圈缠绕更加紧密&#xff0c;槽满率由原先的40%提升到70%。 这意味着相同体积下线圈中的导线更多&#xff0c;电流的传导效率更高&#xff0c;能够减少电阻损耗&#xff0c;产生的磁场更强&#xff0c;电机功率也就更大&…

【软件工程】软件与软件危机

考点1 软件与软件危机 一、软件 1. 定义 在计算机系统支持下&#xff0c;能够完成特定功能和性能的程序、数据和相关的文档。 2. 软件的分类 二、软件危机 在计算机软件开发和维护过程中所遇到的一系列严重问题。 1. 包含两个方面&#xff1a; 如何开发软件如何维护软件…

Rocm-HIP kernel language

HIP的内核启动语法hipLaunchKernelGGL是一个宏&#xff0c;可以作为启动内核的替代方式&#xff0c;它接受启动配置的参数&#xff08;网格维度、分组维度、流、动态共享大小&#xff09;以及任意数量的内核参数。这个宏可以替代CUDA中的三连字符&#xff08;<<< >…

[记录] linux 虚拟机装 windows10

简介 本机系统&#xff1a;Ubuntu22.04 虚拟机&#xff1a;gnome-boxes 相关资料&#xff1a;度盘 安装流程 安装 gnome-boxes sudo apt install gnome-boxes安装 windows10 打开 Boxes, 选择准备好的 windows10 ISO 文件 可以从官网下载&#xff0c;也可以从我给的资料里获…

OpenCV小练习:身份证号码识别

目标&#xff1a;针对一张身份证照片&#xff0c;把身份证号码识别出来&#xff08;转成数字或字符串&#xff09;。 实现思路&#xff1a;需要将目标拆分成两个子任务&#xff1a;(1) 把身份证号码区域从整张图片中检测/裁剪出来&#xff1b;(2) 将图片中的数字转化成文字。第…

【python】OpenCV—Multi Human Pose Estimation

文章目录 1、背景介绍2、关键点检测模型3、源码与结果4、源码解读——检测关键点5、源码解读——找到有效对6、源码解读——组装个人关键点7、涉及到的库cv2.dnn.blobFromImage 8、参考 1、背景介绍 【python】OpenCV—Single Human Pose Estimation 本文以 COCO 格式为例&am…

低代码门户技术:赋能业务灵活性与创新的新时代

随着数字化转型的深入推进&#xff0c;各行各业对灵活、高效的技术解决方案的需求日益增长。在这个背景下&#xff0c;低代码门户技术应运而生&#xff0c;为企业提供了一种新颖的应用开发方式。今天&#xff0c;我们将探讨低代码门户技术的基本概念、优势以及如何在实际应用中…

uni-app启动本地开发环境,修改默认端口号

vite.config.js: import { defineConfig } from "vite"; import uni from "dcloudio/vite-plugin-uni";// https://vitejs.dev/config/ export default defineConfig({server: {port: 3006,},plugins: [uni()], });人工智能学习网站 https://chat.xutong…

YoloV8实战:使用YoloV8实现OBB框检测

定向边框&#xff08;OBB&#xff09;数据集概述 使用定向边界框&#xff08;OBB&#xff09;训练精确的物体检测模型需要一个全面的数据集。本文解释了与Ultralytics YOLO 模型兼容的各种 OBB 数据集格式&#xff0c;深入介绍了这些格式的结构、应用和格式转换方法。数据集使…

【C++】list的使用和list的模拟实现和迭代器失效问题

一、list 的简单介绍 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向代。 2. list的底层是双向链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立节点中&#xff0c;在节点中通过指针指向其前一个元素和后一个…

三级_网络技术_52_应用题

一、 请根据下图所示网络结构回答下列问题。 1.填写路由器RG的路由表项。 目的网络/掩码长度输出端口__________S0&#xff08;直接连接&#xff09;__________S1&#xff08;直接连接&#xff09;__________S0__________S1__________S0__________S1 2.如果在不改变路由表项…

npm install报错解决指南:清理缓存与重建依赖

问题描述 在执行npm install命令时&#xff0c;npm install报错&#xff0c;导致依赖无法正常安装。 具体步骤 清理npm缓存&#xff1a; 使用npm cache clean --force命令来强制清理npm缓存&#xff0c;以排除缓存导致的问题。 检查Node.js和npm版本&#xff1a; 执行node -v和…

面试经典算法150题系列-反转字符串中的单词

反转字符串中的单词 给你一个字符串 s &#xff0c;请你反转字符串中 单词 的顺序。 单词 是由非空格字符组成的字符串。s 中使用至少一个空格将字符串中的 单词 分隔开。 返回 单词 顺序颠倒且 单词 之间用单个空格连接的结果字符串。 注意&#xff1a;输入字符串 s中可能…

HarmonyOS--合理使用动画

一、概述 动画是应用开发中必不可少的部分&#xff0c;它可以使应用程序更加生动和易于互动&#xff0c;一方面可以提升用户体验、增强视觉吸引力&#xff0c;另一方面可以引导用户操作、提高信息传达效率。应用程序中&#xff0c;页面层级间的转场、点击交互、手势操控都可以添…

一刷代码随想录(图论8)

拓扑排序 软件构建 题意&#xff1a; 题目描述&#xff1a; 某个大型软件项目的构建系统拥有 N 个文件&#xff0c;文件编号从 0 到 N - 1&#xff0c;在这些文件中&#xff0c;某些文件依赖于其他文件的内容&#xff0c;这意味着如果文件 A 依赖于文件 B&#xff0c;则必须…

Semantic Kernel/C#:一种通用的Function Calling方法,文末附经测试可用的大模型

Funcion Calling介绍 函数调用允许您将模型如gpt-4o与外部工具和系统连接起来。这对于许多事情都很有用&#xff0c;比如为AI助手赋能&#xff0c;或者在你的应用程序与模型之间建立深度集成。 如果您了解或者使用过Semantic Kernel可能会发现除了OpenAI支持Function Calling…

cenos 7 安装 golang

1、下载地址 All releases - The Go Programming Languagehttps://golang.google.cn/dl/ 2、解压 tar -C /usr/local -zxf go1.14.3.linux-amd64.tar.gz 3、配置PATH 文件 /etc/profile&#xff08;全局&#xff09; 或 $HOME/.profile&#xff08;用户&#xff09; 或 ~/…