PyTorch的ONNX结合MNIST手写数字数据集的应用(.pth和.onnx的转换与onnx运行时)

news2025/1/23 14:51:47

在PyTorch以前的模型都是.pth格式,后面Meta跟微软一起做了一个.onnx的通用格式。这里对这两种格式文件,分别做一个介绍,依然使用MNIST数据集来做示例

1、CUDA下的pth文件

那pth文件里面是什么结构呢?其实在以前的文章就有介绍过,属于字典类型,而且是有序字典类型,这样就可以按照一定的顺序进行处理。

1.1、了解pth结构

先来查看一下pth文件的内容:
MNIST预训练模型.pth文件

import torch
model=torch.load("lenet_mnist_model.pth",map_location=torch.device('cpu'))
print(type(model),len(model))
for k,v in model.items():
    print(k,v.size())
'''
<class 'collections.OrderedDict'> 8
conv1.weight torch.Size([10, 1, 5, 5])
conv1.bias torch.Size([10])
conv2.weight torch.Size([20, 10, 5, 5])
conv2.bias torch.Size([20])
fc1.weight torch.Size([50, 320])
fc1.bias torch.Size([50])
fc2.weight torch.Size([10, 50])
fc2.bias torch.Size([10])
'''

可以看到类型是OrderedDict,两个卷积层加上两个全连接层。每个层都带有权重和偏置,简单显示了它们的形状。

1.2、torch.device

这里有一个需要注意的地方就是,如果将

model=torch.load("lenet_mnist_model.pth",map_location=torch.device('cpu'))

修改为

model=torch.load("lenet_mnist_model.pth")

就会报如下错误:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

也就是说想在CUDA中做反序列化操作,而CUDA是不可用的,从这里可以看到这个模型的训练是在有CUDA的环境下进行的,所以我们这里指定到CPU设备上。

2、CPU下的pth文件

 我们来看一个在CPU的环境下的加载方法,mnist.pth文件下载地址:mnist.pth

import torch
model=torch.load("mnist.pth")
print(type(model['net']),len(model['net']))
for k,v in model['net'].items():
    print(k,v.size())

'''
<class 'collections.OrderedDict'> 10
conv1.weight torch.Size([6, 1, 3, 3])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 3, 3])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
'''

这里可以不指定map_location参数,默认是cpu设备,可以看到这个pth文件结构是两个卷积层加三个全连接层。

3、pth转onnx

我们根据上面的mnist.pth结构,自己来构造一个模型:

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=3,stride=1,padding=0)
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=3,stride=1,padding=0)
        self.fc1   = nn.Linear(400, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)
 
    def forward(self, x):
        out = self.conv1(x) # torch.Size([1, 6, 26, 26])
        out = F.max_pool2d(F.relu(out), 2) # [1, 6, 13, 13]
        out = self.conv2(out) # [1, 16, 11, 11]
        out = F.max_pool2d(F.relu(out), 2)  # [1, 16, 5, 5]
        out = out.view(out.size(0), -1) # [1, 400]
        out = self.fc1(out) # [1, 120]
        out = self.fc2(F.relu(out)) # [1, 84]
        out = self.fc3(F.relu(out))  # [1, 10]
        return out

net = LeNet()
net = net.to('cpu')
checkpoint = torch.load('mnist.pth')
net.load_state_dict(checkpoint['net'])
batch_size = 1
input_shape = (1,28,28)
x = torch.randn(batch_size,*input_shape)
net.eval()
torch.onnx.export(net,x,"mnist.onnx")

构造一样的结构,加载mnist.pth,然后就可以通过export转换成onnx格式的文件了。我们上传到https://netron.app/ 站点,可视化整个模型图,然后点击每个节点,将在右边出现它们的属性值:

4、onnx运行时

onnxruntime主要是拿来推理,当然在ir7的版本也增加了训练等功能,我们来了解下这个东西 

4.1、安装模块

如果缺少onnxruntime模块,就会报错:

ModuleNotFoundError: No module named 'onnxruntime'

这里在JupyterLab中,所以在前面加一个叹号安装

!pip install onnxruntime -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

import torch
import onnxruntime as ort
import numpy as np

session = ort.InferenceSession("mnist.onnx")
x = np.random.rand(1, 1, 28, 28).astype(np.float32)
outputs = session.run(None, {"input": x})
print(outputs[0])

4.2、名称一致 

这里容易出错:InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:input

也就是说这个sess.run([output_name], {input_name: x})中的输入名称错误,所以名称要一样,这里的输入名称是input.1,修改成outputs = session.run(None, {"input.1": x})就可以了
怎么查看名称,可以通过上面站点可视化直接看到名称,也可以使用下面代码获取

input_name = session.get_inputs()
print(input_name[0].name)#input.1

同样的,如果输出名称也想指定,可以使用下面代码获取

out_name = session.get_outputs()[0].name

4.3、三通道转一通道

彩色三通道的图片转成灰色的单通道图片:

import cv2
import numpy as np
img = cv2.imread('1.png', cv2.IMREAD_GRAYSCALE)
cv2.imwrite('1.jpg',img)
print(img.shape)#(28, 28)

5、转成json格式

有时候的需求需要可读文件,一般json是很常见的,也可以进行转换:

import onnx
import json
from google.protobuf.json_format import MessageToJson
 
onnx_model = onnx.load("mnist.onnx")
s = MessageToJson(onnx_model)
onnx_json = json.loads(s)
 
output_json_path = 'mnist2.json'
 
with open(output_json_path, 'w') as f:
    json.dump(onnx_json, f, indent=2)

这样就将onnx文件转成了json格式的文件了

引用来源
github:https://github.com/onnx/onnx
可视化模型:https://netron.app/
ONNX实践:http://www.icfgblog.com/index.php/software/227.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/703301.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0基础学习VR全景平台篇 第50篇:高级功能-自定义右键

本期为大家带来蛙色VR平台&#xff0c;高级功能—自定义右键功能操作。 功能位置示意 一、本功能将用在哪里&#xff1f; 自定义右键功能&#xff0c;观看者可通过电脑端右键和手机端长按屏幕&#xff0c;出现作者配置的自定义内容&#xff0c;使VR全景玩法变得多样化。 二、…

欧科云链2023年报:毛利达1.55亿港元,数字资产业务成最大增长点

据香港商报报道&#xff0c;2023年6月28日&#xff0c;欧科云链控股有限公司&#xff08;以下简称“欧科云链”&#xff09;及其附属公司&#xff08;股份代号&#xff1a;1499.HK&#xff0c;以下简称“集团”&#xff09;发布了截至2023年3月31日的年度报告。报告期内&#x…

工业读码器的选择和使用注意事项有哪些?

工业读码器是一种能够读取条形码、二维码等信息的设备&#xff0c;广泛应用于物流、生产制造、零售等行业。如何选择和使用工业读码器呢?下面是一些注意事项。 选择工业读码器 要根据应用场景选择合适的读码器类型&#xff0c;如手持式、固定式、手动旋转式等。 要考虑读取码的…

【C++】详解多态

目录 一、多态的概念二、多态的定义及实现1、多态的构成条件2、虚函数3、虚函数的重写1、虚函数重写的两个例外 4、C11 override 和 final5、重载、覆盖(重写)、隐藏(重定义)的对比 三、抽象类1、概念2、接口继承和实现继承 四、多态的原理1、虚函数表2、多态的原理3、动态绑定…

Mysql架构篇--Mysql(M-S) 主从同步

文章目录 前言&#xff1a;一、主从同步是什么&#xff1f;二、主从同步实现&#xff1a;1.准备工作&#xff1a;2.开启主从复制&#xff1a;2.1 mysql 服务端配置文件修改&#xff1a;2.2 mysql master 节点用户创建&#xff1a;2.3 mysql slave 节点开启数据复制&#xff1a;…

突破传统设计灵感,虚拟展厅设计方案

导语&#xff1a; 随着科技的不断发展&#xff0c;虚拟展厅设计方案正成为现代设计行业的新宠。这种创新的设计形式不仅突破了传统设计的局限&#xff0c;还为传统设计公司带来了诸多优势和特点&#xff0c;从而提高了设计产量和创意灵感。 在这篇软文中&#xff0c;我们将深入…

雅迪、爱玛谁是“新宠”?

电动两轮车下半场&#xff0c;谁是“新王”&#xff1f; 6月15日&#xff0c;爱玛科技有限公司&#xff08;下称“爱玛”&#xff0c;603529.SH)迎来了上市两周年。 作为电动两轮车的头部玩家&#xff0c;雅迪控股有限公司&#xff08;下称“雅迪”&#xff0c;01585.HK&…

HJ101 输入整型数组和排序标识,对其元素

描述 输入整型数组和排序标识&#xff0c;对其元素按照升序或降序进行排序 数据范围&#xff1a; 1≤n≤1000 1≤n≤1000 &#xff0c;元素大小满足 0≤val≤100000 0≤val≤100000 输入描述&#xff1a; 第一行输入数组元素个数 第二行输入待排序的数组&#xff0c;每个…

python实现九九乘法表

九九乘法表 i 1 while i < 9:j 1while j < i:print(f{j}*{i}{i * j}, end\t)j 1print()i 1结果&#xff1a;

window10 查看本机TCP协议进程

1. netstat 是一个常见的网络工具&#xff0c;用于显示网络连接状态、路由表、接口统计信息等网络相关的信息&#xff0c;可以帮助诊断和解决网络问题。 其中&#xff0c;各参数的含义为&#xff1a; -a&#xff1a;显示所有的网络连接和监听端口。 -e&#xff1a;显示以太网…

CDH yarn Fair 队列最大资源使用限制,任务无法提交

一、问题背景描述 1.任务提交异常日志 2023-06-29 15:48:20,877 INFO org.apache.flink.yarn.YarnClusterDescriptor [] - Deployment took more than 60 seconds. Please check if the requested resources are available in the YARN cluster 2023-06-29 15:48:21,129 IN…

1-什么是NumPy?【视频版】

目录 问题解答观看视频 问题 解答 NumPy&#xff0c;全称Numerical Python&#xff0c;是一个开源的Python科学计算库。它为Python提供了大量的数学库&#xff0c;包括&#xff1a; 强大的N维数组对象成熟的广播功能集成C/C和Fortran代码的工具有用的线性代数、傅里叶变换和随…

第一个spring程序

我们今天写第一个spring程序 我们采用maven形式创建工程。 我们首先在pom.xml中加入引用。 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSch…

(6)蜂鸣器(又称音调报警)

文章目录 6.1 使用有源蜂鸣器而不是无源蜂鸣器 6.2 安装蜂鸣器 6.3 使蜂鸣器安静 蜂鸣器&#xff08;或音调报警器&#xff09;可用于以声音指示飞行器的状态变化。根据电路板的能力&#xff0c;它可以是一个有源设备&#xff08;只需要施加电压来产生一个单一频率的音调&am…

给定一组数据样本,计算:【样本的平均值】, 【样本的标准差】, 【样本的变异系数】,【样本的标准误差】

一、指标含义 样本的平均值&#xff1a;指样本中所有数据的总和除以样本大小&#xff0c;是样本的中心趋势的度量。平均值常用于描述数据的集中程度&#xff0c;具有良好的代表性和易于计算的优点。 样本的标准差&#xff1a;指样本中每个数据与平均值的偏差的平方和的平均值的…

openssl版本升级与降级

openssl版本升级与降级 flyfish 环境 Ubuntu 22.04 1.1.1升级3.1.1 查看openssl版本 openssl versionOpenSSL 1.1.1t 7 Feb 2023https://www.openssl.org/source/ 编译和安装 ./config --prefix/usr/local/openssl311 make -j8 make install进入/usr/local/openssl311/l…

JavaWeb两大组件FILTERLISTENER

一.Filter&#xff1a;过滤器 是什么&#xff1a; 当访问服务器的资源时&#xff0c;过滤器可以将请求拦截下来&#xff0c;完成一些特殊的功能。 作用&#xff1a; 一般用于完成通用的操作。如&#xff1a;登录验证、统一编码处理、敏感字符过滤… 具体流程&#xff1a; 原始…

本地部署 FastChat

本地部署 FastChat 1. 什么是 FastChat2. Github 地址3. 安装 Miniconda34. 创建虚拟环境5. 安装 FastChat6. 使用命令行进行推理7. 使用 Web GUI 服务进行推理8. 使用 Lora 进行训练9. 其他 文章还在创作中。。。 1. 什么是 FastChat FastChat 是一个开放平台&#xff0c;用…

同步和异步、同步复位、异步复位、同步释放(Verilog、Verdi、DC综合)

文章目录 1.同步和异步2. 同步复位、异步复位、同步释放2.1 同步复位2.1.1 Verilog code2.1.2 Verdi waveform2.1.3 DC Synthesis 2.2 异步复位2.2.1 Verilog code 2.3 同步释放&#xff08;异步信号和CLK信号存在时序检查、Recover time&Removel time&#xff09;2.4 异步…

初识Docker:(7)查询Docker镜像的DockerFile

1. 前言 我们知道了根据dockerfile来制作镜像&#xff0c;如果给你一个现成的镜像&#xff0c;你能逆向查看出dockerfile吗&#xff1f; 否则&#xff0c;你怎么知道该镜像使用的是CMD还是ENTRYPOINT &#xff0c;使用的是shell格式还是CMD格式&#xff1f;由于格式决定了doc…