Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Legacy flash remat fix #943

Open
wants to merge 2 commits into from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions axlearn/common/flash_attention/remat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,13 @@ def test_remat_combine_policy(self):
)
)

remat_hlo = str(jax.jit(remat).lower(params, inputs).as_text("hlo"))
self.assertEqual(
str(jax.make_jaxpr(remat)(params, inputs)).count("_mha_forward_kernel"),
1,
remat_hlo.count('custom_call_target="__gpu$xla.gpu.triton"'),
3,
)
self.assertEqual(
str(jax.jit(remat).lower(params, inputs).as_text("hlo")).count(" dot("),
remat_hlo.count(" dot("),
no_remat_dots_count,
)

Expand Down Expand Up @@ -229,4 +230,3 @@ def test_remat_combine_policy(self):
str(jax.jit(remat).lower(params, inputs).as_text("hlo")).count(" dot("),
no_remat_dots_count,
)
jax.jit(remat).lower(params, inputs).as_text("hlo")
36 changes: 21 additions & 15 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,18 +473,17 @@ def pallas_tpu_flash_attention(
batch_size, num_heads, q_seq_len, kv_seq_len, d_model
)
return _flash_attention(
q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug, interpret
q, k, v, ab, segment_ids, causal, softmax_scale, block_sizes, debug, interpret
)


@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 11))
@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10))
def _flash_attention(
q,
k,
v,
ab,
segment_ids,
save_residuals,
causal,
softmax_scale,
block_sizes,
Expand All @@ -497,7 +496,7 @@ def _flash_attention(
v,
ab,
segment_ids,
save_residuals,
False,
causal,
softmax_scale,
block_sizes.block_b,
Expand All @@ -515,23 +514,32 @@ def _flash_attention_fwd(
v,
ab,
segment_ids,
save_residuals,
causal,
softmax_scale,
block_sizes,
debug,
interpret,
):
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
o, l, m = _flash_attention(
q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug, interpret
o, l, m = _flash_attention_impl(
q,
k,
v,
ab,
segment_ids,
True,
causal,
softmax_scale,
block_sizes.block_b,
block_sizes.block_q,
block_sizes.block_k_major,
block_sizes.block_k,
debug,
interpret,
)
return o, (q, k, v, ab, segment_ids, o, l, m)


def _flash_attention_bwd(
save_residuals: bool,
causal: bool,
softmax_scale: float,
block_sizes: LegacyBlockSizes,
Expand All @@ -541,8 +549,6 @@ def _flash_attention_bwd(
do,
):
"""VJP rule for FlashAttention."""
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
(q, k, v, ab, segment_ids, o, l, m) = residuals
if not block_sizes.has_backward_blocks:
raise ValueError(
Expand Down Expand Up @@ -788,11 +794,11 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
)
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
o = jax.ad_checkpoint.checkpoint_name(o, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
l = jax.ad_checkpoint.checkpoint_name(l, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
m = jax.ad_checkpoint.checkpoint_name(m, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
if save_residuals:
l, m = (v[..., 0] for v in aux[-2:])
o = jax.ad_checkpoint.checkpoint_name(o, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
l = jax.ad_checkpoint.checkpoint_name(l, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
m = jax.ad_checkpoint.checkpoint_name(m, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
return (o, l, m)
else:
return o
Expand Down
Loading