pytorch C++ 移植

news2024/11/19 16:33:18

文章目录

  • 前言
  • 安装 libtorch
  • 安装 opencv(C++)
  • 模型转换
    • 通过跟踪转换为 Torch Script
    • 通过注解转换为 Torch Script
  • 编写 C++ 代码
  • 编译环境搭建
    • C++ 库管理
    • 方法一:手动配置 visual studio 环境
    • 方法二:cmake 配置环境
  • python 调用 C++ 程序

前言

pytorch 神经网络模型来自我的这篇文章:https://blog.csdn.net/weixin_45715159/article/details/108277349

训练好的神经网络模型
链接:https://pan.baidu.com/s/1-cZCLNuYs-1MOu1NI42UOw
提取码:4wsv

pytorch 进行神经网络的搭建和训练固然很方便,但是当部署到实际中来时就显得效率一般。例如网络爬虫需要识别验证码以进行登录,运行 python 程序的话光是加载 torch 库就要花一段时间,这对于不需要一次进行大量图片识别的网络爬虫是一种拖累。所以需要编写 C++ 程序来加载神经网路模型以提高程序运行的效率。

下面以一个 CNN 网络的移植为例介绍 pytorch 的 C++ 移植方法
部分来自官方教程:https://pytorch.org/tutorials/advanced/cpp_export.html

环境

  • 操作系统:Windows 10
  • CPU:Intel i5-9300
  • GPU:GTX 1650
  • 深度学习框架:pytorch 1.6,libtorch 1.6
  • CUDA 11
  • python 3.8
  • opencv 4.4.0(C++)
  • visual studio 2019

安装 libtorch

libtorch 是 torch 的 C++ 版本。libtorch 的版本必须与 pytorch 一致

首先进入官网:https://pytorch.org/

在 pytorch 的安装栏,Package 选择 libtorchLangugae 选择 C++/Java

获得如下的下载链接:

在这里插入图片描述

我选择的是 Release version,点击相应链接下载完毕后解压在合适的目录下。

安装 opencv(C++)

opencv 主要用来进行 C++ 程序的图像处理。

和 libtorch 安装类似,在下载页面 https://opencv.org/releases/ 选择相应的操作系统。例如选择 Windows 系统,跟随安装程序选择合适的安装目录即可完成安装。

模型转换

pytorch 提供了一种统一的模型描述语言 Torch Script 供其他编程语言程序加载,下面介绍两种将我们训练出来的模型转换为 Torch Script 的方法,也可以参考这篇博客:https://blog.csdn.net/xxradon/article/details/86504906

通过跟踪转换为 Torch Script

将模型的实例以及示例输入传递给 torch.jit.trace 函数。

这个方法适用于对任意类型输入有固定格式输出的神经网络。

import torch
from Network import Net

net = Net()
net.load_state_dict(torch.load('.\\model.pt'))
example = torch.rand(1, 1, 45, 45)
scrpit_net = torch.jit.trace(net,example)
script_net.save('.\\model_script.pt')

跟踪器有可能生成警告,因为有就地赋值(torch.rand())。

通过注解转换为 Torch Script

这个方法适用于对不同类型输入有不同格式输出的神经网络。

import torch
from Network import Net

net = Net()
net.load_state_dict(torch.load('.\\model.pt'))
scrpit_net = torch.jit.script(net)
script_net.save('.\\model_script.pt')

编写 C++ 代码

源.cpp 文件内容如下,调用了 opencv 库和 libtorch 库。

  • decode() 解码函数,类似于前面 python 的解码函数。

  • C++ 程序通过主函数的入口参数 argc(参数个数) 和 argv[](参数集) 获取传入程序的参数,例如在windows 命令行中输入:

    program.exe arg1 arg2
    

    即可向程序传入两个参数:arg1arg2,在程序中读取 argv[0]argv[1] 即可知道这两个参数的值。

    在本程序中传入的参数为前面保存的神经网络模型的路径和图片的路径

  • torch::jit::script::Module module = torch::jit::load(argv[1]) 通过模型路径读取模型并创建神经网络,十分简便。

  • cv::imread() 为 opencv 的图像读取函数,输入图片路径即可返回 Mat 类型的图像数据矩阵。

  • cv::resize() 为 opencv 的图像变形函数,在这里使图片变形为 45 x 45。

  • cv::cvtColor() 为 opencv 的图像色域改变函数,opencv 读取的图像通道和我们常见的 RGB 通道不同,它是 BGR,在这里我们只需要将它转变为单通道(即灰度图)。

  • image.convertTo(image, CV_32FC1);Mat 类型的数据类型转换方法,在这里转换为32为浮点型单通道。

  • cv::vconcat() 为 opencv 的图像拼接函数,在这里是在图像个数维上进行拼接。

  • 使用 write 标记 images 变量是否已赋值,因为我们要将所有的输入图片拼接成一个整体再转换为 Tensor 传给神经网络以加快程序运行速度。

  • torch::from_blob() 为 libtorch 提供的 Mat 类型转 Tensor 类型的函数接口,我们可以看到最终输入给神经网络的数据维度为 N x 1 x 45 x 45 (N 为图片个数)。

  • std::vector<torch::jit::IValue> inputs;

    inputs.push_back(input_tensors);

    auto outputs = module.forward(inputs).toTensor();

    这三条语句为我在网上查到的一种固定写法,大概是要将 Tensor 类型的数据放在一种叫向量的数据结构里才能传递给神经网络。

  • 最总通过 std::cout 直接输出识别结果,多图片时输出结果以单空格隔开。

也不知道为什么,发现把模型置于计算模式时(module.eval();)输出结果是错误的。

/*
神经网络调用程序,根据命令行参数直接输出识别结果,可一次识别多张,上限100张
参数顺序:神经网络模型路径,图片1路径,……,图片n路径
*/

#include <torch/script.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <memory>

std::string decode(at::Tensor code) //解码函数
{
    int char_num = 4; //字符个数
    int a;
    std::string str = ""; //输出字符串
    std::string table[10] = {"1", "2", "3", "b", "c", "m", "n", "v", "x", "z"}; //编码对照表
    code = code.view({ -1, 10 });
    code = torch::argmax(code, 1);
    for (int i = 0; i < char_num; i++)
    {
        a = code[i].item().toInt();
        str += table[a];
    }
    return str;
}

int main(int argc, const char *argv[])
{
    if (argc < 3 || argc > 102)
    {
        std::cerr << "调用错误!" << std::endl;
        return -1;
    }
    try
    {
        int image_num = argc - 2; //读取的图片数量
        bool write = false;
        torch::jit::script::Module module = torch::jit::load(argv[1]); //读取模型
        cv::Mat images; //保存图片数据
        for (int i = 0; i < image_num; i++)
        {
            cv::Mat image = cv::imread(argv[i + 2]); //读取图片
            cv::resize(image, image, cv::Size(45, 45)); //变形成为45 x 45
            cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); //转成灰度图
            image.convertTo(image, CV_32FC1); //转换为32位浮点型
            if (!write)
            {
                images = image;
                write = true;
            }
            else
            {
                cv::vconcat(images, image, images); //拼接图像
            }
        }
        auto input_tensors = torch::from_blob(images.data, {image_num, 1, 45, 45 }); //将mat转成tensor
        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(input_tensors);
        auto outputs = module.forward(inputs).toTensor();
        std::cout << decode(outputs[0]);
        for (int i = 1; i < image_num; i++)
        {
            auto result = decode(outputs[i]);//解码
            std::cout << ' ' << result;
        }
        return 0;
    }
    catch (...)  //捕获任意异常
    {
        std::cerr << "程序执行出错!" << std::endl;
        return -1;
    }
}

编译环境搭建

下面介绍两种编译环境的搭载方法,一种是直接在 visual studio 中配置,一种是官方推荐的用 cmake 配置 visual studio 环境。

确保已安装 visual studio 的 C++ 部件,下面所有方法都以 visual studio 2019 为 IDE,程序编译类型为 x64 release

C++ 库管理

在前面我们已安装了 libtorch 和 opencv,以 libtorch 1.6.0 (release) 和 opencv 4.4.0 为例,假设它们放在一个文件夹:D:\CplusLib\

在这里插入图片描述

程序需要的动态链接库位置:

  • libtorchD:\CplusLib\libtorch\lib
  • opencvD:\CplusLib\opencv\build\x64\vc15\bin

接下来我们需要将这些动态链接库的路径添加进环境变量,以便 C++ 程序通过环境变量找寻这些动态链接库。

以 Windows 10 系统为例,右键 此电脑,选择 属性 -> 高级系统设置 -> 环境变量 打开界面。

在这里插入图片描述

有两种环境变量:用户变量和系统变量,一种只针对一个用户有效,另一种对所有用户都有效。

我们添加系统变量的 path 项,双击 path,点击 新建 输入路径:

在这里插入图片描述

点击 确定 -> 确定

方法一:手动配置 visual studio 环境

直接在 visual studio 环境中配置要一个一个去写包含的库的路径和动态链接库的名称。我是采用属性表的形式直接让工程项目读取,这样每创建一个新工程就不用重复配置了。

以上面的库安装路径为基础提供每个库的属性表:

  • libtorch(release)libtorch.Cpp.x64.user.props

    <?xml version="1.0" encoding="utf-8"?>
    <Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
      <ImportGroup Label="PropertySheets" />
      <PropertyGroup Label="UserMacros" />
      <PropertyGroup>
     <IncludePath>D:\CplusLib\libtorch\include;D:\CplusLib\libtorch\include\torch;$(IncludePath)</IncludePath>
        <LibraryPath>D:\CplusLib\libtorch\lib;$(LibraryPath)</LibraryPath>
      </PropertyGroup>
      <ItemDefinitionGroup>
        <Link>
      <AdditionalDependencies>asmjit.lib;c10.lib;c10_cuda.lib;caffe2_detectron_ops_gpu.lib;caffe2_module_test_dynamic.lib;caffe2_nvrtc.lib;clog.lib;cpuinfo.lib;dnnl.lib;fbgemm.lib;libprotobuf-lite.lib;libprotobuf.lib;libprotoc.lib;mkldnn.lib;torch.lib;torch_cpu.lib;torch_cuda.lib;%(AdditionalDependencies)</AdditionalDependencies>
        </Link>
      </ItemDefinitionGroup>
      <ItemGroup />
    </Project>
    
  • opencv (release)opencv.Cpp.x64.user.props

    <?xml version="1.0" encoding="utf-8"?>
    <Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
      <ImportGroup Label="PropertySheets" />
      <PropertyGroup Label="UserMacros" />
      <PropertyGroup>
      <IncludePath>D:\CplusLib\opencv\build\include;D:\CplusLib\opencv\build\include\opencv2;$(IncludePath)</IncludePath>
        <LibraryPath>D:\CplusLib\opencv\build\x64\vc15\lib;$(LibraryPath)</LibraryPath>
      </PropertyGroup>
      <ItemDefinitionGroup>
        <Link>
          <AdditionalDependencies>opencv_world440.lib;%(AdditionalDependencies)</AdditionalDependencies>
        </Link>
      </ItemDefinitionGroup>
      <ItemGroup />
    </Project>
    

其中IncludePath 是库的包含路径,LibraryPath 是链接库路径,AdditionalDependencies 是链接库名称。

如何导入属性表

在 visual studio 中点击 属性管理器,右键 Release|x64 -> 添加现有属性表。最后效果如图:

在这里插入图片描述

新建 C++ 项目,将我们前面编写的 C++ 文件添加进来,按照如图配置:

在这里插入图片描述

点击 本地 Windows调试器 即可完成编译。

方法二:cmake 配置环境

首先安装 cmake:https://cmake.org/download/

比如 Windows 64位系统选择 cmake-3.18.2-win64-x64.msi,跟随安装程序即可完成安装。

新建一个目录 D:\CaptchaRecognize\,在目录下新建 CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(CaptchaRecognize)

SET(CMAKE_BUILE_TYPE RELEASE)

INCLUDE_DIRECTORIES(
D:/CplusLib/libtorch/include
D:/CplusLib/libtorch/include/torch
D:/CplusLib/opencv/build/include
D:/CplusLib/opencv/build/include/opencv2
)

SET(TORCH_LIBRARIES D:/CplusLib/libtorch/lib)
SET(OpenCV_LIBS D:/CplusLib/opencv/build/x64/vc15/lib)

LINK_DIRECTORIES(
${TORCH_LIBRARIES}
${OpenCV_LIBS}
)

add_executable(CaptchaRecognize 源.cpp)

target_link_libraries(CaptchaRecognize
asmjit.lib
c10.lib
c10_cuda.lib
caffe2_detectron_ops_gpu.lib
caffe2_module_test_dynamic.lib
caffe2_nvrtc.lib
clog.lib
cpuinfo.lib
dnnl.lib
fbgemm.lib
libprotobuf-lite.lib
libprotobuf.lib
libprotoc.lib
mkldnn.lib
torch.lib
torch_cpu.lib
torch_cuda.lib
opencv_world440.lib
)

set_property(TARGET CaptchaRecognize PROPERTY CXX_STANDARD 14)

该文件中的一些属性和属性表的内容相似。将 源.cpp 复制到该目录下,新建目录 build

在这里插入图片描述

打开 Cmake 进行如下配置:

在这里插入图片描述

点击 configure -> finish,然后点击 Generate 生成 visual studio 工程。

build 目录中找到工程文件并打开,打开 解决方案资源管理器,右键 CaptchaRecognize -> 设为启动项目

在这里插入图片描述

编译属性

在这里插入图片描述
点击 本地 Windows调试器 即可完成编译。

python 调用 C++ 程序

该 C++ 程序的命令行调用示例

CaptchaRecognize.exe model_script.pt pic1.png pic2.png

在 python 中通过 os.popen('command').read() 即可像命令行一样调用 C++ 程序并读取程序的输出结果。

例如

import os

result = os.popen('CaptchaRecognize.exe model_script.pt pic1.png pic2.png').read() #调用 C++ 程序
result = result.split(' ')

由前面可知 C++ 程序输出的结果由单空格隔开,通过 .split(' ') 可得到含所有图片识别结果的列表。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1120493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

go语言Array 与 Slice

有的语言会把数组用作常用的基本的数据结构&#xff0c;比如 JavaScript&#xff0c;而 Golang 中的数组(Array)&#xff0c;更倾向定位于一种底层的数据结构&#xff0c;记录的是一段连续的内存空间数据。但是在 Go 语言中平时直接用数组的时候不多&#xff0c;大多数场景下我…

MySQL中查询重复字段的方法和步骤是怎样

示例 accountinfo 表数据如下&#xff1a; 场景一 单个字段重复数据查找 & 去重 我们要把上面这个表中 单个字段 account字段相同的数据找出来。 思路 分三步 简述&#xff1a; 第一步 要找出重复数据&#xff0c;我们首先想到的就是&#xff0c;既然是重复&#xff0c…

【斗破年番】再遭群嘲,美杜莎怀孕之事被魔改,三方联手除萧潇?

【侵权联系删除】【文/郑尔巴金】 斗破苍穹年番第67集已经更新了。和很多人一样&#xff0c;小郑也去看了&#xff0c;只是小郑万万没有想到&#xff0c;我满怀期待的去看这一集&#xff0c;这一集却能魔改成这样。魔改成什么样了呢&#xff1f;下面来分析下吧&#xff01; 一&…

高效表达三步

一、高效表达 高效表达定主题搭架子填素材 第一&#xff1a; 1个核心主题&#xff0c;让别人秒懂你的想法 &#xff08;表达要定主题&#xff09; 第二&#xff1a; 3种经典框架&#xff0c;帮你快速整理表达思路 第三&#xff1a; 2种表达素材&#xff0c;让发言更具说服力…

基础算法相关笔记

排序 最好情况下&#xff1a; 冒泡排序 最坏时间复杂度 O ( n 2 ) O(n^2) O(n2)。 插入排序 最坏时间复杂度为 O ( n 2 ) O(n^2) O(n2)&#xff0c;最优时间复杂度为 O ( n ) O(n) O(n)。 平均情况下&#xff1a; 快速排序 最坏时间复杂度为 O ( n 2 ) O(n^2) O(n2)&…

跟我一起写个虚拟机 .Net 7(四)- LC_3 解析实例

没想到这篇文章持续了这么久&#xff0c;越学越深&#xff0c;愣是又买了一本书《计算机系统概论》&#xff0c;当然&#xff0c;也看完了&#xff0c;受益匪浅。 系统化的学习才是正确的学习方式&#xff0c;我大学就没看到过这本书&#xff0c;如果早点看到&#xff0c;可能…

可视化 | python可视化相关库梳理(自用)| pandas | Matplotlib | Seaborn | Pyecharts | Plotly

文章目录 &#x1f4da;Plotly&#x1f407;堆叠柱状图&#x1f407;环形图&#x1f407;散点图&#x1f407;漏斗图&#x1f407;桑基图&#x1f407;金字塔图&#x1f407;气泡图&#x1f407;面积图⭐️快速作图工具&#xff1a;plotly.express&#x1f407;树形图&#x1f…

MySQL 排名函数 RANK, DENSE_RANK, ROW_NUMBER

文章目录 1 排名函数有哪些?2 SQL 代码实现2.1 RANK2.2 DENSE_RANK2.3 ROW_NUMBER 1 排名函数有哪些? RANK(): 并列跳跃排名, 并列即相同的值, 相同的值保留重复名次, 遇到下一个不同值时, 跳跃到总共的排名DENSE_RANK(): 并列连续排序, 并列即相同的值, 相同的值保留重复名…

图详解第六篇:多源最短路径--Floyd-Warshall算法(完结篇)

文章目录 多源最短路径--Floyd-Warshall算法1. 算法思想2. dist数组和pPath数组的变化3. 代码实现4. 测试观察5. 源码 前面的两篇文章我们学习了两个求解单源最短路径的算法——Dijkstra算法和Bellman-Ford算法 这两个算法都是用来求解图的单源最短路径的算法&#xff0c;区别在…

effective c++学习笔记(后四章)

六 继承与面向对象设计 红色字 \color{FF0000}{红色字} 红色字 32 确定你的public继承塑模出 is-a关系 如果你令class D (“Derived”)以public形式继承class B (“Base”)&#xff0c;你便是告诉C编译器&#xff08;以及你的代码读者&#xff09;说&#xff0c;每一个类型为…

基于目录的ant任务

一些任务利用目录树来执行一些动作 一些任务利用目录树来执行一些动作。例如&#xff0c;javac这个任务就是一个基于目录的任务&#xff0c;它将一个目录中的.java文件编译为.class文件。因为一些这样的任务在目录树上做很多的工作&#xff0c;所以这些任务本身充当了隐含的文…

C# Socket通信从入门到精通(2)——多个同步TCP客户端C#代码实现

前言: 我们在开发Tcp客户端程序的时候,有时候在同一个软件上我们要连接多个服务器,这时候我们开发的一个客户端就不够使用了,这时候就需要我们开发出来的软件要支持连接多个服务器,最好是数量没有限制,这样我们就能应对任意数量的服务器连接,由于我们开发的Tcp客户端程…

7个可能改变AEC行业的AI工具

推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景 人工智能&#xff08;AI&#xff09;工具在各个行业中越来越受欢迎&#xff0c;ChatGDP的推出无疑让人们看到了人工智能所能提供的可能性。 然而&#xff0c;人工智能不仅仅是生成文本或图形——它可以用于各种设置。 建筑…

【面试题】JDBC桥接模式如何实现的?

Hello 大家好&#xff0c;我是小米&#xff01;很高兴又和大家见面啦&#xff01;今天的主题是——"面试题&#xff1a;JDBC桥接模式如何实现的&#xff1f;"。 相信大家都听说过JDBC&#xff08;Java Database Connectivity&#xff09;&#xff0c;它是Java中连接…

QT判断平台和生成版本设置输入目录

QT判断平台和生成版本设置输入目录 pro工程文件中常用的宏定义Chapter1 QT判断平台和生成版本设置输入目录Chapter2 Qt pro文件中判断 x86/arm(aarch64)交叉编译环境&#xff0c;区分 linux/windows系统, debug/release版本Chapter3 Qt的版本判断、跨平台选择与pro工程文件输出…

231022|redis_demo

安装 https://github.com/tporadowski/redis https://github.com/redis/redis-py/ 解压后要先配置redis.windows.conf文件&#xff0c;里面有本地端口和密码设置 默认host:127.0.0.1 port:6379 打开命令行到redis文件夹下&#xff0c;redis-server.exe redis.windows.conf输入即…

1024我来利用DOS攻击你的电脑了?(第十三课)

1024我来利用DOS攻击你的电脑了&#xff1f;(第十三课) 本文章设计安全领域的重点问题 学习本文章时 请扎在初学者的角度学习 用于正途 一 国家安全法 1 安全法律法规 《宪法》中的相关规定 案例&#xff1a; 大山破解同事小美私人邮箱密码&#xff0c;读取其往来邮件 邮箱…

Go并发编程之四

一、前言 今天我们介绍一下Go并发编程另外一个重要概念【多路复用】&#xff0c;多路复用最开始是在网络通讯领域&#xff08;硬件&#xff09;应用&#xff0c;指的是用同一条线路承载多路信号进行通信的方式&#xff0c;有频分多路复用、时分多路复用等等技术&#xff0c;然…

组合数(递推版)的初始化

初始考虑为将第一列数和斜对角线上的数进行初始化。 橙色方块由两个绿色方块相加而来&#xff0c;一个为1&#xff0c;一个为0&#xff0c;所以斜对角线都为1&#xff0c;可以通过计算得来&#xff0c;不需要初始化&#xff0c;需要与码蹄集盒子与球 第二类Stirling数&#xf…