-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Fuji v3 405b and solve HBM OOMs for larger models #766
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,7 @@ | |
from axlearn.experiments.text.gpt.common import scaled_hidden_dim | ||
from axlearn.experiments.trainer_config_utils import TrainerConfigFn | ||
|
||
MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B") | ||
MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B", "405B") | ||
|
||
|
||
class Version(enum.Enum): | ||
|
@@ -104,6 +104,7 @@ class Version(enum.Enum): | |
"3B": 15 * (1024**4), # 15T tokens | ||
"8B": 15 * (1024**4), # 15T tokens | ||
"70B": 15 * (1024**4), # 15T tokens | ||
"405B": 15 * (1024**4), # 15T tokens | ||
}, | ||
} | ||
|
||
|
@@ -117,6 +118,9 @@ def get_trainer_kwargs( | |
) -> dict[str, Any]: | ||
"""Construct default trainer kwargs given a model size.""" | ||
tokens_per_batch = 4 * (1024**2) # 4M tokens. | ||
if version == Version.V3 and model_size == "405B": | ||
tokens_per_batch = 16 * (1024**2) # 16M tokens. | ||
|
||
if model_size not in TOTAL_TOKENS[version]: | ||
return {} | ||
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch | ||
|
@@ -419,6 +423,55 @@ def get_trainer_kwargs( | |
), | ||
), | ||
) | ||
elif model_size == "405B": | ||
remat_policy = config_for_function( | ||
jax_remat_policies.save_and_offload_only_these_names | ||
).set( | ||
names_which_can_be_saved=[], | ||
names_which_can_be_offloaded=[ | ||
"TransformerLayer.input", | ||
], | ||
offload_src="device", | ||
offload_dst="pinned_host", | ||
) | ||
trainer_kwargs = dict( | ||
model_kwargs=dict( | ||
num_layers=126, | ||
hidden_dim=53248, | ||
num_heads=128, | ||
num_kv_heads=8, | ||
rope_theta=rope_theta, | ||
shared_lm_head=False, | ||
flash_attention=flash_attention, | ||
), | ||
learner_kwargs=dict(peak_lr=8e-5, weight_decay=0.1), | ||
max_sequence_length=max_sequence_length, | ||
train_batch_size=train_batch_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Training batch is not 4M tokens for 400B model, it is 16M according to this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
max_step=max_step, | ||
mesh_shape=mesh_shape_from_axes(fsdp=-1), | ||
mesh_rules=( | ||
# tpu-v6e. | ||
( | ||
"tpu-v6e-.*", | ||
ChainConfigModifier.default_config().set( | ||
config_modifiers=[ | ||
MeshShapeModifier.default_config().set( | ||
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note this requires optimizer state weight only offloading PR to be merged: #789 |
||
), | ||
RematSpecModifier.default_config().set( | ||
remat_policies={ | ||
"model.decoder.transformer.layer": RematSpec( | ||
prevent_cse=True, | ||
policy=remat_policy, | ||
), | ||
} | ||
), | ||
], | ||
), | ||
), | ||
), | ||
) | ||
|
||
else: | ||
raise NotImplementedError(f"Unknown model size {model_size}.") | ||
model_kwargs = trainer_kwargs.pop("model_kwargs") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kelvin-zou unclear if this is safe to change by default for everyone. Please review this specifically. It was needed to fit 405b on HBM memory.