diff --git a/axlearn/cloud/gcp/tpu_health_check.py b/axlearn/cloud/gcp/tpu_health_check.py index 5f7c18999..067503928 100644 --- a/axlearn/cloud/gcp/tpu_health_check.py +++ b/axlearn/cloud/gcp/tpu_health_check.py @@ -31,6 +31,7 @@ from contextlib import contextmanager from datetime import datetime from typing import Literal, Optional, Union +import pytest import tensorflow as tf from absl import flags, logging @@ -61,10 +62,9 @@ def _parse_spec_and_check_if_should_skip( timeout = float(check_split[1]) break else: - logging.info( - "Skipping %s slice health check because check spec is %s.", check_type, check_spec + pytest.skip( + reason=f"Skipping {check_type} slice health check because check spec is {check_spec}.", ) - return None # These environment variables are set by GKE. if "MEGASCALE_NUM_SLICES" not in os.environ or "NODE_NAME" not in os.environ: @@ -74,12 +74,9 @@ def _parse_spec_and_check_if_should_skip( total_slices = int(os.environ["MEGASCALE_NUM_SLICES"]) if total_slices < num_slices_lower_bound: - logging.info( - "Skipping %s slice health check since num_slices < %d.", - check_type, - num_slices_lower_bound, + pytest.skip( + reason=f"Skipping {check_type} slice health check since num_slices < {num_slices_lower_bound}.", ) - return None return timeout diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index b1a9a7e3b..5dd43aded 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -111,6 +111,7 @@ get_or_none, shapes, split_prng_key, + maybe_shard, ) NEG_INF = -1e15 @@ -1260,24 +1261,46 @@ def apply_rotary_position_embeddings( """ # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] + # sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) + # # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + # sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape) + # # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + # cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape) + # # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] + # rotate_half_query = jnp.reshape( + # jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape + # ) + # query = query * cos_pos + rotate_half_query * sin_pos + # # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + # rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape) + # key = key * cos_pos + rotate_half_key * sin_pos + # if rotary_value: + # # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] + # rotate_half_value = jnp.reshape( + # jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape + # ) + # value = value * cos_pos + rotate_half_value * sin_pos + # return query, key, value + + def _rotate_half(x: jnp.ndarray) -> jnp.ndarray: + halves = jnp.split(x, 2, axis=-1) + return jnp.concatenate((-halves[1], halves[0]), axis=-1) + sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape) # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape) # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] - rotate_half_query = jnp.reshape( - jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape - ) + rotate_half_query = _rotate_half(query) + query = query * cos_pos + rotate_half_query * sin_pos # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] - rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape) + rotate_half_key = _rotate_half(key) key = key * cos_pos + rotate_half_key * sin_pos if rotary_value: # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] - rotate_half_value = jnp.reshape( - jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape - ) + rotate_half_value = _rotate_half(value) value = value * cos_pos + rotate_half_value * sin_pos return query, key, value @@ -2525,6 +2548,12 @@ class Config(BaseLayer.Config): # Ref: https://github.com/google/praxis/blob/main/praxis/layers/transformers.py#L1129 # TODO (bwzhang@) Adding a unittest for the hybridnorm. structure: str = "prenorm" + # If not None, how to partition pre norm activation values. + prenorm_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition pre attention activation values. + preattention_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition post attention activation values. + postattention_partition_spec: Optional[tuple[Optional[str]]] = None def __init__(self, cfg: Config, *, parent: Module): super().__init__(cfg, parent=parent) @@ -2648,10 +2677,14 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: return dict(attention=atten_state), atten_output if cfg.structure == "prenorm": + target = maybe_shard(target, cfg.prenorm_partition_spec) skip_input = target # pre-norm: where normalization happens within the residual part. norm_target = self.norm(target) + norm_target = maybe_shard(norm_target, cfg.preattention_partition_spec) atten_state, atten_output = attention_thunk(norm_target) + atten_output = maybe_shard(atten_output, cfg.postattention_partition_spec) data = skip_input + self.stochastic_depth(self.dropout(atten_output.data)) + data = self._remat_name(data, 'residual_add') elif cfg.structure == "postnorm": # This is the structure used by the original Transformer, BERT, and RoBERTa. atten_state, atten_output = attention_thunk(target) @@ -2878,6 +2911,13 @@ class Config(BaseLayer.Config): # TODO(tlei3): deprecate this feature since we use TensorStats. add_value_rms_norm_summary: Sequence[str] = [] + # If not None, how to partition pre norm activation values. + prenorm_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition pre MLP activation values. + premlp_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition post MLP activation values. + postmlp_partition_spec: Optional[tuple[Optional[str]]] = None + def __init__(self, cfg: Config, *, parent: Module): super().__init__(cfg, parent=parent) cfg: TransformerFeedForwardLayer.Config = self.config @@ -2942,17 +2982,21 @@ def _linear2(x): remat_pt1 = "activation" remat_pt2 = "linear2" if cfg.structure == "prenorm": + inputs = maybe_shard(inputs, cfg.prenorm_partition_spec) x = self.norm(inputs) + x = maybe_shard(x, cfg.premlp_partition_spec) x = self._linear1_activation(x) x = self._remat_name(x, remat_pt1) x = self.dropout1(x) x = _linear2(x) x = self._remat_name(x, remat_pt2) + x = maybe_shard(x, cfg.postmlp_partition_spec) x = self.dropout2(x) x = self.stochastic_depth(x) if cfg.residual_weight != 1: x *= cfg.residual_weight x += inputs + x=self._remat_name(x, 'mlp_residual') elif cfg.structure == "postnorm": x = self._linear1_activation(inputs) x = self._remat_name(x, remat_pt1) @@ -2998,7 +3042,7 @@ def _linear1_activation(self, x: Tensor) -> Tensor: if isinstance(cfg.activation, tuple): activations = [ self._get_activation( - self.children[f"linear1_{i}"](x), activation_fn_name=activation + self._remat_name(self.children[f"linear1_{i}"](x), f"linear1_{i}"), activation_fn_name=activation ) for i, activation in enumerate(cfg.activation) ] @@ -3477,7 +3521,8 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names) # Encourage the right activation sharding. ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) - ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + # ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, None) if not isinstance(cfg, Sequence): cfg = [cfg] @@ -3488,6 +3533,21 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): set_attn_partition_specs(layer_cfg.cross_attention.attention) if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config): set_ffn_partition_specs(layer_cfg.feed_forward) + + # Neuron backend needs fine grained activation sharding. + if jax.default_backend() == 'neuron': + prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None) + preattention_partition_spec = (fsdp_axis_names, None, None) + postattention_partition_spec = (fsdp_axis_names, tp_axis_names, None) + + layer_cfg.self_attention.set( + prenorm_partition_spec=prenorm_partition_spec, + preattention_partition_spec=preattention_partition_spec, + postattention_partition_spec=postattention_partition_spec) + layer_cfg.feed_forward.set( + prenorm_partition_spec=prenorm_partition_spec, + premlp_partition_spec=preattention_partition_spec, + postmlp_partition_spec=postattention_partition_spec) # pytype: enable=attribute-error @@ -4071,6 +4131,20 @@ def forward( # TODO(sneha): extend_step +def save_only_these(*names_to_save): + # Save all values, including unnamed ones, excluding the specified names. + names_to_save = frozenset(names_to_save) + def policy(prim, *_, **params): + if 'name' in params and params['name'] in names_to_save: + print(f"[WIP] Saving {params['name']}") + return True + elif 'name' in params: + print(f"[WIP] Not saving tensor: {params['name']}") + return False + else: + print("[WIP] Not saving unnamed tensor") + return False + return policy def build_remat_spec( stack_cfg: Union[ @@ -4106,20 +4180,40 @@ def build_remat_spec( if stack_cfg.klass is PipelinedTransformerLayer: return None + backend = jax.default_backend() checkpoints = [] if self_attention: attention_name = stack_cfg.layer.self_attention.attention.klass.__name__ - checkpoints.extend( - [f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]] - ) - + if backend != "neuron": + checkpoints.extend( + [f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]] + ) + elif jax.device_count() > (64 * 8): + checkpoints.extend( + [f"{attention_name}.{el}" for el in ['q_proj', 'k_proj', 'v_proj']] + ["TransformerAttentionLayer.residual_add", "TransformerFeedForwardLayer.mlp_residual"] + ) + else: + checkpoints.extend( + [f"{attention_name}.{el}" for el in ['q_proj', 'k_proj', 'v_proj']] + ) if feed_forward and hasattr(stack_cfg.layer, "feed_forward"): ffn_name = stack_cfg.layer.feed_forward.klass.__name__ - checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]]) + if backend != "neuron": + checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]]) + elif jax.device_count() > (64 * 8): + checkpoints.extend([f"{ffn_name}.{el}" for el in ["linear1_0", "linear1_1"]]) + else: + checkpoints.extend([f"{ffn_name}.{el}" for el in ["linear1_0"]]) + + if backend != "neuron": + policy = config_for_function(jax_remat_policies.save_only_these_names).set( + names_which_can_be_saved=checkpoints + ) + else: + policy = config_for_function(save_only_these).set( + names_to_save=checkpoints + ) - policy = config_for_function(jax_remat_policies.save_only_these_names).set( - names_which_can_be_saved=checkpoints - ) if offload_dst: policy = config_for_function(jax_remat_policies.save_and_offload_only_these_names).set( names_which_can_be_saved=[], diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index d7b59f62e..ae013b769 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -5041,6 +5041,70 @@ def test_set_double_shard_weights_config_for_list_of_configs( (fsdp_axis_names, tp_axis_names, None), ) + @parameterized.product( + self_attention_input_linear_cfg=( + QKVLinear.default_config(), + FusedQKVLinear.default_config(), + ), + cross_attention_cfg=(None, TransformerAttentionLayer.default_config()), + batch_axis_names=("data", ("replica", "data", "fsdp")), + fsdp_axis_names=("fsdp",), + tp_axis_names=("model",), + seq_axis_names=("seq",), + ) + def test_set_activation_shardings_config_for_list_of_configs( + self, + self_attention_input_linear_cfg, + cross_attention_cfg, + batch_axis_names, + fsdp_axis_names, + tp_axis_names, + seq_axis_names, + ): + cfg_layer: TransformerLayer.Config = TransformerLayer.default_config().set( + cross_attention=cross_attention_cfg + ) + cfg_layer.self_attention.structure = "prenorm" + cfg_layer.feed_forward.structure = "prenorm" + cfg_layer.self_attention.attention.input_linear = self_attention_input_linear_cfg + cfg_layers = [cfg_layer, cfg_layer] + + cfg_layer.self_attention.prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None,) + cfg_layer.self_attention.preattention_partition_spec = (fsdp_axis_names, None, None) + cfg_layer.self_attention.postattention_partition_spec = (fsdp_axis_names, tp_axis_names, None) + + cfg_layer.feed_forward.prenorm_partition_spec = (fsdp_axis_names, tp_axis_names, None,) + cfg_layer.feed_forward.premlp_partition_spec = (fsdp_axis_names, None, None) + cfg_layer.feed_forward.postmlp_partition_spec = (fsdp_axis_names, tp_axis_names, None) + + for cfg in cfg_layers: + self_atten = cfg.self_attention + feed_forward = cfg.feed_forward + self.assertSequenceEqual( + self_atten.prenorm_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) + self.assertSequenceEqual( + self_atten.preattention_partition_spec, + (fsdp_axis_names, None, None), + ) + self.assertSequenceEqual( + self_atten.postattention_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) + + self.assertSequenceEqual( + feed_forward.prenorm_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) + self.assertSequenceEqual( + feed_forward.premlp_partition_spec, + (fsdp_axis_names, None, None), + ) + self.assertSequenceEqual( + feed_forward.postmlp_partition_spec, + (fsdp_axis_names, tp_axis_names, None), + ) class PositionalEmbeddingTest(TestCase): """Tests PositionalEmbedding.""" diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index 0513d8718..9790d661f 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -298,6 +298,8 @@ def _metrics( target_labels: Tensor, target_num_bytes: Optional[Tensor], ) -> dict[str, Tensor]: + if logits.dtype in (jnp.bfloat16, jnp.float16): + logits = logits.astype(jnp.float32) live_targets = (target_labels != self.decoder.config.pad_token_id) & (target_labels >= 0) num_targets = live_targets.sum() accuracy = ( diff --git a/axlearn/common/evaler.py b/axlearn/common/evaler.py index 391946f75..592a44eb4 100644 --- a/axlearn/common/evaler.py +++ b/axlearn/common/evaler.py @@ -31,6 +31,7 @@ from axlearn.common.module import Module, OutputCollection from axlearn.common.module import functional as F from axlearn.common.utils import ( + DataPartitionType, NestedPartitionSpec, NestedTensor, Tensor, @@ -81,6 +82,11 @@ class Config(Module.Config): # evalers, not setting prefix will show the accuracies on the same plot for comparison # across evalers. prefix: Optional[str] = None + # Subset of mesh axis names over which the leaves of the input batch are sharded. + batch_axis_names: Union[str, Sequence[str]] = "data" + # The input partition: + # Options: FULL (default), BATCH, REPLICATED + input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL def __init__( self, @@ -188,11 +194,11 @@ def _pjit(self, fn: Callable) -> Callable: in_shardings=( self._model_param_partition_specs, # model_params. None, # replicated_inputs (e.g., prng_key). - utils.input_partition_spec(), # per_example_inputs. + utils.data_partition_type_to_spec(partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), # per_example_inputs. ), out_shardings=dict( replicated=None, - per_example=utils.input_partition_spec(), + per_example=utils.data_partition_type_to_spec( partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), ), ) @@ -574,6 +580,11 @@ class Config(Module.Config): metric_calculator: BaseMetricCalculator.Config = ModelSummaryAccumulator.default_config() # If not None, writes input batches and `metric_calculator` forward outputs. output_writer: Optional[BaseOutputWriter.Config] = None + # Subset of mesh axis names over which the leaves of the input batch are sharded. + batch_axis_names: Union[str, Sequence[str]] = "data" + # The input partition: + # Options: FULL (default), BATCH, REPLICATED + input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL def __init__( self, @@ -595,7 +606,7 @@ def __init__( self._add_child("input", maybe_set_config(cfg.input, is_training=False)) self._add_child( "metric_calculator", - cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype), + cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype, batch_axis_names=cfg.batch_axis_names, input_partition_type=cfg.input_partition_type), model=model, model_param_partition_specs=model_param_partition_specs, ) @@ -691,7 +702,7 @@ def eval_step( with jax.profiler.StepTraceAnnotation(cfg.name, step_num=step): with jax.profiler.TraceAnnotation(f"{cfg.name}.forward"): - global_input_batch = utils.host_to_global_device_array(input_batch) + global_input_batch = utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names) forward_outputs = self.metric_calculator.forward( global_input_batch, model_params=model_params, diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 085eb39dc..2408968bd 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -25,7 +25,7 @@ from axlearn.common.flash_attention.utils import mha_reference if jax.default_backend() != "gpu": - pytest.skip(reason="Incompatible hardware", allow_module_level=True) + pytestmark = pytest.mark.skip(reason="Incompatible hardware, GPU only test.") @pytest.mark.parametrize( diff --git a/axlearn/common/flash_attention/neuron_attention.py b/axlearn/common/flash_attention/neuron_attention.py new file mode 100644 index 000000000..46e689a1c --- /dev/null +++ b/axlearn/common/flash_attention/neuron_attention.py @@ -0,0 +1,130 @@ +from absl import logging +from functools import partial +import jax +import jax.numpy as jnp +import jax.numpy as jnp +from jax import custom_vjp +import jax_neuronx +import os + +lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1 + +@partial(custom_vjp, nondiff_argnums=(4, 5)) +def flash_attention(query, key, value, bias, causal, softmax_scale): + # NOTE : Merge with upstream. Old code supports both 2d and 4d bias but upstream code only supports 4d. + # We no longer need 2d logit_bias but should sync how we merge this check with upstream. + out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale) + return out + + +def _mha_forward(query, key, value, bias, causal, softmax_scale): + # Get the batch size, sequence lengths, number of heads, and hidden dimension + batch_size, q_seq_len, num_heads, d_model = query.shape + + # Transpose the query, key, and value tensors + q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len] + k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len] + v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model] + + import neuronxcc.nki.language as nl + from neuronxcc.nki.kernels.attention import flash_fwd + seed = jnp.array([1]) + + # Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads + if (num_heads % 2) == 0 and (num_heads // 2 > 0): + grid = batch_size, nl.nc(lnc) * (num_heads // lnc) + else: + grid = batch_size, num_heads + + if bias != None: + assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" + attn_output, lse = flash_fwd[grid]( + q, + k, + v, + seed, + bias, + use_causal_mask=causal, + softmax_scale=softmax_scale, + mixed_precision=True, + dropout_p=0.0, + ) + else: + attn_output, lse = flash_fwd[grid]( + q, + k, + v, + seed, + use_causal_mask=causal, + softmax_scale=softmax_scale, + mixed_precision=True, + dropout_p=0.0, + ) + # Transpose the output back to the original shape + attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model] + + return attn_output, (lse, attn_output, q, k, v, bias) + + +def _mha_backward(causal, softmax_scale, res, d_attn_output): + lse, o, q, k, v, bias = res + batch_size, num_heads, d_model, seq_len = q.shape + + # Transpose the input tensors + o = o.transpose(0, 2, 3, 1) + dy = d_attn_output.transpose(0, 2, 3, 1) + + # Transpose v tensor + v = jnp.transpose(v, axes=(0, 1, 3, 2)) + seed = jnp.array([1]) + + from neuronxcc.nki.kernels.attention import flash_attn_bwd + import neuronxcc.nki.language as nl + + # Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads + if (num_heads % 2) == 0 and (num_heads // 2 > 0): + grid = batch_size, nl.nc(lnc) * (num_heads // lnc) + else: + grid = batch_size, num_heads + + if bias != None: + assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" + d_query, d_key, d_value = flash_attn_bwd[grid]( + q, + k, + v, + o, + dy, + lse, + seed, + bias, + use_causal_mask=causal, + mixed_precision=True, + dropout_p=0.0, + softmax_scale=softmax_scale, + ) + else: + d_query, d_key, d_value = flash_attn_bwd[grid]( + q, + k, + v, + o, + dy, + lse, + seed, + use_causal_mask=causal, + mixed_precision=True, + dropout_p=0.0, + softmax_scale=softmax_scale, + ) + + # Batch seq_len heads, head_dim + # Transpose the gradients back to the original shape + d_query = d_query.transpose(0, 3, 1, 2) + d_key = d_key.transpose(0, 3, 1, 2) + d_value = d_value.transpose(0, 3, 1, 2) + + return d_query, d_key, d_value, None + + +flash_attention.defvjp(_mha_forward, _mha_backward) \ No newline at end of file diff --git a/axlearn/common/flash_attention/neuron_attention_test.py b/axlearn/common/flash_attention/neuron_attention_test.py new file mode 100644 index 000000000..8e90e7438 --- /dev/null +++ b/axlearn/common/flash_attention/neuron_attention_test.py @@ -0,0 +1,132 @@ +# Copyright © 2024 Amazon Inc. +"""Tests for Flash attention on Neuron. Tested on trn1.""" +import functools + +import chex +import jax +import jax.numpy as jnp +import pytest + +from axlearn.common.flash_attention.neuron_attention import flash_attention +from axlearn.common.flash_attention.utils import mha_reference + + +if jax.default_backend() != "neuron": + pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.") + + +@pytest.mark.parametrize( + "batch_size,seq_len,num_heads,per_head_dim", + [ + (1, 2048, 1, 64), + (2, 2048, 2, 64), + (1, 2048, 1, 128), + (2, 2048, 2, 128), + (1, 2048, 8, 128), + (2, 2048, 8, 128), + ], +) +@pytest.mark.parametrize("use_fwd", [True, False]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32]) +def test_fwd_against_ref( + batch_size: int, + seq_len: int, + num_heads: int, + per_head_dim: int, + use_fwd: bool, + causal: bool, + input_dtype: jnp.dtype, +): + sm_scale = 1.0 / (per_head_dim**0.5) + k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) + q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) + k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) + v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) + + bias = None + segment_ids = None + + if use_fwd: + + @jax.jit + def impl(q, k, v, bias): + fn = functools.partial( + flash_attention, + causal=causal, + softmax_scale=sm_scale, + ) + out, _ = jax.vjp(fn, q, k, v, bias) + return out + + else: + impl = functools.partial( + flash_attention, + causal=causal, + softmax_scale=sm_scale, + ) + + o = impl(q, k, v, bias) + o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale) + chex.assert_trees_all_close(o, o_ref, atol=0.05) + + +@pytest.mark.parametrize( + "batch_size,num_heads,seq_len,per_head_dim", + [ + (1, 1, 2048, 64), + (2, 2, 2048, 64), + (1, 1, 2048, 128), + (2, 2, 2048, 128), + (1, 8, 2048, 128), + (2, 8, 2048, 128), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) +def test_bwd_against_ref( + batch_size: int, + num_heads: int, + seq_len: int, + per_head_dim: int, + causal: bool, + input_dtype: jnp.dtype, +): + sm_scale = 1.0 / (per_head_dim**0.5) + q = jax.random.normal( + jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype + ) + k = jax.random.normal( + jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype + ) + v = jax.random.normal( + jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype + ) + + bias = None + segment_ids = None + + def fn(q, k, v, bias): + return flash_attention( + q, + k, + v, + bias, + causal=causal, + softmax_scale=sm_scale, + ).sum() + + def ref_fn(q, k, v, bias, segment_ids): + return mha_reference( + q, + k, + v, + bias, + segment_ids, + causal=causal, + softmax_scale=sm_scale, + ).sum() + + jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias) + jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) + chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07) diff --git a/axlearn/common/flash_attention/tpu_attention_test.py b/axlearn/common/flash_attention/tpu_attention_test.py index 1ba38d7a4..8d968ec30 100644 --- a/axlearn/common/flash_attention/tpu_attention_test.py +++ b/axlearn/common/flash_attention/tpu_attention_test.py @@ -23,7 +23,7 @@ from axlearn.common.utils import Tensor if jax.default_backend() != "tpu": - pytest.skip(reason="Incompatible hardware", allow_module_level=True) + pytestmark = pytest.mark.skip(reason="Incompatible hardware, TPU only test.") def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor: diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 7da0543f1..070bed859 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -11,6 +11,7 @@ from axlearn.common.attention import NEG_INF, MaskFn, causal_mask, softmax_with_biases from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention from axlearn.common.flash_attention.gpu_attention import flash_attention as gpu_flash_attention +from axlearn.common.flash_attention.neuron_attention import flash_attention as neuron_flash_attention from axlearn.common.flash_attention.tpu_attention import tpu_flash_attention from axlearn.common.utils import Tensor @@ -75,7 +76,7 @@ def mha_reference( def flash_attention_implementation( - backend: Literal["cpu", "tpu", "gpu", "xla"], + backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"], *, mask: Optional[MaskFn] = None, softmax_scale: float, @@ -159,6 +160,17 @@ def jit_attn(query, key, value, bias, segment_ids): return jit_attn + elif backend == "neuron": + # shard_map-decorated function needs to be jitted. + @jax.jit + def jit_attn(query, key, value, bias, segment_ids): + if segment_ids != None: + raise Exception("Sequence Packing is not supported on Neuron backend") + return neuron_flash_attention( + query, key, value, bias, causal, softmax_scale) + + return jit_attn + elif backend in ("cpu", "xla"): if backend == "cpu": logging.warning("Flash attention CPU backend is for testing only.") diff --git a/axlearn/common/gda_test.py b/axlearn/common/gda_test.py index edf415517..03b820384 100644 --- a/axlearn/common/gda_test.py +++ b/axlearn/common/gda_test.py @@ -28,7 +28,7 @@ class GDATest(TestCase): itertools.product( ((1, 1), (8, 1), (4, 2)), # mesh_shape (1, 16), # per_host_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition + (DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType.BATCH), # data_partition ) ) def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition): @@ -43,11 +43,16 @@ def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition devices = mesh_utils.create_device_mesh(mesh_shape) if data_partition == DataPartitionType.FULL: global_batch_size = per_host_batch_size * jax.process_count() + elif data_partition == DataPartitionType.BATCH: + global_batch_size = per_host_batch_size * jax.process_count() else: assert data_partition == DataPartitionType.REPLICATED global_batch_size = per_host_batch_size if data_partition == DataPartitionType.FULL and global_batch_size < jax.device_count(): return + # first axis is assumed to be batch axis + if data_partition == DataPartitionType.BATCH and global_batch_size % mesh_shape[0] == 0: + return per_host_input_batch = dict(x=jnp.zeros((per_host_batch_size, 8), dtype=jnp.float32)) with jax.sharding.Mesh(devices, ("data", "model")): global_input_batch = host_to_global_device_array( diff --git a/axlearn/common/host_array_test.py b/axlearn/common/host_array_test.py index 1a3adc417..73b27ddc9 100644 --- a/axlearn/common/host_array_test.py +++ b/axlearn/common/host_array_test.py @@ -12,6 +12,7 @@ from absl.testing import absltest, parameterized from jax import numpy as jnp from jax.experimental import mesh_utils +import pytest from axlearn.common.test_utils import TestCase, is_supported_mesh_shape, is_supported_platform from axlearn.common.utils import ( @@ -28,26 +29,24 @@ def is_supported( global_batch_size: int, data_partition: DataPartitionType, ): - return ( - is_supported_platform(platform) - and is_supported_mesh_shape(mesh_shape) - and ( - data_partition == DataPartitionType.REPLICATED - or global_batch_size % jax.device_count() == 0 - ) - ) + if not is_supported_platform(platform): + return False, f'Platform "{platform}" not supported with devices {jax.devices()}.' + if not is_supported_mesh_shape(mesh_shape): + return False, f'Mesh shape "{mesh_shape}" not supported with device_count "{jax.device_count()}".' + if data_partition != DataPartitionType.REPLICATED: + return False, f'Data partition is "{data_partition}", expected "DataPartitionType.REPLICATED".' + if global_batch_size % jax.device_count() != 0: + return False, 'Global batch has to be divisible with number of devices. Global batch is "{global_batch_size}", number of devices is "{jax.device_count()}".' + return True , "" class HostArrayTest(TestCase): @parameterized.parameters( - filter( - lambda params: is_supported(*params), - itertools.product( - ("cpu", "tpu"), # platform, - ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape - (1, 16), # global_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition - ), + itertools.product( + ("cpu", "tpu"), # platform, + ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape + (1, 16), # global_batch_size + (DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType.BATCH), # data_partition ) ) def test_global_host_array_conversion( @@ -64,6 +63,9 @@ def test_global_host_array_conversion( global_batch_size, data_partition, ) + supported, reason = is_supported(platform, mesh_shape, global_batch_size, data_partition) + if not supported: + pytest.skip(reason) devices = mesh_utils.create_device_mesh(mesh_shape) mesh = jax.sharding.Mesh(devices, ("data", "model")) logging.info("Global mesh: %s", mesh) diff --git a/axlearn/common/inference_test.py b/axlearn/common/inference_test.py index d6d946ed3..e25314783 100644 --- a/axlearn/common/inference_test.py +++ b/axlearn/common/inference_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest, parameterized from jax.experimental import mesh_utils from jax.experimental.pjit import pjit +import pytest from axlearn.common import layers, test_utils, utils from axlearn.common.base_model import BaseModel @@ -168,22 +169,22 @@ def predict_batch(self, input_batch: NestedTensor) -> NestedTensor: def is_supported( platform: str, mesh_shape: tuple[int, int], - param_dtype: jnp.dtype, inference_dtype: Optional[jnp.dtype], global_batch_size: int, data_partition: DataPartitionType, - use_ema: bool = False, ): - del param_dtype, use_ema # not used # TODO(xuan-zou): jax 0.4.25 breaks bfloat16 on CPU due to high variance on # the final result (up to 10% precision diff), will re-enable when fixed. # NOTE: bfloat16 test on GPU is added and verified. - return ( - test_utils.is_supported_platform(platform) - and np.prod(mesh_shape) == jax.device_count() - and (data_partition != DataPartitionType.FULL or global_batch_size >= jax.device_count()) - and ((inference_dtype != jnp.bfloat16) or platform != "cpu") - ) + if not test_utils.is_supported_platform(platform): + return False, f'Platform "{platform}" not supported with devices "{jax.devices()}".' + if np.prod(mesh_shape) != jax.device_count(): + return False, f'mesh_shape "{mesh_shape}" mush add up to number of devices "{jax.device_count()}"' + if not (data_partition != DataPartitionType.FULL or global_batch_size >= jax.device_count()): + return False, f"Either data_partition ('{data_partition}') is not DataPartitionType.FULL or global_batch_size ('{global_batch_size}') is >= to jax.device_count() ({jax.device_count()})." + if not ((inference_dtype != jnp.bfloat16) or platform != "cpu"): + return False, f"Either inference_dtype ('{inference_dtype}') is jnp.bfloat16 or platform ('{platform}') is not cpu." + return True, "" class InferenceTest(test_utils.TestCase): @@ -292,17 +293,14 @@ def init_state(prng_key): return state, ckpt_dir @parameterized.parameters( - filter( - lambda params: is_supported(*params), - itertools.product( - ("cpu", "gpu", "tpu"), # platform, - ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape - (jnp.float32, jnp.bfloat16), # param_dtype - (None, jnp.float32, jnp.bfloat16), # inference_dtype - (1, 16), # global_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition - (True, False), # whether use ema weight - ), + itertools.product( + ("cpu", "gpu", "tpu"), # platform, + ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape + (jnp.float32, jnp.bfloat16), # param_dtype + (None, jnp.float32, jnp.bfloat16), # inference_dtype + (1, 16), # global_batch_size + (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition + (True, False), # whether use ema weight ) ) def test_runner( @@ -322,6 +320,9 @@ def test_runner( global_batch_size, data_partition, ) + supported, reason = is_supported(platform, mesh_shape, inference_dtype, global_batch_size, data_partition) + if not supported: + pytest.skip(reason=reason) with tempfile.TemporaryDirectory() as local_tmp_dir: prng_key = jax.random.PRNGKey(11) local_run = jax.process_count() == 1 @@ -402,16 +403,13 @@ def test_runner( self.assertNestedAllClose(global_outputs, expected_outputs) @parameterized.parameters( - filter( - lambda params: is_supported(*params), - itertools.product( - ("cpu", "gpu"), # platform, - ((1, 1), (4, 1), (8, 1)), # mesh_shape - (jnp.float32,), # param_dtype - (jnp.float32,), # inference_dtype - (16,), # global_batch_size - (DataPartitionType.FULL,), # data_partition - ), + itertools.product( + ("cpu", "gpu"), # platform, + ((1, 1), (4, 1), (8, 1)), # mesh_shape + (jnp.float32,), # param_dtype + (jnp.float32,), # inference_dtype + (16,), # global_batch_size + (DataPartitionType.FULL,), # data_partition ) ) def test_runner_module_outputs( @@ -430,6 +428,9 @@ def test_runner_module_outputs( global_batch_size, data_partition, ) + supported, reason = is_supported(platform, mesh_shape, inference_dtype, global_batch_size, data_partition) + if not supported: + pytest.skip(reason=reason) with tempfile.TemporaryDirectory() as local_tmp_dir: prng_key = jax.random.PRNGKey(11) local_run = jax.process_count() == 1 @@ -486,17 +487,14 @@ def test_runner_module_outputs( self.assertEqual(output.module_outputs, {}) @parameterized.parameters( - filter( - lambda params: is_supported(*params), - itertools.product( - ("cpu", "gpu", "tpu"), # platform, - ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape - (jnp.float32, jnp.bfloat16), # param_dtype - (None, jnp.float32, jnp.bfloat16), # inference_dtype - (1, 16), # global_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition - (True, False), # whether use ema weight - ), + itertools.product( + ("cpu", "gpu", "tpu"), # platform, + ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape + (jnp.float32, jnp.bfloat16), # param_dtype + (None, jnp.float32, jnp.bfloat16), # inference_dtype + (1, 16), # global_batch_size + (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition + (True, False), # whether use ema weight ) ) def test_pipeline( @@ -509,6 +507,9 @@ def test_pipeline( data_partition: DataPartitionType, use_ema: bool, ): + supported, reason = is_supported(platform, mesh_shape, inference_dtype, global_batch_size, data_partition) + if not supported: + pytest.skip(reason=reason) del platform # only used by is_supported_platform(). local_run = jax.process_count() == 1 with tempfile.TemporaryDirectory() as local_tmp_dir: @@ -634,16 +635,13 @@ def test_merge_with_string_tensors_bad_input( merge_with_string_tensors(batch, batch_str_tensors) @parameterized.parameters( - filter( - lambda params: is_supported(*params), - itertools.product( - ("cpu", "gpu", "tpu"), # platform, - ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape - (jnp.float32,), # param_dtype - (jnp.float32,), # inference_dtype - (1, 64), # global_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition - ), + itertools.product( + ("cpu", "gpu", "tpu"), # platform, + ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape + (jnp.float32,), # param_dtype + (jnp.float32,), # inference_dtype + (1, 64), # global_batch_size + (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition ) ) def test_pipeline_with_string_tensors( @@ -655,6 +653,9 @@ def test_pipeline_with_string_tensors( global_batch_size: int, data_partition: DataPartitionType, ): + supported, reason = is_supported(platform, mesh_shape, inference_dtype, global_batch_size, data_partition) + if not supported: + pytest.skip(reason=reason) del platform # only used by is_supported_platform(). local_run = jax.process_count() == 1 mesh_axis_names = ("data", "model") @@ -748,20 +749,17 @@ def decode_fn(record_bytes): ) @parameterized.parameters( - filter( - lambda params: is_supported(*params), - itertools.product( - ("cpu", "gpu"), # platform, - ( - (1, 1), - (4, 1), - (8, 1), - ), # mesh_shape - (jnp.float32,), # param_dtype - (jnp.float32,), # inference_dtype - (16,), # global_batch_size - (DataPartitionType.FULL,), # data_partition - ), + itertools.product( + ("cpu", "gpu"), # platform, + ( + (1, 1), + (4, 1), + (8, 1), + ), # mesh_shape + (jnp.float32,), # param_dtype + (jnp.float32,), # inference_dtype + (16,), # global_batch_size + (DataPartitionType.FULL,), # data_partition ) ) def test_pipeline_summary_writer( @@ -773,6 +771,9 @@ def test_pipeline_summary_writer( global_batch_size: int, data_partition: DataPartitionType, ): + supported, reason = is_supported(platform, mesh_shape, inference_dtype, global_batch_size, data_partition) + if not supported: + pytest.skip(reason=reason) del platform # only used by is_supported_platform(). local_run = jax.process_count() == 1 mesh_axis_names = ("data", "model") diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index 6736fb580..f4c2520cc 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -56,6 +56,7 @@ Tensor, partial_with_fn_metadata, with_sharding_constraint, + maybe_shard ) # The padding type for jax.lax.conv_general_dilated API. Either the strings ‘SAME’, or ‘VALID’, or @@ -2464,6 +2465,12 @@ class Config(BaseLayer.Config): num_embeddings: Required[int] = REQUIRED # Maximum number of embeddings in table. dim: Required[int] = REQUIRED # Embedding vector dimensionality. + # If not None, how to partition pre gather activation values. + pregather_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition embedding table. + embedding_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition post gather activation values. + postgather_partition_spec: Optional[tuple[Optional[str]]] = None @classmethod def default_config(cls): @@ -2498,8 +2505,14 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: ) def forward(self, x: Tensor) -> Tensor: + cfg = self.config + x = maybe_shard(x, cfg.pregather_partition_spec) emb = self.parameters["weight"] - return emb[x] + emb = maybe_shard(emb, cfg.embedding_partition_spec) + activation = emb[x] + activation = maybe_shard(activation, cfg.postgather_partition_spec) + return activation + def attend(self, x: Tensor) -> Tensor: """Apply query array 'x' to the embedding weight array. diff --git a/axlearn/common/ssm_kernels/mamba_kernels_test.py b/axlearn/common/ssm_kernels/mamba_kernels_test.py index 1127c58b2..d482b8d88 100644 --- a/axlearn/common/ssm_kernels/mamba_kernels_test.py +++ b/axlearn/common/ssm_kernels/mamba_kernels_test.py @@ -14,7 +14,7 @@ from axlearn.common.utils import Tensor if jax.default_backend() != "tpu": - pytest.skip(reason="Incompatible hardware", allow_module_level=True) + pytestmark = pytest.mark.skip(reason="Incompatible hardware, TPU only test") # Use higher precision matmuls for testing. diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index a60560769..e9b4852e5 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -47,6 +47,7 @@ from axlearn.common.summary_writer import BaseWriter, SummaryWriter from axlearn.common.update_transformation import ForwardOutputs from axlearn.common.utils import ( + DataPartitionType, HybridMeshShape, MeshShape, Nested, @@ -199,6 +200,9 @@ class Config(Module.Config): # An additional context manager to run the training loop and initialization inside of. # The provided config should instantiate to a thunk that returns the context manager. context_manager: Optional[ConfigOr[Callable[[], ContextManager]]] = None + # The input partition: + # Options: FULL (default), BATCH, REPLICATED + input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL def __init__( self, @@ -343,7 +347,7 @@ def trainer_state_partition_specs(self): def _train_step_input_partition_specs(self): # By default, each input tensor is fully partitioned along the batch axis. - return utils.input_partition_spec() + return utils.data_partition_type_to_spec(self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names) def model_params_for_eval(self): state = self.trainer_state @@ -568,7 +572,7 @@ def run( self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) output = self._run_step( - utils.host_to_global_device_array(input_batch), + utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), force_run_evals=( force_run_eval_sets_at_max_step if self.step >= cfg.max_step else None ), diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 1401337f8..1e141318e 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -443,6 +443,10 @@ def with_sharding_constraint(x, shardings): return x return jax.lax.with_sharding_constraint(x, shardings) +def maybe_shard(x, partition_spec) -> Tensor: + if partition_spec is None: + return x + return with_sharding_constraint(x, PartitionSpec(*partition_spec)) def replicate_to_local_data(x: NestedTensor) -> NestedTensor: """Replicates and converts Tensors in `x` to local DeviceArrays. @@ -591,14 +595,18 @@ class DataPartitionType(Enum): FULL = "full" # Data are fully replicated across all devices. REPLICATED = "replicated" + # Data are partitioned across batch axis only. + BATCH = "batch" -def data_partition_type_to_spec(partition: DataPartitionType) -> PartitionSpec: +def data_partition_type_to_spec(partition: DataPartitionType, * , batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp")) -> PartitionSpec: """Returns a PartitionSpec for the given partition type.""" if partition == DataPartitionType.FULL: return input_partition_spec() elif partition == DataPartitionType.REPLICATED: return None + elif partition == DataPartitionType.BATCH: + return PartitionSpec(batch_axis_names) else: raise NotImplementedError(f"Unsupported partition: {partition}") @@ -607,6 +615,7 @@ def host_to_global_device_array( host_arrays: Nested[Union[np.ndarray, Tensor]], *, partition: DataPartitionType = DataPartitionType.FULL, + batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"), ) -> NestedTensor: """Converts the given host device arrays to global device arrays. @@ -625,7 +634,7 @@ def host_to_global_device_array( NotImplementedError: if the given `partition` type is not supported. """ mesh = thread_resources.env.physical_mesh - partition_spec = data_partition_type_to_spec(partition) + partition_spec = data_partition_type_to_spec(partition, batch_axis_names=batch_axis_names) partition_specs = complete_partition_spec_tree( jax.tree_util.tree_structure(host_arrays), partition_spec ) @@ -636,6 +645,8 @@ def make_gda(x, partition_spec): global_shape = (x.shape[0] * process_count, *x.shape[1:]) elif partition == DataPartitionType.REPLICATED: global_shape = (x.shape[0], *x.shape[1:]) + elif partition == DataPartitionType.BATCH: + global_shape = (x.shape[0] * process_count, *x.shape[1:]) else: raise NotImplementedError(f"Unsupported partition: {partition}") return jax.make_array_from_process_local_data( diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index f4c06b47a..374d85e66 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -42,6 +42,7 @@ ) from axlearn.common.trainer import SpmdTrainer from axlearn.common.utils import ( + DataPartitionType, PHYSICAL_TO_LOGICAL_DISPATCH_KEY, HybridMeshShape, MeshShape, @@ -1701,6 +1702,31 @@ def test_length(self): class HostToGlobalArrayTest(TestCase): """Tests host_to_global_device_array.""" + @pytest.mark.neuron + def test_partition_batch(self): + """Test a case where each process produces a slice.""" + device_count = jax.device_count() + process_count = jax.process_count() + print(f"{device_count=}, {process_count=}") + assert device_count > 1 + + global_shape = (device_count // 2, 1) + assert global_shape[0] % process_count == 0 + per_feed_size = global_shape[0] // process_count + feed_index = jax.process_index() + + with jax.sharding.Mesh(np.array(jax.devices()).reshape(device_count // 2, 2), ("x", "y")): + start = feed_index * per_feed_size + local_x = jnp.arange(start, start + per_feed_size)[:, None] + + # Construct global array. + global_x = host_to_global_device_array(local_x, partition=DataPartitionType.BATCH, batch_axis_names="x") + + # Compare against expected. + expected = jnp.arange(global_shape[0])[:, None] + self.assertEqual(jnp.mean(expected), jnp.mean(global_x)) + self.assertNestedEqual(expected, replicate_to_local_data(global_x)) + @pytest.mark.tpu def test_partition_full(self): """Test a case where each process produces a slice.""" diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index d32120d25..a1dc30371 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -13,7 +13,7 @@ import math from collections.abc import Sequence from typing import Literal, Optional, Protocol, Union - +import jax import jax.numpy as jnp import tensorflow as tf from jax.sharding import PartitionSpec @@ -57,7 +57,7 @@ from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer from axlearn.common.summary_writer import BaseWriter from axlearn.common.trainer import MeshShape, SpmdTrainer -from axlearn.common.utils import HybridMeshShape, Nested, get_data_dir +from axlearn.common.utils import DataPartitionType, HybridMeshShape, Nested, get_data_dir from axlearn.experiments.text.common import DataMixtureComponent, tfds_text_source from axlearn.experiments.trainer_config_utils import TrainerConfigFn @@ -201,10 +201,10 @@ def update_model_remat_config( Raises: NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: - raise NotImplementedError( - f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" - ) + # if stack_cfg.klass is not RepeatedTransformerLayer: + # raise NotImplementedError( + # f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" + # ) if layer_cfg.self_attention.attention.klass is not FlashAttention: # Enable remat to reduce memory usage for larger models. @@ -288,8 +288,8 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: - update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) + # if stack_cfg.klass is RepeatedTransformerLayer: + update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) decoder_cfg = Decoder.default_config().set( @@ -326,6 +326,15 @@ def model_config( seq_axis_names="seq", ) cfg.decoder.logits_partition_spec = (batch_axis_names, "seq", "model") + # Neuron backend require fine grained sharding around embedding gather op. + if jax.default_backend() == 'neuron': + cfg.decoder.emb.token_emb.param_partition_spec = ("model", ("expert", "fsdp", "seq")) # shard vocab + if lm_head_cfg != None: + cfg.decoder.lm_head.param_partition_spec = ("model", ("expert", "fsdp", "seq")) # shard vocab + cfg.decoder.emb.token_emb.pregather_partition_spec = ('fsdp', None) + cfg.decoder.emb.token_emb.embedding_partition_spec = ('model', None) + cfg.decoder.emb.token_emb.postgather_partition_spec = ('fsdp', None, None) + set_bias_recursively(cfg, False) set_norm_recursively(cfg, normalization) cfg.z_loss_scale = z_loss_scale @@ -638,6 +647,7 @@ def get_trainer_config_fn( train_input_source: InstantiableConfig[input_tf_data.BuildDatasetFn], evalers: dict[str, SpmdEvaler.Config], mesh_shape: Union[MeshShape, HybridMeshShape], + input_partition_type: Optional[DataPartitionType] = None, mesh_axis_names: Sequence[str] = MESH_AXIS_NAMES, mesh_rules: Optional[Sequence[tuple[str, Optional[Union[MeshShape, HybridMeshShape]]]]] = None, eval_every_n_steps: int = 5000, @@ -689,9 +699,27 @@ def config_fn() -> InstantiableConfig: pad_example_fn=input_tf_data.default_pad_example_fn, ), ) + if input_partition_type: + cfg.input_partition_type = input_partition_type + if len(mesh_axis_names) != len(mesh_shape): + raise ValueError( + f"Number of mesh axis names ({mesh_axis_names}) " + f"must match number of mesh dims ({mesh_shape})." + ) + cfg.mesh_axis_names = mesh_axis_names + cfg.mesh_shape = mesh_shape + # Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and + # "pipeline" axis (for pipeline parallelism). + cfg.batch_axis_names = tuple( + el for el in mesh_axis_names if el not in ("model", "pipeline") + ) + cfg.mesh_rules = mesh_rules cfg.evalers = {} for name, evaler_cfg in evalers.items(): evaler_cfg.input.batcher.set(global_batch_size=eval_batch_size or train_batch_size) + if input_partition_type: + evaler_cfg.set(input_partition_type=input_partition_type) + evaler_cfg.set(batch_axis_names=cfg.batch_axis_names) evaler_cfg.set( eval_policy=config_for_function(eval_every_n_steps_policy).set( n=eval_every_n_steps, @@ -708,19 +736,6 @@ def config_fn() -> InstantiableConfig: cfg.checkpointer.keep_last_n = 3 cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 - if len(mesh_axis_names) != len(mesh_shape): - raise ValueError( - f"Number of mesh axis names ({mesh_axis_names}) " - f"must match number of mesh dims ({mesh_shape})." - ) - cfg.mesh_axis_names = mesh_axis_names - cfg.mesh_shape = mesh_shape - # Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and - # "pipeline" axis (for pipeline parallelism). - cfg.batch_axis_names = tuple( - el for el in mesh_axis_names if el not in ("model", "pipeline") - ) - cfg.mesh_rules = mesh_rules # Maybe load state. if init_state_builder: cfg.init_state_builder = init_state_builder diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..09e1ef8ea 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -15,16 +15,20 @@ import itertools from typing import Any, Optional, Union +import jax from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies from axlearn.common import causal_lm, config from axlearn.common.attention import ( BaseStackedTransformerLayer, FusedGroupedQKVLinear, + GroupedQKVLinear, FusedQKVLinear, GroupedQueryAttention, MultiheadAttention, RepeatedTransformerLayer, + StackedTransformerLayer, + TransformerLayer, RoFormerQKVLinear, ) from axlearn.common.base_layer import RematSpec @@ -39,7 +43,7 @@ MeshShapeModifier, RematSpecModifier, ) -from axlearn.common.utils import extended_checkpoint_policies +from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies from axlearn.experiments.text.gpt.common import ( STEP_DTYPE, SourceBuilder, @@ -113,7 +117,7 @@ def get_trainer_kwargs( *, vocab_size: int, version: Version, - flash_attention: bool = False, + flash_attention: bool = True, ) -> dict[str, Any]: """Construct default trainer kwargs given a model size.""" tokens_per_batch = 4 * (1024**2) # 4M tokens. @@ -127,6 +131,7 @@ def get_trainer_kwargs( num_kv_heads = None if version == Version.V3: num_kv_heads = 8 + backend = jax.default_backend() rope_theta = ROPE_THETA[version] @@ -150,8 +155,8 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=6e-4, weight_decay=0.01), max_sequence_length=64, - train_batch_size=32, - eval_batch_size=32, + train_batch_size=64, + eval_batch_size=64, max_step=3000, eval_every_n_steps=1500, save_every_n_steps=500, @@ -174,6 +179,12 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), + ) ) elif model_size == "3B": trainer_kwargs = dict( @@ -192,6 +203,12 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), + ) ) elif model_size == "7B": trainer_kwargs = dict( @@ -207,6 +224,7 @@ def get_trainer_kwargs( learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, + input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -287,6 +305,14 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), + ( + "neuron-(trn1|trn1n).32xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=8), + ), ), ) elif model_size == "8B": @@ -367,6 +393,10 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), ), ) elif model_size == "70B": @@ -381,11 +411,12 @@ def get_trainer_kwargs( ffn_dim=scaled_hidden_dim(scale=3.5, round_up_to_multiples_of=256), rope_theta=rope_theta, shared_lm_head=False, - flash_attention=flash_attention, + flash_attention=True, ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, + input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -417,12 +448,17 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), ), ) else: raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) + model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config()) trainer_kwargs["model_cfg"] = model_config(**model_kwargs) trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( max_step=trainer_kwargs["max_step"], @@ -443,7 +479,7 @@ def model_config( shared_lm_head: bool, dropout_rate: float = 0.0, ffn_dim: Optional[Union[int, config.FunctionConfigBase]] = None, - flash_attention: bool = False, + flash_attention: bool = True, stack_cfg: Optional[BaseStackedTransformerLayer.Config] = None, ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -473,7 +509,10 @@ def model_config( ffn_dim = scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=256) if num_kv_heads: atten_cfg = GroupedQueryAttention.default_config() - atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads) + backend = jax.default_backend() + + qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear + atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads) else: atten_cfg = MultiheadAttention.default_config() atten_input_linear = FusedQKVLinear.default_config() @@ -499,6 +538,7 @@ def model_config( emb_cfg=TransformerTextEmbeddings.default_config().set(pos_emb=None), lm_head_cfg=LmHead.default_config() if not shared_lm_head else None, attention_cfg=flash_attention_config() if flash_attention else atten_cfg, + # layer_cfg=layer_cfg, attention_qkv_linear=atten_qkv_linear, ) return cfg