通过 AI Edge Torch 生成式 API 在设备上使用自定义大语言模型

news2024/12/27 10:43:02

67bb3e711571e6c5f0fb404ddf23cff9.png

作者 / 首席工程师 Cormac Brick,软件工程师 Haoliang Zhang

我们很高兴地发布 AI Edge Torch 生成式 API,它能将开发者用 PyTorch 编写的高性能大语言模型 (LLM) 部署至 TensorFlow Lite (TFLite) 运行时,从而无缝地将新的设备端生成式 AI 模型部署到边缘设备上。本文是 Google AI Edge 博客连载的第二篇。上一篇文章为大家介绍了 Google AI Edge Torch,该产品可以在使用 TFLite 运行时的设备上实现高性能的 PyTorch 模型推理。

AI Edge Torch 生成式 API 使开发者能够在设备上引入强大的新功能,例如摘要生成、内容生成等。我们之前已经通过 MediaPipe LLM Inference API 让开发者们能够将一些最受欢迎的 LLM 部署到设备上。现在,我们很高兴能进一步拓展对模型的支持范围,并让大家部署到设备,而且具备优秀的性能表现。今天发布的 AI Edge Torch 生成式 API 是初始版本,提供以下功能:

  • 简单易用的模型创作 API,支持自定义 Transformer。

  • 在 CPU 上性能表现出色,并即将支持 GPU 和 NPU。

  • 作为 AI Edge Torch 的扩展,支持 PyTorch。

  • 完全兼容现有的 TFLite 部署流程,包括量化和运行时。

  • 支持 TinyLlama、Phi-2 和 Gemma 2B 等模型。

  • 兼容 TFLite 运行时和 Mediapipe LLM 运行时接口,支持 Android、iOS 和 Web。

  • MediaPipe LLM Inference API

    https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference

  • AI Edge Torch

    https://ai.google.dev/edge/lite/models/convert_pytorch

我们将在本文中为大家深入介绍该 API 的性能、可移植性、创作开发体验、端到端推理流水线和调试工具链。更具体的文档和示例请查看:

https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative/examples

性能表现

为了让 MediaPipe LLM Inference API 顺利支持最受欢迎的一些 LLM,我们的团队手工打造了几款在设备上拥有最佳性能的 Transformer 模型。通过这项工作,我们确定了几个主要课题: 如何有效地表示注意力机制、量化的使用以及良好的 KV 缓存。生成式 API 很好地完成了这些课题 (本文后面会具体提到),而且依然能达到之前手写版本性能的 90% 以上,并大大提高开发速度。

  • 通过 MediaPipe 和 TensorFlow Lite 在设备上运行大语言模型

    https://developers.googleblog.com/en/large-language-models-on-device-with-mediapipe-and-tensorflow-lite/

下表显示了三种模型样本的关键基准测试结果:

d692b53e324cba281ba0e0ca04eec04f.png

  • 三种模型样本

    https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/README.md

这些基准测试是在大核上运行,使用 4 个 CPU 线程,并且使用了这些模型在所列设备上目前所知最快的 CPU 实现。

创作体验

核心创作库提供了常见 Transformer 模型 (仅编码器、仅解码器或编码-解码器等样式) 的基本构建模块。您可以用它从头开始创作模型,或重新创作现有模型以提高性能。我们建议大多数用户采用重新创作的方式,因为这样就不需要训练 / 微调的步骤了。使用生成式 API 创作的核心优势如下:

  • 一组针对可转换性、性能和平台可移植性进行了优化的核心 Transformer 构建模块,可以轻松与常规 PyTorch 算子进行混合和匹配。

  • 一个简单的权重重映射机制。

  • 直观的量化 API。

  • 支持多签名导出,包括预填充、解码或自定义签名,并能无缝接入现成的 MP 任务 / LLM Inference API。

作为示例,下面展示如何使用新的生成式 API 以约 50 行 Python 代码重新创作 TinyLLama (1.1B) 的核心功能。

  • TinyLLama (1.1B)

    https://github.com/jzhang38/TinyLlama

步骤 1: 定义模型结构

import torch
import torch.nn as nn


from ai_edge_torch.generative.layers.attention import TransformerBlock
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.builder as builder
import ai_edge_torch.generative.layers.model_config as cfg




class TinyLLamma(nn.Module):


  def __init__(self, config: cfg.ModelConfig):
    super().__init__()


    self.config = config
    # Construct model layers.
    self.lm_head = nn.Linear(
        config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
    )
    self.tok_embedding = nn.Embedding(
        config.vocab_size, config.embedding_dim, padding_idx=0
    )
    self.transformer_blocks = nn.ModuleList(
        TransformerBlock(config) for _ in range(config.num_layers)
    )
    self.final_norm = builder.build_norm(
        config.embedding_dim,
        config.final_norm_config,
    )
    self.rope_cache = attn_utils.build_rope_cache(
        size=config.kv_cache_max,
        dim=int(config.attn_config.rotary_percentage * config.head_dim),
        base=10_000,
        condense_ratio=1,
        dtype=torch.float32,
        device=torch.device("cpu"),
    )
    self.mask_cache = attn_utils.build_causal_mask_cache(
        size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
    )
    self.config = config

步骤 2: 定义模型的前向函数

@torch.inference_mode
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
    B, T = idx.size()
    cos, sin = self.rope_cache
    cos = cos.index_select(0, input_pos)
    sin = sin.index_select(0, input_pos)
    mask = self.mask_cache.index_select(2, input_pos)
    mask = mask[:, :, :, : self.config.kv_cache_max]


    # forward the model itself
    x = self.tok_embedding(idx)  # token embeddings of shape (b, t, n_embd)


    for i, block in enumerate(self.transformer_blocks):
      x = block(x, (cos, sin), mask, input_pos)


    x = self.final_norm(x)
    res = self.lm_head(x)  # (b, t, vocab_size)
    return res

步骤 3: 映射旧模型权重

您可以使用库中的 ModelLoader API 轻松映射权重,就像这样:

import ai_edge_torch.generative.utilities.loader as loading_utils




# This map will associate old tensor names with the new model.
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
    ff_up_proj="model.layers.{}.mlp.up_proj",
    ff_down_proj="model.layers.{}.mlp.down_proj",
    ff_gate_proj="model.layers.{}.mlp.gate_proj",
    attn_query_proj="model.layers.{}.self_attn.q_proj",
    attn_key_proj="model.layers.{}.self_attn.k_proj",
    attn_value_proj="model.layers.{}.self_attn.v_proj",
    attn_output_proj="model.layers.{}.self_attn.o_proj",
    pre_attn_norm="model.layers.{}.input_layernorm",
    pre_ff_norm="model.layers.{}.post_attention_layernorm",
    embedding="model.embed_tokens",
    final_norm="model.norm",
    lm_head="lm_head",
)

完成这些步骤后,您可以运行一些示例输入来验证重新创作过的模型的数值正确性。如果数值检查达标,您就可以继续进行后续的转换和量化操作。

  • 验证重新创作的模型

    https://github.com/google-ai-edge/ai-edge-torch/blob/59946008def0ab867c2f4cd8931eaf607ac0d768/ai_edge_torch/generative/test/test_model_conversion.py#L132

转换和量化

通过 ai_edge_torch 提供的转换 API,您可以将 (重新创作的) Transformer 模型转换为高度优化的 TensorFlow Lite 模型。转换过程包含以下关键步骤:

  1. 导出到 StableHLO。通过 torch dynamo 编译器对 PyTorch 模型进行追踪和编译,生成带有 Aten 算子的 FX 计算图,然后由 ai_edge_torch 将其降为 StableHLO 计算图。

  2. ai_edge_torch 在 StableHLO 上执行进一步的编译器操作,包括算子融合 / 折叠等,生成高性能的 TFLite flatbuffer (包含用于 SDPA、KVCache 的融合算子)。

  • StableHLO

    https://github.com/openxla/stablehlo

量化

核心生成式 API 库还提供了一组量化 API,涵盖了常见的 LLM 量化模式。这些模式作为额外参数传递给 ai_edge_torch 转换器 API,由该 API 自动完成量化。我们会在未来的版本中提供更多的量化模式。

多签名导出

我们发现在实际推理场景中,LLM 模型需要有明确分离 (细分) 的推理函数 (预填充、解码),才能实现最佳的服务性能。这部分基于这样的观察: 预填充 / 解码可能需要采用不同的 tensor 形状,预填充受到算力限制,而解码则受到内存限制。对于大型 LLM,避免在预填充 / 解码之间重复模型权重至关重要。我们使用 TFLite 和 ai_edge_torch 中现有的多签名特性来实现这一点,使得开发者能轻松地为模型定义多个入口,如下所示:

def convert_tiny_llama_to_tflite(
    prefill_seq_len: int = 512,
    kv_cache_max_len: int = 1024,
    quantize: bool = True,
):
  pytorch_model = tiny_llama.build_model(kv_cache_max_len=kv_cache_max_len)
  
  # Tensors used to trace the model graph during conversion.
  prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
  prefill_input_pos = torch.arange(0, prefill_seq_len)
  decode_token = torch.tensor([[0]], dtype=torch.long)
  decode_input_pos = torch.tensor([0], dtype=torch.int64)


  # Set up Quantization for model.
  quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
  
  edge_model = (
      ai_edge_torch.signature(
          'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
      )
      .signature('decode', pytorch_model, (decode_token, decode_input_pos))
      .convert(quant_config=quant_config)
  )
  edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')

针对 LLM 的性能优化

我们在性能调查阶段发现了几个改善 LLM 性能的关键要素:

  1. 高性能的 SDPA 和 KVCache: 我们发现,如果没有足够的编译器优化 / 融合,转换后的 TFLite 模型会因为这些函数中算子的粒度问题,性能不会很好。为了解决这个问题,我们引入了高级函数边界和 StableHLO 复合算子。

  2. 利用 TFLite 的 XNNPack 代理进一步加速 SDPA: 确保大量 MatMul / 矩阵-向量计算得到很好的优化至关重要。XNNPack 库能在广泛的移动 CPU 上以出色的性能完成这些基础计算。

  3. 避免不必要的计算: 静态形状模型如果在预填充阶段有长且固定的输入消息大小,或者在解码阶段有大的固定序列长度,则带来的计算量会大于该模型需要的最小计算量。

  4. 运行时内存消耗: 我们在 TFLite 的 XNNPack 代理中引入了权重缓存 / 预打包机制,显著降低了内存的峰值使用量。

  • SDPA

    https://github.com/google-ai-edge/ai-edge-torch/blob/7f52f70709bc12cf041b3b1fd4a49bc0d52c889a/ai_edge_torch/generative/layers/attention.py#L74

部署

LLM 推理通常涉及许多预处理 / 后处理步骤和复杂的编排,例如令牌化、采样和自回归解码逻辑。为此,我们提供了基于 MediaPipe 的解决方案以及一个纯 C++ 推理示例。

  • 基于 MediaPipe 的解决方案

    https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference

  • 纯 C++ 推理示例

    https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative/examples/c%2B%2B

使用 MediaPipe LLM Inference API

MediaPipe LLM Inference API 是一个高级 API,支持使用 prompt-in / prompt-out 接口进行 LLM 推理。它负责处理底层所有的 LLM 复杂流水线操作,让模型得以更轻松和顺畅地部署。要使用 MediaPipe LLM Inference API 进行部署,您需要使用给定的预填充和解码签名来转换模型,并创建一个任务包,如下方代码所示:

def bundle_tinyllama_q8():
  output_file = "PATH/tinyllama_q8_seq1024_kv1280.task"
  tflite_model = "PATH/tinyllama_prefill_decode_hlfb_quant.tflite"
  tokenizer_model = "PATH/tokenizer.model"
  config = llm_bundler.BundleConfig(
      tflite_model=tflite_model,
      tokenizer_model=tokenizer_model,
      start_token="<s>",
      stop_tokens=["</s>"],
      output_filename=output_file,
      enable_bytes_to_unicode_mapping=False,
  )
  llm_bundler.create_bundle(config)

在 TFLite 运行时使用纯 C++ 推理

我们还提供了一个简单易用的 C++ 示例 (无需 MediaPipe 依赖),来展示如何运行端到端的文本生成。如果您需要将导出的模型与自己独有的生产流水线和需求进行集成,这个示例是一个很好的起点,来帮助您实现更好的定制和灵活性。

跨平台支持

由于核心推理运行时都支持 TFLite,所以整个流水线都可以轻松集成到您的 Android (包括在 Google Play 中) 或 iOS 应用中,无需进行任何修改。这意味着用新的生成式 API 转换的模型只需添加几个自定义算子依赖即可立即部署。在未来的版本中,我们将为 Android 和 iOS 带来 GPU 支持,并支持 ML 加速器 (TPU、NPU)。

工具

最近发布的模型探索器 (Model Explorer) 是一款很好用的工具,可用于可视化诸如 Gemma 2B 之类的大型模型。分层查看和并排比较可以让您轻松查看和比较原始、重新创作和转换后的模型。我们也准备了专门的文章为您进一步介绍该工具,以及如何通过可视化基准信息来优化模型性能。

  • 模型探索器

    https://ai.google.dev/edge/model-explorer

  • 模型探索器: 大模型开发的计算图可视化工具

    https://research.google/blog/model-explorer/

以下是我们在编写 PyTorch TinyLlama 模型时使用该工具的示例。我们并排显示了 PyTorch export() 模型与 TFLite 模型。通过使用模型探索器,我们可以轻松比较每个层级 (如 RMSNorms、SelfAttention) 的表达情况。

d6a3aaba051b7068070714decc35c648.gif

△ 并排比较 TinyLlama PyTorch 和转换后的 TFLite

总结以及下一步

AI Edge Torch 生成式 API 是为 MediaPipe LLM Inference API 预构建优化模型的强大补充,适用于希望在设备上运行自己的生成式 AI 模型的开发者。我们会在接下来的几个月继续带来更新,包括 Web 支持、更好的量化和对 CPU 之外的硬件的支持。我们也会尝试探索更好的框架集成方案。

目前发布的是开发库的早期预览版本,该版本依然处于实验阶段,旨在与开发者社区进行开放互动。API 可能会发生变化,且存在不完善之处,并且对量化和模型的支持有限。但我们已经在 GitHub repo 中为大家提供了很多用于上手的内容,欢迎大家测试和体验,并随时和我们分享 PR、问题和功能需求。

  • GitHub repo

    https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative

在本次连载的第三篇文章中,我们将深入探讨模型探索器可视化工具,了解该工具如何帮助开发者们可视化、调试和探索模型。

  • 模型探索器

    https://ai.google.dev/edge/model-explorer


5e13c25da8e2c39d5e536bc00ad4d718.png

9a02db4d383a5f417307aa954cf086eb.png

a3f99636cb6d86070191d849b1b92855.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1792564.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

申请医疗设备注册变更时,需要补充考虑网络安全的情况有哪些?

在申请医疗器械设备注册变更时&#xff0c;需要补充网络安全的情况主要包括以下几点&#xff1a; 网络安全功能更新&#xff1a;如果医疗器械的自研软件发生网络安全功能更新&#xff0c;或者合并网络安全补丁更新的情形&#xff0c;需要单独提交一份自研软件网络安全功能更新…

计算机网络ppt和课后题总结(下)

常用端口总结 计算机网络中&#xff0c;端口是TCP/IP协议的一部分&#xff0c;用于标识运行在同一台计算机上的不同服务。端口号是一个16位的数字&#xff0c;范围从0到65535。通常&#xff0c;0到1023的端口被称为“熟知端口”或“系统端口”&#xff0c;它们被保留给一些标准…

springboot项目中第三方jar包打包进jar包

springboot项目中&#xff0c;如果手动引入了jar包&#xff0c;打包时不会将手动引入的第三方jar包打包进价包里&#xff0c;如何处理&#xff1f; 若第三方的jar包的lib和src同级&#xff0c;则maven打包时默认不会将lib下的jar包打包进jar包&#xff0c;处理方式有两种&#…

康谋技术 | 自动驾驶:揭秘高精度时间同步技术(二)

在自动驾驶中&#xff0c;对车辆外界环境进行感知需要用到很多传感器的数据&#xff08;Lidar&#xff0c;Camera&#xff0c;GPS/IMU&#xff09;&#xff0c;如果计算中心接收到的各传感器消息时间不统一&#xff0c;则会造成例如障碍物识别不准等问题。 为了对各类传感器进…

数据结构与算法-12_二叉搜索树

文章目录 1.概述2.实现定义节点查询Comparable最小最大新增前驱后继删除找小的找大的找之间小结 3.习题E01. 删除节点-Leetcode 450E02. 新增节点-Leetcode 701E03. 查询节点-Leetcode 700E04. 验证二叉搜索树-Leetcode 98E05. 求范围和-Leetcode 938E06. 根据前序遍历结果构造…

【面试题】创建两个线程交替打印100以内数字(一个打印偶数一个打印奇数)

阅读导航 一、问题概述二、解决思路三、代码实现四、代码优化 一、问题概述 面试官&#xff1a;C多线程了解吗&#xff1f;你给我写一下&#xff0c;起两个线程交替打印0~100的奇偶数。就是有两个线程&#xff0c;一个线程打印奇数另一个打印偶数&#xff0c;它们交替输出&…

读AI未来进行式笔记04数字医疗与机器人

1. 数字医疗 1.1. 20世纪的“现代医学”得益于史无前例的科学突破&#xff0c;使得医疗的方方面面都得到改善&#xff0c;让人类预期寿命从1900年的31岁提高到2017年的72岁 1.2. 现有的医疗数据库和流程将实现数字化 1.2.1. 患者记录 1.2.…

泛微开发修炼之旅--06自定义Action接口开发示例、源码及使用场景

文章链接&#xff1a;泛微开发修炼之旅--06自定义Action接口开发示例、源码及使用场景

创新实训2024.06.02日志:SSE、流式输出以及基于MTPE技术的MT-SSE技术

1. Why SSE&#xff1f; 之所以要做SSE&#xff0c;是因为在开发、调试以及使用我们开发的软件时&#xff0c;我发现消息的响应时间会很长。之所以会这样最主要的原因是&#xff0c;MTPE这项基于CoT的技术&#xff0c;本质上是多个单一的提示工程有机地组合在一起对大模型生成…

Java中常见错误-泛型擦除及桥接方法问题及解决方案

Java中泛型擦除及桥接方法 泛型擦除无界擦除上界擦除下界擦除 桥接方法演示案例wrong1wrong2wrong3right 原理总结 泛型擦除 ​ 泛型擦除是Java泛型机制的一个特性&#xff0c;它意味着**在编译期间&#xff0c;所有的泛型信息都会被移除&#xff0c;而在运行时&#xff0c;所…

html+CSS+js部分基础运用15

1、完成输入框内容的实时反向输出。 2、银行账户余额变动自动通知项目。 设计要求&#xff1a;单击按钮后&#xff0c;余额按照输入框的数额减少&#xff0c;同时将按钮式的提示信息&#xff08;金额&#xff09;同步改变。利用侦听属性实现余额发生变化时发出提示信息&#x…

python-flask项目的服务器线上部署

在部署这部分我首先尝试了宝塔面板&#xff0c;始终连接失败 换了一种思路选择了Xshell成功连接 首先我们需要下载个免费版本的Xshell 免费的&#xff1a;家庭/学校免费 - NetSarang Website 下载完毕打开 1新建-> 输入服务器的账号密码&#xff1a; 在所有会话中点击自…

NDIS Filter开发-PNP响应和安装

NDIS filter驱动可能是最容易生成的驱动之一&#xff0c;如果你安装了VS 2015 WDK之后&#xff0c;你可以直接生成一个能运行的Filter驱动&#xff0c;它一般是ndislwf。 和大部分硬件不同&#xff0c;NDIS Filter驱动介于软件和硬件抽象层之上&#xff0c;它和硬件相关&…

工业无线wifi系统搭配高速路由,解决联网及数据传输

​面对日益复杂的工业应用场景,企业对无线网络的高速、可靠和安全提出了更高要求。星创易联SR600系列多网口4G路由器应运而生,为工业无线WiFi系统提供了一个性能卓越的高速路由方案。&#xff08;key-iot.com/iotlist/sr600-5.html&#xff09; SR600路由器集4G LTE、虚拟专用…

c++(内存分配,构造,析构)

#include <iostream>using namespace std; class Per { private:string name;int age;double *height;double *weigh; public://无参构造Per(){cout << "Per::无参构造" << endl;}//有参构造Per(string name,int age,double height,double weigh):…

C++候捷stl-视频笔记4

一个万用的hash function 哈希函数的形式&#xff0c;一种是一般函数(右边)&#xff0c;一种是成员函数(左边)&#xff0c;类的对象将成为函数对象 具体做法例子。直接把属性的所有hash值加起来&#xff0c;会在hashtable中会产生很多的碰撞&#xff0c;放在同一个bucket中的元…

Nginx的https功能

一.HTTPS功能简介 Web网站的登录页面都是使用https加密传输的&#xff0c;加密数据以保障数据的安全&#xff0c;HTTPS能够加密信息&#xff0c;以免敏感信息被第三方获取&#xff0c;所以很多银行网站或电子邮箱等等安全级别较高的服务都会采用HTTPS协议&#xff0c;HTTPS其实…

优化家庭网络,路由器无线中继配置全攻略(中兴E1600无线中继设置/如何解决没有预埋有线网络接口的问题/使用闲置路由实现WIFI扩展)

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 网络优化 📒📒 操作步骤 📒💡适用场景🚨 常见问题及解决方案⚓️ 相关链接 ⚓️📖 介绍 📖 在现代家庭生活中,WiFi已经渗透到我们生活的每一个角落,成为了日常生活中不可或缺的一部分。然而,不少用户常常遇到W…

Bytebase 作为唯一数据库工具厂商,亮相亚马逊云科技中国峰会

作为云计算行业的风向标&#xff0c;亚马逊云科技中国峰会每年都吸引着全球顶尖企业和行业精英。此次峰会不仅展示了最新的 AI 技术趋势和解决方案&#xff0c;还为参展商和与会者提供了一个卓越的交流与合作平台。 Bytebase 作为全场唯一的数据库工具厂商亮相数据区&#xff0…

Windows下Qt5.14.2连接华为IoTDA平台

一、华为IoTDA简介 华为云物联网平台&#xff08;IoT 设备接入云服务&#xff09;提供海量设备的接入和管理能力&#xff0c;将物理设备联接到云&#xff0c;支撑设备数据采集上云和云端下发命令给设备进行远程控制&#xff0c;配合华为云其他产品&#xff0c;帮助您快速构筑物…