【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程

news2025/3/9 10:56:49

【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程

提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 【深度学习】【OnnxRuntime】【Python】模型转化、环境搭建以及模型部署的详细教程
  • 前言
  • 模型转换--pytorch转onnx
  • Windows平台搭建依赖环境
  • onnxruntime调用onnx模型
    • ONNXRuntime推理核心流程
    • ONNXRuntime推理代码
  • 总结


前言

ONNXRuntime是微软推出的一款高性能的机器学习推理引擎框架,用户可以非常便利的用其运行一个onnx模型,专注于加速机器学习模型的预测阶段。ONNXRuntime设计目的是为了提供一个高效的执行环境,使机器学习模型能够在各种硬件上快速执行,支持多种运行后端包括CPU,GPU,TensorRT,DML等,使得开发者可以灵活选择最适合其应用场景的硬件平台。
ONNXRuntime是对ONNX模型最原生的支持。

读者可以通过学习【onnx部署】部署系列学习文章目录的onnxruntime系统学习–Python篇 的内容,系统的学习OnnxRuntime部署不同任务的onnx模型。


模型转换–pytorch转onnx

Pytorch模型转onnx并推理的步骤如下:

  1. 将PyTorch预训练模型文件( .pth 或 .pt 格式)转换成ONNX格式的文件(.onnx格式),这一转换过程在PyTorch环境中进行。
  2. 将转换得到的 .onnx 文件随后作为输入,调用ONNXRuntime的C++ API来执行模型的推理。

博主使用AlexNet图像分类(五种花分类)进行演示,需要安装pytorch环境,对于该算法的基础知识,可以参考博主【AlexNet模型算法Pytorch版本详解】博文

conda create --name AlexNet python==3.10
conda activate AlexNet
# 根据自己主机配置环境
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# 假设模型转化出错则降级为指定1.16.1版本
pip install onnx==1.16.1

然后把训练模型好的AlexNet.pth模型转成AlexNet.onnx模型,pyorch2onnx.py转换代码如下:

import torch
from model import AlexNet
model = AlexNet(num_classes=5)
weights_path = "./AlexNet.pth"
# 加载模型权重
model.load_state_dict(torch.load(weights_path))
# 模型推理模式
model.eval()
model.cpu()
# 虚拟输入数据
dummy_input1 = torch.randn(1, 3, 224, 224)
# 模型转化函数
torch.onnx.export(model, (dummy_input1), "AlexNet.onnx", verbose=True, opset_version=11)


【AlexNet.pth百度云链接,提取码:ktq5 】直接下载使用即可。


Windows平台搭建依赖环境

需要在anaconda虚拟环境安装onnxruntime,需要注意onnxruntime-gpu, cuda, cudnn三者的版本要对应,具体参照官方说明。

博主是win11+cuda12.1+cudnn8.8.1,对应onnxruntime-gpu==1.18.0

import torch
# 查询cuda版本
print(torch.version.cuda)
# 查询cudnn版本
print(torch.backends.cudnn.version())

# 激活环境
activate AlexNet
# 安装onnx
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnx
# 安装GPU版
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxruntime-gpu==1.18.0
# 或者可以安装CPU版本:没有版本对应要求
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxruntime
# 安装opencv
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python

onnxruntime调用onnx模型

ONNXRuntime推理核心流程

设置会话选项
通常包括配置优化器级别、线程数和设备(GPU/CPU)使用等。

sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.intra_op_num_threads = 4
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
会换选项日志严重性级别优化器级别线程数设备使用
选项log_severity_levelgraph_optimization_levelgraph_optimization_levelCUDAExecutionProvider;CPUExecutionProvider
作用决定了哪些级别的日志信息将被记录下来,运行时提供了几个预定义的宏来表示不同的日志级别。在模型加载到ONNXRuntime之前对其进行图优化的过程,提高执行效率设置每个运算符内部执行时的最大线程数CUDA/CPU设备选择。
参数整形,1:Info, 2:Warning. 3:Error, 4:Fatal,默认是2。ORT_ENABLE_BASIC:基本的图优化; ORT_DISABLE_ALL:禁用所有优化;ORT_ENABLE_EXTENDED:启用扩展优化;ORT_ENABLE_ALL:启用所有优化。整型列表中的顺序决定了执行提供者的优先级。

加载模型并创建会话
加载预训练的ONNX模型文件,使用运行时环境、会话选项和模型创建一个Session对象。

session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)
ort.InferenceSession参数path_or_bytessess_optionsproviders
内容模型的位置或者模型的二进制数据会话选项设备选择

获取模型输入输出信息
从Session对象中获取模型输入和输出的详细信息,包括数量、名称、类型和形状。

input_nodes_num = len(session.get_inputs()) 
output_nodes_num = len(session.get_outputs()) 
input_name = session.get_inputs()[i].name
output_name = session.get_outputs()[i].name
input_shape = session.get_inputs()[i].shape
output_shape = session.get_outputs()[i].shape

预处理输入数据
对输入数据进行颜色空间转换,尺寸缩放、标准化以及形状维度扩展操作。

rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
blob = cv2.resize(rgb, (input_w, input_h))
blob = blob.astype(np.float32)
blob /= 255.0
blob -= np.array([0.485, 0.456, 0.406])
blob /= np.array([0.229, 0.224, 0.225])
timg = cv2.dnn.blobFromImage(blob)

这部分不是OnnxRuntime核心部分,根据任务需求不同,代码略微不同。

执行推理
调用Session.run方法,传入输入张量、输出张量名和其他必要的参数,执行推理。

ort_outputs = session.run(output_names=input_node_names, input_feed={output_node_names[0]: timg})
Session.run参数output_namesinput_feed
含义输出节点名称的列表。输入节点名称和输入数据的键值对字典,可能有多个输入。

后处理推理结果
推理完成后,从输出张量中获取结果数据,根据需要对结果进行后处理,以获得最终的预测结果。

prob = ort_outputs[0]
max_index = np.argmax(prob)

这部分不是OnnxRuntime核心部分,根据任务需求不同,代码基本不同。


ONNXRuntime推理代码

需要配置flower_classes.txt文件存储五种花的分类标签,并将其放置到工程目录下(推荐)。

daisy
dandelion
roses
sunflowers
tulips

这里需要将AlexNet.onnx放置到工程目录下(推荐),并且将以下推理代码拷贝到新建的py文件中,并执行查看结果。

import onnxruntime as ort
import cv2
import numpy as np

# 加载标签文件获得分类标签
def read_class_names(file_path="./flower_classes.txt"):
    class_names = []
    try:
        with open(file_path, 'r') as fp:
            for line in fp:
                name = line.strip()
                if name:
                    class_names.append(name)
    except IOError:
        print("could not open file...")
        import sys
        sys.exit(-1)
    return class_names

# 主函数
def main():
    # 预测的目标标签数
    labels = read_class_names()

    # 测试图片
    image_path = "./sunflowers.jpg"
    image = cv2.imread(image_path)
    cv2.imshow("输入图", image)
    cv2.waitKey(0)

    # 设置会话选项
    sess_options = ort.SessionOptions()
    # 0=VERBOSE, 1=INFO, 2=WARN, 3=ERROR, 4=FATAL
    sess_options.log_severity_level = 3
    # 优化器级别:基本的图优化级别
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
    # 线程数:4
    sess_options.intra_op_num_threads = 4
    # 设备使用优先使用GPU而是才是CPU,列表中的顺序决定了执行提供者的优先级
    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

    # onnx训练模型文件
    onnxpath = "./AlexNet.onnx"

    # 加载模型并创建会话
    session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)

    input_nodes_num = len(session.get_inputs())     # 输入节点输
    output_nodes_num = len(session.get_outputs())   # 输出节点数
    input_node_names = []                           # 输入节点名称
    output_node_names = []                          # 输出节点名称

    # 获取模型输入信息
    for i in range(input_nodes_num):
        # 获得输入节点的名称并存储
        input_name = session.get_inputs()[i].name
        input_node_names.append(input_name)
        # 显示输入图像的形状
        input_shape = session.get_inputs()[i].shape
        ch, input_h, input_w = input_shape[1], input_shape[2], input_shape[3]
        print(f"input format: {ch}x{input_h}x{input_w}")

    # 获取模型输出信息
    for i in range(output_nodes_num):
        # 获得输出节点的名称并存储
        output_name = session.get_outputs()[i].name
        output_node_names.append(output_name)
        # 显示输出结果的形状
        output_shape = session.get_outputs()[i].shape
        num, nc = output_shape[0], output_shape[1]
        print(f"output format: {num}x{nc}")

    input_shape = session.get_inputs()[0].shape
    input_h, input_w = input_shape[2], input_shape[3]
    print(f"input format: {input_shape[1]}x{input_h}x{input_w}")

    # 预处理输入数据
    # 默认是BGR需要转化成RGB
    rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # 对图像尺寸进行缩放
    blob = cv2.resize(rgb, (input_w, input_h))
    blob = blob.astype(np.float32)
    # 对图像进行标准化处理
    blob /= 255.0   # 归一化
    blob -= np.array([0.485, 0.456, 0.406])  # 减去均值
    blob /= np.array([0.229, 0.224, 0.225])  # 除以方差
    #CHW-->NCHW 维度扩展
    timg = cv2.dnn.blobFromImage(blob)
    # ---blobFromImage 可以用以下替换---
    # blob = blob.transpose(2, 0, 1)
    # blob = np.expand_dims(blob, axis=0)
    # -------------------------------

    # 模型推理
    try:
        ort_outputs = session.run(output_names=output_node_names, input_feed={input_node_names[0]: timg})
    except Exception as e:
        print(e)
        ort_outputs = None

    # 后处理推理结果
    prob = ort_outputs[0]
    max_index = np.argmax(prob)     # 获得最大值的索引
    print(f"label id: {max_index}")
    # 在测试图像上加上预测的分类标签
    label_text = labels[max_index]
    cv2.putText(image, label_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, 8)
    cv2.imshow("输入图像", image)
    cv2.waitKey(0)

if __name__ == '__main__':
    main()

图片正确预测为向日葵:


总结

尽可能简单、详细的介绍了pytorch模型到onnx模型的转化,python下onnxruntime环境的搭建以及ONNX模型的OnnxRuntime部署。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2132099.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文件系统(磁盘 磁盘文件 inode)

文章目录 磁盘看看物理磁盘磁盘的存储结构 对磁盘的储存进行逻辑抽象inode号文件名 -> inode判断文件在哪个分区 磁盘 电脑中存在非常多的文件,被打开的文件只是少量的。 没有被打开的文件,在磁盘中放着,那么文件是如何存取? …

Unity 之 【Android Unity FBO渲染】之 [Unity 渲染 Android 端播放的视频] 的一种方法简单整理

Unity 之 【Android Unity FBO渲染】之 [Unity 渲染 Android 端播放的视频] 的一种方法简单整理 目录 Unity 之 【Android Unity FBO渲染】之 [Unity 渲染 Android 端播放的视频] 的一种方法简单整理 一、简单介绍 二、FBO 简单介绍 三、案例实现原理 四、注意事项 五、简…

深度盘点PLM 项目管理系统哪家强?优缺点一目了然!

本文将盘点10款知名的PLM 项目管理系统,为企业选型提供参考! 想象一下,在一个企业的产品研发过程中,各种数据、文档四处散落,不同部门之间沟通不畅,项目进度难以把控。这时,PLM 项目管理系统就如…

在线压缩图片地址

https://squoosh.app/editor这个是免费的,并且不限制图片数量 https://tinypng.com/ 这个限制图片的大小,如果单张图片超过5M需要收费 https://www.jpeg-optimizer.com/ https://imagecompressor.com/

再次进阶 舞台王者 第八季完美童模全球赛形象大使【于洪森】赛场秀场超燃合集!

7月20-23日,2024第八季完美童模全球总决赛在青岛圆满落幕。在盛大的颁奖典礼上,一位才能出众的少年——于洪森,迎来了他舞台生涯的璀璨时刻。 形象大使——于洪森,以璀璨童星之姿,优雅地踏上完美童模盛宴的绚丽舞台&am…

WPF实现Hammer 3D入门学习

代码下载:https://download.csdn.net/download/bjhtgy/89748674

springboot Web基础开发

Spring Boot 是一个用于简化 Spring 应用开发的框架,它通过自动配置和开箱即用的功能,使得创建和部署 Spring 应用变得更为高效。以下是 Spring Boot 基础 Web 开发的一些关键点和实操总结: 1. 项目搭建 使用 Spring Initializr: 访问 Spring…

代码随想录刷题day31丨56. 合并区间,738.单调递增的数字,总结

代码随想录刷题day31丨56. 合并区间,738.单调递增的数字,总结 1.题目 1.1合并区间 题目链接:56. 合并区间 - 力扣(LeetCode) 视频讲解:贪心算法,合并区间有细节!LeetCode&#x…

源代码加密软件有哪些?2024常用的10款好用的企业源代码加密软件分享!

源代码作为企业的核心资产,一旦泄露,将可能导致技术被窃取、产品被复制,甚至引发法律纠纷。 一、企业源代码泄密的危害详情描述 企业源代码泄密事件频发,其危害不容小觑。 一方面,源代码的泄露可能导致企业的核心技术…

国内领先线上运动平台:如何借助AI技术实现业务腾飞与用户体验升级

“ 从智能训练到身体分析,再到辅助判决,AI技术正以惊人的速度渗透进体育和健身领域,为运动员和健身爱好者提供了前所未有的个性化体验。 ” AI,运动的智能伴侣 在巴黎奥运会上,AI技术的运用成为了焦点。它不仅为运动…

人脸关键点数据集WFLW

数据集:Look at Boundary: A Boundary-Aware Face Alignment Algorithm 论文:Look at Boundary: A Boundary-Aware Face Alignment Algorithm 发表:CVPR2018 1. 标注点位 官方说有98个点,但是配图只有0-95,咋回事&…

java设计模式(持续更新中)

1 设计模式介绍 设计模式代表了代码的最佳实践,被有经验的开发人员使用。设计模式是很多被反复使用并知晓的,主要是对代码和经验的总结。使用设计模式是为了重用代码,并让代码更容易被人理解,保证代码的可靠性。对接口编程而不是…

单考一个OCP认证?还是OCP和OCM认证都要考?

​ Oracle的OCP认证是数据库行业非常经典的一个认证,从事数据库行业的人都建考一个 Oracle OCP 认证。 OCP认证内容包括: OCA部分:数据库基础知识、SQL 语言使用、基本的数据库管理技能等,如数据库安装与配置、理解数据库架构、…

Web APIs - DOM节点操作

Web APIs - DOM节点操作 第9天 目标: 了解DOM节点的增删改查,掌握利用数据操作页面,完成移动端通讯录案例 日期对象节点操作M端事件JS插件综合案例 1、日期对象 日期对象:用来表示日期和时间的对象 作用:可以得到当前系统日期和…

力扣139-单词拆分(Java详细题解)

题目链接:139. 单词拆分 - 力扣(LeetCode) 前情提要: 因为本人最近都来刷dp类的题目所以该题就默认用dp方法来做。 最近刚学完背包,所以现在的题解都是以背包问题为基础再来写的。 如果大家不懂背包问题的话&#…

深度盘点:2024年企业最喜欢用的WMS仓库管理系统有哪些?

本文将列举国内外知名的仓库管理系统,从每个系统的适用范围、核心功能、特点来为大家解读。为企业选型提供参考! WMS系统是Warehouse Management System(仓库管理系统)的简称,它是一个帮助企业和仓库管理者高效管理仓库…

NMOS与PMOS原理图

重点关注续流二极管方向和电流流向: NMOS应用: PMOS 应用:

BASM引领2024国家网络安全宣传周:智能守护,打造全方位业务与应用安全监测平台

在这个信息泛滥的时代,网络安全已不再是可有可无的选项。 随着技术的飞速发展,新型网络攻击层出不穷,数据泄露、恶意攻击频发,保护个人与企业的数字安全显得尤为重要。 2024年国家网络安全宣传周期间,通付盾给大家带…

Cortex-R52+的PE mode详解--Abort

目录 1.R52 AArch32通用寄存器描述 2.Abort模式是什么 3.实例详解 1.R52 AArch32通用寄存器描述 上篇文章我们阐述了关于R52异常如何定位,其中详细说明了发生异常后应该在什么模式下去观察寄存器。 今天就以Abort异常为例,详解下如何精准定位Abort异…

一文读懂网络安全等级保护

网络安全等级保护(简称“等保”)是我国为了保护信息安全而推出的一项制度,旨在通过对信息系统分等级实施安全保护,确保信息安全。它涵盖了信息和存储、传输、处理这些信息的信息系统,以及使用的信息安全产品。等级保护…