前一篇文章介绍了 Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU 论文,下面对其代码实现进行分析。
cutlass 的 examples/47_ampere_gemm_universal_streamk 展示了 GEMM Stream-K 算法在 Ampere 架构上的使用。对比了普通 Gemm 以及 Split-K 算法和 Stream-K 的性能:
- Device 层面,GemmUniversal 统一支持了 Gemm、Split-K 和 Stream-K 算法,主要实现在其基类 GemmUniversalBase 中;
- Kernel 层面, GemmUniversal 为 Gemm 和 Split-K 的实现,GemmUniversalStreamk 为 Stream-K 的实现:
- 二者由 DefaultGemmUniversal、DefaultGemm、Gemm 等共享很多组件和配置,即构建了 Gemm,但是仅使用其中组件;
- 通用的 kernel 模板函数 Kernel2 调用 GemmUniversal::invoke 和 GemmUniversalStreamk::invoke 函数,主要实现为 GemmUniversal::run_with_swizzle 和 GemmUniversalStreamk::gemm 函数;
- Threadblock 层面,同样分为 GemmIdentityThreadblockSwizzle 和 ThreadblockSwizzleStreamK 两个分支:
- 采用4阶段的 MmaMultistage,其中 A 和 B 矩阵的迭代器为 PredicatedTileAccessIterator;
- Epilogue 同时继承了 EpilogueBase 和 EpilogueBaseStreamK,DefaultEpilogueTensorOp 中指定了输出迭代器为 PredicatedTileIterator。
可以参照 CUTLASS GEMM Components 展示的层级来理解。
注意:
- 当前 cutlass 仓库中新旧代码并存,该示例调用的是 2.x API。
- 论文中 Data-parallel 和 Fixed-split 均对应到
kGemm
模式,kGemmSplitKParallel
模式为 GemmSplitKParallel。
ampere_gemm_universal_streamk.cu
检查 CUDA Toolkit 版本。
/// Program entrypoint
int main(int argc, const char **argv)
{
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaGetDevice 为 CUDA Runtime API,返回当前正在使用的设备。
cudaGetDeviceProperties 返回有关计算设备的信息 cudaDeviceProp 。
检查设备计算能力。这里要求 SM80以上。
// Current device must must have compute capability at least 80
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(¤t_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!((props.major * 10 + props.minor) >= 80))
{
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
创建一个 Options 结构体。
Options::parse 通过 CommandLine 结构体解析命令行参数。
// Parse commandline options
Options options("ampere_streamk_gemm");
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
std::cout <<
options.iterations << " timing iterations of " <<
options.problem_size.m() << " x " <<
options.problem_size.n() << " x " <<
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
HostTensor::resize 改变逻辑张量的大小。
HostTensor::host_view 返回一个 TensorView 对象。
TensorFillRandomUniform 函数通过 std::rand 生成随机数。
//
// Initialize GEMM datasets
//
// Initialize tensors using CUTLASS helper functions
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
// Fill matrix A on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(),
1,
ElementA(2),
ElementA(-2),
0);
// Fill matrix B on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(),
1,
ElementB(2),
ElementB(-2),
0);
// Fill matrix C on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c.host_view(),
1,
ElementC(2),
ElementC(-2),
0);
HostTensor::sync_device 拷贝数据到设备端。
HostTensor::device_ref 返回一个 TensorRef 对象。
DeviceGemmReference 即 Gemm。调用参考 kernel 计算结果。
HostTensor::sync_host 拷贝数据到主机端。
//
// Compute reference output
//
// Copy data from host to GPU
options.tensor_a.sync_device();
options.tensor_b.sync_device();
options.tensor_c.sync_device();
// Zero-initialize reference output matrix D
cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
options.tensor_ref_d.sync_device();
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
options.problem_size,
ElementAccumulator(options.alpha),
options.tensor_a.device_ref(),
options.tensor_b.device_ref(),
ElementAccumulator(options.beta),
options.tensor_c.device_ref(),
options.tensor_ref_d.device_ref());
// Wait for kernels to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Copy output data from reference kernel to host for comparison
options.tensor_ref_d.sync_host();
options.split_k_factor=1
时比较 Basic-DP 和 StreamK。
调用 run 模板函数来运行参数实例化的 kernel。
DeviceGemmBasic 和 DeviceGemmStreamK 均为 GemmUniversal。只是前者使用 GemmIdentityThreadblockSwizzle 后者使用 ThreadblockSwizzleStreamK。
options.split_k_factor
自增。
//
// Evaluate CUTLASS kernels
//
// Test default operation
if (options.split_k_factor == 1)
{
// Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
Result basic_dp = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
options.split_k_factor++; // Increment splitting factor for next evaluation
}
options.split_k_factor
大于1时,比较 Basic-SplitK 和 SplitK-StreamK。
// Show that StreamK can emulate "Split-K" with a tile-splitting factor
Result basic_splitk = run<DeviceGemmBasic>(
std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
options);
Result streamk_splitk = run<DeviceGemmStreamK>(
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
options);
printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
return 0;
}
run
TensorFill 用标量元素填充张量。
/// Execute a given example GEMM computation
template <typename DeviceGemmT>
Result run(std::string description, Options &options)
{
// Display test description
std::cout << std::endl << description << std::endl;
// Zero-initialize test output matrix D
cutlass::reference::host::TensorFill(options.tensor_d.host_view());
options.tensor_d.sync_device();
创建一个 GemmUniversal 对象。
args_from_options 分为 DeviceGemmBasic 和 DeviceGemmStreamK 两个版本。根据 Options 构造出 GemmUniversal::Arguments,即 GemmUniversalBase::Arguments,即 GemmUniversal::Arguments。
GemmUniversalBase::get_workspace_size 返回由这些参数表示的问题几何形状所需的工作区大小(以字节为单位)。
allocation 即 DeviceAllocation。构造函数调用 allocate 申请内存。
GemmUniversalBase::can_implement 判断能否 grid 是否超出以及形状是否满足对齐要求。
GemmUniversalBase::initialize 初始化参数。
// Instantiate CUTLASS kernel depending on templates
DeviceGemmT device_gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
CUTLASS_CHECK(device_gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
进行功能测试。
调用不带入参的 GemmUniversalBase::operator() 函数。
TensorEquals 检查输出是否和参考值的每个元素都相等。能做到严格相等吗?
// Correctness / Warmup iteration
CUTLASS_CHECK(device_gemm());
// Copy output data from CUTLASS and reference kernel to host for comparison
options.tensor_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = cutlass::reference::host::TensorEquals(
options.tensor_d.host_view(),
options.tensor_ref_d.host_view());
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
性能测试。
GpuTimer 通过 cudaEvent 计时。
gflops 为实际计算吞吐量。
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(device_gemm());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
}
if (!result.passed) {
exit(-1);
}
return result;
}
GemmUniversal
GemmUniversal 是一个有状态的、可重用的 GEMM 句柄。一旦为给定的 GEMM 计算(问题几何形状和数据引用)初始化后,它就可以在具有相同几何形状的不同 GEMM 问题之间重复使用。(一旦初始化,有关问题几何形状和指向工作区内存的引用的详细信息将无法更新。)通用 GEMM 支持串行归约、并行归约、批量跨步和批量数组变体。
主要实现都在 GemmUniversalBase 中。
DefaultGemmUniversal::GemmKernel 即 GemmUniversal 或 GemmUniversalStreamk。
DefaultGemmConfiguration::EpilogueOutputOp,即 LinearCombination。
/*!
GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation
(problem geometry and data references), it can be reused across different GEMM problems having the
geometry. (Once initialized, details regarding problem geometry and references to workspace memory
cannot be updated.)
The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB = ComplexTransform::kNone,
/// Gather operand A by using an index array
bool GatherA = false,
/// Gather operand B by using an index array
bool GatherB = false,
/// Scatter result D by using an index array
bool ScatterD = false,
/// Permute result D
typename PermuteDLayout_ = layout::NoPermute,
/// Permute operand A
typename PermuteALayout_ = layout::NoPermute,
/// Permute operand B
typename PermuteBLayout_ = layout::NoPermute
>
class GemmUniversal :
public GemmUniversalBase<
typename kernel::DefaultGemmUniversal<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_,
SharedMemoryClearOption::kNone,
GatherA,
GatherB,
ScatterD,
PermuteDLayout_,
PermuteALayout_,
PermuteBLayout_
>::GemmKernel
> {
public:
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
using PermuteDLayout = PermuteDLayout_;
using PermuteALayout = PermuteALayout_;
using PermuteBLayout = PermuteBLayout_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
GemmUniversal::GemmKernel 为 GemmUniversalBase::GemmKernel,即 DefaultGemmUniversal::GemmKernel。后者根据传入的模板参数ThreadblockSwizzle
来确定。
using Base = GemmUniversalBase<
typename kernel::DefaultGemmUniversal<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_,
SharedMemoryClearOption::kNone,
GatherA,
GatherB,
ScatterD,
PermuteDLayout_,
PermuteALayout_,
PermuteBLayout_
>::GemmKernel
>;
using Arguments = typename Base::Arguments;
using GemmKernel = typename Base::GemmKernel;
};
GemmUniversalBase
使用 GemmUniversal 或者 GemmUniversalStreamk 中的信息。
template <typename GemmKernel_>
class GemmUniversalBase {
public:
using GemmKernel = GemmKernel_;
/// Boolean indicating whether the CudaHostAdapter is enabled
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
using ElementB = typename GemmKernel::ElementB;
using LayoutB = typename GemmKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename GemmKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
/// Numerical accumulation element type
using ElementAccumulator = typename GemmKernel::Mma::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
Arguments
使用传入GemmKernel
结构体的类型。即 GemmUniversal::Arguments 或者GemmUniversalStreamk::Arguments。
device_ordinal_
的初始值为-1。
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
/// Index of the GEMM Kernel within the CudaHostAdapter
static int32_t const kGemmKernelIndex = 0;
/// Kernel dynamic shared memory allocation requirement
/// Update the kernel function's shared memory configuration for the current device
static constexpr size_t kSharedStorageSize = sizeof(typename GemmKernel::SharedStorage);
protected:
//
// Device properties (uniform across all instances of the current thread)
//
// Device ordinal
CUTLASS_THREAD_LOCAL static int device_ordinal_;
/// Device SM count
CUTLASS_THREAD_LOCAL static int device_sms_;
/// Kernel SM occupancy (in thread blocks)
CUTLASS_THREAD_LOCAL static int sm_occupancy_;
GemmUniversalBase::init_device_props
初始化device_sms_
和sm_occupancy_
,并设置动态 Shared Memory。
如果有必要,初始化线程当前设备的静态线程本地成员。
CUTLASS_TRACE_HOST 在 debug 模式下,打印文件名和行号。
protected:
/// Initialize static thread-local members for the thread's current device,
/// if necessary.
static Status init_device_props()
{
CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()");
cudaGetDevice 返回当前正在使用的设备。
如果当前设备已经初始化了,则直接返回。
cudaError_t cudart_result;
// Get current device ordinal
int current_ordinal;
cudart_result = cudaGetDevice(¤t_ordinal);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
// Done if matches the current static member
if (current_ordinal == device_ordinal_) {
// Already initialized
return Status::kSuccess;
}
cudaDeviceGetAttribute 返回有关设备的信息。
// Update SM count member
cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
cudaFuncSetAttribute 设置给定函数的属性。
如果 SharedMemory 大于48KB,则设置函数的动态分配的共享内存的最大容量。
// If requires more than 48KB: configure for extended, dynamic shared memory
if constexpr (kSharedStorageSize >= (48 << 10))
{
cudart_result = cudaFuncSetAttribute(
Kernel2<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
kSharedStorageSize);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
}
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags 是 CUDA Runtime API,返回每个 SM 运行该 kernel 函数时的最大活跃线程块数。
// Update SM occupancy member
cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
&sm_occupancy_,
Kernel2<GemmKernel>,
GemmKernel::kThreadCount,
kSharedStorageSize,
cudaOccupancyDisableCachingOverride);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
// Update device ordinal member on success
device_ordinal_ = current_ordinal;
CUTLASS_TRACE_HOST(" "
"device_ordinal: (" << device_ordinal_ << "), "
"device_sms: (" << device_sms_ << "), "
"sm_occupancy: (" << sm_occupancy_ << ") "
"smem_size: (" << kSharedStorageSize << ") "
"GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");
return Status::kSuccess;
}
因 Kernel 不同,可能是 GemmUniversal::Params 或者 GemmUniversalStreamk::Params。
protected:
//
// Instance data members
//
/// Kernel parameters
typename GemmKernel::Params params_;
GemmUniversalBase::init_params
初始化params_
。
/// Initialize params member
Status init_params(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
{
int32_t device_sms = 0;
int32_t sm_occupancy = 0;
kEnableCudaHostAdapter 的值为宏CUTLASS_ENABLE_CUDA_HOST_ADAPTER
,未启用。
CudaHostAdapter 类也没有实现。
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
//
// Occupancy query using CudaHostAdapter::query_occupancy().
//
if (cuda_adapter) {
Status status = cuda_adapter->query_occupancy(
&device_sms,
&sm_occupancy,
kGemmKernelIndex,
GemmKernel::kThreadCount,
kSharedStorageSize);
CUTLASS_ASSERT(status == Status::kSuccess);
if (status != Status::kSuccess) {
return status;
}
}
else {
return Status::kErrorInternal;
}
}
因此,调用 GemmUniversalBase::init_device_props 函数得到 SM 数量和 SM 内的最大线程块数。
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
// Initialize static device properties, if necessary
Status result = init_device_props();
if (result != Status::kSuccess) {
return result;
}
//
// Use thread-local static members for occupancy query initialized by call to
// `init_device_props()`
//
device_sms = device_sms_;
sm_occupancy = sm_occupancy_;
}
得到一个 GemmUniversal::Params 或者 GemmUniversalStreamk::Params 对象。
// Initialize params member
params_ = typename GemmKernel::Params(args, device_sms, sm_occupancy);
return Status::kSuccess;
}
GemmUniversalBase::can_implement
调用 kernel 的 GemmUniversal::can_implement 或 GemmUniversalStreamk::can_implement 进一步检查。
public:
//---------------------------------------------------------------------------------------------
// Stateless API
//---------------------------------------------------------------------------------------------
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()");
dim3 grid = get_grid_shape(args, cuda_adapter);
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
grid.z <= std::numeric_limits<uint16_t>::max()))
{
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
GemmUniversalBase::get_workspace_size
返回由这些参数表示的问题几何形状所需的工作区大小(以字节为单位)。
/// Returns the workspace size (in bytes) needed for the problem
/// geometry expressed by these arguments
static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
首先创建一个 GemmUniversalBase 对象。
然后调用 GemmUniversalBase::init_params 初始化参数。
// Initialize parameters from args
GemmUniversalBase base;
if (base.init_params(args, cuda_adapter) != Status::kSuccess) {
return 0;
}
调用 UniversalParamsBase::get_workspace_size 或者 GemmUniversalStreamk::Params::get_workspace_size 函数得到 kernel 需要的全局内存工作空间大小。
// Get size from parameters
size_t workspace_bytes = base.params_.get_workspace_size();
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
return workspace_bytes;
}
GemmUniversalBase::get_grid_shape
/// Returns the grid extents in thread blocks to launch
static dim3 get_grid_shape(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
首先创建一个 GemmUniversalBase 对象。
然后调用 GemmUniversalBase::init_params 初始化参数。
// Initialize parameters from args
GemmUniversalBase base;
if (base.init_params(args, cuda_adapter) != Status::kSuccess) {
return dim3(0,0,0);
}
调用 UniversalParamsBase::get_grid_dims 或者 GemmUniversalStreamk::Params::get_grid_dims 函数得到网格的维度。
// Get dims from parameters
dim3 grid_dims = base.params_.get_grid_dims();
CUTLASS_TRACE_HOST(
" tiled_shape: " << base.params_.get_tiled_shape() << "\n"
<< " grid_dims: {" << grid_dims << "}");
return grid_dims;
}
GemmUniversalBase::maximum_active_blocks
与 GemmUniversalBase::init_params 中的操作类似。
/// Returns the maximum number of active thread blocks per multiprocessor
static int maximum_active_blocks(CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
int32_t device_sms = 0;
int32_t sm_occupancy = 0;
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
if (cuda_adapter) {
Status status = cuda_adapter->query_occupancy(
&device_sms,
&sm_occupancy,
kGemmKernelIndex,
GemmKernel::kThreadCount,
kSharedStorageSize);
CUTLASS_ASSERT(status == Status::kSuccess);
if (status != Status::kSuccess) {
return -1;
}
}
else {
return -1;
}
}
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
// Initialize static device properties, if necessary
if (init_device_props() != Status::kSuccess) {
return -1;
}
sm_occupancy = sm_occupancy_;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
return sm_occupancy;
}
GemmUniversalBase::initialize
// Stateful API
//---------------------------------------------------------------------------------------------
/// Initializes GEMM state from arguments and workspace memory
Status initialize(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
调用 GemmUniversalBase::init_params 函数得到 GemmUniversal::Params 或者 GemmUniversalStreamk::Params。
// Initialize parameters from args
Status result = init_params(args, cuda_adapter);
if (result != Status::kSuccess) {
return result;
}
调用 UniversalParamsBase::init_workspace 函数或者 GemmUniversalStreamk::Params::init_workspace 函数对工作空间清零。
// Assign and prepare workspace memory
if (args.mode == GemmUniversalMode::kGemm) {
return params_.init_workspace(workspace, stream);
}
return Status::kSuccess;
}
GemmUniversalBase::update
调用 GemmUniversal::Params::update 或者 GemmUniversalStreamk::Params::update 函数更新参数。
/// Lightweight update given a subset of arguments.
Status update(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversalBase()::update()");
params_.update(args);
return Status::kSuccess;
}
GemmUniversalBase::run
CUTLASS_TRACE_HOST 宏在 debug 模式下使用。
GemmUniversal::kThreadCount 和 GemmUniversalStreamk::kThreadCount 均通过 WarpCount
得到。后者为 MmaBase::WarpCount,通过应用程序传入的 ThreadblockShape 和 WarpShape 确定。
调用 UniversalParamsBase::get_grid_dims 或 GemmUniversalStreamk::Params::get_grid_dims 函数得到网格维度。
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
// Configure grid and block dimensions
dim3 block(GemmKernel::kThreadCount, 1, 1);
dim3 grid = params_.get_grid_dims();
CUTLASS_ASSERT 为断言。
Kernel2 调用 GemmUniversal::invoke 或者 GemmUniversalStreamk::invoke 函数。
kernel 函数的参数为 GemmUniversal::Params 或者 GemmUniversalStreamk::Params 类。
// Launch kernel
CUTLASS_TRACE_HOST(" "
"grid: (" << grid << "), "
"block: (" << block << "), "
"SMEM: (" << kSharedStorageSize << ")");
if constexpr (kEnableCudaHostAdapter) {
CUTLASS_ASSERT(cuda_adapter);
if (cuda_adapter) {
void* kernel_params[] = {¶ms_};
return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, kernel_params, 0);
}
else {
return Status::kErrorInternal;
}
}
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
Kernel2<GemmKernel><<<grid, block, kSharedStorageSize, stream>>>(params_);
// Query for errors
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
GemmUniversalBase::operator()
重载运算符调用 GemmUniversalBase::run 函数。
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr)
{
return run(stream, cuda_adapter);
}
GemmUniversalBase::operator()
接受输入参数的版本先 GemmUniversalBase::initialize 再 GemmUniversalBase::run。
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr)
{
Status status = initialize(args, workspace, stream, cuda_adapter);
if (status == Status::kSuccess) {
status = run(stream, cuda_adapter);
}
return status;
}
};
UniversalParamsBase
/// Parameters structure
template <
typename ThreadblockSwizzle,
typename ThreadblockShape,
typename ElementA,
typename ElementB,
typename ElementC,
typename LayoutA,
typename LayoutB>
struct UniversalParamsBase
{
//
// Data members
//
GemmCoord problem_size{};
GemmCoord grid_tiled_shape{};
int swizzle_log_tile{0};
GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
int batch_count {0};
int gemm_k_size {0};
int64_t batch_stride_D {0};
int *semaphore = nullptr;
//
// Host dispatch API
//
/// Default constructor
UniversalParamsBase() = default;
UniversalParamsBase::UniversalParamsBase
构造函数调用 UniversalParamsBase::init_grid_tiled_shape 计算切块后的网格形状。
/// Constructor
UniversalParamsBase(
UniversalArgumentsBase const &args, /// GEMM application arguments
int device_sms, /// Number of SMs on the device
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
:
problem_size(args.problem_size),
mode(args.mode),
batch_count(args.batch_count),
batch_stride_D(args.batch_stride_D),
semaphore(nullptr)
{
init_grid_tiled_shape();
}
UniversalParamsBase::get_workspace_size
GemmSplitKParallel 需要problem.m() * problem.n() * k_slice
的工作空间。
/// Returns the workspace size (in bytes) needed for this problem geometry
size_t get_workspace_size() const
{
size_t workspace_bytes = 0;
if (mode == GemmUniversalMode::kGemmSplitKParallel)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes =
sizeof(ElementC) *
size_t(batch_stride_D) *
size_t(grid_tiled_shape.k());
}
串行的话空间对应输出分块数量,因为每个输出分块需要一个同步信号量进行归约。
else if (mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
return workspace_bytes;
}
UniversalParamsBase::init_workspace
调用 UniversalParamsBase::get_workspace_size 获取大小。
cudaMemsetAsync 将同步信号量清零。
分配并初始化指定的工作区缓冲区。 假设分配给工作区的内存至少与 get_workspace_size() 相同大。
/// Assign and initialize the specified workspace buffer. Assumes
/// the memory allocated to workspace is at least as large as get_workspace_size().
Status init_workspace(
void *workspace,
cudaStream_t stream = nullptr)
{
semaphore = static_cast<int *>(workspace);
// Zero-initialize entire workspace
if (semaphore)
{
size_t workspace_bytes = get_workspace_size();
CUTLASS_TRACE_HOST(" Initialize " << workspace_bytes << " workspace bytes");
cudaError_t result = cudaMemsetAsync(
semaphore,
0,
workspace_bytes,
stream);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
UniversalParamsBase::get_tiled_shape
/// Returns the GEMM volume in thread block tiles
GemmCoord get_tiled_shape() const
{
return grid_tiled_shape;
}
UniversalParamsBase::get_grid_blocks
返回要启动的线程块总数。
UniversalParamsBase::get_grid_dims 函数返回网格的维度。
/// Returns the total number of thread blocks to launch
int get_grid_blocks() const
{
dim3 grid_dims = get_grid_dims();
return grid_dims.x * grid_dims.y * grid_dims.z;
}
UniversalParamsBase::get_grid_dims
GemmIdentityThreadblockSwizzle::get_grid_shape 函数根据传入的grid_tiled_shape
以逻辑图块为单位计算 CUDA 网格尺寸。
/// Returns the grid extents in thread blocks to launch
dim3 get_grid_dims() const
{
return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape);
}
UniversalParamsBase::init_grid_tiled_shape
调用 GemmIdentityThreadblockSwizzle::get_tiled_shape 函数以逻辑图块为单位返回问题的形状。
GemmIdentityThreadblockSwizzle::get_log_tile 函数计算最佳光栅化宽度。
private:
CUTLASS_HOST_DEVICE
void init_grid_tiled_shape() {
// Get GEMM volume in thread block tiles
grid_tiled_shape = ThreadblockSwizzle::get_tiled_shape(
problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
batch_count);
swizzle_log_tile = ThreadblockSwizzle::get_log_tile(grid_tiled_shape);
// Determine extent of K-dimension assigned to each block
gemm_k_size = problem_size.k();
如果是 Gemm 模式或者 GemmSplitKParallel 模式,调整grid_tiled_shape.k()
的值。
is_continous_k_aligned 判断 k 维是否对齐。
const_max 返回两个整型的最大值。
CACHELINE_BYTES
是128,写法上支持更大值。
ceil_div 向上对齐的除法。
gemm_k_size
为 GEMM 运算时的 k 维大小。根据问题大小得到 k 维上的分块数量。
if (mode == GemmUniversalMode::kGemm || mode == GemmUniversalMode::kGemmSplitKParallel)
{
static const uint32_t CACHELINE_BYTES = 128;
static const size_t element_bytes_a = sizeof(ElementA);
static const size_t element_bytes_b = sizeof(ElementB);
static const size_t cacheline_elements_a = CACHELINE_BYTES / element_bytes_a;
static const size_t cacheline_elements_b = CACHELINE_BYTES / element_bytes_b;
const bool cacheline_alignment_needed =
util::is_continous_k_aligned<LayoutA, LayoutB>(problem_size, cacheline_elements_a, cacheline_elements_b);
int const kAlignK = const_max(
const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value),
cacheline_alignment_needed ? const_max(cacheline_elements_a, cacheline_elements_b) : 1);
gemm_k_size = round_up(ceil_div(problem_size.k(), batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(problem_size.k(), gemm_k_size);
}
}
}
};
DefaultGemmUniversal
/
//
// Real-valued GEMM kernels
//
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB,
/// Scatter result D by using an index array
bool ScatterD,
/// Permute result D
typename PermuteDLayout,
/// Permute operand A
typename PermuteALayout,
/// Permute operand B
typename PermuteBLayout
>
struct DefaultGemmUniversal<
ElementA,
LayoutA,
ComplexTransform::kNone, // transform A
kAlignmentA,
ElementB,
LayoutB,
ComplexTransform::kNone, // transform B
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator,
SharedMemoryClear,
GatherA,
GatherB,
ScatterD,
PermuteDLayout,
PermuteALayout,
PermuteBLayout,
typename platform::enable_if< ! cutlass::is_complex<ElementAccumulator>::value>::type
> {
DefaultGemmKernel
为 DefaultGemm::GemmKernel,即 Gemm。
DefaultGemmKernel::Mma
为 Gemm::Mma,即
DefaultGemm::Mma,即 DefaultMma::ThreadblockMma,即 MmaMultistage。因为应用程序指定了 NumStages 等于4。
DefaultGemmKernel::Epilogue
为 Gemm::Epilogue,即 DefaultGemm::Epilogue,即 DefaultGemm::RegularEpilogue,即 DefaultEpilogueTensorOp::Epilogue,即 Epilogue。
using DefaultGemmKernel = typename kernel::DefaultGemm<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
true,
Operator,
SharedMemoryClear,
GatherA,
GatherB,
ScatterD,
PermuteDLayout,
PermuteALayout,
PermuteBLayout
>::GemmKernel;
SelectBase
继承 GemmUniversal 或者 GemmUniversalStreamk。
根据传入的ThreadblockSwizzle
是 GemmIdentityThreadblockSwizzle 还是 ThreadblockSwizzleStreamK 推断出来。
/// Universal kernel without StreamkFeature member type
template <class SwizzleT, class Enable = void>
class SelectBase :
public kernel::GemmUniversal<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
SwizzleT>
{};
/// Universal kernel with StreamkFeature member type
template <class SwizzleT>
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> :
public kernel::GemmUniversalStreamk<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
SwizzleT>
{};
/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
using GemmKernel = SelectBase<ThreadblockSwizzle>;
};
GemmUniversal
Mma::Policy
为DefaultMmaCore::MmaPolicy,即 DefaultMmaTensorOp::Policy,即 MmaTensorOpPolicy。
Mma::Operator
为 DefaultMmaCore::MmaTensorOp ,即 DefaultMmaTensorOp::type,MmaTensorOp。
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
class GemmUniversal<
Mma_,
Epilogue_,
ThreadblockSwizzle_,
void,
// 3.x kernels use the first template argument to define the ProblemShape
// We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API
cute::enable_if_t<not (cute::is_tuple<Mma_>::value || IsCutlass3ArrayKernel<Mma_>::value)>
> {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
GemmUniversal::Arguments
主要实现在基类 UniversalArgumentsBase 中。
//
// Structures
//
/// Argument structure
struct Arguments : UniversalArgumentsBase
{
//
// Data members
//
typename EpilogueOutputOp::Params epilogue;
void const * ptr_A;
void const * ptr_B;
void const * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
typename LayoutA::Stride stride_a;
typename LayoutB::Stride stride_b;
typename LayoutC::Stride stride_c;
typename LayoutC::Stride stride_d;
typename LayoutA::Stride::LongIndex lda;
typename LayoutB::Stride::LongIndex ldb;
typename LayoutC::Stride::LongIndex ldc;
typename LayoutC::Stride::LongIndex ldd;
int const * ptr_gather_A_indices;
int const * ptr_gather_B_indices;
int const * ptr_scatter_D_indices;
GemmUniversal::Arguments::Arguments
//
// Methods
//
Arguments():
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr)
{}
GemmUniversal::Arguments::Arguments
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride stride_a,
typename LayoutB::Stride stride_b,
typename LayoutC::Stride stride_c,
typename LayoutC::Stride stride_d,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr)
:
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices)
{
lda = 0;
ldb = 0;
ldc = 0;
ldd = 0;
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
GemmUniversal::Arguments::Arguments
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride::LongIndex lda,
typename LayoutB::Stride::LongIndex ldb,
typename LayoutC::Stride::LongIndex ldc,
typename LayoutC::Stride::LongIndex ldd,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices)
{
stride_a = make_Coord(lda);
stride_b = make_Coord(ldb);
stride_c = make_Coord(ldc);
stride_d = make_Coord(ldd);
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
GemmUniversal::Arguments::transposed_problem
/// Returns arguments for the transposed problem
Arguments transposed_problem() const
{
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
std::swap(args.stride_a, args.stride_b);
std::swap(args.batch_stride_A, args.batch_stride_B);
std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices);
return args;
}
};
GemmUniversal::Params
主要实现同样在基类 UniversalParamsBase 中。
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params : UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC,
LayoutA,
LayoutB>
{
using ParamsBase = UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC,
LayoutA,
LayoutB>;
//
// Data members
//
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename EpilogueOutputOp::Params output_op;
void * ptr_A;
void * ptr_B;
void * ptr_C;
void * ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int * ptr_gather_A_indices;
int * ptr_gather_B_indices;
int * ptr_scatter_D_indices;
//
// Host dispatch API
//
/// Default constructor
Params() = default;
GemmUniversal::Params::Params
/// Constructor
Params(
Arguments const &args, /// GEMM application arguments
int device_sms, /// Number of SMs on the device
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
:
ParamsBase(args, device_sms, sm_occupancy),
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
output_op(args.epilogue),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
ptr_D(args.ptr_D),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
{}
GemmUniversal::Params::update
更新数据指针和 batch 步长。
/// Lightweight update given a subset of arguments.
void update(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
// Update input/output pointers
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_C = const_cast<void *>(args.ptr_C);
ptr_D = args.ptr_D;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
this->batch_stride_D = args.batch_stride_D;
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
output_op = args.epilogue;
}
};
主循环和收尾阶段使用相同的 Shared Memory。
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
GemmUniversal::can_implement
检查问题的尺寸是否满足3个矩阵 layout 的对齐要求。
public:
//
// Host dispatch API
//
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size)
{
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
static int const kAlignmentA = (cute::is_same<LayoutA,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (cute::is_same<LayoutA,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (cute::is_same<LayoutB,
layout::RowMajorInterleaved<32>>::value)
? 32
: (cute::is_same<LayoutB,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = (cute::is_same<LayoutC,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (cute::is_same<LayoutC,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (cute::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (cute::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (cute::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| cute::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (cute::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (cute::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (cute::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| cute::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (cute::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (cute::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (cute::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| cute::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
GemmUniversal::can_implement
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
GemmUniversal::invoke
类静态方法实现工厂调用。 GemmUniversal::operator() 为实现。
public:
//
// Device-only API
//
// Factory invocation
CUTLASS_DEVICE
static void invoke(
Params const ¶ms,
SharedStorage &shared_storage)
{
GemmUniversal op;
op(params, shared_storage);
}
GemmUniversal::operator()
调用 GemmUniversal::run_with_swizzle 函数。
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
ThreadblockSwizzle threadblock_swizzle;
run_with_swizzle(params, shared_storage, threadblock_swizzle);
}
GemmUniversal::run_with_swizzle
Gemm 模式的实现。
调用 GemmIdentityThreadblockSwizzle::get_tile_offset 获得交错重排后的 CTA 坐标。
如果超出区间则直接返回。
/// Executes one GEMM with an externally-provided swizzling function
CUTLASS_DEVICE
void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
offset_k
为当前 CTA 在 k 维上的偏移。
problem_size_k
为当前 CTA 处理的问题 k 维大小。
Gemm 和 GemmSplitKParallel 模式下多个 CTA 处理 k 维。
Batched 和 Array 模式需要调整ptr_A
和ptr_B
当前矩阵的位置。
为什么需要同步线程呢?
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
__syncthreads();
计算 CTA 在 A 和 B 矩阵上的逻辑坐标。
MmaMultistage::IteratorA 为 DefaultMma::IteratorA,即 PredicatedTileAccessIterator。IteratorB
类型与之相同。
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
ptr_A,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A,
params.ptr_gather_A_indices);
typename Mma::IteratorB iterator_B(
params.params_B,
ptr_B,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B,
params.ptr_gather_B_indices);
canonical_warp_idx_sync 得到线程束的索引。
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = canonical_warp_idx_sync();
int lane_idx = threadIdx.x % 32;
gemm_k_iterations
为 CTA 在 k 维上的循环次数。
MmaMultistage 执行主体循环。
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
收尾
EpilogueOutputOp 即 Epilogue::OutputOp,即 EpilogueOp,即 LinearCombination。
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
创建 CTA 间的同步信号量 Semaphore。
如果是kSplitKSerial
, Semaphore::fetch 函数最初获取同步锁但不阻塞。
LinearCombination::set_k_partition 根据归约时 k 的索引设置beta_
值。除了第一个 CTA 外均为1。
已知 GEMM 的公式为:
D
=
α
A
B
+
β
C
D = \alpha AB + \beta C
D=αAB+βC
这样第一个 CTA 根据情况处理 C 矩阵,其他 CTA 均从 Global Memory 加载 D 矩阵,累加部分和。
Epilogue::OutputTileIterator 为 DefaultEpilogueTensorOp::OutputTileIterator,即 DefaultEpilogueTensorOp::PackedOutputTileIterator,即 PredicatedTileIterator。
//
// Fetch pointers based on mode.
//
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
}
else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) {
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C += threadblock_tile_offset.k() * params.batch_stride_C;
ptr_D += threadblock_tile_offset.k() * params.batch_stride_D;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_C = static_cast<ElementC * const *>(params.ptr_C)[threadblock_tile_offset.k()];
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[threadblock_tile_offset.k()];
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
ptr_C,
params.problem_size.mn(),
thread_idx,
threadblock_offset,
params.ptr_scatter_D_indices
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
ptr_D,
params.problem_size.mn(),
thread_idx,
threadblock_offset,
params.ptr_scatter_D_indices
);
创建一个 Epilogue 对象。
如果不是第一个 CTA,则需要切换源矩阵,从前一个线程块计算的结果开始继续计算。
Semaphore::wait 等待到 k 个。
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
Semaphore::release 释放信号量。
//
// Release the semaphore
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
GemmUniversalStreamk
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock mapping function
>
struct GemmUniversalStreamk {
public:
//
// Types and constants
//
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
/// The per-thread tile of raw accumulators
using AccumulatorTile = typename Mma::FragmentC;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
__NV_STD_MAX 是在常量表达式中使用的宏函数,因为 C++等价的功能需要编译器支持constexpr
。这些宏函数以__NV_STD_*
为前缀。
kWorkspaceBytesPerBlock
取 Mma 和 Epilogue 两者中的最大值。
/// Workspace bytes per thread block
static size_t const kWorkspaceBytesPerBlock =
__NV_STD_MAX(
kThreadCount * sizeof(AccumulatorTile),
Epilogue::kWorkspaceBytesPerBlock);
/// Block-striped reduction utility
using BlockStripedReduceT = BlockStripedReduce<kThreadCount, AccumulatorTile>;
GemmUniversalStreamk::Arguments
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode = GemmUniversalMode::kGemm;
GemmCoord problem_size {};
int batch_count {1}; // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor
typename EpilogueOutputOp::Params epilogue{};
void const * ptr_A = nullptr;
void const * ptr_B = nullptr;
void const * ptr_C = nullptr;
void * ptr_D = nullptr;
int64_t batch_stride_A{0};
int64_t batch_stride_B{0};
int64_t batch_stride_C{0};
int64_t batch_stride_D{0};
typename LayoutA::Stride stride_a{0};
typename LayoutB::Stride stride_b{0};
typename LayoutC::Stride stride_c{0};
typename LayoutC::Stride stride_d{0};
typename LayoutA::Stride::LongIndex lda{0};
typename LayoutB::Stride::LongIndex ldb{0};
typename LayoutC::Stride::LongIndex ldc{0};
typename LayoutC::Stride::LongIndex ldd{0};
int avail_sms{-1}; /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
//
// Methods
//
/// Default Constructor
Arguments() = default;
GemmUniversalStreamk::Arguments::Arguments
RowMajor::Stride,即 Coord 的版本。
/// Constructor
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride stride_a,
typename LayoutB::Stride stride_b,
typename LayoutC::Stride stride_c,
typename LayoutC::Stride stride_d,
int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
):
mode(mode),
problem_size(problem_size),
batch_count(batch_split),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), avail_sms(avail_sms)
{
CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size);
}
GemmUniversalStreamk::Arguments::Arguments
RowMajor::LongIndex 的版本。
/// Constructor
Arguments(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
typename EpilogueOutputOp::Params epilogue,
void const * ptr_A,
void const * ptr_B,
void const * ptr_C,
void * ptr_D,
int64_t batch_stride_A,
int64_t batch_stride_B,
int64_t batch_stride_C,
int64_t batch_stride_D,
typename LayoutA::Stride::LongIndex lda,
typename LayoutB::Stride::LongIndex ldb,
typename LayoutC::Stride::LongIndex ldc,
typename LayoutC::Stride::LongIndex ldd,
int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
):
mode(mode),
problem_size(problem_size),
batch_count(batch_split),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), avail_sms(avail_sms)
{
stride_a = make_Coord(lda);
stride_b = make_Coord(ldb);
stride_c = make_Coord(ldc);
stride_d = make_Coord(ldd);
CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size);
}
GemmUniversalStreamk::Arguments::transposed_problem
交换 A 和 B 矩阵。
/// Returns arguments for the transposed problem
Arguments transposed_problem() const
{
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
std::swap(args.stride_a, args.stride_b);
std::swap(args.batch_stride_A, args.batch_stride_B);
return args;
}
};
GemmUniversalStreamk::Params
/// Parameters structure
struct Params
{
public:
//
// Data members
//
void * ptr_A = nullptr;
void * ptr_B = nullptr;
typename Mma::IteratorA::Params params_A{};
typename Mma::IteratorB::Params params_B{};
int64_t batch_stride_A{0};
int64_t batch_stride_B{0};
GemmUniversalMode mode = GemmUniversalMode::kGemm;
ThreadblockSwizzle block_mapping{};
void *barrier_workspace = nullptr;
void *partials_workspace = nullptr;
typename EpilogueOutputOp::Params output_op{};
void * ptr_D = nullptr;
void * ptr_C = nullptr;
typename Epilogue::OutputTileIterator::Params params_D{};
typename Epilogue::OutputTileIterator::Params params_C{};
int64_t batch_stride_D{0};
int64_t batch_stride_C{0};
GemmUniversalStreamk::Params::cacheline_align_up
内部定义静态变量CACHELINE_SIZE
。
将给定的内存分配大小对齐到最近的缓存行边界,减少缓存冲突。
protected:
//
// Host-only dispatch-utilities
//
/// Pad the given allocation size up to the nearest cache line
static size_t cacheline_align_up(size_t size)
{
static const int CACHELINE_SIZE = 128;
return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE;
}
GemmUniversalStreamk::Params::get_barrier_workspace_size
计算执行屏障操作时所需的工作区大小。
ThreadblockSwizzleStreamK::sk_regions 返回 sk 区域的数量。
ThreadblockSwizzleStreamK::sk_blocks_per_region 每个区域中的SK CTA 的数量。
对于原子归约,每个 SK CTA 需要一个同步标志;
对于并行归约,每个归约 CTA 需要其自己的同步标志。
/// Get the workspace size needed for barrier
size_t get_barrier_workspace_size() const
{
// For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction,
// each reduction block needs its own synchronization flag.
int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region();
int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks);
return cacheline_align_up(sizeof(typename Barrier::T) * num_flags);
}
GemmUniversalStreamk::Params::get_partials_workspace_size
ThreadblockSwizzleStreamK::sk_regions 返回 sk 区域的数量。
ThreadblockSwizzleStreamK::sk_blocks_per_region 每个区域中的SK CTA 的数量。
kWorkspaceBytesPerBlock 为每个 CTA 累加结果需要的空间。
/// Get the workspace size needed for intermediate partial sums
size_t get_partials_workspace_size() const
{
int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region();
return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks);
}
public:
//
// Host dispatch API
//
/// Default constructor
Params() = default;
GemmUniversalStreamk::Params::Params
/// Constructor
Params(
Arguments const &args, /// GEMM application arguments
int device_sms, /// Number of SMs on the device
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
:
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
output_op(args.epilogue),
mode(args.mode),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
ptr_D(args.ptr_D),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D),
barrier_workspace(nullptr),
partials_workspace(nullptr)
{
// Number of SMs to make available for StreamK decomposition
int avail_sms = (args.avail_sms == -1) ?
device_sms :
fast_min(args.avail_sms, device_sms);
创建一个ThreadblockSwizzleStreamK 对象。
// Initialize the block mapping structure
block_mapping = ThreadblockSwizzle(
args.mode,
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.batch_count,
sm_occupancy,
device_sms,
avail_sms,
sizeof(ElementA),
sizeof(ElementB),
sizeof(ElementC),
Epilogue::kAccumulatorFragments);
}
GemmUniversalStreamk::Params::get_workspace_size
调用 GemmUniversalStreamk::Params::get_barrier_workspace_size 和 GemmUniversalStreamk::Params::get_partials_workspace_size, 返回工作区大小。
/// Returns the workspace size (in bytes) needed for these parameters
size_t get_workspace_size() const
{
return
get_barrier_workspace_size() +
get_partials_workspace_size();
}
GemmUniversalStreamk::Params::init_workspace
/// Assign and initialize the specified workspace buffer. Assumes
/// the memory allocated to workspace is at least as large as get_workspace_size().
Status init_workspace(
void *workspace,
cudaStream_t stream = nullptr)
{
uint8_t *ptr = static_cast<uint8_t*>(workspace);
调用 GemmUniversalStreamk::Params::get_partials_workspace_size 函数获取大小。
// Establish partials workspace
partials_workspace = nullptr;
size_t partials_workspace_bytes = get_partials_workspace_size();
if (partials_workspace_bytes > 0)
{
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
partials_workspace = ptr;
ptr += partials_workspace_bytes;
}
workspace
=partials_workspace
+barrier_workspace
GemmUniversalStreamk::Params::get_barrier_workspace_size
// Establish barrier workspace
barrier_workspace = nullptr;
size_t barrier_workspace_bytes = get_barrier_workspace_size();
if (barrier_workspace_bytes > 0)
{
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
barrier_workspace = ptr;
ptr += barrier_workspace_bytes;
}
重复定义barrier_workspace_bytes
?
将barrier_workspace
清零。
// Zero-initialize barrier workspace
if (barrier_workspace)
{
size_t barrier_workspace_bytes = get_barrier_workspace_size();
CUTLASS_TRACE_HOST(" Initialize " << barrier_workspace_bytes << " barrier bytes");
cudaError_t result = cudaMemsetAsync(
barrier_workspace,
0,
barrier_workspace_bytes,
stream);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
GemmUniversalStreamk::Params::get_tiled_shape
调用 ThreadblockSwizzleStreamK::tiled_shape 返回三维图块数量。
/// Returns the GEMM volume in thread block tiles
cutlass::gemm::GemmCoord get_tiled_shape() const
{
return block_mapping.tiled_shape();
}
GemmUniversalStreamk::Params::get_grid_blocks
GemmUniversalStreamk::Params::get_grid_dims 得到网格维度。
/// Returns the total number of thread blocks to launch
int get_grid_blocks() const
{
dim3 grid_dims = get_grid_dims();
return grid_dims.x * grid_dims.y * grid_dims.z;
}
GemmUniversalStreamk::Params::get_grid_dims
调用 ThreadblockSwizzleStreamK::get_grid_dims 函数。
/// Returns the grid extents in thread blocks to launch
dim3 get_grid_dims() const
{
return block_mapping.get_grid_dims();
}
GemmUniversalStreamk::Params::update
更新指针、步长信息以及收尾操作。
/// Lightweight update given a subset of arguments.
void update(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()");
// Update input/output pointers
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_C = const_cast<void *>(args.ptr_C);
ptr_D = args.ptr_D;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_D = args.batch_stride_D;
output_op = args.epilogue;
}
};
GemmUniversalStreamk::TileWorkDesc
结构体中包含图块索引、坐标、全局 MAC 起始索引、k 轴 MAC 起止索引、k 轴剩余 MAC 迭代数。
/// Tile work descriptor
struct TileWorkDesc
{
/// The linear tile index
int tile_idx;
/// The location of this tile (in threadblock-tile coordinates) in the output matrix
cutlass::gemm::GemmCoord tiled_coord;
// The first global-scoped MAC-iteration this threadblock will perform for this tile
int iter_begin;
// The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile
int k_begin;
// The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile
int k_end;
/// The number of remaining MAC-iterations this threadblock will perform for this tile
int k_iters_remaining;
GemmUniversalStreamk::TileWorkDesc::tile_started
判断当前 CTA 是否执行图块的第一个 MAC 迭代。
// Whether this block will perform the first iteration of this tile
CUTLASS_DEVICE
bool tile_started()
{
return (k_begin == 0);
}
GemmUniversalStreamk::TileWorkDesc::tile_finished
判断当前 CTA 是否执行图块的第后一个 MAC 迭代。
// Whether this block will perform the last iteration of this tile
CUTLASS_DEVICE
bool tile_finished(Params const ¶ms)
{
return (k_end == params.block_mapping.problem_size.k());
}
};
/// Shared memory storage structure
union SharedStorage
{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
protected:
//
// Data members
//
/// GEMM problem parameters
Params params;
/// Shared storage reference
SharedStorage &shared_storage;
/// ID within the threadblock
int thread_idx;
/// ID of warp
int warp_idx;
/// ID of each thread within a warp
int lane_idx;
/// Threadblock scoped epilogue
Epilogue epilogue;
GemmUniversalStreamk::can_implement
检查问题的尺寸是否满足3个矩阵 layout 的对齐要求。
public:
//
// Host-only dispatch API
//
/// Determines whether the GEMM problem size satisfies this kernel's
/// alignment requirements
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size)
{
CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");
static int const kAlignmentA = (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<LayoutB,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutB,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
GemmUniversalStreamk::can_implement
/// Determines whether the GEMM problem satisfies this kernel's
/// alignment requirements
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
GemmUniversalStreamk::init_iterator_A
根据 GemmUniversalStreamk::TileWorkDesc 中的指针和形状信息初始化矩阵 A 的迭代器 PredicatedTileAccessIterator。
protected:
//
// Device-only utility methods
//
/// Iterator for fetching tile fragments from A
CUTLASS_DEVICE
typename Mma::IteratorA init_iterator_A(
TileWorkDesc &tile_work,
GemmUniversalMode mode)
{
// The input A matrix
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
如果是 Batched 模式,根据 k 值调整ptr_A
到对应的行;如果是 Array 模式,params.ptr_A
是一个指针数组。
// Update input pointers based on batched/array mode
if (mode == GemmUniversalMode::kBatched) {
ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A;
}
if (mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[tile_work.tiled_coord.k()];
}
MmaMultistage::IteratorA 为 DefaultMma::IteratorA,即 PredicatedTileAccessIterator。
int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM;
int m_end = params.block_mapping.problem_size.m();
return typename Mma::IteratorA(
params.params_A,
ptr_A,
{ m_end, tile_work.k_end },
threadIdx.x,
{ m_begin, tile_work.k_begin });
}
GemmUniversalStreamk::init_iterator_B
/// Iterator for fetching tile fragments from B
CUTLASS_DEVICE
typename Mma::IteratorB init_iterator_B(
TileWorkDesc &tile_work,
GemmUniversalMode mode)
{
// The input B matrix
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
// Update input pointers based on batched/array mode
if (mode == GemmUniversalMode::kBatched) {
ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B;
}
if (mode == GemmUniversalMode::kArray) {
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[tile_work.tiled_coord.k()];
}
int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN;
int n_end = params.block_mapping.problem_size.n();
return typename Mma::IteratorB(
params.params_B,
ptr_B,
{ tile_work.k_end, n_end },
threadIdx.x,
{ tile_work.k_begin, n_begin });
}
GemmUniversalStreamk::init_dp_tile_work
初始化 DP 图块的工作描述符 GemmUniversalStreamk::TileWorkDesc。一个 CTA 处理一个图块。
k_iters_remaining
表示线程块在当前图块中还要执行的 MAC 迭代次数。
CUTLASS_DEVICE
void init_dp_tile_work(
TileWorkDesc &tile_work,
int tile_idx)
{
// The linear tile index
tile_work.tile_idx = tile_idx;
// The first global-scoped MAC-iteration this threadblock will perform for this tile
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
// The number of MAC-iterations this threadblock will perform for this tile
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
处理图块的整个 k 轴。
// The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile
tile_work.k_begin = 0;
// The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile
tile_work.k_end = params.block_mapping.problem_size.k();
ThreadblockSwizzleStreamK::get_tile_offset 计算出当前线程块在网格中的二维平铺坐标。
// The location of this tile (in threadblock-tile coordinates) in the output matrix
tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx);
}
GemmUniversalStreamk::init_sk_tile_work
初始化 SK 图块的工作描述符 GemmUniversalStreamk::TileWorkDesc。
CUTLASS_DEVICE
void init_sk_tile_work(
TileWorkDesc &tile_work,
int tile_idx,
int block_iter_begin,
int block_iter_end)
{
// The linear tile index
tile_work.tile_idx = tile_idx;
// The first global-scoped MAC-iteration for this tile
int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile();
一个图块可能由多个 CTA 处理,因此 CTA 处理的第一个图块的起始迭代索引可能不是图块的起始索引。block_iter_begin
为 CTA 处理的迭代起始位置,tile_iter_begin
为图块的迭代起始位置。tile_work.iter_begin
为当前 CTA 负责处理的迭代起始位置。
k_iter_begin
为 CTA 需要处理当前图块的本地起始索引。
tile_work.k_iters_remaining
为 CTA 需要处理的剩余的迭代数。
// The first global-scoped MAC-iteration this threadblock will perform for this tile
tile_work.iter_begin = max(block_iter_begin, tile_iter_begin);
// The first tile-scoped MAC-iteration this threadblock will perform for this tile
int k_iter_begin = tile_work.iter_begin - tile_iter_begin;
// The last (one past) tile-scoped MAC-iteration this threadblock will perform for this tile
int k_iter_end = block_iter_end - tile_iter_begin;
// The number of MAC-iterations this threadblock will perform for this tile
tile_work.k_iters_remaining = k_iter_end - k_iter_begin;
tile_work.k_begin
和tile_work.k_end
为图块任务的 k 轴起止索引。
// The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile
tile_work.k_begin = k_iter_begin * Mma::Shape::kK;
// The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile
tile_work.k_end = min(
params.block_mapping.problem_size.k(), // extent of k domain
(k_iter_end * Mma::Shape::kK)); // extent of the threadblock's global iteration assignment
ThreadblockSwizzleStreamK::get_tile_offset 函数返回图块在输出矩阵中的位置。
// The location of this tile (in threadblock-tile coordinates) in the output matrix
tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx);
}
GemmUniversalStreamk::share_accumulators
协作 CTA 将部分和汇总到accum_tile_workspace
中。
/// Share accumulators with peers
CUTLASS_DEVICE
void share_accumulators(
AccumulatorTile const &accumulator_tile,
int block_idx,
int first_block_idx)
{
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
int accum_tile_offset = first_block_idx * kThreadCount;
如果是第一个 CTA,
- 调用 BlockStriped::store 将
accumulator_tile
中的部分和保存到accum_tile_workspace
;
否则,
- 等待其他 CTA:
- 原子策略:调用 GenericBarrier::wait_lt 等待信号大于0,即第一个 CTA 完成保存;
- 非原子策略:GenericBarrier::wait_eq 等待前面的 CTA 都完成;
- 调用 BlockStripedReduce::reduce 将自己的
accumulator_tile
累加到accum_tile_workspace
;
if (block_idx == first_block_idx)
{
// First peer initializes the workspace partials
BlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx);
}
else
{
// Subsequent peers atomically accumulate into the workspace partials
if (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic)
{
// Non-deterministic reduction order: wait for the first peer to have initialized the partials before we add to them
Barrier::wait_lt(params.barrier_workspace, thread_idx, first_block_idx, 1);
}
else
{
// Turnstile reduction order: wait until the previous peer has written
int wait_count = block_idx - first_block_idx;
Barrier::wait_eq(params.barrier_workspace, thread_idx, first_block_idx, wait_count);
}
// Perform reduction in workspace
BlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx);
}
GenericBarrier::arrive_inc 使用线程0增加某个标志位的到达计数(arrival count)。
// Signal our arrival
Barrier::arrive_inc(params.barrier_workspace, thread_idx, first_block_idx);
}
GemmUniversalStreamk::acquire_accumulators
GenericBarrier::wait_eq_reset 使用0号线程等待前面的num_carry_in
个 CTA 完成暂存。
/// Acquire accumulators from peers
CUTLASS_DEVICE
void acquire_accumulators(
AccumulatorTile &accumulator_tile,
int block_idx,
int first_block_idx)
{
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
// Wait for arrival
int num_carry_in = block_idx - first_block_idx;
Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in);
BlockStripedReduce::load_add 将params.partials_workspace
中的部分和累加到accum_tile_offset
。
// Load and add peer-partials accumulator tile to local accumulator tile
int accum_tile_offset = first_block_idx * kThreadCount;
BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx);
}
GemmUniversalStreamk::do_epilogue
更新指针以指向正确的矩阵位置。
/// Perform epilogue computations and output
CUTLASS_DEVICE
void do_epilogue(
TileWorkDesc &tile_work,
AccumulatorTile &accumulator_tile)
{
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
// Update pointers for batched/array mode(s)
if (params.mode == GemmUniversalMode::kBatched) {
ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C;
ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D;
}
if (params.mode == GemmUniversalMode::kArray) {
ptr_C = static_cast<ElementC * const *>(params.ptr_C)[tile_work.tiled_coord.k()];
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[tile_work.tiled_coord.k()];
}
确定图块在矩阵中的位置。
创建 C 和 D 矩阵的 PredicatedTileIterator 迭代器。
// Location of this tile in item-coords
MatrixCoord threadblock_item_begin(
tile_work.tiled_coord.m() * Mma::Shape::kM,
tile_work.tiled_coord.n() * Mma::Shape::kN
);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
ptr_C,
params.block_mapping.problem_size.mn(),
thread_idx,
threadblock_item_begin);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
ptr_D,
params.block_mapping.problem_size.mn(),
thread_idx,
threadblock_item_begin);
通过 Epilogue 收尾。
// Execute the epilogue operator to update the destination tensor.
epilogue(
EpilogueOutputOp(params.output_op),
iterator_D,
accumulator_tile,
iterator_C);
}
GemmUniversalStreamk::separate_reduction
根据 reduce_idx
确定要图块索引 reduce_tile_idx
和片段索引reduce_fragment_idx
。
CUTLASS_DEVICE
void separate_reduction(int reduce_idx)
{
int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx;
// Reduce by sk-tile (every tile contributed to by one or more blocks)
reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments;
reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments;
ThreadblockSwizzleStreamK::iters_per_tile 为每个图块上的迭代数。
计算当前归约操作的第一个和最后一个迭代位置。
ThreadblockSwizzleStreamK::get_sk_block_idx 计算出该迭代对应的第一个 SK CTA 索引。
peer_idx_begin
和peer_idx_last
为处理这个图块的第一个和最后一个 SK CTA 的索引,用于后续的同步和归约操作。
int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile();
int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1;
peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first);
peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last);
Barrier 即 GenericBarrier。
GenericBarrier::wait_eq_reset 使用0号线程等待num_peers
个 SK CTA 完成暂存。
// Wait for peers to complete
int peer_idx_end = peer_idx_last + 1;
int num_peers = peer_idx_end - peer_idx_begin;
Barrier::wait_eq_reset(
params.barrier_workspace,
thread_idx,
(reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx,
num_peers);
ThreadblockSwizzleStreamK::get_tile_offset 根据reduce_tile_idx
计算出当前线程块在网格中的二维平铺坐标。
/// The location of this tile (in threadblock-tile coordinates) in the output matrix
GemmCoord tiled_coord = params.block_mapping.get_tile_offset(reduce_tile_idx);
// Location of this tile in item-coords
MatrixCoord threadblock_item_begin(
tiled_coord.m() * Mma::Shape::kM,
tiled_coord.n() * Mma::Shape::kN
);
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
ptr_C,
params.block_mapping.problem_size.mn(),
thread_idx,
threadblock_item_begin);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
ptr_D,
params.block_mapping.problem_size.mn(),
thread_idx,
threadblock_item_begin);
Epilogue::reduce 将来自多个 peer block 的累加器片段归约到片上,同时应用收尾计算并将最终结果写入输出矩阵。
与 EpilogueBaseStreamK::share 相对应。
// Execute the epilogue operator to update the destination tensor.
epilogue.reduce(
peer_idx_begin,
peer_idx_end,
reduce_fragment_idx,
params.partials_workspace,
EpilogueOutputOp(params.output_op),
iterator_D,
iterator_C);
}
GemmUniversalStreamk::process_tile
flowchart
st=>start: Start
op=>operation: Your Operation
cond=>condition: Yes or No?
e=>end
st->op->cond
cond(yes)->e
cond(no)->op
调用 GemmUniversalStreamk::init_iterator_A 和 GemmUniversalStreamk::init_iterator_B 初始化输入迭代器。
AccumulatorTile 即 MmaMultistage::FragmentC,即 Mma::FragmentC,即 Array。
创建一个 MmaMultistage 对象。
MmaMultistage::operator() 对当前 tile 执行一系列乘加(MAC)操作,累加结果存储在 accumulator_tile
中。
CUTLASS_DEVICE
void process_tile(
TileWorkDesc tile_work,
int block_idx,
int dp_start_block_idx,
int block_iter_begin)
{
// Initialize input iterators
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);
// Initialize accumulators
AccumulatorTile accumulator_tile;
accumulator_tile.clear();
// Initialize MMA abstraction
Mma mma(
shared_storage.main_loop,
thread_idx,
warp_idx,
lane_idx);
// Perform this tile's range of multiply-accumulate (MAC) iterations
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
如果归约策略是原子的或者没有归约块,或者当前为 DP CTA:
- ThreadblockSwizzleStreamK::get_first_block_idx 获取第一个处理
tile_work.tile_idx
图块的 DP CTA; - GemmUniversalStreamk::TileWorkDesc::tile_finished 判断是否为图块的末尾 CTA:
- 如果不是最后一个 CTA,GemmUniversalStreamk::share_accumulators 将部分和累加到
partials_workspace
; - DP CTA 或者最后一个 SK CTA,
- 调用 GemmUniversalStreamk::acquire_accumulators 将
partials_workspace
累加到accumulator_tile
; - GemmUniversalStreamk::do_epilogue 调用 Epilogue 执行收尾操作。
- 调用 GemmUniversalStreamk::acquire_accumulators 将
- 如果不是最后一个 CTA,GemmUniversalStreamk::share_accumulators 将部分和累加到
if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) ||
(params.block_mapping.reduction_blocks == 0) ||
(block_idx >= dp_start_block_idx))
{
//
// Cooperative SK peer reduction or DP block
//
int first_block_idx = params.block_mapping.get_first_block_idx(tile_work.tile_idx, block_idx);
if (!tile_work.tile_finished(params)) {
// Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace
share_accumulators(accumulator_tile, block_idx, first_block_idx);
}
else
{
// DP blocks and "finishing" SK blocks must perform epilogue operations and write the output tile
if (!tile_work.tile_started())
{
// A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks
acquire_accumulators(accumulator_tile, block_idx, first_block_idx);
}
do_epilogue(tile_work, accumulator_tile);
}
}
否则,SK CTA,
- 通过 EpilogueBaseStreamK::share 将
accumulator_tile
分成多个片段存储到全局共享工作区params.partials_workspace
中; - GenericBarrier::arrive_range_inc 递增 EpilogueBaseStreamK::kAccumulatorFragments 个标志位。
else
{
//
// Separate peer reduction
//
// Share accumulator partial sums with peer threadblock(s) through scratch workspace
epilogue.share(block_idx, params.partials_workspace, accumulator_tile, tile_work.tile_started());
// Signal arrival
Barrier::arrive_range_inc(
params.barrier_workspace,
thread_idx,
tile_work.tile_idx * Epilogue::kAccumulatorFragments,
Epilogue::kAccumulatorFragments);
}
}
GemmUniversalStreamk::gemm
ThreadblockSwizzleStreamK::get_block_idx 返回线性 CTA 索引。
ThreadblockSwizzleStreamK::sk_regions 返回 SK 区域的数量。
ThreadblockSwizzleStreamK::sk_blocks_per_region 返回每个区域包含的 SK CTA 的个数。
sk_padding_start_block_idx
是 SK CTA 的结束位置。
3种 CTA 的顺序为 SK、DP、Reduce。
grid_padding_start_block_idx
是 Reduce CTA 的结束位置。
/// Executes one GEMM
CUTLASS_DEVICE
void gemm()
{
// Initialize block's iteration range
int tile_idx = 0;
int block_iter_begin = 0;
int block_iters_remaining = 0;
int block_idx = params.block_mapping.get_block_idx();
int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;
创建一个 GemmUniversalStreamk::TileWorkDesc 结构体,然后不同类型的 CTA 对其初始化。
// Initialize tile work descriptor
TileWorkDesc tile_work;
bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx);
bool sk_block = (block_idx < sk_padding_start_block_idx);
bool reduce_block = (block_idx >= reduce_start_block_idx) &&
(block_idx < grid_padding_start_block_idx) &&
(ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed);
如果是 Data-parallel 的分块,
dp_block_idx
为当前 CTA 在 DP 块中的索引;tile_idx
为需要处理的图块索引,tile_allottment
为图块数量。如果不是第一个波次,- 则每个 DP CTA 仅分配一个图块;
tile_idx
会增加一个第一个波次的偏移量。
block_iters_remaining
为需要处理的迭代次数;- GemmUniversalStreamk::init_dp_tile_work 初始化 DP 图块的工作描述符;
- 检查 DP CTA 的图块是否与 SK 图块重叠(仅可能发生在 cohort 光栅化中),或者是否超出了矩阵的边界。
if (dp_block)
{
// This is a DP block
int dp_block_idx = block_idx - dp_start_block_idx;
int first_dp_tile = (params.block_mapping.cohort_raster) ? 0 : params.block_mapping.sk_tiles;
// Blocks in first DP wave get configured number of tiles
tile_idx = first_dp_tile + dp_block_idx;
int tile_allottment = params.block_mapping.dp_first_wave_tiles;
// Blocks in subsequent DP waves get 1 tile
if (dp_block_idx >= params.block_mapping.avail_sms) {
tile_allottment = 1;
tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms;
}
block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;
init_dp_tile_work(tile_work, tile_idx);
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
if ((tile_idx < params.block_mapping.sk_tiles) ||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
{
return;
}
}
如果是 stream-k 分块,
- ThreadblockSwizzleStreamK::get_iter_extents 确定当前 SK CTA 在工作分配中的迭代范围;
- ThreadblockSwizzleStreamK::get_sk_tile_idx 根据迭代索引推断出与之对应的图块索引;
GemmUniversalStreamk::init_sk_tile_work 初始化 SK 图块的工作描述符。
else if (sk_block)
{
// This is a SK block
int block_iter_end;
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
block_iters_remaining = block_iter_end - block_iter_begin;
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
如果是归约分块,
reduce_block_idx
为当前 CTA 在 Reduction 块中的相对索引;- 调用 GemmUniversalStreamk::separate_reduction 将多个线程块的计算结果(例如来自 Split-K 块的部分结果)汇总成一个最终的结果;
- 直接返回,不进入循环。
else
{
if (reduce_block)
{
// This is a reduction threadblock
int reduce_block_idx = block_idx - reduce_start_block_idx;
separate_reduction(reduce_block_idx);
}
return;
}
调用 GemmUniversalStreamk::process_tile 函数根据 tile_work
中的信息(图块的坐标、迭代范围等)进行相应的计算。每次处理tile_work.k_iters_remaining
次迭代。
block_iters_remaining
为0时退出。
// Iteration-processing loop body
CUTLASS_PRAGMA_NO_UNROLL
while (true)
{
// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);
block_iters_remaining -= tile_work.k_iters_remaining;
if (block_iters_remaining == 0)
{
break;
}
// Continue to next tile
__syncthreads();
处理下一个图块,
- 如果是 DP CTA,调整为下一个波中的图块,调用 GemmUniversalStreamk::init_dp_tile_work 初始化;
- 如果是 SK CTA,以倒序方式处理图块,调用 GemmUniversalStreamk::init_sk_tile_work 函数。
if (block_idx >= dp_start_block_idx)
{
// DP block consume their tiles at stride
tile_idx += params.block_mapping.avail_sms;
init_dp_tile_work(tile_work, tile_idx);
}
else
{
// SK blocks consume their tiles in backwards order
tile_idx--;
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
}
}
GemmUniversalStreamk::invoke
静态函数创建一个 GemmUniversalStreamk 对象,然后调用 GemmUniversalStreamk::operator()。
public:
//
// Device-only API
//
// Factory invocation
CUTLASS_DEVICE
static void invoke(
Params const ¶ms,
SharedStorage &shared_storage)
{
GemmUniversalStreamk op(params, shared_storage);
op();
}
GemmUniversalStreamk::GemmUniversalStreamk
// Constructor
CUTLASS_DEVICE
GemmUniversalStreamk(
Params const ¶ms,
SharedStorage &shared_storage)
:
params(params),
shared_storage(shared_storage),
thread_idx(threadIdx.x),
warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code
lane_idx(threadIdx.x % 32),
epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx)
{}
GemmUniversalStreamk::operator()
调用 GemmUniversalStreamk::gemm 函数。
/// Executes one GEMM
CUTLASS_DEVICE
void operator()()
{
// Generic SK code path
gemm();
}
};
ThreadblockSwizzleStreamK
/// Threadblock mapping control for GEMMs
struct ThreadblockSwizzleStreamK {
/// Advertise StreamkFeature
using StreamkFeature = void;
/// Kernel traits
template <typename GemmKernel>
struct KernelTraits {};
3种归约方法,这里使用 Mixed 模式。
/// Reduction strategy
enum ReductionStrategy
{
kNone, // Data-parallel strategy (no seams, fixup, etc.)
kAtomic, // Non-deterministic reduction of SK-block partials using atomic aggregation in L2
kMixed, // Deterministic reduction of SK-block partials employing either:
// (a) A separate wave of reduction thread blocks" (for scenarios with lots of
// SK-blocks per SK-tile)
// (b) Turnstile-ordered atomic aggregation in L2 (for scenarios with few
// SK-blocks per SK-tile)
};
static ReductionStrategy const kReductionStrategy = kMixed;
kDpEfficiencyThreshold
没有用到。
光栅化队列为8x4,每个队列包含32个 CTA。
kFixupStartupIterEquiv
和kFixupPeerIterEquiv
没有用到。
//
// Heuristics
//
/// Data-parallel wave-quantization efficiency threshold (above which we go data-parallel)
static float constexpr kDpEfficiencyThreshold = 0.92f;
/// Minimum number of MAC-iterations per streamk block
static int const kMinItersPerSkBlock = 2;
/// Height in CTAs of a grid rasterization cohort
static int const kCohortCtasM = 8;
/// Width in CTAs of a grid rasterization cohort
static int const kCohortCtasN = 4;
/// Number of CTAs per cohort
static int const kCtasPerCohort = kCohortCtasN * kCohortCtasM;
/// Cost-equivalent number of SM-iterations for fixup I/O
static int const kFixupStartupIterEquiv = 10;
static int const kFixupPeerIterEquiv = 3;
//
// Member state
//
/// The 3D value-extents of the GEMM computation volume (m,n,k)
GemmCoord problem_size;
/// Div/mod accelerators
FastDivmod div_mod_tiled_shape_m;
FastDivmod div_mod_tiled_shape_n;
FastDivmod div_mod_tiled_cohort_shape_n;
FastDivmod div_mod_iters_per_tile;
/// Whether to perform cohort CTA rasterization
bool cohort_raster;
// Whether to pad and remap block indices
bool remap_block_indices;
/// CTA occupancy per SM
int sm_occupancy;
/// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size)
int avail_sms;
int dp_blocks; /// Number of data-parallel thread blocks in the grid
int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce
/// Number of reduction blocks in the grid
int reduction_blocks;
int sk_waves;
int sk_tiles;
int sk_big_blocks_per_region;
int sk_iters_per_region;
/// Div/mod accelerators
FastDivmod div_mod_sk_iters_per_normal_block;
FastDivmod div_mod_sk_iters_per_big_block;
FastDivmod div_mod_sk_iters_per_region;
FastDivmod div_mod_sk_regions; //!! used in block map
FastDivmod div_mod_sk_blocks_per_region; //!! used in block map
/// The batch count
int batch_count;
//
// Host+device interface
//
/// Constructor
ThreadblockSwizzleStreamK() = default;
ThreadblockSwizzleStreamK::tiled_shape
返回一个 GemmCoord 对象。
batch_count
放到 k 维上。
/// Returns the GEMM volume in thread block tiles
CUTLASS_HOST_DEVICE
GemmCoord tiled_shape() const
{
return GemmCoord(
static_cast<int>(div_mod_tiled_shape_m),
static_cast<int>(div_mod_tiled_shape_n),
batch_count);
}
ThreadblockSwizzleStreamK::iters_per_tile
FastDivmod::int 可以取出原始的除数。
/// Number of iterations per output tile
CUTLASS_HOST_DEVICE
int iters_per_tile() const
{
return static_cast<int>(div_mod_iters_per_tile);
}
ThreadblockSwizzleStreamK::sk_iters_per_normal_block
/// Number of iterations for normal SK-blocks
CUTLASS_HOST_DEVICE
int sk_iters_per_normal_block() const
{
return static_cast<int>(div_mod_sk_iters_per_normal_block);
}
ThreadblockSwizzleStreamK::sk_regions
/// Number of SK regions
CUTLASS_HOST_DEVICE
int sk_regions() const
{
return static_cast<int>(div_mod_sk_regions);
}
ThreadblockSwizzleStreamK::sk_blocks_per_region
/// Number of SK blocks per region (splitting factor)
CUTLASS_HOST_DEVICE
int sk_blocks_per_region() const
{
return static_cast<int>(div_mod_sk_blocks_per_region);
}
ThreadblockSwizzleStreamK::Print
tiles = dp_tiles + sk_tiles
//
// Host-side interface
//
/// Debug print
void Print()
{
#ifndef __CUDA_ARCH__
auto tiles = tiled_shape().mn().product();
std::cout <<
"problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" <<
", tiled_shape: (" << tiled_shape().m() << "," << tiled_shape().n() << ")" <<
", tiles: " << tiles <<
", dp_tiles: " << tiles - sk_tiles <<
", sk_tiles: " << sk_tiles <<
", iters_per_tile: " << iters_per_tile() <<
", reduction_blocks: " << reduction_blocks <<
", dp_blocks: " << dp_blocks <<
", dp_waves: " << dp_blocks / avail_sms <<
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
", sk_blocks_per_region: " << sk_blocks_per_region() <<
", sk_regions: " << sk_regions() <<
", sk_waves: " << sk_waves <<
", sk_iters_per_normal_block: " << sk_iters_per_normal_block() <<
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
", remap_block_indices: " << remap_block_indices <<
", cohort_raster: " << cohort_raster <<
", sm_occupancy: " << sm_occupancy <<
", avail_sms: " << avail_sms <<
", num_blocks: " << get_num_blocks() <<
"\n\n";
#endif
}
ThreadblockSwizzleStreamK::get_sk_blocks
初始化 savings_iters
为最小整数值,sk_blocks
为 0。如果 sk_tiles
为 0,则直接返回。
// Compute sk_blocks to dispatch for a given number of sk_tiles
static void get_sk_blocks(
int &sk_blocks, /// [out]
int &savings_iters, /// [out]
int sk_tiles,
int iters_per_tile,
int avail_sms,
int max_sk_occupancy,
bool allow_partial_wave)
{
savings_iters = INT_MIN;
sk_blocks = 0;
if (sk_tiles == 0) {
return;
}
sk_iters
为 SK 图块的总迭代次数。
dp_equiv_iters
为等效 DP 迭代次数。由于向上取整,所以大于sk_iters
。
int sk_iters = sk_tiles * iters_per_tile;
int dp_equiv_waves = (sk_tiles + avail_sms - 1) / avail_sms;
int dp_equiv_iters = iters_per_tile * dp_equiv_waves;
kMinItersPerSkBlock 为每个 SK 块内的最小 MAC 循环次数。
如果允许部分波次分配,则最小 SK 线程块数为 avail_sms
和 sk_tiles + 1
之间的最小值。否则,最小线程块数等于 avail_sms
。
最大的 SK 块数 max_sk_blocks
受 avail_sms * max_sk_occupancy
和sk_iters / kMinItersPerSkBlock
的限制。
int min_sk_blocks = (allow_partial_wave) ? fast_min(avail_sms, sk_tiles + 1) : avail_sms;
int max_sk_blocks = fast_min(avail_sms * max_sk_occupancy, sk_iters / kMinItersPerSkBlock);
t
i
m
e
C
T
A
(
g
)
←
a
+
b
(
F
i
x
u
p
P
e
e
r
s
(
g
)
>
1
)
+
c
(
I
t
e
r
s
P
e
r
C
t
a
(
g
)
)
+
d
(
F
i
x
u
p
P
e
e
r
s
(
g
)
−
1
)
\begin{aligned} time_{CTA}(g) \leftarrow & {a} + {b} (FixupPeers(g) > 1) \\ & + {c} (ItersPerCta(g)) +{d} (FixupPeers(g) - 1) \end{aligned}
timeCTA(g)←a+b(FixupPeers(g)>1)+c(ItersPerCta(g))+d(FixupPeers(g)−1)
其中:
I
t
e
r
s
P
e
r
C
t
a
(
g
)
←
⌈
⌈
m
BLK_M
⌉
×
⌈
n
BLK_N
⌉
×
⌈
k
BLK_K
⌉
g
⌉
F
i
x
u
p
P
e
e
r
s
(
g
)
←
⌈
⌈
k
BLK_K
⌉
I
t
e
r
a
t
i
o
n
s
P
e
r
C
t
a
(
g
)
⌉
\begin{aligned} ItersPerCta(g) \leftarrow & \left\lceil \frac{ \lceil \frac{m}{\text{BLK\_M}} \rceil \times \lceil \frac{n}{\text{BLK\_N}} \rceil \times \lceil \frac{k}{\text{BLK\_K}} \rceil} {g}\right\rceil \\ FixupPeers(g) \leftarrow & \left\lceil \frac{\left\lceil\frac{k}{\text{BLK\_K}} \right\rceil} {IterationsPerCta(g)} \right\rceil \end{aligned}
ItersPerCta(g)←FixupPeers(g)←⌈g⌈BLK_Mm⌉×⌈BLK_Nn⌉×⌈BLK_Kk⌉⌉
IterationsPerCta(g)⌈BLK_Kk⌉
遍历[min_sk_blocks, max_sk_blocks]
区间的所有分块数。
- 根据
trial_sk_blocks
计算 SK 波数sk_waves
和每个 CTA 处理的最大 SK 迭代次数max_sk_iters_per_block
,即 I t e r s P e r C t a ( g ) ItersPerCta(g) ItersPerCta(g)。 sk_iter_equiv
为等效的 SK 迭代次数。num_peers
为处理同一图块的 CTA 数量,即 F i x u p P e e r s ( g ) FixupPeers(g) FixupPeers(g)。base_cost
为 CTA 的固定成本 a a a。iter_cost
为迭代成本 c c c。peer_cost
为 CTA 协作成本 b b b 和 d d d。- 如果
trial_savings_iters
大于或等于当前的savings_iters
,则更新savings_iters
和sk_blocks
。
for (int trial_sk_blocks = min_sk_blocks; trial_sk_blocks <= max_sk_blocks; ++trial_sk_blocks)
{
int sk_waves = (trial_sk_blocks + avail_sms - 1) / avail_sms;
int max_sk_iters_per_block = (sk_iters + trial_sk_blocks - 1) / trial_sk_blocks;
int sk_iter_equiv = max_sk_iters_per_block * sk_waves;
int num_peers = ((trial_sk_blocks + sk_tiles - 1) / sk_tiles) + 1; // add one for alignment skew
float iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv);
if (trial_sk_blocks % sk_tiles == 0)
{
// aligned
num_peers = (trial_sk_blocks / sk_tiles);
iter_cost = 0.0f;
}
float peer_cost = 2.0f * float(num_peers);
float base_cost = 2.0f * float(sk_waves);
int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost);
int trial_savings_iters = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv;
if (trial_savings_iters >= savings_iters) {
savings_iters = trial_savings_iters;
sk_blocks = trial_sk_blocks;
}
}
}
ThreadblockSwizzleStreamK::get_blocks
计算全波图块数full_wave_tiles
和部分波图块数partial_wave_tiles
。
如果输出块数能够被 SM 整除,则只使用 DP 分块。
/// Determine the populations of DP and SK blocks to invoke for the given number of output tiles
static void get_blocks(
int &dp_tiles, /// [out]
int &sk_blocks, /// [out]
int output_tiles,
int iters_per_tile,
int avail_sms,
int sm_occupancy)
{
int full_waves = output_tiles / avail_sms;
int full_wave_tiles = full_waves * avail_sms;
int partial_wave_tiles = output_tiles - full_wave_tiles;
int score = -1;
dp_tiles = output_tiles;
sk_blocks = 0;
if (partial_wave_tiles == 0)
{
// Perfect quantization
return;
}
如果full_waves
小于 SM 的最大活跃 CTA 数,通过形成 SK 波来达到满 GPU 占用率。
max_sk_occupancy
计算出可以用于 SK 的最大波数。
调用 ThreadblockSwizzleStreamK::get_sk_blocks 函数计算sk_blocks
和score
。
if (full_waves < sm_occupancy)
{
// We're less than full GPU occupancy
// Form the SK wave from the partial wave to get us up to full GPU occupancy
int max_sk_occupancy = sm_occupancy - full_waves;
dp_tiles = full_wave_tiles;
get_sk_blocks(
sk_blocks,
score,
partial_wave_tiles,
iters_per_tile,
avail_sms,
max_sk_occupancy,
true); // we can run with less than a full wave of SK-blocks
if (score < 0) {
// not profitable
sk_blocks = 0;
dp_tiles = output_tiles;
}
return;
}
如果当前完整波的数量是 SM 占用数的整数倍减一。这意味着如果添加一个部分波,将使占用率达到满负载。代码将尝试通过将剩余的partial_wave_tiles
分配为 SK 块,以确保 GPU 被完全占用。
// We're at (or greater) than GPU occupancy
if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1))
{
// If occupancy is more than one CTA per SM, form the SK wave from the partial
// wave to get us to full GPU occupancy
int max_sk_occupancy = 1;
dp_tiles = full_wave_tiles;
get_sk_blocks(
sk_blocks,
score,
partial_wave_tiles,
iters_per_tile,
avail_sms,
max_sk_occupancy,
true); // we can run with less than a full wave of SK-blocks
if (score >= 0) {
return;
}
}
在 GPU 占用不足时,通过结合最后一个完整波和部分波来形成 SK 波(Stream-K 波)。其目的是调整 SK 块以优化 GPU 资源的利用。
减少 DP 分块的数量。
// Form the SK wave by combining the last full wave and the partial wave
// We're less than full GPU occupancy
dp_tiles = full_wave_tiles - avail_sms;
int max_sk_occupancy = sm_occupancy - ((full_waves - 1) % sm_occupancy);
get_sk_blocks(
sk_blocks,
score,
partial_wave_tiles + avail_sms,
iters_per_tile,
avail_sms,
max_sk_occupancy,
false); // we cannot run with less than a full wave of SK-blocks
if (score < 0) {
// not profitable
sk_blocks = 0;
dp_tiles = output_tiles;
}
}
ThreadblockSwizzleStreamK::ThreadblockSwizzleStreamK
iters_per_tile
为每个图块的迭代次数。
/// Constructor: *Gemm* problem size (m, n, k)
ThreadblockSwizzleStreamK(
GemmUniversalMode const mode_,
GemmCoord const problem_size_,
GemmCoord const tile_size_,
int const batch_split_, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
int const sm_occupancy_,
int const device_sms_,
int const avail_sms_, /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
size_t const element_A_bytes_,
size_t const element_B_bytes_,
size_t const element_C_bytes_,
int const epilogue_acc_fragments_)
:
problem_size(problem_size_),
batch_count((mode_ == GemmUniversalMode::kBatched || mode_ == GemmUniversalMode::kArray) ? batch_split_ : 1),
reduction_blocks(0),
dp_blocks(0),
dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks
sk_tiles(0),
sk_big_blocks_per_region(0),
sk_iters_per_region(0),
sk_waves(0),
sm_occupancy(sm_occupancy_),
remap_block_indices(false),
avail_sms(fast_max(1, avail_sms_)),
cohort_raster(false)
{
int gpu_occupancy = device_sms_ * sm_occupancy;
int iters_per_tile = (problem_size.k() + tile_size_.k() - 1) / tile_size_.k();
int sk_iters_per_normal_block = 0;
int sk_regions = 1; // Default: a single region of iteration space (across all SK tiles)
int sk_blocks_per_region = 0;
tiled_shape
为图块在3维上的形状。
flops_per_byte
和dp_efficiency
没有用到。
GemmCoord tiled_shape(
(problem_size.m() + tile_size_.m() - 1) / tile_size_.m(),
(problem_size.n() + tile_size_.n() - 1) / tile_size_.n(),
batch_count);
size_t problem_bytes =
(element_C_bytes_ * problem_size.m() * problem_size.n()) +
(element_A_bytes_ * problem_size.m() * problem_size.k()) +
(element_B_bytes_ * problem_size.k() * problem_size.n());
size_t problem_flops = size_t(problem_size.m()) * size_t(problem_size.n()) * size_t(problem_size.k()) * 2;
[[maybe_unused]] float flops_per_byte = float(problem_flops) / float(problem_bytes);
int output_tiles = tiled_shape.m() * tiled_shape.n();
int waves = (output_tiles + avail_sms - 1) / avail_sms;
[[maybe_unused]] float dp_efficiency = float(output_tiles) / float(waves * avail_sms);
首先初始化为仅使用 DP 图块。
//
// Determine dispatch composition of DP-tiles and SK-blocks
//
// Start with a DP-only configuration
int dp_tiles = output_tiles; // Number of data-parallel tiles
int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles
仅 Gemm 模式支持 SK 加载平衡。
- 如果
split_factor
大于1,则split_factor
个 SK 处理一个图块,不使用 DP CTA; - 否则如果设置了
kReductionStrategy
且avail_sms
大于1,调用 ThreadblockSwizzleStreamK::get_blocks 启发式计算dp_tiles
和sk_blocks
。
// Only kGemm mode allows for SK load balancing
if (mode_ == GemmUniversalMode::kGemm)
{
int split_factor = batch_split_;
if (split_factor > 1)
{
// Split-K override
dp_tiles = 0;
sk_blocks = output_tiles * split_factor;
}
else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled
(avail_sms > 1)) // Plurality of SMs to load balance across
{
// Use heuristics
get_blocks(
dp_tiles, /// [out]
sk_blocks, /// [out]
output_tiles,
iters_per_tile,
avail_sms,
sm_occupancy);
}
}
计算 SK CTA 的信息。
sk_iters_per_normal_block
为正常 SK CTA 的迭代数。
sk_regions
表示 SK CTA 组处理的子分区数量。
如果sk_tiles
能够被 SK CTA 均分时,sk_regions
为 SK 图块的数量。
得到sk_blocks_per_region
、sk_big_blocks_per_region
和sk_iters_per_region
三个变量。
sk_tiles = output_tiles - dp_tiles;
// Compute SK block iteration details
if (sk_blocks > 0)
{
sk_waves = (sk_blocks + avail_sms - 1) / avail_sms;
int sk_iters = sk_tiles * iters_per_tile;
sk_blocks = fast_min(sk_blocks, sk_iters);
sk_iters_per_normal_block = sk_iters / sk_blocks;
int extra_sk_iters = sk_iters - (sk_iters_per_normal_block * sk_blocks);
int sk_big_blocks = extra_sk_iters;
if ((sk_blocks > sk_tiles) && (sk_blocks % sk_tiles == 0))
{
// Split-K decomposition
sk_regions = sk_tiles;
}
sk_blocks_per_region = sk_blocks / sk_regions;
sk_big_blocks_per_region = sk_big_blocks / sk_regions;
sk_iters_per_region = sk_iters / sk_regions;
使用单独的归约波的条件:
- 使用非原子归约策略;
- SK 波的数量不足以完全占用 GPU;
- 有超过三个 CTA 共同处理一个 SK 图块。
epilogue_acc_fragments_
为 Epilogue::kAccumulatorFragments,即 LinearCombination::kCount,即 AlignmentC,等于8。
// Use a separate reduction wave when all of:
// - Non-atomic reduction stratgy
// - The number of SK waves won't fully occupy the GPU (Otherwise we don't have
// a strong-scaling case for more parallel reduction)
// - More than three peers working on an SK tile. (This occurs when the ratio of
// SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks,
// e.g.:[partial-block | block | block | partial-block] ). With three or
// less peers, the two non-finishing SK-blocks are not expexted to contend.
if ((kReductionStrategy == kMixed) &&
(sk_waves < sm_occupancy) &&
(sk_blocks > 2 * sk_tiles))
{
// Launch a reduction block for every accumulator fragment in each SK-tile
reduction_blocks = sk_tiles * epilogue_acc_fragments_;
}
重新映射块索引的条件:
- 可以占用多个 SM;
- 所有可用的 SM 都在使用;
- 活动 CTA 数量大于可用 SM 数量的两倍。
// When we have a multi-occupancy kernel and at least two waves of active blocks (where
// at least one wave is SK blocks), we need to (1) dispatch at least four waves, and (2)
// remap the block indices so that we can reliably spread the SK blocks evenly across the
// device's first SM occupancy valence. Also see get_num_blocks() and get_block_idx().
remap_block_indices = (
(sm_occupancy > 1) &&
(device_sms_ == avail_sms) &&
(get_num_active_blocks() > avail_sms * 2));
// Initialize fast div/mod members related to SK
div_mod_sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block);
div_mod_sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1);
div_mod_sk_iters_per_region = FastDivmod(sk_iters_per_region);
div_mod_sk_regions = FastDivmod(sk_regions);
div_mod_sk_blocks_per_region = FastDivmod(sk_blocks_per_region);
}
计算 DP CTA 的信息。
在2维平面将图块分为群组,从而提升 L2缓存的重用率,群组形状为tiled_cohort_shape
。类似于 swizzling 方法。
cohort_blocks
为群组中 CTA 的数量。
cohort_efficiency
为群组内 CTA 的有效率。因对齐,可能会小于1。
//
// Compute DP blocks
//
dp_blocks = dp_tiles;
cutlass::gemm::GemmCoord tiled_cohort_shape(
(tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM,
(tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN,
tiled_shape.k());
int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort;
float cohort_efficiency = float(dp_blocks) / float(cohort_blocks);
计算最后一个 SK CTA 所在的群组网格坐标(cohort_grid_m, cohort_grid_n)
。检查其是否超出了tiled_shape
。
// Check if the SK tiles would be in cohorts that are in-bounds
bool sk_in_range = true;
if (sk_tiles > 0)
{
int last_sk_tile = sk_tiles - 1;
int cohort_tile_idx = last_sk_tile / kCtasPerCohort;
int cohort_grid_m = cohort_tile_idx / tiled_cohort_shape.n();
int cohort_grid_n = (cohort_grid_m > 0) ?
tiled_cohort_shape.n() - 1 :
cohort_tile_idx % tiled_cohort_shape.n();
if ((((cohort_grid_m + 1) * kCohortCtasM) >= tiled_shape.m()) ||
(((cohort_grid_n + 1) * kCohortCtasN) >= tiled_shape.n()))
{
sk_in_range = false;
}
}
如果 SK CTA 没有超出,DP CTA 的数量是 GPU 占用的两倍,cohort_efficiency
大于0.85,则
- 启用群组光栅化,更新
dp_blocks
为cohort_blocks
;
否则,更新 DP 第一波的半持久性,以确保完整的网格波集(仅适用于存在 SK 组件且未进行阻塞队列光栅化时):
dp_tile_waves
为 DP 图块需要的波数,full_dp_tile_waves
为其中的完整波数量。dp_first_wave_tiles
为第一个 DP 波中每个 CTA 产生的输出图块数。waveset_excess
表示 SK 波和 DP 图块波的总和与 SM 占用的余数。- 如果
dp_first_wave_tiles + waveset_excess
小于或等于full_dp_tile_waves
,则增加第一个 DP 波中处理的图块数,去掉 SM 占用的余数。
// Decide if we're going to be doing cohort raster
if (sk_in_range &&
(dp_blocks >= gpu_occupancy * 2) &&
(cohort_efficiency > 0.85f))
{
cohort_raster = true;
dp_blocks = cohort_blocks;
}
else if (sk_waves > 0)
{
// Update semi-persistence of first DP wave to ensure full grid wavesets
// (Only applies when there's an SK component and we're not doing blocked cohort rasterization)
int dp_tile_waves = (dp_tiles + avail_sms - 1) / avail_sms;
int full_dp_tile_waves = dp_tiles / avail_sms;
int waveset_excess = (sk_waves + dp_tile_waves) % sm_occupancy;
if (dp_first_wave_tiles + waveset_excess <= full_dp_tile_waves)
{
dp_first_wave_tiles += waveset_excess;
dp_blocks -= (waveset_excess * avail_sms);
}
}
将图块划分构造为 FastDivmod,加速设备端的计算。
// Setup fast-div/mod for device-side usage
div_mod_tiled_shape_m = FastDivmod(tiled_shape.m());
div_mod_tiled_shape_n = FastDivmod(tiled_shape.n());
div_mod_tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
div_mod_iters_per_tile = FastDivmod(iters_per_tile);
}
ThreadblockSwizzleStreamK::get_num_active_blocks
计算有效 CTA 的数量:SK CTA + DP CTA + Reduction CTA。
/// Number of blocks performing useful work
int get_num_active_blocks() const
{
return (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
}
ThreadblockSwizzleStreamK::get_num_blocks
获取当前 GEMM 操作中使用的 CTA 数量。
先调用 ThreadblockSwizzleStreamK::get_num_active_blocks 获得实际需要的 CTA 总数。
如果需要重新映射块索引,则至少要为 SM 分配4个波。
/// Obtains number of threadblocks per GEMM
int get_num_blocks() const
{
int active_blocks = get_num_active_blocks();
if (remap_block_indices)
{
// Add padding blocks if we are performing remapping in order to dispatch a grid of at least four waves
return fast_max(active_blocks, avail_sms * 4);
}
return active_blocks;
}
ThreadblockSwizzleStreamK::get_grid_dims
调用 ThreadblockSwizzleStreamK::get_num_blocks 得到 CTA 总数。
/// Obtains grid extents in CTAs
dim3 get_grid_dims() const
{
return dim3(get_num_blocks(), 1, batch_count);
}
ThreadblockSwizzleStreamK::device_num_blocks
//
// Device-side interface
//
/// Obtains number of threadblocks per GEMM
CUTLASS_DEVICE
int device_num_blocks() const
{
return gridDim.x;
}
ThreadblockSwizzleStreamK::get_sk_tile_idx
/// Obtains tile index for the given sk iteration
CUTLASS_DEVICE
int get_sk_tile_idx(int iter) const
{
int tile_idx = div_mod_iters_per_tile.div(iter);
return tile_idx;
}
ThreadblockSwizzleStreamK::get_batch_idx
/// Obtains the batch index
CUTLASS_DEVICE
int get_batch_idx() const
{
return RematerializeBlockIdxZ();
}
ThreadblockSwizzleStreamK::get_tile_offset
根据给定的 tile_idx
计算出当前线程块在网格中的二维平铺坐标 (m, n)
,并将其封装到 GemmCoord 对象中。
首先使用行主序的方式计算线程块的二维坐标。
/// Obtains the calling threadblock's tiled coordinates for the given tile index
CUTLASS_DEVICE
GemmCoord get_tile_offset(int tile_idx) const
{
int m, n;
// row-major raster
div_mod_tiled_shape_n(m, n, tile_idx);
如果矩阵的行数m
小于列数n
,则切换到列主序光栅化。当矩阵是宽矩阵时,列的优先遍历可能会提高访存性能和资源利用率。
if (tiled_shape().m() < tiled_shape().n())
{
// column-major raster
div_mod_tiled_shape_m(n, m, tile_idx);
}
当启用 cohort_raster
时,线程块按群组进行光栅化排列。
- 计算群组的线性索引
cohort_tile_idx
,转换为群组网格中的二维坐标(cohort_grid_m, cohort_grid_n)
; - 计算 CTA 在群组内的线性索引
block_idx_cohort
,进一步分解为组内二维坐标(block_cohort_m, block_cohort_n)
; - 根据群组网格坐标和群组内部 CTA 索引合成
m
和n
。
这种光栅化方式可以使线程块在 GPU 上的分布更加均匀,可能有助于负载均衡和减少资源竞争。
if (cohort_raster)
{
// tiled cohort raster
int cohort_tile_idx = tile_idx / kCtasPerCohort;
int cohort_grid_m, cohort_grid_n;
div_mod_tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx);
int block_idx_cohort = tile_idx % kCtasPerCohort;
int block_cohort_m = block_idx_cohort / kCohortCtasN;
int block_cohort_n = block_idx_cohort % kCohortCtasN;
m = (cohort_grid_m * kCohortCtasM) + block_cohort_m;
n = (cohort_grid_n * kCohortCtasN) + block_cohort_n;
}
return GemmCoord(m, n, get_batch_idx());
}
ThreadblockSwizzleStreamK::get_tile_offset_row_major
使用行主序(row-major)光栅化的方式来计算每个线程块在网格中的位置,并返回一个 GemmCoord 对象。
ThreadblockSwizzleStreamK::get_batch_idx 返回网格的 z 轴索引。
/// Obtains the calling threadblock's tiled coordinates for the given tile index (row-major rasterization)
CUTLASS_DEVICE
GemmCoord get_tile_offset_row_major(int tile_idx) const
{
// row-major raster
int m, n;
div_mod_tiled_shape_n(m, n, tile_idx);
return GemmCoord(m, n, get_batch_idx());
}
ThreadblockSwizzleStreamK::get_block_idx
获取当前 CTA 的线性索引。
首先获取原始块索引。
/// Obtains calling threadblock's linear threadblock index
CUTLASS_DEVICE
int get_block_idx() const
{
int block_idx = RematerializeBlockIdxX();
如果启用了 remap_block_indices
,并且当前 CTA 在前两波,则重新映射块索引:remapped_block_idx
将相邻的两个线程块(block)分配到不同的波次(wave),以优化计算资源的使用。
// Remap the block indices for the first two waves of thread blocks if
// we have multi-occupancy and the grid constitutes four or more waves
if (remap_block_indices && (block_idx < avail_sms * 2))
{
int dest_sm = block_idx / 2;
int dest_wave = block_idx % 2;
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
block_idx = remapped_block_idx;
}
如果当前 CTA 位于 SK 区域,则进一步重新映射:通过调整 block_in_region
和 region
在div_mod_sk_regions
函数传入的顺序交换二者的值。重新映射减少区域内等待时间,提高计算效率。
假设区域数为3,每个区域的块数为4。按照上述映射规则,计算每个 block_idx
对应的 region
和 block_in_region
:
block_idx | region | block_in_region | 重映射后的索引 |
---|---|---|---|
0 | 0 | 0 | 0 |
1 | 1 | 0 | 4 |
2 | 2 | 0 | 8 |
3 | 0 | 1 | 1 |
4 | 1 | 1 | 5 |
5 | 2 | 1 | 9 |
6 | 0 | 2 | 2 |
7 | 1 | 2 | 6 |
8 | 2 | 2 | 10 |
9 | 0 | 3 | 3 |
10 | 1 | 3 | 7 |
11 | 2 | 3 | 11 |
不确定该操作对访存模式和缓存重用的影响。
// Remap block indices to interleave SK regions to limit intra-region waiting
if (block_idx < sk_regions() * sk_blocks_per_region())
{
int block_in_region;
int region;
div_mod_sk_regions(block_in_region, region, block_idx);
block_idx = (region * sk_blocks_per_region()) + block_in_region;
}
return block_idx;
}
ThreadblockSwizzleStreamK::get_sk_block_idx
根据给定的迭代索引iter
计算出该迭代对应的第一个 SK CTA 索引。
首先计算iter
属于哪个区域,以及在区域中的偏移量iter_in_region
。
/// Obtains calling linear threadblock index of the first block to work on the given tile
CUTLASS_DEVICE
int get_sk_block_idx(int iter) const
{
int region_idx;
int iter_in_region;
div_mod_sk_iters_per_region(region_idx, iter_in_region, iter);
ThreadblockSwizzleStreamK::sk_iters_per_normal_block 为普通 SK CTA 的迭代次数。
计算区域内所有大 CTA 的迭代次数总和big_block_iters
,以及所有普通 CTA 迭代数总和normal_block_iters
。big_block 比 normal_block 的迭代数多一个。big_block 在前,normal_block 在后。
int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block()) + sk_big_blocks_per_region; // number of iterations in the region's big blocks
int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal blocks
假设该 CTA 为 big_block,计算其索引big_block_idx_in_region
;
假设其属于 normal_block,计算其索引normal_block_idx_in_region
。
真正的索引为block_idx_in_region
。
int big_block_idx_in_region = div_mod_sk_iters_per_big_block.div(iter_in_region);
int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod_sk_iters_per_normal_block.div(normal_block_iters);
int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ?
big_block_idx_in_region :
normal_block_idx_in_region;
ThreadblockSwizzleStreamK::sk_blocks_per_region 为每个区域中 SK CTA 的数量。
最终计算出处理当前迭代的线程块在整个线程块网格中的全局索引owning_block_idx
。
int owning_block_idx = (sk_blocks_per_region() * region_idx) + block_idx_in_region;
return owning_block_idx;
}
ThreadblockSwizzleStreamK::get_iter_extents
计算线程块应处理的开始迭代和结束迭代索引。
首先确定sk_block_idx
的所属区域,以及在区域中的索引block_idx_in_region
。
假定没有 big_block,计算全局起始迭代索引block_iter_begin
。
/// Obtains iteration extends for the given SK block index
CUTLASS_DEVICE
void get_iter_extents(
int sk_block_idx,
int &block_iter_begin,
int &block_iter_end) const
{
int region_idx;
int block_idx_in_region;
div_mod_sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx);
block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block());
调整block_iter_begin
的值:
- 如果当前 CTA 是一个“big_block”,则增加块的迭代数量。
- 如果是 normal_block,跳过前面所有的 big_block。
// Adjust extents for the first "num_big_blocks" blocks that get one extra iteration
int block_iters = sk_iters_per_normal_block();
if (block_idx_in_region < sk_big_blocks_per_region) {
// This is a +1 iteration block
block_iter_begin += block_idx_in_region;
block_iters++;
} else {
// This is a regular block
block_iter_begin += sk_big_blocks_per_region;
}
计算迭代的末尾。
block_iter_end = block_iter_begin + block_iters;
}
ThreadblockSwizzleStreamK::get_first_block_idx
获取处理tile_idx
图块时,第一个开始工作的 CTA 索引。
如果是 DP 图块,直接返回 CTA 索引即可;
/// Obtains calling linear threadblock index of the first block to work on the given tile
CUTLASS_DEVICE
int get_first_block_idx(int tile_idx, int block_idx) const
{
if (tile_idx >= sk_tiles) {
// DP tile
return block_idx;
}
否则调用 ThreadblockSwizzleStreamK::iters_per_tile,计算出全局的 MAC 迭代数iter
,再通过 ThreadblockSwizzleStreamK::get_sk_block_idx 函数获取第一个处理给定迭代的 CTA 索引。
int iter = tile_idx * iters_per_tile();
return get_sk_block_idx(iter);
}
};
参考资料:
- Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU
- 聊聊_苹果 AMX_ 矩阵运算单元
- 探索_AMX_: 解锁_Apple_ Silicon隐藏性能
- [QST] How should I set batch_stride of gemm_universal? #702
- [QST] StreamK ReductionStrategy: “Atomic” or “Mixed” #1488
- StreamK in 47_ampere_gemm_universal_streamk
- High-Performance Software Rasterization on GPUs
- variable cache line width ?
- Dissecting GPU Memory Hierarchy through Microbenchmarking
- TESLA V100 GPU
- Comparing LLC-memory Traffic between CPU and GPU Architectures
- OPTIMIZING CUDA APPLICATIONS FOR NVIDIA A100 GPU
- SC18 MatMul CublasLt CUTLASS
- [9.7.12.6. Parallel Synchronization and Communication Instructions: red
- New instruction for inter-CTA barrier in future GPUs? #1502
- 9.7.8.8. Data Movement and Conversion Instructions: ld
- Warps and occupancy - GTC
- [BUG] gemm universal streamk core dump when cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags return sm_occupancy_ is 0 #1223
- cudaGetDevice function what does it do not so clear?
- 算子性能优化 方法介绍
- CS 701 Software Pipelining
- 使用 CUDA 扭曲级别基本体
- libcu++
- C++雾中风景16:std::make_index_sequence, 来试一试新的黑魔法吧
- 谈谈 C++ 中的内存顺序 (Memory Order)
- 8.8. Release and Acquire Patterns
- 7.6. Synchronization Functions
- [QST]
ThreadblockSwizzleStreamK
cost modeling questions #1489 - How register liveness reducing is achieved #1515
- Branching with Predication
- CUTLASS GEMM API
- cutlass GEMM 流水线——single-stage、pipelined、multi-stage