【AI编译器】triton学习:矩阵乘优化

news2024/10/10 12:20:25

Matrix Multiplication¶

主要内容:

块级矩阵乘法

多维指针算术

重新编排程序以提升L2缓存命

自动性能调整

Motivations¶

矩阵乘法是当今高性能计算系统的一个关键组件,在大多数情况下被用于构建硬件。由于该操作特别复杂,因此通常由软件提供商来进行实现,而不是使用者自己手动编写代码。这些库称为“内核库”(例如cuBLAS),可能会存在一定的版权限制,并且通常无法随意修改以适用于深度学习工作负载中的特殊需求(例如融合式活动函数)。本教程将带你了解如何使用 Triton 自行编写稳定、可扩展的矩阵乘法函数,以实现高性能计算系统所需的功能。

大概来说,我们将要写的内核实现下面这种固定步长算法,对一个(M、K)矩阵和一个(K、N)矩阵进行乘法运算:

# Do in parallel
for m in range(0, M, BLOCK_SIZE_M):
 # Do in parallel
 for n in range(0, N, BLOCK_SIZE_N):
   acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
   for k in range(0, K, BLOCK_SIZE_K):
     a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
     b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
     acc += dot(a, b)
   C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc

在每个嵌套程序的每次递归过程中,Triton 实例都会产生一个专门的计算机程序来执行。

Compute Kernel¶

上面算法的实现在Triton上其实并不困难。只是内部迭代过程中,对于要从A和B缓存区读取数据的记忆位置计算这一步骤会有些复杂。为了实现该操作,我们需要使用多维指针运算方式。

指针数学运算符号¶

如果 X 是一个行向矩阵,那么X[i,j]的存储位址为 &X[i,j] = X + istride_xi + jstride_xj。因此,可以将A[m:m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]和 B[k:k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N]的块中存储地址表示成以下形式的虚拟代码:

&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] =  a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] =  b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);

这意味着在Triton软件中,我们可以将A和B两个向量的指针分段设置为0(即i=0)。且需要注意的是,当M与数据握的大小BLOCK_SIZE_M不是相匹配的时候,我们可以通过添加一个额外模式来处理这种情况,例如,在数据中往底部加上一些无用的值。此后我们将为K维度使用遮蔽式载入操作进行处理。

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)

然后在内部循环中进行以下更新:

a_ptrs += BLOCK_SIZE_K * stride_ak;
b_ptrs += BLOCK_SIZE_K * stride_bk;

二级缓存的优化措施

如上所述,每个程序实例都会计算出一组长度为[BLOCK_SIZE_M,BLOCK_SIZE_N]的C语言代码块。这种方式非常重要,因为执行顺序可能导致该程序中L2缓存的命中率不同,而且令人遗憾的是,如果我们使用矩阵增量式顺序执行代码,则其性能将会受到影响。

pid = tl.program_id(axis=0)
grid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // grid_n
pid_n = pid % grid_n

这根本不够用。

解决这个问题的一种方法是在排列块时采用数据重复使用的最佳组合。 我们可以“将所有 GROUP_M 行都集成到同一块中”,然后再按照对应的列进行排序:

# Program ID
pid = tl.program_id(axis=0)
# Number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# Number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# Number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Id of the group this program is in
group_id = pid // num_pid_in_group
# Row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *Within groups*, programs are ordered in a column-major order
# Row-id of the program in the *launch grid*
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
# Col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m

以下的矩阵乘法是每个矩阵块为九个,共有九个矩阵乘法操作。通过比较如果我们按行向量序排列输出的话,则需要在SRAM中加载90个元素来计算第一层的9个输出值,而若是以固定单元格为基础进行分组操作,只需加载54个元素。



实际上,这会使我们矩阵乘法的算法执行效率提高超过10%(例如:A100在220至245TFLOPS之间)。

Final Result¶

import torch

import triton
import triton.language as tl


def is_cuda():
   return triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_hip_mi200():
   target = triton.runtime.driver.active.get_current_target()
   return target.backend == 'hip' and target.arch == 'gfx90a'


def get_cuda_autotune_config():
   return [
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                     num_warps=8),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                     num_warps=2),
       triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                     num_warps=2),
       # Good config for fp8 inputs.
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                     num_warps=8),
       triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                     num_warps=8),
       triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4)
   ]


def get_hip_autotune_config():
   return [
       triton.Config(
           {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
           num_warps=4, num_stages=0),
       triton.Config(
           {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
           num_warps=8, num_stages=0),
       triton.Config(
           {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
           num_warps=8, num_stages=0),
       triton.Config(
           {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
           num_warps=4, num_stages=0),
       triton.Config(
           {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
           num_warps=4, num_stages=0),
   ]


def get_autotune_config():
   if is_cuda():
       return get_cuda_autotune_config()
   else:
       return get_hip_autotune_config()


# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
@triton.autotune(
   configs=get_autotune_config(),
   key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
       # Pointers to matrices
       a_ptr, b_ptr, c_ptr,
       # Matrix dimensions
       M, N, K,
       # The stride variables represent how much to increase the ptr by when moving by 1
       # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
       # by to get the element one row down (A has M rows).
       stride_am, stride_ak,  #
       stride_bk, stride_bn,  #
       stride_cm, stride_cn,
       # Meta-parameters
       BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
       GROUP_SIZE_M: tl.constexpr,  #
       ACTIVATION: tl.constexpr  #
):
   """Kernel for computing the matmul C = A x B.
   A has shape (M, K), B has shape (K, N) and C has shape (M, N)
   """
   # -----------------------------------------------------------
   # Map program ids `pid` to the block of C it should compute.
   # This is done in a grouped ordering to promote L2 data reuse.
   # See above `L2 Cache Optimizations` section for details.
   pid = tl.program_id(axis=0)
   num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
   num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
   num_pid_in_group = GROUP_SIZE_M * num_pid_n
   group_id = pid // num_pid_in_group
   first_pid_m = group_id * GROUP_SIZE_M
   group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
   pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
   pid_n = (pid % num_pid_in_group) // group_size_m

   # ----------------------------------------------------------
   # Create pointers for the first blocks of A and B.
   # We will advance this pointer as we move in the K direction
   # and accumulate
   # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
   # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
   # See above `Pointer Arithmetic` section for details
   offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
   offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
   offs_k = tl.arange(0, BLOCK_SIZE_K)
   a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
   b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

   # -----------------------------------------------------------
   # Iterate to compute a block of the C matrix.
   # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
   # of fp32 values for higher accuracy.
   # `accumulator` will be converted back to fp16 after the loop.
   accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
   for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
       # Load the next block of A and B, generate a mask by checking the K dimension.
       # If it is out of bounds, set it to 0.
       a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
       b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
       # We accumulate along the K dimension.
       accumulator = tl.dot(a, b, accumulator)
       # Advance the ptrs to the next K block.
       a_ptrs += BLOCK_SIZE_K * stride_ak
       b_ptrs += BLOCK_SIZE_K * stride_bk
   # You can fuse arbitrary activation functions here
   # while the accumulator is still in FP32!
   if ACTIVATION == "leaky_relu":
       accumulator = leaky_relu(accumulator)
   c = accumulator.to(tl.float16)

   # -----------------------------------------------------------
   # Write back the block of the output matrix C with masks.
   offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
   offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
   c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
   c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
   tl.store(c_ptrs, c, mask=c_mask)


# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
   return tl.where(x >= 0, x, 0.01 * x)

我们现在可以创建一个操作简单的函数包装器,具有两个输入张量参数(), 该函数:(1)检查输入张量是否符合任何形状要求;(2)分配输出;(3)调用上面所示的核心计算。

def matmul(a, b, activation=""):
   # Check constraints.
   assert a.shape[1] == b.shape[0], "Incompatible dimensions"
   assert a.is_contiguous(), "Matrix A must be contiguous"
   M, K = a.shape
   K, N = b.shape
   # Allocates output.
   c = torch.empty((M, N), device=a.device, dtype=torch.float16)
   # 1D launch kernel where each block gets its own program.
   grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
   matmul_kernel[grid](
       a, b, c,  #
       M, N, K,  #
       a.stride(0), a.stride(1),  #
       b.stride(0), b.stride(1),  #
       c.stride(0), c.stride(1),  #
       ACTIVATION=activation  #
   )
   return c

Unit Test¶

我们可以对自定义的矩阵乘法运算与原生执行操作的Torch进行测试(即cuBLAS)。

torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# Bigger tolerance for AMD MI200 devices.
# MI200 devices use reduced precision fp16 and bf16 and flush input and
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
rtol = 1e-2 if is_hip_mi200() else 0
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
   print("✅ Triton and Torch match")
else:
   print("❌ Triton and Torch differ")

TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
if TORCH_HAS_FP8 and is_cuda():
   torch.manual_seed(0)
   a = torch.randn((512, 512), device="cuda", dtype=torch.float16)
   b = torch.randn((512, 512), device="cuda", dtype=torch.float16)
   a = a.to(torch.float8_e5m2)
   # pre-transpose b for efficiency.
   b = b.T
   b = b.to(torch.float8_e5m2)
   triton_output = matmul(a, b)
   torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
   print(f"triton_output_with_fp8_inputs={triton_output}")
   print(f"torch_output_with_fp8_inputs={torch_output}")
   if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0):
       print("✅ Triton and Torch match")
   else:
       print("❌ Triton and Torch differ")
triton_output_with_fp16_inputs=tensor([[-10.9531,  -4.7109,  15.6953,  ..., -28.4062,   4.3320, -26.4219],
       [ 26.8438,  10.0469,  -5.4297,  ..., -11.2969,  -8.5312,  30.7500],
       [-13.2578,  15.8516,  18.0781,  ..., -21.7656,  -8.6406,  10.2031],
       ...,
       [ 40.2812,  18.6094, -25.6094,  ...,  -2.7598,  -3.2441,  41.0000],
       [ -6.1211, -16.8281,   4.4844,  ..., -21.0312,  24.7031,  15.0234],
       [-17.0938, -19.0000,  -0.3831,  ...,  21.5469, -30.2344, -13.2188]],
      device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[-10.9531,  -4.7109,  15.6953,  ..., -28.4062,   4.3320, -26.4219],
       [ 26.8438,  10.0469,  -5.4297,  ..., -11.2969,  -8.5312,  30.7500],
       [-13.2578,  15.8516,  18.0781,  ..., -21.7656,  -8.6406,  10.2031],
       ...,
       [ 40.2812,  18.6094, -25.6094,  ...,  -2.7598,  -3.2441,  41.0000],
       [ -6.1211, -16.8281,   4.4844,  ..., -21.0312,  24.7031,  15.0234],
       [-17.0938, -19.0000,  -0.3831,  ...,  21.5469, -30.2344, -13.2188]],
      device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
triton_output_with_fp8_inputs=tensor([[-21.4375,  13.1719,   6.0352,  ...,  28.7031,   8.6719, -40.7500],
       [ 10.0000,  37.0000,  -5.5664,  ...,  20.9844,  46.8125,  30.8281],
       [ 19.5625,  -3.0078, -20.0469,  ...,  -2.1309,  -8.0625,  12.5625],
       ...,
       [-18.1562, -34.1562, -27.4219,  ..., -27.3906, -24.0938, -12.3516],
       [ -3.3945,  -8.6250, -23.6562,  ...,  -4.1094,  -3.5332, -16.0781],
       [-23.9688,  -3.2637, -33.6875,  ...,  17.3125, -36.6250,  25.8594]],
      device='cuda:0', dtype=torch.float16)
torch_output_with_fp8_inputs=tensor([[-21.4375,  13.1719,   6.0352,  ...,  28.7031,   8.6719, -40.7500],
       [ 10.0000,  37.0000,  -5.5664,  ...,  20.9844,  46.8125,  30.8281],
       [ 19.5625,  -3.0078, -20.0469,  ...,  -2.1309,  -8.0625,  12.5625],
       ...,
       [-18.1562, -34.1562, -27.4219,  ..., -27.3906, -24.0938, -12.3516],
       [ -3.3945,  -8.6250, -23.6562,  ...,  -4.1094,  -3.5332, -16.0781],
       [-23.9688,  -3.2637, -33.6875,  ...,  17.3125, -36.6250,  25.8594]],
      device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match

Benchmark¶

性能指标¶

我们可以对现有的内核与cuBLAS或rocBLAS进行比较,这里以对方阵为例,但也能够根据你自定义需求对其他矩阵形状进行性能测试。

ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'

configs = []
for fp8_inputs in [False, True]:
   if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
       continue
   configs.append(
       triton.testing.Benchmark(
           x_names=["M", "N", "K"],  # Argument names to use as an x-axis for the plot
           x_vals=[128 * i for i in range(2, 33)],  # Different possible values for `x_name`
           line_arg="provider",  # Argument name whose value corresponds to a different line in the plot
           # Possible values for `line_arg`
           # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
           line_vals=["triton"] if fp8_inputs else [ref_lib.lower(), "triton"],  # Label name for the lines
           line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"],  # Line styles
           styles=[("green", "-"), ("blue", "-")],
           ylabel="TFLOPS",  # Label name for the y-axis
           plot_name="matmul-performance-" +
           ("fp16" if not fp8_inputs else "fp8"),  # Name for the plot, used also as a file name for saving the plot.
           args={"fp8_inputs": fp8_inputs},
       ))


@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, fp8_inputs):
   a = torch.randn((M, K), device='cuda', dtype=torch.float16)
   b = torch.randn((K, N), device='cuda', dtype=torch.float16)
   if TORCH_HAS_FP8 and fp8_inputs:
       a = a.to(torch.float8_e5m2)
       b = b.T
       b = b.to(torch.float8_e5m2)
   quantiles = [0.5, 0.2, 0.8]
   if provider == ref_lib.lower():
       ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
   if provider == 'triton':
       ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
   perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
   return perf(ms), perf(max_ms), perf(min_ms)


benchmark.run(show_plots=True, print_data=True)




matmul-performance-fp16:
        M       N       K      cuBLAS      Triton
0    256.0   256.0   256.0    4.096000    4.096000
1    384.0   384.0   384.0   12.288000   12.288000
2    512.0   512.0   512.0   26.214401   26.214401
3    640.0   640.0   640.0   42.666665   42.666665
4    768.0   768.0   768.0   63.195428   68.056616
5    896.0   896.0   896.0   78.051553   93.661869
6   1024.0  1024.0  1024.0  110.376426   99.864382
7   1152.0  1152.0  1152.0  135.726544  129.825388
8   1280.0  1280.0  1280.0  163.840004  163.840004
9   1408.0  1408.0  1408.0  155.765024  132.970149
10  1536.0  1536.0  1536.0  176.947204  157.286398
11  1664.0  1664.0  1664.0  179.978245  176.449258
12  1792.0  1792.0  1792.0  172.914215  204.353162
13  1920.0  1920.0  1920.0  200.347822  168.585369
14  2048.0  2048.0  2048.0  226.719125  190.650180
15  2176.0  2176.0  2176.0  211.827867  211.827867
16  2304.0  2304.0  2304.0  229.691080  228.592087
17  2432.0  2432.0  2432.0  203.583068  199.251522
18  2560.0  2560.0  2560.0  224.438347  218.453323
19  2688.0  2688.0  2688.0  200.704002  198.602388
20  2816.0  2816.0  2816.0  211.719459  210.696652
21  2944.0  2944.0  2944.0  218.579083  220.513412
22  3072.0  3072.0  3072.0  208.173173  208.173173
23  3200.0  3200.0  3200.0  214.046818  219.178074
24  3328.0  3328.0  3328.0  208.067338  208.973281
25  3456.0  3456.0  3456.0  217.308808  219.080343
26  3584.0  3584.0  3584.0  216.142772  211.565625
27  3712.0  3712.0  3712.0  209.428397  213.000737
28  3840.0  3840.0  3840.0  210.250955  205.179974
29  3968.0  3968.0  3968.0  213.142249  215.971570
30  4096.0  4096.0  4096.0  221.847481  215.784121
matmul-performance-fp8:
        M       N       K      Triton
0    256.0   256.0   256.0    3.276800
1    384.0   384.0   384.0   10.053818
2    512.0   512.0   512.0   20.164923
3    640.0   640.0   640.0   34.133334
4    768.0   768.0   768.0   42.130286
5    896.0   896.0   896.0   58.538665
6   1024.0  1024.0  1024.0   61.680940
7   1152.0  1152.0  1152.0   80.702267
8   1280.0  1280.0  1280.0  102.400003
9   1408.0  1408.0  1408.0   82.602666
10  1536.0  1536.0  1536.0   99.688560
11  1664.0  1664.0  1664.0  116.868992
12  1792.0  1792.0  1792.0  135.414749
13  1920.0  1920.0  1920.0  100.905113
14  2048.0  2048.0  2048.0  114.912434
15  2176.0  2176.0  2176.0  121.226797
16  2304.0  2304.0  2304.0  134.201527
17  2432.0  2432.0  2432.0  134.423269
18  2560.0  2560.0  2560.0  146.941707
19  2688.0  2688.0  2688.0  118.171514
20  2816.0  2816.0  2816.0  129.036114
21  2944.0  2944.0  2944.0  139.988852
22  3072.0  3072.0  3072.0  144.446699
23  3200.0  3200.0  3200.0  139.433550
24  3328.0  3328.0  3328.0  131.131689
25  3456.0  3456.0  3456.0  139.725414
26  3584.0  3584.0  3584.0  149.113421
27  3712.0  3712.0  3712.0  142.506914
28  3840.0  3840.0  3840.0  138.413021
29  3968.0  3968.0  3968.0  147.194128
30  4096.0  4096.0  4096.0  156.430916

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1862935.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙开发报错 -cppcrash happened

报错信息&#xff1a; cppcrash happened in ‘设备名’ 现象&#xff1a;打开应用就闪退&#xff0c;无论是模拟器还是真机都会闪退&#xff0c;预览器没有问题 报错原因&#xff0c;在入口页面添加了export&#xff0c; 页面是不需要导出的&#xff0c;只有组件需要导出&…

三,SSM整合-前后端分离(实现增删改查)

实现增删改查 实现功能03-添加家居信息需求分析/图解思路分析代码实现注意事项和细节 实现功能04-显示家居信息需求分析/图解思路分析代码实现 实现功能05-修改家居信息需求分析/图解思路分析代码实现注意事项和细节 实现功能06-删除家居信息需求分析/图解思路分析代码实现课后…

数据接入开放协议-术语表

数据与元数据 - 数据 传感器采样生成的 具有物理含义的一个或一系列值 - 实时数据 包含 时间、数据质量指示的一个或一系列数据 - 位号元数据 对于传感器采样得到的数据的物理描述 包括但不限于 时间、单位、数据范围、数据质量指示 上位机与下位机 - 上位机supOS 系统 - 下位…

信息学奥赛初赛天天练-32-CSP-J2021阅读程序-冒泡排序、数组去重、二分查找、坐标确定矩形应用

PDF文档公众号回复关键字:20240625 2021 CSP-J 完善程序3 1 完善程序 (单选题 &#xff0c;每小题3分&#xff0c;共30分) (矩形计数)平面上有n个关键点&#xff0c;求有多少个四条边都和x轴或者y轴平行的矩形&#xff0c;满足四个顶点都是关键点。给出的关键点可能有重复&am…

ComfyUI 插件竟然包含病毒!做好这 5 点降低中招风险!

前言 大家好&#xff0c;这里是和你们一起探索 AI 的AI绘画月月~ ComfyUI 是目前最受欢迎的开源 AI 绘画绘画工具之一&#xff0c;它具有极高的灵活性&#xff0c;只需安装对应的插件就可以自己搭建工作流&#xff0c;实现个性化出图或体验最新的 AI 模型。如果你是 ComfyUI 的…

LP-SCADA数据采集监控平台是什么?

SCADA系统&#xff0c;即数据采集与监视控制系统&#xff0c;是一种以计算机软件为基础&#xff0c;利用计算机技术、自动控制技术、通信与网络技术、传感仪表及执行机构实现对广域分布生产过程设备设施的远程数据采集、控制、监测、参数调节以及各类信号报警的生产过程控制与调…

vue大屏适配方案

前言 开发过大屏的铁汁们应该知道&#xff0c;前期最头疼的就是大屏适配&#xff0c;由于大屏项目需要在市面上不是很常见的显示器上进行展示&#xff0c;所以要根据不同的尺寸进行适配&#xff0c;今天我将为大家分享的我使用的大屏适配方案&#xff0c;话不多说&#xff0c;直…

Unity 弧形图片位置和背景裁剪

目录 关键说明 Unity 设置如下 代码如下 生成和部分数值生成 角度转向量 计算背景范围 关键说明 效果图如下 来自红警ol游戏内的截图 思路&#xff1a;确定中心点为圆的中心点 然后 计算每个的弧度和距离 Unity 设置如下 没什么可以说的主要是背景图设置 代码如下 …

天才程序员周弈帆 | Stable Diffusion 解读(四):Diffusers实现源码解读

本文来源公众号“天才程序员周弈帆”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;Stable Diffusion 解读&#xff08;四&#xff09;&#xff1a;Diffusers实现源码解读 接上一篇文章[天才程序员周弈帆 | Stable Diffusion 解读…

关于哈希表的一点理论基础

1.哈希表 数组就是哈希表&#xff0c;关键元素就是数组的下标。主要用于判断一个元素是否在集合中出现。 2.哈希函数 3.哈希碰撞 即有两个元素的下标相同 3.哈希碰撞的解决办法 &#xff08;1&#xff09;拉链法&#xff1a; 放到链表中&#xff0c;但要注意控制哈希表的大…

【SpringSecurity】认证与鉴权框架SpringSecurity——授权

目录 权限系统的必要性常见的权限管理框架SpringSecurity授权基本流程准备脚本限制访问资源所需权限菜单实体类和Mapper封装权限信息封装认证/鉴权失败处理认证失败封装鉴权失败封装配置SpringSecurity 过滤器跨域处理接口添加鉴权hasAuthority/hasAnyAuthorityhasRole/​ hasA…

L58---100.路径总和(广搜)---Java版

1.题目描述 2.思路 (1)首先检查p,q节点是不是为空&#xff1b;如果同时为空&#xff0c;则他们是相同的树 (2)p&#xff0c;q节点如果一个为空&#xff0c;一个不为空&#xff0c;则他们不是相同的树 (3)p,q的值不一样&#xff0c;则他们不是相同的树 (4)递归遍历左子树和右子…

用英文介绍纽约:NEW YORK, USA‘s MEGACITY

NEW YORK, USA’s MEGACITY | America’s Largest City Link: https://www.youtube.com/watch?vdzjQ-akB3BI&listPLmSQiOQJmbZ7TU39cyx7gizM9i8nOuZXy&index24 The story of New York City, America’s megalopolis. Summary Paragraph 1: The Historical Developm…

路由表操作

路由表&#xff08;Routing Table&#xff09;是网络设备&#xff08;如计算机、路由器、交换机等&#xff09;用来确定数据包传输路径的数据库。每当网络设备收到一个数据包时&#xff0c;它会查找路由表&#xff0c;决定将数据包转发到哪个网络接口或网关。下面介绍路由表的基…

vue3 antv/g6 动态设置mode,让节点不可以拖动

1、查看一下官网的设置说明 G6 设置mode 默认模式&#xff1a; const graph new G6.Graph({container: div,width: 500,height: 500,modes: {default: [drag-node,drag-canvas],custom: [drag-canvas]} })默认情况下&#xff0c;我们定义的是default&#xff0c;然后创建节…

Emacs之显示blame插件:blamer、git-messenger(一百四十四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

【C++STL】Vector扩容机制

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; &#x1f525;c系列专栏&#xff1a;C/C零基础到精通 &#x1f525; 给大…

AI文档助手:提升文档处理效率

随着人工智能技术的飞速发展&#xff0c;AI文档助手已经成为我们提升工作效率的重要工具。小编就来和大家分享几款AI文档助手&#xff0c;它们能够通过智能化的功能帮助我们快速、准确地完成各种文档任务。 1.百度文库AI助手 百度文库AI助手是百度基于文心一言重构的一站式智能…

【jupyter notebook】解决打不开以及安装扩展插件的问题

文章目录 问题描述问题 1解决问题 2解决 问题描述 问题 1 在自定义的虚拟环境下&#xff0c;安装 jupyter notebook 6.4.12 版本时&#xff0c;报以下错误&#xff1a; 解决 查了一些 解决方法&#xff0c;执行以下命令即可解决&#xff1a; conda install traitlets5.9.0 …

WAV怎么转mp3?将wav转成MP3的几种方法介绍

WAV怎么转mp3&#xff1f;很多情况下&#xff0c;我们可能需要将高质量的 WAV 文件转换为更小、更兼容的 MP3 文件。例如&#xff0c;你可能想要为你的音乐收藏腾出更多存储空间&#xff0c;或者需要将音频文件上传到联网平台&#xff0c;而这些平台通常对文件大小有严格限制。…