大模型自定义算子优化方案学习笔记:CUDA算子定义、算子编译、正反向梯度实现

news2025/1/18 9:59:19

01算子优化的意义

随着大模型应用的普及以及算力紧缺,下一步对于计算性能的追求一定是技术的核心方向。因为目前大模型的计算逻辑是由一个个独立的算子或者说OP正反向求导实现的,底层往往调用的是GPU提供的CUDA的驱动程序。如果不能对于整个计算过程学习并了解,对于性能优化领域无非是隔靴搔痒,今天也是抽一点时间读了下网上的一些文档和CUDA的文档,整理了学习材料。

首先说下为什么要自定义算子,无非是两个原因,

(1)TF、PyTorch等提供的原生算子不满足需求,需要通过底层接口,比如CUDA层更灵活的实现个性化算子

(2)目前提供的算子性能不足,没有很好的利用到GPU的并行计算优势,有优化空间

接着性能优化的问题说,因为GPU与CPU最大的区别是计算单元占据了绝大部分的体积(图中绿色部分),而控制单元较少。

自定义手写算子可以更好地利用绿色的计算单元的并行化优势。大的思路是Grid包含Block,Block包含Thread。于是首先算子需要把计算逻辑拆分成Thread,让程序可以并行化的运行起来,然后有机的管理各个Block的执行节奏,解决好异步和同步问题,就可以让芯片的计算效率最大化。

Grid of Thread Blocks

02实现流程

整个自定义算子优化其实可以分为CUDA算子定义、算子编译、正方向梯度实现几个步骤。

1、CUDA算子定义

需要定义以下几个文件:

(1)function.cpp:python层和CUDA层的衔接,实现手写的C++ CUDA算子可以被python代码调用

(2)function.h:CUDA函数声明文件

(3)function.cu:CUDA函数的逻辑实现

这里比较核心的文件就是.cu文件,构建的时候主要做两个事:一个是建设Kernel函数,因为只有Kernel函数是在GPU端执行,执行完之后要将控制权给到控制函数,这里要控制好异步、同步的问题。二是在kernel函数中需要通过循环函数定义每个Thread以及每个Block的工作,真正让计算并行化

.cpp文件可以通过pybind函数实现,这个函数主要解决的是C++代码和Python绑定的问题,项目地址:GitHub - pybind/pybind11: Seamless operability between C++11 and Python

2.编译和执行

import torch
from torch.utils.cpp_extension import load
cuda_module = load(name="function",
                   extra_include_paths=["include"],
                   sources=["function.cpp", "function.cu"],
                   verbose=True)

接着就是执行过程中的编译,通过load函数底层会执行JIT(Just in time)的动态编译模式调用.cpp和.cu文件,底层运行的是Ninjia编译器,通过调用nvcc编译.so文件

[1/2] nvcc -c function.cu -o function.cuda.o
[2/3] c++ -c function.cpp -o function.o
[3/3] c++ function.o function.cuda.o -shared -o function.so

3.正反向梯度实现

以上两步实现了自定义算子的逻辑,可以通过手写CUDA算子并在python框架层调用,如果要满足建模需求,还需要实现正方向梯度。具体的做法是在建模的函数中实现forward和backward函数,定义导数作为输出。

以上大概就是手写算子优化的简单流程,仅当学习笔记。

参考:

【1】熬了几个通宵,我写了份CUDA新手入门代码 - 知乎

【2】CUDA C++ Programming Guide

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1318131.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LearnDash LMS ProPanel在线学习系统课程创作者的分析工具

点击阅读LearnDash LMS ProPanel在线学习系统课程创作者的分析工具原文 LearnDash LMS ProPanel在线学习系统课程创作者的分析工具通过整合报告和作业管理来增强您的 LearnDash 管理体验,使您能够发送特定于课程的通信,并显示课程的实时活动&#xff01…

分类信息网商业运营版源码系统:适合各类行业分类站点建站 带安装部署教程

随着互联网的快速发展,信息分类网站在各个行业中得到了广泛应用。为了满足不同行业的需求,罗峰给大家分享一款适合各类行业分类站点建站的商业运营版源码系统。该系统旨在提供一套完整的解决方案,帮助用户快速搭建自己的分类信息网站&#xf…

【最新版】在WSL上运行 Linux GUI (图形用户界面)应用(Gnome 文本编辑器、GIMP、Nautilus、VLC、X11 应用)

文章目录 一、 安装WSL0. 先决条件1. 全新安装2. 现有 WSL 安装3. 注意事项 二、运行 Linux GUI 应用1. 更新发行版中的包2. 安装 Gnome 文本编辑器启动 3. 安装 GIMP启动 4. 安装 Nautilus启动 5. 安装 VLC启动 6. 安装 X11 应用 适用于 Linux 的 Windows 子系统 (WSL) 现在支…

Javaweb考前复习冲刺(不断更新版)

Javaweb考前复习冲刺 第一章: JavaWeb 入门 JavaWeb是指:以Java作为后台语言的项目工程。 javaweb项目创建的过程: 首先集成Tomcat服务器环境新建dynamic web project部署工程运行 路由含义: ​ http://localhost:8080/工程…

Redis 主从集群 —— 超详细操作演示!

五、Redis 主从集群 五、Redis 主从集群5.1 主从集群搭建5.1.1、伪集群搭建与配置5.1.2、分级管理5.1.3、容灾冷处理 5.2 主从复制原理5.2.1、主从复制原理5.2.2、数据同步演变过程 5.3 哨兵机制实现5.3.1 简介5.3.2 Redis高可用集群搭建5.3.3 Redis高可用集群的启动5.3.4 Sent…

ubuntu创建apt-mirror本地仓库

首先创建apt-mirror的服务端,也就是存储所有apt-get下载的文件和依赖。大约需要300G,预留400G左右空间就可以开始了。 安装ubuntu省略,用的是ubuntu202204 ubuntu挂载硬盘(不需要的可以跳过): #下载挂载工具 sudo apt…

Java并发(十九)----Monitor原理及Synchronized原理

1、Java 对象头 以 32 位虚拟机为例 普通对象 |--------------------------------------------------------------| | Object Header (64 bits) | |------------------------------------|-------------------------| | Mark W…

序列生成模型(一):序列概率模型

文章目录 前言1. 序列数据2. 序列数据的潜在规律3. 序列概率模型的两个基本问题 一、序列概率模型1. 理论基础序列的概率分解自回归生成模型 2. 序列生成 前言 深度学习在处理序列数据方面取得了巨大的成功,尤其是在自然语言处理领域。序列数据可以是文本、声音、视…

pytorch和pytorchvision安装

参考https://blog.csdn.net/2301_76863102/article/details/129369549 https://blog.csdn.net/weixin_43798572/article/details/123122477 查看我的版本 右键,nvivdia控制面板,帮助,系统信息 驱动程序版本号为528.49 更新很快的 CUDA版本…

stm32F4——BEEP与按键的实例使用

stm32F4——BEEP与按键的实例使用 蜂鸣器和按键的实质都是GPIO的使用,所以基本原理就不介绍啦,基本寄存器其实都是GPIO的高低电平的赋值,可以参考stm32——LEDGPIO的详细介绍 文章目录 stm32F4——BEEP与按键的实例使用一、BEEP二、 KEY三、通…

MySQL数据库 DML

目录 DML概述 添加数据 修改数据 删除数据 DML概述 DML英文全称是Data Manipulation Language(数据操作语言),用来对数据库中表的数据记录进行增、删、改操作。 添加数据(工NSERT)修改数据(UPDATE)删除数据(DELETE) 添加数据 (1)给指定字段添加数据 INSERT …

Epic 安装失败,错误代码SUPQR1612,必要的先决条件安装失败,弹窗CD-ROM

错误记录 并且弹出来这个窗口the feature you are trying to use is on a CD-ROM or other removable disk that is not available. 让我寻找这个msi文件 正在想办法解决 如果有人能解决 麻烦告知评论区!

网络安全事件分级指南

文章目录 一、特别重大网络安全事件符合下列情形之一的,为特别重大网络安全事件:通常情况下,满足下列条件之一的,可判别为特别重大网络安全事件: 二、重大网络安全事件符合下列情形之一且未达到特别重大网络安全事件的…

C++软件调试与异常排查技术从入门到精通学习路线分享

目录 1、概述 2、全面了解引发C软件异常的常见原因 3、熟练掌握排查C软件异常的常见手段与方法 3.1、IDE调试 3.2、添加打印日志 3.3、分块注释代码 3.4、数据断点 3.5、历史版本比对法 3.6、Windbg静态分析与动态调试 3.7、使用IDA查看汇编代码 3.8、使用常用工具分…

【Linux】cp问题,生产者消费者问题代码实现

文章目录 前言一、 BlockQueue.hpp(阻塞队列)二、main.cpp 前言 生产者消费者模式就是通过一个容器来解决生产者和消费者的强耦合问题。生产者和消费者彼此之间不直接通讯,而通过阻塞队列来进行通讯,所以生产者生产完数据之后不用…

将博客搬至微信公众号了

一、博客搬家通知 各位码友们好,大家是不是基本很少看 CSDN 了呢?平时开发是不都依靠着 chatGPT 来解决工作中的技术问题了,不过我觉得在工作中的使用场景各式各样的,具体问题还得自己具体去梳理逻辑,再考虑用什么样的…

【Vue第6-7章】vue-router与Vue UI组件库_Vue2

目录 第6章 vue-router 6.1 相关理解 6.1.1 vue-router的理解 6.1.2 对SPA应用的理解 6.1.3 路由的理解 6.2 基本路由 6.2.1 效果 6.2.2 总结:编写使用路由的3步 6.2.3 笔记与代码 6.2.3.1 笔记 6.2.3.2 30_src_路由的基本使用 6.3 嵌套(多级…

股票价格预测 | Python实现基于ARIMA和LSTM的股票预测模型(含XGBoost特征重要性衡量)

文章目录 效果一览文章概述模型描述源码设计效果一览 文章概述 Python实现基于ARIMA和LSTM的股票预测模型(Stock-Prediction) Data ExtractionFormatting data for time seriesFeature engineering(Feature Importance using X

MX6ULL学习笔记(十二)Linux 自带的 LED 灯

前言 前面我们都是自己编写 LED 灯驱动,其实像 LED 灯这样非常基础的设备驱动,Linux 内 核已经集成了。Linux 内核的 LED 灯驱动采用 platform 框架,因此我们只需要按照要求在设备 树文件中添加相应的 LED 节点即可,本章我们就来学…

vue-cli创建一个vue3项目

通过命令创建VUE脚手架: npm install -g vue/cli 创建VUE项目: vue create npg***b 剩下的就是一路回车 npm install element-plus --save 入口文件main.ts import { createApp } from vue import App from ./App.vue import router from ./router…