【深度学习】【图像分类】【OnnxRuntime】【C++】ResNet模型部署

news2024/9/21 0:38:14

【深度学习】【图像分类】【OnnxRuntime】【C++】ResNet模型部署

提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 【深度学习】【图像分类】【OnnxRuntime】【C++】ResNet模型部署
  • 前言
  • 模型转换--pytorch转onnx
  • Windows平台搭建依赖环境
  • ONNXRuntime推理代码
  • 总结


前言

本期将讲解深度学习图像分类网络ResNet模型的部署,对于该算法的基础知识,可以参考博主【ResNet模型算法Pytorch版本详解】博文。
读者可以通过学习【onnx部署】部署系列学习文章目录的onnxruntime系统学习–c++篇 的内容,系统的学习OnnxRuntime部署不同任务的onnx模型。


模型转换–pytorch转onnx

首先需要搭建pytorch环境,推荐基础的读者参考博主的博文搭建环境【PyTorch环境搭建】

import torch
import torchvision as tv
def resnet2onnx():
    # 使用torch提供的预训练权重 1000分类
    model = tv.models.resnet50(pretrained=True)
    model.eval()
    model.cpu()
    dummy_input1 = torch.randn(1, 3, 224, 224)
    torch.onnx.export(model, (dummy_input1), "resnet50.onnx", verbose=True, opset_version=11)
if __name__ == "__main__":
    resnet2onnx()


如下图,torchvision本身提供了不少经典的网络,为了减少教学复杂度,这里博主直接使用了torchvision提供的ResNet网络,并下载和加载了它提供的训练权重。这里可以替换成自己的搭建的ResNet网络以及自己训练的训练权重。


Windows平台搭建依赖环境

在【入门基础篇】中详细的介绍了onnxruntime环境的搭建以及ONNXRuntime推理核心流程代码,不再重复赘述。


ONNXRuntime推理代码

需要配置imagenet_classes.txt【百度云下载,提取码:rkz7 】文件存储1000类分类标签,假设是用户自定的分类任务,需要根据实际情况作出修改,并将其放置到工程目录下(推荐)。

这里需要将resnet50.onnx放置到工程目录下(推荐),并且将以下推理代码拷贝到新建的cpp文件中,并执行查看结果。

#include "onnxruntime_cxx_api.h"
#include "cpu_provider_factory.h"
#include <opencv2/opencv.hpp>
#include <fstream>

// 加载标签文件获得分类标签
std::string labels_txt_file = "D:/C++_demo/onnxruntime_onnx/imagenet_classes.txt";
std::vector<std::string> readClassNames();
std::vector<std::string> readClassNames()
{
	std::vector<std::string> classNames;

	std::ifstream fp(labels_txt_file);
	if (!fp.is_open())
	{
		printf("could not open file...\n");
		exit(-1);
	}
	std::string name;
	while (!fp.eof())
	{
		std::getline(fp, name);
		if (name.length())
			classNames.push_back(name);
	}
	fp.close();
	return classNames;
}

int main(int argc, char** argv) {
	// 预测的目标标签数
	std::vector<std::string> labels = readClassNames();

	// 测试图片
	cv::Mat image = cv::imread("D:/C++_demo/onnxruntime_onnx/lion.jpg");
	cv::imshow("输入图", image);

	// 初始化ONNXRuntime环境
	Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "ResNet-onnx");

	// 设置会话选项
	Ort::SessionOptions session_options;
	// 优化器级别:基本的图优化级别
	session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
	// 线程数:4
	session_options.SetIntraOpNumThreads(4);
	// 设备使用优先使用GPU而是才是CPU
	std::cout << "onnxruntime inference try to use GPU Device" << std::endl;
	OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0);
	OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1);

	// onnx训练模型文件
	std::string onnxpath = "D:/C++_demo/onnxruntime_onnx/resnet50.onnx";
	std::wstring modelPath = std::wstring(onnxpath.begin(), onnxpath.end());

	// 加载模型并创建会话
	Ort::Session session_(env, modelPath.c_str(), session_options);

	// 获取模型输入输出信息
	int input_nodes_num = session_.GetInputCount();			// 输入节点输
	int output_nodes_num = session_.GetOutputCount();		// 输出节点数
	std::vector<std::string> input_node_names;				// 输入节点名称
	std::vector<std::string> output_node_names;				// 输出节点名称
	Ort::AllocatorWithDefaultOptions allocator;		
	// 输入图像尺寸
	int input_h = 0;		
	int input_w = 0;

	// 获取模型输入信息
	for (int i = 0; i < input_nodes_num; i++) {
		// 获得输入节点的名称并存储
		auto input_name = session_.GetInputNameAllocated(i, allocator);
		input_node_names.push_back(input_name.get());
		// 显示输入图像的形状
		auto inputShapeInfo = session_.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
		int ch = inputShapeInfo[1];
		input_h = inputShapeInfo[2];
		input_w = inputShapeInfo[3];
		std::cout << "input format: " << ch << "x" << input_h << "x" << input_w << std::endl;
	}

	// 获取模型输出信息
	int num = 0;
	int nc = 0;
	for (int i = 0; i < output_nodes_num; i++) {
		// 获得输出节点的名称并存储
		auto output_name = session_.GetOutputNameAllocated(i, allocator);
		output_node_names.push_back(output_name.get());
		// 显示输出结果的形状
		auto outShapeInfo = session_.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
		num = outShapeInfo[0];
		nc = outShapeInfo[1];
		std::cout << "output format: " << num << "x" << nc << std::endl;
	}

	// 输入数据预处理
	cv::Mat rgb, blob;
	// 默认是BGR需要转化成RGB
	cv::cvtColor(image, rgb, cv::COLOR_BGR2RGB);
	// 对图像尺寸进行缩放
	cv::resize(rgb, blob, cv::Size(input_w, input_h));
	blob.convertTo(blob, CV_32F);
	// 对图像进行标准化处理
	blob = blob / 255.0;	// 归一化
	cv::subtract(blob, cv::Scalar(0.485, 0.456, 0.406), blob);	// 减去均值
	cv::divide(blob, cv::Scalar(0.229, 0.224, 0.225), blob);	//除以方差
	// CHW-->NCHW 维度扩展
	cv::Mat timg = cv::dnn::blobFromImage(blob);
	std::cout << timg.size[0] << "x" << timg.size[1] << "x" << timg.size[2] << "x" << timg.size[3] << std::endl;
	// 占用内存大小,后续计算是总像素*数据类型大小
	size_t tpixels = input_h * input_w * 3;
	std::array<int64_t, 4> input_shape_info{ 1, 3, input_h, input_w };

	// 准备数据输入
	auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
	Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, timg.ptr<float>(), tpixels, input_shape_info.data(), input_shape_info.size());
	
	// 模型输入输出所需数据(名称及其数量),模型只认这种类型的数组
	const std::array<const char*, 1> inputNames = { input_node_names[0].c_str() };
	const std::array<const char*, 1> outNames = { output_node_names[0].c_str() };
	
	// 模型推理
	std::vector<Ort::Value> ort_outputs;
	try {
		ort_outputs = session_.Run(Ort::RunOptions{ nullptr }, inputNames.data(), &input_tensor_, 1, outNames.data(), outNames.size());
	}
	catch (std::exception e) {
		std::cout << e.what() << std::endl;
	}
	// 1x5 获取输出数据并包装成一个cv::Mat对象,为了方便后处理
	const float* pdata = ort_outputs[0].GetTensorMutableData<float>();
	cv::Mat prob(num, nc, CV_32F, (float*)pdata);

	// 后处理推理结果
	cv::Point maxL, minL;		// 用于存储图像分类中的得分最小值索引和最大值索引(坐标)
	double maxv, minv;			// 用于存储图像分类中的得分最小值和最大值
	cv::minMaxLoc(prob, &minv, &maxv, &minL, &maxL); 

	int max_index = maxL.x;		// 获得最大值的索引,只有一行所以列坐标既为索引
	std::cout << "label id: " << max_index << std::endl;
	// 在测试图像上加上预测的分类标签
	cv::putText(image, labels[max_index], cv::Point(50, 50), cv::FONT_HERSHEY_SIMPLEX, 1.0, cv::Scalar(0, 0, 255), 2, 8);
	cv::imshow("输入图像", image);
	cv::waitKey(0);

	// 释放资源
	session_options.release();
	session_.release();
	return 0;
}

图片正确预测为狮子(lion):

其实图像分类网络的部署代码基本是一致的,几乎不需要修改,只需要修改传入的图片数据已经训练模型权重即可。


总结

尽可能简单、详细的讲解了C++下onnxruntime环境部署ResNet模型的过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2127867.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

这种钛镍合金不简单!强度高且模量低,制造过程也不难

大家好&#xff0c;今天我们要来聊聊一种神奇的合金——《A polymer-like ultrahigh-strength metal alloy》发表于《Nature》。在许多新兴技术中&#xff0c;比如变形飞机和超人型人工肌肉&#xff0c;都需要一种既强又灵活的金属合金。但长久以来&#xff0c;要实现这种“强而…

电商品牌假货要怎么处理

在电商蓬勃发展的今日&#xff0c;假货问题如影随形&#xff0c;严重威胁着品牌的声誉与市场的健康。力维网络以专业打假服务&#xff0c;为品牌保驾护航。 一、精准监测&#xff0c;揪出假货端倪 力维网络的数据监测系统犹如一张严密的大网&#xff0c;覆盖全网。通过全面采集…

828华为云征文 | 华为云Flexus X实例上实现Docker容器的实时监控与可视化分析

前言 华为云Flexus X&#xff0c;以顶尖算力与智能调度&#xff0c;引领Docker容器管理新风尚。828企业上云节之际&#xff0c;Flexus X携手前沿技术&#xff0c;实现容器运行的实时监控与数据可视化&#xff0c;让管理变得直观高效。无论是性能瓶颈的精准定位&#xff0c;还是…

揭晓2024年上半年热门跨境电商平台排行榜完整版,排在第二的居然是它!

随着全球电商市场的持续发展和融合&#xff0c;跨境电商平台已成为众多商家拓展国际市场的重要渠道。面对琳琅满目的平台选择&#xff0c;卖家如何做出明智的决策&#xff0c;成为了关注的焦点。本文将从今年上半年GMV这个维度来盘点一下热门电商平台的最新排行榜&#xff0c;有…

qwen2 VL 多模态图文模型;图像、视频使用案例

参考&#xff1a; https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct 模型&#xff1a; export HF_ENDPOINThttps://hf-mirror.comhuggingface-cli download --resume-download --local-dir-use-symlinks False Qwen/Qwen2-VL-2B-Instruct --local-dir qwen2-vl安装&#x…

你不得不知的日志级别

前言 写日志是一项具有挑战性的任务&#xff0c;在工作中我们常常面临一些困境&#xff0c;比如&#xff1a; 开发人员在编写代码时常常陷入纠结&#xff0c;不确定在何处打印日志才是最有意义的&#xff1b;SRE人员在调查生产问题时可能因为缺乏必要的日志信息而束手无策&am…

基于SSM的“高校学生社团管理系统”的设计与实现(源码+数据库+文档)

基于SSM的“高校学生社团管理系统”的设计与实现&#xff08;源码数据库文档) 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SSM 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 系统结构图 首页 注册 登录 后台首页界面 社团公告页面 留…

Engage2024用户大会成功举办,数聚股份携手销售易共绘数字化转型新篇章

2024年9月5日&#xff0c;销售易第六届用户大会Engage2024在上海盛大举行。销售易&#xff0c;作为唯一一家入选Gartner SFA魔力象限、且产品能力全球前四的国产CRM软件&#xff0c;当之无愧是国产CRM软件的龙头&#xff0c;其用户大会自然就是CRM领域盛会&#xff0c;汇聚了众…

生命周期函数

所有继承MonoBehavior的脚本 最终都会挂载到Gameobiject游戏对象上 1.生命周期西数 就是该脚本对象依附的Gameobject对象从出生到消亡整个生命周期中 会通过反射自动调用的一些特殊函数 2.Unity帮助我们记录了一个Gameobject对象依附了哪些脚本 会自动的得到这些对象&#x…

视频监控系统中的云镜控制PTZ详细介绍,以及视频监控接入联网平台相关协议对PTZ的支持

目录 一、PTZ概述 二、PTZ 控制的应用场景 1、公共场所 2、安全监控 3、交通监控 4、工业生产环境中的质量监控 5、大型活动的现场直播或录制 三、PTZ摄像的优缺点 1、优点 2、缺点 四、PTZ控制的基本原理 1、云台控制 2、镜头控制 五、 PTZ 控制协议 1. Pelco-…

深度学习时遇到tensor([0.], device=‘cuda:0‘)等输出

更改了数据集后进行训练遇到了以下输出&#xff0c;精度正常提升&#xff0c;训练正常&#xff0c;就是精度和map之间又很多输出&#xff0c;如下&#xff1a; tensor([0.], devicecuda:0), tensor([0.], devicecuda:0), tensor([0.], devicecuda:0), tensor([0.], devicecuda…

NAT技术+代理服务器+内网穿透

NAT技术 IPv4协议中&#xff0c;会存在IP地址数量不充足的问题&#xff0c;所以不同的子网中会存在相同IP地址的主机。那么就可以理解为私有网络的IP地址并不是唯一对应的&#xff0c;而公网中的IP地址都是唯一的&#xff0c;所以NAT&#xff08;Network Address Translation&…

往复密封问题的两个问题

&#x1f3c6;本文收录于《CSDN问答解惑-专业版》专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收…

使用ChatGPT高质量撰写文献综述全攻略实操指南,五步轻松搞定!

大家好,感谢关注。我是七哥,一个在高校里不务正业,折腾学术科研AI实操的学术人。关于使用ChatGPT等AI学术科研的相关问题可以和作者七哥(yida985)交流,多多交流,相互成就,共同进步,为大家带来最酷最有效的智能AI学术科研写作攻略。 在学术研究中,文献综述很重要,但…

无线感知会议系列【2】【智能无感感知 特征,算法,数据集】

前言&#xff1a; 这篇来自 2022 泛在可信智能感知 论坛 作者&#xff1a; 清华大学杨铮教授 视频&#xff1a; 2.智能无线感知&#xff1a;特征、算法、数据集&#xff1b; 杨峥 清华大学 副教授_哔哩哔哩_bilibili 这篇论文前面有讲过,我前面的博客也有基于提供的数据集做了…

关于打不开SOAMANAGER如何解决

参考文章&#xff1a;https://blog.csdn.net/yannickdann/article/details/115396035 打开SE93

Python字典实战题目练习,巩固知识、检查技术

本文主要是作为Python中列表的一些题目&#xff0c;方便学习完Python的列表之后进行一些知识检验&#xff0c;感兴趣的小伙伴可以试一试&#xff0c;含选择题、判断题、实战题&#xff0c;答案在第四章。 在做题之前可以先学习或者温习一下Python的列表&#xff0c;推荐阅读下面…

沃尔玛测评防关联技术:自养号攻略全面解析

防关联技术 1.使用国外的服务器和防火墙&#xff1a;为了确保测评活动的隐蔽性和安全性&#xff0c;卖家应选择使用国外的服务器&#xff0c;并通过远程搭建一个安全终端防火墙。这样可以阻断硬件参数的关联问题&#xff0c;降低被沃尔玛平台检测到的风险。 2.创建住宅专线IP…

《食品安全导刊》是什么级别的期刊?是正规期刊吗?能评职称吗?

问题解答 问&#xff1a;《食品安全导刊》是不是核心期刊&#xff1f; 答&#xff1a;不是&#xff0c;是知网收录的正规学术期刊。 问&#xff1a;《食品安全导刊》级别&#xff1f; 答&#xff1a;国家级。主管单位&#xff1a; 中国商业联合会 主办单…

解析DNS查询报文,探索DNS工作原理

目录 1. 用 tcpdump工具监听抓包 2. 用 host 工具获取域名对应的IP地址 3. 分析DNS以太网查询数据帧 3.1 linux下查询DNS服务器IP地址 3.2 DNS以太网查询数据帧 &#xff08;1&#xff09;数据链路层 &#xff08;2&#xff09;网络层 &#xff08;3&#xff09;传输层…