From 3f4a177331ab844899865b2f5dcc73f5359113bb Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Thu, 2 Jan 2025 16:30:23 -0800 Subject: [PATCH] Enabled running Pallas Flash Attention on CPU. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pallas supports CPU simulation (`interpret=True`), so we can use the same TPU Pallas kernel on CPU — making code debugging easier. This change lets the following unittests run on CPU as if they were on TPU, enabling easier testing and debugging: - `axlearn/common/flash_attention/tpu_attention_test.py` Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU. - `axlearn/common/flash_attention/gpu_attention_test.py` Now CI covers those tests on CPU as well. In M3 Max MacBook Pro, test coverages and processing time are as follows, * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20) * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s --- .../flash_attention/gpu_attention_test.py | 44 +++++++++++++++++-- axlearn/common/flash_attention/layer_test.py | 9 +++- .../common/flash_attention/tpu_attention.py | 35 ++++++++------- .../flash_attention/tpu_attention_test.py | 19 ++++---- axlearn/common/flash_attention/utils.py | 13 +++--- 5 files changed, 80 insertions(+), 40 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 181a68883..2a2c8b38a 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -28,7 +28,7 @@ from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference from axlearn.common.test_utils import TestCase -if jax.default_backend() != "gpu": +if jax.default_backend() not in ("gpu", "cpu"): pytest.skip(reason="Incompatible hardware", allow_module_level=True) @@ -69,6 +69,8 @@ def test_triton_fwd_only_against_ref( kv_seq_len = seq_len if kv_seq_len != seq_len and use_segment_ids: pytest.skip() + if jax.default_backend() == "cpu" and kv_seq_len > 128: + pytest.skip(reason="CI got OOM.") k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5) q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_heads, per_head_dim), dtype=input_dtype) @@ -101,6 +103,7 @@ def test_triton_fwd_only_against_ref( causal=causal, softmax_scale=softmax_scale, dropout_rate=dropout_rate, + interpret=(jax.default_backend() == "cpu"), ) o_ref = mha_reference( q, @@ -152,6 +155,8 @@ def test_decode_against_ref( kv_head_factor: int, window_len: int, ): + if jax.default_backend() == "cpu" and seq_len > 1024: + pytest.skip(reason="Too slow on CPU.") self.assertEqual(num_heads % kv_head_factor, 0) assert num_heads % kv_head_factor == 0 k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4) @@ -180,7 +185,14 @@ def test_decode_against_ref( if window_len > 0: mask_fn = sliding_window_causal_mask(window_len) o = flash_decoding( - q, k, v, bias=bias, softmax_scale=softmax_scale, kv_seq_len=seq_len, mask_fn=mask_fn + q, + k, + v, + bias=bias, + softmax_scale=softmax_scale, + kv_seq_len=seq_len, + mask_fn=mask_fn, + interpret=(jax.default_backend() == "cpu"), ) if bias is not None: bias = bias[:, :, :, :seq_len] @@ -269,6 +281,7 @@ def test_triton_against_xla_ref( block_q=block_size, block_k=block_size, dropout_rate=dropout_rate, + interpret=(jax.default_backend() == "cpu"), ) jax_out = call_flash( q, @@ -346,6 +359,9 @@ def test_cudnn_against_triton_ref( causal: bool, dtype: jnp.dtype, ): + if jax.default_backend() == "cpu": + pytest.skip(reason="cudnn function needs GPU.") + k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype) k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype) @@ -357,7 +373,15 @@ def test_cudnn_against_triton_ref( jax_out = cudnn_dot_product_attention( q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale ) - jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale) + jax_ref_out = flash_attention( + q, + k, + v, + bias=None, + causal=causal, + softmax_scale=softmax_scale, + interpret=(jax.default_backend() == "cpu"), + ) if dtype == jnp.bfloat16: # We relax the atol to support bf16 in the unit test. chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.02, rtol=1e-5) @@ -372,7 +396,15 @@ def fn(q, k, v): ).sum() def ref_fn(q, k, v): - return flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale).sum() + return flash_attention( + q, + k, + v, + bias=None, + causal=causal, + softmax_scale=softmax_scale, + interpret=(jax.default_backend() == "cpu"), + ).sum() # Compare gradients. jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v) @@ -414,6 +446,8 @@ def test_cudnn_dropout_against_xla_dropout( by setting V to the identity matrix. However, this only works when seq_len == per_head_dim, i.e. when the shape of output is the same as the shape of the dropout mask. """ + if jax.default_backend() == "cpu": + pytest.skip(reason="cudnn function needs GPU.") qkv_shape = (batch_size, seq_len, num_heads, per_head_dim) softmax_scale = 1.0 cudnn_attn = functools.partial( @@ -481,6 +515,8 @@ def ref_fn(q, k, v): def test_cudnn_dropout_determinism(): """Tests that cuDNN dropout produces identical outputs across runs.""" + if jax.default_backend() == "cpu": + pytest.skip(reason="cudnn function needs GPU.") k1, k2, k3 = jax.random.split(jax.random.PRNGKey(3), 3) q = jax.random.normal(k1, (1, 128, 2, 64), dtype=jnp.float16) k = jax.random.normal(k2, (1, 128, 2, 64), dtype=jnp.float16) diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index c5f46a670..36f5fa0f4 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -21,7 +21,7 @@ import jax import jax.numpy as jnp import pytest -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.experimental import mesh_utils from jax.sharding import Mesh @@ -100,6 +100,7 @@ def _prepare_layers( sliding_window_size, inference=False, set_layer_bias_recursively=False, + tpu_block_size=512, dropout_rate=0.0, ): hidden_dim = num_heads * per_head_dim @@ -124,6 +125,7 @@ def _prepare_layers( .set( mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names), output_dim_to_partition_spec=default_output_dim_to_partition_spec(mesh_axis_names), + tpu_block_size=tpu_block_size, ) ) if inference: @@ -550,6 +552,7 @@ def test_forward( causal=causal, sliding_window_size=sliding_window_size, dropout_rate=dropout_rate, + tpu_block_size=128, ) query_len = int(query_len_multiplier * seq_len) @@ -916,3 +919,7 @@ def test_extend_step( atol=2e-2, ) jax.extend.backend.clear_backends() + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index 016c29551..82e6fa81d 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -2,7 +2,7 @@ """Wrappers for FlashAttention on TPU in JAX with logit bias support.""" import functools -from typing import Optional, Union +from typing import Optional import jax import jax.numpy as jnp @@ -40,6 +40,8 @@ ) from axlearn.common.utils import Tensor +MaskFnOrZero = MaskFnAttentionBias | ZeroAttentionBias + def tpu_flash_attention( query: Tensor, # [batch_size, target_len, num_heads, head_dim] @@ -48,7 +50,7 @@ def tpu_flash_attention( bias: Tensor = None, # [batch_size, num_heads, target_len, source_len] segment_ids: Tensor = None, # [batch_size, target_len] *, - mask: Optional[MaskFnAttentionBias] = None, + mask: MaskFnOrZero, softmax_scale: float = 1.0, block_size: int = 128, interpret: bool = False, @@ -113,7 +115,7 @@ def tpu_flash_attention( f"Source seq len {key.shape[1]} must be divisible by block size {block_size}." ) - mask: Union[MaskFnAttentionBias | ZeroAttentionBias] = as_attention_bias(mask) + mask: MaskFnOrZero = as_attention_bias(mask) # Switch num_heads and seq_len axes. query = jnp.einsum("btnh->bnth", query) @@ -121,8 +123,9 @@ def tpu_flash_attention( value = jnp.einsum("bsnh->bnsh", value) try: check_tpu_splash_attention( - query=query, - key=key, + target_len=query.shape[2], + source_len=key.shape[2], + head_dim=query.shape[3], mask=mask, has_segment_ids=(segment_ids is not None), has_bias=(bias is not None), @@ -199,7 +202,7 @@ def _legacy_tpu_flash_attention( bias: Tensor = None, # [batch_size, num_heads, target_len, source_len] segment_ids: Tensor = None, # [batch_size, target_len] *, - mask: MaskFnAttentionBias, + mask: MaskFnOrZero, block_sizes: Optional[LegacyBlockSizes] = None, interpret: bool = False, ) -> Tensor: # [batch_size, num_heads, target_len, head_dim]. @@ -253,17 +256,19 @@ class SplashAttentionUnsupportedError(NotImplementedError): def check_tpu_splash_attention( *, - query: Tensor, # [batch_size, num_heads, source_len, head_dim] - key: Tensor, # [batch_size, num_heads, target_len, head_dim] - mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + target_len: int, + source_len: int, + head_dim: int, + mask: MaskFnOrZero, has_segment_ids: bool = False, has_bias: bool = False, ): """Checks if splash attention is supported on TPU for the given arguments. Args: - query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim]. - key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim]. + target_len: The length of the target sequence. + source_len: The length of the source sequence. + head_dim: The dimension of each head. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. has_segment_ids: Whether segment_ids is None or not. has_bias: Whether attention involves a bias. @@ -272,10 +277,6 @@ def check_tpu_splash_attention( SplashAttentionUnsupportedError: If splash attention is not supported for the given arguments. """ - target_len = query.shape[2] - source_len = key.shape[2] - head_dim = query.shape[3] - if has_bias: raise SplashAttentionUnsupportedError("SplashAttention does not support specifying a bias.") with jax.ensure_compile_time_eval(): @@ -305,7 +306,7 @@ def check_tpu_splash_attention( def _to_splash_mask( - mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + mask: MaskFnOrZero, *, mask_shape: tuple[int, int], q_seq_shards: int = 1, @@ -344,7 +345,7 @@ def _tpu_splash_attention( key: Tensor, # [batch_size, num_heads, source_len, head_dim] value: Tensor, # [batch_size, num_heads, source_len, head_dim] *, - mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + mask: MaskFnOrZero, segment_ids: Optional[Tensor] = None, # [batch_size, target_len] block_sizes: Optional[splash_attention_kernel.BlockSizes] = None, interpret: bool = False, diff --git a/axlearn/common/flash_attention/tpu_attention_test.py b/axlearn/common/flash_attention/tpu_attention_test.py index f9a99c310..52b79649d 100644 --- a/axlearn/common/flash_attention/tpu_attention_test.py +++ b/axlearn/common/flash_attention/tpu_attention_test.py @@ -5,7 +5,6 @@ import unittest -import chex import jax import jax.numpy as jnp import numpy as np @@ -29,14 +28,10 @@ from axlearn.common.test_utils import TestCase, is_supported_mesh_shape from axlearn.common.utils import Tensor -# Comment out to test on CPU manually. Technically, this test runs on the CPU, albeit very slowly. -if jax.default_backend() != "tpu": - pytest.skip(reason="Incompatible hardware", allow_module_level=True) - def setUpModule(): - # If on CPU, emulate 4 devices. - chex.set_n_cpu_devices(4) + if jax.default_backend() not in ("tpu", "cpu"): + pytest.skip(reason="Incompatible hardware", allow_module_level=True) def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor: @@ -102,7 +97,6 @@ def test_to_splash_mask(self, mask, expected): sliding_window_size=[1024], num_heads=[4], per_head_dim=[256], - mesh=[(4, 1)], mesh_axis_names=[("data", "model")], ) def test_forward( @@ -113,11 +107,12 @@ def test_forward( per_head_dim, mask_fn, sliding_window_size, - mesh, mesh_axis_names, ): - if not is_supported_mesh_shape(mesh): - pytest.skip(reason=f"Unsupported mesh {mesh}.") + if jax.default_backend() == "cpu" and seq_len > 1024: + pytest.skip(reason="Too slow on CPU.") + mesh = (1, 1) if jax.default_backend() == "cpu" else (4, 1) + self.assertTrue(is_supported_mesh_shape(mesh)) k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) q = jax.random.normal( @@ -254,6 +249,8 @@ def ref_fn(q, k, v, bias, ids): if mask is not None: mask = MaskFnAttentionBias(mask, shape=(query_len, kv_len)) + else: + mask = ZeroAttentionBias() def fn(q, k, v, bias, ids): record_legacy_call = unittest.mock.patch.object( diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 2b4546969..e9b92e069 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -17,7 +17,6 @@ MaskFnAttentionBias, SegmentIdAttentionBias, TensorAttentionBias, - ZeroAttentionBias, split, ) from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention @@ -203,6 +202,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: mask_fn=mask_fn, kv_seq_len=kv_seq_len, softmax_scale=softmax_scale, + interpret=(backend == "cpu"), ) key = _repeat_kv_heads(query.shape[2], key) @@ -237,6 +237,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: softmax_scale=softmax_scale, causal=causal.has_value(), dropout_rate=dropout_rate, + interpret=(backend == "cpu"), ) else: explicit_bias += segment_ids @@ -268,20 +269,18 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: value, bias=explicit_bias.value(), segment_ids=get_segment_ids(segment_ids), - # The `from_sequence()` function guarantees that if there is only one - # mask, it is returned without modification. - # This allows the `causal` path in `_legacy_tpu_flash_attention()` to work. - mask=mask if not isinstance(mask, ZeroAttentionBias) else None, + mask=mask, softmax_scale=softmax_scale, block_size=block_size, + interpret=(backend == "cpu"), ) elif backend in ("cpu", "xla"): key = _repeat_kv_heads(query.shape[2], key) value = _repeat_kv_heads(query.shape[2], value) if backend == "cpu": - logging.warning("Flash attention CPU backend is for testing only.") - logging.warning("Flash attention falling back using plain MHA implementation") + logging.info("Flash attention CPU backend is for testing only.") + logging.info("Flash attention falling back using plain MHA implementation") # `causal` is supported. # `segment_ids` is supported.