PyTorch 添加 C++ 拓展

news2025/1/11 21:41:56

参考内容:pytorch添加C++拓展简单实战编写及基本功能测试

文章目录

  • 第一步:编写 C++ 模块
    • test.h
    • test.cpp
  • 第二步:编写 setup.py
  • 第三步:安装 C++ 模块
  • 第四步:验证安装
  • 第五步:C++ 模块使用
    • test_cpp1.py
    • test_cpp2.py
  • 运行结果
  • 扩展阅读

编译安装前的文件目录:

这里的 csrc 应该不是指 pytorch 项目中的 /torch/csrc

csrc
├─ cpu
│    ├─ test.cpp
│    └─ test.h
└─ setup.py

第一步:编写 C++ 模块

test.h

#include <torch/extension.h>
#include <vector>

// 前向传播
torch::Tensor Test_forward_cpu(const torch::Tensor& inputA, const torch::Tensor& inputB);

// 反向传播
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput);

test.cpp

#include "test.h"

// 前向传播
torch::Tensor Test_forward_cpu(const torch::Tensor& x, const torch::Tensor& y){
    AT_ASSERTM(x.sizes() == y.sizes(), "x must be the same size as y");
    torch::Tensor z = torch::zeros(x.sizes());
    z = 2 * x + y;
    return z;
}

// 反向传播
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput){
    torch::Tensor gradOutputX = 2 * gradOutput * torch::ones(gradOutput.sizes());
    torch::Tensor gradOutputY = gradOutput * torch::ones(gradOutput.sizes());
    return {gradOutputX, gradOutputY};
}

// pybind11 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("forward", &Test_forward_cpu, "TEST forward");
    m.def("backward", &Test_backward_cpu, "TEST backward");
}

第二步:编写 setup.py

from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CppExtension

# 头文件目录
include_dirs = os.path.dirname(os.path.abspath(__file__))
# 源代码目录
source_cpu = glob.glob(os.path.join(include_dirs, 'cpu', '*.cpp'))

setup(
    name='test_cpp', # 模块名称,需要在 python 中调用
    version="0.1",
    ext_modules=[
        CppExtension('test_cpp', sources=source_cpu, include_dirs=[include_dirs]),
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

第三步:安装 C++ 模块

在 csrc 文件夹下运行命令

python setup.py install

第一次尝试的报错信息:

/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/setuptools/_distutils/cmd.py:66: EasyInstallDeprecationWarning: easy_install command is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()

参考 SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip 后得知是 setuptools 版本太高,于是降低 setuptools 版本,pip install setuptools==58.2.0

第二次尝试的运行结果:

running install
running bdist_egg
running egg_info
writing test_cpp.egg-info/PKG-INFO
writing dependency_links to test_cpp.egg-info/dependency_links.txt
writing top-level names to test_cpp.egg-info/top_level.txt
reading manifest file 'test_cpp.egg-info/SOURCES.txt'
writing manifest file 'test_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'test_cpp' extension
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu
Emitting ninja build file /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/1] c++ -MMD -MF /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o.d -pthread -B /home/zjma/.conda/envs/debugtest/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/zjma/pytorch_v1.13.1/csrc -I/home/zjma/pytorch_v1.13.1/torch/include -I/home/zjma/pytorch_v1.13.1/torch/include/torch/csrc/api/include -I/home/zjma/pytorch_v1.13.1/torch/include/TH -I/home/zjma/pytorch_v1.13.1/torch/include/THC -I/home/zjma/.conda/envs/debugtest/include/python3.8 -c -c /home/zjma/pytorch_v1.13.1/csrc/cpu/test.cpp -o /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=test_cpp -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
cc1plus: warning: command-line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
creating build/lib.linux-x86_64-3.8
g++ -pthread -shared -B /home/zjma/.conda/envs/debugtest/compiler_compat -L/home/zjma/.conda/envs/debugtest/lib -Wl,-rpath=/home/zjma/.conda/envs/debugtest/lib -Wl,--no-as-needed -Wl,--sysroot=/ /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o -L/home/zjma/pytorch_v1.13.1/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-3.8/test_cpp.cpython-38-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.8/test_cpp.cpython-38-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for test_cpp.cpython-38-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/test_cpp.py to test_cpp.cpython-38.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.test_cpp.cpython-38: module references __file__
creating 'dist/test_cpp-0.1-py3.8-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing test_cpp-0.1-py3.8-linux-x86_64.egg
removing '/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg' (and everything under it)
creating /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg
Extracting test_cpp-0.1-py3.8-linux-x86_64.egg to /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages
test-cpp 0.1 is already the active version in easy-install.pth

Installed /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg
Processing dependencies for test-cpp==0.1
Finished processing dependencies for test-cpp==0.1

编译安装后的文件目录:

csrc
├─ build
│    ├─ bdist.linux-x86_64
│    ├─ lib.linux-x86_64-3.8
│    │    └─ test_cpp.cpython-38-x86_64-linux-gnu.so
│    ├─ lib.linux-x86_64-cpython-38
│    │    └─ test_cpp.cpython-38-x86_64-linux-gnu.so
│    ├─ temp.linux-x86_64-3.8
│    │    ├─ .ninja_deps
│    │    ├─ .ninja_log
│    │    ├─ build.ninja
│    │    └─ home
│    └─ temp.linux-x86_64-cpython-38
│           ├─ .ninja_deps
│           ├─ .ninja_log
│           ├─ build.ninja
│           └─ home
├─ cpu
│    ├─ test.cpp
│    └─ test.h
├─ dist
│    └─ test_cpp-0.1-py3.8-linux-x86_64.egg
├─ setup.py
└─ test_cpp.egg-info
       ├─ PKG-INFO
       ├─ SOURCES.txt
       ├─ dependency_links.txt
       └─ top_level.txt

第四步:验证安装

1、在虚拟环境的路径 /lib/python3.8/site-packages 下看到 test_cpp-0.1-py3.8-linux-x86_64.egg 文件
在这里插入图片描述
2、conda list 查看当前虚拟环境下已经安装的包
在这里插入图片描述3、进入 python 的交互模式,import test_cpp 后报错:

>>> import test_cpp
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: libc10.so: cannot open shared object file: No such file or directory

参考 通过Python setup.py install的第三方包,import时却无法导入是什么问题呢? - 神经的网络里挣扎的回答 - 知乎,因为编译的 test_cpp 包需要依赖 torch 包,导致无法导入。所以,在 import test_cpp 前要先 import torch

第五步:C++ 模块使用

test_cpp1.py

import torch
import test_cpp
from torch.autograd import Function

class TestFunction(Function):

    @staticmethod
    def forward(ctx, x, y):
        return test_cpp.forward(x, y)

    @staticmethod
    def backward(ctx, gradOutput):
        gradX, gradY = test_cpp.backward(gradOutput)
        return gradX, gradY

class Test(torch.nn.Module):
    def __init__(self):
        super(Test, self).__init__()
    def forward(self, inputA, inputB):
        return TestFunction.apply(inputA, inputB)

test_cpp2.py

import torch
from torch.autograd import Variable

from test_cpp1 import Test

x = Variable(torch.Tensor([1,2,3]), requires_grad=True)
y = Variable(torch.Tensor([4,5,6]), requires_grad=True)

test = Test()
z = test(x, y)
z.sum().backward()

print('x: ', x)
print('y: ', y)
print('z: ', z)
print('x.grad: ', x.grad)
print('y.grad: ', y.grad)

运行结果

/home/zjma/.conda/envs/debugtest/bin/python /home/zjma/PycharmProjects/pythonProject/test_cpp2.py 
x:  tensor([1., 2., 3.], requires_grad=True)
y:  tensor([4., 5., 6.], requires_grad=True)
z:  tensor([ 6.,  9., 12.], grad_fn=<TestFunctionBackward>)
x.grad:  tensor([2., 2., 2.])
y.grad:  tensor([1., 1., 1.])

进程已结束,退出代码为 0

运行结果符合预期。

扩展阅读

  • pytorch之c++/cuda拓展(讲得很详细,举的例子和上文基本一样,但用到了CUDA,很多内容可以扩展去看)
  • 官方教程 相关内容的笔记(后面可以复现一下)
    • PyTorch进阶1:C++扩展
    • pytorch 的C++扩展

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1412369.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何使用preact开始一个前端项目?

本篇文章对于preact不做过深介绍&#xff0c;仅仅介绍其基础的使用方法。使用Preact&#xff0c;我们可以通过组装组件和元素的树来创建用户界面。组件是返回其树应该输出的内容的描述的函数或类。这些描述通常是用JSX&#xff08;如下所示&#xff09;或HTML编写的&#xff0c…

CSS之粘性定位

让我为大家介绍一下粘性定位吧&#xff01; 大家应该都了解过绝对定位&#xff0c;它是相对于父级定位 那么粘性定位相对于谁呢&#xff1f; 它相对于overflow:hidden; 如果没找到就会跟fixed固定定位一样&#xff0c;相对于视口 <!DOCTYPE html> <html lang"en…

地图在游戏中的应用案例:王者荣耀

腾讯位置服务&#xff0c;作为国内地图导航的领头羊&#xff0c;在各行业中应用广泛&#xff0c;包括&#xff1a;网约车、智能物流、车用地图、智能穿戴、智能景区、运输安全监控、金融地图、运动健康、房产服务、智慧交通、时空大数据慧眼、专网地图等等。 腾讯地图与其他竞…

视频监控方案设计:EasyCVR视频智能监管系统方案技术特点与应用

随着科技的发展&#xff0c;视频监控平台在各个领域的应用越来越广泛。然而&#xff0c;当前的视频监控平台仍存在一些问题&#xff0c;如视频质量不高、监控范围有限、智能化程度不够等。这些问题不仅影响了监控效果&#xff0c;也制约了视频监控平台的发展。 为了解决这些问…

分享几种常见的OCR图形识别API接口

VIN识别 支持对车辆挡风玻璃处和行驶证车架号码进行识别。 银行卡识别 识别出该卡的银行卡号、所属银行、卡片类型以及银行邮编等信息。 通用文字识别 自动提取及快速识别出图像中文字内容&#xff0c;适用于多场景图像文字识别。 身份证识别 识别及提取身份证正反面所有字段…

EG-2102CB 表面声波(SAW)振荡器

表面声波&#xff08;SAW&#xff09;振蒎器&#xff0c;简称声表晶振&#xff0c;其频率范围非常广泛&#xff0c;可实现从100MHz到700MHz的精度频率输出。其标准工作电源电压为3.3V&#xff0c;具有高稳定性。输出特性稳定&#xff0c;具有低抖动、高精度、高线性等优点。其输…

ChatGPT 和文心一言 | 两大AI助手哪个更胜一筹

欢迎来到英杰社区&#xff1a; https://bbs.csdn.net/topics/617804998 欢迎来到阿Q社区&#xff1a; https://bbs.csdn.net/topics/617897397 &#x1f4d5;作者简介&#xff1a;热爱跑步的恒川&#xff0c;致力于C/C、Java、Python等多编程语言&#xff0c;热爱跑步&#xff…

您有一份OpenHarmony开发者论坛2023年度总结,请查收~

2023年11月&#xff0c;OpenHarmony开发者论坛1.0版本正式上线。感谢各位开发者对OpenHarmony的大力支持和热爱&#xff0c;成为OpenHarmony开发者论坛的第一批体验用户&#xff0c;并迅速在论坛开启了OpenHarmony技术交流。 通过开发者们在论坛进行提问、答疑、分享技术文章、…

Flask 之旅 (二):表单

背景 上一篇帖子我们使用 Flask 创建了最基本的 web 服务。使用 bootstrap 对页面进行装点&#xff0c;使用 JQuery Ajax 实现了在页面上实时显示 log 的功能。趁着周末&#xff0c;我继续开始学习更多的东西以满足这个 web 服务的需求。 模板继承 之前我们有了首页&#xf…

如何使用Docker安装Spug并实现远程访问本地运维管理界面

文章目录 前言1. Docker安装Spug2 . 本地访问测试3. Linux 安装cpolar4. 配置Spug公网访问地址5. 公网远程访问Spug管理界面6. 固定Spug公网地址 前言 Spug 面向中小型企业设计的轻量级无 Agent 的自动化运维平台&#xff0c;整合了主机管理、主机批量执行、主机在线终端、文件…

算法基础学习|离散化与区间合并

位运算 代码模板 求n的第k位数字: n >> k & 1 返回n的最后一位1&#xff1a;lowbit(n) n & -n 题目&#xff1a;二进制中1的个数 题目 给定一个长度为 的数列&#xff0c;请你求出数列中每个数的二进制表示中 1 的个数。 输入格式 第一行包含整数 。 第…

kubeSphere DevOps自定义容器 指定nodejs版本

✨✨✨✨✨✨ &#x1f380;前言&#x1f381;基于内置镜像构建&#x1f381;把镜像添加基础容器中&#x1f381;检查容器是否配置成功&#x1f381;不生效的原因排查&#x1f381;按步骤执行如下命令 &#x1f380;前言 由于我本地的开发环境node是16.18.1,而自带容器node的版…

体验华为云对话机器人服务 CBS

&#x1f3e1;浩泽学编程&#xff1a;个人主页 &#x1f525; 推荐专栏&#xff1a;《深入浅出SpringBoot》《java对AI的调用开发》 《RabbitMQ》《Spring》《SpringMVC》 &#x1f6f8;学无止境&#xff0c;不骄不躁&#xff0c;知行合一 文章目录 前言一、开通…

【虚拟化 VS 容器化】

目录 1. 虚拟化1.1什么是虚拟化&#xff1f;1.2虚拟化的特点1.3虚拟化主流技术1.4虚拟化的应用场景 2. 容器化2.1什么是容器化&#xff1f;2.2容器化的特点2.3容器化主流技术2.4容器化的应用场景 3. 虚拟化VS容器化3.1图解区别3.2架构区别3.3表式区别 4. 虚拟化的发展趋势参考链…

MSTP协议

目录 MSTP 基本原则 MSTP术语 BPDU变化 三种生成树的比较 MSTP MSTP&#xff08;802.1s&#xff09;多生成树。 多生成树(MSTP)解决&#xff1a; &#xff08;1&#xff09;去掉环 &#xff08;2&#xff09;负载均衡&#xff08;重点&#xff09; &#xff08;3&#xf…

本地Vscode使用SSH连接Linux虚拟机循环输入密码,无法登陆

今天在工作的时候没有在本地关闭Vscode的前提下&#xff0c;重启了虚拟机后&#xff0c;发现ssh连接不上了&#xff0c;症状就是反复输入密码就是进不去系统&#xff0c;查了很多网上的教程都没啥用&#xff1b; 最后就一招彻底解决问题&#xff1a; 第一步&#xff1a;打开虚…

发生内存泄漏后

内存泄漏是指程序在运行过程中分配的内存无法被释放&#xff0c;导致内存使用量不断增加&#xff0c;最终可能导致程序崩溃或系统崩溃。 产生内存泄漏的原因 内存泄漏可能是由多种原因造成的&#xff0c;例如&#xff1a; 忘记释放内存。由于项目比较大&#xff0c;一般申请内…

电脑自动开机播放PPT的解决方案

客户有个需求&#xff0c;要求与LED大屏幕连接的电脑定时自动播放PPT。为了安全电脑在不播放的时段&#xff0c;必须关机。 目录 1、使用“时控插座”并进行设置 2、戴尔电脑BIOS设置&#xff08;上电开机&#xff09; 3、设置Windows自动登录 4、任务计划设置 5、启动Au…

数据结构与算法-二叉树-路径总和lll

路径总和lll 给定一个二叉树的根节点 root &#xff0c;和一个整数 targetSum &#xff0c;求该二叉树里节点值之和等于 targetSum 的 路径 的数目。 路径 不需要从根节点开始&#xff0c;也不需要在叶子节点结束&#xff0c;但是路径方向必须是向下的&#xff08;只能从父节…

华为配置ACL限制用户通过Telnet登录设备

配置ACL限制用户通过Telnet登录设备示例 组网需求 如图1所示&#xff0c;PC与设备之间路由可达&#xff0c;用户希望简单方便的配置和管理远程设备&#xff0c;可以在服务器端配置Telnet用户使用AAA验证登录&#xff0c;并配置安全策略&#xff0c;保证只有符合安全策略的用户才…