-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabled running Pallas Flash Attention on CPU. #922
base: main
Are you sure you want to change the base?
Conversation
@ruomingp Could you take a look? From 975 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few thoughts missed in earlier reviews...
@@ -152,6 +153,8 @@ def test_decode_against_ref( | |||
kv_head_factor: int, | |||
window_len: int, | |||
): | |||
if jax.default_backend() != "gpu" and seq_len > 1024: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can we check it against "cpu" directly instead of != "gpu"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, done.
@@ -346,6 +357,9 @@ def test_cudnn_against_triton_ref( | |||
causal: bool, | |||
dtype: jnp.dtype, | |||
): | |||
if jax.default_backend() == "cpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise, let's avoid assuming that the backend is either gpu or cpu in multiple places.
if jax.default_backend() == "cpu": | |
if jax.default_backend() != "gpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave this code as-is as you asked
Nit: can we check it against "cpu" directly instead of != "gpu"?
In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu"
is != "gpu"
in this code.
if jax.default_backend() not in ("gpu", "cpu"):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.
I know you are making this assumption, but such dependency is fragile---what if we extend the supported backends in the future?
In this case, requiring the backend to be "gpu" is both more robust and readable. What's the downside?
if jax.default_backend() == "cpu": | ||
pytest.skip(reason="cudnn function needs GPU.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here and elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, keep using jax.default_backend() == "cpu":
seq_len=[1024, 32768], | ||
seq_len=[1024], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the sliding window size is 1024, it will be useful to keep a test case for seq_len > 1024. We can enable the test only on TPU if it's too slow on CPU. We can also use a seq_len such as 2048 for cpu if it's fast enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
softmax_scale=softmax_scale, | ||
block_size=block_size, | ||
interpret=(backend == "cpu"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given how often we do this across locations, I wonder if we can do the following:
- Make
interpret
default to None (instead of False); - If it's None, assume interpret=True if the backend is "cpu";
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion. interpret=True
applies only to the Pallas kernel. Therefore, having an interpret
variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.
Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if
statement to:
elif backend in ("cpu", "tpu"):
would allow debugging in layer_test.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for review. I responded all comments. Could you check it again?
softmax_scale=softmax_scale, | ||
block_size=block_size, | ||
interpret=(backend == "cpu"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion. interpret=True
applies only to the Pallas kernel. Therefore, having an interpret
variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.
Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if
statement to:
elif backend in ("cpu", "tpu"):
would allow debugging in layer_test.py
.
@@ -152,6 +153,8 @@ def test_decode_against_ref( | |||
kv_head_factor: int, | |||
window_len: int, | |||
): | |||
if jax.default_backend() != "gpu" and seq_len > 1024: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, done.
@@ -346,6 +357,9 @@ def test_cudnn_against_triton_ref( | |||
causal: bool, | |||
dtype: jnp.dtype, | |||
): | |||
if jax.default_backend() == "cpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave this code as-is as you asked
Nit: can we check it against "cpu" directly instead of != "gpu"?
In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu"
is != "gpu"
in this code.
if jax.default_backend() not in ("gpu", "cpu"):
if jax.default_backend() == "cpu": | ||
pytest.skip(reason="cudnn function needs GPU.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, keep using jax.default_backend() == "cpu":
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems possible to support interpret=None
to simplify the code, with the following behavior:
interpret=True/False
: enable/disable interpret;interpret=None
: let the implementation choose whether to interpret, depending on whether pallas is used and running on cpu vs. accelerator;
But I don't want to block this PR, as we can simplify it later.
Thank you for review! |
0e39bdd
to
3f4a177
Compare
Pallas supports CPU simulation (`interpret=True`), so we can use the same TPU Pallas kernel on CPU — making code debugging easier. This change lets the following unittests run on CPU as if they were on TPU, enabling easier testing and debugging: - `axlearn/common/flash_attention/tpu_attention_test.py` Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU. - `axlearn/common/flash_attention/gpu_attention_test.py` Now CI covers those tests on CPU as well. In M3 Max MacBook Pro, test coverages and processing time are as follows, * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20) * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s
Enabled running Pallas Flash Attention on CPU.
Pallas supports CPU simulation (
interpret=True
), so we can use the sameTPU Pallas kernel on CPU — making code debugging easier.
This change lets the following unittests run on CPU as if they were on TPU,
enabling easier testing and debugging:
axlearn/common/flash_attention/tpu_attention_test.py
Similarly,
gpu_attention_test.py
can also be run on CPU as if they were on GPU.axlearn/common/flash_attention/gpu_attention_test.py
Now CI covers those tests on CPU as well.
In M3 Max MacBook Pro, test coverages and processing time are as follows,