模型转换 PyTorch转ONNX 入门

news2024/9/23 11:17:00

前言

本文主要介绍如何将PyTorch模型转换为ONNX模型,为后面的模型部署做准备。转换后的xxx.onnx模型,进行加载和测试。最后介绍使用Netron,可视化ONNX模型,看一下网络结构;查看使用了那些算子,以便开发部署。

目录

前言

一、PyTorch模型转ONNX模型

1.1 转换为ONNX模型且加载权重

1.2 转换为ONNX模型但不加载权重

1.3 torch.onnx.export() 函数

二、加载ONNX模型

三、可视化ONNX模型


一、PyTorch模型转ONNX模型

将PyTorch模型转换为ONNX模型,通常是使用torch.onnx.export( )函数来转换的,基本的思路是:

  1. 加载PyTorch模型,可以选择只加载模型结构;也可以选择加载模型结构和权重。
  2. 然后定义PyTorch模型的输入维度,比如(1, 3, 224, 224),这是一个三通道的彩色图,分辨率为224x224。
  3. 最后使用torch.onnx.export( )函数来转换,生产xxx.onnx模型。

下面有一个简单的例子:

import torch
import torch.onnx

# 加载 PyTorch 模型
model = ...

# 设置模型输入,包括:通道数,分辨率等
dummy_input = torch.randn(1, 3, 224, 224, device='cpu')

# 转换为ONNX模型
torch.onnx.export(model, dummy_input, "model.onnx", export_params=True)

1.1 转换为ONNX模型且加载权重

这里举一个resnet18的例子,基本思路是:

  1. 首先加载了一个预训练的 ResNet18 模型;
  2. 然后将其设置为评估模式。接下来定义一个与模型输入张量形状相同的输入张量,并使用 torch.randn() 函数生成了一个随机张量。
  3. 最后,使用 onnx.export() 函数将 PyTorch 模型转换为 ONNX 格式,并将其保存到指定的输出文件中。

程序如下:

import torch
import torchvision.models as models

# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)

# 将模型设置为评估模式
model.eval()

# 定义输入张量,需要与模型的输入张量形状相同
input_shape = (1, 3, 224, 224)
x = torch.randn(input_shape)


# 需要指定输入张量,输出文件路径和运行设备
# 默认情况下,输出张量的名称将基于模型中的名称自动分配
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将 PyTorch 模型转换为 ONNX 格式
output_file = "resnet18.onnx"
torch.onnx.export(model, x.to(device), output_file, export_params=True)

1.2 转换为ONNX模型但不加载权重

举一个resnet18的例子:基本思路是:

  1. 首先加载了一个预训练的 ResNet18 模型;
  2. 然后使用 onnx.export() 函数将 PyTorch 模型转换为 ONNX 格式;指定参数do_constant_folding=False,不加载模型的权重。
import torch
import torchvision.models as models

# 加载 PyTorch 模型
model = models.resnet18()

# 将模型转换为 ONNX 格式但不加载权重
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "resnet18.onnx", do_constant_folding=False)

下面构建一个简单网络结构,并转换为ONNX

import torch
import torchvision
import numpy as np

# 定义一个简单的PyTorch 模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x

# 创建模型实例
model = MyModel()

# 指定模型输入尺寸
dummy_input = torch.randn(1, 3, 32, 32)

# 将PyTorch模型转为ONNX模型
torch.onnx.export(model, dummy_input, 'mymodel.onnx',  do_constant_folding=False)

1.3 torch.onnx.export() 函数

看一下这个函数的参数

torch.onnx.export(
            model, 
            args, 
            f, 
            export_params=True, 
            opset_version=10, 
            do_constant_folding=True, 
            input_names=['input'], 
            output_names=['output'], 
            dynamic_axes=None, 
            verbose=False, 
            example_outputs=None, 
            keep_initializers_as_inputs=None)
  • model:需要导出的 PyTorch 模型
  • args:PyTorch模型输入数据的尺寸,指定通道数、长和宽。可以是单个 Tensor 或元组,也可以是元组列表。
  • f:导出的 ONNX 文件路径和名称,mymodel.onnx。
  • export_params:是否导出模型参数。如果设置为 False,则不导出模型参数。
  • opset_version:导出的 ONNX 版本。默认值为 10。
  • do_constant_folding:是否对模型进行常量折叠。如果设置为 True,不加载模型的权重。
  • input_names:模型输入数据的名称。默认为 'input'。
  • output_names:模型输出数据的名称。默认为 'output'。
  • dynamic_axes:动态轴的列表,允许在导出的 ONNX 模型中创建变化的维度。
  • verbose:是否输出详细的导出信息。
  • example_outputs:用于确定导出 ONNX 模型输出形状的样本输出。
  • keep_initializers_as_inputs:是否将模型的初始化器作为输入导出。如果设置为 True,则模型初始化器将被作为输入的一部分导出。

下面是只是一个常用的模板

import torch.onnx 

# 转为ONNX
def Convert_ONNX(model): 

    # 设置模型为推理模式
    model.eval() 

    # 设置模型输入的尺寸
    dummy_input = torch.randn(1, input_size, requires_grad=True)  

    # 导出ONNX模型  
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "xxx.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX')


if __name__ == "__main__": 

    # 构建模型并训练
    # xxxxxxxxxxxx

    # 测试模型精度
    #testAccuracy() 

    # 加载模型结构与权重
    model = Network() 
    path = "myFirstModel.pth" 
    model.load_state_dict(torch.load(path)) 
 
    # 转换为ONNX 
    Convert_ONNX(model)

二、加载ONNX模型

加载ONNX模型,通常需要用到ONNX、ONNX Runtime,所以需要先安装。

pip install onnx
pip install onnxruntime

加载ONNX模型可以使用ONNX Runtime库,以下是一个加载ONNX模型的示例代码:

import onnxruntime as ort

# 加载 ONNX 模型
ort_session = ort.InferenceSession("model.onnx")

# 准备输入信息
input_info = ort_session.get_inputs()[0]
input_name = input_info.name
input_shape = input_info.shape
input_type = input_info.type


# 运行ONNX模型
outputs = ort_session.run(input_name, input_data)

# 获取输出信息
output_info = ort_session.get_outputs()[0]
output_name = output_info.name
output_shape = output_info.shape
output_data = outputs[0]

print("outputs:", outputs)
print("output_info :", output_info )
print("output_name :", output_name )
print("output_shape :", output_shape )
print("output_data :", output_data )

以下是一个示例程序,将 resnet18 模型从 PyTorch 转换为 ONNX 格式,然后加载和测试 ONNX 模型的过程:

import torch
import torchvision.models as models
import onnx
import onnxruntime

# 加载 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval()

# 定义输入和输出张量的名称和形状
input_names = ["input"]
output_names = ["output"]
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)

# 将 PyTorch 模型转换为 ONNX 格式
torch.onnx.export(
    model,  # 要转换的 PyTorch 模型
    torch.randn(input_shape),  # 模型输入的随机张量
    "resnet18.onnx",  # 保存的 ONNX 模型的文件名
    input_names=input_names,  # 输入张量的名称
    output_names=output_names,  # 输出张量的名称
    dynamic_axes={input_names[0]: {0: "batch_size"}, output_names[0]: {0: "batch_size"}}  # 动态轴,即输入和输出张量可以具有不同的批次大小
)

# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
onnx_model_graph = onnx_model.graph
onnx_session = onnxruntime.InferenceSession(onnx_model.SerializeToString())

# 使用随机张量测试 ONNX 模型
x = torch.randn(input_shape).numpy()
onnx_output = onnx_session.run(output_names, {input_names[0]: x})[0]

print(f"PyTorch output: {model(torch.from_numpy(x)).detach().numpy()[0, :5]}")
print(f"ONNX output: {onnx_output[0, :5]}")

上述代码中,首先加载预训练的 resnet18 模型,并定义了输入和输出张量的名称和形状。

然后,使用 torch.onnx.export() 函数将模型转换为 ONNX 格式,并保存为 resnet18.onnx 文件。

接着,使用 onnxruntime.InferenceSession() 函数加载 ONNX 模型,并使用随机张量进行测试。

最后,将 PyTorch 模型和 ONNX 模型的输出进行比较,以确保它们具有相似的输出。

三、可视化ONNX模型

使用Netron,可视化ONNX模型,看一下网络结构;查看使用了那些算子,以便开发部署。

这里简单介绍一下

Netron是一个轻量级、跨平台的模型可视化工具,支持多种深度学习框架的模型可视化,包括TensorFlow、PyTorch、ONNX、Keras、Caffe等等。它提供了可视化网络结构、层次关系、输出尺寸、权重等信息,并且可以通过鼠标移动和缩放来浏览模型。Netron还支持模型的导出和导入,方便模型的分享和交流。

 Netron的网页在线版本,直接在网页中打开和查看ONNX模型Netron

开源地址:GitHub - lutzroeder/netron: Visualizer for neural network, deep learning, and machine learning models

支持多种操作系统:

macOS: Download 

Linux: Download 

Windows: Download 

Browser: Start 

Python Server: Run pip install netron and netron [FILE] or netron.start('[FILE]').

 

下面是可视化模型截图:

还能查看某个节点(运算操作)的信息,比如下面MaxPool,点击一下,能看到使用的3x3的池化核,是否有填充pads,步长strides等参数。

 

分享完毕~

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/351365.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络第1章(概述)学习笔记

❤ 作者主页:欢迎来到我的技术博客😎 ❀ 个人介绍:大家好,本人热衷于Java后端开发,欢迎来交流学习哦!( ̄▽ ̄)~* 🍊 如果文章对您有帮助,记得关注、点赞、收藏、…

PPS文件如何转换成PPT?附两种方法

在工作中,PPS文件的使用还是很广泛的,因为作为幻灯片放映文件,点击后就能直接播放,十分方便。但如果想要修改PPS里的内容,PPS是无法编辑的,我们需要把文件转换成PPT,再进行修改。 那PPS文件如何…

详细解读ChatGPT:如何调用ChatGPT的API接口到官方例子的说明以及GitHub上的源码应用和csdn集成的ChatGPT

文章目录1. 解读ChatGPT1.1 词语解释1.2 功能解读2. GitHub上ChatGPT的应用源码3. 调用ChatGPT的API4. 官方例子说明5. 集成ChatGPT自ChatGPT出来到如今,始终走在火热的道路上,如今日活用户破亿,他为何有如此大的魅力,深受广大用户…

通用 GPU 领先企业登临科技加入龙蜥社区,完成与龙蜥操作系统的兼容适配

近日,上海登临科技有限公司(以下简称“登临科技”)签署了 CLA(Contributor License Agreement,贡献者许可协议),正式加入龙蜥社区(OpenAnolis)。作为国内通用 GPU 领先企…

深入浅出带你学习GlassFish中间件漏洞

前文 上文给大家带来了WEBLOGIC常见的漏洞不知道大家理解了没有,今天给大家带来一个新的中间件漏洞的讲解——glassfish,本文会先介绍该中间件的简单信息然后解析一下该中间件可能存在的漏洞类型,下面我们展开文章来讲。 GlassFish GlassF…

2023美国大学生数学建模竞赛E题思路解析

背景:光污染是指任何过多或不当使用人造光的表现。我们所称为光污染的一些现象包括光侵入、过亮、以及光混乱。这些现象最容易在大城市太阳落山后观察到天空中的发光;然而,它们也可能发生在更偏远的地区。光污染改变了我们对夜空的看法&#…

(三十四)Vue之新生命周期钩子nextTick

文章目录普通实现的一个问题解决问题nextTick上一篇:(三十三)Vue之消息订阅与发布 首先先看这一个需求,给每个任务项新增一个编辑按钮 当编辑按钮点击时,任务项就会变成文本框,并且自动获取焦点 普通实…

中国天气——对流性天气过程复习笔记

对流性天气过程 对流性天气十分激烈,影响范围相对较小,持续时间短,通常是局部灾害性天气 雷暴结构 产生雷暴的积雨云叫雷暴云,也叫雷暴单体,水平尺度约为十几千米多个雷暴单体成群聚集在一起叫做雷暴群,…

【C语言】数组的声明和使用(一维数组、多维数组)

数组一、什么是数组?二、一维数组(一)一维数组声明(二)一维数组初始化(三)一维数组的引用三、多维数组(以二维数组为例)(一)二维数组声明&#xf…

PyQt5数据库开发1 4.3 QSqlTableModel 之 Qt项目的创建

目录 一、新建Qt项目 1. 编辑资源文件 2. 添加前缀 3. 新建放资源文件的目录 4. 添加图标文件 二、Action 1. 新建打开数据库Action 2. 添加其他Action 三、工具栏 1. 添加工具栏 2. 拖动actOpenDB到工具栏 3. 设置工具栏属性 4. 添加分隔符 5. 添加其他工具 6.…

精选案例 |《金融电子化》:光大银行云原生背景下的运维监控体系建设

顺应“十四五”规划中关于“加快金融机构数字化转型”要求,中国人民银行印发了《金融科技发展规划(2022-2025年)》。近几年来,金融行业牢牢占据着国内产业数字化转型市场投入的榜首位置。IDC调查显示,2022上半年&#…

北斗卫星信号类型及卫星颗数

文章目录一、北斗系统现阶段提供的公开服务信号二、北斗二号、三号卫星个数三、GNSS模块中的北斗信号参考来源这篇博客主要是整理一下北斗卫星现阶段提供的公开服务信号、二号和三号卫星个数,以及简单看看市场的GNSS模块对北斗信号的支持情况。一、北斗系统现阶段提…

智云通CRM:引起流单的三个问题,你了解了吗?

销售人员一般都会了解基本的销售流程,但是为什么还是出现了各种流单的问题?智云通CRM总结以下三个问题: 第一,采购流程是会发生反复的,不会一直向下走。 从整体上看,客户的采购流程遵循着上述规律&#x…

C++ 修改防火墙firewall设置(Windows)

文章目录1、简介1.1 防火墙概述1.2 入站,还是出站?1.3 防火墙规则优先级2、系统界面方式3、命令行方式3.1 防火墙基本状态设置3.2 入站出站规则设置3.3 其他设置3.4 telnet检测端口4、C方式4.1 注册表4.2 COM(Windows XP)4.3 COM&…

深度学习模型概念

Big data features: 5V--volume, velocity, variety, value, veracity.Big data challenges:高维、multi-modal、complexity、privacy 1. Federated Learning 联邦学习 Federated Learning:Server将model分散到各个用户user,clients利用本地…

不同相机之间图片像素对应关系求解(单应性矩阵求解)

一、场景 相机1和相机2相对位置不变,相机拍摄图片有重叠,求他们交叠部分的一一对应关系。数学语言描述为已知相机1图片中P点像素(u1, v1),相机1中P点在相机2图片中像素值为(u2, v2),它们存在某种变换,求变换矩阵。 因为…

计算机存储数字的本质,正码,反码,补码

计算机-原码 就是二进制定点表示法,即最高位为符号位:“0”表示正,“1”表示负,其余位表示数值的大小。 该数字不进行其他操作时数字最原始的二进制表示, 对于原码来说,绝对值相等的正数和负数只有符号位不…

高通平台开发系列讲解(USB篇)libuvc详解

文章目录 一、什么是UVC二、UVC拓扑结构三、libuvc的预览时序图沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇文章将介绍libuvc。 一、什么是UVC UVC,全称为:USB video(device) class。 UVC是微软与另外几家设备厂商联合推出的为USB视频捕获设备定义的协议标…

缓存雪崩 缓存击穿-总结

目录 缓存雪崩 缓存击穿-总结 缓存雪崩 出现场景: 解决方案: 缓存击穿 出现场景: 举例如图: 缓存击穿的三个前提: 解决方案: 缓存雪崩 缓存击穿-总结 缓存雪崩 出现场景: (1) 对于R…

用于隔离PWM的光耦合器选择和使用

光耦合器(或光隔离器)是一种将电路电隔离的器件,不仅在隔离方面非常出色,而且允许您连接到具有不同接地层或在不同电压电平下工作的电路。光耦合器具有“故障安全”功能,因为如果受到高于最大额定值的电压,…