【pytorch扩展】CUDA自定义pytorch算子(简单demo入手)

news2024/7/7 16:46:40

Pytorch作为一款优秀的AI开发平台,提供了完备的自定义算子的规范。我们用torch开发时,经常会因为现有算子的不足限制我们idea的迸发。于是,CUDA/C++自定义pytorch算子是不得不磕了。

今天通过一个小实验来梳理自定义pytorch算子都需要做哪些准备。比如,我们做一个张量加法。
vim test_add.py

from add import sum_double_op
import torch
import time

class Timer:
    def __init__(self, op_name):
        self.begin_time = 0
        self.end_time = 0
        self.op_name = op_name

    def __enter__(self):
        torch.cuda.synchronize()
        self.begin_time = time.time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
        self.end_time = time.time()
        print(f"Average time cost of {self.op_name} is {(self.end_time - self.begin_time) * 1000:.4f} ms")


if __name__ == '__main__':
    n = 1000000
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tensor1 = torch.ones(n, dtype=torch.float32, device=device, requires_grad=True)
    tensor2 = torch.ones(n, dtype=torch.float32, device=device, requires_grad=True)
    with Timer("sum_double"):
        ans = sum_double_op(tensor1, tensor2)

这里的"sum_double_op"就是我们用CUDA写的算子。那这个可以直接调用,并且可以传递梯度的算子,需要怎么做呢?


众所周知,CUDA/C++都是编译性语言,编译以后再调用会比python这种解释性语言更快。所以,我们需要对CUDA有一个编译过程。这个编译过程用setuptools来实现(可以pip安装)。
先vim setup.py

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='myAdd',
    packages=find_packages(),
    version='0.1.0',
    author='muzhan',
    ext_modules=[
        CUDAExtension(
            'sum_double',
            ['./add/add.cpp',
             './add/add_cuda.cu',]
        ),
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

直接“python setup.py install”即可完成cuda算子的编译和安装。等等,你的add.cpp和add_cuda.cu还没呢?
vim add_cuda.cu

#include <cstdio>
#define THREADS_PER_BLOCK 256
#define WARP_SIZE 32
#define DIVUP(m, n) ((m + n - 1) / n)


__global__ void two_sum_kernel(const float* a, const float* b, float * c, int n){
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n){
        c[idx] = a[idx] + b[idx];
    }
}


void two_sum_launcher(const float* a, const float* b, float* c, int n){
    dim3 blockSize(DIVUP(n, THREADS_PER_BLOCK));
    dim3 threadSize(THREADS_PER_BLOCK);
    two_sum_kernel<<<blockSize, threadSize>>>(a, b, c, n);
}

vim add.cpp

#include <torch/extension.h>
#include <torch/serialize/tensor.h>

#define CHECK_CUDA(x) \
  TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
  TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
  CHECK_CUDA(x);       \
  CHECK_CONTIGUOUS(x)


void two_sum_launcher(const float* a, const float* b, float* c, int n);


void two_sum_gpu(at::Tensor a_tensor, at::Tensor b_tensor, at::Tensor c_tensor){
    CHECK_INPUT(a_tensor);
    CHECK_INPUT(b_tensor);
    CHECK_INPUT(c_tensor);

    const float* a = a_tensor.data_ptr<float>();
    const float* b = b_tensor.data_ptr<float>();
    float* c = c_tensor.data_ptr<float>();
    int n = a_tensor.size(0);
    two_sum_launcher(a, b, c, n);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &two_sum_gpu, "sum two arrays (CUDA)");
}

我们看一下文件结构:

.
├── add
│   ├── add.cpp
│   ├── add_cuda.cu
│   ├── __init__.py
│   └── sum.py
├── README.md
├── setup.py
└── test_add.py

有了add.cpp和add_cuda.cu以后,我们就可以用"python setup.py install"来进行编译和安装了。编译和安装以后,我们需要用python类封装一下:

vim __init__.py
from .sum import *

vim sum.py

from torch.autograd import Function
import sum_double


class SumDouble(Function):

    @staticmethod
    def forward(ctx, array1, array2):
        """sum_double function forward.
        Args:
            array1 (torch.Tensor): [n,]
            array2 (torch.Tensor): [n,]
        
        Returns:
            ans (torch.Tensor): [n,]
        """
        array1 = array1.float()
        array2 = array2.float()
        ans = array1.new_zeros(array1.shape)
        sum_double.forward(array1.contiguous(), array2.contiguous(), ans)

        # ctx.mark_non_differentiable(ans) # if the function is no need for backpropogation

        return ans

    @staticmethod
    def backward(ctx, g_out):
        # return None, None   # if the function is no need for backpropogation

        g_in1 = g_out.clone()
        g_in2 = g_out.clone()
        return g_in1, g_in2


sum_double_op = SumDouble.apply

最后,直接

python test_add.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1894361.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

视频监控平台web客户端的免密查看视频页:在PC浏览器上如何调试手机上的前端网页(PC上的手机浏览器的开发者工具)

目录 一、手机上做前端页面开发调试 1、背景 2、视频监控平台AS-V1000的视频分享页 3、调试手机前端页面代码的条件 二、手机端的准备工作 1、手机准备 2、手机的开发者模式 3、PC和手机的连接 &#xff08;1&#xff09;进入调试模式 &#xff08;2&#xff09;选择…

【大数据】—量化交易实战案例(海龟交易策略)

声明&#xff1a;股市有风险&#xff0c;投资需谨慎&#xff01;本人没有系统学过金融知识&#xff0c;对股票有敬畏之心没有踏入其大门&#xff0c;今天用另外一种方法模拟炒股&#xff0c;后面的模拟的实战全部用同样的数据&#xff0c;最后比较哪种方法赚的钱多。 海龟交易…

初试成绩占比百分之70!计算机专硕均分340+!华中师范大学计算机考研考情分析!

华中师范大学&#xff08;Central China Normal University&#xff09;简称“华中师大”或“华大”&#xff0c;位于湖北省会武汉&#xff0c;是中华人民共和国教育部直属重点综合性师范大学&#xff0c;国家“211工程”、“985工程优势学科创新平台”重点建设院校&#xff0c…

智慧消防视频监控烟火识别方案,筑牢安全防线

一、方案背景 在现代化城市中&#xff0c;各类小型场所&#xff08;简称“九小场所”&#xff09;如小餐馆、小商店、小网吧等遍布大街小巷&#xff0c;为市民生活提供了极大的便利。然而&#xff0c;由于这些场所往往规模较小、人员流动性大、消防安全意识相对薄弱&#xff0…

分布式计算、异构计算与算力共享

目录 算力 算力共享的技术支撑 云计算技术 边缘计算技术 区块链技术 分布式计算、异构计算与算力共享 分布式计算:计算力的“集团军作战” 异构计算:计算力的“多兵种协同” 算力共享:计算力的“共享经济” 深入融合,共创计算新纪元 算力共享对科研领域的影响 …

JavaScript懒加载图像

懒加载图像是一种优化网页性能的技术&#xff0c;它将页面中的图像延迟加载&#xff0c;即在用户需要查看它们之前不会立即加载。这种技术通常用于处理大量或大尺寸图像的网页&#xff0c;特别是那些包含长页面或大量媒体内容的网站。 好处 **1. 加快页面加载速度&#xff1a…

《昇思25天学习打卡营第9天|保存与加载》

文章目录 今日所学&#xff1a;一、构建与准备二、保存和加载模型权重三、保存和加载MindIR总结 今日所学&#xff1a; 在上一章节主要学习了如何调整超参数以进行网络模型训练。在这一过程中&#xff0c;我们通常会想要保存一些中间或最终的结果&#xff0c;以便进行后续的模…

在Windows 11上更新应用程序的几种方法,总有一种适合你

序言 让你安装的应用程序保持最新是很重要的,而Windows 11使更新Microsoft应用商店和非Microsoft应用商店的应用程序变得非常容易。我们将向你展示如何使用图形方法以及命令行方法来更新你的应用程序。 如何更新Microsoft Store应用程序 如果你的一个或多个应用程序是从Mic…

底层软件 | 十分详细,为了学习设备树,我写了5w字笔记!

0、设备树是什么&#xff1f;1、DTS 1.1 dts简介1.2 dts例子 2、DTC&#xff08;Device Tree Compiler&#xff09;3、DTB&#xff08;Device Tree Blob&#xff09;4、绑定&#xff08;Binding&#xff09;5、Bootloader compatible属性 7、 #address-cells和#size-cells属性8…

64.函数参数和指针变量

目录 一.函数参数 二.函数参数和指针变量 三.视频教程 一.函数参数 函数定义格式&#xff1a; 类型名 函数名(函数参数1,函数参数2...) {代码段 } 如&#xff1a; int sum(int x&#xff0c;int y) {return xy; } 函数参数的类型可以是普通类型&#xff0c;也可以是指针类…

基于 Windows Server 2019 部署域控服务器

文章目录 前言1. 域控服务器设计规划2. 安装部署域控服务器2.1. 添加 Active Directory 域服务2.2. 将服务器提升为域控制器2.3. 检查域控服务器配置信息 3. 管理域账号3.1. 新建域管理员账号3.2. 新建普通域账号 4. 服务器加域和退域4.1. 服务器加域操作4.2. 服务器退域操作 总…

实战教程:如何用JavaScript构建一个功能强大的音乐播放器,兼容本地与在线资源

项目地址&#xff1a;Music Player App 作者&#xff1a;Reza Mehdikhanlou 视频地址&#xff1a;youtube 我将向您展示如何使用 javascript 编写音乐播放器。我们创建一个项目&#xff0c;您可以使用 javascript 从本地文件夹或任何 url 播放音频文件。 项目目录 assets 1…

安卓请求服务器[根据服务器的内容来更新spinner]

根据服务器的内容来更新spinner 本文内容请结合如下两篇文章一起看: 腾讯云函数node.js返回自动带反斜杠 腾讯云函数部署环境[使用函数URL] 现在有这样一个需求,APP有一个下拉选择框作为版本选择,因为改个管脚就变成一个版本,客户需求也很零散,所以后期会大量增加版本,这时候每…

数据结构——树的基础概念

目录 1.树的概念 2.树的相关概念 3.树的表示 &#xff08;1&#xff09;直接表示法 &#xff08;2&#xff09;双亲表示法 (3)左孩子右兄弟表示法 4.树在实际中的运用&#xff08;表示文件系统的目录树结构&#xff09; 1.树的概念 树是一种非线性的数据结构&#xff0…

找不到msvcp120.dll无法继续执行的原因分析及解决方法

在计算机使用中&#xff0c;经常会遇到msvcp120.dll文件丢失的情况&#xff0c;很多人对这个文件不是很熟悉&#xff0c;今天就来给大家讲解一下msvcp120.dll文件的丢失以及这个文件的重要性&#xff0c;让大家更好地了解计算机&#xff0c;同时也可以帮助我们更好地掌握这个文…

EVM-MLIR:以MLIR编写的EVM

1. 引言 EVM_MLIR&#xff1a; 以MLIR编写的EVM。 开源代码实现见&#xff1a; https://github.com/lambdaclass/evm_mlir&#xff08;Rust&#xff09; 为使用MLIR和LLVM&#xff0c;将EVM-bytecode&#xff0c;转换为&#xff0c;machine-bytecode。LambdaClass团队在2周…

leetcode216.组合总和III、40.组合总和II、39.组合总和

216.组合总和III 找出所有相加之和为 n 的 k 个数的组合&#xff0c;且满足下列条件&#xff1a; 只使用数字1到9 每个数字 最多使用一次 返回 所有可能的有效组合的列表 。该列表不能包含相同的组合两次&#xff0c;组合可以以任何顺序返回。 示例 1: 输入: k 3, n 7 输出…

【Spring Boot 源码学习】初识 ConfigurableEnvironment

《Spring Boot 源码学习系列》 初识 ConfigurableEnvironment 一、引言二、主要内容2.1 Environment2.1.1 配置文件&#xff08;profiles&#xff09;2.1.2 属性&#xff08;properties&#xff09; 2.2 ConfigurablePropertyResolver2.2.1 属性类型转换配置2.2.2 占位符配置2.…

单调栈(左小大,右小大)

①寻找每个数左边第一个比它小的数 给定一个长度为 N 的整数数列&#xff0c;输出每个数左边第一个比它小的数&#xff0c;如果不存在则输出 −1。 输入样例&#xff1a; 3 4 2 7 5 输出样例&#xff1a; -1 3 -1 2 2 从左到右遍历&#xff0c;用单调递增&#xff08;栈底到栈顶…

Spring MVC 中 使用 RESTFul 实现用户管理系统

1. Spring MVC 中 使用 RESTFul 实现用户管理系统 文章目录 1. Spring MVC 中 使用 RESTFul 实现用户管理系统2. 静态页面准备2.1 user.css2.2 user_index.html2.3 user_list.html2.4 user_add.html2.5 user_edit.html 3. SpringMVC环境搭建3.1 创建module&#xff1a;usermgt3…