diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 37baf3d8b..dffbf8c1c 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -2508,16 +2508,19 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: atten_state, atten_output = attention_thunk(TensorSpec(target.shape, target.dtype)) return dict(attention=atten_state), atten_output + remat_pt1 = "attention_output" if cfg.structure == "prenorm": skip_input = target # pre-norm: where normalization happens within the residual part. norm_target = self.norm(target) atten_state, atten_output = attention_thunk(norm_target) data = skip_input + self.stochastic_depth(self.dropout(atten_output.data)) + data = self._remat_name(data, remat_pt1) elif cfg.structure == "postnorm": # This is the structure used by the original Transformer, BERT, and RoBERTa. atten_state, atten_output = attention_thunk(target) # Post-norm: norm applied on the sum of input and attention output. data = self.norm(target + self.stochastic_depth(self.dropout(atten_output.data))) + data = self._remat_name(data, remat_pt1) elif cfg.structure == "hybridnorm": skip_input = target # pre-norm: where normalization happens within the residual part. norm_target = self.prenorm(target) @@ -2525,6 +2528,7 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: data = skip_input + self.stochastic_depth( self.dropout(self.postnorm(atten_output.data)) ) + data = self._remat_name(data, remat_pt1) else: raise NotImplementedError(cfg.structure) return dict(attention=atten_state), self.Output( @@ -2801,6 +2805,7 @@ def _linear2(x): self._add_tensor_stats("inputs", inputs) remat_pt2 = "linear2" + remat_pt3 = "feed_forward_output" if cfg.structure == "prenorm": x = self.norm(inputs) x = self._linear1_activation(x) @@ -2812,6 +2817,7 @@ def _linear2(x): if cfg.residual_weight != 1: x *= cfg.residual_weight x += inputs + x = self._remat_name(x, remat_pt3) elif cfg.structure == "postnorm": x = self._linear1_activation(inputs) x = _linear2(x) @@ -2821,6 +2827,7 @@ def _linear2(x): if cfg.residual_weight != 1: x *= cfg.residual_weight x = self.norm(x + inputs) + x = self._remat_name(x, remat_pt3) elif cfg.structure == "hybridnorm": x = self.prenorm(inputs) x = self._linear1_activation(x) @@ -2833,6 +2840,7 @@ def _linear2(x): if cfg.residual_weight != 1: x *= cfg.residual_weight x += inputs + x = self._remat_name(x, remat_pt3) elif cfg.structure == "nonorm": x = inputs x = self._linear1_activation(x) @@ -2845,6 +2853,7 @@ def _linear2(x): # this layer, e.g., in ParallelTransformerLayer. if cfg.residual_weight != 1: x *= cfg.residual_weight + x = self._remat_name(x, remat_pt3) else: raise NotImplementedError(cfg.structure) return x @@ -3956,15 +3965,21 @@ def policy(prim, *_, **params): return policy -SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)" -FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*" +# Regex patterns for matching remat names +class RematRegexSavePatterns(enum.Enum): + QKV_PROJ = r".*\.?(k|q|v)_proj" + LINEAR1_X = r".*\.?linear1_[01]" + ATTENTION_OUTPUT = r"TransformerAttentionLayer\.attention_output" + FEED_FORWARD_OUTPUT = r"TransformerFeedForwardLayer\.feed_forward_output" + SELF_ATTENTION = ".*([qkvo]_proj|context)" + FEED_FORWARD = ".*linear[12]_.*" def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore ], - save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN, + save_pattern: _SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value, offload_pattern: _SavePattern = None, offload_dst: str = "pinned_host", ) -> Optional[RematSpec]: diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 1e188ecc0..41d2513c8 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -41,7 +41,6 @@ from axlearn.common import attention, attention_bias, test_utils, utils from axlearn.common.attention import ( - FEED_FORWARD_SAVE_PATTERN, BaseStackedTransformerLayer, BaseTransformerLayer, BottleNeckAdapterTransformerLayer, @@ -58,6 +57,7 @@ PipelinedTransformerLayer, QKVLinear, QLinear, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, StackedTransformerLayer, @@ -3420,7 +3420,7 @@ def f(x, layer_params): jax.remat( f, policy=_save_and_offload_only_these_names_regex( - names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN, + names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value, names_which_can_be_offloaded=None, offload_src="device", offload_dst="pinned_host", @@ -3875,6 +3875,72 @@ def f(x, layer_params): 5, ) + def test_build_remat_spec_neuron(self): + model_dim, num_heads = 6, 2 + cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) + cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + cfg.feed_forward.hidden_dim = model_dim * 4 + cfg.vlog = 5 + + layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) + + def f(x, layer_params): + forward_outputs, _ = F( + layer, + inputs=dict( + data=x, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return forward_outputs + + # Ignore type errors. + spec: Any = build_remat_spec(mock.MagicMock()) + + policy = ( + config_for_function(_save_and_offload_only_these_names_regex) + .set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + RematRegexSavePatterns.ATTENTION_OUTPUT.value, + RematRegexSavePatterns.FEED_FORWARD_OUTPUT.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ) + .instantiate() + ) + + _, default_policy_backward = jax.linearize( + jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse), + jnp.asarray(target), + layer_params, + ) + _, full_remat_backward = jax.linearize( + jax.remat(f), + jnp.asarray(target), + layer_params, + ) + + # Eliminated the remat of qkv_proj, o_proj and linear1_0 = 5 dots. This assumes + # FlashAttention is not enabled. + self.assertEqual( + str(full_remat_backward).count(" dot_general") + - str(default_policy_backward).count(" dot_general"), + 5, + ) + class TestStackModel(BaseLayer): """A dummy transformer stack.""" diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 973fb9234..c01789a39 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -34,6 +34,7 @@ BaseQKVLinear, MultiheadAttention, RepeatedTransformerLayer, + StackedTransformerLayer, TransformerLayer, build_remat_spec, set_double_shard_weights_config, @@ -190,20 +191,12 @@ def update_model_remat_config( ): """Recomputes and sets the remat_spec based on provided layer_cfg. - Only applied if the stack_cfg is a RepeatedTransformerLayer. - Args: stack_cfg: The transformer stack config. layer_cfg: The transformer layer config. offload_dst: Destination of remat checkptoing offloading. - Raises: - NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: - raise NotImplementedError( - f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" - ) remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg)) layer_cfg.set(remat_spec=remat_spec) @@ -277,7 +270,7 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: + if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)): update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..61a8ff073 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -24,8 +24,10 @@ FusedQKVLinear, GroupedQueryAttention, MultiheadAttention, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, + _save_and_offload_only_these_names_regex, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -85,7 +87,6 @@ class Version(enum.Enum): Version.V3: 5e5, } - # Mapping from Fuji versions to total number of tokens used in training. TOTAL_TOKENS = { Version.V1: { @@ -417,6 +418,38 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=config_for_function( + _save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + RematRegexSavePatterns.RESIDUAL_ADD.value, + RematRegexSavePatterns.MLP_RESIDUAL.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ), + ), + } + ), + ], + ), + ), ), ) else: