diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 33b8ec57..8c4a8b7b 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -134,8 +134,8 @@ from axlearn.common.utils import ( Nested, NestedTensor, - OffloadPolicy, PartitionSpec, + RematPolicy, SavePattern, Tensor, TensorSpec, @@ -4003,8 +4003,6 @@ def forward( # TODO(sneha): extend_step -# Adapted from jax source code to support regex. Reference: -# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 # TODO(kelvin-zou): deprecated, keep it here to minimize distruption to the golden configs. # Please use axlearn.common.utils.extended_checkpoint_policies instead. def _save_and_offload_only_these_names_regex( @@ -4013,7 +4011,7 @@ def _save_and_offload_only_these_names_regex( names_which_can_be_offloaded: SavePattern, offload_src: str, offload_dst: str, -) -> OffloadPolicy: +) -> RematPolicy: return save_and_offload_only_these_names_regex( names_which_can_be_saved=names_which_can_be_saved, names_which_can_be_offloaded=names_which_can_be_offloaded, @@ -4029,7 +4027,9 @@ class RematRegexSavePatterns(enum.Enum): CONTEXT = r".*context" LINEAR1_X = r".*linear1_[01]" LINEAR2_X = r".*linear2_[01]" - SELF_ATTENTION = ".*([qkvo]_proj|context)" + # This is called native attention because the "context" remat point only exists when using + # native attention, e.g. `MultiheadAttention` or `GroupedQueryAttention`. + NATIVE_ATTENTION = ".*([qkvo]_proj|context)" FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X]) @@ -4037,22 +4037,33 @@ def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore ], - save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value, + save_pattern: SavePattern = RematRegexSavePatterns.NATIVE_ATTENTION.value, offload_pattern: SavePattern = None, offload_dst: str = "pinned_host", ) -> Optional[RematSpec]: """Configures how the Transformer or Conformer stack will save the linearization points. We try to save activations from the forward pass that are inefficient to recompute on the - backward pass. We choose the linearization points in the MultiHeadAttention layer, as that - demonstrated (empirically) the best throughput, allowing us to train with a batch size of 16 on - gpt2-10b with adamw and full sharding across 4 TPU v4 chips and a RepeatedTransformerLayer, - with 1.8x the step time of a stacked layer with a batch size of 8 and the same sharding config. + backward pass which are mainly matrix multiplications. By default, we don't save linear + layer's output due to the large expansion factor. For conformer model, we start from the same remat policy as language models. TODO(zhiyunlu): investigate Conformer model's memory/step-time tradeoffs. Possibly we need to save points in the LConv module. + Note that the default `save_pattern`, `NATIVE_ATTENTION`, doesn't save the context tensor when + using `FlashAttention`. To save it when using `FlashAttention`, use the policy from the module + `axlearn.common.flash_attention.remat`: + + ```python + from axlearn.common.utils import save_and_offload_these_names_regex + from axlearn.common.flash_attention.remat import save_or_offload_flash_attention_policy + combine_remat_policies( + save_and_offload_these_names_regex(...), + save_or_offload_flash_attention_policy() + ) + ``` + Args: stack_cfg: A transformer config. save_pattern: Activation regex pattern to save in HBM. diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index d4555b11..c1c056cd 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -174,6 +174,9 @@ def infer_tpu_version(tpu_type: str) -> str: """ tpu_type = infer_tpu_type(tpu_type) tpu_version = tpu_type.rsplit("-", 1)[0] # split from the last occurrence of '-' + # Resolve aliases like v5e to v5litepod, since in some cases (e.g. aot compilation) v5e is + # expected. + tpu_version = _TPU_VERSION_ALIASES.get(tpu_version, tpu_version) if tpu_version not in _TPU_VERSIONS: raise ValueError(f"Unknown TPU version {tpu_version}. Expected one of {_TPU_VERSIONS}") return tpu_version @@ -238,4 +241,5 @@ def infer_xsc_compiler_options( return options +_TPU_VERSION_ALIASES = {"v5e": "v5litepod"} _TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p", "v6e") diff --git a/axlearn/common/compiler_options_test.py b/axlearn/common/compiler_options_test.py index 30d1b2d3..ff146c2e 100644 --- a/axlearn/common/compiler_options_test.py +++ b/axlearn/common/compiler_options_test.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp import pytest +from absl.testing import parameterized from axlearn.common import compiler_options, test_utils from axlearn.common.utils import Tensor @@ -48,3 +49,9 @@ def test_xsc_compiler_options(self): ) for name, option in options.items(): self.assertEqual(option, expected_options[name]) + + @parameterized.parameters( + dict(tpu_type="v5e-16", expected="v5litepod"), + ) + def test_tpu_version_alias(self, tpu_type: str, expected: str): + self.assertEqual(expected, compiler_options.infer_tpu_version(tpu_type)) diff --git a/axlearn/common/decoder.py b/axlearn/common/decoder.py index 027cf5da..fabf01ed 100644 --- a/axlearn/common/decoder.py +++ b/axlearn/common/decoder.py @@ -659,6 +659,7 @@ def prefill_states( ), **kwargs, ) + self.add_module_output("prefill_hidden_states", outputs["hidden_states"]) states = dict(time_step=time_step, input_ids=input_ids, **states) return states, outputs diff --git a/axlearn/common/decoder_test.py b/axlearn/common/decoder_test.py index 288124af..b9b1ebb7 100644 --- a/axlearn/common/decoder_test.py +++ b/axlearn/common/decoder_test.py @@ -613,8 +613,14 @@ def tokens_to_scores(token_ids, cache): side_effect=mock_tokens_to_scores, ) + # Drop any module outputs added in `method` to avoid leaking tracers via checkify. + def method_fn(*args, **kwargs): + out = getattr(decoder, method)(*args, **kwargs) + decoder.get_invocation_context().get_module_outputs().clear() + return out + # Checkify the decoding method being called. - decoder._checked_method = checkify.checkify(getattr(decoder, method)) + decoder._checked_method = checkify.checkify(method_fn) # pylint: enable=protected-access with mock_ctx: diff --git a/axlearn/common/dit.py b/axlearn/common/dit.py index a388d192..dddd00d3 100644 --- a/axlearn/common/dit.py +++ b/axlearn/common/dit.py @@ -234,9 +234,10 @@ def forward(self, input: Tensor) -> Tensor: Each tensor has shape [batch_size, 1|num_length, dim]. """ cfg = self.config + if input.ndim not in (2, 3): + raise ValueError(f"The input must be rank 2 or 3, but got the {input.shape} tensor.") x = get_activation_fn(cfg.activation)(input) output = self.linear(x) - assert output.ndim in (2, 3) if output.ndim == 2: output = einops.rearrange(output, "b d -> b 1 d") output = jnp.split(output, cfg.num_outputs, axis=-1) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 42110673..9cff2723 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -94,6 +94,9 @@ def _mha_forward_kernel( See also `_mha_backward_kernel` for the backward pass. + Note: the kernel name is used to do string matching for rematerialization in `remat.py`. Be + careful when renaming this. + Args: q_ref: Input query ref. k_ref: Input key ref. diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index a3330a82..6bfe31ee 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -146,6 +146,32 @@ def _prepare_layers( return test_layer, ref_layer, params, hidden_dim +class DummyModel(BaseLayer): + """A dummy model.""" + + @config_class + class Config(BaseLayer.Config): + layer: GroupedQueryAttention.Config = GroupedQueryAttention.default_config() + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + self._add_child("layer", cfg.layer) + + def forward(self, *, query, key, value, attention_logit_biases, segment_ids): + # [batch, target_length, target_dim]. + x = self.layer( + query, + key=key, + value=value, + attention_logit_biases=attention_logit_biases, + segment_ids=segment_ids, + ) + # TODO(markblee,zhaoyi-zhang): The atol needs to increase significantly if using + # jnp.sum, as we no longer scale by the size of the data dims. + return jnp.mean(x.data, dtype=query.dtype) + + class TestFlashAttention(TestCase): """Tests FlashAttention layer.""" @@ -512,34 +538,7 @@ def test_backward( pytest.skip(reason="Only one of causal and use_bias can be True.") with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names): - - class DummyModel(BaseLayer): - """A dummy model.""" - - @config_class - class Config(BaseLayer.Config): - layer: GroupedQueryAttention.Config = GroupedQueryAttention.default_config() - - def __init__(self, cfg: Config, *, parent: Module): - super().__init__(cfg, parent=parent) - cfg = self.config - self._add_child("layer", cfg.layer) - - def forward(self, *, query, key, value, attention_logit_biases, segment_ids): - # [batch, target_length, target_dim]. - x = self.layer( - query, - key=key, - value=value, - attention_logit_biases=attention_logit_biases, - segment_ids=segment_ids, - ) - # TODO(markblee,zhaoyi-zhang): The atol needs to increase significantly if using - # jnp.sum, as we no longer scale by the size of the data dims. - return jnp.mean(x.data, dtype=query.dtype) - hidden_dim = num_heads * per_head_dim - if sliding_window_size is not None: mask_fn = config_for_function(sliding_window_causal_mask).set( sliding_window_size=sliding_window_size diff --git a/axlearn/common/flash_attention/remat.py b/axlearn/common/flash_attention/remat.py new file mode 100644 index 00000000..64c555fa --- /dev/null +++ b/axlearn/common/flash_attention/remat.py @@ -0,0 +1,59 @@ +# Copyright © 2025 Apple Inc. +"""Remat policy for FlashAttention kernels.""" + +from jax._src.cudnn.fused_attention_stablehlo import _dot_product_attention_fwd_p_wrapper +from jax.custom_derivatives import custom_vjp_call_jaxpr_p +from jax.experimental.pallas import pallas_call_p + +from axlearn.common.utils import Recompute, RematPolicy, RematType, Saveable + + +def save_or_offload_flash_attention_policy(remat_type: RematType = Saveable) -> RematPolicy: + """Returns a remat policy for FlashAttention output. + + This remat policy allows saving attention output, which is the tensor before out projection + commonly named "context". More precisely, it saves the attention output of GPU Pallas kernel, + TPU Legacy Pallas kernel, TPU SplashAttention kernel, and cuDNN FlashAttention kernel. + + Because cuDNN FlashAttention and TPU SplashAttention invocations are in Jax source code, it's + not feasible to save the output using `checkpoint_name`. Therefore, we match the Jax primitives + to implement this save policy. + + Note for users: for context length >= 4096, FlashAttention kernel takes noticeably longer on + both TPU and GPU to execute than o_proj. Therefore, saving the output of FlashAttention is + more advantages than saving o_proj since they have roughly the same memory footprint if the HBM + capacity doesn't allow saving both. + + Args: + remat_type: Remat type. Defaults to Saveable (save to HBM) and only supports Saveable. + + Returns: + A RematPolicy. Users can combine this remat policy with any existing policy with + `axlearn.common.utils.combine_remat_policies`. + """ + # Jax bug: https://github.com/jax-ml/jax/issues/25841. + # TODO(hanzhi-zhou): add support for Offloadable when jax supports it. + if remat_type is not Saveable: + raise NotImplementedError(f"{remat_type=} is not implemented.") + + def policy(prim, *_, **params): + src_info = "" + # Primitives could be copies if modules are reinitialized, so `is` check is unreliable. + # Use string equality instead. + prim_s = str(prim) + if prim_s == str(pallas_call_p): + src_info = str(params.get("name_and_src_info", "")) + if prim_s == str(custom_vjp_call_jaxpr_p): + src_info = str(params.get("fun_jaxpr", "")) + # GPU Pallas kernel. + if "_mha_forward_kernel" in src_info: + return remat_type + # TPU new and legacy Pallas kernel. + if "flash_attention_kernel" in src_info: + return remat_type + # cuDNN kernel. + if prim_s == str(_dot_product_attention_fwd_p_wrapper): + return remat_type + return Recompute + + return policy diff --git a/axlearn/common/flash_attention/remat_test.py b/axlearn/common/flash_attention/remat_test.py new file mode 100644 index 00000000..d9ec3f97 --- /dev/null +++ b/axlearn/common/flash_attention/remat_test.py @@ -0,0 +1,224 @@ +# Copyright © 2025 Apple Inc. + +"""Tests FlashAttention remat policy.""" +# pylint: disable=ungrouped-imports +import os + +# Due to reference layer using XLA, +# set the following environment variables to avoid OOM in GPU tests. +# pylint: disable=wrong-import-position +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" +from contextlib import nullcontext + +# pylint: enable=wrong-import-position +import jax +import jax.numpy as jnp +from absl.testing import parameterized +from jax.ad_checkpoint import checkpoint_policies +from jax.experimental import mesh_utils +from jax.sharding import Mesh + +from axlearn.common.flash_attention.layer import ( + FlashAttention, + default_mha_dim_to_partition_spec, + default_output_dim_to_partition_spec, +) +from axlearn.common.flash_attention.layer_test import DummyModel, _fake_inputs +from axlearn.common.flash_attention.remat import save_or_offload_flash_attention_policy +from axlearn.common.module import functional as F +from axlearn.common.test_utils import TestCase, is_supported_mesh_shape +from axlearn.common.utils import ( + Offloadable, + Saveable, + combine_remat_policies, + default_remat_combine_fn, + offload_dots_saveable, +) + + +class TestFlashAttentionRemat(TestCase): + """Tests FlashAttention remat policy.""" + + def _get_remat_test_data(self, use_segment_ids): + if jax.default_backend() not in ("gpu", "tpu"): + self.skipTest("Requires TPU or GPU to run.") + batch = 8 + seq_len = 128 + num_heads = 1 + per_head_dim = 128 + # Mesh shape doesn't matter. Find any supported mesh shape. + mesh = (8, 1) + if not is_supported_mesh_shape(mesh): + mesh = (4, 1) + if not is_supported_mesh_shape(mesh): + mesh = (1, 1) + mesh_axis_names = ("data", "model") + mesh = Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names) + with mesh: + hidden_dim = num_heads * per_head_dim + + kwargs = dict( + query_dim=hidden_dim, + key_dim=hidden_dim, + value_dim=hidden_dim, + num_heads=num_heads, + dtype=jnp.bfloat16, + causal=True, + mask=None, + ) + test_cfg = DummyModel.default_config().set( + layer=FlashAttention.default_config().set( + tpu_block_size=128, + mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names), + output_dim_to_partition_spec=default_output_dim_to_partition_spec( + mesh_axis_names + ), + **kwargs, + ) + ) + test_layer = test_cfg.set(name="test").instantiate(parent=None) + + # Use the same params for both. Only attention implementation differs. + params = test_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + inputs = _fake_inputs( + batch=batch, + num_heads=num_heads, + kv_len=seq_len, + query_len=seq_len, + hidden_dim=hidden_dim, + use_bias=False, + use_segment_ids=use_segment_ids, + ) + + def loss(params, inputs): + loss, _ = F( + test_layer, + inputs=inputs, + state=params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return loss + + return mesh, params, inputs, loss + + @parameterized.product( + use_segment_ids=[True, False], + remat_type=[Saveable, Offloadable(src="device", dst="pinned_host")], + ) + def test_flash_remat(self, use_segment_ids, remat_type): + if jax.default_backend() not in ("gpu", "tpu"): + self.skipTest("Remat test requires either GPU or TPU.") + + mesh, params, inputs, loss = self._get_remat_test_data(use_segment_ids) + with mesh: + no_remat = jax.value_and_grad(loss) + full_remat = jax.value_and_grad(jax.remat(loss)) + with self.assertRaises(NotImplementedError) if isinstance( + remat_type, Offloadable + ) else nullcontext(): + save_flash = jax.value_and_grad( + jax.remat(loss, policy=save_or_offload_flash_attention_policy(remat_type)) + ) + if isinstance(remat_type, Offloadable): + return + fn_expected_fw_count = [ + (no_remat, 1), + (full_remat, 2), + (save_flash, 1), + ] + + # Note: we don't use HLO for Pallas since HLO doesn't contain the kernel name. + if jax.default_backend() == "gpu": + if use_segment_ids: + # Pallas kernel case. + for fn, count in fn_expected_fw_count: + self.assertEqual( + str(jax.make_jaxpr(fn)(params, inputs)).count("_mha_forward_kernel"), + count, + ) + else: + # cuDNN case. + # Note: the backward kernel is called "__cudnn$fmhaSoftmaxBackward". + # Use " to distinguish forward and backward kernel. + for fn, count in fn_expected_fw_count: + self.assertEqual( + jax.jit(fn) + .lower(params, inputs) + .as_text("hlo") + .count('"__cudnn$fmhaSoftmax"'), + count, + ) + elif jax.default_backend() == "tpu": + for fn, count in fn_expected_fw_count: + self.assertEqual( + jax.jit(fn) + .lower(params, inputs) + .as_text("hlo") + .count('custom_call_target="tpu_custom_call"'), + # +1 because this custom call also matches the backward call. + # Also +1 for legacy code path since the backward kernels are + # not fused. I.e. there are two calls, one for dkdv and one for dq. + count + 1 + int(use_segment_ids), + ) + + def test_remat_combine_policy(self): + if jax.default_backend() != "gpu": + self.skipTest("Need GPU for this test.") + mesh, params, inputs, loss = self._get_remat_test_data(True) + with mesh: + no_remat = jax.value_and_grad(loss) + no_remat_dots_count = str(jax.jit(no_remat).lower(params, inputs).as_text("hlo")).count( + " dot(" + ) + offload = Offloadable(src="device", dst="pinned_host") + remat = jax.value_and_grad( + jax.remat( + loss, + policy=combine_remat_policies( + checkpoint_policies.dots_saveable, + save_or_offload_flash_attention_policy(), + ), + ) + ) + + self.assertEqual( + str(jax.make_jaxpr(remat)(params, inputs)).count("_mha_forward_kernel"), + 1, + ) + self.assertEqual( + str(jax.jit(remat).lower(params, inputs).as_text("hlo")).count(" dot("), + no_remat_dots_count, + ) + + with self.assertRaises(RuntimeError): + # Tests conflicting save policy for dots. + remat = jax.value_and_grad( + jax.remat( + loss, + policy=combine_remat_policies( + checkpoint_policies.everything_saveable, + offload_dots_saveable("device", "pinned_host"), + ), + ) + ) + jax.jit(remat).lower(params, inputs).as_text("hlo") + + # Tests conflicting save policy should works if preferred type is specified. + for preferred in [Saveable, offload]: + remat = jax.value_and_grad( + jax.remat( + loss, + policy=combine_remat_policies( + checkpoint_policies.everything_saveable, + offload_dots_saveable("device", "pinned_host"), + combine_fn=default_remat_combine_fn(preferred), + ), + ) + ) + self.assertEqual( + str(jax.jit(remat).lower(params, inputs).as_text("hlo")).count(" dot("), + no_remat_dots_count, + ) + jax.jit(remat).lower(params, inputs).as_text("hlo") diff --git a/axlearn/common/input_base.py b/axlearn/common/input_base.py index 31b7d3e9..efec40ea 100644 --- a/axlearn/common/input_base.py +++ b/axlearn/common/input_base.py @@ -148,3 +148,10 @@ def constrain_batch_axis(batch): global_physical_batch, batch_axis_names=batch_axis_names ) return constrain_batch_axis(global_logical_batch) + + def element_spec(self) -> Nested[jax.ShapeDtypeStruct]: + """Returns the per-feed logical batch spec. + + This is used e.g. for AOT compilation and is not strictly required for training. + """ + raise NotImplementedError(type(self)) diff --git a/axlearn/common/input_grain.py b/axlearn/common/input_grain.py index e06d8ae3..10c78456 100644 --- a/axlearn/common/input_grain.py +++ b/axlearn/common/input_grain.py @@ -159,6 +159,7 @@ def maybe_repeat(ds: Dataset): ds = ds.repeat() return ds + # TODO(markblee): Support mixing grain.IterDataset. return grain.MapDataset.mix( datasets=[maybe_repeat(source) for source in sources], weights=weights, @@ -629,3 +630,22 @@ def dataset(self) -> grain.IterDataset: f"Please make sure to call {shard_dataset.__name__} if using input dispatch." ) return maybe_to_iter_dataset(ds) + + def element_spec(self) -> utils.Nested[jax.ShapeDtypeStruct]: + """Infers the element spec. + + Grain requires fetching an example from the dataset to extract the spec. To avoid reading + actual data, replace your source dataset with one from `input_fake.fake_grain_source`. + """ + ds = self.dataset() + if isinstance(ds, grain.MapDataset): + example = ds[0] + else: + example = next(ds.__iter__()) # pylint: disable=unnecessary-dunder-call + + def shape_dtype(x): + if not hasattr(x, "shape") or not hasattr(x, "dtype"): + raise ValueError(f"element_spec() requires Tensor-like leaves, got: {x}.") + return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) + + return jax.tree.map(shape_dtype, example) diff --git a/axlearn/common/input_grain_test.py b/axlearn/common/input_grain_test.py index 4cfed0e7..b052d2ca 100644 --- a/axlearn/common/input_grain_test.py +++ b/axlearn/common/input_grain_test.py @@ -567,3 +567,24 @@ def test_dispatch_tpu(self): # Should contain the right ids. self.assertEqual([0, 1, 2, 3], replicate_to_local_data(batch)["input_ids"].tolist()) break + + def test_element_spec(self): + ds = range_dataset(start=0, stop=10, seed=123).map(lambda x: {"input_ids": x}) + grain_input: Input = self._input_config(ds).instantiate(parent=None) + # element_spec() requires Tensor-like leaves. + with self.assertRaisesRegex(ValueError, "Tensor"): + grain_input.element_spec() + + ds = range_dataset(start=0, stop=10, seed=123).map(lambda x: {"input_ids": np.array(x)}) + cfg = self._input_config( + ds.repeat(num_epochs=None), + per_process=lambda ds: ds.batch(2), + process_count=4, + process_index=0, + ) + grain_input: Input = cfg.instantiate(parent=None) + self.assertEqual( + # Element spec should reflect the per-process shape. + {"input_ids": jax.ShapeDtypeStruct(shape=(2,), dtype=np.int64)}, + grain_input.element_spec(), + ) diff --git a/axlearn/common/input_tf_data.py b/axlearn/common/input_tf_data.py index 1d44ad1b..6cb1e752 100644 --- a/axlearn/common/input_tf_data.py +++ b/axlearn/common/input_tf_data.py @@ -46,6 +46,7 @@ from axlearn.common.module import Module from axlearn.common.utils import ( PHYSICAL_TO_LOGICAL_DISPATCH_KEY, + Nested, Tensor, get_data_dir, get_recursively, @@ -1209,6 +1210,16 @@ def processor(self) -> DatasetToDatasetFn: def dataset(self) -> tf.data.Dataset: return self._batcher(self._processor(self._source())) + def element_spec(self) -> Nested[jax.ShapeDtypeStruct]: + """Returns the tfds element spec.""" + + return jax.tree.map( + lambda tf_spec: jax.ShapeDtypeStruct( + shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype + ), + self.dataset().element_spec, + ) + def disable_shuffle_recursively(cfg: Input.Config): """Disables all shuffling on the input config. diff --git a/axlearn/common/input_tf_data_test.py b/axlearn/common/input_tf_data_test.py index 63b859b4..f43c5d5e 100644 --- a/axlearn/common/input_tf_data_test.py +++ b/axlearn/common/input_tf_data_test.py @@ -1583,5 +1583,28 @@ def test_disable_shuffle_recursively(self): self.assertEqual(cfg.source.source.train_shuffle_files, False) +class ElementSpecTest(parameterized.TestCase): + """Tests Input.element_spec().""" + + def test_element_spec(self): + cfg = Input.default_config().set( + source=config_for_function(with_processor).set( + source=config_for_function(fake_text_source), + processor=config_for_function(identity), + ), + processor=config_for_function(identity), + batcher=config_for_function(batch).set( + global_batch_size=2, + pad_example_fn=default_pad_example_fn, + ), + is_training=True, + name="test", + ) + self.assertEqual( + {"text": jax.ShapeDtypeStruct(shape=(2,), dtype=object)}, + cfg.instantiate(parent=None).element_spec(), + ) + + if __name__ == "__main__": absltest.main() diff --git a/axlearn/common/lora.py b/axlearn/common/lora.py index f0500671..4785b844 100644 --- a/axlearn/common/lora.py +++ b/axlearn/common/lora.py @@ -501,7 +501,7 @@ def __init__(self, cfg: Config, *, parent: Module): "adapter", cfg.adapter.set( input_dim=cfg.query_dim, - output_dim=cfg.query_dim, + output_dim=cfg.num_heads * cfg.per_head_dim, num_heads=cfg.num_heads, ), ) diff --git a/axlearn/common/lora_test.py b/axlearn/common/lora_test.py index 02cf9584..febf547f 100644 --- a/axlearn/common/lora_test.py +++ b/axlearn/common/lora_test.py @@ -137,7 +137,7 @@ def test_alpha_is_zero(self): class LoraFusedQKVLinearTest(TestCase): def test_forward(self): - model_dim = 6 + model_dim = 16 num_heads = 2 per_head_dim = 3 seq_len = 4 @@ -197,7 +197,7 @@ def test_forward(self): ), ) def test_extend_step(self, layer): - model_dim = 8 + model_dim = 16 num_heads = 2 per_head_dim = 4 # change this to 4 to adapt the need of RoPE. seq_len = 4 @@ -267,7 +267,7 @@ def test_extend_step(self, layer): ) def test_prefill_states(self): - model_dim = 6 + model_dim = 16 num_heads = 2 per_head_dim = 3 seq_len = 4 diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index 0e8ef798..9211da27 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -29,6 +29,7 @@ import jax import jax.ad_checkpoint +from einops import rearrange, repeat from jax import numpy as jnp from jax._src.mesh import thread_resources from jax.experimental.shard_map import shard_map @@ -43,11 +44,15 @@ ) from axlearn.common.base_layer import BaseLayer, ParameterSpec from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class -from axlearn.common.convolution import Conv1D -from axlearn.common.layers import Linear, MultiLinear, RMSNorm +from axlearn.common.layers import Conv1D, GroupNorm, Linear, MultiLinear, NormType, RMSNorm from axlearn.common.module import Module from axlearn.common.param_init import FanAxes, Initializer, Shape, constant_initializer, uniform from axlearn.common.ssm_kernels.mamba_kernels import compute_mamba_scan +from axlearn.common.ssm_kernels.ssd_kernels import ( + ssd, + ssd_linear_scan_w_hidden_states, + ssd_linear_scan_w_timestep, +) from axlearn.common.utils import Nested, Tensor, TensorSpec, with_sharding_constraint @@ -73,6 +78,7 @@ class Config(Initializer.Config): # Clamp dt projection's bias to at least this value. dt_init_floor: float = 1e-4 # One of 'random' or 'constant'. + # If 'constant', the projection matrix is initialized to a constant; otherwise, random. # pylint: disable=C0301 mode: str = "random" def initialize( @@ -1123,7 +1129,7 @@ class Config(BaseSSMLayer.Config): """Configures a Mamba block.""" norm: InstantiableConfig = RMSNorm.default_config() - mamba_layer: MambaMixerLayer.Config = MambaMixerLayer.default_config() + mamba_layer: BaseLayer.Config = MambaMixerLayer.default_config() residual_mode: BlockResidualMode = BlockResidualMode.FP32 def __init__(self, cfg: Config, *, parent: Module): @@ -1455,3 +1461,1044 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): for i in range(cfg.num_layers) ] super().__init__(cfg.set(layer=layers), parent=parent) + + +# Naming convention for Mamba2: +# * `SSD` is used to denote any kernel-specific parameters/functions (consistent with the kernel), +# * `Mamba2`` (where SSD is a sub-module) is used to denote layer-level parameters/functions. + + +class SSDdtBiasInitializer(Initializer): + """Initializes the bias of the dt projection in the SSD layer of Mamba2. + + The weight matrix of the dt projection is seperately constructed and initialized. + """ + + @config_class + class Config(Initializer.Config): + """Configures SSDdtBiasInitializer. + + The initialization is different from Mamba1 in that there is no low-rank parameterization. + and we only need to initialize the bias term. + + Reference: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2.py. + """ + + # Initialization stddev is set to `dt_scale` * 1/sqrt{dt_rank} when random. + dt_scale: float = 1.0 + # Minimum value of the dt projection's bias after applying softplus. + dt_min: float = 1e-3 + # Maximum value of the dt projection's bias after applying softplus. + dt_max: float = 1e-1 + # Clamp dt projection's bias to at least this value. + dt_init_floor: float = 1e-4 + + def initialize( + self, + name: str, + *, + prng_key: Tensor, + shape: Shape, + dtype: jnp.dtype, + axes: Optional[FanAxes] = None, + ) -> Tensor: + """Initializes the SSD dt projection bias following the official implementation.""" + if axes is not None: + raise ValueError("SSDdtBiasInitializer does not support FanAxes.") + cfg = self.config + assert 0 < cfg.dt_min < cfg.dt_max, "`dt_min` must be < `dt_max`." + dt = jnp.exp( + uniform(scale=1.0, dtype=dtype)(prng_key, shape) + * (math.log(cfg.dt_max) - math.log(cfg.dt_min)) + + math.log(cfg.dt_min) + ).astype( + dtype + ) # math.log may return float64, so we need to cast to dtype + dt = jnp.clip(dt, a_min=cfg.dt_init_floor) + # Get inverse of softplus. + inv_dt = dt + jnp.log(-jnp.expm1(-dt)) + return inv_dt + + +class SSDLLogAInitializer(Initializer): + """Initializes SSD's log-log A parameter, a = exp(-exp(llog_a)).""" + + @config_class + class Config(Initializer.Config): + """Configures SSDLLogAInitializer. + + Reference: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2.py. + """ + + # `A` will be initialized within the range of [a_min, a_max], usually not tuned. + a_min: int = 1 + a_max: int = 16 + + def initialize( + self, + name: str, + *, + prng_key: Tensor, + shape: Shape, + dtype: jnp.dtype, + axes: Optional[FanAxes] = None, + ) -> jnp.ndarray: + """Returns a [num_heads] shaped vector.""" + if axes is not None: + raise ValueError("SSDLLogAInitializer does not support FanAxes.") + + cfg = self.config + return jnp.log( + jax.random.uniform(prng_key, shape, dtype=dtype, minval=cfg.a_min, maxval=cfg.a_max) + ) + + +class BaseSSDRecurrence(BaseLayer): + """An abstract class representing a layer that computes the SSD recurrence.""" + + class Output(NamedTuple): + """Defines the output of the SSD recurrence.""" + + data: Tensor # [batch, num_heads, target_length, head_dim] + states: Tensor # [batch, num_heads, target_length, state_dim, head_dim] + + @config_class + class Config(BaseLayer.Config): + """Configures a BaseSSDRecurrence.""" + + output_mode: MambaRecurrenceOutputMode = MambaRecurrenceOutputMode.OUTPUTS + + def forward( + self, x: Tensor, *, log_a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor + ) -> Output: + """Computes the Mamba2's SSD recurrence output given full-sequence inputs and parameters. + + Args: + x: [batch_size, num_heads, seq_len, head_dim] + log_a: [num_heads] + b: [batch_size, num_groups, seq_len, state_dim] + c: [batch_size, num_groups, seq_len, state_dim] + delta: [batch_size, num_heads, seq_len] + d: [head_dim] + + Returns: + An instance of BaseSSDRecurrence.Output. + """ + raise NotImplementedError(type(self)) + + +class PallasSSDRecurrence(BaseSSDRecurrence): + """A layer that computes the Mamba2's SSD recurrence with a Pallas-based chunk-wise scan.""" + + @config_class + class Config(BaseSSDRecurrence.Config): + """Configures a PallasSSDRecurrence.""" + + mamba2_dim_to_partition_spec: dict[str, PartitionSpec] = { + "bhtd": PartitionSpec(None), + "bht": PartitionSpec(None), + } + + output_partition_spec: PartitionSpec = PartitionSpec(None) + + def forward( + self, x: Tensor, *, log_a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor + ) -> BaseSSDRecurrence.Output: + """Computes Mamba2's SSD recurrence with a Pallas-based chunk-wise scan. + + Args: + x: [batch_size, num_heads, seq_len, head_dim] + log_a: [1, num_heads, 1] + b: [batch_size, num_groups, seq_len, state_dim] + c: [batch_size, num_groups, seq_len, state_dim] + delta: [batch_size, num_heads, seq_len] + d: [1, num_heads, 1, 1] + + Returns: + An BaseSSDRecurrence.Output instance, where .data is the same shape as x and .states is + None (no need to return hidden states during training). + + Unlike the Mamba recurrence, discretizations of parameters are not explicitly computed. + More specifically, \bar a (i.e., discretized a) is computed outside the kernel whereas + \bar b is computed implicitly via adding the delta term to the input + x -- \bar x = x * delta. + See the following line from the official repo for details - + https://github.com/state-spaces/mamba/blob/8ffd905c91d207f5c0cc84fc2a2fb748655094f0/mamba_ssm/modules/ssd_minimal.py#L103 + + Note that `ssd` functions need to be wrapped, otherwise the following error will be raised: + ``NotImplementedError: Mosaic kernels cannot be automatically partitioned.`` + The current version of `ssd` function assumes that h0 is None, so there is no need to + provide its partition spec. + """ + cfg = self.config + + sharded_ssd = shard_map( + ssd, + mesh=thread_resources.env.physical_mesh, + in_specs=( + cfg.mamba2_dim_to_partition_spec["bhtd"], + cfg.mamba2_dim_to_partition_spec["bhtd"], + cfg.mamba2_dim_to_partition_spec["bhtd"], + cfg.mamba2_dim_to_partition_spec["bht"], + ), + out_specs=cfg.output_partition_spec, + check_rep=False, + ) + # The kernel code `ssd_kernels.py` uses q/k/v notations, which corresponds to b/c/x. + x_bar = x * jnp.expand_dims(delta, axis=-1) + loga_bar = log_a * delta + o = sharded_ssd(c, b, x_bar, loga_bar) + + o = o + d * x + return BaseSSDRecurrence.Output(data=o, states=None) + + +class LinearScanSSDRecurrence(BaseSSDRecurrence): + """A layer that computes the Mamba2's SSD recurrence with a Jax-based linear scan.""" + + def forward( + self, + x: Tensor, + *, + log_a: Tensor, + b: Tensor, + c: Tensor, + delta: Tensor, + d: Tensor, + time_step: Optional[Tensor] = None, + ) -> BaseSSDRecurrence.Output: + """Computes the Mamba2's SSD recurrence with a Jax-based linear scan. + + Args: + x: [batch_size, num_heads, seq_len, head_dim] + log_a: [1, num_heads, 1] + b: [batch_size, num_groups, seq_len, state_dim] + c: [batch_size, num_groups, seq_len, state_dim] + delta: [batch_size, num_heads, seq_len] + d: [1, num_heads, 1, 1] + time_step: [batch_size] or None + + Returns: + An BaseSSDRecurrence.Output instance, where .data is the same shape as x and .states is + the hidden states of shape [batch_size, num_heads, seq_len, state_dim, head_dim] for + the given time step if `time_step` is not None, otherwise the full hidden states of + shape [batch_size, num_heads, seq_len, state_dim, head_dim] is returned. + """ + # Same procedure as the pallas version above. + x_bar = x * jnp.expand_dims(delta, axis=-1) + loga_bar = log_a * delta + + if time_step is None: + # Return the full hidden states. + o, states = ssd_linear_scan_w_hidden_states(c, b, x_bar, loga_bar) + else: + # Return the hidden states at the given time step. + o, states = ssd_linear_scan_w_timestep(c, b, x_bar, loga_bar, time_step) + + o = o + d * x + return BaseSSDRecurrence.Output(data=o, states=states) + + +class Mamba2MixerLayer(BaseLayer): + """A layer that computes the Mamba2 recurrence over its input.""" + + @config_class + class Config(BaseLayer.Config): + """Configures a Mamba2MixerLayer.""" + + # `d_model` increases as models get larger. + input_dim: Required[int] = REQUIRED + # `d_state` typically in {64, 128} + state_dim: Required[int] = REQUIRED + # num_heads = input_dim // head_dim, head_dim is typically 128. + num_heads: Required[int] = REQUIRED + + # `G` in the paper, typically 8 + num_groups: Required[int] = REQUIRED + + # See sec 8.2 for the parameterization. More details (e.g., conv + # for bc projection) can be found in the following link: + # https://github.com/state-spaces/mamba/blob/8ffd905c91d207f5c0cc84fc2a2fb748655094f0/mamba_ssm/modules/mamba2.py # pylint: disable=C0301 + + xz_proj: MultiLinear.Config = MultiLinear.default_config().set( + bias=False, + param_partition_spec=(None, None, "model"), + ) + + bc_proj: MultiLinear.Config = MultiLinear.default_config().set( + bias=False, + param_partition_spec=(None, None, "model"), + ) + # A causal convolution. The window defaults to 4, the same as mamba1. + x_conv: Conv1D.Config = Conv1D.default_config().set( + window=4, + bias=True, + param_partition_spec=(None, None, "model"), + ) + b_conv: Conv1D.Config = Conv1D.default_config().set( + window=4, + bias=True, + param_partition_spec=(None, None, "model"), + ) + c_conv: Conv1D.Config = Conv1D.default_config().set( + window=4, + bias=True, + param_partition_spec=(None, None, "model"), + ) + + # `dt_bias` is separately created and initialized. + dt_proj: Linear.Config = Linear.default_config().set( + bias=False, param_partition_spec=(None, "model") + ) + pre_out_proj_norm: InstantiableConfig = GroupNorm.default_config().set( + norm_type=NormType.RMSNORM, + norm_axes=-1, + ) + out_proj: Linear.Config = Linear.default_config().set( + bias=False, + param_partition_spec=("model", None), + ) + + expansion_factor: float = 2.0 + cache_dtype: Optional[jnp.dtype] = None + bc_norm: Optional[InstantiableConfig] = RMSNorm.default_config() + norm_eps: float = 1e-5 + norm_dtype: Optional[jnp.dtype] = None + + # The recurrence implementation to use for full-sequence inputs. + ssd_recurrence: BaseSSDRecurrence = PallasSSDRecurrence.default_config() + # The recurrence implementation to use for inference. + inference_mamba_recurrence: BaseSSDRecurrence = ( + LinearScanSSDRecurrence.default_config().set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES + ) + ) + + class Mamba2Output(NamedTuple): + """Defines the output of the Mamba2MixerLayer.""" + + data: Tensor # [batch, num_heads, target_length, head_dim] + ssd_state: Tensor # [batch, num_heads, state_dim, head_dim] + + class SSDParameters(NamedTuple): + """Defines the parameters of the SSD recurrence.""" + + log_a: Tensor # [1, num_heads, 1] + b: Tensor # [batch_size, num_groups, seq_len, state_dim] + c: Tensor # [batch_size, num_groups, seq_len, state_dim] + delta: Tensor # [batch_size, num_heads, seq_len] + d: Tensor # [1, num_heads, 1, 1] + + # Cache used for internal inference, whereas Mamba2Output is external output. + class Mamba2Cache(NamedTuple): + """Defines the cache of the Mamba2MixerLayer for inference.""" + + # Naming is a bit different from Mamba1: conv_input -> conv_state. + x_conv_state: Tensor # [batch_size, seq_len, inner_dim] + b_conv_state: Tensor # [batch_size, seq_len, state_dim * 2] + c_conv_state: Tensor # [batch_size, seq_len, state_dim * 2] + ssd_state: Tensor # [batch_size, num_heads, state_dim, head_dim] + time_step: Optional[Tensor] = None # [batch] + + @property + def inner_dim(self): + cfg = self.config + return int(cfg.input_dim * cfg.expansion_factor) + + @property + def head_dim(self): + cfg = self.config + return self.inner_dim // cfg.num_heads + + @property + def output_dim(self): + cfg = self.config + return cfg.input_dim + + @property + def group_dim(self): + cfg = self.config + return self.inner_dim // cfg.num_groups + + @property + def bc_state_dim(self): + cfg = self.config + return cfg.state_dim * cfg.num_groups + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + + self._add_child( + "xz_proj", + cfg.xz_proj.set( + input_dim=cfg.input_dim, + num_outputs=2, + output_dim=self.inner_dim, + bias=False, + ), + ) + self._add_child( + "bc_proj", + cfg.bc_proj.set( + input_dim=cfg.input_dim, + num_outputs=2, + output_dim=self.bc_state_dim, + bias=False, + ), + ) + self._add_child( + "x_conv", + cfg.x_conv.set( + padding=((cfg.x_conv.window - 1, 0),), # A causal convolution. + input_dim=self.inner_dim, + output_dim=self.inner_dim, + num_input_dim_groups=self.inner_dim, + ), + ) + self._add_child( + "b_conv", + cfg.b_conv.set( + padding=((cfg.b_conv.window - 1, 0),), # A causal convolution. + input_dim=self.bc_state_dim, + output_dim=self.bc_state_dim, + num_input_dim_groups=self.bc_state_dim, + ), + ) + self._add_child( + "c_conv", + cfg.c_conv.set( + padding=((cfg.c_conv.window - 1, 0),), # A causal convolution. + input_dim=self.bc_state_dim, + output_dim=self.bc_state_dim, + num_input_dim_groups=self.bc_state_dim, + ), + ) + + # b/c norm is analoguous to q/k norm in standard attention. + if cfg.bc_norm: + self._add_child( + "b_norm", + cfg.bc_norm.clone().set( + input_dim=cfg.state_dim, eps=cfg.norm_eps, forward_dtype=cfg.norm_dtype + ), + ) + self._add_child( + "c_norm", + cfg.bc_norm.clone().set( + input_dim=cfg.state_dim, eps=cfg.norm_eps, forward_dtype=cfg.norm_dtype + ), + ) + + self._add_child( + "dt_proj", + cfg.dt_proj.set( + input_dim=cfg.input_dim, + output_dim=cfg.num_heads, + bias=False, + ), + ) + self._add_child( + "pre_out_proj_norm", + cfg.pre_out_proj_norm.set( + input_dim=self.inner_dim, num_groups=cfg.num_groups, eps=cfg.norm_eps + ), + ) + self._add_child( + "out_proj", + cfg.out_proj.set( + input_dim=self.inner_dim, + output_dim=cfg.input_dim, + bias=False, + ), + ) + + self._add_child("recurrence", cfg.ssd_recurrence) + self._add_child( + "inference_recurrence", + cfg.inference_mamba_recurrence.set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES + ), + ) + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + """Creates parameter specs. + + Returns: + A dict mapping `llog_a`, `dt_bias` and `d` to their respective ParameterSpecs. + """ + cfg = self.config + params = dict( + llog_a=ParameterSpec( + # Initialize with a shape that avoids expansion later. + shape=(1, cfg.num_heads, 1), + mesh_axes=(None, "model", None), + initializer=SSDLLogAInitializer.default_config().instantiate(), + dtype=cfg.dtype, + weight_decay_scale=0.0, + ), + dt_bias=ParameterSpec( + shape=(cfg.num_heads,), + mesh_axes=("model",), + initializer=SSDdtBiasInitializer.default_config().instantiate(), + dtype=cfg.dtype, + weight_decay_scale=0.0, + ), + d=ParameterSpec( + # Initialize with a shape that avoids expansion later. + shape=(1, cfg.num_heads, 1, 1), + mesh_axes=(None, "model", None, None), + initializer=constant_initializer(1.0), + dtype=cfg.dtype, + weight_decay_scale=0.0, + ), + ) + return params + + def _project_input(self, inputs: Tensor) -> tuple[Tensor, Tensor]: + """Projects inputs into tensors with dimension inner_dim. + + Args: + inputs: [batch_size, seq_len, input_dim] + + Returns: + x, z of the same size [batch_size, seq_len, inner_dim] + """ + xz = self.xz_proj(inputs) + x, z = jnp.split(xz, 2, axis=-2) # [batch_size, seq_len, 1, inner_dim] + return jnp.squeeze(x, axis=2), jnp.squeeze(z, axis=2) + + def _ssm_parameters( + self, + inputs: Tensor, + b_input: Optional[Tensor] = None, + c_input: Optional[Tensor] = None, + ) -> SSDParameters: + """Computes input-dependent SSD parameters. + + Args: + inputs: [batch_size, seq_len, inner_dim] + b_input: [batch_size, seq_len, bc_state_dim]. If b_input and c_input + are given, no need to compute bc_proj. + c_input: [batch_size, seq_len, bc_state_dim]. If b_input and c_input + are given, no need to compute bc_proj. + + Exposing the computation of `b` and `c` is useful to keep track the conv1d input for + `b_conv` and `c_conv`. During training, `b_input` and `c_input` should be None as + they represent the results after short conv. During inference, they represent the + input of short conv. + + Returns: + An instance of SSMParameters. + + Raises: + ValueError: If only one of b_input and c_input is provided. + + TODO (bailin-wang): merge b_conv and c_conv for better efficiency. + """ + cfg = self.config + if (b_input is None) != (c_input is None): + raise ValueError("Either both or none of b_input and c_input should be provided.") + + if b_input is None or c_input is None: + bc = self.bc_proj(inputs) # [batch_size, seq_len, 2, bc_state_dim] + bc = rearrange(bc, "b s n d -> b s (n d)") + b, c = jnp.split(bc, 2, axis=-1) + else: + b = b_input + c = c_input + + b = jax.nn.silu(self.b_conv(b)) + c = jax.nn.silu(self.c_conv(c)) + + b = rearrange(b, "b s (g d) -> b g s d", d=cfg.state_dim) + c = rearrange(c, "b s (g d) -> b g s d", d=cfg.state_dim) + + if "b_norm" in self.children and "c_norm" in self.children: + b = self.b_norm(b) + c = self.c_norm(c) + + # `dt` is in float32 for better precision of softplus for the delta term which later will + # be combined with float32 `log_a`. See also the following link: + # https://github.com/state-spaces/mamba/blob/6b72c122713bb769cc82c6b8e6d019c53d27d6a1/mamba_ssm/ops/triton/ssd_combined.py#L603. + dt = self.dt_proj(inputs) + jnp.expand_dims( + _at_least_float32(self.parameters["dt_bias"]), axis=(0, 1) + ) + delta = jax.nn.softplus(dt) # [batch_size, seq_len, num_heads] + delta = rearrange(delta, "b s h -> b h s") # [batch_size, num_heads, seq_len] + + log_a = -jnp.exp( + _at_least_float32(self.parameters["llog_a"]) + ) # a = exp(-exp(llog_a)), log_a = -exp(llog_a * delta) + + return Mamba2MixerLayer.SSDParameters( + log_a=log_a, b=b, c=c, delta=delta, d=self.parameters["d"] + ) + + def _output_from_states(self, inputs: Tensor, *, z: Tensor) -> Tensor: + """Projects recurrence output back to input dimension. + + Args: + inputs: [batch_size, num_heads, seq_len, head_dim] + z: [batch_size, num_heads, seq_len, head_dim] + + Returns: + A tensor of shape [batch_size, seq_len, input_dim] + + Note that the num_heads/num_groups dim is contracted in the output. + """ + cfg = self.config + y = inputs * jax.nn.silu(z) + y_for_gnorm = rearrange(y, "b nh l d -> b l (nh d)", nh=cfg.num_heads) + y_for_proj = self.pre_out_proj_norm(y_for_gnorm) + return self.out_proj(y_for_proj) + + def forward(self, query: Tensor) -> Mamba2Output: + """Computes the Mamba2 recurrence over the provided inputs. + + Args: + query: [batch_size, input_length, input_dim] + + Returns: + A Mamba2Output instance where .data is the same shape as `inputs`. + """ + _, output = self._forward_for_mode(mode=ForwardMode.FORWARD, query=query) + return output + + def _forward_for_mode( + self, + *, + mode: ForwardMode, + query: Tensor, + cache: Optional[Mamba2Cache] = None, + ) -> tuple[Optional[Nested[Tensor]], Tensor]: + """Computes MambaMixerLayer outputs. + + Args: + mode: {FORWARD, INIT_STATES, EXTEND_STEP} + query: A Tensor of shape [batch_size, seq_len, input_dim] + cache: Optional NestedTensor as produced by `prefill_states`. + + Returns: + An optional cache, depending on `mode`. + A Mamba2Output instance, where .data is of the same shape as `inputs`. + + Raises: + ValueError: If `mode` is unsupported. + """ + self.vlog(3, "mamba2.input=%s", query.sum()) + if mode == ForwardMode.FORWARD: + mamba_cache, mamba_output = self._full_sequence_forward( + query, recurrence=self.recurrence + ) + elif mode == ForwardMode.INIT_STATES: + assert cache is not None + mamba_cache, mamba_output = self.prefill_states( + time_step=cache, + query=query, + ) + elif mode == ForwardMode.EXTEND_STEP: + assert cache is not None + mamba_cache, mamba_output = self.extend_step(cache, query) + else: + raise ValueError(f"Unrecognized mode {mode}.") + self.vlog(3, "mamba2.output=%s", mamba_output.data.sum()) + return dict(mamba_layer=mamba_cache), mamba_output + + def _full_sequence_forward( + self, inputs: Tensor, *, recurrence: BaseSSDRecurrence + ) -> tuple[Optional[Mamba2Cache], Mamba2Output]: + """Computes the Mamba2 layer output from a full sequence of inputs. + + Args: + inputs: A tensor of shape [batch_size, seq_len, input_dim]. + recurrence: A BaseMambaRecurrence to use for computing the recurrence. + + Returns: + An optional Mamba2Cache instance. Currently, it is always None. + A Mamba2Output instance. + """ + cfg = self.config + + x, z = self._project_input(inputs) + x_conv = jax.nn.silu(self.x_conv(x)) + x_conv_w_head = rearrange(x_conv, "b s (h d) -> b h s d", d=self.head_dim) + z_w_head = rearrange(z, "b s (h d) -> b h s d", d=self.head_dim) + + log_a, b, c, delta, d = self._ssm_parameters(inputs) + recurrence_output = recurrence(x_conv_w_head, log_a=log_a, b=b, c=c, delta=delta, d=d) + output = self._output_from_states(recurrence_output.data, z=z_w_head) + + ssd_state = recurrence_output.states + if ssd_state is not None: + ssd_state = ssd_state.astype(cfg.cache_dtype) + + mamba_cache = None + mamba_output = Mamba2MixerLayer.Mamba2Output(data=output, ssd_state=ssd_state) + return mamba_cache, mamba_output + + # pylint: disable=unused-argument + def init_states(self, *, target_batch_size: int, target_max_len: int) -> Mamba2Cache: + """Initializes cache for autoregressive cached decoding. + + Args: + batch_size: The batch size of the target to be decoded. + target_max_len: The maximum length of the target to be decoded. + + Returns: + A Mamba2Cache instance. + """ + cfg = self.config + dtype = cfg.cache_dtype or cfg.dtype + cache = Mamba2MixerLayer.Mamba2Cache( + x_conv_state=jnp.zeros( + (target_batch_size, cfg.x_conv.window, self.inner_dim), dtype=dtype + ), + b_conv_state=jnp.zeros( + (target_batch_size, cfg.b_conv.window, self.bc_state_dim), dtype=dtype + ), + c_conv_state=jnp.zeros( + (target_batch_size, cfg.c_conv.window, self.bc_state_dim), dtype=dtype + ), + ssd_state=jnp.zeros( + (target_batch_size, cfg.num_heads, cfg.state_dim, self.head_dim), dtype=dtype + ), + time_step=jnp.zeros(target_batch_size, dtype=jnp.int32), + ) + return cache + + def prefill_states( + self, + *, + time_step: Tensor, + query: Tensor, + ) -> tuple[Mamba2Cache, Mamba2Output]: + """Initializes cache for autoregressive cached decoding. It refines the mamba state + returned from `_full_sequence_forward` to the state at `time_step` for the + incremental decoding later. + + Args: + time_step: A Tensor of shape [batch_size]. Each value is an index into the length + dimension indicating where decoding will start from. + query: Tensor of shape [batch_size, target_length, target_dim] corresponding to input + vector up to `time_step` indices. For batch index `i`, only + `inputs[i, :time_step[i], ...]` will affect subsequent decoding. + + Returns: + A Mamba2Cache instance containing updated convolution state, ssm state and time_step. + A Mamba2Output instance where .data is the same shape as query. + """ + cfg = self.config + cache_dtype = cfg.cache_dtype or cfg.dtype + + x, z = self._project_input(query) + x_conv = jax.nn.silu(self.x_conv(x)) + x_conv_w_head = rearrange(x_conv, "b s (h d) -> b h s d", d=self.head_dim) + z_w_head = rearrange(z, "b s (h d) -> b h s d", d=self.head_dim) + + # Run `bc_proj` outside of `_ssm_parameters` so that we can keep track of the conv1d input. + bc_input = self.bc_proj(query) # [batch_size, seq_len, 2, bc_state_dim] + bc_input = rearrange(bc_input, "b s n d -> b s (n d)") + b_input, c_input = jnp.split(bc_input, 2, axis=-1) + log_a, b, c, delta, d = self._ssm_parameters(query, b_input=b_input, c_input=c_input) + + recurrence_output = self.inference_recurrence( + x_conv_w_head, log_a=log_a, b=b, c=c, delta=delta, d=d, time_step=time_step + ) + output = self._output_from_states(recurrence_output.data, z=z_w_head) + mamba_output = Mamba2MixerLayer.Mamba2Output( + data=output, ssd_state=recurrence_output.states.astype(cache_dtype) + ) + + # Collect and refine conv states and ssd states. + x_conv_state = x + b_conv_state = b_input + c_conv_state = c_input + + # For the full sequence, always in float32, will be down-cast based on cache_dtype. + cont_ssd_state = recurrence_output.states.astype(cache_dtype) + + batch_size = query.shape[0] + batch_range = jnp.arange(batch_size) + + # Pad conv input so we can take the last window timesteps that precede time_step. + x_time_step_range = time_step[:, None] + jnp.arange(cfg.x_conv.window)[None, :] + padded_x_conv_state = jnp.pad( + x_conv_state, ((0, 0), (cfg.x_conv.window, 0), (0, 0)) + ) # [batch_size, target_length+window, input_dim] + cont_x_conv_state = padded_x_conv_state[batch_range[:, None], x_time_step_range] + + b_time_step_range = time_step[:, None] + jnp.arange(cfg.b_conv.window) + padded_b_conv_state = jnp.pad(b_conv_state, ((0, 0), (cfg.b_conv.window, 0), (0, 0))) + cont_b_conv_state = padded_b_conv_state[batch_range[:, None], b_time_step_range] + + c_time_step_range = time_step[:, None] + jnp.arange(cfg.c_conv.window) + padded_c_conv_state = jnp.pad(c_conv_state, ((0, 0), (cfg.c_conv.window, 0), (0, 0))) + cont_c_conv_state = padded_c_conv_state[batch_range[:, None], c_time_step_range] + + init_cache = Mamba2MixerLayer.Mamba2Cache( + x_conv_state=cont_x_conv_state.astype(cache_dtype), + b_conv_state=cont_b_conv_state.astype(cache_dtype), + c_conv_state=cont_c_conv_state.astype(cache_dtype), + ssd_state=cont_ssd_state.astype(cache_dtype), + time_step=time_step, + ) + return init_cache, mamba_output + + def _single_step_conv_update( + self, + inputs: Tensor, + *, + conv_state: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + ) -> tuple[Tensor, Tensor]: + """Updates cache of convolutional inputs and returns updated state. + + Args: + inputs: [batch_size, inner_dim] + conv_state: [batch_size, width, inner_dim] + weight: [width, 1, inner_dim] + bias: [inner_dim] + + Returns: + A tensor of shape [batch_size, inner_dim]. + A tensor of shape [batch_size, width, inner_dim], representing the new conv state. + """ + new_conv_state = jnp.roll(conv_state, shift=-1, axis=1) + new_conv_state = new_conv_state.at[:, -1].set(inputs) + + conv_output = jnp.sum( + new_conv_state * jnp.squeeze(_at_least_float32(weight), axis=1), axis=1 + ).astype( + inputs.dtype + ) # [batch_size, inner_dim] + if bias is not None: + conv_output = conv_output + bias + return conv_output, new_conv_state + + def _single_step_ssm_update( + self, + x: Tensor, + *, + ssm_state: Tensor, + log_a: Tensor, + b: Tensor, + c: Tensor, + d: Tensor, + delta: Tensor, + ) -> tuple[Tensor, Tensor]: + """Moves the SSM state forward by a single step. + + Args: + x: [batch_size, num_heads, 1, head_dim] + ssm_state: [batch_size, num_heads, state_dim, head_dim] + log_a: [1, num_heads, 1], always float32 + b: [batch_size, num_groups, 1, state_dim] + c: [batch_size, num_groups, 1, state_dim] + delta: [batch_size, num_heads, 1], always float32 + d: [1, head_dim, 1, 1] + + Returns: + A tensor of shape [batch_size, num_heads, 1, head_dim] for the new output. + A tensor of shape [batch_size, num_heads, state_dim, head_dim] for the updated state. + """ + cfg = self.config + num_head_per_group = cfg.num_heads // cfg.num_groups + + orig_dtype = x.dtype + acc_dtype = cfg.cache_dtype or cfg.dtype + + # x: [batch_size, num_heads, head_dim] + # b and c: [batch_size, num_groups, state_dim] + # d: [batch_size, num_heads] + x, b, c, d = map(lambda x: jnp.squeeze(x, axis=2), (x, b, c, d)) + + # [batch_size, num_heads, state_dim] + b = repeat(b, "b ng d -> b (ng ngh) d", ngh=num_head_per_group) + c = repeat(c, "b ng d -> b (ng ngh) d", ngh=num_head_per_group) + + # [batch_size, num_heads, head_dim] + x_bar = x * delta + # [batch_size, num_heads, 1] + loga_bar = log_a * delta + # [batch_size, num_heads, 1] + a = jnp.exp(loga_bar) + # [batch_size, num_heads, state_dim, head_dim] + a = jnp.expand_dims(a, axis=-1) + + new_ssm_state = a * ssm_state + jnp.einsum("...i,...j->...ij", b, x_bar) + output = jnp.einsum("...ij,...i->...j", new_ssm_state, c) + d * x + + output = jnp.expand_dims(output.astype(orig_dtype), axis=2) + new_ssm_state = new_ssm_state.astype(acc_dtype) + return output, new_ssm_state + + def extend_step( + self, + cache: Mamba2Cache, + query: Tensor, + ) -> tuple[Mamba2Cache, Mamba2Output]: + """Computes the next state given the query of the current step. This function is used + in autoregressive decoding. + + Args: + cached_states: A Nested[Tensor] containing previous state of shape and index. + query: Tensor of shape [batch_size, 1, inner_dim] + + Returns: + A Mamba2Cache instance containing the convolution state, ssm state and time_step. + A Mamba2Output instance, where .data is the same shape as query. + """ + time_step: Tensor = cache.time_step + assert time_step.ndim == 1 + cfg = self.config + + x, z = self._project_input(query) + x_conv, new_x_conv_state = self._single_step_conv_update( + jnp.squeeze(x, axis=1), + conv_state=cache.x_conv_state, + weight=self.parameters["x_conv"]["weight"], + bias=self.parameters["x_conv"]["bias"], + ) + x_conv = jnp.expand_dims(jax.nn.silu(x_conv), axis=1) # [batch_size, 1, inner_dim] + x_conv_w_head = rearrange(x_conv, "b s (h d) -> b h s d", d=self.head_dim) + z_w_head = rearrange(z, "b s (h d) -> b h s d", d=self.head_dim) + + # Obtain ssm parameters. + bc = self.bc_proj(query) # [batch_size, seq_len, 2, bc_state_dim] + bc = rearrange(bc, "b s n d -> b s (n d)") + b, c = jnp.split(bc, 2, axis=-1) + + b_conv, new_b_conv_state = self._single_step_conv_update( + jnp.squeeze(b, axis=1), + conv_state=cache.b_conv_state, + weight=self.parameters["b_conv"]["weight"], + bias=self.parameters["b_conv"]["bias"], + ) + b = jnp.expand_dims(jax.nn.silu(b_conv), axis=1) # [batch_size, 1, bc_inner_dim] + + c_conv, new_c_conv_state = self._single_step_conv_update( + jnp.squeeze(c, axis=1), + conv_state=cache.c_conv_state, + weight=self.parameters["c_conv"]["weight"], + bias=self.parameters["c_conv"]["bias"], + ) + c = jnp.expand_dims(jax.nn.silu(c_conv), axis=1) # [batch_size, 1, bc_inner_dim] + + b = rearrange(b, "b s (g d) -> b g s d", d=cfg.state_dim) + c = rearrange(c, "b s (g d) -> b g s d", d=cfg.state_dim) + + if cfg.bc_norm: + b = self.b_norm(b) + c = self.c_norm(c) + + dt = self.dt_proj(query) + jnp.expand_dims( + _at_least_float32(self.parameters["dt_bias"]), axis=(0, 1) + ) + delta = jax.nn.softplus(dt) # [batch_size, 1, num_heads] + delta = rearrange(delta, "b s h -> b h s") # [batch_size, num_heads, 1] + + log_a = -jnp.exp( + _at_least_float32(self.parameters["llog_a"]) + ) # a = exp(-exp(llog_a)), log_a = -exp(llog_a) + d = self.parameters["d"] + + y, new_ssd_state = self._single_step_ssm_update( + x_conv_w_head, + ssm_state=cache.ssd_state, + log_a=log_a, + b=b, + c=c, + d=d, + delta=delta, + ) + output = self._output_from_states(y, z=z_w_head) + + new_cache = Mamba2MixerLayer.Mamba2Cache( + x_conv_state=new_x_conv_state, + b_conv_state=new_b_conv_state, + c_conv_state=new_c_conv_state, + ssd_state=new_ssd_state, + time_step=time_step + 1, + ) + mamba2output = Mamba2MixerLayer.Mamba2Output( + data=output, + ssd_state=new_ssd_state, + ) + return new_cache, mamba2output + + +class JambaMamba2Block(JambaMambaBlock): + """A JambaMamba2Block along with RMN norm and a feed-forward layer.""" + + @config_class + class Config(JambaMambaBlock.Config): + """Configures a JambaMamba2Block.""" + + num_heads: Required[int] = REQUIRED + num_groups: Required[int] = REQUIRED + + @classmethod + def default_config(cls) -> Config: + cfg = super().default_config() + cfg.mamba_layer = Mamba2MixerLayer.default_config() + return cfg + + def __init__(self, cfg: Config, *, parent: Module): + cfg.mamba_layer = cfg.mamba_layer.set(num_heads=cfg.num_heads, num_groups=cfg.num_groups) + super().__init__(cfg, parent=parent) + + +def set_double_shard_weights_config_mamba2( + cfg: Union[JambaMamba2Block.Config, Sequence[JambaMamba2Block.Config]], + *, + batch_axis_names: Union[str, Sequence[str]] = ("data", "expert", "fsdp"), + fsdp_axis_names: Union[str, Sequence[str]] = "fsdp", + tp_axis_names: Union[str, Sequence[str]] = "model", + seq_axis_names: Union[str, Sequence[str]] = "seq", +): + """Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes. + + Args: + cfg: (A sequence of) Transformer layer config to apply sharding spec to. + batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors. + fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors. + tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors. + seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors. + """ + + def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): + # Shard weights. + ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names) + ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names) + # Encourage the right activation sharding. + ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + + def set_mamba2_partition_specs(mamba_layer: Mamba2MixerLayer.Config): + mamba_layer.xz_proj.param_partition_spec = (fsdp_axis_names, None, tp_axis_names) + mamba_layer.bc_proj.param_partition_spec = (fsdp_axis_names, None, tp_axis_names) + mamba_layer.b_conv.param_partition_spec = (None, None, tp_axis_names) + mamba_layer.c_conv.param_partition_spec = (None, None, tp_axis_names) + mamba_layer.dt_proj.param_partition_spec = (fsdp_axis_names, tp_axis_names) + mamba_layer.out_proj.param_partition_spec = (tp_axis_names, fsdp_axis_names) + + mamba_layer.dt_proj.output_partition_spec = ( + batch_axis_names, + seq_axis_names, + tp_axis_names, + ) + mamba_layer.out_proj.output_partition_spec = ( + batch_axis_names, + seq_axis_names, + tp_axis_names, + ) + + if not isinstance(cfg, Sequence): + cfg = [cfg] + + for layer_cfg in cfg: + set_mamba2_partition_specs(layer_cfg.mamba_layer) + if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config): + set_ffn_partition_specs(layer_cfg.feed_forward) diff --git a/axlearn/common/ssm_kernels/ssd_kernels.py b/axlearn/common/ssm_kernels/ssd_kernels.py new file mode 100644 index 00000000..ce537652 --- /dev/null +++ b/axlearn/common/ssm_kernels/ssd_kernels.py @@ -0,0 +1,785 @@ +# Copyright © 2024 Apple Inc. + +""" Pallas kernels for Mamba2 + +High-level idea: this kernel implements a two-level chunking algorithm to +balance memory consumption and running speed. Intuitively, we store chunk-level +hidden states to avoid recomputation, and subchunk-level states are recomputed based +on the chunk-level states. + + +Notations: + nb: number of chunks + ns: number of subchunks + bl: subchunk size + dkn: number of tiles in the dk dim + dvn: number of tiles in the dv dim + dk: state_dim (corresponds to dim of qk heads) + dv: head_dim (corresponds to dim of v heads) + +q/k/v is used as it's more intuitive than b/c/x of SSD in the orginal implementation, +see section 7.2 https://arxiv.org/pdf/2405.21060. Accordingly, dk/dv is used instead +of state_dim/head_dim. This notation is also used in linear attention models. +However, state_dim/head_dim is used in the model file to be consistent with Mamba1 +and the original implementation. + +""" + +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from einops import rearrange, repeat +from jax import lax +from jax._src.lax.control_flow import for_loop +from jax.experimental import pallas as pl + +from axlearn.common.utils import Tensor + + +def _matmul_fp32(lhs: Tensor, rhs: Tensor) -> Tensor: + """A wrapper around jax.lax.dot to conduct float32 matmul""" + return jax.lax.dot(lhs, rhs, precision="float32", preferred_element_type=jnp.float32) + + +@jax.custom_vjp +def _ssd(q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, h0: Tensor) -> Tensor: + """A differentiable function that computes the output of SSD. + + Args: + q: [bs, num_heads, seq_len, dk] + k: [bs, num_heads, seq_len, dk] + v: [bs, num_heads, seq_len, dv] + log_alpha: [bs, num_heads, seq_len] + h0: [bs, num_heads, dk, dv] + + Returns: + o: [bs, num_heads, seq_len, dv] + """ + ( + o, + _, + ) = _ssd_forward(q, k, v, log_alpha, h0) + return o + + +def _ssd_forward_kernel( + q_ref: Tensor, + k_ref: Tensor, + v_ref: Tensor, + cum_log_alpha_ref: Tensor, + initial_state_ref: Tensor, + gamma_ref: Tensor, + mutable_ch_ref: Tensor, + mutable_final_state_ref: Tensor, + mutable_o_ref: Tensor, +): + """Forward kernel for SSD. + + Args: + q_ref: tensor reference of shape [ns, bl, singleton_dim] + k_ref: tensor reference of shape [ns, bl, singleton_dim] + v_ref: tensor reference of shape [ns, bl, singleton_dim] + cum_log_alpha_ref: tensor reference of shape [ns, bl] + initial_state_ref: tensor reference of shape [singleton_dim, singleton_dim] + gamma_ref: tensor reference of shape [ns, bl, singleton_dim] + + Output via mutable tensors: + mutable_ch_ref: tensor reference of shape [ns, singleton_dim, singleton_dim] + mutable_final_state_ref: tensor reference of shape [singleton_dim, singleton_dim] + mutable_o_ref: tensor reference of shape [ns, bl, singleton_dim] + + Note on intial_state and final_state: + * initial_state is at seq-level and not updated during the forward pass + * final_state is used to pass chunk-level states across different chunks + - it will be initialized to initial_state at the beginning of each chunk + - it will be updated after processing each chunk + - in the end, it will return as the seq-level final state + """ + subchunk_dim, subchunk_size = cum_log_alpha_ref.shape[0], cum_log_alpha_ref.shape[1] + casual_mask = jnp.tril(jnp.ones((subchunk_size, subchunk_size)), k=0) + + # In our grid definition, axis 4 is the chunk index. + @pl.when(pl.program_id(axis=4) == 0) + def init_carry(): + mutable_final_state_ref[:, :] = initial_state_ref[:, :] + + def _ssd_forward_chunk_loop_body(t: int, h_carry_ref: Tensor): + subchunk_idx = t + prev_state = h_carry_ref[:, :] + + q_block = q_ref[subchunk_idx, :].astype(jnp.float32) + k_block = k_ref[subchunk_idx, :].astype(jnp.float32) + v_block = v_ref[subchunk_idx, :].astype(jnp.float32) + + # Notation mapping wrt. the paper: lambda -> Lambda, gamma -> gamma, beta -> Gamma. + lambda_block = cum_log_alpha_ref[subchunk_idx, :] + gamma_block = gamma_ref[subchunk_idx] + + lambda_block = jnp.expand_dims(lambda_block, axis=-1) # [bl, 1] + beta_block = ( + jnp.expand_dims(gamma_block, axis=0) - lambda_block + ) # [bl, singleton_dim] after broadcasting + ssd_mask_block = lambda_block - jnp.transpose(lambda_block, [1, 0]) + ssd_mask_block = ssd_mask_block * casual_mask + + lambda_block = jnp.exp(lambda_block) + beta_block = jnp.exp(beta_block) + gamma_block = jnp.exp(gamma_block) + ssd_mask_block = jnp.exp(ssd_mask_block) + + q_tilde_block = q_block * lambda_block + k_tilde_block = k_block * beta_block + + o_block_inter = _matmul_fp32(q_tilde_block, prev_state) + intra_att = _matmul_fp32(q_block, k_block.T) + attn_mask = casual_mask * ssd_mask_block + o_block_intra = _matmul_fp32((intra_att * attn_mask), v_block) + o_block = o_block_inter + o_block_intra + + cur_state = prev_state * jnp.expand_dims(gamma_block, axis=-1) + _matmul_fp32( + k_tilde_block.T, v_block + ) # [d_k, d_v] + h_carry_ref[:, :] = cur_state + mutable_o_ref[subchunk_idx, :] = o_block.astype(mutable_o_ref.dtype) + + # Obtain final state from previous chunk. + h_carry = mutable_final_state_ref[:, :] + mutable_ch_ref[:, :] = mutable_final_state_ref[:, :] + final_state = for_loop.for_loop( + subchunk_dim, + _ssd_forward_chunk_loop_body, + h_carry, + ) + mutable_final_state_ref[:, :] = final_state + + +@jax.jit +def _ssd_forward( + q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, initial_state: Tensor +) -> Tuple: + """Forward pass for SSD. + + Args: + q, k: [bs, num_heads, seq_len, dk] + v: [bs, num_heads, seq_len, dv] + log_alpha: [bs, num_heads, seq_len] + initial_state: [singleton_dim, singleton_dim] + + Returns: + o: [bs, num_heads, seq_len, dv] + residuals: Tuple of tensors to be used in the backward + """ + bs, num_qk_heads, seq_len, k_head_dim = q.shape + _, num_v_heads, _, v_head_dim = v.shape + # TODO (bailin-wang): the following defaults works best for v5p, but they may not be optimal + # for others tpu types. We may need to expose them as arguments in the future. + singleton_dim = 128 + chunk_size, subchunk_size = 512, 64 + acc_dtype, orig_dtype = jnp.float32, q.dtype + + assert seq_len % chunk_size == 0 and chunk_size % subchunk_size == 0 + + assert num_v_heads % num_qk_heads == 0 + num_heads = num_v_heads + num_head_per_group = num_v_heads // num_qk_heads + + assert k_head_dim % singleton_dim == 0 + assert v_head_dim % singleton_dim == 0 + num_k_tiles = k_head_dim // singleton_dim + num_v_tiles = v_head_dim // singleton_dim + + # Add two extra dims for chunk-wise computation. + chunk_dim = seq_len // chunk_size + subchunk_dim = chunk_size // subchunk_size + + grid = (bs, num_heads, num_k_tiles, num_v_tiles, chunk_dim) + + # q/k/v tensors are kept in bf16 and converted later to fp32 in VMEM. + log_alpha = log_alpha.astype(jnp.float32) + initial_state = initial_state.astype(jnp.float32) + + # None is effectively 1, but the dim will be squeezed out. + qk_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + qk_spec = pl.BlockSpec( + lambda b, h, k, v, m: (b, lax.div(h, num_head_per_group), m, 0, k), qk_tiling + ) + v_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + v_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, m, 0, v), v_tiling) + + alpha_tiling = (None, None, None, subchunk_dim, subchunk_size) + alpha_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, m, 0, 0), alpha_tiling) + + # Initial hidden states. + is_tiling = (None, None, singleton_dim, singleton_dim) + is_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, k, v), is_tiling) + + # Chunk-wise states (not subchunk-wise states). + ch_tiling = (None, None, None, singleton_dim, singleton_dim) + ch_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, m, k, v), ch_tiling) + + # Chunk-wise final states help pass states from the previous chunk to the next. + fs_spec = is_spec + + ch_shape = jax.ShapeDtypeStruct( + shape=(bs, num_heads, chunk_dim, k_head_dim, v_head_dim), dtype=acc_dtype + ) + fs_shape = jax.ShapeDtypeStruct( + shape=(bs, num_heads, k_head_dim, v_head_dim), dtype=jnp.float32 + ) + + # Pre-compute the cumulative sum of log_alpha. + log_alpha = rearrange( + log_alpha, "b h (nb ns bl) -> b h nb ns bl", nb=chunk_dim, ns=subchunk_dim + ) + cum_log_alpha = jnp.cumsum(log_alpha, axis=-1) + + q = rearrange(q, "b h (nb bl) dk -> b h nb bl dk", bl=subchunk_size) + k = rearrange(k, "b h (nb bl) dk -> b h nb bl dk", bl=subchunk_size) + v = rearrange(v, "b h (nb bl) dv -> b h nb bl dv", bl=subchunk_size) + + # Pallas kernels operate on tiles of size at least [8, 128]. + gamma = cum_log_alpha[:, :, :, :, subchunk_size - 1 :] # [b, h, nb, ns, 1] + gamma_expanded = jnp.repeat(gamma, singleton_dim, axis=-1) # [b, h, nb, ns, singleton_dim] + gamma_tiling = (None, None, None, subchunk_dim, singleton_dim) + gamma_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, m, 0, 0), gamma_tiling) + + o_tiling = (None, None, None, subchunk_dim, subchunk_size, singleton_dim) + o_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, k, m, 0, v), o_tiling) + o_shape = jax.ShapeDtypeStruct( + shape=(bs, num_heads, num_k_tiles, chunk_dim * subchunk_dim, subchunk_size, v_head_dim), + dtype=orig_dtype, + ) + + chunk_states, final_state, o = pl.pallas_call( + _ssd_forward_kernel, + in_specs=(qk_spec, qk_spec, v_spec, alpha_spec, is_spec, gamma_spec), + out_specs=(ch_spec, fs_spec, o_spec), + out_shape=(ch_shape, fs_shape, o_shape), + grid=grid, + compiler_params=dict( + mosaic=dict( + dimension_semantics=("parallel", "parallel", "parallel", "parallel", "arbitrary") + ) + ), + )(q, k, v, cum_log_alpha, initial_state, gamma_expanded) + + o = jnp.sum(o, axis=2) # sum over dkn dim + o = rearrange(o, "b h nb bl dv -> b h (nb bl) dv") + + # Input tensors q/k/v stored in the residual list for backward pass are reshaped, and + # cum_log_alpha and gamma are upcasted to float32. + final_state = final_state.astype(orig_dtype) + return o, (q, k, v, cum_log_alpha, gamma_expanded, chunk_states, final_state) + + +def _ssd_backward_kernel( + q_ref: Tensor, + k_ref: Tensor, + v_ref: Tensor, + cum_log_alpha_ref: Tensor, + gamma_ref: Tensor, + ch_ref: Tensor, + mutable_do_ref: Tensor, + mutable_dq_ref: Tensor, + mutable_dk_ref: Tensor, + mutable_dv_ref: Tensor, + mutable_dh_carry_ref: Tensor, +): + """Backward kernel for SSD. + + Args: + q_ref: tensor reference of shape [ns, bl, singleton_dim] + k_ref: tensor reference of shape [ns, bl, singleton_dim] + v_ref: tensor reference of shape [ns, bl, singleton_dim] + cum_log_alpha_ref: tensor reference of shape [ns, bl] + gamma_ref: tensor reference of shape [ns, bl, singleton_dim] + ch_ref: tensor reference of shape [ns, singleton_dim, singleton_dim] + + Output via mutable tensors: + mutable_do_ref: tensor reference of shape [ns, bl, singleton_dim] + mutable_dq_ref: tensor reference of shape [ns, bl, singleton_dim] + mutable_dk_ref: tensor reference of shape [ns, bl, singleton_dim] + mutable_dv_ref: tensor reference of shape [ns, bl, singleton_dim] + mutable_dh_carry_ref: tensor reference of shape [ns, singleton_dim, singleton_dim] + + Note: similar to final_state in the forward pass, dh_carry is used to pass gradients wrt. + hidden states across different chunks. It will be initalized to zero at the last chunk. + The final gradient wrt. hidden states will be returned as the gradient wrt. initial_state. + """ + subchunk_dim, subchunk_size = cum_log_alpha_ref.shape[0], cum_log_alpha_ref.shape[1] + causal_mask = jnp.tril(jnp.ones((subchunk_size, subchunk_size)), k=0).astype(jnp.float32) + + # In our grid definition, axis 4 is the chunk index. + @pl.when(pl.program_id(axis=4) == 0) + def init_carry(): + mutable_dh_carry_ref[:, :] = jnp.zeros_like(mutable_dh_carry_ref, dtype=jnp.float32) + + def _ssd_backward_dq_chunk_loop_body(t: int, h_carry_ref: Tensor): + subchunk_idx = t + h_block = h_carry_ref[:, :] # final states from previous chunk + k_block = k_ref[subchunk_idx, :].astype(jnp.float32) + v_block = v_ref[subchunk_idx, :].astype(jnp.float32) + do_block = mutable_do_ref[subchunk_idx, :].astype(jnp.float32) + + lambda_block = cum_log_alpha_ref[subchunk_idx, :] + gamma_block = gamma_ref[subchunk_idx] + + lambda_block = jnp.expand_dims(lambda_block, axis=-1) # [nb, 1] + beta_block = gamma_block - lambda_block # [nb, d_k] + ssd_mask_block = lambda_block - jnp.transpose(lambda_block, [1, 0]) + ssd_mask_block = ssd_mask_block * causal_mask + + lambda_block = jnp.exp(lambda_block) + beta_block = jnp.exp(beta_block) + gamma_block = jnp.exp(gamma_block) + ssd_mask_block = jnp.exp(ssd_mask_block) + + k_tilde_block = k_block * beta_block + + attn_mask = causal_mask * ssd_mask_block + d_intra_att = _matmul_fp32(do_block, v_block.T) * attn_mask + + dq_tilde_block = _matmul_fp32(do_block, h_block.T) + dq_block_1 = dq_tilde_block * lambda_block + dq_block_2 = _matmul_fp32(d_intra_att, k_block) + dq_block = dq_block_1 + dq_block_2 + mutable_dq_ref[subchunk_idx, :] = dq_block + + next_h_block = h_block * jnp.expand_dims(gamma_block, axis=-1) + _matmul_fp32( + k_tilde_block.T, v_block + ) + h_carry_ref[:, :] = next_h_block + + def _ssd_backward_dkv_chunk_loop_body(t: int, dh_carry_ref: Tensor): + subchunk_idx = t + dh_block = dh_carry_ref[:, :] + q_block = q_ref[subchunk_idx, :].astype(jnp.float32) + k_block = k_ref[subchunk_idx, :].astype(jnp.float32) + v_block = v_ref[subchunk_idx, :].astype(jnp.float32) + do_block = mutable_do_ref[subchunk_idx, :].astype(jnp.float32) + causal_mask = jnp.tril(jnp.ones((subchunk_size, subchunk_size)), k=0).astype(jnp.float32) + + lambda_block = cum_log_alpha_ref[subchunk_idx, :] + gamma_block = gamma_ref[subchunk_idx] + + lambda_block = jnp.expand_dims(lambda_block, axis=-1) # [nb, 1] + beta_block = gamma_block - lambda_block # [nb, d_k] + ssd_mask_block = lambda_block - jnp.transpose(lambda_block, [1, 0]) + ssd_mask_block = ssd_mask_block * causal_mask + + lambda_block = jnp.exp(lambda_block) + beta_block = jnp.exp(beta_block) + gamma_block = jnp.exp(gamma_block) + ssd_mask_block = jnp.exp(ssd_mask_block) + + q_tilde_block = q_block * lambda_block + k_tilde_block = k_block * beta_block + + intra_att = _matmul_fp32(q_block, k_block.T) + attn_mask = causal_mask * ssd_mask_block + d_intra_att = _matmul_fp32(do_block, v_block.T) * attn_mask + + dk_block_1 = _matmul_fp32(d_intra_att.T, q_block) + dk_tilde_block = _matmul_fp32(v_block, dh_block.T) + dk_block_2 = dk_tilde_block * beta_block + dk_block = dk_block_1 + dk_block_2 + mutable_dk_ref[subchunk_idx, :] = dk_block + + dv_block_1 = _matmul_fp32((intra_att * attn_mask).T, do_block) + dv_block_2 = _matmul_fp32(k_tilde_block, dh_block) + dv_block = dv_block_1 + dv_block_2 + mutable_dv_ref[subchunk_idx, :] = dv_block + + prev_dh_block = dh_block * jnp.expand_dims(gamma_block, axis=-1) + _matmul_fp32( + q_tilde_block.T, do_block + ) + dh_carry_ref[:, :] = prev_dh_block + + h_carry = ch_ref[:, :] + _ = for_loop.for_loop(subchunk_dim, _ssd_backward_dq_chunk_loop_body, h_carry) + + dh_carry = mutable_dh_carry_ref[:, :] + dinitial_state = for_loop.for_loop( + subchunk_dim, _ssd_backward_dkv_chunk_loop_body, dh_carry, reverse=True + ) + mutable_dh_carry_ref[:, :] = dinitial_state + + +@jax.jit +def _ssd_backward(residuals: Tuple, do: Tensor) -> Tuple: + """Backward pass for SSD. + + Args: + residuals: Tuple of tensors returned from the forward pass + do: [bs, num_heads, seq_len, dv] + + Returns: + dq: [bs, num_heads, seq_len, dk] + dk: [bs, num_heads, seq_len, dk] + dv: [bs, num_heads, seq_len, dv] + dlog_alpha: [bs, num_heads, seq_len] + dinitial_state: [bs, num_heads, dk, dv] + """ + q, k, v, cum_log_alpha, gamma_expanded, chunk_states, final_state = residuals + + # `final_state` preserves the original dtype (e.g., bfloat16). + orig_dtype = final_state.dtype + + singleton_dim = 128 + bs, num_heads, chunk_dim, subchunk_dim, subchunk_size = cum_log_alpha.shape + k_dim, v_dim = q.shape[-1], v.shape[-1] + num_k_tiles, num_v_tiles = k_dim // singleton_dim, v_dim // singleton_dim + num_qk_heads = q.shape[1] + num_head_per_group = num_heads // num_qk_heads + + grid = (bs, num_heads, num_k_tiles, num_v_tiles, chunk_dim) + + qk_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + qk_spec = pl.BlockSpec( + lambda b, h, k, v, m: (b, lax.div(h, num_head_per_group), chunk_dim - 1 - m, 0, k), + qk_tiling, + ) + v_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + v_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, 0, v), v_tiling) + + alpha_tiling = (None, None, None, subchunk_dim, subchunk_size) + alpha_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, 0, 0), alpha_tiling) + gamma_tiling = (None, None, None, subchunk_dim, singleton_dim) + gamma_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, 0, 0), gamma_tiling) + + ch_tiling = (None, None, None, singleton_dim, singleton_dim) + ch_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, k, v), ch_tiling) + + do_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + do_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, 0, v), do_tiling) + + dqk_tiling = (None, None, None, None, subchunk_dim, subchunk_size, singleton_dim) + dqk_spec = pl.BlockSpec( + lambda b, h, k, v, m: ( + b, + lax.div(h, num_head_per_group), + lax.rem(h, num_head_per_group), + v, + chunk_dim - 1 - m, + 0, + k, + ), + dqk_tiling, + ) + dqk_shape = jax.ShapeDtypeStruct( + shape=( + bs, + num_qk_heads, + num_head_per_group, + num_v_tiles, + chunk_dim * subchunk_dim, + subchunk_size, + k_dim, + ), + dtype=jnp.float32, + ) + + dv_tiling = (None, None, None, subchunk_dim, subchunk_size, singleton_dim) + dv_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, k, chunk_dim - 1 - m, 0, v), dv_tiling) + dv_shape = jax.ShapeDtypeStruct( + shape=(bs, num_heads, num_k_tiles, chunk_dim * subchunk_dim, subchunk_size, v_dim), + dtype=jnp.float32, + ) + + dh_carry_tiling = (None, None, singleton_dim, singleton_dim) + dh_carry_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, k, v), dh_carry_tiling) + dh_carry_shape = jax.ShapeDtypeStruct(shape=(bs, num_heads, k_dim, v_dim), dtype=jnp.float32) + + do = rearrange(do, "b h (nb bl) dv -> b h nb bl dv", bl=subchunk_size) + + dq, dk, dv, dinitial_state = pl.pallas_call( + _ssd_backward_kernel, + in_specs=(qk_spec, qk_spec, v_spec, alpha_spec, gamma_spec, ch_spec, do_spec), + out_specs=(dqk_spec, dqk_spec, dv_spec, dh_carry_spec), + out_shape=(dqk_shape, dqk_shape, dv_shape, dh_carry_shape), + grid=grid, + compiler_params=dict( + mosaic=dict( + dimension_semantics=("parallel", "parallel", "parallel", "parallel", "arbitrary") + ) + ), + )(q, k, v, cum_log_alpha, gamma_expanded, chunk_states, do) + + # Sum over dvn dim. + dq = jnp.sum(dq, axis=3) + dk = jnp.sum(dk, axis=3) + dq = rearrange(dq, "b ng nhg nb bl dk -> b ng nhg (nb bl) dk") + dk = rearrange(dk, "b ng nhg nb bl dk -> b ng nhg (nb bl) dk") + + # Compute dlog_alpha via `q * dq - k * dk`. + dq_ = rearrange(dq, "b ng nhg l dk -> b (ng nhg) l dk") + dk_ = rearrange(dk, "b ng nhg l dk -> b (ng nhg) l dk") + + q_ = repeat(q, "b ng nb bl dk -> b (ng nhg) nb bl dk", nhg=num_head_per_group) + k_ = repeat(k, "b ng nb bl dk -> b (ng nhg) nb bl dk", nhg=num_head_per_group) + q_ = rearrange(q_, "b h nb bl dk -> b h (nb bl) dk") + k_ = rearrange(k_, "b h nb bl dk -> b h (nb bl) dk") + + dlog_alpha_ = jnp.sum(dq_ * q_ - dk_ * k_, axis=-1) + dlog_alpha = lax.cumsum(dlog_alpha_, axis=2, reverse=True) + + # Sum over dkn dim. + dv = jnp.sum(dv, axis=2) + dv = rearrange(dv, "b h nb bl dv -> b h (nb bl) dv") + + # Sum over nhg dim + dq = jnp.sum(dq, axis=2) + dk = jnp.sum(dk, axis=2) + # `dlog_alpha` is always in float32, `dv` is also in float32. + dq, dk = dq.astype(orig_dtype), dk.astype(orig_dtype) + + dinitial_state = dinitial_state.astype(orig_dtype) + return dq, dk, dv, dlog_alpha, dinitial_state + + +_ssd.defvjp(_ssd_forward, _ssd_backward) + + +@jax.jit +@jax.named_call # `named_call` ensures the name is used in tracing, which is useful for profiling. +def ssd(q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, h0: Optional[Tensor] = None) -> Tensor: + """Differentiable function that computes the output of SSD. + + Args: + q: [batch_size, num_groups, seq_len, dk] + k: [batch_size, num_groups, seq_len, dk] + v: [batch_size, num_groups, seq_len, dv] + log_alpha: [batch_size, num_heads, seq_len] + h0: [batch_size, num_heads, dk, dv] + + Returns: + output: [batch_size, num_heads, seq_len, dv] + + The notion of groups is similar to the group in multi-group attention (or more preciesly + multi-value attention) -- one group of q/k corresponds to multiple v heads. + """ + + bs, ng, _, dk = q.shape + bs, nh, _, dv = v.shape + assert nh % ng == 0 + assert v.dtype == jnp.float32 + assert log_alpha.dtype == jnp.float32 + + if h0 is None: + h0 = jnp.zeros((bs, nh, dk, dv), dtype=jnp.float32) + + output = _ssd(q, k, v, log_alpha, h0) + return output + + +def ssd_linear_scan( + q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, h0: Union[Tensor, None] = None +) -> Tensor: + """LinearScan based reference implementations for testing SSD kernels. + + Args: + q, k: [batch_size, num_groups, seq_len, dk] + k: [batch_size, num_groups, seqlen, dk] + v: [batch_size, num_heads, seq_len, dv] + log_alpha: [batch_size, num_heads, seq_len] + h0: [batch_size, num_heads, dk, dv] or None + + Returns: + output: [batch_size, num_heads, seq_len, dv] + hidden_state: [batch_size, num_heads, dk, dv] + """ + bs, ng, _, dk = q.shape + bs, nh, _, dv = v.shape + assert nh % ng == 0 + + # The linearscan kernel assumes that nh == ng, so we need to repeat q/k. + num_head_per_group = nh // ng + q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + + # ITt's more convenient for vmap to have internal states of size [dv, dk] + if h0 is None: + h0 = jnp.zeros((bs, nh, dv, dk), dtype=jnp.float32) + else: + # to be consistent with pallas api, h0 is in dk x dv as input + h0 = rearrange(h0, "b h dk dv -> b h dv dk") + + # All inputs are upcasted to float32, making this function a good reference funciton to + # test pallas kernel's numerical precision in the case of bf16 inputs. + dtype = q.dtype + if dtype == jnp.bfloat16: + q, k, v, h0 = map(lambda x: x.astype(jnp.float32), (q, k, v, h0)) + + def scan_body_fn(h_prev, current_inputs): + acc_dtype = h_prev.dtype + q_t, k_t, v_t, log_a_t = current_inputs + a_t = jnp.exp(log_a_t).astype(acc_dtype) + h_next = a_t * h_prev + jnp.einsum("i,j->ij", v_t, k_t, preferred_element_type=jnp.float32) + o_t = jnp.einsum("ij,j->i", h_next, q_t, preferred_element_type=jnp.float32) + return h_next, o_t.astype(q_t.dtype) + + def single_head_scan(q_head, k_head, v_head, alpha_head, h0_head): + return jax.lax.scan(scan_body_fn, h0_head, (q_head, k_head, v_head, alpha_head)) + + multi_head_scan = jax.vmap(single_head_scan, in_axes=(0, 0, 0, 0, 0), out_axes=(0, 0)) + batched_scan = jax.vmap(multi_head_scan, in_axes=(0, 0, 0, 0, 0), out_axes=(0, 0)) + + # Note: if dk > 128 (e.g., 256), somehow jax jvp would fail; a work-around + # is to add another dim to ensure that minor dk is always 128. + q = rearrange(q, "b h l (dkn dks) -> dkn b h l dks", dks=128) + k = rearrange(k, "b h l (dkn dks) -> dkn b h l dks", dks=128) + h0 = rearrange(h0, "b h dv (dkn dks) -> dkn b h dv dks", dks=128) + + batched_scan = jax.vmap(batched_scan, in_axes=(0, 0, None, None, 0), out_axes=(0, 0)) + final_state, output = batched_scan(q, k, v, log_alpha, h0) + final_state = rearrange(final_state, "dkn b h dv dks -> b h dv (dkn dks)") + output = jnp.sum(output, axis=0) + + final_state = rearrange(final_state, "b h dv dk -> b h dk dv") + + if dtype == jnp.bfloat16: + output = output.astype(jnp.bfloat16) + final_state = final_state.astype(jnp.bfloat16) + + return output, final_state + + +def ssd_linear_scan_w_hidden_states( + q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, h0: Union[Tensor, None] = None +) -> Tensor: + """LinearScan based reference implementations for testing SSD kernels. + + This version additionally returns the hidden states of all tokens. + + Args: + q: [batch_size, num_groups, seqlen, dk] + k: [batch_size, num_groups, seqlen, dk] + v: [batch_size, num_heads, seqlen, dv] + log_alpha: [batch_size, num_heads, seqlen] + h0: [batch_size, num_heads, dk, dv] or None + + Returns: + output: [batch_size, num_heads, seq_len, dv] + hidden_states: [batch_size, num_heads, seq_len, dk, dv] + """ + bs, ng, _, dk = q.shape + bs, nh, _, dv = v.shape + assert nh % ng == 0 + + num_head_per_group = nh // ng + q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + + if h0 is None: + h0 = jnp.zeros((bs, nh, dv, dk), dtype=jnp.float32) + else: + # to be consistent with pallas api, h0 is in dk x dv as input + h0 = rearrange(h0, "b h dk dv -> b h dv dk") + + dtype = q.dtype + if dtype == jnp.bfloat16: + q, k, v, h0 = map(lambda x: x.astype(jnp.float32), (q, k, v, h0)) + + def scan_body_fn(h_prev, current_inputs): + acc_dtype = h_prev.dtype + k_t, v_t, log_a_t = current_inputs + a_t = jnp.exp(log_a_t).astype(acc_dtype) + h_next = a_t * h_prev + jnp.einsum("i,j->ij", v_t, k_t) + return h_next, h_next + + def single_head_scan(k_head, v_head, alpha_head, h0_head): + return jax.lax.scan(scan_body_fn, h0_head, (k_head, v_head, alpha_head)) + + multi_head_scan = jax.vmap(single_head_scan, in_axes=(0, 0, 0, 0), out_axes=(0, 0)) + batched_scan = jax.vmap(multi_head_scan, in_axes=(0, 0, 0, 0), out_axes=(0, 0)) + + k = rearrange(k, "b h l (dkn dks) -> dkn b h l dks", dks=128) + h0 = rearrange(h0, "b h dv (dkn dks) -> dkn b h dv dks", dks=128) + + batched_scan = jax.vmap(batched_scan, in_axes=(0, None, None, 0), out_axes=(0, 0)) + final_state, hidden_states = batched_scan(k, v, log_alpha, h0) + assert final_state is not None + + hidden_states = rearrange(hidden_states, "dkn b h l dv dks -> b h l (dkn dks) dv") + output = jnp.einsum( + "b h l s, b h l s d -> b h l d", q, hidden_states, preferred_element_type=jnp.float32 + ) + + if dtype == jnp.bfloat16: + output = output.astype(jnp.bfloat16) + return output, hidden_states + + +def ssd_linear_scan_w_timestep( + q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, timestep: Tensor, h0=None +) -> Tensor: + """LinearScan that takes timestep as input and masks useless k/v based on timestep. + + This function is used during inference where decoding might start from different timesteps. + + Args: + q: [batch_size, num_groups, seqlen, dk] + k: [batch_size, num_groups, seqlen, dk] + v: [batch_size, num_heads, seqlen, dv] + log_alpha: [batch_size, num_heads, seqlen] + h0: [batch_size, num_heads, dk, dv] or None + timestep: [batch_size, seqlen] or None + + Returns: + output: [batch_size, num_heads, seq_len, dv] + hidden_states: [batch_size, num_heads, dk, dv] + + """ + bs, ng, l, dk = q.shape + bs, nh, l, dv = v.shape + assert nh % ng == 0 + + num_head_per_group = nh // ng + q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + + timestep_mask = jnp.arange(l)[None, :] >= timestep[:, None] + k = jnp.where(timestep_mask[:, None, :, None], 0.0, k) + v = jnp.where(timestep_mask[:, None, :, None], 0.0, v) + log_alpha = jnp.where(timestep_mask[:, None, :], 0.0, log_alpha) + + if h0 is None: + h0 = jnp.zeros((bs, nh, dv, dk), dtype=jnp.float32) + else: + # to be consistent with pallas api, h0 is in dk x dv as input + h0 = rearrange(h0, "b h dk dv -> b h dv dk") + + dtype = q.dtype + if dtype == jnp.bfloat16: + q, k, v, h0 = map(lambda x: x.astype(jnp.float32), (q, k, v, h0)) + + def scan_body_fn(h_prev, current_inputs): + acc_dtype = h_prev.dtype + q_t, k_t, v_t, log_a_t = current_inputs + a_t = jnp.exp(log_a_t).astype(acc_dtype) + h_next = a_t * h_prev + jnp.einsum("i,j->ij", v_t, k_t, preferred_element_type=jnp.float32) + o_t = jnp.einsum("ij,j->i", h_next, q_t, preferred_element_type=jnp.float32) + return h_next, o_t.astype(q_t.dtype) + + def single_head_scan(q_head, k_head, v_head, alpha_head, h0_head): + return jax.lax.scan(scan_body_fn, h0_head, (q_head, k_head, v_head, alpha_head)) + + multi_head_scan = jax.vmap(single_head_scan, in_axes=(0, 0, 0, 0, 0), out_axes=(0, 0)) + batched_scan = jax.vmap(multi_head_scan, in_axes=(0, 0, 0, 0, 0), out_axes=(0, 0)) + + q = rearrange(q, "b h l (dkn dks) -> dkn b h l dks", dks=128) + k = rearrange(k, "b h l (dkn dks) -> dkn b h l dks", dks=128) + h0 = rearrange(h0, "b h dv (dkn dks) -> dkn b h dv dks", dks=128) + + batched_scan = jax.vmap(batched_scan, in_axes=(0, 0, None, None, 0), out_axes=(0, 0)) + final_state, output = batched_scan(q, k, v, log_alpha, h0) + final_state = rearrange(final_state, "dkn b h dv dks -> b h dv (dkn dks)") + output = jnp.sum(output, axis=0) + + final_state = rearrange(final_state, "b h dv dk -> b h dk dv") + + if dtype == jnp.bfloat16: + output = output.astype(jnp.bfloat16) + + return output, final_state diff --git a/axlearn/common/ssm_kernels/ssd_kernels_test.py b/axlearn/common/ssm_kernels/ssd_kernels_test.py new file mode 100644 index 00000000..2cc5c05a --- /dev/null +++ b/axlearn/common/ssm_kernels/ssd_kernels_test.py @@ -0,0 +1,389 @@ +# Copyright © 2024 Apple Inc. + +"""Tests SSD Pallas kernels.""" +from typing import Union + +import jax +import jax.nn as jnn +import jax.numpy as jnp +import numpy as np +import pytest +import torch +from absl.testing import parameterized +from einops import rearrange, repeat +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, PartitionSpec +from torch.nn import functional as F + +from axlearn.common.ssm_kernels.ssd_kernels import _ssd_backward, _ssd_forward, ssd, ssd_linear_scan +from axlearn.common.test_utils import TestCase, assert_allclose + +if jax.default_backend() != "tpu": + pytest.skip(reason="Incompatible hardware", allow_module_level=True) + + +def _ssd_reference(q, k, v, log_alpha, h0): + """Reference implementation of SSD for comparison. + + Args: + q/k: [batch_size, num_heads, seq_len, dk] + v: [batch_size, num_heads, seq_len, dv] + log_alpha: [batch_size, num_heads, seq_len] + h0: [batch_size, num_heads, dk, dv] + + Returns: + o: [batch_size, num_heads, seq_len, dv] + """ + return ssd_linear_scan(q, k, v, log_alpha, h0)[0] + + +def _ssd_naive_reference(q, k, v, log_alpha, h0=None): + """For-loop reference implementation of SSD. + + Note that this implementation somehow have worse + numerical stability than the vmap version above. + + Args: + q/k: [batch_size, num_heads, seq_len, dk] + v: [batch_size, num_heads, seq_len, dv] + log_alpha: [batch_size, num_heads, seq_len] + h0: [batch_size, num_heads, dk, dv] + + Returns: + o: [batch_size, num_heads, seq_len, dv] + h: [batch_size, num_heads, dk, dv] + """ + bs, ng, l, dk = q.shape + _, _, _, dv = v.shape + + bs, ng, l, dk = q.shape + bs, nh, l, dv = v.shape + assert nh % ng == 0 + + num_head_per_group = nh // ng + q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + + if h0 is None: + h0 = jnp.zeros((bs, nh, dk, dv), dtype=jnp.float32) + + o_list = [] + h = h0 + for t in range(l): + q_t = q[:, :, t] + k_t = k[:, :, t] + v_t = v[:, :, t] + alpha_t = jnp.exp(log_alpha[:, :, t, None, None]) + + h = alpha_t * h + jnp.einsum( + "...i,...j->...ij", k_t, v_t, preferred_element_type=jnp.float32 + ) + o_t = jnp.einsum("...ij,...i->...j", h, q_t, preferred_element_type=jnp.float32) + o_list.append(o_t) + o = jnp.stack(o_list, axis=2) + return o, h + + +# disable some pylint checks to allow copied code to pass checks + +# pylint: disable=line-too-long +# pylint: disable=invalid-name +# pylint: disable=unused-variable + + +def segsum(x): + """More stable segment sum calculation. Helper function copied from + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py. + """ + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_chunk_tri(X, A, B, C, chunk_size=16, initial_states=None): + """Reference implementation of SSD with chunked computation, copied from + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py. + + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + + X, A, B, C corresponds to V, \alpha, K, Q in linear attention + (H_t = \alpha H_{t-1)+ K_t^\top V_t, O_t = Q_t S_t). + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % chunk_size == 0 + + # Rearrange into blocks/chunks + X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=chunk_size) for x in (X, A, B, C)] + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +@jax.jit +def _ssd_reference_vjp( + q: jax.Array, + k: jax.Array, + v: jax.Array, + alpha: jax.Array, + h0: Union[jax.Array, None], + do: jax.Array, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + o, vjp = jax.vjp(_ssd_reference, q, k, v, alpha, h0) + return o, vjp(do) + + +def _generate_ssd_inputs(shape, dtype, seed, paramn="gla", zero_h0=True): + """ + Args: + shape: [bs, ng, nh, l, dk, dv] + dtype: float32, bfloat16 + seed: random seed + paramn: "mamba" or "gla" + zero_h0: whether to generate zero initial hidden state + + Returns: + q, k, v, log_alpha, h0, do + """ + bs, ng, nh, l, dk, dv = shape + rng = jax.random.PRNGKey(seed) + q_key, k_key, v_key, alpha_key, h_key, dh_key = jax.random.split(rng, 6) + + if paramn == "mamba": + q = jax.random.uniform(q_key, (bs, ng, l, dk), dtype=dtype) + k = jax.random.uniform(k_key, (bs, ng, l, dk), dtype=dtype) + v = jax.random.uniform(v_key, (bs, nh, l, dv), dtype=jnp.float32) + + log_alpha = -jnp.exp(jax.random.uniform(alpha_key, (bs, nh, l), dtype=jnp.float32)) + dt = jax.random.normal(alpha_key, (bs, nh, l), dtype=jnp.float32) + dt = jnp.log(1.0 + jnp.exp(dt - 4)) + + log_alpha = dt * log_alpha + v = v * dt[..., None] + elif paramn == "gla": + q = jax.random.normal(q_key, (bs, ng, l, dk), dtype=dtype) + k = jax.random.normal(k_key, (bs, ng, l, dk), dtype=dtype) + v = jax.random.normal(v_key, (bs, nh, l, dv), dtype=dtype) + + # shortconv (skipped) and non-linear activation + q = jnn.silu(q) + k = jnn.silu(k) + v = jnn.silu(v) + + # l2 norm (help reduces the range of dq/dk -> better precision for bfloat16) + q = q / jnp.linalg.norm(q, axis=-1, keepdims=True) + k = k / jnp.linalg.norm(k, axis=-1, keepdims=True) + + log_alpha = ( + jnn.log_sigmoid(jax.random.normal(alpha_key, (bs, nh, l), dtype=jnp.float32)) / 16.0 + ) + else: + raise ValueError(f"Unsupported param: {paramn}") + + if zero_h0: + h0 = jnp.zeros((bs, nh, dk, dv), dtype=jnp.float32) + else: + h0 = jax.random.normal(h_key, (bs, nh, dk, dv), dtype=jnp.float32) + + do = jax.random.normal(dh_key, (bs, nh, l, dv), dtype=dtype) + + # log_alpha is always in float32 + log_alpha = log_alpha.astype(jnp.float32) + return q, k, v, log_alpha, h0, do + + +class SSDPallasKernelTest(TestCase): + @parameterized.product( + batch_size=[2, 4], + num_heads=[4, 8], + seq_len=[1024, 2048], + dk=[128, 256], + dv=[128, 256], + seed=[0, 1], + ) + def test_ssd_forward( + self, batch_size: int, num_heads: int, seq_len: int, dk: int, dv: int, seed: int + ) -> None: + """Test SSD forward pass against Tri's torch reference implementation.""" + # Set the device to CPU + device = "cpu" + + # Set the random seed for reproducibility + np.random.seed(seed) + + # Generate random input data + x = np.random.rand(batch_size, seq_len, num_heads, dk).astype(np.float32) + dt = np.random.rand(batch_size, seq_len, num_heads).astype(np.float32) + dt = np.log(1.0 + np.exp(dt - 4)) + A = -np.exp(np.random.rand(batch_size, seq_len, num_heads).astype(np.float32)) + B = np.random.rand(batch_size, seq_len, num_heads, dv).astype(np.float32) + C = np.random.rand(batch_size, seq_len, num_heads, dv).astype(np.float32) + + # Compute intermediate variables + x_bar = x * dt[..., None] + A_bar = A * dt + + # Convert numpy arrays to torch tensors + x_torch = torch.tensor(x, dtype=torch.float32) + dt_torch = torch.tensor(dt, dtype=torch.float32) + A_torch = torch.tensor(A, dtype=torch.float32) + B_torch = torch.tensor(B, dtype=torch.float32) + C_torch = torch.tensor(C, dtype=torch.float32) + x_bar_torch = torch.tensor(x_bar, dtype=torch.float32) + A_bar_torch = torch.tensor(A_bar, dtype=torch.float32) + + # Compute the torch reference output + y_torch, _ = ssd_chunk_tri(x_bar_torch, A_bar_torch, B_torch, C_torch) + + # Convert numpy arrays to jax arrays + x_jax = jnp.array(x, dtype=jnp.float32) + dt_jax = jnp.array(dt, dtype=jnp.float32) + A_jax = jnp.array(A, dtype=jnp.float32) + B_jax = jnp.array(B, dtype=jnp.float32) + C_jax = jnp.array(C, dtype=jnp.float32) + x_bar_jax = jnp.array(x_bar, dtype=jnp.float32) + A_bar_jax = jnp.array(A_bar, dtype=jnp.float32) + + # Reshape jax arrays for comparison + x_jax = rearrange(x_jax, "b t h d -> b h t d") + dt_jax = rearrange(dt_jax, "b t h -> b h t") + A_jax = rearrange(A_jax, "b t h -> b h t") + B_jax = rearrange(B_jax, "b t h n -> b h t n") + C_jax = rearrange(C_jax, "b t h n -> b h t n") + x_bar_jax = rearrange(x_bar_jax, "b t h d -> b h t d") + A_bar_jax = rearrange(A_bar_jax, "b t h -> b h t") + + # Compute the jax output + y_jax = ssd(C_jax, B_jax, x_bar_jax, A_bar_jax, h0=None) + y_jax = rearrange(y_jax, "b h t d -> b t h d") + + assert_allclose(y_torch.numpy(), np.asarray(y_jax), atol=1e-3, rtol=1e-3) + + @parameterized.product( + batch_size=[2, 4], + num_heads=[4, 8], + seq_len=[1024, 2048], + dk=[128, 256], + dv=[128, 256], + dtype=["float32", "bfloat16"], + seed=[0, 1], + ) + def test_forward_and_backward(self, batch_size, num_heads, seq_len, dk, dv, dtype, seed): + try: + self.ssd_forward_and_backward(batch_size, num_heads, seq_len, dk, dv, dtype, seed) + except Exception as e: + # breakpoint() # uncomment for debugging failed conditions + raise e + + def ssd_forward_and_backward(self, batch_size, num_heads, seq_len, dk, dv, dtype, seed): + num_groups = num_heads + shape = (batch_size, num_groups, num_heads, seq_len, dk, dv) + q, k, v, log_alpha, h0, do = _generate_ssd_inputs(shape, dtype, seed) + if dtype == "float32": + tol = 1e-3 + elif dtype == "bfloat16": + tol = 1e-2 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + o_pallas, residuals = _ssd_forward(q, k, v, log_alpha, h0) + final_state_pallas = residuals[-1] + o_ref, final_state_ref = ssd_linear_scan(q, k, v, log_alpha, h0) + + assert_allclose(o_pallas, o_ref, atol=tol, rtol=tol) + assert_allclose(final_state_pallas, final_state_ref, atol=tol, rtol=tol) + + dq_pallas, dk_pallas, dv_pallas, dlog_alpha_pallas, dh0_pallas = _ssd_backward( + residuals, do + ) + _, ssd_reference_grad_ = jax.vjp(_ssd_reference, q, k, v, log_alpha, h0) + dq_ref, dk_ref, dv_ref, dlog_alpha_ref, dh0_ref = ssd_reference_grad_(do) + + assert_allclose(dq_pallas, dq_ref, atol=tol, rtol=tol) + assert_allclose(dk_pallas, dk_ref, atol=tol, rtol=tol) + assert_allclose(dv_pallas, dv_ref, atol=tol, rtol=tol) + assert_allclose(dlog_alpha_pallas, dlog_alpha_ref, atol=tol, rtol=tol) + assert_allclose(dh0_pallas, dh0_ref, atol=tol, rtol=tol) + + +class ShardSSDPallasKernelTest(TestCase): + # this test only works for four devices + @pytest.mark.skipif(jax.device_count() != 4, reason="Requires 4 devices") + def test_sharded_ssd_wo_sp(self): + batch, ngroups, nheads, seqlen, k_head_dim, v_head_dim = 8, 4, 4, 1024, 256, 128 + dtype = "float32" + q, k, v, log_alpha, _, _ = _generate_ssd_inputs( + (batch, ngroups, nheads, seqlen, k_head_dim, v_head_dim), dtype, 0 + ) + + o_ref, _ = ssd_linear_scan(q, k, v, log_alpha) + + devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) + mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + + def get_sharded_ssd(mesh): + """ + Note: current version assumes that h0 is None, for which you don't + need to provide partition spec. + """ + sharded_ssd = shard_map( + ssd, + mesh=mesh, + in_specs=( + PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None, None), + PartitionSpec( + ("data", "expert", "fsdp"), + ("seq", "model"), + None, + None, + ), + PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None, None), + PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None), + ), + out_specs=PartitionSpec(("data", "expert", "fsdp"), "model", "seq", None), + check_rep=False, + ) + return sharded_ssd + + sharded_ssd = get_sharded_ssd(mesh) + o_pallas = sharded_ssd(q, k, v, log_alpha) + + assert_allclose(o_pallas, o_ref, atol=1e-3, rtol=1e-3) diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 6ce8bb49..5774ba40 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -14,16 +14,19 @@ # Licensed under the Apache License, Version 2.0 (the "License"). -"""Tests Mamba and Jamba implementations.""" - +"""Tests Mamba/Mamba2 and Jamba implementations.""" import math from typing import Optional import jax import jax.numpy as jnp import numpy as np +import pytest import torch from absl.testing import parameterized +from jax._src.mesh import ResourceEnv, thread_resources +from jax.experimental import mesh_utils +from jax.sharding import Mesh, PartitionSpec from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig @@ -33,17 +36,28 @@ from axlearn.common.ssm import ( AssociativeScanMambaRecurrence, BlockResidualMode, + JambaMamba2Block, JambaMambaBlock, LinearScanMambaRecurrence, + Mamba2MixerLayer, MambaBlock, MambaMixerLayer, + PallasSSDRecurrence, RepeatedSSMLayer, StackedMixedSSMTransformerLayer, StackedSSMLayer, ) +from axlearn.common.ssm_kernels.ssd_kernels import ssd from axlearn.common.test_utils import TestCase, assert_allclose from axlearn.common.utils import Nested, Tensor, TensorSpec, cast_floats +try: + from mamba_ssm.modules.mamba2_simple import Mamba2Simple # pytype: disable=import-error + + MAMBA_INSTALLED = True +except ModuleNotFoundError: + MAMBA_INSTALLED = False + # The following PyTorch Mamba implementations are adapted from: # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/mamba/modeling_mamba.py # and @@ -941,3 +955,589 @@ def test_prefill(self, dtype: jnp.dtype): cfg.layer.self_attention.attention.num_heads = num_heads cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None) _test_prefill_states(cfg, model_dim=model_dim, dtype=dtype) + + +@pytest.mark.skipif( + jax.default_backend() != "tpu" or jax.device_count() != 4, + reason="Test requires four chips, e.g., one v5p gcp instance.", +) +class Mamba2RecurrenceTest(TestCase): + """Test the correctness of the Mamba2 recurrence for decoding.""" + + @classmethod + def setup_class(cls): + devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) + global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + thread_resources.env = new_env + + @classmethod + def teardown_class(cls): + init_env = ResourceEnv(physical_mesh=(), loops=()) + thread_resources.env = init_env + + def test_ssd_parameterization(self): + batch_size, num_heads, seq_len, state_dim, head_dim = 2, 4, 1024, 128, 256 + key = jax.random.PRNGKey(0) + dtype = jnp.float32 + + # note that construct random params requires that log_a <= 0 and delta > 0. + x = jax.random.normal(key, (batch_size, num_heads, seq_len, head_dim), dtype=dtype) + llog_a = jax.random.uniform(key, (1, num_heads, 1), dtype=dtype) + log_a = -jnp.exp(llog_a) + b = jax.random.normal(key, (batch_size, num_heads, seq_len, state_dim), dtype=dtype) + c = jax.random.normal(key, (batch_size, num_heads, seq_len, state_dim), dtype=dtype) + delta = jax.nn.softplus( + jax.random.uniform(key, (batch_size, num_heads, seq_len), dtype=dtype) - 4.0 + ) + d = jax.random.normal(key, (1, num_heads, 1, 1), dtype=dtype) + + mamba2_dim_to_partition_spec = { + "bhtd": PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None, None), + "bht": PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None), + } + output_partition_spec = PartitionSpec(("data", "expert", "fsdp"), "model", "seq", None) + + cfg = PallasSSDRecurrence.default_config().set( + name="test", + mamba2_dim_to_partition_spec=mamba2_dim_to_partition_spec, + output_partition_spec=output_partition_spec, + ) + layer = cfg.instantiate(parent=None) + o_module, _ = F( + layer, + inputs=dict(x=x, log_a=log_a, b=b, c=c, delta=delta, d=d), + state=None, + is_training=False, + prng_key=key, + ) + + # alternative input to the kernel; delta by default is applied to x to get x_bar, here we can + # also apply it to b to get b_bar first. + b_bar = b * jnp.expand_dims(delta, axis=-1) + loga_bar = log_a * delta + o_alternative = ssd(c, b_bar, x, loga_bar) + d * x + assert_allclose(o_module.data, o_alternative, atol=1e-1, rtol=1e-1) + + +@pytest.mark.skipif( + jax.default_backend() != "tpu" or jax.device_count() != 4, + reason="Test requires four chips, e.g., one v5p gcp instance.", +) +class Mamba2MixerLayerTest(TestCase): + @classmethod + def setup_class(cls): + devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) + global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + thread_resources.env = new_env + + @classmethod + def teardown_class(cls): + init_env = ResourceEnv(physical_mesh=(), loops=()) + thread_resources.env = init_env + + @parameterized.product( + dtype=(jnp.float32, jnp.bfloat16), + inference_mode=(True, False), + ) + def test_extend_step(self, dtype: jnp.dtype, inference_mode: bool): + batch_size = 2 + input_dim = 512 + state_dim = 128 + num_heads = 2 + seq_len = 1024 + num_groups = 2 + expansion_factor = 1 + output_dim = input_dim + cache_dtype = dtype + + cfg = Mamba2MixerLayer.default_config().set( + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + expansion_factor=expansion_factor, + dtype=dtype, + cache_dtype=cache_dtype, + ) + + layer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + + if inference_mode: + # inference recurrence can return the ssd states for testing + layer.recurrence = layer.inference_recurrence + + inputs_data = jax.random.uniform( + jax.random.PRNGKey(1), [batch_size, seq_len, input_dim], dtype=dtype + ) + inputs = dict(query=inputs_data) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + + mamba2_cache = layer.init_states(target_batch_size=batch_size, target_max_len=seq_len) + self.assertEqual(mamba2_cache.x_conv_state.dtype, cache_dtype) + self.assertEqual(mamba2_cache.b_conv_state.dtype, cache_dtype) + self.assertEqual(mamba2_cache.c_conv_state.dtype, cache_dtype) + self.assertEqual(mamba2_cache.ssd_state.dtype, cache_dtype) + self.assertEqual(forward_outputs.data.dtype, dtype) + + inputs = dict(cache=mamba2_cache) + decoder_output = jnp.zeros(shape=[seq_len, batch_size, output_dim], dtype=dtype) + for t in range(seq_len): + inputs["query"] = inputs_data[:, t : t + 1, :] + (mamba2_cache, mamba2output), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=inputs, + method="extend_step", + ) + inputs["cache"] = mamba2_cache + decoder_output = decoder_output.at[t].set(jnp.squeeze(mamba2output.data, axis=1)) + + decoder_output_transposed = jnp.transpose(decoder_output, [1, 0, 2]) + + if dtype == jnp.float32: + final_state_diff_tol = 1e-2 + output_tol = 1e-1 + else: + final_state_diff_tol = 1e-1 + output_tol = 2e0 + + if inference_mode: + forward_final_state = forward_outputs.ssd_state[:, :, -1] + final_state_diff = jnp.abs((forward_final_state - mamba2_cache.ssd_state)).max() + self.assertTrue(final_state_diff < final_state_diff_tol) + + # ssm output diff will get a bit amplified by the ffn layer + assert_allclose( + decoder_output_transposed, forward_outputs.data, atol=output_tol, rtol=output_tol + ) + + @parameterized.product(dtype=(jnp.float32, jnp.bfloat16)) + def test_prefill_states(self, dtype: jnp.dtype): + batch_size = 2 + input_dim = 512 + state_dim = 256 + num_heads = 4 + seq_len = 1024 + num_groups = 2 + expansion_factor = 2 + cache_dtype = jnp.float32 + + cfg = Mamba2MixerLayer.default_config().set( + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + expansion_factor=expansion_factor, + dtype=dtype, + cache_dtype=cache_dtype, + ) + + layer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + + # full forward pass as reference + inputs_data = jax.random.uniform( + jax.random.PRNGKey(1), [batch_size, seq_len, input_dim], dtype=dtype + ) + inputs = dict(query=inputs_data) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + + # prefill stage + time_step = jnp.arange(batch_size) + (initial_state, initial_output), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=dict(time_step=time_step, query=inputs_data), + method="prefill_states", + ) + self.assertTrue(initial_state.x_conv_state.dtype, cache_dtype) + self.assertTrue(initial_state.b_conv_state.dtype, cache_dtype) + self.assertTrue(initial_state.c_conv_state.dtype, cache_dtype) + self.assertTrue(initial_state.ssd_state.dtype, cache_dtype) + self.assertTrue(initial_output.data.dtype, dtype) + + time_step_mask = (jnp.arange(seq_len) < time_step[:, None]).astype(dtype) + decoder_output = initial_output.data * time_step_mask[..., None] + + inputs = dict(cache=initial_state) + while jnp.any(time_step < seq_len): + inputs["query"] = jnp.take_along_axis( + inputs_data, time_step[:, None, None], axis=1, mode="clip" + ) + (updated_state, outputs), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(4), + inputs=inputs, + method="extend_step", + ) + inputs["cache"] = updated_state + + # [batch_size, 1, output_dim] + cur_outputs = outputs.data + + # [batch_size, seq_len, 1] + oh_indices = jax.nn.one_hot(time_step, seq_len, dtype=dtype)[..., None] + decoder_output = decoder_output + cur_outputs * oh_indices + + time_step = time_step + 1 + + assert_allclose(decoder_output, forward_outputs.data, atol=1e-1, rtol=1e-1) + + +@pytest.mark.skipif( + jax.default_backend() != "tpu" or jax.device_count() != 4, + reason="Test requires four chips, e.g., one v5p gcp instance.", +) +class JambaMamba2BlockTest(TestCase): + @classmethod + def setup_class(cls): + devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) + global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + thread_resources.env = new_env + + @classmethod + def teardown_class(cls): + init_env = ResourceEnv(physical_mesh=(), loops=()) + thread_resources.env = init_env + + @parameterized.product( + input_dim=[1024, 2048], + state_dim=[128, 256], + num_heads=[2, 4], + num_groups=[2, 4], + dtype=[jnp.float32, jnp.bfloat16], + ) + def forward( + self, input_dim: int, state_dim: int, num_heads: int, num_groups: int, dtype: jnp.dtype + ): + mamba2block_cfg = JambaMamba2Block.default_config().set( + name="test", + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + dtype=dtype, + ) + mamba2block_cfg.feed_forward = mamba2block_cfg.feed_forward.set(hidden_dim=2 * input_dim) + mamba2block = mamba2block_cfg.instantiate(parent=None) + mamba2block_params = mamba2block.initialize_parameters_recursively( + prng_key=jax.random.PRNGKey(0) + ) + + batch_size, tgt_len = 2, 1024 + x = jax.random.uniform(jax.random.PRNGKey(1), [batch_size, tgt_len, input_dim], dtype=dtype) + + outputs, _ = F( + mamba2block, + inputs=(x,), + state=mamba2block_params, + is_training=True, + prng_key=jax.random.PRNGKey(2), + ) + + self.assertEqual(outputs.data.shape, x.shape) + self.assertEqual(outputs.data.dtype, x.dtype) + + @parameterized.product( + batch_size=[2, 4], + input_dim=[1024, 2048], + seq_len=[1024, 2048], + state_dim=[128, 256], + num_heads=[2, 4], + num_groups=[2, 4], + dtype=[jnp.float32, jnp.bfloat16], + ) + def extend_step( + self, + batch_size: int, + input_dim: int, + seq_len: int, + state_dim: int, + num_heads: int, + num_groups: int, + dtype: jnp.dtype, + ): + cfg = JambaMamba2Block.default_config().set( + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + dtype=dtype, + ) + cfg.feed_forward = cfg.feed_forward.set(hidden_dim=2 * input_dim) + layer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + + inputs_data = jax.random.normal( + jax.random.PRNGKey(1), [batch_size, seq_len, input_dim], dtype=dtype + ) + inputs = dict(data=inputs_data) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + + init_state = layer.init_states(target_batch_size=batch_size, target_max_len=seq_len) + self.assertEqual(init_state["mamba_block"].x_conv_state.dtype, dtype) + self.assertEqual(init_state["mamba_block"].b_conv_state.dtype, dtype) + self.assertEqual(init_state["mamba_block"].c_conv_state.dtype, dtype) + self.assertEqual(init_state["mamba_block"].ssd_state.dtype, dtype) + + inputs = dict(cached_states=init_state) + decoder_output = jnp.zeros(shape=[seq_len, batch_size, input_dim]) + for t in range(seq_len): + inputs["data"] = inputs_data[:, t : t + 1, :] + extend_step_output, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=inputs, + method="extend_step", + ) + inputs["cached_states"] = extend_step_output[0] + decoder_output = decoder_output.at[t].set( + jnp.squeeze(extend_step_output[1].data, axis=1) + ) + + decoder_output_transposed = jnp.transpose(decoder_output, [1, 0, 2]) + assert_allclose(decoder_output_transposed, forward_outputs.data, atol=1e-1, rtol=1e-1) + + @parameterized.product( + batch_size=[2], + input_dim=[1024], + state_dim=[256], + num_heads=[2], + seq_len=[1024], + num_groups=[2], + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_prefill_states( + self, + batch_size: int, + input_dim: int, + seq_len: int, + state_dim: int, + num_heads: int, + num_groups: int, + dtype: jnp.dtype, + ): + cfg = JambaMamba2Block.default_config().set( + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + dtype=dtype, + ) + cfg.feed_forward = cfg.feed_forward.set(hidden_dim=2 * input_dim) + layer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + + inputs_data = jax.random.normal( + jax.random.PRNGKey(1), [batch_size, seq_len, input_dim], dtype=dtype + ) + inputs = dict(data=inputs_data) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + + time_step = jnp.arange(batch_size) + (initial_state, initial_output), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=dict(time_step=time_step, data=inputs_data), + method="prefill_states", + ) + + time_step_mask = (jnp.arange(seq_len) < time_step[:, None]).astype(dtype) + decoder_output = initial_output.data * time_step_mask[..., None] + + inputs = dict(cached_states=initial_state) + for _ in range(seq_len): + inputs["data"] = jnp.take_along_axis( + inputs_data, time_step[:, None, None], axis=1, mode="clip" + ) + (updated_state, outputs), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=inputs, + method="extend_step", + ) + inputs["cached_states"] = updated_state + + # [batch_size, 1, output_dim] + cur_outputs = outputs.data + + # [batch_size, seq_len, 1] + oh_indices = jax.nn.one_hot(time_step, seq_len, dtype=dtype)[..., None] + decoder_output = decoder_output + cur_outputs * oh_indices + + time_step = time_step + 1 + + assert_allclose(decoder_output, forward_outputs.data, atol=1e-1, rtol=1e-1) + + +@pytest.mark.skipif( + jax.default_backend() != "gpu" or not MAMBA_INSTALLED, + reason="Test requires mamba_ssm to be installed on a GPU machine", +) +class GPUMamba2MixerLayerTest(TestCase): + @classmethod + def setup_class(cls): + num_devices = jax.device_count() + devices = mesh_utils.create_device_mesh((1, 1, 1, 1, num_devices)) + global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + thread_resources.env = new_env + + @classmethod + def teardown_class(cls): + init_env = ResourceEnv(physical_mesh=(), loops=()) + thread_resources.env = init_env + + @parameterized.product( + batch_size=[2, 4], + seq_len=[512, 1024], + expansion_factor=[1, 2], + ) + def test_forward(self, batch_size: int, seq_len: int, expansion_factor: int): + if self.mamba_ssm is None: + self.skipTest("mamba_ssm needs to be installed on a GPU machine for testing") + + d_model, d_state, expansion_factor = 512, 128, 2 + head_dim, num_groups = 128, 4 + d_inner = expansion_factor * d_model + num_heads = d_inner // head_dim + + def _j2t(param): + """Convert jax array to torch tensor.""" + return torch.from_numpy(np.array(param)) + + inputs_data = jax.random.normal(jax.random.PRNGKey(1), [batch_size, seq_len, d_model]) + inputs_torch = _j2t(inputs_data) + + # pylint: disable=undefined-variable + ref_model = Mamba2Simple( + d_model=d_model, + d_state=d_state, + headdim=head_dim, + ngroups=num_groups, + expand=expansion_factor, + use_mem_eff_path=False, + ) + + jax_model = ( + Mamba2MixerLayer.default_config() + .set( + input_dim=d_model, + state_dim=d_state, + num_groups=num_groups, + num_heads=num_heads, + expansion_factor=expansion_factor, + bc_norm=None, + dtype=jnp.float32, + cache_dtype=jnp.float32, + ) + .set(name="test") + .instantiate(parent=None) + ) + jax_params = jax_model.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + jax_params = cast_floats(jax_params, to_dtype=jnp.float32) + + # use linearscan kernel which is already tested against pallas kernel. + jax_model.recurrence = jax_model.inference_recurrence + + # copying the weights from the jax model to the ref model + inputs = dict(query=inputs_data) + forward_outputs, _ = F( + jax_model, + state=jax_params, + is_training=True, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + jax_output_np = np.array(forward_outputs.data) + + # in_proj <-> [z, x, B, C, dt] + xz_w = _j2t(jax_params["xz_proj"]["weight"]) # [d_model, 2, d_inner] + bc_w = _j2t(jax_params["bc_proj"]["weight"]) # [d_model, 2, dk] + dt_w = _j2t(jax_params["dt_proj"]["weight"]) # [d_model, num_heads] + zxBCdt_w = torch.cat([xz_w[:, 1], xz_w[:, 0], bc_w[:, 0], bc_w[:, 1], dt_w], dim=1) + ref_model.in_proj.weight.data.copy_(zxBCdt_w.T) + + # conv1d <-> [x_conv, b_conv, c_conv] + x_conv_w = _j2t(jax_params["x_conv"]["weight"]) + x_conv_bias = _j2t(jax_params["x_conv"]["bias"]) + b_conv_w = _j2t(jax_params["b_conv"]["weight"]) + b_conv_bias = _j2t(jax_params["b_conv"]["bias"]) + c_conv_w = _j2t(jax_params["c_conv"]["weight"]) + c_conv_bias = _j2t(jax_params["c_conv"]["bias"]) + xbc_conv_w = torch.cat([x_conv_w, b_conv_w, c_conv_w], dim=2) + xbc_conv_bias = torch.cat([x_conv_bias, b_conv_bias, c_conv_bias], dim=0) + ref_model.conv1d.weight.data.copy_(xbc_conv_w.T) + ref_model.conv1d.bias.data.copy_(xbc_conv_bias) + + # out_proj <-> out_proj + out_w = _j2t(jax_params["out_proj"]["weight"]) + ref_model.out_proj.weight.data.copy_(out_w.T) + + # A_log <-> llog_a + a_w = _j2t(jax_params["llog_a"]) # [1, num_heads, 1] + ref_model.A_log.data.copy_(a_w[0, :, 0]) + + # dt_bias <-> dt_bias + dt_bias = _j2t(jax_params["dt_bias"]) + ref_model.dt_bias.data.copy_(dt_bias) + + # D <-> d + d = _j2t(jax_params["d"]) # [1, 1, num_heads, 1] + ref_model.D.data.copy_(d[0, 0, :, 0]) + + # norm <-> pre_out_proj_norm + norm_scale = _j2t(jax_params["pre_out_proj_norm"]["scale"]) + ref_model.norm.weight.data.copy_(norm_scale) + + device = "cuda:0" + ref_model = ref_model.to(device) + inputs_torch = inputs_torch.to(device) + torch_output = ref_model(inputs_torch) + torch_output_np = torch_output.cpu().detach().numpy() + + assert_allclose(torch_output_np, jax_output_np, atol=1e-2, rtol=1e-2) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index a6056076..06720378 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -200,6 +200,14 @@ class Config(Module.Config): # The provided config should instantiate to a thunk that returns the context manager. context_manager: Optional[ConfigOr[Callable[[], ContextManager]]] = None + # If False, assumes the train_step may need to be recompiled and go through the lowering + # and compilation process every train step and rely on compilation cache to prevent + # excessive recompilations. Note: this could introduce overhead to training due to + # pre-compilation checks (such as sharding check) that increases the step time for some + # models. Note that this cache is always disabled at steps when xsc is enabled. + # Defaults to None which is interpreted as True. + cache_compiled_train_step: Optional[bool] = None + def __init__( self, cfg: Config, @@ -273,6 +281,7 @@ def __init__( else: xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy + self._compiled_train_step: Optional[jax.stages.Compiled] = None # Create all children within the mesh context so that utils.input_partition_spec() works # properly. @@ -964,7 +973,8 @@ def _get_compiled_train_step_fn( ) -> Callable[[TrainerState, NestedTensor], tuple[TrainerState, NestedTensor]]: """Build a fully compiled train step function. - Relies on the JAX pjit cache to avoid recompilation where possible. + Relies on the JAX pjit cache to avoid recompilation when with_xsc=True or + cache_compiled_train_step=False. Args: train_state: A TrainerState instance. @@ -977,8 +987,17 @@ def _get_compiled_train_step_fn( Raises: RuntimeError: If `with_xsc` is requested on heterogenous device kinds. """ + if ( + not (self.config.cache_compiled_train_step is False) + and not with_xsc + and self._compiled_train_step is not None + ): + return self._compiled_train_step if not with_xsc: - return self.compile_train_step(trainer_state=trainer_state, input_batch=input_batch) + self._compiled_train_step = self.compile_train_step( + trainer_state=trainer_state, input_batch=input_batch + ) + return self._compiled_train_step # Get device kinds and assert that they are homogenous. device_kinds = set(d.device_kind for d in jax.devices()) if len(device_kinds) != 1: @@ -1103,7 +1122,6 @@ def compile_train_step( Returns: A compiled training step, with signature matching self._pjit_train_step's return. """ - with self.mesh(), self._context_manager(): if trainer_state is None: # Do not run init(), which requires real devices. @@ -1114,13 +1132,7 @@ def compile_train_step( if input_batch is None: # Infer input batch shapes from input element spec. # N.B. in a multi-process setting these will be host-local (per process). - # TODO(markblee): This path currently assumes input_tf_data; fix for generic inputs. - input_batch = jax.tree.map( - lambda tf_spec: jax.ShapeDtypeStruct( - shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype - ), - self.input.dataset().element_spec, # pytype: disable=attribute-error - ) + input_batch = self.input.element_spec() # Rely on the instance handle to ensure that we hit the compilation cache if possible. jit_train_step = self._jit_train_step or self._pjit_train_step() # Note(Jan 2022): diff --git a/axlearn/common/trainer_test.py b/axlearn/common/trainer_test.py index 39177429..7ce772ea 100644 --- a/axlearn/common/trainer_test.py +++ b/axlearn/common/trainer_test.py @@ -2,10 +2,12 @@ """Tests SpmdTrainer.""" -# pylint: disable=no-self-use import copy import dataclasses import math + +# pylint: disable=no-self-use +import os import os.path import shutil import tempfile @@ -70,6 +72,8 @@ NUM_CLASSES = 16 +os.environ["TPU_SKIP_MDS_QUERY"] = "1" + class DummyInput(Module): """A dummy input.""" @@ -172,6 +176,14 @@ def __iter__(self): # guaranteed to be savable). yield from self.dataset() + def element_spec(self): + return jax.tree.map( + lambda tf_spec: jax.ShapeDtypeStruct( + shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype + ), + self.dataset().element_spec, + ) + class DummyModel(BaseModel): """A dummy model.""" @@ -452,20 +464,23 @@ def test_trainer( # The prng_key per step is deterministic. np.testing.assert_array_equal(output_a["aux"]["prng_key"], output_b["aux"]["prng_key"]) - @parameterized.parameters( - {"platform": "cpu", "mesh_shape": (1, 1)}, - {"platform": "tpu", "mesh_shape": (4, 1)}, + @parameterized.product( + [{"platform": "cpu", "mesh_shape": (1, 1)}, {"platform": "tpu", "mesh_shape": (4, 1)}], + enable_python_cache=[True, False], ) # pylint: enable=duplicate-code - def test_xsc_check_policy( + def test_xsc_check_policy_and_compilation_cache( self, *, platform, mesh_shape, + enable_python_cache, ): if not test_utils.is_supported_platform(platform): return - cfg = SpmdTrainer.default_config().set(name="test_trainer", train_dtype=jnp.bfloat16) + cfg: SpmdTrainer.Config = SpmdTrainer.default_config().set( + name="test_trainer", train_dtype=jnp.bfloat16 + ) cfg.dir = tempfile.mkdtemp() cfg.mesh_axis_names = ("data", "model") cfg.mesh_shape = mesh_shape @@ -485,6 +500,7 @@ def test_xsc_check_policy( cfg.vlog = 2 # Set XSC policy. cfg.xsc_check_policy = lambda step: (step in [7, 8]) + cfg.cache_compiled_train_step = enable_python_cache # Test training run. trainer: SpmdTrainer = cfg.set(max_step=12).instantiate(parent=None) @@ -508,13 +524,25 @@ def mock_compile_train_step(*args, compiler_options=None, **kwargs): output_a = trainer.run(prng_key=jax.random.PRNGKey(123)) end_cache_hits = pjit_lib._pjit_lower_cached.cache_info().hits # pylint: enable=protected-access - # We expect to have hit the lowering cache on all but one step. - self.assertEqual(end_cache_hits - start_cache_hits, cfg.max_step - 1) - self.assertEqual(mocked_compile_fn.call_count, cfg.max_step) if platform == "tpu": + if not enable_python_cache: + # We expect to have hit the lowering cache on all but one step. + self.assertEqual(end_cache_hits - start_cache_hits, cfg.max_step - 1) + self.assertEqual(mocked_compile_fn.call_count, cfg.max_step) + else: + # We expect to have hit the lowering cache on xsc steps. + self.assertEqual(end_cache_hits - start_cache_hits, 2) + self.assertEqual(mocked_compile_fn.call_count, 3) # Should have been called with compile options on two steps. self.assertEqual(compiled_with_options_call_count[0], 2) else: + if not enable_python_cache: + self.assertEqual(end_cache_hits - start_cache_hits, cfg.max_step - 1) + self.assertEqual(mocked_compile_fn.call_count, cfg.max_step) + else: + # We won't hit any cache since we have python cache. + self.assertEqual(end_cache_hits - start_cache_hits, 0) + self.assertEqual(mocked_compile_fn.call_count, 1) # XSC check should be disabled. self.assertEqual(compiled_with_options_call_count[0], 0) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 0846f4d8..ffd1b4d8 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -23,23 +23,23 @@ import types from collections.abc import Mapping, Sequence from enum import Enum -from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union +from typing import Any, Callable, NamedTuple, Optional, Protocol, TypeVar, Union import jax import numpy as np from absl import logging from jax import numpy as jnp from jax._src.ad_checkpoint import name_p -from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.mesh import thread_resources from jax._src.tree_util import KeyEntry, KeyPath +from jax.ad_checkpoint import Offloadable, Recompute, Saveable from jax.core import Primitive from jax.experimental import mesh_utils, multihost_utils from jax.sharding import PartitionSpec from axlearn.common import serialization -from axlearn.common.config import is_named_tuple +from axlearn.common.config import ConfigOr, is_named_tuple, maybe_instantiate # New code should use Nested[XX] instead of NestedXX. # Old definitions are provided for backwards compatibility. @@ -105,8 +105,13 @@ def sharding(self) -> jax.sharding.Sharding: NestedTensorSpec = Optional[Union[TensorSpec, dict[str, Any]]] +RematType = Union[type(Saveable), Offloadable, type(Recompute)] SavePattern = Union[str, re.Pattern, None] -OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]] + + +class RematPolicy(Protocol): + def __call__(self, prim: Primitive, *args: Any, **params: Any) -> Union[RematType, bool]: + ... def save_and_offload_only_these_names_regex( @@ -115,8 +120,9 @@ def save_and_offload_only_these_names_regex( names_which_can_be_offloaded: SavePattern, offload_src: str, offload_dst: str, -) -> OffloadPolicy: +) -> RematPolicy: """Adapted from jax source code to support regex. + Reference: https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 @@ -132,19 +138,19 @@ def save_and_offload_only_these_names_regex( """ def policy(prim, *_, **params): - if prim is name_p: + if str(prim) == str(name_p): if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]): - return pe.Saveable + return Saveable if names_which_can_be_offloaded and re.fullmatch( names_which_can_be_offloaded, params["name"] ): - return pe.Offloadable(src=offload_src, dst=offload_dst) - return pe.Recompute # not saveable unless it's in the allow-list + return Offloadable(src=offload_src, dst=offload_dst) + return Recompute # not saveable unless it's in the allow-list return policy -def offload_dots_saveable(offload_src: str, offload_dst: str) -> OffloadPolicy: +def offload_dots_saveable(offload_src: str, offload_dst: str) -> RematPolicy: """Extract from offload_dot_with_no_batch_dims and remove no-batch-dims limit. https://github.com/google/jax/blob/f4158ace933482844c145a6b919bf5dc86e084ba/jax/_src/ad_checkpoint.py#L81C1-L90C1 @@ -160,9 +166,109 @@ def offload_dots_saveable(offload_src: str, offload_dst: str) -> OffloadPolicy: # pylint: disable-next=unused-argument def policy(prim, *_, **params): - if prim is lax_internal.dot_general_p: - return pe.Offloadable(src=offload_src, dst=offload_dst) - return pe.Recompute + if str(prim) == str(lax_internal.dot_general_p): + return Offloadable(src=offload_src, dst=offload_dst) + return Recompute + + return policy + + +class RematCombineFn(Protocol): + def __call__( + self, + p1: RematType, + p2: RematType, + *, + prim: Primitive, + args: tuple[Any], + kwargs: dict[str, Any], + ) -> RematType: + """Protocol for remat policy combine function. + + Args: + p1: Remat type returned by policy 1 for `prim`. + p2: Remat type returned by policy 2 for `prim`. + prim: The jax primitive for which the remat type will be applied. + args: Positional arguments passed to RematPolicy + kwargs: Keyword arguments passed to RematPolicy. + + Returns: + RematType: The final remat type for `prim`. + """ + + +def default_remat_combine_fn(preferred_remat_type: Optional[RematType] = None) -> RematCombineFn: + """The default remat policy combine function. + + If the two policies return conflicting remat types and neither is `Recompute`: + - If `preferred_remat_type` is None, raises `RuntimeError`. + - If `preferred_remat_type` is not None, `preferred_remat_type` will be the resulting + remat type. + + Args: + preferred_remat_type: Indicates how to resolve remat type conflicts. + + Returns: + A `RematCombineFn` for use in `combine_remat_policies`. + """ + + def combine_fn( + p1: RematType, + p2: RematType, + *, + prim: Primitive, + args: tuple[Any], + kwargs: dict[str, Any], + ): + del args, kwargs + if p1 is not Recompute and p2 is not Recompute: + if p1 is not p2: + if preferred_remat_type is None: + raise RuntimeError( + f"Conflict in remat policies for primitive {prim}. " + f"Got policy 1 = {p1}, policy 2 = {p2}. " + "Please specify preferred_remat_type to resolve conflicts." + ) + else: + return preferred_remat_type + return p1 + else: + if p1 is not Recompute: + return p1 + return p2 + + return combine_fn + + +def combine_remat_policies( + policy_1: RematPolicy, + policy_2: RematPolicy, + *, + combine_fn: ConfigOr[RematCombineFn] = default_remat_combine_fn(), +): + """Returns a remat policy that combines the two policies with `combine_fn`. + + Args: + policy_1: Remat policy 1. + policy_2: Remat policy 2. + combine_fn: A function that combines and potentially resolves conflicts of the remat types + from the two policies. The default `combine_fn` chooses the policy that does not return + `Recompute` and raises `RuntimeError` if both are not `Recompute` and are different. + + Returns: + A `RematPolicy`. + """ + combine_fn = maybe_instantiate(combine_fn) + + def convert_to_enum(p: Union[RematType, bool]) -> RematType: + if isinstance(p, bool): + p = Saveable if p else Recompute + return p + + def policy(prim, *args, **kwargs): + p1 = convert_to_enum(policy_1(prim, *args, **kwargs)) + p2 = convert_to_enum(policy_2(prim, *args, **kwargs)) + return combine_fn(p1, p2, prim=prim, args=args, kwargs=kwargs) return policy @@ -170,6 +276,7 @@ def policy(prim, *_, **params): extended_checkpoint_policies = types.SimpleNamespace( offload_dots_saveable=offload_dots_saveable, save_and_offload_only_these_names_regex=save_and_offload_only_these_names_regex, + combine_remat_policies=combine_remat_policies, ) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 001a873e..259e6a96 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -149,7 +149,7 @@ def get_trainer_kwargs( extended_checkpoint_policies.save_and_offload_only_these_names_regex ).set( names_which_can_be_saved=None, - names_which_can_be_offloaded=RematRegexSavePatterns.SELF_ATTENTION.value, + names_which_can_be_offloaded=RematRegexSavePatterns.NATIVE_ATTENTION.value, offload_src="device", offload_dst="pinned_host", )