torch模型转onnx

news2024/11/17 10:38:20

加载模型

model=torch.load('saved_model/moudle_best_auc.pth', map_location='cpu')
model.eval().cpu()

注:由于导出的模型是用于推理的,因此必须指定模型加载的位置和模型验证的位置,这里我使用了cpu做出导出的硬件

分析模型的输入和输出

这里要看自己的代码,在数据加载的时候输出是什么样子的,比如我这里:
在这里插入图片描述
那输入的list就比较好确定了,是["input_ids","input_mask"],模型的输出list确定,要根据网络的最后一层来看,我这里用的class_loj,那我的输出list就是["class_loj"]

构建一个dummy_input

这是导出是否成功的核心之一,这里给出了模型的输入维度,要求输入是tensor,在我的模型里面,输入的是整数的tensor,input_ids和input_mask都经过了tokenizer的编译,所以这里用的dummy_input一定也要这样子:

input_ids=np.random.randint(10, size=(batch_size, 373))# 373是维度,根据自己的来定,这里是序列问题
inputs = torch.from_numpy(tape_pad(input_ids, 0))#我需要将其转为这样的格式
inputs_mask = torch.from_numpy(tape_pad(np.ones_like(input_ids), 0)) # 构建mask的方式根据自己选择
dummy_input={"input_ids":inputs,
            "input_mask":input_mask} # 得到这样一个dummy_input

导出模型

这些全部准备好了之后,就可以导出了:

torch.onnx.export(
    model, # 前面torch.load导入的模型
    (dummy_input,),#模型的输入
    'tape.onnx',  # 导出的模型名字
    export_params=True, # 如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False.
    opset_version=15, # 版本号,写10-15都可以
    do_constant_folding=True, # 是否执行常量折叠优化
    input_names=model_inputs, # 输入模型的张量的名称
    output_names=model_outputs, # 模型输出的张量的名称
    # dynamic_axes将batch_size的维度指定为动态,
    dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'},
                  'input_mask':{0: 'batch_size', 1: 'sequence'},
                  'class_loj': {0: 'batch_size'}}
    
)

完整导出代码

import torch
from tape.tokenizers import TAPETokenizer
from tape.datasets import pad_sequences as tape_pad
import numpy as np

# 导入token和模型
tokenizer=TAPETokenizer(vocab='unirep')
model=torch.load('saved_model/moudle_best.pth', map_location='cpu')
model.eval().cpu()

# 准备输入格式
batch_size=1
input_ids=np.random.randint(10, size=(batch_size, 373))
inputs = torch.from_numpy(tape_pad(input_ids, 0))
inputs_mask = torch.from_numpy(tape_pad(np.ones_like(input_ids), 0))
dummy_input={"input_ids":inputs,
            "input_mask":input_mask}

# 准备输入和输出张量名称            
model_inputs=['input_ids', 'input_mask']
model_outputs=['class_loj']

# 使用onnx导出
torch.onnx.export(
    model,
    (dummy_input,),
    'tape.onnx',
    export_params=True,
    opset_version=15,
    do_constant_folding=True,
    input_names=model_inputs,
    output_names=model_outputs,
    dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'},
                  'input_mask':{0: 'batch_size', 1: 'sequence'},
                  'class_loj': {0: 'batch_size'}}
    
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/888812.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SSH框架实现的管理系统(包含java源码+数据库)

资料下载链接 介绍 基于SSH框架的管理系统 简洁版 ; 实现 登录 、 注册 、 增 、 删 、 改 、 查 ; 可继续完善增加前端、校验、其他功能等; 可作为 SSH(Structs Spring Hibernate)项目 开发练习基础模型&#xf…

维护平衡的艺术:如何与干系人建立和谐关系以确保项目成功

在项目管理领域中,干系人的作用是无法忽视的。他们的支持和参与往往是项目成功的关键。与干系人建立和维护良好的关系成为每一位项目经理必备的技能。接下来,我们将深入探讨如何有效地与干系人互动,从而为项目的成功奠定坚实基础。 干系人的…

中国“诺贝尔奖”未来科学大奖公布2023年获奖名单

未来科学大奖委员会于8月16日公布2023年获奖名单。柴继杰、周俭民因发现抗病小体并阐明其结构和在抗植物病虫害中的功能做出的开创性工作获得“生命科学奖”,赵忠贤、陈仙辉因对高温超导材料的突破性发现和对转变温度的系统性提升所做出的开创性贡献获得“物质科学奖…

2023骨传导耳机推荐,适合运动骨传导耳机推荐

相信很多人跟我一样,随着现在五花八门的耳机品种增多,选耳机的时候真是眼花缭乱,尤其还是网购,只能看,不能试,所以选择起来比较困难, 作为一个运动达人,为了让大家在购买耳机时少走弯…

YOLO系列解读DAY1—YOLOV1预训练模型

一、说在前面 小伙伴们好,博主很久没有写博客了,略感生疏,不到之处敬请谅解,欢迎指出文中错误,大家一起探讨。欲看视频讲解,可转至博主DouYin、B站,欢迎关注,链接如下: …

Cat(1):Cat入门

1 什么是调用链监控 1.1 架构的演进历史 单体应用 架构说明: 全部功能集中在一个项目内(All in one)。 在单体应用的年代,分析线上问题主要靠日志以及系统级别的指标。 微服务架构 架构说明: 将系统服务层完全独立…

亚马逊添加心愿单对卖家有什么好处

在亚马逊平台上,卖家可以从消费者的角度来看待心愿单的好处。消费者可以将自己感兴趣的商品添加到心愿单中,这对卖家来说也是有一些潜在好处的: 1、潜在销售机会增加:当消费者将商品添加到心愿单中,这可能表示他们对这…

mySQL 视图 VIEW

简化版的创建视图 create view 视图名 as select col ...coln from 表create view 视图名(依次别名) as select col ...coln from 表create view 视图名 as select col “别名1”,。。。col "别名n" from 表show tab…

Angular安全专辑之二——‘unsafe-eval’不是以下内容安全策略中允许的脚本源

一:错误出现 这个错误的意思是,拒绝将字符串评估为 JavaScript,因为‘unsafe-eval’不是以下内容安全策略中允许的脚本源。 二:错误场景 testEval() {const data eval("var sum2 new Function(a, b, return a b); sum2(em…

挖掘优质短视频超百万条,火山引擎DataLeap助力电商平台生态治理

更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 在人们的日常生活中,网购已经成为人们生活中不可或缺的购物形式。 根据《中国社交电商行业发展白皮书(2022)》的数据显示&#x…

古战策与现代项目: 孙子兵法在项目管理中的应用

项目管理在当今的商业环境中是至关重要的。从初创公司到世界500强,项目管理的策略和工具都在不断地演变。然而,我们是否可以从古老的战争策略中汲取智慧,并将它们应用于现代的项目管理实践中呢? 让我们通过孙子兵法,一个古老而又…

ui设计师工作总结及计划范文模板

ui设计师工作总结及计划范文模板【篇一】 白驹过隙,转眼间某某年已近结尾,时间伴随着我们的脚步急驰而去,到了个人工作总结的时候,蓦然回首,才发现过去的一年不还能画上圆满的句号,内心感慨万千&#xff0c…

【PySide】Pyside QtWebEngine网页浏览器打开Flash网页

说明 QWebEngineView 加载 flash插件,可成功显示Flash,如图 源代码 # -*- coding: utf-8 -*- """ @File : pyside_2.py @Time : 2023/8/17 0:11 @Author : KmBase @Version : 1.0 @Contact : @Desc : None """import…

AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE 摘要模型架构Embedding层Transformer Encoder层MLP Head 整体流程 摘要 虽然Transformer体系结构已经成为自然语言处理任务的事实上的标准,但它在计算机视觉方面的应用仍然有限。在视…

传统算法是如何在销补调计划中发挥作用的

本文分享了一个「传统机器学习算法」在实际业务中的使用场景。 前言 如果嫌麻烦,你可以直接跳到正题观看~ 最近无论是在工作中的交谈,还是在日常刷屏的新闻,铺天盖地的都是大模型。我横竖是看不明白,费了大劲终于从字缝里看到了两…

后端项目打包上传服务器记录

后端项目打包上传服务器记录 文章目录 后端项目打包上传服务器记录1、项目打包2、jar包上传服务器 本文记录打包一个后端项目,上传公司服务器的过程。 1、项目打包 通过IDEA的插件进行打包: 打成一个jar包,jar包的位置在控制台可以看到。 2、…

记录--JS 的垃圾回收机制

这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 前言 垃圾回收(Garbage Collection)是一种内存管理机制,用于检测和清理不再被程序使用的内存,这些不再被使用的内存就被称为垃圾。垃圾回收器会在 JS 引擎(浏览器或者 nodejs)内…

Baklib是比语雀、Notion、石墨文档更好用的在线知识库管理工具

在当今信息爆炸的时代,如何高效地管理和利用知识成为了每个人都面临的问题。在线知识库管理工具应运而生,帮助用户整理、存储和共享知识。在这篇文章中,我将介绍一个更好用的在线知识库管理工具——Baklib,并探讨它相对于其他知识…

Python爬虫——scrapy_多条管道下载

定义管道类(在pipelines.py里定义) import urllib.requestclass DangDangDownloadPipelines:def process_item(self, item, spider):url http: item.get(src)filename ../books_img/ item.get(name) .jpgurllib.request.urlretrieve(url, filename…

JAVA编程学习笔记

常用代码、特定函数、复杂概念、特定功能……在学习编程的过程中你会记录下哪些内容?快来分享你的笔记,一起切磋进步吧! 一、常用代码 在java编程中常用需要储备的就是工具类。包括封装的时间工具类。http工具类,加解密工具类&am…