From c9e53ab63abab3891c2b9e4622aa79860feb6a2e Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Thu, 21 Nov 2024 14:13:28 -0800 Subject: [PATCH 1/3] [FA] fix an assertion failure due to refactoring in PR54 We move the static_assert to the top-level kernel. After moving, the static_assert will be caught by autotuner: try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) except (OutOfResources, CompileTimeAssertionFailure, PTXASError): return [float("inf"), float("inf"), float("inf")] Prior to the change, CompileTimeAssertionFailure somehow is not caught and got reported and failed the build. --- tritonbench/kernels/triton_fused_attention.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 89160948..06f4e702 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -454,10 +454,10 @@ def _attn_fwd_inner_ws( num_warps=w, ) ) - for BM in [128] # 64, 128] - for BN in [128] # 64, 128] - for s in [3] # , 4, 7] - for w in [8] # 4, 8] + for BM in [64, 128] + for BN in [64, 128] + for s in [3, 4, 7] + for w in [4, 8] ] # TMA, WS, and CompPipe configsTmaWS = [ @@ -548,7 +548,6 @@ def _attn_fwd_compute( ENABLE_TMA: tl.constexpr, LOOP_SCHEDULE: tl.constexpr, ): - tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) off_z = off_hz // H @@ -993,6 +992,7 @@ def _attn_fwd( LOOP_SCHEDULE: tl.constexpr, ENABLE_WS: tl.constexpr, ): + tl.static_assert(BLOCK_N <= HEAD_DIM) _attn_fwd_compute( Q, K, @@ -1072,6 +1072,7 @@ def _attn_fwd_opt( # Q, V, desc_k, desc_v, sm_scale, M, Out, # LOOP_SCHEDULE: tl.constexpr, ENABLE_WS: tl.constexpr, ): + tl.static_assert(BLOCK_N <= HEAD_DIM) _attn_fwd_compute( Q, K, @@ -1151,6 +1152,7 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, # LOOP_SCHEDULE: tl.constexpr, ENABLE_WS: tl.constexpr, ): + tl.static_assert(BLOCK_N <= HEAD_DIM) _attn_fwd_compute( Q, K, From 4921007a71b4c25f747524f4b0021b51c7d779c0 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Thu, 21 Nov 2024 15:54:33 -0800 Subject: [PATCH 2/3] fix for _ws Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 06f4e702..5654e4b5 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -728,7 +728,6 @@ def _attn_fwd_compute_ws( ENABLE_TMA: tl.constexpr, LOOP_SCHEDULE: tl.constexpr, ): - tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) off_z = off_hz // H @@ -913,6 +912,7 @@ def _attn_fwd_ws( LOOP_SCHEDULE: tl.constexpr, ENABLE_WS: tl.constexpr, ): + tl.static_assert(BLOCK_N <= HEAD_DIM) _attn_fwd_compute_ws( Q, K, @@ -1232,6 +1232,7 @@ def _attn_fwd_tma_ws( # Q, V, desc_k, desc_v, sm_scale, M, Out, # LOOP_SCHEDULE: tl.constexpr, ENABLE_WS: tl.constexpr, ): + tl.static_assert(BLOCK_N <= HEAD_DIM) _attn_fwd_compute_ws( Q, K, From 61b42992d335114cf1986e634688626493b7c40f Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Thu, 21 Nov 2024 16:16:54 -0800 Subject: [PATCH 3/3] add more configs so smaller headdim can pass Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 5654e4b5..7b08f533 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -360,11 +360,11 @@ def _attn_fwd_inner_ws( num_warps=w, ) ) - for BM in [128] - for BN in [128] + for BM in [64, 128] + for BN in [64, 128] for sched in schedList for enable_tma in [False] - for w in [8] + for w in [4, 8] ] # no WS, with TMA and CompPipe configsTma = [ @@ -393,11 +393,11 @@ def _attn_fwd_inner_ws( num_warps=w, ) ) - for BM in [128] - for BN in [128] + for BM in [64, 128] + for BN in [64, 128] for sched in schedList for enable_tma in [True] - for w in [8] + for w in [4, 8] ] # no TMA, with WS and CompPipe configsWS = [