记录pytorch实现自定义算子并转onnx文件输出

news2024/10/7 6:45:50

概览:记录了如何自定义一个算子,实现pytorch注册,通过C++编译为库文件供python端调用,并转为onnx文件输出

整体大概流程:

  • 定义算子实现为torch的C++版本文件
  • 注册算子
  • 编译算子生成库文件
  • 调用自定义算子

一、编译环境准备

1,在pytorch官网下载如下C++的libTorch package,下载完成后解压文件,是一个libtorch文件夹。

2,提前准备好python,以及pytorch

3,本示例使用了opencv库,所以需要提前安装好opencv。

二、自定义算子的实现

1,实现自定义算子函数

在解压后的libtorch文件夹统计目录,实现自定义算子,用opencv库实现的图像投射函数:warp_perspective。warp_perspective函数后面几行就是实现自定义算子的注册

warpPerspective.cpp文件:

#include "torch/script.h"
#include "opencv2/opencv.hpp"

torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
    // BEGIN image_mat
    cv::Mat image_mat(/*rows=*/image.size(0),
        /*cols=*/image.size(1),
        /*type=*/CV_32FC1,
        /*data=*/image.data_ptr<float>());
    // END image_mat

    // BEGIN warp_mat
    cv::Mat warp_mat(/*rows=*/warp.size(0),
        /*cols=*/warp.size(1),
        /*type=*/CV_32FC1,
        /*data=*/warp.data_ptr<float>());
    // END warp_mat

    // BEGIN output_mat
    cv::Mat output_mat;
    cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ image.size(0),image.size(1) });
    // END output_mat

    // BEGIN output_tensor
    torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ image.size(0),image.size(1) });
    return output.clone();
    // END output_tensor
}
//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);  // torch.__version__: 1.5.0


 torch.__version__ >= 1.6.0  torch/include/torch/library.h
TORCH_LIBRARY(my_ops, m) {
    m.def("warp_perspective", warp_perspective);
}


2,同级目录创建CMakeList.txt文件

里面需要修改你自己的python下torch的路径,以及你对应安装python版pytorch是cpu还是gpu的。

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(warp_perspective)

set(CMAKE_VERBOSE_MAKEFILE ON)
# >>> build type 
set(CMAKE_BUILD_TYPE "Release")				# 指定生成的版本
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")


set(TORCH_ROOT "/home/xxx/anaconda3/lib/python3.10/site-packages/torch")   
include_directories(${TORCH_ROOT}/include)
link_directories(${TORCH_ROOT}/lib/)

# Opencv
find_package(OpenCV REQUIRED)

# Define our library target
add_library(warp_perspective SHARED warpPerspective.cpp)

# Enable C++14
target_compile_features(warp_perspective PRIVATE cxx_std_17)

# libtorch库文件
target_link_libraries(warp_perspective 
    # CPU
    c10 
    torch_cpu
    # GPU
    # c10_cuda 
    # torch_cuda
    
)


# opencv库文件
target_link_libraries(warp_perspective
    ${OpenCV_LIBS}
)

add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)

3,编译生成库文件

同级目录创建build文件夹,进入build文件夹利用CMakeList.txt进行编译,生成libwarp_perspective.so库文件

mkdir build
cd build
cmake ..
make

4,python版pytorch进行自定义算子的测试

注意我的以上代码都是放在了/data/xxx/mylib路径下,所以torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")就找到库文件的位置。

这里我随便找了一张图片,和直接用python版的opencv做投射变换的结果作为golden对比。如下分别是原图,golden, 自定义pytorch算子的输出。自定义算子的输出不太对,但是图像轮廓和投射效果是对的,后面有时间我再检查一下是什么原因。

测试代码: 

import torch
import cv2
import numpy as np

torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")

im=cv2.imread("/data/xxx/mylib/cat.jpg",0)

pst1 = np.float32([[56,65], [368,52], [28,387], [389,390]])
pst2 = np.float32([[100,145], [300,100], [80,290], [310,300]])
#2.2获取透视变换矩阵
T = cv2.getPerspectiveTransform(pst1, pst2)


in_data =torch.from_numpy(np.float32(im))
in2_data = torch.Tensor(T)

out1=torch.ops.my_ops.warp_perspective(in_data,in2_data)
dst0=np.uint8(out1.numpy())
cv2.imwrite("/data/xxx/mylib/cat_warp.jpg",dst0)

dst = cv2.warpPerspective(im, np.float32(T), (im.shape[1], im.shape[0]))
cv2.imwrite("/data/xxx/mylib/cat_warp_gold.jpg",dst)

三、自定义算子导出为onnx文件

将注册的pytorch的自定义算子导出为onnx文件查看,效果图如下:

导出代码文件如下

import torch
import numpy as np

torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")
class MyNet(torch.nn.Module):
    def __init__(self, name):
        super(MyNet, self).__init__()
        self.model_name = name

    def forward(self, in_data, warp_data):
        return torch.ops.my_ops.warp_perspective(in_data, warp_data)


def my_custom(g, in_data, warp_data):
    return g.op("cus_ops::warp_perspective", in_data, warp_data)
torch.onnx.register_custom_op_symbolic("my_ops::warp_perspective", my_custom, 9)


if __name__ == "__main__":
    net = MyNet("my_ops")
    in_data = torch.randn((32, 32))
    warp_data = torch.rand((3, 3))

    out = net(in_data, warp_data)
    print("out: ", out)

    # export onnx
    torch.onnx.export(net,
            (in_data, warp_data),
            "./my_ops_export_model2.onnx",
            input_names=["img_data", "warp_mat"],
            output_names=["out_img"],
            custom_opsets={"cus_ops": 11},
            )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1206287.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux可以投屏到电视吗?用网页浏览器就能投屏到电视!

Linux系统的电脑如果要投屏到安卓电视屏幕上&#xff0c;可以使用投屏工具AirDroid Cast的网页版和TV版一起实现。 首先&#xff0c;在Linux系统的电脑里用chrome浏览器或edge浏览器打开webcast.airdroid.com。这就是AirDroid Cast的网页版。你可以看到中间白色框框的右上角有个…

12358748257

问题一&#xff1a;.浮点数打印问题 float red_increment (target_red_value - initial_red_value) / STEPS; u8 STEPS 100; printf("绿色值每一次增量------%f\n", red_increment); 后面三个参数均为u8类型 希望采用 %f打印出每次的步进值。但是结果为空白 希…

聚观早报 |滴滴发布Q3财报;小鹏G9连续销量排行第一

【聚观365】11月14日消息 滴滴发布Q3财报 小鹏G9连续销量排行第一 XREAL双11实现7倍增长 真我GT5 Pro真机图 2024年智能手机AI功能竞争激烈 滴滴发布Q3财报 滴滴在其官网发布2023年三季度业绩报告。报告显示&#xff0c;三季度滴滴实现总收入514亿元&#xff0c;同比增长…

【Mysql系列】Mysql基础篇

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

YOLOv8-Seg改进:小目标涨点系列篇 | TPC-YOLO-seg不同场景小目标分割均能提升 | 23年顶刊最新成果

🚀🚀🚀本文改进:轻量级的基于注意力的网络 TPC-YOLO-seg用于微小物体分割 🚀🚀🚀TPC-YOLO-seg 小目标分割首选,暴力涨点 🚀🚀🚀YOLOv8-seg创新专栏:http://t.csdnimg.cn/KLSdv 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研; 1)手把手教你如何…

SAP中销售业务的查询修改及冲销操作手册

目的 物流在销售订单发货开票出问题时进行查询分析及处理冲销的相关操作 触发条件 销售业务出现变更导致需要重新做销售或人为错误 必要条件 订单&#xff0c;交货单&#xff0c;发票己完成并过账 有用提示 在实际冲销业务过程中需要去分析&#xff0c;在了解业务的情况下去…

JSP详细

一.JSP简介 JSP&#xff08;全称Java Server Pages&#xff09;java服务器页面。 是一种动态网页技术标准。JSP部署于网络服务器上&#xff0c;可以响应客户端发送的请求&#xff0c;并根据请求内容动态地生成HTML、XML或其他格式文档的Web网页&#xff0c;然后返回给请求者。…

探索高效智能:AI 模型的优化工具盘点 | 开源专题 No.43

openai/evals Stars: 12.3k License: NOASSERTION OpenAI Evals 是一个用于评估 LLMs (大型语言模型) 或使用 LLMs 作为组件构建的系统的框架。它还包括一个具有挑战性 evals 的开源注册表。Evals 现在支持通过 Completion Function Protocol 评估任何系统&#xff0c;包括 p…

代码随想录算法训练营第五十三天丨 动态规划part14

1143.最长公共子序列 思路 本题和动态规划&#xff1a;718. 最长重复子数组 (opens new window)区别在于这里不要求是连续的了&#xff0c;但要有相对顺序&#xff0c;即&#xff1a;"ace" 是 "abcde" 的子序列&#xff0c;但 "aec" 不是 &quo…

LeetCode(12)时间插入、删除和获取随机元素【数组/字符串】【中等】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; 380. O(1) 时间插入、删除和获取随机元素 1.题目 实现RandomizedSet 类&#xff1a; RandomizedSet() 初始化 RandomizedSet 对象bool insert(int val) 当元素 val 不存在时&#xff0c;向集合中插入该项&#xff0c;并返回…

xss学习笔记

跨站脚本攻击 掌握XSS 的原理 掌握XSS 的场景 掌握XSS 的危害 掌握XSS 漏洞验证 掌握XSS 的分类跨站脚本攻击 漏洞概述 ​ 跨站点脚本&#xff08;Cross Site Scripting&#xff0c; XSS&#xff09;是指客户端代码注入攻击&#xff0c;攻击者可以在合法网站或Web 应用程…

百度文心一言

1分钟了解一言是谁&#xff1f; 一句话介绍【文心一言】 我是百度研发的人工智能模型&#xff0c;任何人都可以通过输入【指令】和我进行互动&#xff0c;对我提出问题或要求&#xff0c;我能高效地帮助你们获取信息、知识和灵感哦 什么是指令&#xff1f;我该怎么和你互动&am…

模拟接口数据之使用Fetch方法实现

文章目录 前言一、package.json配置mock执行脚本二、封装接口&#xff0c;区分走ajax还是fetch三、创建mock目录&#xff0c;及相关接口文件四、定义接口五、使用mock数据使用模拟数据优化fetch返回数据 六、不使用模拟数据七、对比其他需要使用依赖相关配置如有启发&#xff0…

什么叫做云安全?云安全有哪些要求?

云安全(Cloud Security)是一种基于云计算的安全防护策略&#xff0c;旨在保护企业数据和应用程序的安全性和完整性。云安全利用云计算的分布式处理和存储能力&#xff0c;以更高效、更灵活的方式提供安全服务。 云安全的要求主要包括以下几个方面&#xff1a; 数据安全和隐私保…

k8s的service自动发现服务:实战版

Service服务发现的必要性: 对于kubernetes整个集群来说&#xff0c;Pod的地址也可变的&#xff0c;也就是说如果一个Pod因为某些原因退出了&#xff0c;而由于其设置了副本数replicas大于1&#xff0c;那么该Pod就会在集群的任意节点重新启动&#xff0c;这个重新启动的Pod的I…

【python自动化】Playwright基础教程(四)事件操作①高亮元素匹配器鼠标悬停

本文目录 文章目录 前言高亮显示元素定位 - highlighthighlight实战highlight定位多个元素 元素匹配器 - nthnth实战演示 元素匹配 - first&last 综合定位方式时间操作进行实战&#xff0c;巩固之前我们学习的定位方式。 这一部分内容对应官网 : https://playwright.dev/py…

⑦【MySQL】什么是约束?如何使用约束条件?主键、自增、外键、非空....

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ 约束 ⑦【MySQL】约束条件1. 约束的基本使用2.…

5.运行时数据区-字符串常量池、程序计数器、直接内存

目录 概述字符串常量池字符串常量池存储数据的方式三种常量池字面量与符号引用 哈希表实战 程序计数器直接内存直接内存与堆内存比较 结束 概述 相关文章在此总结如下&#xff1a; 文章地址jvm基本知识地址jvm类加载系统地址双亲委派模型与打破双亲委派地址运行时数据区地址 …

Spring事务之AOP导致事务失效问题

情况说明 首先开启了AOP&#xff0c;并且同时开启了事务。下面这个TransactionAspect就是一个简单的AOP切面&#xff0c;有一个Around通知。 Aspect Component public class TransactionAspect {Pointcut("execution(* com.qhyu.cloud.datasource.service.TransactionSe…