From e27d722c9bbd9c4970c515c20323bce3c8294baa Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 9 Jan 2025 20:25:36 -0500 Subject: [PATCH] Add int8 to gemm w/ addmatrix (#3040) Update the gemm addmatrix benchmark to support int8 inputs as well as bfloat16. Exclude all int8 shapes from correctness testing becasue PyTorch matmul does not support int8 on GPU yet. --- .github/workflows/triton-benchmarks.yml | 12 ++- .../gemm_postop_addmatrix_benchmark.py | 83 ++++++++++++++----- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 84082bbe77..380f32b4ac 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -206,13 +206,21 @@ jobs: source ../../scripts/capture-hw-details.sh python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-gelu.csv $REPORTS/gemm-postop-gelu-triton-report.csv --benchmark gemm-postop-gelu --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - - name: Run Triton GEMM + PostOp (add matrix) kernel benchmark + - name: Run Triton GEMM + PostOp (add matrix) kernel benchmark bfloat16 if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark.py') }} run: | cd benchmarks/triton_kernels_benchmark python gemm_postop_addmatrix_benchmark.py --reports $REPORTS source ../../scripts/capture-hw-details.sh - python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix.csv $REPORTS/gemm-postop-addmatrix-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix-bfloat16.csv $REPORTS/gemm-postop-addmatrix-bfloat16-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + + - name: Run Triton GEMM + PostOp (add matrix) kernel benchmark int8 + if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark.py') }} + run: | + cd benchmarks/triton_kernels_benchmark + INT8_ONLY=1 python gemm_postop_addmatrix_benchmark.py --reports $REPORTS + source ../../scripts/capture-hw-details.sh + python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix-int8.csv $REPORTS/gemm-postop-addmatrix-int8-triton-report.csv --benchmark gemm-postop-addmatrix-int8 --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - name: Run Triton FA kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_fwd_benchmark.py') }} diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 7d9d877660..33551bf8b4 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -5,6 +5,7 @@ This benchmark is modified from gemm_benchmark.py to add a matrix to the output of the gemm operation. """ +import os import torch import triton @@ -12,6 +13,25 @@ import triton_kernels_benchmark as benchmark_suit +INT8_ONLY_OPTION = os.getenv('INT8_ONLY', '0') == '1' +ALL_DTYPES_OPTION = os.getenv('ALL_DTYPES', '0') == '1' + + +def dtypes(): + if ALL_DTYPES_OPTION: + return [torch.bfloat16, torch.int8] + if INT8_ONLY_OPTION: + return [torch.int8] + return [torch.bfloat16] + + +def suffix(): + if ALL_DTYPES_OPTION: + return 'all' + if INT8_ONLY_OPTION: + return 'int8' + return 'bfloat16' + @triton.autotune( configs=[ @@ -43,7 +63,8 @@ def matmul_kernel_with_block_pointers( stride_am: tl.constexpr, stride_ak: tl.constexpr, # stride_bk: tl.constexpr, stride_bn: tl.constexpr, # stride_cm: tl.constexpr, stride_cn: tl.constexpr, # - stride_dm: tl.constexpr, stride_dn: tl.constexpr, + stride_dm: tl.constexpr, stride_dn: tl.constexpr, # + ACCUMULATOR_DTYPE: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): pid = tl.program_id(axis=0) @@ -63,7 +84,7 @@ def matmul_kernel_with_block_pointers( offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0)) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE) for _ in range(0, K, BLOCK_SIZE_K): a = tl.load(a_block_ptr, boundary_check=(0, 1)) b = tl.load(b_block_ptr, boundary_check=(0, 1)) @@ -117,7 +138,8 @@ def matmul_kernel_with_block_pointers_batched( stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, # stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, # stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, # - stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, + stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, # + ACCUMULATOR_DTYPE: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): bid = tl.program_id(axis=0) @@ -141,7 +163,7 @@ def matmul_kernel_with_block_pointers_batched( offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0)) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE) for _ in range(0, K, BLOCK_SIZE_K): a = tl.load(a_block_ptr, boundary_check=(0, 1)) b = tl.load(b_block_ptr, boundary_check=(0, 1)) @@ -185,7 +207,8 @@ def matmul(a, b, d, c): a.stride(0), a.stride(1), a.stride(2), # b.stride(0), b.stride(1), b.stride(2), # c.stride(0), c.stride(1), c.stride(2), # - d.stride(0), d.stride(1), d.stride(2)) + d.stride(0), d.stride(1), d.stride(2), # + tl.float32 if a.dtype.is_floating_point else tl.int32) elif len(a.shape) == 2 and len(b.shape) == 2: assert a.shape[1] == b.shape[0], 'Incompatible dimensions' assert a.is_contiguous(), 'Matrix A must be contiguous' @@ -199,7 +222,8 @@ def matmul(a, b, d, c): a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # - d.stride(0), d.stride(1)) + d.stride(0), d.stride(1), # + tl.float32 if a.dtype.is_floating_point else tl.int32) else: assert False, 'Input matrixs dimensions mismatch' return c @@ -209,10 +233,10 @@ def matmul(a, b, d, c): @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'M', 'K', 'N'], + x_names=['B', 'M', 'K', 'N', 'dtype'], # different possible values for `x_name` - x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + # - [ # + x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in dtypes()] + # + [[*shape, dtype] for shape in [ # [1, 1, 5120, 13824], # [1, 4, 4096, 12288], # [1, 512, 8192, 8192], # @@ -232,8 +256,8 @@ def matmul(a, b, d, c): [4, 32768, 4096, 128], # [32, 4096, 4096, 128], # [4096, 8, 128, 16384], # - [4096, 8, 16384, 128] - ], + [4096, 8, 16384, 128] # + ] for dtype in dtypes()], line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -243,33 +267,46 @@ def matmul(a, b, d, c): # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel=['GB/s', 'TFlops'], # label name for the y-axis - plot_name='matmul-performance-postop-addmatrix', + plot_name='matmul-performance-postop-addmatrix' + '-' + suffix(), # name for the plot. Used also as a file name for saving the plot. args={}, )) -def benchmark(B, M, N, K, provider): +def benchmark(B, M, N, K, dtype, provider): + res_dtype = torch.float32 if dtype.is_floating_point else torch.int32 + if dtype.is_floating_point: + rand = lambda shape, dtype: torch.rand(shape, device='xpu', dtype=dtype) + else: + rand = lambda shape, dtype: torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype) if B == 1: - a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16) - b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16) - d = torch.rand((M, N), device='xpu', dtype=torch.float32) + a = rand((M, K), dtype) + b = rand((K, N), dtype) + d = rand((M, N), res_dtype) else: - a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16) - b = torch.rand((B, K, N), device='xpu', dtype=torch.bfloat16) - d = torch.rand((B, M, N), device='xpu', dtype=torch.float32) + a = rand((B, M, K), dtype) + b = rand((B, K, N), dtype) + d = rand((B, M, N), res_dtype) quantiles = [0.5, 0.0, 1.0] if provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: - c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + c = torch.empty((B, M, N), device='xpu', dtype=res_dtype) else: assert len(a.shape) == 2, 'Expecting shape of length 2' - c = torch.empty((M, N), device='xpu', dtype=torch.float32) + c = torch.empty((M, N), device='xpu', dtype=res_dtype) triton_fn = lambda: matmul(a, b, d, c) - torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d + # Torch does not support integer calculation in matmul + torch_device = 'xpu' if dtype.is_floating_point else 'cpu' + torch_dtype = dtype if dtype.is_floating_point else res_dtype + torch_fn = lambda: torch.matmul(a.to(device=torch_device, dtype=torch_dtype), + b.to(device=torch_device, dtype=torch_dtype)).to(device='xpu', dtype=res_dtype + ) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048], + [1, 512, 8192, 32768], [4, 32768, 4096, 128]]: + # torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: