DLA :pytorch添加算子

news2024/11/14 3:41:36

pytorch的C++ extension写法

        这部分主要介绍如何在pytorch中添加自定义的算子,需要以下cuda基础。就总体的逻辑来说正向传播需要输入数据,反向传播需要输入数据和上一层的梯度,然后分别实现这两个kernel,将这两个kernerl绑定到pytorch即可。

add

  • 但实际上来说,这可能不是一个很好的教程,因为加法中没有对输入的grad_out进行继续的操作(不用写cuda的操作)。所以实际上只需要正向传播的launch_add2函数。更重要的是作者大佬写了博客介绍。
// https://github.com/godweiyang/NN-CUDA-Example/blob/master/kernel/add2_kernel.cu

__global__ void add2_kernel(float* c,
                            const float* a,
                            const float* b,
                            int n) {
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
            i < n; i += gridDim.x * blockDim.x) {
        c[i] = a[i] + b[i];
    }
}

void launch_add2(float* c,
                 const float* a,
                 const float* b,
                 int n) {
    // 创建 [(n + 1023) / 1024 ,1 ,1]的三维向量数据
    dim3 grid((n + 1023) / 1024);//dim3 为CUDA中三维向量结构体
    // 创建 [1024 ,1 ,1]的三维向量数据
    dim3 block(1024);
    // 函数add2_kernel实现两个n维向量相加
    // 共有(n + 1023) / 1024*1*1个block , 每个block有1024*1*1个线程
    add2_kernel<<<grid, block>>>(c, a, b, n);
}

在这里插入图片描述

binary activation function

  • 正向计算为:
x > 1 ? 1 : -1;// 也可以使用sign() 函数(求符号函数)实现
  • 这篇文章作者没有自己写正向传播的算子,使用的是at::sign
// https://github1s.com/jxgu1016/BinActivateFunc_PyTorch/blob/master/src/cuda/BinActivateFunc_cuda.cpp#L17-L22
at::Tensor BinActivateFunc_forward(
    at::Tensor input) 
{
    CHECK_INPUT(input);
    return at::sign(input);
}
  • 这篇文章用的Setuptools将写好的算子和pytorch链接起来,运行时需要安装一下(JIT运行时编译也很香,代码直接运行,就是cmakelist.txt需要各种环境配置很麻烦)。绑定部分见链接。以下是作者实现的反向传播的kernel:
// https://github.com/jxgu1016/BinActivateFunc_PyTorch/blob/master/src/cuda/BinActivateFunc_cuda_kernel.cu
#include <ATen/ATen.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)

namespace {
template <typename scalar_t>
__global__ void BinActivateFunc_cuda_backward_kernel(
    const int nthreads,
    const scalar_t* __restrict__ input_data,
    scalar_t* __restrict__ gradInput_data) 
{
    CUDA_KERNEL_LOOP(n, nthreads) {
        if (*(input_data + n) > 1 || *(input_data + n) < -1) {
            *(gradInput_data + n) = 0;
        }
    }
}
} // namespace

int BinActivateFunc_cuda_backward(
    at::Tensor input,
    at::Tensor gradInput) 
{
    const int nthreads = input.numel();
    const int CUDA_NUM_THREADS = 1024;
    const int nblocks = (nthreads + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;

    AT_DISPATCH_FLOATING_TYPES(input.type(), "BinActivateFunc_cuda_backward", ([&] {
        BinActivateFunc_cuda_backward_kernel<scalar_t><<<nblocks, CUDA_NUM_THREADS>>>(
            nthreads,
            input.data<scalar_t>(),
            gradInput.data<scalar_t>());
    }));
    return 1;
}

swish

// https://github1s.com/thomasbrandon/swish-torch/blob/HEAD/csrc/swish_kernel.cu
#include <torch/types.h>
#include <cuda_runtime.h>
#include "CUDAApplyUtils.cuh"

// TORCH_CHECK replaces AT_CHECK in PyTorch 1,2, support 1.1 as well.
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif

#ifndef __CUDACC_EXTENDED_LAMBDA__
#error "please compile with --expt-extended-lambda"
#endif

namespace kernel {
#include "swish.h"

using at::cuda::CUDA_tensor_apply2;
using at::cuda::CUDA_tensor_apply3;
using at::cuda::TensorArgType;

template <typename scalar_t>
void
swish_forward(
  torch::Tensor &output,
  const torch::Tensor &input
) {
  CUDA_tensor_apply2<scalar_t,scalar_t>(
    output, input,
    [=] __host__ __device__ (scalar_t &out, const scalar_t &inp) {
      swish_fwd_func(out, inp);
    },
    TensorArgType::ReadWrite, TensorArgType::ReadOnly
  );
}

template <typename scalar_t>
void
swish_backward(
  torch::Tensor &grad_inp,
  const torch::Tensor &input,
  const torch::Tensor &grad_out
) {
  CUDA_tensor_apply3<scalar_t,scalar_t,scalar_t>(
    grad_inp, input, grad_out,
    [=] __host__ __device__ (scalar_t &grad_inp, const scalar_t &inp, const scalar_t &grad_out) {
      swish_bwd_func(grad_inp, inp, grad_out);
    },
    TensorArgType::ReadWrite, TensorArgType::ReadOnly, TensorArgType::ReadOnly
  );
}

} // namespace kernel

void
swish_forward_cuda(
    torch::Tensor &output, const torch::Tensor &input
) {
  auto in_arg  = torch::TensorArg(input,  "input",  0),
       out_arg = torch::TensorArg(output, "output", 1);
  torch::checkAllDefined("swish_forward_cuda", {in_arg, out_arg});
  torch::checkAllSameGPU("swish_forward_cuda", {in_arg, out_arg});
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "swish_forward_cuda", [&] {
      kernel::swish_forward<scalar_t>(output, input);
  });
}

void
swish_backward_cuda(
  torch::Tensor &grad_inp, const torch::Tensor &input, const torch::Tensor &grad_out
) {
  auto gi_arg = torch::TensorArg(grad_inp, "grad_inp", 0),
       in_arg = torch::TensorArg(input,    "input",    1),
       go_arg = torch::TensorArg(grad_out, "grad_out", 2);
  torch::checkAllDefined("swish_backward_cuda", {gi_arg, in_arg, go_arg});
  torch::checkAllSameGPU("swish_backward_cuda", {gi_arg, in_arg, go_arg});
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_inp.scalar_type(), "swish_backward_cuda", [&] {
      kernel::swish_backward<scalar_t>(grad_inp, input, grad_out);
  });
}

cg

  • ScatWave是使用CUDA散射的Torch实现,主要使用lua语言https://github.com/edouardoyallon/scatwave

  • https://github.com/huangtinglin/PyTorch-extension-Convolution

  • This is a tutorial to explore how to customize operations in PyTorch.

  • https://pytorch.org/tutorials/advanced/cpp_extension.html

  • 台湾博主 Pytorch+cpp/cuda extension 教學 tutorial 1 - English CC - B站搬运地址

  • pytorch的C++ extension写法

  • https://github.com/salinaaaaaa/NVIDIA-GPU-Tensor-Core-Accelerator-PyTorch-OpenCV

  • https://github.com/MariyaSha/Inference_withTorchTensorRT

  • 项目介绍了简单的CUDA入门,涉及到CUDA执行模型、线程层次、CUDA内存模型、核函数的编写方式以及PyTorch使用CUDA扩展的两种方式。通过该项目可以基本入门基于PyTorch的CUDA扩展的开发方式。

RWKV CUDA

  • 实例:手写 CUDA 算子,让 Pytorch 提速 20 倍(某特殊算子) https://zhuanlan.zhihu.com/p/476297195
  • https://github.com/BlinkDL/RWKV-CUDA
  • The CUDA version of the RWKV language model

数据加速

  • 用于在 Pytorch 中更快地固定 CPU <-> GPU 传输的库

环境

  • Docker images and github actions for building packages containing PyTorch C++/CUDA extensions.
    一个构建系统,用于生成(相对)轻量级和便携式的 PyPI 轮子,其中包含 PyTorch C++/CUDA 扩展。使用Torch Extension Builder构建的轮子动态链接到用户PyTorch安装中包含的Torch和CUDA库。最终用户计算机上不需要安装 CUDA。

CG

  • 又发现一个部署工具
研究人员很难将机器学习模型交付到生产环境。

解决方案的一部分是Docker,但要让它工作非常复杂:Dockerfiles,预/后处理,Flask服务器,CUDA版本。通常情况下,研究人员必须与工程师坐下来部署该死的东西。

安德烈亚斯和本创造了Cog。Andreas曾经在Spotify工作,在那里他构建了使用Docker构建和部署ML模型的工具。Ben 曾在 Docker 工作,在那里他创建了 Docker Compose。

我们意识到,除了Spotify之外,其他公司也在使用Docker来构建和部署机器学习模型。Uber和其他公司也建立了类似的系统。因此,我们正在制作一个开源版本,以便其他人也可以这样做。

如果您有兴趣使用它或想与我们合作,请与我们联系。我们在 Discord 上或给我们发电子邮件 team@replicate.com.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/806817.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iOS开发-聊天emoji表情与自定义动图表情左右滑动控件

iOS开发-聊天emoji表情与自定义动图表情左右滑动控件 之前开发中遇到需要实现聊天emoji表情与自定义动图表情左右滑动控件。使用UICollectionView实现。 一、效果图 二、实现代码 UICollectionView是一种类似于UITableView但又比UITableView功能更强大、更灵活的视图&#x…

window.location.href is not a function

在使用uniapp跳转到外部页面时&#xff0c;使用window.location.href报错 解决&#xff1a; 当出现"window.location.href is not a function"的错误时&#xff0c;这通常是因为在某些浏览器中&#xff0c;window.location.href被视为只读属性&#xff0c;而不是函…

时频分析方法的matlab实现

傅里叶变换 function [ output_args ] example3_7( input_args ) %EXAMPLE3_7 Summary of this function goes here % Detailed explanation goes here clc; clear; fs12800;%采样频率 s1load(Sig1.txt); s2load(Sig2.txt); lslength(s1); figure(1) subplot(211) plot…

c++11 标准模板(STL)(std::basic_filebuf)(八)

定义于头文件 <fstream> template< class CharT, class Traits std::char_traits<CharT> > class basic_filebuf : public std::basic_streambuf<CharT, Traits> std::basic_filebuf 是关联字符序列为文件的 std::basic_streambuf 。输入序…

【力扣每日一题】2023.7.29 环形链表

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们一个链表&#xff0c;让我们判断这个链表是否有环。我们可以直接遍历这个链表&#xff0c;最后能走到链表末尾也就是空指针那就…

Go语言进阶语法八万字详解,通俗易懂

文章目录 File文件操作FileInfo接口权限打开模式File操作文件读取 I/O操作io包 文件复制io包下的Read()和Write()io包下的Copy()ioutil包总结 断点续传Seeker接口断点续传 bufio包bufio包原理Reader对象Writer对象 bufio包bufio.Readerbufio.Writer ioutil包ioutil包的方法示例…

wps图表怎么改横纵坐标,MLP 多层感知器和CNN卷积神经网络区别

目录 wps表格横纵坐标轴怎么设置&#xff1f; MLP (Multilayer Perceptron) 多层感知器 CNN (Convolutional Neural Network) 卷积神经网络 多层感知器MLP&#xff0c;全连接网络&#xff0c;DNN三者的关系 wps表格横纵坐标轴怎么设置&#xff1f; 1、打开表格点击图的右侧…

LeetCode559. N 叉树的最大深度

559. N 叉树的最大深度 文章目录 [559. N 叉树的最大深度](https://leetcode.cn/problems/maximum-depth-of-n-ary-tree/)一、题目二、题解方法一&#xff1a;迭代方法二&#xff1a;递归 一、题目 给定一个 N 叉树&#xff0c;找到其最大深度。 最大深度是指从根节点到最远叶…

互联网医院建设|互联网医疗系统|互联网医院开发功能及特点

互联网医院是一种通过互联网技术与医疗服务相结合的创新模式。它可以让患者在家中或办公室&#xff0c;通过智能终端与医生进行在线咨询、预约挂号、开具电子处方等。这种模式打破了传统医疗的时间与空间限制&#xff0c;为患者提供了更为便捷、高效的医疗服务。 互联网医院功能…

LibreSSL SSL_connect: SSL_ERROR_SYSCALL in connection to github.com:443

1、问题&#xff1a; https://github.com/CocoaPods/Specs.git/&#xff1a;LibreSSL SSL_connect: SSL_ERROR_SYSCALL in connection to github.com:443的解决办法 出现这个问题的原因基本都是代理的问题&#xff1a; 只需要加上代理就可以了&#xff1a; #http代理 git conf…

算法39:Excel 表列序号

一、需求 给你一个字符串 columnTitle &#xff0c;表示 Excel 表格中的列名称。返回 该列名称对应的列序号 。 例如&#xff1a; A -> 1 B -> 2 C -> 3 … Z -> 26 AA -> 27 AB -> 28 … 示例 1&#xff1a; 输入: columnTitle “A” 输出: 1 示例 2&…

Elasticsearch:通过动态修剪实现更快的基数聚合

作者&#xff1a;Adrien Grand Elasticsearch 8.9 通过支持动态修剪&#xff08;dynamic pruning&#xff09;引入了基数聚合加速。 这种优化需要满足特定的条件才能生效&#xff0c;但一旦实现&#xff0c;通常会产生惊人的结果。 我们观察到&#xff0c;通过此更改&#xff0…

向量vector与sort()

运行代码&#xff1a; //向量与sort() #include"std_lib_facilities.h" //声明Item类 struct Item {string name;int iid;double value;friend istream& operator>>(istream& is, Item& ii);friend ostream& operator<<(ostream& o…

阿里 120W 年薪架构师力荐 750 页微服务架构深度解析笔记

前言 当前&#xff0c;微服务架构在国内正处于蓬勃发展的阶段&#xff0c;无论是大型互联网公司还是传统的 IT 企业&#xff0c;纷纷采用微服务架构构建系统。 在过去几年里&#xff0c;DevOps、云原生、面向演进式架构等理念已经深入人心&#xff0c;围绕微服务生态也出现了…

导致内存泄漏的因素及解决方式

导致内存泄漏的因素 一、全局变量 因为全局变量不被js垃圾回收机制所回收&#xff0c;所以在使用全局变量时要小心。避免在想使用局部变量因为疏忽导致该变量流失到全局&#xff0c;如未声明变量&#xff0c;却直接对其进行赋值&#xff0c;就会导致该变量在全局创建&#xf…

Python Numpy入门基础(一)创建数组

入门基础&#xff08;一&#xff09; 创建数组 1- np.array() 参数众多&#xff0c;初学时只要关注基本用法。 array(...)array(object, dtypeNone, *, copyTrue, orderK, subokFalse, ndmin0,likeNone)Create an array.Parameters----------object : array_likeAn array, …

CAN通信协议

CAN 物理电平 以高速CAN为例 有电压差&#xff08;2.5V&#xff09;为显性&#xff0c;逻辑0无电压差为隐性&#xff0c;逻辑1 帧结构 SOF 恒为显性&#xff0c;逻辑0 仲裁段 当有多个设备发送数据&#xff0c;产生总线冲突时&#xff0c;来判断一个先后顺序由于总线是线与…

重学C++系列之友元

一、什么是友元 在C中&#xff0c;为了提高程序的效率&#xff0c;在一些场景下&#xff0c;引入友元&#xff0c;但同时类的封装性就会被破坏。 二、怎么实现友元 友元关键字&#xff08;friend&#xff09; // 在类中声明另一个类的成员函数来做为友元函数 // 以关键字&…

golangd\pycharm-ai免费代码助手安装使用gpt4-免费使用--[推荐]

golangd-ai免费代码助手安装使用,pycharm可以使用&#xff0c;估计只要是xx的ide都是可以使用这个插件 目前GPT4以及gpt的大规模使用&#xff0c;如何快速掌握以及在ide中快速使用的办法&#xff0c;今天安装一款golangd编辑器的插件已经使用 一、安装以及使用 1.在golangd中…

texshop mac中文版-TeXShop for Mac(Latex编辑预览工具)

texshop for mac是一款可以在苹果电脑MAC OS平台上使用的非常不错的Mac应用软件&#xff0c;texshop for mac是一个非常有用的工具&#xff0c;广泛使用在数学&#xff0c;计算机科学&#xff0c;物理学&#xff0c;经济学等领域的合作&#xff0c;这些程序的标准tetex分布特产…