CUDA ~ WarpReduce

news2025/1/21 0:56:26

又是一篇关于cuda的 要好好学学哦, CUDA 编程进阶分享,一些 warp 的使用

如何实现一个高效的Softmax CUDA kernel?多少还是有些细节没有理解,恰好最近要做一个类似的 Reduce+Scale Kernel,原理机制还是比较相似的,所以翻出来重新理解一下。

我们定义这么一个ReduceScale操作:假设Tensor是(N, C),首先在C这个维度计算出 absMax 值,我们记作scale,然后将每一行除以各自 行的scale,并最终输出。一段朴素的numpy代码是这样:

import numpy as np  
  
  
N = 1000  
C = 128  
x = np.random.randn(N, C)  
scale = np.expand_dims(np.max(np.abs(x), axis=1), 1)  
out = x / scale  
print(out.shape)  

BaseLine

这里我们BaseLine是直接调用cub库中的 BlockReduce,一个 threadBlock 处理一行数据,计算出AbsMaxVal,然后再缩放,代码如下:

#include "cuda.h"  
#include "cub/cub.cuh"  
  
constexpr int kReduceBlockSize = 128;  
  
template<typename T>  
__device__ T abs_func(const T& a) {  
  return abs(a);  
}  
  
  
template<typename T>  
__device__ T max_func(const T a, const T b) {  
  return a > b ? a : b;  
}  
  
template<typename T>  
struct AbsMaxOp {  
  __device__ __forceinline__ T operator()(const T& a, const T& b) const {  
    return max_func(abs_func(a), abs_func(b));  
  }  
};  
  
template<typename T>  
__inline__ __device__ T BlockAllReduceAbsMax(T val) {  
  typedef cub::BlockReduce<T, kReduceBlockSize> BlockReduce;  
  __shared__ typename BlockReduce::TempStorage temp_storage;  
  __shared__ T final_result;  
  T result = BlockReduce(temp_storage).Reduce(val, AbsMaxOp<T>());  
  if (threadIdx.x == 0) { final_result = result; }  
  __syncthreads();  
  return final_result;  
}  
  
template<typename T, typename IDX>  
__global__ void ReduceScaleBlockKernel(T* x, IDX row_size, IDX col_size) {  
  for(int32_t row = blockIdx.x, step=gridDim.x; row < row_size; row+= step){  
    T thread_scale_factor = 0.0;   
    for(int32_t col=threadIdx.x; col < col_size; col+= blockDim.x){  
      IDX idx = row * col_size + col;   
      T x_val = x[idx];  
      thread_scale_factor = max_func(thread_scale_factor, abs_func(x_val));   
    }  
    T row_scale_factor = BlockAllReduceAbsMax<T>(thread_scale_factor);   
    for(int32_t col=threadIdx.x; col < col_size; col+=blockDim.x){  
      IDX idx = row * col_size + col;   
      x[idx] /= row_scale_factor;  
    }  
  }  
}  

参数中 x 是输入数据,row_size是行的数量,col_size是列的大小测试机器是在 A100 40GB,为了让结果区别比较明显,我们将行数设置的比较大,输入形状为(55296*8, 128),启动的线程块数目根据 如何设置CUDA Kernel中的grid_size和block_size?这篇文章来指定,这里比较粗暴的设置为(55296, 128),数据类型为 Float,然后我们看下ncu的结果:

主要有这几个指标,耗时为577.95us,吞吐量为 748.78Gb/s下面我们就根据 Softmax 优化那篇文章所提及的点来逐步分析:

优化1 数据Pack

在之前的 高效、易用、可拓展我全都要:OneFlow CUDA Elementwise 模板库的设计优化思路 里很详细的描述了如何做向量化读写,cuda里最大支持 128bit的读写,那么在数据类型为 Float 时,我们即可以将连续的4个 Float 打包到一起,一次性读写,提升吞吐。有了解过这方面的读者应该就反应过来,诶 CUDA 里 不是刚好有一个类型叫 float4 就是干这件事的么,没错,但是为了更灵活的支持其他数据类型的向量化,我们利用union共享空间的特性实现了一个 Pack 类:

template<typename T, int N>  
struct GetPackType {  
  using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;  
};  
  
template<typename T, int N>  
using PackType = typename GetPackType<T, N>::type;  
  
template<typename T, int N>  
union Pack {  
  static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, "");  
  __device__ Pack() {  
    // do nothing  
  }  
  PackType<T, N> storage;  
  T elem[N];  
};  

优化2 数据缓存

整个算子逻辑是需要读取一遍数据,计算scale,然后再读取一遍数据,用scale进行缩放。很显然这里我们读取了两遍数据,而数据是放在 Global Memory,带宽比较低,会带来读取耗时。

 

一个很自然的想法是缓存到寄存器/Shared Memory中。由于这里我们只实现 WarpReduce 版本,所以我们是缓存到寄存器(其他版本可以参考开头的优化 Softmax 文章)中,减少一次对 Global Memory 的读取。

template<typename T, typename IDX, int pack_size, int cols_per_thread>  
__global__ void ReduceScaleWarpKernel(T* x, IDX row_size, IDX col_size) {  
    // ...  
    T buf[cols_per_thread];  
    // ...  

优化3 使用Warp处理一行数据

相较 BaseLine,我们这里使用 warp 作为 Reduce 的单位进行操作,首先我们简单看下 WarpReduce 的实现。

template<typename T>  
struct AbsMaxOp {  
  __device__ __forceinline__ T operator()(const T& a, const T& b) const {  
    return max_func(abs_func(a), abs_func(b));  
  }  
};  
  
template<typename T>  
__inline__ __device__ T WarpAbsMaxAllReduce(T val){  
    for(int lane_mask = kWarpSize/2; lane_mask > 0; lane_mask /= 2){  
        val = AbsMaxOp<T>()(val, __shfl_xor_sync(0xffffffff, val, lane_mask));   
    }  
    return val;   
}  

这段代码在别的 BlockReduce 也经常看到,他是借助 __shfl_xor_sync 来实现比较,shuffle 指令允许同一线程束的两个线程直接读取对方的寄存器。

T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize);  

其中 mask 是对线程的一个掩码,我们一般所有线程都要参与计算,所以 mask 是 0xffffffffvar 则是寄存器值,laneMask 则是用来做按位异或的掩码

这里引入一个概念叫 Lane,它表示线程束中的第几号线程

示意图如下:

 

当 laneMask = 16 时,其二进制为 0001 0000,然后线程束每个线程与 laneMask 做异或操作如:

  • 0000 0000 xor 0001 0000 = 0001 0000 = 16

  • 0000 0001 xor 0001 0000 = 0001 0001 = 17

  • 0000 0010 xor 0001 0000 = 0001 0010 = 18

以此类推,最终得到一个 Warp 中的 absmax 值。接下来我们开始写Kernel,模板参数分别为:

  • T 数据类型

  • IDX 索引类型

  • pack_size pack数,比如float可以pack成4个,那对应pack_size=4

  • cols_per_thread 每个线程需要处理的元素个数,比如一行大小是128,而我们一个warp有32个线程,那么这里就是128/32 = 4

template<typename T, typename IDX, int pack_size, int cols_per_thread>  
__global__ void ReduceScaleWarpKernel(T* x, IDX row_size, IDX col_size) {  
    // ...      
}  

跟BaseLine一样,我们block大小还是设置为128个线程,一个warp是32个线程,所以我们一个block可以组织成(32, 4),包含4个warp。

根据这个层级划分,我们可以计算出:

  • global_thread_group_id 当前warp的全局index

  • num_total_thread_group warp的总数量

  • lane_id 线程束内的线程id

  • num_packs pack的数目,即每个线程需要处理的元素个数 / pack_size

const int32_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;   
    const int32_t num_total_thread_group = gridDim.x * blockDim.y;   
    const int32_t lane_id = threadIdx.x;   
    using LoadStoreType = PackType<T, pack_size>;  
    using LoadStorePack = Pack<T, pack_size>;  
    T buf[cols_per_thread];   
    constexpr int num_packs = cols_per_thread / pack_size;  

由于存在启动的warp的数量小于行的数量,所以我们要引入一个 for 循环。假设我们 cols = 256,那么线程束里的每个线程需要处理 256 /32 = 8个元素,而4个float可以pack到一起,所以我们线程束里的每个线程要处理2个pack,因此也要引入一个关于 num_packs 的 for 循环,以保证整一行都有被读取到:

一次性读取到一个 pack 后,我们再一个个放到寄存器当中缓存起来,并计算线程上的 AbsMaxVal。

for(IDX row_idx = global_thread_group_id; row_idx < row_size; row_idx += num_total_thread_group){  
        T thread_abs_max_val = 0.0;   
        for(int pack_idx = 0; pack_idx < num_packs; pack_idx++){  
            const int32_t pack_offset = pack_idx * pack_size;   
            const int32_t col_offset = pack_idx * kWarpSize * pack_size + lane_id * pack_size;   
            const int32_t load_offset = (row_idx * col_size + col_offset) / pack_size;   
            LoadStorePack load_pack;   
            load_pack.storage = *(reinterpret_cast<LoadStoreType*>(x)+ load_offset);   
            #pragma unroll   
            for(int i = 0; i < pack_size; i++){  
                buf[pack_offset] = load_pack.elem[i];   
                thread_abs_max_val = max_func(thread_abs_max_val, abs_func(buf[pack_offset]));  
            }   
        }  

接着我们调用 WarpAbsMaxAllReduce 进行reduce,获得线程束中的 AbsMaxVal,并对缓存的数据进行数值缩放。

T warp_max_val = WarpAbsMaxAllReduce<T>(thread_abs_max_val);   
        #pragma unroll  
        for (int col = 0; col < cols_per_thread; col++) {  
            buf[col] = buf[col] / warp_max_val;  
        }  

最后跟一开始读取类似,我们将寄存器里的值再写回去,相关索引的计算逻辑都是一致的:

for(int pack_idx = 0; pack_idx < num_packs; pack_idx++){  
            const int32_t pack_offset = pack_idx * pack_size;   
            const int32_t col_offset = pack_idx * pack_size * kWarpSize + lane_id * pack_size;   
            const int32_t store_offset = (row_idx * col_size + col_offset) / pack_size;   
            LoadStorePack store_pack;   
            #pragma unroll   
            for(int i = 0; i < pack_size; i++){  
                store_pack.elem[i] = buf[pack_offset + i];   
            }   
            *(reinterpret_cast<LoadStoreType*>(x)+ store_offset) = store_pack.storage;   
        }  

完整代码如下:

template<typename T>  
__inline__ __device__ T WarpAbsMaxAllReduce(T val){  
    for(int lane_mask = kWarpSize/2; lane_mask > 0; lane_mask /= 2){  
        val = AbsMaxOp<T>()(val, __shfl_xor_sync(0xffffffff, val, lane_mask));   
    }  
    return val;   
}  
  
template<typename T, typename IDX, int pack_size, int cols_per_thread>  
__global__ void ReduceScaleWarpKernel(T* x, IDX row_size, IDX col_size) {  
    const int32_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;   
    const int32_t num_total_thread_group = gridDim.x * blockDim.y;   
    const int32_t lane_id = threadIdx.x;   
    using LoadStoreType = PackType<T, pack_size>;  
    using LoadStorePack = Pack<T, pack_size>;  
    T buf[cols_per_thread];   
    constexpr int num_packs = cols_per_thread / pack_size;  
    for(IDX row_idx = global_thread_group_id; row_idx < row_size; row_idx += num_total_thread_group){  
        T thread_abs_max_val = 0.0;   
        for(int pack_idx = 0; pack_idx < num_packs; pack_idx++){  
            const int32_t pack_offset = pack_idx * pack_size;   
            const int32_t col_offset = pack_idx * kWarpSize * pack_size + lane_id * pack_size;   
            const int32_t load_offset = (row_idx * col_size + col_offset) / pack_size;   
            LoadStorePack load_pack;   
            load_pack.storage = *(reinterpret_cast<LoadStoreType*>(x)+ load_offset);   
            #pragma unroll   
            for(int i = 0; i < pack_size; i++){  
                buf[pack_offset] = load_pack.elem[i];   
                thread_abs_max_val = max_func(thread_abs_max_val, abs_func(buf[pack_offset]));  
            }   
        }  
        T warp_max_val = WarpAbsMaxAllReduce<T>(thread_abs_max_val);   
        #pragma unroll  
        for (int col = 0; col < cols_per_thread; col++) {  
            buf[col] = buf[col] / warp_max_val;  
        }  
        for(int pack_idx = 0; pack_idx < num_packs; pack_idx++){  
            const int32_t pack_offset = pack_idx * pack_size;   
            const int32_t col_offset = pack_idx * pack_size * kWarpSize + lane_id * pack_size;   
            const int32_t store_offset = (row_idx * col_size + col_offset) / pack_size;   
            LoadStorePack store_pack;   
            #pragma unroll   
            for(int i = 0; i < pack_size; i++){  
                store_pack.elem[i] = buf[pack_offset + i];   
            }   
            *(reinterpret_cast<LoadStoreType*>(x)+ store_offset) = store_pack.storage;   
        }  
    }  
}  

这里我们方便测试,调用的时候就直接写死一些模板参数

constexpr int cols_per_thread = 128 / kWarpSize;   
ReduceScaleWarpKernel<float, int32_t, 4, cols_per_thread><<<55296, block_dim>>>(device_ptr, row_size, col_size);  

最后我们看一下 ncu 的结果:

吞吐量达到了1.3T,时间位333us,相比 BaseLine 快了 73 %。

总结

还有更多特殊情况可以参考 Softmax 优化的代码,这里仅实现了第一个 Warp 计算方式。我感觉看着还行,真自己写起来理解还是有点困难的,希望这篇博客能帮助读者理解到一些 warp 的使用。     whaosoft aiot http://143ai.com  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/105868.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CSS -- 10. 移动WEB开发之rem布局

文章目录移动WEB开发之rem布局1 rem基础2 媒体查询2.1 什么是媒体查询2.2 语法规范2.3 mediatype 查询类型2.4 关键字2.5 媒体特性2.6 案例&#xff1a;根据页面宽度改变背景颜色2.7 媒体查询rem实现元素动态大小变化2.8 针对不同的屏幕尺寸引入不同的样式文件3 Less基础3.1 维…

8000字详解Thread Pool Executor

摘要&#xff1a;Java是如何实现和管理线程池的?本文分享自华为云社区《JUC线程池: ThreadPoolExecutor详解》&#xff0c;作者&#xff1a;龙哥手记 。 带着大厂的面试问题去理解 提示 请带着这些问题继续后文&#xff0c;会很大程度上帮助你更好的理解相关知识点。pdai …

数据泄露成数据安全最大风险,企业如何预防呢?

据《中国政企机构数据安全风险分析报告》显示&#xff0c;2022年1月——2022年10月&#xff0c;安全内参共收录全球政企机构重大数据安全报道180起&#xff0c;其中数据泄露相关安全事件高达93起&#xff0c;占51.7%。与近三年平均每月公开报道频次相比&#xff0c;2022年相较前…

如何在3DMAX中不使用Maxscript或插件破碎物体对象?

在3DMAX中破碎物体我们通常会借助Maxscript或者插件&#xff0c;其实&#xff0c;不借助任何其他工具&#xff0c;3DMAX也可以实现对物体的破碎&#xff0c;下面就给大家介绍一种方法&#xff1a; 1.首先&#xff0c;创建一个破碎对象&#xff0c;比如一个石块&#xff08;或者…

AI趋势下,小布助手的进化论

“要构建人工智能等高精尖产业的新增长引擎”&#xff0c;随着人工智能在未来全球科技经济中的重要作用愈加凸显&#xff0c;当前产业已然获得了有史以来最强的政策建构力量。 随着政策的利好&#xff0c;中国人工智能进入一个前所未有的快速发展阶段。企查查数据显示&#xf…

疫情下的在线教学数据观

由于新型冠状病毒感染的肺炎疫情影响&#xff0c;剧烈增长的市场需求助推了在线教育的发展&#xff0c;同时也暴露了一些问题。 最近我们被客户要求撰写关于疫情的研究报告&#xff0c;包括一些图形和统计输出。 在本文中&#xff0c;我们结合了对100多个高中学生进行的在线教…

快讯 | 嘉为蓝鲸受邀出席汽车新智造数字行业峰会,助力构建数字时代竞争力!

12月9日&#xff0c;第五届GADI汽车新智造数字创新行业峰会暨年度评选盛典于上海圆满落幕&#xff0c;嘉为蓝鲸受邀出席。本届大会以“数智创新 赋能破局”为主题&#xff0c;多方面切入解读新能源汽车的数字化发展趋势&#xff0c;助力车企构建数字时代竞争力。 01 研运一体&a…

数据通信基础 - 信道特性(奈奎斯特定理、香农定理 )

文章目录1 概述1.1 通信系统模型图2 信道特性2.1 信道带宽 W2.2 奈奎斯特定理 - 无噪音2.3 香农定理 - 有噪音2.4 带宽、码元速率、数据速率 关系梳理3 网工软考真题1 概述 1.1 通信系统模型图 通信的目的&#xff1a;传递信息 2 信道特性 2.1 信道带宽 W 模拟信道&#…

数据中台选型必读(六):说说数据服务的七大核心功能

在前面的文章中&#xff0c;我们介绍了数据中台的元数据中心、指标字典与指标体系、数据模型设计、数据质量评估等内容&#xff0c;这些都是One Data理念下数据中台架构的重要部分。 我们今天要讲的One Service——统一数据服务&#xff0c;指的是由数据中台提供统一的数据接入…

搭建自动发卡网站搭建教程(独角数卡)保姆级教程,支付 + 图文

自动发卡网站 程序是开源的独角数卡 我搭建了一个这样的 wooknow自动销售发卡http://ok.54ndd.com/ 一个在线销售虚拟产品的平台。你应该见过这样的发卡平台。一些虚拟产品&#xff0c;如软件、激活码和会员可以放在上面出售。我在这里使用的发卡项目是一个开源的单字符数字…

Matplotlib怎么创建 axes 对象?

在 matplotlib 中&#xff0c;有几种常见的方法来创建 axes 对象&#xff1a; 1.使用 subplots 函数&#xff1a; import matplotlib.pyplot as pltfig, ax plt.subplots()subplots 函数会创建一个新的图形&#xff08;figure&#xff09;并返回一个包含单个子区域&#xff…

二肽Ala-Pro,13485-59-1

Substrate for skin fibroblast prolidase.皮肤成纤维细胞prolida酶的底物。 编号: 199181中文名称: 二肽Ala-Pro英文名: Ala-ProCAS号: 13485-59-1单字母: H2N-AP-OH三字母: H2N-Ala-Pro-COOH氨基酸个数: 2分子式: C8H14N2O3平均分子量: 186.21精确分子量: 186.1等电点(PI): 6…

【git 提交、撤销、回退代码】

git 提交、撤销、回退代码git push后 发现提交分支错误 --> 回退代码git 未push、取消commit(保留代码&#xff09;git 未push、取消commit(不保留代码&#xff09;git push后 发现提交分支错误 --> 回退代码 首先 git log 查看提交记录&#xff0c; 找到需要回退到哪次…

CSRF实战案例—绕过referer值验证

在一个添加管理员的界面引起了我的注意 尝试添加一个管理员,如下添加成功,我们可以观察其请求包中并未存在token字段,可能存在csrf漏洞。但是存在“Referer”和“Origin”字段 我们把referer字段删了只剩origin,查看是否可以请求成功,发现可以请求成功 两个值都删了,请求…

PGL 系列(四)词向量 CBOW

环境 python 3.6.8paddlepaddle-gpu 2.3.0numpy 1.19.5一、CBOW 概念 CBOW:通过上下文的词向量推理中心词 在CBOW中,先在句子中选定一个中心词,并把其它词作为这个中心词的上下文。如 上图 CBOW所示,把“spiked”作为中心词,把“Pineapples、are、and、yellow”作为中心词…

【车载开发系列】UDS诊断---控制DTC设置($0x85)

【车载开发系列】UDS诊断—控制DTC设置&#xff08;$0x85&#xff09; UDS诊断---控制DTC设置&#xff08;$0x85&#xff09; 【车载开发系列】UDS诊断---控制DTC设置&#xff08;$0x85&#xff09;一.概念定义常见汽车故障二.子功能三.报文格式1&#xff09;报文请求2&#xf…

索引的底层实现原理是什么?

索引存储在内存中&#xff0c;为服务器存储引擎为了快速找到记录的一种数据结构。索引的主要作用是加快数据查找速度&#xff0c;提高数据库的性能。 索引的分类 (1) 普通索引&#xff1a;最基本的索引&#xff0c;它没有任何限制。 (2) 唯一索引&#xff1a;与普通索引类似…

计算机毕设Python+Vue研究生培养过程管理系统(程序+LW+部署)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

【SpringCloud-Eureka】Gateway网关

Gateway概念 特征 核心流程 Eureka服务注册 生产端 Gateway网关 验证网关 Gateway概念 路由&#xff08;Route&#xff09;是GateWay中最基本的组件之一&#xff0c;表示一个具体的路由信息载体&#xff0c;主要由下面几个部分组成&#xff1a; id&#xff1a;路由唯一标…

Cadence Allegro在PCB中手动或者自动添加差分对属性

设计PCB过程中&#xff0c;若设计中有差分对信号&#xff0c;则需要将是差分的2个信号设置为差分对&#xff0c;设置差分对有2种方式&#xff1a;手动添加及自动添加一、手动添加差分对&#xff1a;1、点击Setup-Constraints-Constraint Manager调出CM规则管理器&#xff0c;然…