Skip to content

Commit

Permalink
Merge branch 'apple:main' into rope_emb_pos
Browse files Browse the repository at this point in the history
  • Loading branch information
Firenze11 authored Jan 17, 2025
2 parents 2bb2a2b + b0ee05e commit bd32156
Show file tree
Hide file tree
Showing 25 changed files with 3,446 additions and 81 deletions.
31 changes: 21 additions & 10 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@
from axlearn.common.utils import (
Nested,
NestedTensor,
OffloadPolicy,
PartitionSpec,
RematPolicy,
SavePattern,
Tensor,
TensorSpec,
Expand Down Expand Up @@ -4003,8 +4003,6 @@ def forward(
# TODO(sneha): extend_step


# Adapted from jax source code to support regex. Reference:
# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120
# TODO(kelvin-zou): deprecated, keep it here to minimize distruption to the golden configs.
# Please use axlearn.common.utils.extended_checkpoint_policies instead.
def _save_and_offload_only_these_names_regex(
Expand All @@ -4013,7 +4011,7 @@ def _save_and_offload_only_these_names_regex(
names_which_can_be_offloaded: SavePattern,
offload_src: str,
offload_dst: str,
) -> OffloadPolicy:
) -> RematPolicy:
return save_and_offload_only_these_names_regex(
names_which_can_be_saved=names_which_can_be_saved,
names_which_can_be_offloaded=names_which_can_be_offloaded,
Expand All @@ -4029,30 +4027,43 @@ class RematRegexSavePatterns(enum.Enum):
CONTEXT = r".*context"
LINEAR1_X = r".*linear1_[01]"
LINEAR2_X = r".*linear2_[01]"
SELF_ATTENTION = ".*([qkvo]_proj|context)"
# This is called native attention because the "context" remat point only exists when using
# native attention, e.g. `MultiheadAttention` or `GroupedQueryAttention`.
NATIVE_ATTENTION = ".*([qkvo]_proj|context)"
FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X])


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value,
save_pattern: SavePattern = RematRegexSavePatterns.NATIVE_ATTENTION.value,
offload_pattern: SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
"""Configures how the Transformer or Conformer stack will save the linearization points.
We try to save activations from the forward pass that are inefficient to recompute on the
backward pass. We choose the linearization points in the MultiHeadAttention layer, as that
demonstrated (empirically) the best throughput, allowing us to train with a batch size of 16 on
gpt2-10b with adamw and full sharding across 4 TPU v4 chips and a RepeatedTransformerLayer,
with 1.8x the step time of a stacked layer with a batch size of 8 and the same sharding config.
backward pass which are mainly matrix multiplications. By default, we don't save linear
layer's output due to the large expansion factor.
For conformer model, we start from the same remat policy as language models.
TODO(zhiyunlu): investigate Conformer model's memory/step-time tradeoffs. Possibly we
need to save points in the LConv module.
Note that the default `save_pattern`, `NATIVE_ATTENTION`, doesn't save the context tensor when
using `FlashAttention`. To save it when using `FlashAttention`, use the policy from the module
`axlearn.common.flash_attention.remat`:
```python
from axlearn.common.utils import save_and_offload_these_names_regex
from axlearn.common.flash_attention.remat import save_or_offload_flash_attention_policy
combine_remat_policies(
save_and_offload_these_names_regex(...),
save_or_offload_flash_attention_policy()
)
```
Args:
stack_cfg: A transformer config.
save_pattern: Activation regex pattern to save in HBM.
Expand Down
4 changes: 4 additions & 0 deletions axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def infer_tpu_version(tpu_type: str) -> str:
"""
tpu_type = infer_tpu_type(tpu_type)
tpu_version = tpu_type.rsplit("-", 1)[0] # split from the last occurrence of '-'
# Resolve aliases like v5e to v5litepod, since in some cases (e.g. aot compilation) v5e is
# expected.
tpu_version = _TPU_VERSION_ALIASES.get(tpu_version, tpu_version)
if tpu_version not in _TPU_VERSIONS:
raise ValueError(f"Unknown TPU version {tpu_version}. Expected one of {_TPU_VERSIONS}")
return tpu_version
Expand Down Expand Up @@ -238,4 +241,5 @@ def infer_xsc_compiler_options(
return options


_TPU_VERSION_ALIASES = {"v5e": "v5litepod"}
_TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p", "v6e")
7 changes: 7 additions & 0 deletions axlearn/common/compiler_options_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

from axlearn.common import compiler_options, test_utils
from axlearn.common.utils import Tensor
Expand Down Expand Up @@ -48,3 +49,9 @@ def test_xsc_compiler_options(self):
)
for name, option in options.items():
self.assertEqual(option, expected_options[name])

@parameterized.parameters(
dict(tpu_type="v5e-16", expected="v5litepod"),
)
def test_tpu_version_alias(self, tpu_type: str, expected: str):
self.assertEqual(expected, compiler_options.infer_tpu_version(tpu_type))
1 change: 1 addition & 0 deletions axlearn/common/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def prefill_states(
),
**kwargs,
)
self.add_module_output("prefill_hidden_states", outputs["hidden_states"])
states = dict(time_step=time_step, input_ids=input_ids, **states)
return states, outputs

Expand Down
8 changes: 7 additions & 1 deletion axlearn/common/decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,14 @@ def tokens_to_scores(token_ids, cache):
side_effect=mock_tokens_to_scores,
)

# Drop any module outputs added in `method` to avoid leaking tracers via checkify.
def method_fn(*args, **kwargs):
out = getattr(decoder, method)(*args, **kwargs)
decoder.get_invocation_context().get_module_outputs().clear()
return out

# Checkify the decoding method being called.
decoder._checked_method = checkify.checkify(getattr(decoder, method))
decoder._checked_method = checkify.checkify(method_fn)

# pylint: enable=protected-access
with mock_ctx:
Expand Down
3 changes: 2 additions & 1 deletion axlearn/common/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,10 @@ def forward(self, input: Tensor) -> Tensor:
Each tensor has shape [batch_size, 1|num_length, dim].
"""
cfg = self.config
if input.ndim not in (2, 3):
raise ValueError(f"The input must be rank 2 or 3, but got the {input.shape} tensor.")
x = get_activation_fn(cfg.activation)(input)
output = self.linear(x)
assert output.ndim in (2, 3)
if output.ndim == 2:
output = einops.rearrange(output, "b d -> b 1 d")
output = jnp.split(output, cfg.num_outputs, axis=-1)
Expand Down
3 changes: 3 additions & 0 deletions axlearn/common/flash_attention/gpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def _mha_forward_kernel(
See also `_mha_backward_kernel` for the backward pass.
Note: the kernel name is used to do string matching for rematerialization in `remat.py`. Be
careful when renaming this.
Args:
q_ref: Input query ref.
k_ref: Input key ref.
Expand Down
53 changes: 26 additions & 27 deletions axlearn/common/flash_attention/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,32 @@ def _prepare_layers(
return test_layer, ref_layer, params, hidden_dim


class DummyModel(BaseLayer):
"""A dummy model."""

@config_class
class Config(BaseLayer.Config):
layer: GroupedQueryAttention.Config = GroupedQueryAttention.default_config()

def __init__(self, cfg: Config, *, parent: Module):
super().__init__(cfg, parent=parent)
cfg = self.config
self._add_child("layer", cfg.layer)

def forward(self, *, query, key, value, attention_logit_biases, segment_ids):
# [batch, target_length, target_dim].
x = self.layer(
query,
key=key,
value=value,
attention_logit_biases=attention_logit_biases,
segment_ids=segment_ids,
)
# TODO(markblee,zhaoyi-zhang): The atol needs to increase significantly if using
# jnp.sum, as we no longer scale by the size of the data dims.
return jnp.mean(x.data, dtype=query.dtype)


class TestFlashAttention(TestCase):
"""Tests FlashAttention layer."""

Expand Down Expand Up @@ -512,34 +538,7 @@ def test_backward(
pytest.skip(reason="Only one of causal and use_bias can be True.")

with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names):

class DummyModel(BaseLayer):
"""A dummy model."""

@config_class
class Config(BaseLayer.Config):
layer: GroupedQueryAttention.Config = GroupedQueryAttention.default_config()

def __init__(self, cfg: Config, *, parent: Module):
super().__init__(cfg, parent=parent)
cfg = self.config
self._add_child("layer", cfg.layer)

def forward(self, *, query, key, value, attention_logit_biases, segment_ids):
# [batch, target_length, target_dim].
x = self.layer(
query,
key=key,
value=value,
attention_logit_biases=attention_logit_biases,
segment_ids=segment_ids,
)
# TODO(markblee,zhaoyi-zhang): The atol needs to increase significantly if using
# jnp.sum, as we no longer scale by the size of the data dims.
return jnp.mean(x.data, dtype=query.dtype)

hidden_dim = num_heads * per_head_dim

if sliding_window_size is not None:
mask_fn = config_for_function(sliding_window_causal_mask).set(
sliding_window_size=sliding_window_size
Expand Down
59 changes: 59 additions & 0 deletions axlearn/common/flash_attention/remat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright © 2025 Apple Inc.
"""Remat policy for FlashAttention kernels."""

from jax._src.cudnn.fused_attention_stablehlo import _dot_product_attention_fwd_p_wrapper
from jax.custom_derivatives import custom_vjp_call_jaxpr_p
from jax.experimental.pallas import pallas_call_p

from axlearn.common.utils import Recompute, RematPolicy, RematType, Saveable


def save_or_offload_flash_attention_policy(remat_type: RematType = Saveable) -> RematPolicy:
"""Returns a remat policy for FlashAttention output.
This remat policy allows saving attention output, which is the tensor before out projection
commonly named "context". More precisely, it saves the attention output of GPU Pallas kernel,
TPU Legacy Pallas kernel, TPU SplashAttention kernel, and cuDNN FlashAttention kernel.
Because cuDNN FlashAttention and TPU SplashAttention invocations are in Jax source code, it's
not feasible to save the output using `checkpoint_name`. Therefore, we match the Jax primitives
to implement this save policy.
Note for users: for context length >= 4096, FlashAttention kernel takes noticeably longer on
both TPU and GPU to execute than o_proj. Therefore, saving the output of FlashAttention is
more advantages than saving o_proj since they have roughly the same memory footprint if the HBM
capacity doesn't allow saving both.
Args:
remat_type: Remat type. Defaults to Saveable (save to HBM) and only supports Saveable.
Returns:
A RematPolicy. Users can combine this remat policy with any existing policy with
`axlearn.common.utils.combine_remat_policies`.
"""
# Jax bug: https://github.com/jax-ml/jax/issues/25841.
# TODO(hanzhi-zhou): add support for Offloadable when jax supports it.
if remat_type is not Saveable:
raise NotImplementedError(f"{remat_type=} is not implemented.")

def policy(prim, *_, **params):
src_info = ""
# Primitives could be copies if modules are reinitialized, so `is` check is unreliable.
# Use string equality instead.
prim_s = str(prim)
if prim_s == str(pallas_call_p):
src_info = str(params.get("name_and_src_info", ""))
if prim_s == str(custom_vjp_call_jaxpr_p):
src_info = str(params.get("fun_jaxpr", ""))
# GPU Pallas kernel.
if "_mha_forward_kernel" in src_info:
return remat_type
# TPU new and legacy Pallas kernel.
if "flash_attention_kernel" in src_info:
return remat_type
# cuDNN kernel.
if prim_s == str(_dot_product_attention_fwd_p_wrapper):
return remat_type
return Recompute

return policy
Loading

0 comments on commit bd32156

Please sign in to comment.