mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16测试
- 1.参考文档
- 2.numpy测试
- 3.cuda kernel测试
- 4.相关截图
本文演示了如何按PTX指令文档中的layout格式要求,加载数据,执行mma指令,并且跟numpy对比结果的一致性
1.参考文档
- Matrix Fragments for mma.m16n8k16 with floating point type
- Warp-level matrix load instruction: ldmatrix
2.numpy测试
tee numpy_gemm.py<<-'EOF'
import numpy as np
def print_data(data):
r,c=data.shape
for i in range(r):
for j in range(c):
print(f"{data[i][j]:8.3f}",end=",")
print("\n",end="")
print("\n",end="")
M=16
N=8
K=16
input_a=np.arange(M*K,dtype=np.float32).reshape(M,K)*0.01
input_a=input_a.astype(np.float16)
input_b=np.arange(K*N,dtype=np.float32).reshape(K,N)*0.01
input_b=input_b.astype(np.float16)
output_d=np.dot(input_a,input_b)
print_data(output_d)
EOF
python numpy_gemm.py
输出
0.992, 1.004, 1.016, 1.028, 1.040, 1.052, 1.063, 1.076,
2.527, 2.566, 2.604, 2.641, 2.678, 2.717, 2.754, 2.791,
4.062, 4.125, 4.191, 4.254, 4.316, 4.379, 4.441, 4.508,
5.602, 5.688, 5.777, 5.867, 5.953, 6.043, 6.133, 6.223,
7.137, 7.250, 7.363, 7.480, 7.594, 7.707, 7.820, 7.938,
8.672, 8.812, 8.953, 9.094, 9.234, 9.375, 9.508, 9.656,
10.211, 10.375, 10.539, 10.703, 10.867, 11.031, 11.203, 11.367,
11.742, 11.938, 12.125, 12.320, 12.508, 12.695, 12.891, 13.086,
13.281, 13.500, 13.711, 13.930, 14.148, 14.359, 14.578, 14.797,
14.812, 15.055, 15.297, 15.547, 15.789, 16.031, 16.266, 16.516,
16.359, 16.625, 16.891, 17.156, 17.422, 17.688, 17.953, 18.234,
17.891, 18.188, 18.469, 18.766, 19.062, 19.359, 19.641, 19.938,
19.422, 19.750, 20.062, 20.391, 20.703, 21.016, 21.344, 21.656,
20.969, 21.312, 21.656, 22.000, 22.344, 22.688, 23.031, 23.375,
22.500, 22.859, 23.234, 23.609, 23.984, 24.344, 24.719, 25.094,
24.031, 24.422, 24.828, 25.219, 25.609, 26.016, 26.406, 26.797,
3.cuda kernel测试
tee mma_ops.cu<<-'EOF'
#include <iostream>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cuda.h>
#include <mma.h>
using namespace nvcuda;
#define WARP_SIZE 32
#define CHECK_CUDA(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
//mma指令
#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) \
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" \
: "=r"(RD0), "=r"(RD1) \
: "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1))
//加载A矩阵(行存储)
#define LDMATRIX_X4(R0, R1, R2, R3, addr) \
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" \
: "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) \
: "l"(addr))
//加载B矩阵(行存储),需要转置
#define LDMATRIX_X2(R0, R1, addr) \
asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "l"(addr))
//异步加载数据
#define CP_ASYNC_CG(dst, src, Bytes) \
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(Bytes))
#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
#define CP_ASYNC_WAIT_GROUP(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N))
#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
/*
测试一:从dram中加载数据,按mma的layout,放在寄存器中,执行mma执行
*/
__global__ void ptx_mma_global(half* input_A, half* input_B, half* input_C, int M, int N, int K) {
const size_t laneid = threadIdx.x % WARP_SIZE;
uint32_t RA[4];
uint32_t RB[2];
uint32_t RC[2];
RC[0]=0;
RC[1]=0;
//A矩阵: M*K 16*16
/*
指令文档中要求每一个thread按以下规则存放数据
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
groupID + 8 Otherwise
col = (threadID_in_group * 2) + (i & 0x1) for ai where i < 4
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
*/
clock_t begin=clock64();
int groupID = laneid /4;
int threadID_in_group = laneid % 4;
int row_a0=groupID;
int col_a0=(threadID_in_group * 2) + (0 & 0x1);
int row_a2=groupID + 8;
int col_a2=(threadID_in_group * 2) + (2 & 0x1);
int row_a4=groupID;
int col_a4=(threadID_in_group * 2) + (4 & 0x1) + 8;
int row_a6=groupID + 8;
int col_a6=(threadID_in_group * 2) + (6 & 0x1) + 8;
//A矩阵a0 a1是连续存放的,这里用uint32_t来存放
RA[0]=*(uint32_t*)&input_A[row_a0*K+col_a0];
RA[1]=*(uint32_t*)&input_A[row_a2*K+col_a2];
RA[2]=*(uint32_t*)&input_A[row_a4*K+col_a4];
RA[3]=*(uint32_t*)&input_A[row_a6*K+col_a6];
/
//B矩阵 K*N=16*8
/*
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
col = groupID
*/
//B矩阵非连续,每个元素单独提取
int row_b0=(threadID_in_group * 2) + (0 & 0x1);
int col_b0=groupID;
int row_b1=(threadID_in_group * 2) + (1 & 0x1);
int col_b1=groupID;
int row_b2=(threadID_in_group * 2) + (2 & 0x1) + 8 ;
int col_b2=groupID;
int row_b3=(threadID_in_group * 2) + (3 & 0x1) + 8 ;
int col_b3=groupID;
half *ptr_b=(half*)RB;
ptr_b[0]=*(half*)&input_B[row_b0*N+col_b0];
ptr_b[1]=*(half*)&input_B[row_b1*N+col_b1];
ptr_b[2]=*(half*)&input_B[row_b2*N+col_b2];
ptr_b[3]=*(half*)&input_B[row_b3*N+col_b3];
/
//C矩阵 M*N=16*8
/*
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = groupID for ci where i < 2
groupID + 8 for ci where i >= 2
col = (threadID_in_group * 2) + (i & 0x1) for ci where i = {0,..,3}
*/
int row_c0=groupID;
int col_c0=(threadID_in_group * 2) + (0 & 0x1);
int row_c2=groupID + 8;
int col_c2=(threadID_in_group * 2) + (2 & 0x1);
HMMA16816(RC[0], RC[1],
RA[0], RA[1], RA[2], RA[3],
RB[0],RB[1],
RC[0], RC[1]);
*(uint32_t*)&input_C[row_c0*N+col_c0]=RC[0];
*(uint32_t*)&input_C[row_c2*N+col_c2]=RC[1];
clock_t end=clock64();
if(laneid==0)
{
printf("ptx_mma_global kernel e2e:%ld\n",end-begin);
}
}
/*
测试二:从dram中加载数据到share memory中,再用ldmatrix指令加载到寄存器中,执行mma执行
*/
__global__ void ptx_mma_shared(half* input_A, half* input_B, half* input_C, int M, int N, int K) {
const size_t laneid = threadIdx.x % WARP_SIZE;
__shared__ half A[16*16];
__shared__ half B[16*8];
clock_t begin=clock64();
uint32_t smem_lane_addr = __cvta_generic_to_shared(&A[laneid*8]);
CP_ASYNC_CG(smem_lane_addr,&input_A[laneid*8],16);
if(laneid<16)
{
uint32_t smem_lane_addr = __cvta_generic_to_shared(&B[laneid*8]);
CP_ASYNC_CG(smem_lane_addr,&input_B[laneid*8],16);
}
CP_ASYNC_COMMIT_GROUP();
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
uint32_t RA[4];
uint32_t RB[2];
uint32_t RC[2];
RC[0]=0;
RC[1]=0;
/*
文档要求:
When reading 8x8 matrices, a group of four consecutive threads loads 16 bytes. The matrix addresses must be naturally aligned accordingly.
Each thread in a warp loads fragments of a row
因此:
1.对于A矩阵(16*16),需要分成4个8*8的矩阵,二行二列,由32个线程一起完成,一行8个元素,half类型,16字节,由连续的4个thread负责
一个8*8的矩阵,需要32个线程协同完成(一个warp)
2.ldmatrix要求传入的地址为一行的首地址,将laneid转成每行的首地址
lanid%16->生成0-15的行号(因为需要每行首地址)->每一行的步进单位为16(2列)
lanid/16->生成0-1的列号(因为有二列)->每一列的步进单位为8
首行地址=laneid % 16 * 16 + laneid / 16 * 8;
print([(x%16,x//16) for x in range(32)])
[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0), (7, 0), (8, 0), (9, 0), (10, 0), (11, 0), (12, 0), (13, 0), (14, 0), (15, 0),
(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1), (9, 1), (10, 1), (11, 1), (12, 1), (13, 1), (14, 1), (15, 1)]
从第0列到第1列
3.对于B矩阵(16*8),只有一列,首行地址=laneid*8
*/
int aTile_index = laneid % 16 * 16 + laneid / 16 * 8;
LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], __cvta_generic_to_shared(&A[aTile_index]));
int bTile_index = laneid * 8;
LDMATRIX_X2(RB[0], RB[1], __cvta_generic_to_shared(&B[bTile_index]));
//执行mma执行
HMMA16816(RC[0], RC[1],
RA[0], RA[1], RA[2], RA[3],
RB[0], RB[1],
RC[0], RC[1]);
//C矩阵 M*N=16*8
/*
groupID = %laneid >> 2
threadID_in_group = %laneid % 4
row = groupID for ci where i < 2
groupID + 8 for ci where i >= 2
col = (threadID_in_group * 2) + (i & 0x1) for ci where i = {0,..,3}
*/
int groupID = laneid /4;
int threadID_in_group = laneid % 4;
int row_c0 = groupID;
int col_c0 = (threadID_in_group * 2) + (0 & 0x1);
int row_c2 = groupID + 8;
int col_c2 = (threadID_in_group * 2) + (2 & 0x1);
//写回到DRAM
*(uint32_t*)&input_C[row_c0*N+col_c0]=RC[0];
*(uint32_t*)&input_C[row_c2*N+col_c2]=RC[1];
clock_t end=clock64();
if(laneid==0)
{
printf("ptx_mma_shared kernel e2e:%ld\n",end-begin);
}
}
int M=16;
int N=8;
int K=16;
void dump(half *host_c)
{
for(int r=0;r<M;r++)
{
for(int c=0;c<N;c++)
{
printf("%8.3f ",__half2float(host_c[r*N+c]));
}
printf("\n");
}
}
int main() {
half *host_a = new half[M*K];
half *host_b = new half[K*N];
half *host_c = new half[M*N];
half *dev_a;
half *dev_b;
half *dev_c;
CHECK_CUDA(cudaMalloc(&dev_a, sizeof(half)*M*K));
CHECK_CUDA(cudaMalloc(&dev_b, sizeof(half)*K*N));
CHECK_CUDA(cudaMalloc(&dev_c, sizeof(half)*M*N));
for(int i = 0; i < M*K; ++i) host_a[i] = __float2half(i*0.01);
for(int i = 0; i < K*N; ++i) host_b[i] = __float2half(i*0.01);
CHECK_CUDA(cudaMemcpy(dev_a, host_a, sizeof(half)*M*K,cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(dev_b, host_b, sizeof(half)*K*N,cudaMemcpyHostToDevice));
for(int i = 0; i < M*N; ++i) host_c[i] = 0;
CHECK_CUDA(cudaMemcpy(dev_c, host_c, sizeof(half)*K*N,cudaMemcpyHostToDevice));
ptx_mma_global<<<1, 32>>>(dev_a, dev_b,dev_c,M,N,K);cudaDeviceSynchronize();
cudaMemcpy(host_c, dev_c, sizeof(half)*M*N, cudaMemcpyDeviceToHost);
dump(host_c);
printf("------------------------------------------------------------\n");
CHECK_CUDA(cudaMemcpy(dev_a, host_a, sizeof(half)*M*K,cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(dev_b, host_b, sizeof(half)*K*N,cudaMemcpyHostToDevice));
for(int i = 0; i < M*N; ++i) host_c[i] = 0;
CHECK_CUDA(cudaMemcpy(dev_c, host_c, sizeof(half)*K*N,cudaMemcpyHostToDevice));
ptx_mma_shared<<<1, 32>>>(dev_a, dev_b,dev_c,M,N,K);cudaDeviceSynchronize();
cudaMemcpy(host_c, dev_c, sizeof(half)*M*N, cudaMemcpyDeviceToHost);
dump(host_c);
cudaFree(dev_a);
cudaFree(dev_b);
cudaFree(dev_c);
free(host_a);
free(host_b);
free(host_c);
return 0;
}
EOF
/usr/local/cuda/bin/nvcc -std=c++17 -g -arch=sm_86 -lineinfo mma_ops.cu -o mma_ops
./mma_ops
/usr/local/NVIDIA-Nsight-Compute/ncu --set full --section SpeedOfLight_HierarchicalTensorRooflineChart --target-processes all --clock-control=none \
--print-details all --export ncu_report_mma_ops -f ./mma_ops
# 查看tensor core利用率
/usr/local/NVIDIA-Nsight-Compute/ncu --metrics \
sm__ops_path_tensor_src_fp16_dst_fp16_sparsity_off.sum.pct_of_peak_sustained_elapsed,\
sm__ops_path_tensor_src_fp16_dst_fp16_sparsity_off.sum,\
sm__ops_path_tensor_src_fp16_dst_fp16_sparsity_off.sum.peak_sustained,\
sm__ops_path_tensor_src_fp16_dst_fp16_sparsity_off.sum.per_second,\
sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed,\
sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_active,\
sm__cycles_elapsed ./mma_ops
输出
ptx_mma_global kernel e2e:564
0.992 1.004 1.016 1.028 1.040 1.052 1.063 1.076
2.529 2.566 2.604 2.641 2.678 2.717 2.754 2.791
4.062 4.125 4.191 4.254 4.316 4.379 4.441 4.508
5.598 5.688 5.777 5.867 5.953 6.043 6.133 6.223
7.137 7.250 7.363 7.480 7.594 7.707 7.820 7.938
8.672 8.812 8.953 9.094 9.234 9.375 9.508 9.656
10.211 10.375 10.539 10.703 10.867 11.039 11.203 11.367
11.742 11.938 12.125 12.320 12.508 12.695 12.891 13.086
13.281 13.500 13.711 13.930 14.148 14.359 14.578 14.797
14.820 15.055 15.297 15.547 15.789 16.031 16.266 16.516
16.359 16.625 16.891 17.156 17.422 17.688 17.953 18.234
17.891 18.188 18.469 18.766 19.062 19.359 19.641 19.938
19.422 19.750 20.062 20.391 20.703 21.016 21.328 21.656
20.969 21.312 21.641 22.000 22.344 22.688 23.031 23.375
22.500 22.859 23.234 23.609 23.984 24.344 24.719 25.094
24.031 24.422 24.828 25.219 25.625 26.016 26.406 26.812
------------------------------------------------------------
ptx_mma_shared kernel e2e:424
0.992 1.004 1.016 1.028 1.040 1.052 1.063 1.076
2.529 2.566 2.604 2.641 2.678 2.717 2.754 2.791
4.062 4.125 4.191 4.254 4.316 4.379 4.441 4.508
5.598 5.688 5.777 5.867 5.953 6.043 6.133 6.223
7.137 7.250 7.363 7.480 7.594 7.707 7.820 7.938
8.672 8.812 8.953 9.094 9.234 9.375 9.508 9.656
10.211 10.375 10.539 10.703 10.867 11.039 11.203 11.367
11.742 11.938 12.125 12.320 12.508 12.695 12.891 13.086
13.281 13.500 13.711 13.930 14.148 14.359 14.578 14.797
14.820 15.055 15.297 15.547 15.789 16.031 16.266 16.516
16.359 16.625 16.891 17.156 17.422 17.688 17.953 18.234
17.891 18.188 18.469 18.766 19.062 19.359 19.641 19.938
19.422 19.750 20.062 20.391 20.703 21.016 21.328 21.656
20.969 21.312 21.641 22.000 22.344 22.688 23.031 23.375
22.500 22.859 23.234 23.609 23.984 24.344 24.719 25.094
24.031 24.422 24.828 25.219 25.625 26.016 26.406 26.812
4.相关截图