【深度学习框架格式转化】【CPU】Pytorch模型转ONNX模型格式流程详解【入门】

news2025/1/13 8:06:39

【深度学习框架格式转化】【GPU】Pytorch模型转ONNX模型格式流程详解【入门】

提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 【深度学习框架格式转化】【GPU】Pytorch模型转ONNX模型格式流程详解【入门】
  • 前言
  • PyTorch模型环境搭建(CPU)
  • 安装onnx和onnxruntime(CPU)
  • pytorch2onnx
  • 总结


前言

神经网络的模型通常在深度学习框架(PyTorc、TensorFlow和Caffe等)下训练得到,这些特定环境的深度学习框架依赖较多,规模较大,不适合在生产环境中安装,onnx支持大多数框架下模型的转换,便于整合模型,并且深度学习模型需要大量的算力才能满足实时运行需求,需要优化模型的运行效率,onnx并则能带来稳定的提速。
onnx还能再转化成TensorRT(GPU)格式和OpenVINO(CPU)格式进行推理,进一步提升速度

CPU模式下的格式转化,无论Pytorch还是ONNX搭建流程都十分简便,适合入门学习,也对极其适合对硬件要求很低的轻量级模型的运行。

后续可以学习【GPU】Pytorch模型转ONNX格式流程详解


PyTorch模型环境搭建(CPU)

博主以伪装对象分割(COS)之PFNet算法为例进行详解:【PFNet-pytorch代码】。
用PyTorch运行一个伪装对象分割模型PFNet,并把模型部署到ONNX Runtime这个推理引擎上。
博主在win10环境下装anaconda环境,搭建PFNet模型运行的PyTorch环境(官网下载地址)

# 创建虚拟环境
conda create -n pytorch2onnx_cpu python=3.10 -y
# 激活环境
activate pytorch2onnx_cpu 
# 下载githup源代码到合适文件夹,并cd到代码文件夹内(科学上网)
git clone https://github.com/Mhaiyang/CVPR2021_PFNet.git
# 安装pytorch(cpu)
pip3 install torch torchvision torchaudio

博主在这里不会详细讲解代码内容,只关注代码的使用,即代码的测试过程。源码作者提供了预训练权重和测试数据,博主整理到了【百度云,提取码:a660】上供大家下载。
下载resnet50-19c8e357.pth放置到CVPR2021_PFNet\backbone\resnet下:

下载PFNet.pth放置到CVPR2021_PFNet下:

下载测试数据集CAMO_TestingDataset.zip、CHAMELEON_TestingDataset.zip和COD10K_TestingDataset.zip解压重命名放置到CVPR2021_PFNet\data\test中:

使用预训练权重进行测试,修改infer.py文件内容

# 1.修改infer.py,只保留在test中有的数据集
to_test = OrderedDict([
                       ('CHAMELEON', chameleon_path),
                       ('CAMO', camo_path),
                       ('COD10K', cod10k_path),
                       # ('NC4K', nc4k_path)
                       ])
                       
# 2.修改infer.py,删除/注释所有使用gpu相关代码
# device_ids = [0]
# torch.cuda.set_device(device_ids[0])

# net = PFNet(backbone_path).cuda(device_ids[0])
net = PFNet(backbone_path)

# img_var = Variable(img_transform(img).unsqueeze(0)).cuda(device_ids[0])
img_var = Variable(img_transform(img).unsqueeze(0))

# 3.修改config.py中的内容
# datasets_root = '../data/NEW'修改成datasets_root = './data              

在CVPR2021_PFNet\results可以查看效果:

数据量比较大,运行速度也不算快。

到这里PyTorch模型环境搭建(CPU)完毕。


安装onnx和onnxruntime(CPU)

需要在anaconda虚拟环境安装onnx和onnxruntime

# 激活环境
activate pytorch2onnx_cpu 
# 安装onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnx
# 安装CPU版
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxruntime

获取ONNX Runtime的版本信息

import onnxruntime as ort
print("ONNX Runtime version:", ort.__version__)

pytorch2onnx

在CVPR2021_PFNet目录下新建pytorch2onnx.py文件并执行文件

import onnx
from onnx import numpy_helper
import torch
from PFNet import PFNet
backbone_path = './backbone/resnet/resnet50-19c8e357.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
example = torch.randn(1,3, 416, 416).to(device)     # 1 3 416 416
print(example.dtype)
model = PFNet(backbone_path)                        # PFNet网络模型

model.load_state_dict(torch.load(r'PFNet.pth'))     # 加载训练好的模型
model = model.to(device)                            # 模型放到cpu上
model.eval()

torch.onnx.export(model, example, r"PFNet.onnx")     	# 导出模型
model_onnx = onnx.load(r"PFNet.onnx")                   # onnx加载保存的onnx模型
onnx.checker.check_model(model_onnx)                    # 检查模型是否有问题
print(onnx.helper.printable_graph(model_onnx.graph))    # 打印onnx网络

pytorch模型转化成onnx模型成功。
现在抛开任何pytorch相关的依赖,使用onnx模型完成测试,新建run_onnx.py,代码是参考源代码的推理部分infer.py改写来的。

import onnxruntime as ort
import numpy as np
from collections import OrderedDict
from config import *
from PIL import Image
from numpy import mean
import time
import datetime

def composed_transforms(image):
    mean = np.array([0.485, 0.456, 0.406])  # 均值
    std = np.array([0.229, 0.224, 0.225])  # 标准差
    # transforms.Resize是双线性插值
    resized_image = image.resize((args['scale'], args['scale']), resample=Image.BILINEAR)
    # onnx模型的输入必须是np,并且数据类型与onnx模型要求的数据类型保持一致
    resized_image = np.array(resized_image)
    normalized_image = (resized_image/255.0 - mean) / std
    return np.round(normalized_image.astype(np.float32), 4)

def check_mkdir(dir_name):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

to_test = OrderedDict([
                       # ('CHAMELEON', chameleon_path),
                       # ('CAMO', camo_path),
                       ('COD10K', cod10k_path),
                       ])
args = {
    'scale': 416,
    'save_results': True
}

def main():
    # 保存检测结果的地址
    results_path = './results2'
    exp_name = 'PFNet'
    providers = ["CPUxecutionProvider"]
    ort_session = ort.InferenceSession("PFNet.onnx", providers=providers)  # 创建一个推理session
    input_name = ort_session.get_inputs()[0].name
    # 输出有四个
    output_names = [output.name for output in ort_session.get_outputs()]
    start = time.time()
    for name, root in to_test.items():
        time_list = []
        image_path = os.path.join(root, 'image')
        if args['save_results']:
            check_mkdir(os.path.join(results_path, exp_name, name))
        img_list = [os.path.splitext(f)[0] for f in os.listdir(image_path) if f.endswith('jpg')]
        for idx, img_name in enumerate(img_list):
            img = Image.open(os.path.join(image_path, img_name + '.jpg')).convert('RGB')
            w, h = img.size
            #  对原始图像resize和归一化
            img_var = composed_transforms(img)
            # np的shape从[w,h,c]=>[c,w,h]
            img_var = np.transpose(img_var, (2, 0, 1))
            # 增加数据的维度[c,w,h]=>[bathsize,c,w,h]
            img_var = np.expand_dims(img_var, axis=0)
            start_each = time.time()
            prediction = ort_session.run(output_names, {input_name: img_var})
            time_each = time.time() - start_each
            time_list.append(time_each)
            # 除去多余的bathsize维度,NumPy变会PIL同样需要变换数据类型
            # *255替换pytorch的to_pil
            prediction = (np.squeeze(prediction[3])*255).astype(np.uint8)
            if args['save_results']:
               (Image.fromarray(prediction).resize((w, h)).convert('L').save(os.path.join(results_path, exp_name, name, img_name + '.png')))
        print(('{}'.format(exp_name)))
        print("{}'s average Time Is : {:.3f} s".format(name, mean(time_list)))
        print("{}'s average Time Is : {:.1f} fps".format(name, 1 / mean(time_list)))
    end = time.time()
    print("Total Testing Time: {}".format(str(datetime.timedelta(seconds=int(end - start)))))
if __name__ == '__main__':
    main()

在CVPR2021_PFNet\results2可以查看效果:
在这里插入图片描述

到这里读者将代码迁移到新机器时,可以不再安装pytorch相关依赖就能使用模型的预测功能,这可以极大的减少所依赖环境的大小。


总结

尽可能简单、详细的介绍CPU模式下Pytorch模型转ONNX格式的流程,后续介绍GPU版本的格式转化,学习难度只是有略微提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1023378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LCP 50. 宝石补给(每日一题)

欢迎各位勇者来到力扣新手村,在开始试炼之前,请各位勇者先进行「宝石补给」。 每位勇者初始都拥有一些能量宝石, gem[i] 表示第 i 位勇者的宝石数量。现在这些勇者们进行了一系列的赠送,operations[j] [x, y] 表示在第 j 次的赠送…

解决VS Code安装远程服务器插件慢的问题

解决VS Code安装远程服务器插件慢的问题 最近想在服务器上做juypter notebook的代码运行,发现要给服务器安装Jupyter插件,但是安装速度奇慢无比(因为服务器不连外网),一开始查看从VS Code插件市场下载插件的博客&…

网络编程day02(socket套接字)

今日任务&#xff1a; TCP\UDP服务端客户端通信 TCP&#xff1a;代码 服务端&#xff1a; #include <stdio.h> #include <string.h> #include <stdlib.h> #include <sys/types.h> #include <sys/socket.h> #include <arpa/inet.h> #in…

项目提交按钮没防抖,差点影响了验收

前端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★ 地址&#xff1a;前端面试题库 表妹一键制作自己的五星红旗国庆头像&#xff0c;超好看 前言 一个运行了多年的ToB的项目&#xff0c;由于数据量越来越大&#xff0c;业务越来越复杂&…

【HarmonyOS】【DevEco Studio】盘点DevEco Studio日志获取途径

【关键词】 DevEco Studio、日志获取 【问题背景】 在收到IDE工单的时候&#xff0c;很多时候开发者出现的问题都需要提供一些日志&#xff0c;然后根据日志分析&#xff0c;那么你知道IDE各种日志的获取方式么&#xff1f;往下看 【获取方法】 一、idea.log获取 IDE界面H…

滴滴一面:说说MySQL主从数据同步机制

说在前面 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;最近有小伙伴拿到了一线互联网企业如滴滴、阿里、汽车之家、极兔、有赞、希音、百度、网易、滴滴的面试资格&#xff0c;遇到一几个很重要的主从同步面试题&#xff1a; 说说MySQL主从同步的流程说说MySQL主从同步…

添加一个仅管理员可见的页面

例如我新加一个页面 申请一个路由 《插播》 前端是如何知道我们是管理员的呢&#xff0c;ant-design框架会帮我们存到InitialState里&#xff0c;做为全局变量 在access.ts里我们获取到了用户是否为管理员 &#xff08;用户存在且为管理员&#xff09; 框架为我们打通了个路由…

【深度学习实验】前馈神经网络(二):使用PyTorch实现不同激活函数(logistic、tanh、relu、leaky_relu)

目录 一、实验介绍 二、实验环境 1. 配置虚拟环境 2. 库版本介绍 三、实验内容 0. 导入必要的工具包 1. 定义激活函数 logistic(z) tanh(z) relu(z) leaky_relu(z, gamma0.1) 2. 定义输入、权重、偏置 3. 计算净活性值 4. 绘制激活函数的图像 5. 应用激活函数并…

MySQL基础—从零开始学习MySQL

01.MySQL课程介绍_哔哩哔哩_bilibili 1、MySQL安装 以管理员身份运行cmd net start mysql80net stop mysql80 客户端连接 1). 方式一&#xff1a;使用MySQL提供的客户端命令行工具 2). 方式二&#xff1a;使用系统自带的命令行工具执行指令 mysql [-h 127.0.0.1] [-P 3…

mysql知识大全

MySQL知识大全&#xff08;2&#xff09; MySqL 基础为1—7&#xff08;增删改查基础语法&#xff09;&#xff0c;MySQL进阶知识为8—11&#xff08;约束、数据库设计、多表查询、事务&#xff09; 1、数据库相关概念 以前我们做系统&#xff0c;数据持久化的存储采用的是文件…

【二叉树】二叉树展开为链表-力扣 114 题

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

【云原生】k8s-----集群调度

目录 1.k8s的list-watch机制 1.1 list-watc机制简介 1.2 根据list-watch机制&#xff0c;pod的创建流程 2.scheduler的调度策略 2.1 scheduler的调度策略简介 2.2 Scheduler预选策略的算法 2.3 Scheduler优选策略的算法 3. k8s中的标签管理及nodeSelector和nodeName的 调…

win10 安装 Langchain-Chatchat 避坑指南(2023年9月18日v0.2.4版本,包含全部下载内容!)

网上教程都是基于外网或者翻墙的&#xff0c;而且细节极其不清晰&#xff0c;尤其是最关键的模型下载。 另外提一句&#xff0c;我的显卡是&#xff1a;3080Ti 16GB版本&#xff0c;运行之后&#xff0c;显存占用13-14GB 1、安装Anaconda&#xff08;这个就不啰嗦了&#xff0c…

【SpringMVC】JSON注解全局异常处理机制

&#x1f389;&#x1f389;欢迎来到我的CSDN主页&#xff01;&#x1f389;&#x1f389; &#x1f3c5;我是Java方文山&#xff0c;一个在CSDN分享笔记的博主。&#x1f4da;&#x1f4da; &#x1f31f;在这里&#xff0c;我要推荐给大家我的专栏《Spring MVC》。&#x1f3…

Nue JS 造全新的 Web 生态

Nue JS 是最近开源的 Web 前端项目&#xff0c;用于构建用户界面&#xff0c;体积非常小&#xff08;压缩后 2.3kb&#xff09;。Nue JS 支持服务器端渲染 (SSR)、反应式组件和 “同构” 组合 ("isomorphic" combinations)。 Vue.js、React.js 或 Svelte&#xff0c;…

Day 01 python学习笔记

1、引入 让我们先写第一个python程序&#xff08;如果是纯小白的话&#xff09; 因为我们之前安装了python解释器 所以我们直接win r ---->输入cmd&#xff08;打开运行终端&#xff09; >python #&#xff08;在终端中打开python解释器&#xff09;>>>pri…

CSDN博客可以添加联系方式了

csdn博客一直不允许留一些联系方式&#xff0c;结果是官方有联系方式路径 在首页&#xff0c;往下拉&#xff0c;左侧就有 点击这个即可添加好友了~ 美滋滋&#xff0c;一起交流&#xff0c; 学习技术 ~

详细介绍如何微调 YOLOv8 姿势模型以进行动物姿势估计--附完整源码

动物姿势估计是计算机视觉的一个研究领域,是人工智能的一个子领域,专注于自动检测和分析图像或视频片段中动物的姿势和位置。目标是确定一只或多只动物身体部位的空间排列,例如头部、四肢和尾巴。这项技术具有广泛的应用,从研究动物行为和生物力学到野生动物保护和监测。 …

CS 创世SD NAND FLASH 存储芯片,比TF卡更小巧轻便易用的大容量存储,TF卡替代方案

文章目录 介绍创世SD卡引脚与NOR Flash存储比较 介绍 SD NAND FLASH&#xff08;Secure Digital NAND Flash&#xff09;是一种安全数字 NAND 闪存技术&#xff0c;通常用于存储数据&#xff0c;并且具有一些额外的安全特性。这种技术结合了 NAND 闪存的高密度存储能力和安全性…

JavaScript 期约与异步函数的学习笔记

同步与异步的概念 JavaScript 是一门单线程的语言&#xff0c;这意味着它在任何给定的时间只能执行一个任务。 然而&#xff0c;JavaScript 通过异步编程技术来处理并发操作&#xff0c;以避免阻塞主线程的情况。 在上图中&#xff0c;同步行为的进程 A 因为等待进程 B 执行完…