Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Fuji v3 405b and solve HBM OOMs for larger models #766

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3116,6 +3116,7 @@ def _forward_for_mode(
Raises:
ValueError: If `mode` is unsupported.
"""
data = self._remat_name(data, "input")
if isinstance(data, Tensor):
self.vlog(3, "transformer.input=%s", data.sum()) # pytype: disable=attribute-error
self_attention_return_aux = set()
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ class Config(BaseLayer.Config):
@classmethod
def default_config(cls):
cfg = super().default_config()
cfg.param_partition_spec = (None, "model")
cfg.param_partition_spec = ("fsdp", "model")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kelvin-zou unclear if this is safe to change by default for everyone. Please review this specifically. It was needed to fit 405b on HBM memory.

# By default, initialize to Gaussian with std=1/sqrt(dim), e.g., 0.036 when dim=768.
#
# This is the same as:
Expand Down
55 changes: 54 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from axlearn.experiments.text.gpt.common import scaled_hidden_dim
from axlearn.experiments.trainer_config_utils import TrainerConfigFn

MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B")
MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B", "405B")


class Version(enum.Enum):
Expand Down Expand Up @@ -104,6 +104,7 @@ class Version(enum.Enum):
"3B": 15 * (1024**4), # 15T tokens
"8B": 15 * (1024**4), # 15T tokens
"70B": 15 * (1024**4), # 15T tokens
"405B": 15 * (1024**4), # 15T tokens
},
}

Expand All @@ -117,6 +118,9 @@ def get_trainer_kwargs(
) -> dict[str, Any]:
"""Construct default trainer kwargs given a model size."""
tokens_per_batch = 4 * (1024**2) # 4M tokens.
if version == Version.V3 and model_size == "405B":
tokens_per_batch = 16 * (1024**2) # 16M tokens.

if model_size not in TOTAL_TOKENS[version]:
return {}
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
Expand Down Expand Up @@ -419,6 +423,55 @@ def get_trainer_kwargs(
),
),
)
elif model_size == "405B":
remat_policy = config_for_function(
jax_remat_policies.save_and_offload_only_these_names
).set(
names_which_can_be_saved=[],
names_which_can_be_offloaded=[
"TransformerLayer.input",
],
offload_src="device",
offload_dst="pinned_host",
)
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=126,
hidden_dim=53248,
num_heads=128,
num_kv_heads=8,
rope_theta=rope_theta,
shared_lm_head=False,
flash_attention=flash_attention,
),
learner_kwargs=dict(peak_lr=8e-5, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Training batch is not 4M tokens for 400B model, it is 16M according to this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
mesh_rules=(
# tpu-v6e.
(
"tpu-v6e-.*",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note this requires optimizer state weight only offloading PR to be merged: #789

),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=remat_policy,
),
}
),
],
),
),
),
)

else:
raise NotImplementedError(f"Unknown model size {model_size}.")
model_kwargs = trainer_kwargs.pop("model_kwargs")
Expand Down