导出谷歌gemma模型为ONNX

news2025/1/23 6:08:08

参考代码如下(从GitHub - luchangli03/export_llama_to_onnx: export llama to onnx修改而来,后面会合入进去)

模型权重链接参考:

https://huggingface.co/google/gemma-2b-it

可以对modeling_gemma.py进行一些修改(transformers升级为最新版本内置该模型代码),从而提升导出的onnx性能:

1,GemmaForCausalLM中原始的logits计算为:

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

修改为:

        hidden_states = outputs[0]
        hidden_states = hidden_states[:,-1:,:]
        logits = self.lm_head(hidden_states)

这样使得降低prefill阶段lm_head的计算量。

2,模型使用了GemmaSdpaAttention,导出的onnx模型从一个很大的张量中索引向量仅仅用作attention mask:

causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
    causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

这里即增加了存储又增加了计算。实际上可以直接把扩展后的attention mask作为onnx输入传入进来,从而完全消除这个存储和计算。

不知为何很多模型(例如千问等)都输入一个[1, seq_len]的向量,然后内部扩展为一个[1,1, seq_len, sumN]的mask,这些操作都可以直接替换为模型直接采用[1,1, seq_len, sumN]的mask输入。

这里对modeling_gemma.py修改方法为:

class GemmaModel(GemmaPreTrainedModel):
    def forward(
        # causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
        causal_mask = attention_mask

class GemmaSdpaAttention(GemmaAttention):
    def forward(
        # if attention_mask is not None and cache_position is not None:
        #     causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

模型导出代码(进行了上述修改,如果不想修改的话,修改下这里面的atten mask的shape,dtype即可):

import os
import argparse
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer


class LLMForCausalLMWrapper(nn.Module):
    def __init__(self, model, config, args):
        super().__init__()
        self.model = model
        self.config = config
        self.args = args

    def forward(
        self,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
        output_attentions=False,
        output_hidden_states=False,
        use_cache=True,
    ):
        """
        Note: you can modify modeling_gemma.py to make the converted model more efficient:
        hidden_states = outputs[0]
        hidden_states = hidden_states[:,-1:,:]
        logits = self.lm_head(hidden_states)
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=None,
            use_cache=True,
        )

        logits = outputs.logits
        kv_caches_out = []
        for past_kv in outputs.past_key_values:
            kv_caches_out.extend(past_kv)

        topk_outputs = []
        if self.args.add_topk_warper > 0:
            logging.warning("add topk to glm model")
            if self.args.topk < 0:
                raise ValueError("topk {} is invalid")
            topk_outputs = torch.topk(logits, k=self.args.topk, dim=-1)

        return logits, *kv_caches_out, *topk_outputs


def export_llm_to_single_onnx(model, config, dtype, args, model_name):
    llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)

    onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")

    layer_num = len(model.model.layers)

    hidden_size = config.hidden_size
    head_num = config.num_attention_heads
    head_dim = config.head_dim

    batch = 1
    N = 1
    sumN = 32
    lastSum = sumN - N

    input_ids_shape = [batch, N]
    input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)
    # Note: orig atten_mask shape is [1, sumN]
    attention_mask = torch.randn([batch, 1, N, sumN], dtype=dtype).to(args.device)
    position_ids = torch.ones([batch, N], dtype=torch.int64).to(args.device)

    in_names = ["input_ids", "attention_mask", "position_ids"]

    dynamic_axes = {
        'input_ids': {1: 'N', },
        'attention_mask': {2: 'N', 3: 'sumN'},
        "position_ids": {1: 'N', },
    }
    if args.dyn_batch:
        dynamic_axes['input_ids'][0] = "batch"
        dynamic_axes['attention_mask'][0] = "batch"
        dynamic_axes['position_ids'][0] = "batch"

    kv_caches_in = []
    out_names = ["lm_logits"]

    kv_cache_in_shape = [1, 1, lastSum, head_dim]
    kv_cache_dyn_axes = {2: "sumN-N"}

    if args.dyn_batch:
        kv_cache_dyn_axes[0] = "batch"

    past_key_values = []

    for i in range(layer_num):
        past_key_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)
        past_value_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)

        kv_caches_in.extend([past_key_in, past_value_in])
        in_names.extend([f"past_key_in{i}", f"past_value_in{i}"])
        out_names.extend([f"past_key{i}", f"past_value{i}"])

        dynamic_axes[f"past_key_in{i}"] = kv_cache_dyn_axes
        dynamic_axes[f"past_value_in{i}"] = kv_cache_dyn_axes

        past_key_values.append((past_key_in, past_value_in))

    input_datas = (input_ids, attention_mask, position_ids, past_key_values)

    torch.onnx.export(
        llama_model_wrapper,
        input_datas,
        onnx_file_name,
        opset_version=args.opset,
        do_constant_folding=True,
        input_names=in_names,
        output_names=out_names,
        dynamic_axes=dynamic_axes,
    )


def export_llama(args):
    device = args.device
    dtype_map = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    dtype = dtype_map[args.dtype]

    print(f"begin load model from {args.model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()

    # model.model.layers = model.model.layers[:1]  # only export one layer for debug

    print(f"finish load model from {args.model_path}")
    config = model.config
    print("config:", config)

    print(f"begin export llm")
    export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='export llm',
    )
    parser.add_argument('-m', '--model_path', required=True, type=str)
    parser.add_argument('-o', '--out_dir', required=False, type=str, default="")
    parser.add_argument('--opset', required=False, type=int, default=15)
    parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")
    parser.add_argument('-p', '--dtype', required=False, type=str,
                        choices=["float32", "float16", "bfloat16"], default="float16")
    parser.add_argument('--add_topk_warper', required=False, type=int, default=0)
    parser.add_argument('--topk', required=False, type=int, default=4)
    parser.add_argument('--dyn_batch', action='store_true')

    args = parser.parse_args()
    export_llama(args)

导出的onnx文件onnxsim:

GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model

导出的onnx模型推理示例(依赖文件在GitHub - luchangli03/export_llama_to_onnx: export llama to onnx)

import numpy as np
from onnx_rt_utils import OnnxRuntimeModel, get_random_data
from sample_utils import sample_topk
from transformers import AutoTokenizer


def prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum):
    """
    only used at the first time
    in round 0, actually the lastSum is 0, thus past_key_in, past_value_in are empty tensor
    """
    for i in range(layer_num):
        past_key_in = get_random_data([1, 1, lastSum, 256], "float16")
        past_value_in = get_random_data([1, 1, lastSum, 256], "float16")

        past_key_in_name = f"past_key_in{i}"
        past_value_in_name = f"past_value_in{i}"
        glm_model_inputs[past_key_in_name] = past_key_in
        glm_model_inputs[past_value_in_name] = past_value_in
    return glm_model_inputs


def prepare_kv_cache_from_outputs(glm_model_inputs, decoder_outputs, layer_num):
    offset = 1
    for i in range(layer_num):
        past_key_in_name = f"past_key_in{i}"
        past_value_in_name = f"past_value_in{i}"

        glm_model_inputs[past_key_in_name] = decoder_outputs[offset + i * 2]
        glm_model_inputs[past_value_in_name] = decoder_outputs[offset + i * 2 + 1]
    return glm_model_inputs


def get_atten_mask(N,  sumN,  padded_len):
    attention_mask = np.zeros(shape=[N * padded_len], dtype="float16")

    pad_num = padded_len - sumN
    if (N == sumN):
        for i in range(N):
            mask_num = N - 1 - i + pad_num
            start = padded_len - mask_num
            for j in range(start, padded_len):
                attention_mask[i * padded_len + j] = -65504
    else:
        if (N != 1):
            raise ValueError("N is not 1")
        lastSum = sumN - N
        for i in range(pad_num):
            attention_mask[lastSum + i] = -65504

    attention_mask = attention_mask.reshape([N, padded_len])
    return attention_mask


# all decoder layer num
layer_num = 18
eos_token_id = 2

pt_model_path = r"E:\test_models\llama\gemma-2b-it"
onnx_model_path = "llm_onnx.onnx"

prompt = "Write me a poem about Machine Learning."
tokenizer = AutoTokenizer.from_pretrained(pt_model_path, trust_remote_code=True)
input_ids = tokenizer(prompt)['input_ids']

print(input_ids)

input_ids = np.array(input_ids).reshape([1, -1]).astype("int64")

N = input_ids.shape[1]
sumN = N
lastSum = sumN - N
print("N:", N, sumN, lastSum)

position_ids = np.arange(sumN).reshape([1, -1]).astype("int64")

input_ids = input_ids.astype("int64")
position_ids = position_ids.astype("int64")

glm_model = OnnxRuntimeModel(onnx_model_path)

max_seq = 32

glm_model_inputs = {}

gen_tokens = []

for i in range(max_seq):
    print("input_ids:", input_ids)
    print("position_ids:", position_ids)

    attention_mask = get_atten_mask(N, sumN, padded_len=sumN).astype("float16")
    print("attention_mask:", attention_mask)
    attention_mask = attention_mask.reshape([1, 1, N, sumN])

    glm_model_inputs["input_ids"] = input_ids
    glm_model_inputs["attention_mask"] = attention_mask
    glm_model_inputs["position_ids"] = position_ids

    if i == 0:
        glm_model_inputs = prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum)

    glm_model_outputs = glm_model(**glm_model_inputs)
    lm_logits = glm_model_outputs[0]
    print("lm_logits:", lm_logits)

    next_token = sample_topk(lm_logits, topk=1)
    gen_tokens.append(next_token)
    print("next_token:", next_token)

    if next_token == eos_token_id:
        break

    input_ids = np.array([next_token]).astype("int64").reshape([-1, 1])
    position_ids = np.array([sumN]).astype("int64").reshape([-1, 1])

    N = 1
    sumN += 1
    prepare_kv_cache_from_outputs(glm_model_inputs, glm_model_outputs, layer_num)

gen_text = tokenizer.decode(gen_tokens)
print("Q:", prompt)
print("A:", gen_text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1501275.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LLCC68与SX1278 LoRa模块的优势对比?

LLCC68和SX1278都是Semtech公司推出的LoRa调制解调器模块&#xff0c;属于LoRa模块家族。它们在无线通信领域都有着广泛的应用&#xff0c;但具体的优势会取决于具体的应用场景和需求。下面是对LLCC68和SX1278 LoRa模块的一些优势对比&#xff1a; LLCC68 LoRa模块的优势&#…

qt自定义时间选择控件窗口

效果如图&#xff1a; 布局如图&#xff1a; 参考代码&#xff1a; //DateTimeSelectWidget #ifndef DATETIMESELECTWIDGET_H #define DATETIMESELECTWIDGET_H#include <QWidget> #include <QDateTime>namespace Ui { class DateTimeSelectWidget; }class DateTim…

【手游联运平台搭建】游戏平台的作用

随着科技的不断发展&#xff0c;游戏行业也在不断壮大&#xff0c;而游戏平台作为连接玩家与游戏的桥梁&#xff0c;发挥着越来越重要的作用。游戏平台不仅为玩家提供了便捷的游戏体验&#xff0c;还为游戏开发者提供了广阔的市场和推广渠道。本文将从多个方面探讨游戏平台的作…

扩展CArray类,增加Contain函数

CArray不包含查找类的函数&#xff0c;使用不便。考虑扩展CArray类&#xff0c;增加Contain函数&#xff0c;通过回调函数暴露数组元素的比较方法&#xff0c;由外部定义。该方法相对重载数组元素的“”符号更加灵活&#xff0c;可以根据需要配置不同的回调函数进行比较 //类型…

继深圳后,重庆与鸿蒙展开原生应用开发合作

截至2023年底&#xff0c;开源鸿蒙开源社区已有250多家生态伙伴加入&#xff0c;开源鸿蒙项目捐赠人达35家&#xff0c;通过开源鸿蒙兼容性测评的伙伴达173个&#xff0c;累计落地230余款商用设备&#xff0c;涵盖金融、教育、智能家居、交通、数字政府、工业、医疗等各领域。 …

底层day3作业

思维导图 作业&#xff1a;1.总结任务的调度算法&#xff0c;把实现代码再写一下 算法&#xff1a;抢占式调度时间片轮转 1.抢占式调度&#xff1a;任务优先级高的可以打断任务优先级低的执行&#xff08;适用于不同优先级&#xff09; 2.时间片轮转&#xff1a;每一个任务拥…

react的diff源码

react 的 render 阶段&#xff0c;其中 begin 时会调用 reconcileChildren 函数&#xff0c; reconcileChildren 中做的事情就是 react 知名的 diff 过程 diff 算法介绍 react 的每次更新&#xff0c;都会将新的 ReactElement 内容与旧的 fiber 树作对比&#xff0c;比较出它们…

电脑小问题:Windows更新后黑屏

Windows 更新后黑屏解决方法 在 Windows 更新后&#xff0c;伴随了一个小问题&#xff0c;电脑启动后出现了桌面黑屏。原因可能是火绒把 explorer.exe 当病毒处理了。 下面讲解 Windows 更新后黑屏的解决方法&#xff0c;步骤如下&#xff1a; 1. 按 ctrl alt delete 组合键…

基于Python3的数据结构与算法 - 12 数据结构(列表和栈)

目录 一、引入 二、分类 三、列表 1. C语言中数组的存储方式 2. Python中列表的存储方式 四、栈 1. 栈的应用 -- 括号匹配问题 一、引入 定义&#xff1a;数据结构是指相互之间存在着一种或多种关系的数据元素的集合和该集合中数据元素之间的关系组成。简单来说&#x…

portainer管理远程docker和docker-swarm集群

使用前请先安装docker和docker-compose&#xff0c;同时完成docker-swarm集群初始化 一、portainer-ce部署 部署portainer-ce实时管理本机docker&#xff0c;使用docker-compose一键拉起 docker-compose.yml version: 3 services:portainer:container_name: portainer#imag…

Docker上部署LPG(loki+promtail+grafana)踩坑复盘

Docker上部署LPG&#xff08;lokipromtailgrafana&#xff09;踩坑复盘 声明网上配置部署踩坑 声明 参考掘金文章&#xff1a;https://juejin.cn/post/7008424451704356872 版本高的用docker compose命令&#xff0c;版本低的用docker-compose 按照文章描述&#xff0c;主要准备…

UVA378 Intersecting Lines 题解

UVA378 Intersecting Lines 题解 怎么这么多点斜式邪教啊。 解法 在计算几何中&#xff0c;我们应该尽可能地避免使用浮点数的计算&#xff0c;尽可能地使用向量计算。 本篇题解默认读者具有向量基础。 为了方便讲解&#xff0c;我们将输入的四个点分别记作 A , B , C , …

本鲸多方位助力创业者高效对接创新创业机遇

在科技创新的浪潮中&#xff0c;创业者们不断探索着新的商业机会&#xff0c;寻求着创新创业的道路。然而&#xff0c;面对复杂多变的市场环境和激烈的竞争压力&#xff0c;如何高效对接创新创业机遇成为了摆在创业者面前的重要课题。 本鲸依托海南本鲸投资有限公司和重庆本鲸…

关于Vivado的实施过程、SDC和XDC约束支持、Vivado实施子流程、Tcl API支持脚本

关于Vivado的实施过程 AMD Vivado™设计套件可实现以下AMD设备体系结构&#xff1a;AMD Versal™自适应计算加速平台&#xff08;自适应SoC&#xff09;&#xff0c;AMDUltraScale™、AMD UltraScale™和AMD 7系列FPGA。各种设计来源如下支持&#xff0c;包括&#xff1a; •…

WebGPU vs. 像素流

在构建 Bzar 之前&#xff0c;我们讨论过我们的技术栈是基于在云上渲染内容的像素流&#xff0c;还是基于使用设备自身计算能力的本地渲染技术。 由于这种选择会极大地影响项目的成本、可扩展性和用户体验&#xff0c;因此在开始编写一行代码之前&#xff0c;从一开始就采取正确…

typeorm-入门

简述 typeorm是一个数据库orm框架&#xff0c;在nestjs官网中有提到&#xff0c;可以充分发挥利用typescript的特性&#xff0c;当然也支持js其中涉及的概念包括 DataSource 数据源&#xff0c;Connection 连接数据库Entity 实体&#xff0c;实体类映射数据库表Relation 关系…

30个炫酷光效视频转场PR模板剪辑素材下载

视频转场Premiere模板&#xff0c;包含30个炫酷光效视频转场过渡效果PR项目模板下载。 适用软件&#xff1a;Premiere Pro 2023 | 分辨率&#xff1a;3840x2160 (4K) | 无需插件 | 文件大小&#xff1a;56.33MB 来自PR转场&#xff0c;下载地址&#xff1a;https://prmuban.com…

Windows下Node.js安装保姆级教程

一、Node.js 下载 访问Node.js官网&#xff0c;点击下载Node.js 下载完成后即可在下载文件中查看安装包 二、安装 一&#xff09;点击安装包开始安装&#xff0c;进入Weclcome界面点击Next 二&#xff09;勾选同意协议&#xff0c;点击Next 三&#xff09;根据需要选择安装路…

Neo4j安装 Linux:CentOS、openEuler 适配langchain应用RAG+知识图谱开发 适配昇腾910B

目录 Neo4j下载上传至服务器后进行解压运行安装JAVA再次运行在windows端打开网页导入数据 Neo4j下载 进入Neo4j官网下载页面 向下滑动找到 Graph Database Self-Managed 选择 社区版&#xff08;COMMUNITY&#xff09; 选择 Linux / Mac Executable Neo4j 5.17.0 (tar) 单机下…

Linux第72步_使用“新字符设备的一般模板”编写LED驱动

使用“新字符设备的一般模板”编写LED驱动&#xff0c;使用寄存器直接开关灯。 1、创建LED目录 输入“cd /home/zgq/linux/Linux_Drivers/回车” 切换到“/home/zgq/linux/Linux_Drivers/” 输入“ls回车”&#xff0c;查看“/home/zgq/linux/Linux_Drivers/” 输入“mkdi…