Skip to content

Commit

Permalink
Add int8 to gemm w/ addmatrix (#3040)
Browse files Browse the repository at this point in the history
Update the gemm addmatrix benchmark to support int8 inputs as well as
bfloat16. Exclude all int8 shapes from correctness testing becasue PyTorch matmul does not support int8 on GPU yet.
  • Loading branch information
alexbaden authored Jan 10, 2025
1 parent e8e47af commit e27d722
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 25 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,21 @@ jobs:
source ../../scripts/capture-hw-details.sh
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-gelu.csv $REPORTS/gemm-postop-gelu-triton-report.csv --benchmark gemm-postop-gelu --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark bfloat16
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark.py') }}
run: |
cd benchmarks/triton_kernels_benchmark
python gemm_postop_addmatrix_benchmark.py --reports $REPORTS
source ../../scripts/capture-hw-details.sh
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix.csv $REPORTS/gemm-postop-addmatrix-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix-bfloat16.csv $REPORTS/gemm-postop-addmatrix-bfloat16-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark int8
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark.py') }}
run: |
cd benchmarks/triton_kernels_benchmark
INT8_ONLY=1 python gemm_postop_addmatrix_benchmark.py --reports $REPORTS
source ../../scripts/capture-hw-details.sh
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix-int8.csv $REPORTS/gemm-postop-addmatrix-int8-triton-report.csv --benchmark gemm-postop-addmatrix-int8 --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
- name: Run Triton FA kernel benchmark
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_fwd_benchmark.py') }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,33 @@
This benchmark is modified from gemm_benchmark.py to add a matrix to the output of the gemm operation.
"""
import os

import torch
import triton
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit

INT8_ONLY_OPTION = os.getenv('INT8_ONLY', '0') == '1'
ALL_DTYPES_OPTION = os.getenv('ALL_DTYPES', '0') == '1'


def dtypes():
if ALL_DTYPES_OPTION:
return [torch.bfloat16, torch.int8]
if INT8_ONLY_OPTION:
return [torch.int8]
return [torch.bfloat16]


def suffix():
if ALL_DTYPES_OPTION:
return 'all'
if INT8_ONLY_OPTION:
return 'int8'
return 'bfloat16'


@triton.autotune(
configs=[
Expand Down Expand Up @@ -43,7 +63,8 @@ def matmul_kernel_with_block_pointers(
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
stride_dm: tl.constexpr, stride_dn: tl.constexpr,
stride_dm: tl.constexpr, stride_dn: tl.constexpr, #
ACCUMULATOR_DTYPE: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
pid = tl.program_id(axis=0)
Expand All @@ -63,7 +84,7 @@ def matmul_kernel_with_block_pointers(
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
order=(1, 0))

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
for _ in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_block_ptr, boundary_check=(0, 1))
b = tl.load(b_block_ptr, boundary_check=(0, 1))
Expand Down Expand Up @@ -117,7 +138,8 @@ def matmul_kernel_with_block_pointers_batched(
stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, #
stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr,
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, #
ACCUMULATOR_DTYPE: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
bid = tl.program_id(axis=0)
Expand All @@ -141,7 +163,7 @@ def matmul_kernel_with_block_pointers_batched(
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
order=(1, 0))

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
for _ in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_block_ptr, boundary_check=(0, 1))
b = tl.load(b_block_ptr, boundary_check=(0, 1))
Expand Down Expand Up @@ -185,7 +207,8 @@ def matmul(a, b, d, c):
a.stride(0), a.stride(1), a.stride(2), #
b.stride(0), b.stride(1), b.stride(2), #
c.stride(0), c.stride(1), c.stride(2), #
d.stride(0), d.stride(1), d.stride(2))
d.stride(0), d.stride(1), d.stride(2), #
tl.float32 if a.dtype.is_floating_point else tl.int32)
elif len(a.shape) == 2 and len(b.shape) == 2:
assert a.shape[1] == b.shape[0], 'Incompatible dimensions'
assert a.is_contiguous(), 'Matrix A must be contiguous'
Expand All @@ -199,7 +222,8 @@ def matmul(a, b, d, c):
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
d.stride(0), d.stride(1))
d.stride(0), d.stride(1), #
tl.float32 if a.dtype.is_floating_point else tl.int32)
else:
assert False, 'Input matrixs dimensions mismatch'
return c
Expand All @@ -209,10 +233,10 @@ def matmul(a, b, d, c):
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['B', 'M', 'K', 'N'],
x_names=['B', 'M', 'K', 'N', 'dtype'],
# different possible values for `x_name`
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + #
[ #
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in dtypes()] + #
[[*shape, dtype] for shape in [ #
[1, 1, 5120, 13824], #
[1, 4, 4096, 12288], #
[1, 512, 8192, 8192], #
Expand All @@ -232,8 +256,8 @@ def matmul(a, b, d, c):
[4, 32768, 4096, 128], #
[32, 4096, 4096, 128], #
[4096, 8, 128, 16384], #
[4096, 8, 16384, 128]
],
[4096, 8, 16384, 128] #
] for dtype in dtypes()],
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
Expand All @@ -243,33 +267,46 @@ def matmul(a, b, d, c):
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
plot_name='matmul-performance-postop-addmatrix',
plot_name='matmul-performance-postop-addmatrix' + '-' + suffix(),
# name for the plot. Used also as a file name for saving the plot.
args={},
))
def benchmark(B, M, N, K, provider):
def benchmark(B, M, N, K, dtype, provider):
res_dtype = torch.float32 if dtype.is_floating_point else torch.int32
if dtype.is_floating_point:
rand = lambda shape, dtype: torch.rand(shape, device='xpu', dtype=dtype)
else:
rand = lambda shape, dtype: torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype)
if B == 1:
a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16)
b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16)
d = torch.rand((M, N), device='xpu', dtype=torch.float32)
a = rand((M, K), dtype)
b = rand((K, N), dtype)
d = rand((M, N), res_dtype)
else:
a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16)
b = torch.rand((B, K, N), device='xpu', dtype=torch.bfloat16)
d = torch.rand((B, M, N), device='xpu', dtype=torch.float32)
a = rand((B, M, K), dtype)
b = rand((B, K, N), dtype)
d = rand((B, M, N), res_dtype)

quantiles = [0.5, 0.0, 1.0]

if provider == 'triton':
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
c = torch.empty((B, M, N), device='xpu', dtype=res_dtype)
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
c = torch.empty((M, N), device='xpu', dtype=res_dtype)
triton_fn = lambda: matmul(a, b, d, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
# Torch does not support integer calculation in matmul
torch_device = 'xpu' if dtype.is_floating_point else 'cpu'
torch_dtype = dtype if dtype.is_floating_point else res_dtype
torch_fn = lambda: torch.matmul(a.to(device=torch_device, dtype=torch_dtype),
b.to(device=torch_device, dtype=torch_dtype)).to(device='xpu', dtype=res_dtype
) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048],
[1, 512, 8192, 32768], [4, 32768, 4096, 128]]:
# torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
else:
Expand Down

0 comments on commit e27d722

Please sign in to comment.