Custom C++ and CUDA Extensions - PyTorch

news2025/1/10 11:44:32

0. Abstract

经历了一波 pybind11 和 CUDA 编程 的学习, 接下来看一看 PyTorch 官方给的 C++/CUDA 扩展的教程. 发现极其简单, 就是直接用 setuptools 导出 PyTorch C++ 版代码的 Python 接口就可以了. 所以, 本博客包含以下内容:

  • LibTorch 初步;
  • C++ Extension 例子;

1. LibTorch 初步

在 PyTorch 的首页安装指引中就可以看到 PyTorch 是支持 C++/Java 的:

下载后解压到一个地方, 如 /opt/libtorch. 然后就可以使用 C++ 编写 PyTorch 程序了. 官方给的有相关例子, 我们选择最经典的 MNIST 手写数字识别项目来看一看:

mnist/
├── CMakeLists.txt
├── README.md
└── mnist.cpp

1.1 CMake 项目

CMakeLists.txt 是构建 cpp 项目的说明文件:

cmake_minimum_required(VERSION 3.5)
project(mnist)
set(CMAKE_CXX_STANDARD 17)

find_package(Torch REQUIRED)

option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
if (DOWNLOAD_MNIST)
	message(STATUS "Downloading MNIST dataset")
	execute_process(
		COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py -d ${CMAKE_BINARY_DIR}/data
		ERROR_VARIABLE DOWNLOAD_ERROR
	)
	if (DOWNLOAD_ERROR)
		message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")
	endif()
endif()

add_executable(mnist mnist.cpp)
target_compile_features(mnist PUBLIC cxx_range_for)
target_link_libraries(mnist ${TORCH_LIBRARIES})

为了下载 MNIST 数据集, 这里用到了一个 Python 文件 ../tools/download_mnist.py, 执行 cmake 后, 编译根目录(build)会出现一个 data 数据文件夹.

  • find_package(Torch REQUIRED) 查找 libtorch 时可能需要指定路径:
    find_package(Torch REQUIRED PATHS "path/to/libtorch/")
  • make 时, Ubuntu18.04 下出现错误: undefined reference to symbol ‘pthread_create@@GLIBC_2.2.5’.
    => 经查阅资料, 说: pthread 不是 linux 下的默认的库, 也就是在链接的时候, 无法找到 phread 库中线程函数的入口地址, 于是链接会失败.
    => 解决方案: target_link_libraries(mnist ${TORCH_LIBRARIES} -lpthread -lm)

make 之后, 执行 ./mnist 就能进行训练与测试了:

CUDA available! Training on GPU.
Train Epoch: 1 [59584/60000] Loss: 0.2078
Test set: Average loss: 0.2062 | Accuracy: 0.935
Train Epoch: 2 [59584/60000] Loss: 0.2039
Test set: Average loss: 0.1304 | Accuracy: 0.959
...

1.2 PyTorch C++ API

接下来看 C++ 代码:

struct Net : torch::nn::Module
{
	Net() : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),
			conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),
			fc1(320, 50),
			fc2(50, 10)
	{
		register_module("conv1", conv1);
		register_module("conv2", conv2);
		register_module("conv2_drop", conv2_drop);
		register_module("fc1", fc1);
		register_module("fc2", fc2);
	}

	torch::Tensor forward(torch::Tensor &x)
	{
		x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
		x = torch::relu(torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
		x = x.view({-1, 320});
		x = torch::relu(fc1->forward(x));
		x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());
		x = fc2->forward(x);
		return torch::log_softmax(x, /*dim=*/1);
	}

	torch::nn::Conv2d conv1;
	torch::nn::Conv2d conv2;
	torch::nn::Dropout2d conv2_drop;
	torch::nn::Linear fc1;
	torch::nn::Linear fc2;
};

template<typename DataLoader>
void train(
	size_t epoch,
	Net &model,
	torch::Device device,
	DataLoader &data_loader,
	torch::optim::Optimizer &optimizer,
	size_t dataset_size
)
{
	model.train();
	size_t batch_idx = 0;
	for (auto &batch: data_loader)
	{
		auto data = batch.data.to(device), targets = batch.target.to(device);
		auto output = model.forward(data);
		auto loss = torch::nll_loss(output, targets);
		AT_ASSERT(!std::isnan(loss.template item<float>()));
		optimizer.zero_grad();
		loss.backward();
		optimizer.step();
		...
	}
}

可以看到, 代码非常简单, 几乎和 Python 接口一致, 如果把 :: 换成 ., 就更像了. 不一样的是多了些类型限制以及一些语法. 具体的我们不多研究, 终究还是没有 Python 简洁好用. 但简单了解一下 PyTorch C++ API 的文档说明还是有必要的:

所以, 这个 LibTorch 既能用来写 C++ 项目, 也能用来给 PyTorch 写扩展. 不过官方还是推荐使用 Python 接口:

2. C++ Extension 例子

官方文档给的例子比较复杂, 这里举一个简单的例子, 把计算:

y = torch.relu(torch.matmul(x, w.t()) + b)

整合到一个操作里, 也就是使用 LibTorch C++ 编写一个等价的运算, 并导出 Python 接口. 这么做的理由是:

大概意思就是 Python 比较慢, 由 Python 一次次调用操作而频繁启动 CUDA 核会拖慢速度.

其实我觉得只有用 CUDA 编程把序列操作整合起来才能真正减少 CUDA 核的频繁启动, LibTorch 能加速可能就是因为 C++ 更快而已.

直接上代码吧, 整个项目的解构是这样子的:

LinearAct/
├── linearfun.py
├── linearact.cpp
└── setup.py

linearact.cpp 包含了组合操作的 forward 过程和 backward 过程, 前者计算正向的正常计算, 后者计算反向的梯度计算:

#include <torch/extension.h>  // 注意这里头文件和直接写 C++ 项目不一样
#include <vector>

std::vector<at::Tensor> forward(torch::Tensor &input, torch::Tensor &weight, torch::Tensor &bias)
{
	auto relu_input = input.mm(weight.t()) + bias;
	auto output = torch::relu(relu_input);
	return {relu_input, output};  // relu_input 会在梯度计算时用到
}

std::vector<torch::Tensor>
backward(torch::Tensor &grad_output, torch::Tensor &relu_input, torch::Tensor &input, torch::Tensor &weight)
{   // 求导链式法则
	auto grad_relu = grad_output.masked_fill(relu_input < 0, 0);
	auto grad_input = grad_relu.mm(weight);
	auto grad_weight = grad_relu.t().mm(input);
	auto grad_bias = grad_relu.sum(0);
	return {grad_input, grad_weight, grad_bias};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
	m.def("forward", &forward, "Custom forward");
	m.def("backward", &backward, "Custom backward");
}

这涉及到 pybind11 的用法, 详情见《pybind11 学习笔记》, 还涉及到使用 torch.autograd.Function 自定义运算的梯度计算, 详情见《PyTorch 中的 apply [autograd.Function]》. 总之, 现在我们使用 LibTorch 写了组合操作, 并写了其参数的梯度计算. linearfun.py 是利用 torch.autograd.Functionforwardbackward 整合到一起, 组成一个完整的可以进行反向梯度传播的组合运算:

import torch  # 注意, 导入 linearact 前, 应先导入 torch
import linearact

class LinearActFunction(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input, weights, bias):
		relu_input, output = linearact.forward(input, weights, bias)  # c++ 函数
		variables = [relu_input, input, weights]
		ctx.save_for_backward(*variables)
		return output

	@staticmethod
	def backward(ctx, grad_output):
		outputs = linearact.backward(grad_output, *ctx.saved_tensors)  # c++ 函数
		grad_x, grad_w, grad_b = outputs
		return grad_x, grad_w, grad_b

mylinear = LinearActFunction.apply

LibTorch C++ 代码由 setuptools 导出 Python 接口:

from setuptools import setup
from torch.utils import cpp_extension

setup(
	name='linearact',
	ext_modules=[cpp_extension.CppExtension('linearact', ['linearact.cpp'])],
	cmdclass={'build_ext': cpp_extension.BuildExtension}  # 整合了 pybind11 的功能
)

在命令行执行:

python setup.py install

就可以将 linearact 包安装到 Python 系统中, 任务完成. 下面进行验证:

import torch
from linearfun import mylinear

x = torch.randn(2, 3, requires_grad=True)
w = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, requires_grad=True)
# 复制一份一样的参数
x1 = torch.from_numpy(x.detach().numpy())
w1 = torch.from_numpy(w.detach().numpy())
b1 = torch.from_numpy(b.detach().numpy())
x1.requires_grad_(True)
w1.requires_grad_(True)
b1.requires_grad_(True)

# %% pytorch
y = torch.relu(torch.matmul(x, w.t()) + b)
y = y.norm(p=2)
print(y)

y.backward()
print(x.grad)
print(w.grad)
print(b.grad)

# %% custom
print('---------------------------')
y = mylinear(x1, w1, b1)
y = y.norm(p=2)
print(y)

y.backward()
print(x1.grad)
print(w1.grad)
print(b1.grad)

执行一次:

tensor(1.2664, grad_fn=<LinalgVectorNormBackward0>)
tensor([[ 0.0851, -1.0418,  0.3958],
        [ 0.0566, -0.6925,  0.2631]])
tensor([[ 0.0000,  0.0000,  0.0000],
        [-1.0724,  0.3669, -0.1399]])
tensor([0.0000, 1.3864])
---------------------------
tensor(1.2664, grad_fn=<LinalgVectorNormBackward0>)
tensor([[ 0.0851, -1.0418,  0.3958],
        [ 0.0566, -0.6925,  0.2631]])
tensor([[ 0.0000,  0.0000,  0.0000],
        [-1.0724,  0.3669, -0.1399]])
tensor([0.0000, 1.3864])

可以看见两者一模一样. 至于测速什么的不在本博文的考虑范围之内, 只是想了解 PyTorch 如何进行 C++ 扩展.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2185578.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python-鸡尾酒疗法/图像相似度/第n小的质数

一&#xff1a;鸡尾酒疗法 题目描述 鸡尾酒疗法&#xff0c;原指“高效抗逆转录病毒治疗”&#xff08;HAART&#xff09;&#xff0c;由美籍华裔科学家何大一于 1996 年提出&#xff0c;是通过三种或三种以上的抗病毒药物联合使用来治疗艾滋病。该疗法的应用可以减少单一用药产…

什么是ETL?什么是ELT?怎么区分它们使用场景

ELT和ETL这两种模式从字面上来看就是一个顺序颠倒的问题&#xff0c;每个单词拆开来看其实都是一样的。E代表的是Extract&#xff08;抽取&#xff09;&#xff0c;也就是从源端拉取数据&#xff1b;T代表的是Transform&#xff08;转换&#xff09;&#xff0c;对一些结构化或…

Visual Studio2017编译GDAL3.0.2源码过程

一、编译环境 操作系统&#xff1a;Windows 10企业版 编译工具&#xff1a;Visual Studio 2017旗舰版 源码版本&#xff1a;gdal3.0.2 二、生成解决方案 打开Visual Studio 2017的x64本机生成工具&#xff0c;切换到gdal3.0.2源码根目录&#xff1b;执行generate_vcxproj.b…

D25【 python 接口自动化学习】- python 基础之判断与循环

day25 for 循环 学习日期&#xff1a;20241002 学习目标&#xff1a;判断与循环&#xfe63;-35 for 循环&#xff1a;如何遍历一个对象里的所有元素&#xff1f; 学习笔记&#xff1a; for 循环与while循环的区别 for循环的定义 使用for循环遍历序列 使用for循环遍历字典…

【理论科学与实践技术】数学与经济管理中的学科与实用算法

在现代商业环境中&#xff0c;数学与经济管理的结合为企业提供了强大的决策支持。包含一些主要学科&#xff0c;包括数学基础、经济学模型、管理学及风险管理&#xff0c;相关的实用算法和这些算法在中国及全球知名企业中的实际应用。 一、数学基础 1). 发现人及著名学者 发…

目标检测评价指标

混淆矩阵&#xff08;Confusion Matrix&#xff09; 准确率&#xff08;accuracy&#xff09; 准确率&#xff1a;预测正确的样本数 / 样本数总数 &#xff08;正对角线 / 所有&#xff09; 精度&#xff08;precision&#xff09; 精度&#xff1a;预测正确里面有多少确实是…

深入理解MySQL中的MVCC原理及实现

目录 什么是MVCC&#xff1f; MVCC实现原理 Undo Log 日志 InnoDB行格式 undo日志格式 1. insert undo log格式 2. update undo log格式 事务回滚机制 Read View MVCC案例分析 案例01-读已提交RC隔离级别下的可见性分析 案例02-可重复读RR隔离级别下的可见性分析 什…

英语词汇小程序小程序|英语词汇小程序系统|基于java的四六级词汇小程序设计与实现(源码+数据库+文档)

英语词汇小程序 目录 基于java的四六级词汇小程序设计与实现 一、前言 二、系统功能设计 三、系统实现 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大厂码农|毕设布道师&a…

【优选算法】(第十六篇)

目录 连续数组&#xff08;medium&#xff09; 题目解析 讲解算法原理 编写代码 矩阵区域和&#xff08;medium&#xff09; 题目解析 讲解算法原理 编写代码 连续数组&#xff08;medium&#xff09; 题目解析 1.题目链接&#xff1a;. - 力扣&#xff08;LeetCode&a…

【重学 MySQL】五十一、更新和删除数据

【重学 MySQL】五十一、更新和删除数据 更新数据删除数据注意事项 在MySQL中&#xff0c;更新和删除数据是数据库管理的基本操作。 更新数据 为了更新&#xff08;修改&#xff09;表中的数据&#xff0c;可使用UPDATE语句。UPDATE语句的基本语法如下&#xff1a; UPDATE ta…

前端学习第二天笔记 CSS选择 盒子模型 浮动 定位 CSS3新特性 动画 媒体查询 精灵图雪碧图 字体图标

CSS学习 CSS选择器全局选择器元素选择器类选择器ID选择器合并选择器 选择器的优先级字体属性背景属性文本属性表格属性表格边框折叠边框表格文字对齐表格填充表格颜色 关系选择器后代选择器子代选择器相邻兄弟选择器通用兄弟选择器 CSS盒子模型弹性盒子模型父元素上的属性flex-…

Linux 安装 yum

第一步&#xff1a;下载安装包 这里以 CentOS 7 为例 wget https://vault.centos.org/7.2.1511/os/x86_64/Packages/yum-3.4.3-132.el7.centos.0.1.noarch.rpm wget https://vault.centos.org/7.2.1511/os/x86_64/Packages/yum-metadata-parser-1.1.4-10.el7.x86_64.rpm wget…

计算机网络(十) —— IP协议详解,理解运营商和全球网络

目录 一&#xff0c;关于IP 1.1 什么是IP协议 1.2 前置认识 二&#xff0c;IP报头字段详解 三&#xff0c;网段划分 3.1 IP地址的构成 3.2 网段划分 3.3 子网划分 3.4 IP地址不足问题 四&#xff0c;公网IP和私有IP 五&#xff0c;理解运营商和全球网络 六&#xff…

硬件面试(一)

网上别人的硬件面试记录&#xff0c;察漏补缺&#xff1a; 1.骄傲容易被打脸&#xff01; 励磁电感和谐振电感的比值K大小有什么含义: 励磁电感和谐振电感的比值 KKK 通常用来衡量电路的特性。当 KKK 较大时&#xff0c;表示励磁电感相对于谐振电感较强&#xff0c;可能导致…

力扣题解1870

这道题是一个典型的算法题&#xff0c;涉及计算在限制的时间内列车速度的最小值。这是一个优化问题&#xff0c;通常需要使用二分查找来求解。 题目描述&#xff08;中等&#xff09; 准时到达的列车最小时速 给你一个浮点数 hour &#xff0c;表示你到达办公室可用的总通勤时…

基于SSM的坚果金融投资管理系统、坚果金融投资管理平台的设计与开发、智慧金融投资管理系统的设计与实现、坚果金融投资管理系统的设计与应用研究(源码+定制+开发)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

如何在 Kubernetes 集群中安装和配置 OpenEBS 持久化块存储?

在 Kubernetes 集群中安装和配置 OpenEBS 持久化块存储是一项常见的任务&#xff0c;特别是在需要提供高可用和动态扩展的存储解决方案时。OpenEBS 是一个基于容器的存储解决方案&#xff0c;它允许你在 Kubernetes 集群中实现持久化存储卷&#xff08;Persistent Volumes&…

Microsoft 发布 PyRIT - 生成式 AI 的红队工具

微软发布了一个名为PyRIT&#xff08;Python风险识别工具的缩写&#xff09;的开放访问自动化框架&#xff0c;用于主动识别生成式人工智能&#xff08;AI&#xff09;系统中的风险。 这个红队工具旨在“使全球的每个组织都能够负责任地利用最新的人工智能进步进行创新”&…

ros2 自定义工作空间添加source

新建一个工作空间&#xff1a;ros2 create pkg~~~~~~~~~~~~ colcon build之后 &#xff0c;在install文件夹里面有一个 setup,bash文件 将这个文件添加到 bashrc gedit .bashrc 这样 在一个新终端中可以直接运行ros2 run package name &#xff08;包名&#xff09; 可执行…

消息中间件---初识(Kafka、RocketMQ、RabbitMQ、ActiveMQ、Redis)

1. 简介 消息中间件是一种支撑性软件系统&#xff0c;它在网络环境中为应用系统提供同步或异步、可靠的消息传输。消息中间件利用高效可靠的消息传递机制进行与平台无关的数据交流&#xff0c;并基于数据通信来进行分布式系统的集成。它支持多种通信协议和数据格式&#xff0c;…