DepthAnything(2): 基于ONNXRuntime在ARM(aarch64)平台部署DepthAnything

news2024/7/30 5:59:35

DepthAnything(1): 先跑一跑Depth Anything_depth anything离线怎么跑-CSDN博客
 


目录

1. 写在前面

2. 安装推理组件

3. 生成ONNX

4. 准备ONNXRuntime库

5. API介绍

6. 例程


1. 写在前面

        DepthAnything是一种能在任何情况下处理任何图像的简单却又强大的深度估计模型。

2. 安装推理组件

        针对有GPU加持的场景,以NVIDIA显卡为例,首先需要安装GPU驱动。

        然后再安装CUDA、cuDNN和ONNX Runtime,一般情况下,在有显卡的系统中,我们选择GPU版本。

        进入如下链接,可以查看CUDA、cuDNN、ONNX Runtime库的版本对应关系。

        NVIDIA - CUDA | onnxruntime

        如下所示,ONNX Runtime与CUDA、cuDNN的版本需要匹配,否则可能出现我发调用GPU进行推理。

        针对没有GPU的场景,我们使用CPU进行推理。可直接下载相应平台的onnx-runtime库。

3. 生成ONNX

        使用DepthAnything训练工程下的export_onnx.py文件导出ONNX模型。

        注意,导出的时候,需要注意分辨率,DepthAnything到处分辨率必须是14的整数倍。

        另外,如果导出分辨率过小,可能会导致识别的深度图失效,因此一般建议导出时,分辨率选择518*518。

        导出onnx模型参考代码如下。

import argparse


import torch

from onnx import load_model, save_model

from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

from depth_anything.dpt import DPT_DINOv2

def parse_args() -> argparse.Namespace:

    parser = argparse.ArgumentParser()

    parser.add_argument(

        "--model",

        type=str,

        choices=["s", "b", "l"],

        required=True,

        help="Model size variant. Available options: 's', 'b', 'l'.",

    )

    parser.add_argument(

        "--output",

        type=str,

        default=None,

        required=False,

        help="Path to save the ONNX model.",

    )

    return parser.parse_args()

def export_onnx(model: str, input: str, output: str = None):

    # Handle args

    if output is None:

        output = f"weights/depth_anything_vit{model}14_ori.onnx"

    # Device for tracing (use whichever has enough free memory)

    device = torch.device("cpu")

    # Sample image for tracing (dimensions don't matter)

    image = torch.rand(1, 3, 518, 518).to(device) # [SAI-KEY] 必须是14的倍数

    # Load model params

    if model == "s":

        depth_anything = DPT_DINOv2(

            encoder="vits", features=64, out_channels=[48, 96, 192, 384]

        )

    elif model == "b":

        depth_anything = DPT_DINOv2(

            encoder="vitb", features=128, out_channels=[96, 192, 384, 768]

        )

    else:  # model == "l"

        depth_anything = DPT_DINOv2(

            encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]

        )

    weights = torch.load(input)

    depth_anything.to(device).load_state_dict(weights)

    depth_anything.eval()

    torch.onnx.export(

        depth_anything,

        image,

        output,

        input_names=["image"], # 列表,如果有多个输入,应按照顺序,依次

        output_names=["depth"],

        opset_version=12,

    )

    save_model(

        SymbolicShapeInference.infer_shapes(load_model(output), auto_merge=True),

        output,

    )

if __name__ == "__main__":

    export_onnx("l", "/2T/001_AI/8001_DepthAnything/003_Models/checkpoints/depth_anything_vitl14.pth", "/2T/001_AI/8001_DepthAnything/003_Models/checkpoints/depth_anything_vitl14.onnx")

4. 准备ONNXRuntime库

        登录链接https://github.com/microsoft/onnxruntime/releases下载相应的ONNX Runtime推理库。如下所示,可选择linux、windows、osx系统下的库,以及选择x86或aarch64架构的库。需要说明的是,aarch64目前仅支持CPU版本。

        不同系统的库,引入方式也不同。

        例如,下载onnxruntime-linux-aarch64-1.18.1.tgz库,解压后可以将其中lib文件夹下的libonnxruntime.so和libonnxruntime.so.1.18.1复制到/usr/lib路径下。

        include头文件可移动可不移动,也可重命名后,加入到工程中。如果需要添加到工程中,需要注意Makefile中的包含项。

        如果是下载的windows平台的库,一般是include文件和dll链接库,按照不同IDE的引入方式来就可以。

5. API介绍

        本小节简单介绍几个API,具体使用可以参照后续小节的例程加以探索和理解。

(1)Ort::Env(ORT_LOGGING_LEVEL_WARNING, "depthAnything_mono");

        ORT_LOGGING_LEVEL_VERBOSE:最详细的日志信息,包括所有信息。

        ORT_LOGGING_LEVEL_INFO:一般的信息,例如模型加载和推理进度。

        ORT_LOGGING_LEVEL_WARNING:警告级别的日志,例如潜在的问题或性能下降,仅输出警告日志。

        ORT_LOGGING_LEVEL_ERROR:错误级别的日志,例如无法恢复的错误。

        ORT_LOGGING_LEVEL_FATAL:致命错误,通常是程序无法继续执行的错误。

(2)OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0); ///< 相当于指定执行后端,如果不指定,则默认使用CPU

        参数1:

        参数2:设备号

(3)GetInputCount()和GetOutputCount()

        获得输入和输出的数量。

(4)GetInputTypeInfo(i)

        获取第i个输入的信息。

(5)GetShape()

        获取输入或输出的shape信息。

6. 例程

        以下例程以DepthAnything利用ONNXRuntime与OpenCV实现对一张图片的深度估计、并将结果存储到本地。

#include <assert.h>

#include <vector>

#include <ctime>

#include <iostream>

#include <chrono>

#include <onnxruntime_cxx_api.h>

#include <opencv2/core.hpp>

#include <opencv2/imgproc.hpp>

#include <opencv2/highgui.hpp>

#include <opencv2/videoio.hpp>


using namespace cv;

using namespace std;

int input_width = 518;

int input_height = 518;

class DepthAnything

{

public:

    DepthAnything(std::string onnx_model_path);

    std::vector<float> predict(std::vector<float>& input_data, int batch_size = 1, int index = 0);

    cv::Mat predict(cv::Mat& input_tensor, int batch_size = 1, int index = 0);

private:

    Ort::Env env;

    Ort::Session session;

    Ort::AllocatorWithDefaultOptions allocator;

    std::vector<const char*>input_node_names = {"image"}; ///< 生成onnx时的输入节点名

    std::vector<const char*>output_node_names = {"depth"}; ///< 生成onnx时的输出节点名

    std::vector<int64_t> input_node_dims;

    std::vector<int64_t> output_node_dims;

};

DepthAnything::DepthAnything(std::string onnx_model_path) :session(nullptr), env(nullptr)

{

    /** 初始化ORT环境. */

    this->env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "DepthAnything_ORT");

    /** 初始化ORT会话选项. */

    Ort::SessionOptions session_options;

    // session_options.SetInterOpNumThreads(1);

    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC); ///< ORT_ENABLE_ALL

    /** 初始化ORT会话. */

    this->session = Ort::Session(env, onnx_model_path.data(), session_options);

    /** 输入输出节点数量. */

    size_t num_input_nodes = session.GetInputCount();

    size_t num_output_nodes = session.GetOutputCount();

    for (int i = 0; i < num_input_nodes; i++){

        Ort::TypeInfo type_info = session.GetInputTypeInfo(i);

        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();

        ONNXTensorElementDataType type = tensor_info.GetElementType();

        this->input_node_dims = tensor_info.GetShape();

        for(int i=0; i<this->input_node_dims.size(); i++){

            printf("shape[%d]: %d\n", i, this->input_node_dims[i]);

        }

    }

    for (int i = 0; i < num_output_nodes; i++){

        Ort::TypeInfo type_info = session.GetOutputTypeInfo(i);

        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();

        ONNXTensorElementDataType type = tensor_info.GetElementType();

        this->output_node_dims = tensor_info.GetShape();

    }

}

std::vector<float> DepthAnything::predict(std::vector<float>& input_tensor_values, int batch_size, int index)

{

    this->input_node_dims[0] = batch_size;

    this->output_node_dims[0] = batch_size;

    float* floatarr = nullptr;

    std::vector<const char*>output_node_names;

    if (index != -1){

        output_node_names = { this->output_node_names[index] };

    }else{

        output_node_names = this->output_node_names;

    }

    this->input_node_dims[0] = batch_size;

    auto input_tensor_size = input_tensor_values.size();

    /** 创建Tensor对象. */

    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); ///< 创建CPU内存信息

    Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), 4); ///< 创建输入张量

    /** 执行推理. */

    auto output_tensors = session.Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor, 1, output_node_names.data(), 1);

    assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());

    floatarr = output_tensors[0].GetTensorMutableData<float>(); ///< 获取输出张量

    int64_t output_tensor_size = 1;

    for (auto& it : this->output_node_dims){

        output_tensor_size *= it;

    }

    std::vector<float>results(output_tensor_size);

    for (unsigned i = 0; i < output_tensor_size; i++){

        results[i] = floatarr[i];

    }

    return results;

}

cv::Mat DepthAnything::predict(cv::Mat& input_tensor, int batch_size, int index)

{

    int input_tensor_size = input_tensor.cols * input_tensor.rows * 3;

    std::size_t counter = 0;

    std::vector<float>input_data(input_tensor_size);

    std::vector<float>output_data;

    /** 转换RGB Planar, 归一化. */

    for (unsigned k = 0; k < 3; k++){

        for (unsigned i = 0; i < input_tensor.rows; i++){

            for (unsigned j = 0; j < input_tensor.cols; j++){

                input_data[counter++] = static_cast<float>(input_tensor.at<cv::Vec3b>(i, j)[k]) / 255.0;

            }

        }

    }

    /** 推理. */

    output_data = this->predict(input_data);

   

    /** 后处理. */

    cv::Mat output_tensor(output_data);

    output_tensor =output_tensor.reshape(1, {input_width, input_height});

    double minVal, maxVal;

    cv::minMaxLoc(output_tensor, &minVal, &maxVal); ///< 获取最大值、最小值.

    output_tensor.convertTo(output_tensor, CV_32F); ///< 转换数据类型,float32类型.

    if (minVal != maxVal) {

        output_tensor = (output_tensor - minVal) / (maxVal - minVal);

       

    }

    output_tensor *= 255.0;

    output_tensor.convertTo(output_tensor, CV_8UC1); ///< 转单通道(灰度图).

    cv::applyColorMap(output_tensor, output_tensor, cv::COLORMAP_HOT); ///< 伪彩映射.

    return output_tensor;

}

std::chrono::time_point<std::chrono::high_resolution_clock> tic;

std::chrono::time_point<std::chrono::high_resolution_clock> toc;

std::chrono::milliseconds elapsed;

int main(int argc, char* argv[])

{

    std::string model_path = "/zqpe/8001_DepthAnything_OnnxRuntime/out/bin/depth_anything_vits14.onnx";

    std::string image_path = "/zqpe/8001_DepthAnything_OnnxRuntime/out/bin/204995.jpg";

   

    printf("Construct depth anything inference engine.\n");

    tic = std::chrono::high_resolution_clock::now();

    DepthAnything model(model_path);

    toc = std::chrono::high_resolution_clock::now();

    elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(toc - tic);

    printf("Construct depth anything inference engine, takes %ld ms\n", elapsed.count());

    printf("Prepare sample.\n");

    tic = std::chrono::high_resolution_clock::now();

    cv::Mat image = cv::imread(image_path);

    auto ori_h = image.cols;

    auto ori_w = image.rows;

    // cv::imshow("image", image);

    cv::cvtColor(image, image, cv::COLOR_BGR2RGB);

    cv::resize(image, image, {input_width, input_height}, 0.0, 0.0, cv::INTER_CUBIC);

    toc = std::chrono::high_resolution_clock::now();

    elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(toc - tic);

    printf("Prepare sample, takes %ld ms\n", elapsed.count());

    cv::Mat result;

    // while(1){

        printf("Do inference.\n");

        tic = std::chrono::high_resolution_clock::now();

        result = model.predict(image);

        toc = std::chrono::high_resolution_clock::now();

        elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(toc - tic);

        printf("Do inference, takes %ld ms\n", elapsed.count());

    // }

    cv::resize(result, result, {ori_h, ori_w}, 0.0, 0.0, cv::INTER_CUBIC);

   

    printf("Save result.\n");

    int pos = image_path.rfind(".");

    image_path.insert(pos, "_depth");

    cv::imwrite(image_path, result);

}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1917066.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

汽车预约维修小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;技师管理&#xff0c;技师信息管理&#xff0c;用户预约管理&#xff0c;取消预约管理&#xff0c;订单信息管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;技师信息&a…

揭秘焦虑症的“隐形杀手“:这些并发症可能悄悄来袭!

揭秘焦虑症的"隐形杀手"&#xff1a;这些并发症可能悄悄来袭&#xff01;在快节奏的现代生活中&#xff0c;焦虑症已经成为越来越多人面临的心理健康挑战。然而&#xff0c;除了广为人知的焦虑、紧张、失眠等症状外&#xff0c;焦虑症还可能引发一系列看似与焦虑无关…

每天五分钟计算机视觉:目标检测算法之R-CNN

本文重点 在计算机视觉领域,目标检测一直是一个核心问题,旨在识别图像中的物体并定位其位置。随着深度学习技术的发展,基于卷积神经网络(CNN)的目标检测算法取得了显著的进步。其中,R-CNN(Regions with CNN features)是一种开创性的目标检测框架,为后续的研究提供了重…

【高中数学/指数、对数】已知9^m=10,a=10^m-11,b=8^m-9,则ab两数和0的大小关系是?(2022年全国统考高考真题)

【问题】 已知9^m10,a10^m-11,b8^m-9,则&#xff08;&#xff09; A.a>0>b B.a>b>0 C.b>a>0 D.b>0>a 【解答】 首先注意到10^log10_11-110,8^log8_9-90&#xff0c; 问题就转化为log8_9,log9_10,log10_11谁大谁小的问题&#xff0c; 再进一步…

maven高级1——一个项目拆成多个

把原来一个项目&#xff0c;拆成多个项目。 &#xff01;&#xff01;他们之间&#xff0c;靠接口通信。 以ssm整合好的项目为例&#xff1a; 如何看拆的ok不ok 只要compile通过就ok。 拆分pojo 先新建一个项目模块&#xff0c;再把内容复制进去。 拆分dao 1.和上面一样…

可控学习综述:信息检索中的方法、应用和挑战

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

统信UOS桌面操作系统上删除系统升级后GRUB中的回滚条目与备份

原文链接&#xff1a;统信UOS删除升级后GRUB中的回滚条目与备份 Hello&#xff0c;大家好啊&#xff01;今天给大家带来一篇关于在统信UOS桌面操作系统上删除系统升级后GRUB中的回滚条目与备份的文章。在进行系统升级后&#xff0c;GRUB引导菜单中可能会出现多个回滚条目和备份…

【MySQL】常见的MySQL日志都有什么用?

MySQL日志的内容非常重要&#xff0c;面试中经常会被问到。同时&#xff0c;掌握日志相关的知识也有利于我们理解MySQL 底层原理&#xff0c;必要时帮助我们排查解决问题。 MySQL中常见的日志类型主要有下面几类(针对的是InnoDB 存储引擎): 错误日志(error log):对 MySQL 的启…

51单片机:电脑通过串口控制LED亮灭(附溢出率和波特率详解)

一、功能实现 1.电脑通过串口发送数据&#xff1a;0F 2.点亮4个LED 二、注意事项 1.发送和接受数据的文本模式 2.串口要对应 3.注意串口的波特率要和程序中的波特率保持一致 4.有无校验位和停止位 三、如何使用串口波特率计算器 1.以本程序为例 2.生成代码如下 void Uar…

【漏洞复现】Crocus系统——Download——文件读取

声明&#xff1a;本文档或演示材料仅供教育和教学目的使用&#xff0c;任何个人或组织使用本文档中的信息进行非法活动&#xff0c;均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 Crocus系统旨在利用人工智能、高清视频、大数据和自动驾驶技术&…

[论文笔记]涨点近5%! 以内容中心的检索增强生成可扩展的级联框架:Pistis-RAG

引言 今天带来一篇较新RAG的论文笔记&#xff1a;Pistis-RAG: A Scalable Cascading Framework Towards Content-Centric Retrieval-Augmented Generation。 在希腊神话中&#xff0c;Pistis象征着诚信、信任和可靠性。受到这些原则的启发&#xff0c;Pistis-RAG是一个可扩展…

详细分析Java中的@EventListener事件监听器(附Demo)

目录 前言1. 基本知识2. Demo 前言 Java的基本知识推荐阅读&#xff1a; java框架 零基础从入门到精通的学习路线 附开源项目面经等&#xff08;超全&#xff09;Spring框架从入门到学精&#xff08;全&#xff09; 1. 基本知识 用于标注一个方法为事件监听器 事件监听器方…

前端面试题43(JavaScript几种排序)

JavaScript 中有多种排序算法可供使用&#xff0c;每种算法都有其特点和适用场景。下面是一些常见的排序算法&#xff0c;它们可以手动实现&#xff0c;也可以通过 JavaScript 内置的 Array.prototype.sort() 方法简化操作。 1. 冒泡排序&#xff08;Bubble Sort&#xff09; …

beyond Compare连接 openWrt 和 VsCode

连接步骤总结 1. 新建会话 -> 文件夹比较 2.点击浏览文件夹 3.在弹出页面 配置 ftp 3.1&#xff09;选中ftp 配置文件 3.2)选中ssh2 3.3)填写我们需要远端连接的主机信息 先点击连接并浏览 得到下方文件夹 弹出无效登录&#xff0c;说明需要密码 我们返回右键刚刚创建的新 …

qt 用数据画一个图,并表示出来

1.概要 想用数据绘制一个画面&#xff0c;看有相机到播放的本质是啥。 要点 // 创建一个QImage对象&#xff0c;指定图像的宽度、高度和格式 QImage image(width, height, QImage::Format_Grayscale8); // 将像素数据复制到QImage对象中 memcpy(image.bits(), pixelD…

Milvus Cloud:重塑向量数据管理新纪元的强大引擎

在大数据与人工智能技术日新月异的今天,数据不再仅仅是简单的数字堆砌,而是蕴含着无限价值的信息宝藏。随着深度学习、自然语言处理、计算机视觉等技术的飞速发展,由这些高级机器学习模型产生的向量数据正以惊人的速度增长,成为了驱动行业创新和业务智能化转型的关键力量。…

[leetcode]kth-smallest-element-in-a-sorted-matrix 有序矩阵中第k小元素

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:bool check(vector<vector<int>>& matrix, int mid, int k, int n) {int i n - 1;int j 0;int num 0;while (i > 0 && j < n) {if (matrix[i][j] < mid) {num i 1;j;…

Linux 调试命令记录

查看CPU信息 cat /proc/cpuinfo 显示当前电源功耗 top 命令能够清晰的展现出系统的状态&#xff0c;而且它是实时的监控&#xff0c;按 q 退出。 uptime 与 w 这两个命令只是单纯的反映出负载&#xff0c;所表示的是过去的1分钟、5分钟和15分钟内进程队列中的平均进程数量。…