【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署

news2024/9/21 20:31:16

【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署

提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署
  • 前言
  • Windows平台搭建依赖环境
  • 模型转换--pytorch转onnx
  • ONNXRuntime推理代码
  • 总结


前言

本期将讲解深度学习图像分类网络VggNet模型的部署,对于该算法的基础知识,可以参考博主【VggNet模型算法Pytorch版本详解】博文。
读者可以通过学习 【onnx部署】部署系列学习文章目录的onnxruntime系统学习–Python篇 的内容,系统的学习OnnxRuntime部署不同任务的onnx模型。


Windows平台搭建依赖环境

在【入门基础篇】中详细的介绍了onnxruntime环境的搭建以及ONNXRuntime推理核心流程代码,不再重复赘述。


模型转换–pytorch转onnx

import torch
import torchvision as tv
def resnet2onnx():
    # 使用torch提供的预训练权重 1000分类
    model = tv.models.vgg16(pretrained=True)
    model.eval()
    model.cpu()
    dummy_input1 = torch.randn(1, 3, 224, 224)
    torch.onnx.export(model, (dummy_input1), "vgg16.onnx", verbose=True, opset_version=11)
if __name__ == "__main__":
    resnet2onnx()


如下图,torchvision本身提供了不少经典的网络,为了减少教学复杂度,这里博主直接使用了torchvision提供的ResNet网络,并下载和加载了它提供的训练权重。这里可以替换成自己的搭建的ResNet网络以及自己训练的训练权重。


ONNXRuntime推理代码

需要配置imagenet_classes.txt【百度云下载,提取码:rkz7 】文件存储1000类分类标签,假设是用户自定的分类任务,需要根据实际情况作出修改,并将其放置到工程目录下(推荐)。

这里需要将vgg16.onnx放置到工程目录下(推荐),并且将以下推理代码拷贝到新建的py文件中,并执行查看结果。

import onnxruntime as ort
import cv2
import numpy as np

# 加载标签文件获得分类标签
def read_class_names(file_path="./imagenet_classes.txt"):
    class_names = []
    try:
        with open(file_path, 'r') as fp:
            for line in fp:
                name = line.strip()
                if name:
                    class_names.append(name)
    except IOError:
        print("could not open file...")
        import sys
        sys.exit(-1)
    return class_names

# 主函数
def main():
    # 预测的目标标签数
    labels = read_class_names()

    # 测试图片
    image_path = "./lion.jpg"
    image = cv2.imread(image_path)
    # cv2.imshow("输入图", image)
    # cv2.waitKey(0)

    # 设置会话选项
    sess_options = ort.SessionOptions()
    # 0=VERBOSE, 1=INFO, 2=WARN, 3=ERROR, 4=FATAL
    sess_options.log_severity_level = 3
    # 优化器级别:基本的图优化级别
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
    # 线程数:4
    sess_options.intra_op_num_threads = 4
    # 设备使用优先使用GPU而是才是CPU,列表中的顺序决定了执行提供者的优先级
    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

    # onnx训练模型文件
    onnxpath = "./vgg16.onnx"

    # 加载模型并创建会话
    session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)

    input_nodes_num = len(session.get_inputs())     # 输入节点输
    output_nodes_num = len(session.get_outputs())   # 输出节点数
    input_node_names = []                           # 输入节点名称
    output_node_names = []                          # 输出节点名称

    # 获取模型输入信息
    for i in range(input_nodes_num):
        # 获得输入节点的名称并存储
        input_name = session.get_inputs()[i].name
        input_node_names.append(input_name)
        # 显示输入图像的形状
        input_shape = session.get_inputs()[i].shape
        ch, input_h, input_w = input_shape[1], input_shape[2], input_shape[3]
        print(f"input format: {ch}x{input_h}x{input_w}")

    # 获取模型输出信息
    for i in range(output_nodes_num):
        # 获得输出节点的名称并存储
        output_name = session.get_outputs()[i].name
        output_node_names.append(output_name)
        # 显示输出结果的形状
        output_shape = session.get_outputs()[i].shape
        num, nc = output_shape[0], output_shape[1]
        print(f"output format: {num}x{nc}")

    input_shape = session.get_inputs()[0].shape
    input_h, input_w = input_shape[2], input_shape[3]
    print(f"input format: {input_shape[1]}x{input_h}x{input_w}")

    # 预处理输入数据
    # 默认是BGR需要转化成RGB
    rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # 对图像尺寸进行缩放
    blob = cv2.resize(rgb, (input_w, input_h))
    blob = blob.astype(np.float32)
    # 对图像进行标准化处理
    blob /= 255.0   # 归一化
    blob -= np.array([0.485, 0.456, 0.406])  # 减去均值
    blob /= np.array([0.229, 0.224, 0.225])  # 除以方差
    #CHW-->NCHW 维度扩展
    timg = cv2.dnn.blobFromImage(blob)
    # ---blobFromImage 可以用以下替换---
    # blob = blob.transpose(2, 0, 1)
    # blob = np.expand_dims(blob, axis=0)
    # -------------------------------

    # 模型推理
    try:
        ort_outputs = session.run(output_names=output_node_names, input_feed={input_node_names[0]: timg})
    except Exception as e:
        print(e)
        ort_outputs = None

    # 后处理推理结果
    prob = ort_outputs[0]
    max_index = np.argmax(prob)     # 获得最大值的索引
    print(f"label id: {max_index}")
    # 在测试图像上加上预测的分类标签
    label_text = labels[max_index]
    cv2.putText(image, label_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, 8)
    cv2.imshow("输入图像", image)
    cv2.waitKey(0)

if __name__ == '__main__':
    main()

图片预测为猎豹(cheetah),没有准确预测出狮子(lion),但是这个图片难度很大,在1000分类中预测的比较接近的。

其实图像分类网络的部署代码基本是一致的,几乎不需要修改,只需要修改传入的图片数据已经训练模型权重即可。


总结

尽可能简单、详细的讲解了Python下onnxruntime环境部署VggNet模型的过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2134343.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Understanding the model of openAI 5 (1024 unit LSTM reinforcement learning)

题意:理解 OpenAI 5(1024 单元 LSTM 强化学习)的模型 问题背景: I recently came across openAI 5. I was curious to see how their model is built and understand it. I read in wikipedia that it "contains a single l…

计算机网络29——Linux基本命令vim,gcc编译命令

1、创建新用户 2、给用户设置密码 3、切换到新用户 切换到root用户 4、删除用户 5、查看ip 6、ping 查看物理上两台主机是否联通 7、netstatus 8、nslookup 查看网址的地址 9、负载均衡与容灾备份 负载均衡:指将负载(工作任务)进行平衡、分…

为什么mac打不开rar文件 苹果电脑打不开rar压缩文件怎么办

你是否遇到过这样的情况,下载了一个rar文件,想要查看里面的内容,却发现Mac电脑无法打开。rar文件是一种常见的压缩文件格式,它可以将多个文件或文件夹压缩成一个文件,节省空间和传输时间。如此高效实用的压缩文档&…

JavaEE:网络初识

文章目录 网络初识网络中的重要概念IP地址端口号认识协议(最核心概念)OSI七层模型TCP/IP五层(或四层)网络模型网络设备所在分层封装和分用 网络初识 网络中的重要概念 网络互联的目的是进行网络通信,也是网络数据传输,更具体一点,是网络主机中的不同进程间,基于网络传输数据.…

论文解读《LaMP: When Large Language Models Meet Personalization》

引言:因为导师喊我围绕 “大语言模型的个性化、风格化生成” 展开研究,所以我就找相关论文,最后通过 ACL 官网找到这篇,感觉还不错,就开始解读吧! “说是解读,其实大部分都是翻译哈哈哈&#x…

域控操作十七点五:域用户无管理员权限下安装IT打包的软件

1,需要软件Runasspcadmin三件套和winrar压缩软件 2,将需要打包的软件放进这个文件夹内,使用播放器举个例子 3,打开runasspcadmin.exe 按图片写就行了 文件夹现在是这样的然后全选右击,用WinRAR添加到压缩包 这个可以自…

量化交易backtrader实践(一)_数据获取篇(4)_通达信数据应用

在第2节实践了从金融数据接口包例如tushare.pro或akshare获取数据,在第3节实践了直接从网页上爬取股票数据。其实,我们的电脑上怎么可能没有几个股票软件,在这些股票软件里,历史行情,实时行情都有,我们能否…

Windows环境本地部署Oracle 19c及卸载实操手册

前言: 一直在做其他测试,貌似都忘了Windows环境oracle 19c的部署,这是一个很早很早的安装记录了,放上来做个备录给到大家参考。 Oracle 19c‌:进一步增强了自动化功能,并提供了更好的性能和安全性。这个版本在自动化、性能和安全性方面进行了重大改进,以满足现代企业对数…

运维人员转行 AI 大模型全攻略:史上最详尽总结,一篇在手,转行无忧!

前言 做运维的苦,谁做谁懂。有时候真感觉自己就像个杂役,在公司都快成修电脑的了。不装了,我要转行!在此给大家分享点经验,希望能帮到你们。 运维工程师若要转行至大模型领域,需要学习一系列全新的技能与…

开放式耳机原理?五款超强单品推荐!

开放式耳机的原理其实挺直观的,它们不像那些把耳朵完全罩住或者塞住的封闭式耳机。开放式耳机通常就是轻轻地挂在耳朵上,声音通过空气传播,直接送到你的耳朵里。 这种设计有几个好处。首先,因为耳朵没有被完全封闭,所…

【开源项目】数字孪生公园~云南某湿地公园—开源工程及源码

飞渡科技数字孪生湿地公园管理平台,基于园林行业定制硬件以及传感器、摄像头等终端采集数据,借助自主研发国产渲染引擎,以人工智能、物联网、数字孪生技术作为核心,还原公园内外的真实场景,同时实现海量数据处理、系统…

DB-GPT部署和试用

前言 DB-GPT是一个开源的AI原生数据应用开发框架(AI Native Data App Development framework with AWEL(Agentic Workflow Expression Language) and Agents)。 目的是构建大模型领域的基础设施,通过开发多模型管理(SMMF)、Text2SQL效果优化、RAG框架以及优化、Mu…

哇!原来vscode的终端可以这么美

相信很多开发小伙伴经常可以看到,为什么别人的vscode的终端可以这么美,又有提示,还有git是提示,时刻告诉你现在正在处于哪个分支,接下来,就让我为大家告诉一个美化vscode终端的方法 先看效果 只要来到这个网…

linux_L2_linux删除文件

linux 删除文件 在Linux下删除文件有多种实现方法,以下是其中几种常见的方法: 方法一:使用rm命令删除单个文件 rm 文件路径例如,删除当前目录下的文件file.txt: rm file.txtQuestion :当你在Linux系统中使用rm命令删…

【视频教程】基于PyTorch深度学习无人机遥感影像目标检测、地物分类及语义分割实践技术应用

随着无人机自动化能力的逐步升级,它被广泛的应用于多种领域,如航拍、农业、植保、灾难评估、救援、测绘、电力巡检等。但同时由于无人机飞行高度低、获取目标类型多、以及环境复杂等因素使得对无人机获取的数据处理越来越复杂。最近借助深度学习方法&…

无线领夹麦克风哪个牌子好,口碑最好的麦克风品牌,领夹麦推荐

在数字化时代的浪潮中,无线领夹麦克风作为现代通讯与创意表达不可多得的工具,正迅速渗透至各类专业及日常场景。在其便捷性与高效性备受推崇的背后,行业内不为人知的秘密也正逐渐浮出水面。近期,五大无线领夹麦克风行业痛点被曝光…

VPSA制氧机与PSA制氧机的差异

制氧机在现代工业及环保等多个领域具有广泛应用,其中VPSA(变压吸附)制氧机和PSA(压力吸附)制氧机是两种常见的制氧设备。尽管两者在基本原理上相似,但在实际应用中却存在诸多显著差异。 工作原理 VPSA制氧机采用变压吸附技术,通过改变吸附剂的…

无线麦克风哪个好,领夹麦克风哪个品牌音质最好,无线麦克风推荐

随着科技的进步,无线领夹麦克风市场迎来了智能化浪潮,各种功能宣传铺天盖地。然而,在这场技术革新的盛宴中,也不乏商家利用信息不对称,设置“智商税”陷阱。从夸大其词的降噪效果到实际使用中的频频失效,再…

Mac上的rar文件怎么解压?Mac上解压RAR文件超实用的方法

rar文件是一种常见的压缩文件格式,它可以将多个文件或文件夹打包成一个文件,从而节省空间和方便传输。但是,mac系统并没有自带的工具可以直接打开或解压rar文件。在这篇文章中,我们将详细解答关于mac解压rar文件的问题。希望我们能…

CI/CD中gitlab和jenkins讲解

一 CICD是什么 CI/CD 是指持续集成(Continuous Integration)和持续部署(Continuous Deployment)或持续交付(Continuous Delivery) 1.1 持续集成(Continuous Integration) 持续集成…