PyTorch 模型转换为 ONNX 格式

news2024/11/29 7:22:56

PyTorch 模型转换为 ONNX 格式

在深度学习领域,模型的可移植性和可解释性是非常重要的。本文将介绍如何使用 PyTorch 训练一个简单的卷积神经网络(CNN)来分类 MNIST 数据集,并将训练好的模型转换为 ONNX 格式。我们还将讨论 PTH 和 ONNX 格式的区别,并介绍如何使用 Netron 可视化 ONNX 模型。

1. PTH 和 ONNX 的区别

PTH 格式

  • 定义:PTH 是 PyTorch 框架的专有格式,通常用于保存模型的状态字典(state_dict),包括模型的结构和训练好的参数。

  • 兼容性

    • PTH 文件只能在 PyTorch 中使用,无法直接在 C++ 环境中加载。虽然 PyTorch 提供了 C++ API(LibTorch),但 PTH 文件的加载和使用主要依赖于 Python 环境。
    • 在 C++ 中使用 PTH 文件需要将模型转换为 PyTorch 的 C++ 格式,这可能会增加复杂性和开发时间。
  • 用途

    • PTH 格式适合在 Python 环境中进行模型训练和调试,但在 C++ 中进行模型部署时,通常需要将模型转换为其他格式(如 ONNX)以便于跨平台使用。
    • 在 C++ 中,使用 PTH 文件的灵活性较低,尤其是在需要与其他框架或系统集成时。

ONNX 格式

  • 定义:ONNX(Open Neural Network Exchange)是一个开放的深度学习模型交换格式,旨在促进不同深度学习框架之间的互操作性。

  • 兼容性

    • ONNX 文件可以在多个深度学习框架中使用,包括 PyTorch、TensorFlow、Caffe2 等,这使得它在 C++ 环境中的兼容性更强。
    • ONNX 模型可以通过 ONNX Runtime、TensorRT、OpenVINO 等推理引擎在 C++ 中高效运行,支持多种硬件加速。
  • 用途

    • ONNX 格式非常适合模型的部署和推理,特别是在需要跨平台或跨框架使用时。它允许开发者在 C++ 中轻松加载和运行模型,而无需依赖于 Python 环境。
    • 在 C++ 中,使用 ONNX 模型可以简化工程化流程,便于与其他系统集成,提升模型的可移植性和可扩展性。

总结

在 C++ 进行深度学习模型的工程化时,选择 ONNX 格式通常更为合适,因为它提供了更好的跨平台兼容性和灵活性。PTH 格式虽然在 PyTorch 环境中非常方便,但在 C++ 中的使用受到限制,通常需要额外的转换步骤。ONNX 的开放性和广泛支持使其成为在多种环境中部署深度学习模型的首选格式。

2. 训练 MNIST 数据集的 CNN 模型

以下是使用 PyTorch 训练 MNIST 数据集的完整代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 1. 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # MNIST 数据集的均值和标准差
])

# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# 2. 定义 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 输入通道为1,输出通道为32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 输入通道为32,输出通道为64
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # 全连接层
        self.fc2 = nn.Linear(128, 10)  # 输出层

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))  # 第一层卷积 + 激活 + 池化
        x = self.pool(torch.relu(self.conv2(x)))  # 第二层卷积 + 激活 + 池化
        x = x.view(x.size(0), -1)  # 展平输入
        x = torch.relu(self.fc1(x))  # 第一个全连接层
        x = self.fc2(x)  # 输出层
        return x

# 3. 训练模型
model = SimpleCNN().to(device)  # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器

# 训练过程
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备
        optimizer.zero_grad()  # 清空梯度
        outputs = model(images)  # 前向传播
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 4. 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

# 5. 转换为 ONNX 格式
onnx_file_path = 'mnist_cnn_model.onnx'
dummy_input = torch.randn(1, 1, 28, 28).to(device)  # 示例输入,形状为 [batch_size, channels, height, width]
torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True,
                  opset_version=11, do_constant_folding=True,
                  input_names=['input'], output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

print(f'Model has been converted to ONNX format and saved as {onnx_file_path}.')

3. 使用 Netron 可视化 ONNX 模型

一旦您将模型转换为 ONNX 格式,您可以使用 Netron 来可视化模型结构。Netron 是一个开源的模型可视化工具,支持多种深度学习框架的模型文件格式,包括 ONNX。

使用步骤:
  1. 下载 Netron

    • 您可以访问 Netron 的官方网站 在线使用,或者下载桌面版本。
  2. 打开 ONNX 模型

    • 如果使用在线版本,直接将 mnist_cnn_model.onnx 文件拖放到浏览器窗口中。
    • 如果使用桌面版本,打开 Netron 应用,选择“File” > “Open Model”,然后选择您的 ONNX 文件。
  3. 查看模型结构

    • 在 Netron 中,您可以查看模型的层次结构、输入输出形状、参数数量等信息。通过可视化,您可以更好地理解模型的设计和工作原理。
      在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249592.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VM Virutal Box的Ubuntu虚拟机与windows宿主机之间设置共享文件夹(自动挂载,永久有效)

本文参考如下链接 How to access a shared folder in VirtualBox? - Ask Ubuntu (1)安装增强功能(Guest Additions) 首先,在网上下载VBoxGuestAdditions光盘映像文件 下载地址:Index of http://…

CA系统(file.h---申请认证的处理)

#pragma once #ifndef FILEMANAGER_H #define FILEMANAGER_H #include <string> namespace F_ile {// 读取文件&#xff0c;返回文件内容bool readFilename(const std::string& filePath);bool readFilePubilcpath(const std::string& filePath);bool getNameFro…

【Git】Git 命令参考手册

目录 Git 命令参考手册1. 创建仓库1.1 创建一个新的本地仓库1.2 克隆一个仓库1.3 克隆仓库到指定目录 2. 提交更改2.1 显示工作目录中已修改的文件&#xff0c;准备提交2.2 将文件添加到暂存区&#xff0c;准备提交2.3 将所有已修改的文件添加到暂存区&#xff0c;准备提交2.4 …

【Linux系列】Chrony时间同步服务器搭建完整指南

1. 简介 Chrony是一个用于Linux系统的高效、精准的时间同步工具&#xff0c;通常用于替代传统的NTP&#xff08;Network Time Protocol&#xff09;服务。Chrony不仅在系统启动时提供快速的时间同步&#xff0c;还能在时钟漂移较大的情况下进行及时调整&#xff0c;因此广泛应…

数据库日志

MySQL中有哪些日志 1&#xff0c;redo log重做日志 redo log是物理机日志&#xff0c;因为它记录的是对数据页的物理修改&#xff0c;而不是SQL语句。 作用是确保事务的持久性&#xff0c;redo log日志记录事务执行后的状态&#xff0c;用来恢复未写入 data file的已提交事务…

【vue for beginner】Vue该怎么学?

&#x1f308;Don’t worry , just coding! 内耗与overthinking只会削弱你的精力&#xff0c;虚度你的光阴&#xff0c;每天迈出一小步&#xff0c;回头时发现已经走了很远。 vue2 和 vue3 Vue2现在正向vue3逐渐更新中&#xff0c;官方vue2已经不再更新。 这个历程和当时的pyt…

【Ubuntu 24.04】How to Install and Use NVM

参考 下载 curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.7/install.sh | bash激活 Activate NVM: Once the installation script completes, you need to either close and reopen the terminal or run the following command to use nvm immediately. exp…

SeggisV1.0 遥感影像分割软件【源代码】讲解

在此基础上进行二次开发&#xff0c;开发自己的软件&#xff0c;例如&#xff1a;【1】无人机及个人私有影像识别【2】离线使用【3】变化监测模型集成【4】个人私有分割模型集成等等&#xff0c;不管是您用来个人学习 还是公司研发需求&#xff0c;都相当合适&#xff0c;包您满…

Python轴承故障诊断 (21)基于VMD-CNN-BiTCN的创新诊断模型

往期精彩内容&#xff1a; Python-凯斯西储大学&#xff08;CWRU&#xff09;轴承数据解读与分类处理 Pytorch-LSTM轴承故障一维信号分类(一)-CSDN博客 Pytorch-CNN轴承故障一维信号分类(二)-CSDN博客 Pytorch-Transformer轴承故障一维信号分类(三)-CSDN博客 三十多个开源…

使用docker搭建hysteria2服务端

源链接&#xff1a;https://github.com/apernet/hysteria/discussions/1248 官网地址&#xff1a;https://v2.hysteria.network/zh/docs/getting-started/Installation/ 首选需要安装docker和docker compose 切换到合适的目录 cd /home创建文件夹 mkdir hysteria创建docke…

基于Java实现的潜艇大战游戏

基于Java实现的潜艇大战游戏 一.需求分析 1.1 设计任务 本次游戏课程设计小组成员团队合作的方式&#xff0c;通过游戏总体分析设计&#xff0c;场景画面的绘制&#xff0c;游戏事件的处理&#xff0c;游戏核心算法的分析实现&#xff0c;游戏的碰撞检测&#xff0c;游戏的反…

课题组自主发展了哪些CMAQ模式预报相关的改进技术?

空气污染问题日益受到各级政府以及社会公众的高度重视&#xff0c;从实时的数据监测公布到空气质量数值预报及预报产品的发布&#xff0c;我国在空气质量监测和预报方面取得了一定进展。随着计算机技术的高速发展、空气污染监测手段的提高和人们对大气物理化学过程认识的深入&a…

深入解析下oracle date底层存储方式

之前我们介绍了varchar2和char的数据库底层存储格式&#xff0c;今天我们介绍下date类型的数据存储格式&#xff0c;并通过测试程序快速获取一个日期。 一、环境搭建 1.1&#xff0c;创建表 我们还是创建一个测试表t_code&#xff0c;并插入数据&#xff1a; 1.2&#xff0c;…

【论文复现】SRGAN

1. 项目结构 如何生成文件夹的文件目录呢? 按住shift键,右击你要生成目录的文件夹,选择“在此处打开Powershell窗口” 在命令窗口里输入命令“tree”,按回车。就会显示出目录结构。 ├─.idea │ └─inspectionProfiles ├─benchmark_results ├─data │ ├─test …

Kubernetes 之 Ingress 和 Service 的异同点

1. 概念与作用 1.1 Ingress Ingress 是什么&#xff1f; Ingress主要负责七层负载&#xff0c;将外部 HTTP/HTTPS 请求路由到集群内部的服务。它可以基于域名和路径定义规则&#xff0c;从而将外部请求分配到不同的服务。 ingress作用 提供 基于 HTTP/HTTPS 的路由。 支持 …

结构体详解+代码展示

系列文章目录 &#x1f388; &#x1f388; 我的CSDN主页:OTWOL的主页&#xff0c;欢迎&#xff01;&#xff01;&#xff01;&#x1f44b;&#x1f3fc;&#x1f44b;&#x1f3fc; &#x1f389;&#x1f389;我的C语言初阶合集&#xff1a;C语言初阶合集&#xff0c;希望能…

Springboot项目搭建(7)

1.概要 2.Layout主页布局 文件地址&#xff1a;src\views\Layout.vue 2.1 script行为模块 从elementUI中选取图标图案。 <script setup> import {Management,Promotion,UserFilled,User,Crop,EditPen,SwitchButton,CaretBottom } from "element-plus/icons-vue…

cocos creator 3.8 俄罗斯方块Demo 10

这里的表格是横行数列&#xff0c;也就是x是行&#xff0c;y是列&#xff0c;不要当x/y轴看。 1-1012-1012-1-1[-1,0]0[0,-1][0,0][0,1][0,2]0[0,0]11[1,0]22[2,0] -1012-1012-1-1[-1,0]0[0,-1][0,0][0,1][0,2]0[0,0]11[1,0]22[2,0] 2-1012-1012-1[-1,-1][-1,0]-1[-1,-1][-1…

Java安全—原生反序列化重写方法链条分析触发类

前言 在Java安全中反序列化是一个非常重要点&#xff0c;有原生态的反序列化&#xff0c;还有一些特定漏洞情况下的。今天主要讲一下原生态的反序列化&#xff0c;这部分内容对于没Java基础的来说可能有点难&#xff0c;包括我。 序列化与反序列化 序列化&#xff1a;将内存…

【Java 学习】面向程序的三大特性:封装、继承、多态

引言 1. 封装1.1 什么是封装呢&#xff1f;1.2 访问限定符1.3 使用封装 2. 继承2.1 为什么要有继承&#xff1f;2.2 继承的概念2.3 继承的语法2.4 访问父类成员2.4.1 子类中访问父类成员的变量2.4.2 访问父类的成员方法 2.5 super关键字2.6 子类的构造方法 3. 多态3.1 多态的概…