Skip to content

Commit

Permalink
[FA]: squeeze Z and H into one axis to align with XeTLA (#2618)
Browse files Browse the repository at this point in the history
Squeeze Z H into the same axis as what XeTLA does. This change can have
about 3% benefit for N_CTX = 512 shapes.
  • Loading branch information
quintinwang5 authored Nov 6, 2024
1 parent 49a52a2 commit 99778f4
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
start_m = tl.program_id(2)
off_z = tl.program_id(0)
off_h = tl.program_id(1)
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
if N_CTX <= 512:
start_m = tl.program_id(0)
off_z = tl.program_id(2)
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
qvk_offset = off_z.to(tl.int64) * stride_qh

# block pointers
Q_block_ptr = tl.make_block_ptr(
Expand Down Expand Up @@ -181,7 +182,7 @@ def forward(q, k, v, causal, sm_scale):
grid = lambda args: (q.shape[0], q.shape[1], triton.cdiv(q.shape[2], args['BLOCK_M']))
n_ctx = q.shape[2]
if n_ctx <= 512:
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), q.shape[1], q.shape[0])
grid = lambda args: (triton.cdiv(q.shape[2], args['BLOCK_M']), 1, q.shape[0] * q.shape[1])
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0':
Expand Down

0 comments on commit 99778f4

Please sign in to comment.