Skip to content

Commit

Permalink
add Fuji v3 405b model config
Browse files Browse the repository at this point in the history
  • Loading branch information
samos123 committed Dec 2, 2024
1 parent 2056857 commit 7532d9e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
1 change: 1 addition & 0 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3116,6 +3116,7 @@ def _forward_for_mode(
Raises:
ValueError: If `mode` is unsupported.
"""
data = self._remat_name(data, "input")
if isinstance(data, Tensor):
self.vlog(3, "transformer.input=%s", data.sum()) # pytype: disable=attribute-error
self_attention_return_aux = set()
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ class Config(BaseLayer.Config):
@classmethod
def default_config(cls):
cfg = super().default_config()
cfg.param_partition_spec = (None, "model")
cfg.param_partition_spec = ("fsdp", "model")
# By default, initialize to Gaussian with std=1/sqrt(dim), e.g., 0.036 when dim=768.
#
# This is the same as:
Expand Down
54 changes: 53 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from axlearn.experiments.text.gpt.common import scaled_hidden_dim
from axlearn.experiments.trainer_config_utils import TrainerConfigFn

MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B")
MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B", "405B")


class Version(enum.Enum):
Expand Down Expand Up @@ -104,6 +104,7 @@ class Version(enum.Enum):
"3B": 15 * (1024**4), # 15T tokens
"8B": 15 * (1024**4), # 15T tokens
"70B": 15 * (1024**4), # 15T tokens
"405B": 15 * (1024**4), # 15T tokens
},
}

Expand All @@ -117,6 +118,9 @@ def get_trainer_kwargs(
) -> dict[str, Any]:
"""Construct default trainer kwargs given a model size."""
tokens_per_batch = 4 * (1024**2) # 4M tokens.
if version == Version.V3 and model_size == "405B":
tokens_per_batch = 16 * (1024**2) # 16M tokens.

if model_size not in TOTAL_TOKENS[version]:
return {}
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
Expand Down Expand Up @@ -419,6 +423,54 @@ def get_trainer_kwargs(
),
),
)
elif model_size == "405B":
remat_policy = config_for_function(
jax_remat_policies.save_and_offload_only_these_names
).set(
names_which_can_be_saved=[],
names_which_can_be_offloaded=[
"TransformerLayer.input",
],
offload_src="device",
offload_dst="pinned_host",
)
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=126,
hidden_dim=53248,
num_heads=128,
num_kv_heads=8,
rope_theta=rope_theta,
flash_attention=flash_attention,
),
learner_kwargs=dict(peak_lr=8e-5, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
mesh_rules=(
# tpu-v6e.
(
"tpu-v6e-.*",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=remat_policy,
),
}
),
],
),
),
),
)

else:
raise NotImplementedError(f"Unknown model size {model_size}.")
model_kwargs = trainer_kwargs.pop("model_kwargs")
Expand Down

0 comments on commit 7532d9e

Please sign in to comment.