Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special remat for Neuron #898

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4001,15 +4001,22 @@ def _save_and_offload_only_these_names_regex(
)


SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)"
FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*"
# Regex patterns for matching remat names
class RematRegexSavePatterns(enum.Enum):
QKV_PROJ = r".*[kqv]_proj"
O_PROJ = r".*o_proj"
CONTEXT = r".*context"
LINEAR1_X = r".*linear1_[01]"
LINEAR2_X = r".*linear2_[01]"
SELF_ATTENTION = ".*([qkvo]_proj|context)"
FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X])


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: SavePattern = SELF_ATTENTION_SAVE_PATTERN,
save_pattern: SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value,
offload_pattern: SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
Expand Down
71 changes: 67 additions & 4 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from axlearn.common import attention, attention_bias, test_utils, utils
from axlearn.common.attention import (
FEED_FORWARD_SAVE_PATTERN,
BaseStackedTransformerLayer,
BaseTransformerLayer,
BottleNeckAdapterTransformerLayer,
Expand All @@ -58,14 +57,14 @@
PipelinedTransformerLayer,
QKVLinear,
QLinear,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
StackedTransformerLayer,
TransformerAttentionLayer,
TransformerFeedForwardLayer,
TransformerLayer,
_next_power_of_two,
_save_and_offload_only_these_names_regex,
apply_attention_logit_biases,
apply_rotary_position_embeddings,
build_remat_spec,
Expand Down Expand Up @@ -124,6 +123,7 @@
VDict,
as_tensor,
flatten_items,
save_and_offload_only_these_names_regex,
shapes,
)

Expand Down Expand Up @@ -3445,8 +3445,8 @@ def f(x, layer_params):
_, save_name_backward = jax.linearize(
jax.remat(
f,
policy=_save_and_offload_only_these_names_regex(
names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN,
policy=save_and_offload_only_these_names_regex(
names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value,
names_which_can_be_offloaded=None,
offload_src="device",
offload_dst="pinned_host",
Expand Down Expand Up @@ -3901,6 +3901,69 @@ def f(x, layer_params):
5,
)

def test_build_remat_spec_neuron(self):
model_dim, num_heads = 6, 2
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim)
cfg.self_attention.attention.set(num_heads=num_heads, causal=True)
cfg.feed_forward.hidden_dim = model_dim * 4
cfg.vlog = 5

layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))

batch_size, tgt_len = 2, 5
rng = np.random.default_rng(seed=123)
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32)

def f(x, layer_params):
forward_outputs, _ = F(
layer,
inputs=dict(
data=x,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return forward_outputs

# Ignore type errors.
spec: Any = build_remat_spec(mock.MagicMock())

policy = (
config_for_function(save_and_offload_only_these_names_regex)
.set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
)
.instantiate()
)

_, default_policy_backward = jax.linearize(
jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse),
jnp.asarray(target),
layer_params,
)
_, full_remat_backward = jax.linearize(
jax.remat(f),
jnp.asarray(target),
layer_params,
)

# Eliminated the remat of qkv_proj and linear1_0 = 4 dots.
self.assertEqual(
str(full_remat_backward).count(" dot_general")
- str(default_policy_backward).count(" dot_general"),
4,
)


class TestStackModel(BaseLayer):
"""A dummy transformer stack."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ mesh_rules[3][1][2]: 1
mesh_rules[3][1][3]: 128
mesh_rules[3][1][4]: 1
mesh_rules[3][1][5]: 1
mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down
11 changes: 2 additions & 9 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BaseQKVLinear,
MultiheadAttention,
RepeatedTransformerLayer,
StackedTransformerLayer,
TransformerLayer,
build_remat_spec,
set_double_shard_weights_config,
Expand Down Expand Up @@ -190,20 +191,12 @@ def update_model_remat_config(
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.

Only applied if the stack_cfg is a RepeatedTransformerLayer.

Args:
stack_cfg: The transformer stack config.
layer_cfg: The transformer layer config.
offload_dst: Destination of remat checkptoing offloading.

Raises:
NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer.
"""
if stack_cfg.klass is not RepeatedTransformerLayer:
raise NotImplementedError(
f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}"
)

remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
layer_cfg.set(remat_spec=remat_spec)
Expand Down Expand Up @@ -277,7 +270,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)):
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
Expand Down
40 changes: 36 additions & 4 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

from axlearn.common import causal_lm, config
from axlearn.common.attention import (
SELF_ATTENTION_SAVE_PATTERN,
BaseStackedTransformerLayer,
FusedGroupedQKVLinear,
FusedQKVLinear,
GroupedQueryAttention,
MultiheadAttention,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
)
Expand All @@ -40,7 +40,10 @@
MeshShapeModifier,
RematSpecModifier,
)
from axlearn.common.utils import extended_checkpoint_policies
from axlearn.common.utils import (
extended_checkpoint_policies,
save_and_offload_only_these_names_regex,
)
from axlearn.experiments.text.gpt.common import (
STEP_DTYPE,
SourceBuilder,
Expand Down Expand Up @@ -86,7 +89,6 @@ class Version(enum.Enum):
Version.V3: 5e5,
}


# Mapping from Fuji versions to total number of tokens used in training.
TOTAL_TOKENS = {
Version.V1: {
Expand Down Expand Up @@ -147,7 +149,7 @@ def get_trainer_kwargs(
extended_checkpoint_policies.save_and_offload_only_these_names_regex
).set(
names_which_can_be_saved=None,
names_which_can_be_offloaded=SELF_ATTENTION_SAVE_PATTERN,
names_which_can_be_offloaded=RematRegexSavePatterns.SELF_ATTENTION.value,
offload_src="device",
offload_dst="pinned_host",
)
Expand Down Expand Up @@ -492,6 +494,36 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(
save_and_offload_only_these_names_regex
).set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
hanzhi713 marked this conversation as resolved.
Show resolved Hide resolved
]
),
names_which_can_be_offloaded=None,
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
offload_src=None,
offload_dst=None,
),
),
}
),
],
),
),
),
)
else:
Expand Down
Loading