Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled running Pallas Flash Attention on CPU. #922

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ds-hwang
Copy link
Contributor

Enabled running Pallas Flash Attention on CPU.

Pallas supports CPU simulation (interpret=True), so we can use the same
TPU Pallas kernel on CPU — making code debugging easier.

This change lets the following unittests run on CPU as if they were on TPU,
enabling easier testing and debugging:

  • axlearn/common/flash_attention/tpu_attention_test.py

Similarly, gpu_attention_test.py can also be run on CPU as if they were on GPU.

  • axlearn/common/flash_attention/gpu_attention_test.py

Now CI covers those tests on CPU as well.
In M3 Max MacBook Pro, test coverages and processing time are as follows,

  • axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
  • axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s

@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners January 14, 2025 04:59
@ds-hwang
Copy link
Contributor Author

@ruomingp Could you take a look? From 975

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few thoughts missed in earlier reviews...

@@ -152,6 +153,8 @@ def test_decode_against_ref(
kv_head_factor: int,
window_len: int,
):
if jax.default_backend() != "gpu" and seq_len > 1024:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we check it against "cpu" directly instead of != "gpu"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done.

@@ -346,6 +357,9 @@ def test_cudnn_against_triton_ref(
causal: bool,
dtype: jnp.dtype,
):
if jax.default_backend() == "cpu":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, let's avoid assuming that the backend is either gpu or cpu in multiple places.

Suggested change
if jax.default_backend() == "cpu":
if jax.default_backend() != "gpu":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this code as-is as you asked

Nit: can we check it against "cpu" directly instead of != "gpu"?

In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.

if jax.default_backend() not in ("gpu", "cpu"):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.

I know you are making this assumption, but such dependency is fragile---what if we extend the supported backends in the future?

In this case, requiring the backend to be "gpu" is both more robust and readable. What's the downside?

Comment on lines +447 to +452
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here and elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, keep using jax.default_backend() == "cpu":

Comment on lines 100 to 95
seq_len=[1024, 32768],
seq_len=[1024],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the sliding window size is 1024, it will be useful to keep a test case for seq_len > 1024. We can enable the test only on TPU if it's too slow on CPU. We can also use a seq_len such as 2048 for cpu if it's fast enough.

Copy link
Contributor Author

@ds-hwang ds-hwang Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I changed it back to resume the first PR's code.

We had this thread in 975

@ruomingp Do we need to support seq_len up to 1024? If the block size is 128, supporting <= 256 should be enough?

@ds-hwang Agreed. I removed 32k test with this if-statement.

softmax_scale=softmax_scale,
block_size=block_size,
interpret=(backend == "cpu"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how often we do this across locations, I wonder if we can do the following:

  • Make interpret default to None (instead of False);
  • If it's None, assume interpret=True if the backend is "cpu";

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. interpret=True applies only to the Pallas kernel. Therefore, having an interpret variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.

Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if statement to:

elif backend in ("cpu", "tpu"):

would allow debugging in layer_test.py.

Copy link
Contributor Author

@ds-hwang ds-hwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for review. I responded all comments. Could you check it again?

softmax_scale=softmax_scale,
block_size=block_size,
interpret=(backend == "cpu"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. interpret=True applies only to the Pallas kernel. Therefore, having an interpret variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.

Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if statement to:

elif backend in ("cpu", "tpu"):

would allow debugging in layer_test.py.

@@ -152,6 +153,8 @@ def test_decode_against_ref(
kv_head_factor: int,
window_len: int,
):
if jax.default_backend() != "gpu" and seq_len > 1024:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done.

@@ -346,6 +357,9 @@ def test_cudnn_against_triton_ref(
causal: bool,
dtype: jnp.dtype,
):
if jax.default_backend() == "cpu":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this code as-is as you asked

Nit: can we check it against "cpu" directly instead of != "gpu"?

In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.

if jax.default_backend() not in ("gpu", "cpu"):

Comment on lines +447 to +452
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, keep using jax.default_backend() == "cpu":

@ds-hwang ds-hwang requested a review from ruomingp January 14, 2025 16:25
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems possible to support interpret=None to simplify the code, with the following behavior:

  • interpret=True/False: enable/disable interpret;
  • interpret=None: let the implementation choose whether to interpret, depending on whether pallas is used and running on cpu vs. accelerator;

But I don't want to block this PR, as we can simplify it later.

@ds-hwang ds-hwang added this pull request to the merge queue Jan 16, 2025
@ds-hwang
Copy link
Contributor Author

Thank you for review!

@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jan 16, 2025
@ds-hwang ds-hwang added this pull request to the merge queue Jan 16, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jan 16, 2025
@ds-hwang ds-hwang enabled auto-merge January 16, 2025 20:01
@ds-hwang ds-hwang force-pushed the flsh_cpu branch 2 times, most recently from 0e39bdd to 3f4a177 Compare January 21, 2025 22:30
Pallas supports CPU simulation (`interpret=True`), so we can use the same
TPU Pallas kernel on CPU — making code debugging easier.

This change lets the following unittests run on CPU as if they were on TPU,
enabling easier testing and debugging:
- `axlearn/common/flash_attention/tpu_attention_test.py`

Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU.
- `axlearn/common/flash_attention/gpu_attention_test.py`

Now CI covers those tests on CPU as well.
In M3 Max MacBook Pro, test coverages and processing time are as follows,
* axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
* axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants