如何使用C++调用Pytorch模型进行推理测试:使用libtorch库

news2024/11/20 11:41:22

如何使用C++调用Pytorch模型进行推理测试:使用libtorch库

目录

      • 如何使用C++调用Pytorch模型进行推理测试:使用libtorch库
        • 一、环境准备
          • 1,linux:以ubuntu 22.04系统为例
            • 1. 准备CUDA和CUDNN
            • 2. 准备C++环境
            • 3, 下载libtorch文件
            • 4, 编写测试libtorch是否安装成功
          • 2, windows: 以win10系统为例
            • 1, 准备CUDA和CUDNN
            • 2,准备C++编译环境
            • 3,下载安装libtorch
            • 4. 注意事项
          • 二、C++代码封装Pytorch模型测试:以resnet-18分类为例
          • 1, 安装opencv用于读取图像
          • 2,用python导出训练好的pytorch模型
          • 3,编写C++代码测试

一、环境准备
1,linux:以ubuntu 22.04系统为例
1. 准备CUDA和CUDNN

有两种方式配置cuda和cudnn,一种是在系统环境安装,可以参考:深度学习环境配置——ubuntu安装CUDA与CUDNN

还有一种是在conda虚拟环境使用cudatoolkit-dev包,具体可以参考:Installing-and-Test-PyTorch-C-API-on-Ubuntu-with-GPU-enabled

我选择的方式是在系统环境安装cuda12.1和cudnn8.9.2。

可使用如下命令查看是否安装成功:

NVCC -V
cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2

image-20240625103837610

2. 准备C++环境

安装gcc, cmake和GLIBC,用apt install即可

可使用如下命令是否查看是否安装成功:

gcc --version
cmake --version
ldd --version

image-20240625103749911

3, 下载libtorch文件

去pytoch官网https://pytorch.org/下载即可:

image-20240625103946244

可使用如下命令下载并解压:

wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.3.1%2Bcu121.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.3.1+cu121.zip

将libtorch路径配置到path变量:

vim ~/.bashrc

最后一行加入:

export LD_LIBRARY_PATH=/path/to/libtorch/lib:$LD_LIBRARY_PATH

注意将/path/to/libtorch替换为实际的path,我这里是/mnt/data1/zq/libtorch

查看是否成功:

source ~/.bashrc
echo $LD_LIBRARY_PATH

image-20240625110447696

4, 编写测试libtorch是否安装成功

创建main.cpp文件,内容如下:

#include <torch/torch.h>
#include <iostream>

int main() {
    if (torch::cuda::is_available()) {
        std::cout << "CUDA is available! Running on GPU." << std::endl;
        // 创建一个随机张量并将其移到GPU上
        torch::Tensor tensor_gpu = torch::rand({2, 3}).cuda();
        std::cout << "Tensor on GPU:\n" << tensor_gpu << std::endl;
    } else {
        std::cout << "CUDA not available! Running on CPU." << std::endl;
        // 创建一个随机张量并保持在CPU上
        torch::Tensor tensor_cpu = torch::rand({2, 3});
        std::cout << "Tensor on CPU:\n" << tensor_cpu << std::endl;
    }
    return 0;
}

编译和运行

创建CMakeLists.txt文件,内容如下:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(test_project)

# Setting the C++ standard to C++17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# If additional compiler flags are needed
add_compile_options(-Wall -Wextra -pedantic)

# Setting the location of LibTorch
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)

# Specify the name of the executable and the corresponding source file
add_executable(test_project main.cpp)

# Linking LibTorch libraries
target_link_libraries(test_project "${TORCH_LIBRARIES}")

# Set the output directory for the executable
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin)

/path/to/libtorch替换为实际的path

编译并测试:

mkdir build
cd build
cmake ..
make 

编译完成之后,应该会出现一个bin目录,其中有一个test_project文件,直接运行即可看到输出。

image-20240625111448917

出现CUDAFloatType说明,libtorch的GPU版本安装成功。

2, windows: 以win10系统为例
1, 准备CUDA和CUDNN

可参考:Windows10下CUDA与cuDNN的安装

2,准备C++编译环境

这一步需要配置cmake, mingw。可参考:Windows 配置 C/C++ 开发环境

建议直接安装Visual Studio这个IDE,可参考:Windows libtorch C++部署GPU版

3,下载安装libtorch

参考这个视频:

win10系统上LibTorch的安装和使用(cuda10.1版本)

一个很水的LibTorch教程(1)

4. 注意事项

windows环境我没有做测试,不保证一定可以成功。linux环境是亲自测试的,保证可以复现

二、C++代码封装Pytorch模型测试:以resnet-18分类为例
1, 安装opencv用于读取图像

需要使用opencv来读取图像数据,可通过如下命令安装:

sudo apt install libopencv-dev
dpkg -l | grep libopencv # 查看是否安装成功
2,用python导出训练好的pytorch模型

在将PyTorch模型应用于C++环境之前,需要将其转换为TorchScript。这可以通过两种方式实现:tracingscripting。可以通过如下代码导出训练好的ResNet-18模型:

import torch
import torchvision

# 加载预训练的模型
model = torchvision.models.resnet18(pretrained=True)

# 将模型设置为评估模式
model.eval()

# 创建一个示例输入
example_input = torch.rand(1, 3, 224, 224)  # 模型输入的大小

# 使用tracing导出模型
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("resnet18.pt")
3,编写C++代码测试

创建main.cpp文件,内容如下:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <filesystem>

// Function to transform image to tensor
torch::Tensor transform_image(const cv::Mat& image) {
    cv::Mat img_transformed;
    cv::cvtColor(image, img_transformed, cv::COLOR_BGR2RGB);
    cv::resize(img_transformed, img_transformed, cv::Size(224, 224));
    img_transformed.convertTo(img_transformed, CV_32FC3, 1.0/255);
    auto img_tensor = torch::from_blob(img_transformed.data, {img_transformed.rows, img_transformed.cols, 3}, torch::kFloat);
    img_tensor = img_tensor.permute({2, 0, 1});
    img_tensor = torch::data::transforms::Normalize<torch::Tensor>({0.485, 0.456, 0.406}, {0.229, 0.224, 0.225})(img_tensor);
    img_tensor = img_tensor.unsqueeze(0);
    return img_tensor;
}

// Load the model and classify an image
void classify_image(const std::string& model_path, const std::string& image_path) {
    // Load the model
    torch::jit::script::Module model = torch::jit::load(model_path);
    model.eval(); // Switch to evaluation mode

    // Load and transform the image
    cv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR);
    if (image.empty()) {
        std::cerr << "Could not read the image: " << image_path << std::endl;
        return;
    }
    torch::Tensor tensor_image = transform_image(image);

    // Perform inference
    torch::Tensor output = model.forward({tensor_image}).toTensor();
    int64_t pred = output.argmax(1).item<int64_t>();

    std::cout << "The image is classified as class index: " << pred << std::endl;
}

int main(int argc, char* argv[]) {
    std::string model_path = "resnet18.pt"; // Default model path
    std::string image_path = "default_image.jpg"; // Default image path
	
    // 从命令行接受两个参数, 分别作为model_path和image_path
    if (argc >= 3) {
        model_path = argv[1];
        image_path = argv[2];
    } else {
        std::cout << "Using default model and image paths." << std::endl;
    }

    classify_image(model_path, image_path);
    return 0;
}

创建CMakeLists.txt,内容如下:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(ImageClassification)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# 设置LibTorch的位置, /path/to/libtorch替换为实际路径
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)

find_package(OpenCV REQUIRED)

add_executable(ImageClassification main.cpp)
target_link_libraries(ImageClassification "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")

编译并运行:

mkdir build && cd build
cmake ..
make

在build目录下会出现ImageClassification这个可执行文件,直接运行传入model_path和image_path即可。

image-20240625114911739

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1899536.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开发者评测|操作系统智能助手OS Copilot

操作系统智能助手OS Copilot 文章目录 操作系统智能助手OS CopilotOS Copilot 是什么优势功能 操作步骤创建实验重置密码创建Access Key配置安全组安装 os-copilot环境变量配置功能评测命令行模式多轮交互模式 OS Copilot 产品体验评测反馈OS Copilot 产品功能评测反馈 参考文档…

【鸿蒙学习笔记】Stage模型工程目录

官方文档&#xff1a;应用配置文件概述&#xff08;Stage模型&#xff09; 目录标题 FA模型和Stage模型工程级目录模块级目录app.json5module.json5程序执行流程程序基本结构开发调试与发布流程 FA模型和Stage模型 工程级目录 模块级目录 app.json5 官方文档&#xff1a;app.j…

【笔记】记一次在linux上通过在线安装mysql报错 CentOS 7 的官方镜像已经不再可用的解决方法+mysql配置

报错&#xff08;恨恨恨恨恨恨恨&#xff01;&#xff01;&#xff01;&#xff01;&#xff01;&#xff09;&#xff1a; [rootlocalhost ~]# sudo yum install mysql-server 已加载插件&#xff1a;fastestmirror, langpacks Determining fastest mirrors Could not retrie…

MWC上海展 | 创新微MinewSemi携ME54系列新品亮相Nordic展台

6月28日&#xff0c; 2024MWC上海圆满落幕&#xff0c;此次盛会吸引了来自全球124个国家及地区的近40,000名与会者。本届大会以“未来先行&#xff08;Future First&#xff09;”为主题&#xff0c;聚焦“超越5G”“人工智能经济”“数智制造”三大子主题&#xff0c;探索讨论…

AI语音工具——Fish Speech:使用简单,可训练专属语音模型!

今天给大家介绍一款超好用的AI语音工具——Fish Speech&#xff0c;使用简单&#xff0c;还可以训练自己的语音模型&#xff01; 工具介绍 Fish Speech是由 Fish Audio 开发的免费开源文本转语音模型。经过十五万小时的数据训练&#xff0c;Fish Speech能够熟练掌握中文、日语…

【Docker系列】Docker 镜像构建中的跨设备移动问题及解决方案

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

前端实现无缝自动滚动动画

1. 前言: 前端使用HTMLCSS实现一个无缝滚动的列表效果 示例图: 2. 源码 html部分源码: <!--* Author: wangZhiyu <w3209605851163.com>* Date: 2024-07-05 23:33:20* LastEditTime: 2024-07-05 23:49:09* LastEditors: wangZhiyu <w3209605851163.com>* File…

数据分析:基于STAR+FeatureCounts的RNA-seq分析全流程流程

流程主要包含两部分组成&#xff1a; 第一部分&#xff1a;二代测序数据的Raw data的fastq文件转换成Gene Count或者Features Counts表&#xff08;行是Features&#xff0c;列是样本名&#xff09;&#xff1b;第二部分&#xff1a;对counts 表进行统计分析&#xff0c;并对其…

印尼网络安全治理能力观察

在全国国际机场的移民服务完全瘫痪 100 多个小时后&#xff0c;印尼政府承认其新成立的国家数据中心 (PDN) 遭受了网络攻击。 恶意 Lockbit 3.0 勒索软件加密了存储在中心的重要数据&#xff0c;其背后的黑客组织要求支付 800 万美元的赎金。 不幸的是&#xff0c;大多数数据…

纸电混合阶段,如何在线上实现纸电会档案的协同管理?

随着国家政策的出台和引导&#xff0c;电子会计档案的管理越来越规范&#xff0c;电子会计档案建设成为打通财务数字化最后一公里的重要一环。但是&#xff0c;当前很多企业的财务管理仍处于电子档案和纸质档案并行的阶段&#xff0c;如何能将其建立合理清晰关联&#xff0c;统…

Hugging Face 全球政策负责人首次参加WAIC 2024 前沿 AI 安全和治理论坛

Hugging Face 全球政策负责人艾琳-索莱曼 &#xff08; Irene Solaiman &#xff09;将参加7月5日在上海举办的WAIC-前沿人工智能安全和治理论坛&#xff0c;并在现场进行主旨演讲和参加圆桌讨论。具体时间信息如下&#xff1a;主旨演讲&#xff1a;开源治理的国际影响时间 &am…

Spring框架Mvc(2)

1.传递数组 代码示例 结果 2.集合参数存储并进行存储类似集合类 代码示例 postman进行测试 &#xff0c;测试结果 3.用Json来对其进行数据的传递 &#xff08;1&#xff09;Json是一个经常使用的用来表示对象的字符串 &#xff08;2&#xff09;Json字符串在字符串和对象…

Java项目:基于SSM框架实现的共享客栈管理系统分前后台【ssm+B/S架构+源码+数据库+毕业论文】

一、项目简介 本项目是一套基于SSM框架实现的共享客栈管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观、操作简单、功能…

20行代码写一个简单 Blazor 时钟组件

群里有些同学询问怎么实现定时刷新界面, 那咱们写点试试看能不能达到需求. 代码比较简单, 就是Task每秒刷新页面. 然后封装为组件实现只局部刷新. Demo TimerAme.razor <p>Label DateTime.Now.ToLongTimeString()</p>TimerAme.razor.cs using Microsoft.AspNet…

计算机应用数学--第三次作业

第三次作业计算题编程题1 基于降维的机器学习2 深度学习训练方法总结 第三次作业 计算题 (15 分&#xff09;对于给定矩阵 A A A&#xff08;规模为 42&#xff09;&#xff0c;求 A A A 的 SVD&#xff08;奇异值分解&#xff09;&#xff0c;即求 U U U&#xff0c; Σ …

YOLOX算法实现血细胞检测

原文:YOLOX算法实现血细胞检测 - 知乎 (zhihu.com) 目标检测一直是计算机视觉中比较热门的研究领域。本文将使用一个非常酷且有用的数据集来实现YOLOX算法,这些数据集具有潜在的真实应用场景。 问题陈述 数据来源于医疗相关数据集,目的是解决血细胞检测问题。任务是通过显微…

vue require引入静态文件报错

如果是通过向后端发送请求&#xff0c;动态的获取对应的文件数据流很容易做到文件的显示和加载。现在研究&#xff0c;一些不存放在后端而直接存放在vue前端项目中的静态媒体文件如何加载。 通常情况下&#xff0c;vue项目的图片jpg&#xff0c;png等都可以直接在/ass…

库存监控和自动通知工具(用来抢商品)

这段代码是一个使用 Python 编写的简单库存监控脚本&#xff0c;其目的是定期检查某个网页上的商品是否缺货&#xff0c;并通过电子邮件通知用户。 这段代码作为库存监控和自动通知工具&#xff0c;对于想要购买如富士相机这类可能经常缺货的商品的用户来说&#xff0c;具有以…

shell脚本awk中使用for循环

今天想使用shell脚本处理一ini文件下的ip地址&#xff0c;也就是INTRANET&#xff0c;前面的ip地址&#xff0c;折腾挺久。文件格式如下&#xff1a; 正确代码&#xff1a; grep -E INTRANET /home/aaaa/bbbb/hostinfo.ini | awk -F , {for(i1; i<NF; i) if($i~"INT…

谷粒商城学习-09-配置Docker阿里云镜像加速及各种docker问题记录

文章目录 一&#xff0c;配置Docker阿里云镜像加速二&#xff0c;Docker安装过程中的几个问题1&#xff0c;安装报错&#xff1a;Could not resolve host: mirrorlist.centos.org; Unknown error1.1 检测虚拟机网络1.2 重设yum源 2&#xff0c;报错&#xff1a;Could not fetch…