使用PyTorch导出JIT模型:C++ API与libtorch实战

news2024/9/23 5:29:42

PyTorch导出JIT模型并用C++ API libtorch调用

本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。

Step1:导出模型

首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。

# export_jit_model.py
import torch
import torchvision.models as models

model = models.resnet50(pretrained=True)
model.eval()

example_input = torch.rand(1, 3, 224, 224)

jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

导出 JIT 模型的方式有两种:trace 和 script。

我们采用
torch.jit.trace
的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。

在我们的工程目录
demo
下运行上面的
export_jit_model.py
,会得到一个 JIT 模型件:
resnet50_jit.pth

Step 2:安装libtorch

接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:

wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

也解压在我们的工程目录
demo
下即可。

Step 3:安装OpenCV

用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录
demo
下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在
CMakeLists.txt
文件中正确地指定本机的 OpenCV 地址即可。

git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6

Step 4:准备测试图像并用Python测试

我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。

kitten.jpg

写一个脚本用 PyTorch 运行一下模型:

# pytorch_test.py

import torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image

# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([
                    transforms.Resize(224),
                    transforms.ToTensor()])
                    # normalize])

model = models.resnet50(pretrained=True)
model.eval()

img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())

输出结果是:282。通过查看
ImageNet 1K 类别名与索引的对应关系
,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。

Step 5:准备cpp源文件

我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。

这里有几个点要说一下,不注意可能会犯错:

  1. cv::imread()
    默认读取为三通道BGR,需要进行B/R通道交换,这里采用
    cv::cvtColor()
    实现。
  2. 图像尺寸需要调整到

224

×

224

224\times 224

2

2

4

×

2

2

4

,通过
cv::resize()
实现。
3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
np.transpose()
操作,这里使用
tensor.permut()
实现,效果是一样的。
4. 数据归一化,采用
tensor.div(255)
实现。

// test_model.cpp
#include <vector>

#include <torch/torch.h>
#include <torch/script.h>

#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>

int main(int argc, char* argv[]) {
  // 加载JIT模型
  auto module = torch::jit::load(argv[1]);

  // 加载图像
  auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);
  cv::Mat image_transfomed;
  cv::resize(image, image_transfomed, cv::Size(224, 224));
  cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);

  // 图像转换为Tensor
  torch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);
  tensor_image = tensor_image.permute({2, 0, 1});
  // tensor_image = tensor_image.toType(torch::kFloat);
  tensor_image = tensor_image.div(255.);
  // tensor_image = tensor_image.sub(0.5);
  // tensor_image = tensor_image.div(0.5);

  tensor_image = tensor_image.unsqueeze(0);

  // 运行模型
  torch::Tensor output = module.forward({tensor_image}).toTensor();

  // 结果处理
  int result = output.argmax().item<int>();
  std::cout << "The classifiction index is: " << result << std::endl;
  return 0;
}

Step 6:构建运行验证

我们先来写一下
CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)

find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)

add_executable(resnet50  test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")

set_property(TARGET resnet50  PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

现在我们的工程目录
demo
下有以下文件:

CMakeLists.txt  export_jit_model.py  kitten.jpg  libtorch  pytorch_test.py  resnet50_jit.pth  test_model.cpp

然后开始用 CMake 构建工程:

mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make

整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件
resnet50
在工程目录
demo
下。

接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:

./build/resnet50 resnet50_jit.pth kitten.jpg

输出:

The classifiction index is: 282

运行成功并且结果正确。

Ref:

https://www.jianshu.com/p/7cddc09ca7a4

https://blog.csdn.net/cxx654/article/details/115916275

https://zhuanlan.zhihu.com/p/370455320

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1950606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AR 眼镜之-蓝牙电话-实现方案

目录 &#x1f4c2; 前言 AR 眼镜系统版本 蓝牙电话 来电铃声 1. &#x1f531; 技术方案 1.1 结构框图 1.2 方案介绍 1.3 实现方案 步骤一&#xff1a;屏蔽原生蓝牙电话相关功能 步骤二&#xff1a;自定义蓝牙电话实现 2. &#x1f4a0; 屏蔽原生蓝牙电话相关功能 …

深入解读 Java 中的 `StringUtils.isNotBlank` 与 `StringUtils.isNotEmpty`

个人名片 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119@qq.com] 📱个人微信:15279484656 🌐个人导航网站:www.forff.top 💡座右铭:总有人要赢。为什么不能是我呢? 专栏导…

【支持语言模型和视觉语言模型的推理引擎sglang】

介绍 sglang是一个AI推理引擎&#xff0c;是一个专门为大语言模型和视觉语言模型设计的高效服务框架。 就像F1赛车需要顶级发动机一样&#xff0c;大语言模型也需要高效的推理引擎来发挥潜力。 而sglang正是这样一个性能怪兽。 根据LMSys组织的官方公告&#xff0c;最新的s…

Docker(十)-Docker运行elasticsearch7.4.2容器实例以及分词器相关的配置

1.下载镜像 1.1存储和检索数据 docker pull elasticsearch:7.4.2 1.2可视化检索数据 docker pull kibana:7.4.22.创建elasticsearch实例 创建本地挂载数据卷配置目录 mkdir -p /software/elasticsearch/config 创建本地挂载数据卷数据目录 mkdir -p /software/elasticse…

【React】JSX:从基础语法到高级用法的深入解析

文章目录 一、什么是 JSX&#xff1f;1. 基础语法2. 嵌入表达式3. 使用属性4. JSX 是表达式 二、JSX 的注意事项1. 必须包含在单个父元素内2. JSX 中的注释3. 避免注入攻击 三、JSX 的高级用法1. 条件渲染2. 列表渲染3. 内联样式4. 函数作为子组件 四、最佳实践 在 React 开发中…

20240724----idea的Java环境卸载与安装

1.删除旧有的jdk https://blog.csdn.net/weixin_42168713/article/details/112162099 &#xff08;补充&#xff1a;我把用户变量和java有关的都删了&#xff09; 2.下载新的jdk百度网盘链接 链接&#xff1a;https://pan.baidu.com/s/1gkuLoxBuRAtIB1IzUTmfyQ 提取码&#xf…

第二代欧洲结构设计标准简介

文章目录 0、背景1、总览2、更新及变化2.1 抗震2.2 地基基础2.3 防火 0、背景 本篇文章来自微信公众号土木吧&#xff0c;原作者李立昌&#xff08;北京鑫美格工程设计有限公司&#xff09;。对原文感兴趣的可以点击这里。 新的欧标滚滚而来&#xff0c;提前做好准备很有必要…

人工智能视频大模型:重塑视频处理与理解的未来

目录 一、人工智能视频大模型概述 1.1 定义与特点 1.2 技术基础 二、关键技术解析 2.1 视频特征提取 2.2 时空建模 2.3 多任务学习 三、应用场景展望 3.1 视频内容分析 3.2 视频编辑与生成 3.3 交互式视频体验 四、未来发展趋势 4.1 模型轻量化与移动端部署 4.2 …

前端面试项目细节重难点分享(十三)

面试题提问&#xff1a;分享你最近做的这个项目&#xff0c;并讲讲该项目的重难点&#xff1f; 答&#xff1a;最近这个项目是一个二次迭代开发项目&#xff0c;迭代周期一年&#xff0c;在做这些任务需求时&#xff0c;确实有很多值得分享的印象深刻的点&#xff0c;我讲讲下面…

【C语言】队列的实现(数据结构)

前言&#xff1a; 相信大家在生活中经常排队买东西&#xff0c;今天学习的队列就跟排队买东西一样&#xff0c;先来买的人就买完先走&#xff0c;也就是先进先出。废话不多说&#xff0c;进入咱们今天的学习吧。 目录 前言&#xff1a; 队列的概念 队列的实现 队列的定义 …

【8月EI会议推荐】第四届区块链技术与信息安全国际会议

一、会议信息 大会官网&#xff1a;http://www.bctis.nhttp://www.icbdsme.org/ 官方邮箱&#xff1a;icbctis126.com 组委会联系人&#xff1a;杨老师 19911536763 支持单位&#xff1a;中原工学院、西安工程大学、齐鲁工业大学&#xff08;山东省科学院&#xff09;、澳门…

git 学习总结

文章目录 一、 git 基础操作1、工作区2、暂存区3、本地仓库4、远程仓库 二、git 的本质三、分支git 命令总结 作者: baron 一、 git 基础操作 如图所示 git 总共有几个区域 工作区, 暂存区, 本地仓库, 远程仓库. 1、工作区 存放项目代码的地方&#xff0c;他有两种状态 Unm…

RK3588+MIPI+GMSL+AI摄像机:自动车载4/8通道GMSL采集/边缘计算盒解决方案

RK3588作为目前市面能买到的最强国产SOC&#xff0c;有强大的硬件配置。在智能汽车飞速发展&#xff0c;对图像数据矿场要求越来越多的环境下&#xff0c;如何高效采集数据&#xff0c;或者运行AI应用&#xff0c;成为刚需。 推出的4/8通道GMSL采集/边缘计算盒产品满足这些需求…

MinIO存储桶通知 - Kafka小测

概述 公司的某个项目需要用上这玩意&#xff0c;所以在本地搭建测试环境&#xff0c;经过一番折腾&#xff0c;测试通过&#xff0c;博文记录&#xff0c;用以备忘 MinIO安装 该节不做说明&#xff0c;网络有很多现成的帖子&#xff0c;自行搜索去 配置步骤 控制台添加事件…

瑞芯微芯片资料中关于图像处理相关的知识点

目录 MPI层模块介绍IPC的应用像素格式排布系统绑定API接口 MPI层 文件&#xff1a;Rockchip_Developer_Guide_MPI.pdf RK MPI&#xff1a;Rockchip Media Process Interface&#xff0c;媒体处理接口。 模块介绍 RK MPI层的模块介绍&#xff1a; IPC的应用 VI 模块捕获视频…

工业三防平板电脑助力工厂产线管理的智能化转型

在当今高度数字化和智能化的工业时代&#xff0c;工厂产线管理正经历着前所未有的变革。其中&#xff0c;工业三防平板电脑作为一种创新的技术工具&#xff0c;正发挥着日益重要的作用&#xff0c;有力地推动着工厂产线管理向智能化转型。 一、工业三防平板电脑具有出色的防水、…

微信小程序-本地部署(前端)

遇到问题&#xff1a;因为是游客模式所以不能修改appID. 参考链接&#xff1a;微信开发者工具如何从游客模式切换为开发者模式&#xff1f;_微信开发者工具如何修改游客模式-CSDN博客 其余参考&#xff1a;Ego微商项目部署&#xff08;小程序项目&#xff09;&#xff08;全网…

大语言模型是什么,该如何去学习呢

什么是 LLM**&#xff1f;** LLM(大型语言模型&#xff0c; Large Lanage Modle)是一种计算机程序&#xff0c;它可以理解和生成类似人类的文本&#xff1b;它能够像我们人类一样阅读、写作和理解语言。你可以把它想象成一个超级聪明的博学的不知疲惫的24小时全年无休的助手。…

使用代理IP进行本地SEO优化:如何吸引附近的客户?

在今天竞争激烈的互联网时代&#xff0c;如何利用代理IP进行本地SEO优化并吸引附近的客户已经成为许多企业和网站面临的关键挑战。本文将探讨使用代理IP的策略和技巧&#xff0c;以帮助公司提高在本地市场的可见性和吸引力&#xff0c;从而扩大本地客户群体。 1. 代理IP在本地…

小型内衣裤洗衣机哪个牌子好?五款万分翘楚机型任你挑选!

在日常生活中&#xff0c;内衣洗衣机已成为现代家庭必备的重要家电之一。选择一款耐用、质量优秀的内衣洗衣机&#xff0c;不仅可以减少洗衣负担&#xff0c;还能提供高效的洗涤效果。然而&#xff0c;市场上众多内衣洗衣机品牌琳琅满目&#xff0c;让我们往往难以选择。那么&a…