1.环境安装
conda create -n LYT_Torch python=3.9 -y
conda activate LYT_Torch
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm tensorboard
pip install einops gdown addict future lmdb numpy pyyaml requests scipy yapf lpips thop timm
2. Onnx模型推理
2.1 Onnxruntime
ONNX Runtime(Open Neural Network Exchange Runtime)是一个用于推理(inference)的高性能、跨平台的开源引擎,它由微软公司发起并维护。ONNX Runtime 支持多种机器学习模型,包括但不限于 PyTorch、TensorFlow、SciKit-Learn 等,允许开发者在不同框架之间无缝迁移模型,并在生产环境中高效运行。
- 跨平台:支持 Windows、Linux 和 macOS 等多种操作系统。
- 高性能:优化了模型的推理速度,特别是在使用 GPU 加速时。
- 多框架支持:原生支持 ONNX 格式的模型,同时也支持从其他框架转换来的模型。
- 实时性:适用于需要实时或近实时推理的应用场景。
- 易于集成:提供了 C++ 和 Python API,方便集成到不同语言的应用程序中。
- 自动混合精度:支持自动混合精度(AMP),可以在不损失模型精度的情况下提高推理速度。
- 量化支持:支持模型量化,可以进一步减小模型大小并提高推理速度。
- 微服务:支持作为微服务部署,易于在云环境或边缘设备上使用。
使用 ONNX Runtime 进行模型推理的基本步骤如下:
- 转换模型:如果模型不是 ONNX 格式,需要使用相应的转换工具将其转换为 ONNX 格式。
- 加载模型:使用 ONNX Runtime 的 API 加载转换后的模型。
- 准备输入:准备模型推理所需的输入数据,并确保其符合模型的输入要求。
- 运行推理:调用 ONNX Runtime 的推理 API 运行模型并获取输出。
- 处理输出:根据应用需求处理模型的推理输出。
2.2 C++推理
#include "LYTNet.h"
LYTNet::LYTNet(std::string model_path)
{
OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); ///cuda
sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
std::wstring widestr = std::wstring(model_path.begin(), model_path.end()); windows
ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions); windows*/
//ort_session = new Session(env, model_path.c_str(), sessionOptions); linux
size_t numInputNodes = ort_session->GetInputCount();
size_t numOutputNodes = ort_session->GetOutputCount();
Ort::AllocatorWithDefaultOptions allocator;
for (int i = 0; i < numInputNodes; i++)
{
input_names.push_back(ort_session->GetInputName(i, allocator));
Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);
auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
auto input_dims = input_tensor_info.GetShape();
input_node_dims.push_back(input_dims);
}
for (int i = 0; i < numOutputNodes; i++)
{
output_names.push_back(ort_session->GetOutputName(i, allocator));
Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
auto output_dims = output_tensor_info.GetShape();
output_node_dims.push_back(output_dims);
}
this->inpHeight = input_node_dims[0][1];
this->inpWidth = input_node_dims[0][2];
this->outHeight = output_node_dims[0][1];
this->outWidth = output_node_dims[0][2];
}
/***************** Mat转vector **********************/
template<typename _Tp>
std::vector<_Tp> convertMat2Vector(const cv::Mat& mat)
{
return (std::vector<_Tp>)(mat.reshape(1, 1));//通道数不变,按行转为一行
}
/****************** vector转Mat *********************/
template<typename _Tp>
cv::Mat convertVector2Mat(std::vector<_Tp> v, int channels, int rows)
{
cv::Mat mat = cv::Mat(v).clone();//将vector变成单列的mat,这里需要clone(),因为这里的赋值操作是浅拷贝
cv::Mat dest = mat.reshape(channels, rows);
return dest;
}
cv::Mat LYTNet::detect(cv::Mat srcimg)
{
cv::Mat dstimg;
resize(srcimg, dstimg, cv::Size(this->inpWidth, this->inpHeight));
dstimg.convertTo(dstimg, CV_32FC3, 1 / 127.5, -1.0);
this->input_image_ = (std::vector<float>)(dstimg.reshape(1, 1));
// const size_t area = this->inpWidth * this->inpHeight * 3;
// this->input_image_.resize(area);
// memcpy(this->input_image_.data(), (float*)dstimg.data, area*sizeof(float));
std::array<int64_t, 4> input_shape_{ 1, this->inpHeight, this->inpWidth, 3 };
auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, input_image_.data(),
input_image_.size(), input_shape_.data(), input_shape_.size());
std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, &input_names[0],
&input_tensor_, 1, output_names.data(), output_names.size());
float* pred = ort_outputs[0].GetTensorMutableData<float>();
cv::Mat output_image(outHeight, outWidth, CV_32FC3, pred);
output_image = (output_image + 1.0) * 127.5;
output_image.convertTo(output_image, CV_8UC3);
resize(output_image, output_image, cv::Size(srcimg.cols, srcimg.rows));
return output_image;
}
2.3 Python推理部署
import argparse
import cv2
import numpy as np
import onnxruntime
class LYTNet:
def __init__(self, modelpath):
# Initialize model
# self.net = cv2.dnn.readNet(modelpath) ####opencv-dnn读取失败
so = onnxruntime.SessionOptions()
so.log_severity_level = 3
self.net = onnxruntime.InferenceSession(modelpath, so)
self.input_height = self.net.get_inputs()[0].shape[1] ####(1,h,w,3)
self.input_width = self.net.get_inputs()[0].shape[2]
self.input_name = self.net.get_inputs()[0].name
def detect(self, srcimg):
input_image = cv2.resize(srcimg, (self.input_width, self.input_height))
input_image = input_image.astype(np.float32) / 127.5 - 1.0
blob = np.expand_dims(input_image, axis=0).astype(np.float32)
result = self.net.run(None, {self.input_name: blob})
# Post process:squeeze, RGB->BGR, Transpose, uint8 cast
output_image = np.squeeze(result[0])
output_image = (output_image + 1.0 ) * 127.5
output_image = output_image.astype(np.uint8)
output_image = cv2.resize(output_image, (srcimg.shape[1], srcimg.shape[0]))
return output_image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--imgpath', type=str,
default='testimgs/4_1.JPG', help="image path")
parser.add_argument('--modelpath', type=str,
default='weights/lyt_net_lolv2_real_320x240.onnx', help="model path")
args = parser.parse_args()
mynet = LYTNet(args.modelpath)
srcimg = cv2.imread(args.imgpath)
dstimg = mynet.detect(srcimg)
if srcimg.shape[0] >= srcimg.shape[1]:
boundimg = np.zeros((srcimg.shape[0], 10, 3), dtype=srcimg.dtype)+255
combined_img = np.hstack([srcimg, boundimg, dstimg])
else:
boundimg = np.zeros((10, srcimg.shape[1], 3), dtype=srcimg.dtype)+255
combined_img = np.vstack([srcimg, boundimg, dstimg])
winName = 'Deep Learning use OpenCV-dnn'
cv2.namedWindow(winName, 0)
cv2.imshow(winName, combined_img) ###原图和结果图也可以分开窗口显示
cv2.waitKey(0)
cv2.destroyAllWindows()
2.4 实现效果
代码与模型下载地址:https://download.csdn.net/download/matt45m/89600913