Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

32 b #121

Draft
wants to merge 118 commits into
base: main
Choose a base branch
from
Draft

32 b #121

Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
b94e702
Save more often
dirkgr Dec 8, 2024
368abb8
Don't check for cancelation all the time
dirkgr Dec 8, 2024
c277d54
Make sure we use the same CE loss that we used for the 13B
dirkgr Dec 8, 2024
7c74d8b
We're going to 5T!
dirkgr Dec 8, 2024
53d61fe
We can live with a bigger eval batch size.
dirkgr Dec 8, 2024
514abb8
Add MMLU downstream eval
dirkgr Dec 9, 2024
011113e
Module isn't callable
dirkgr Dec 9, 2024
2577397
Qwen-ish
dirkgr Dec 9, 2024
93637a1
Make model bigger
dirkgr Dec 9, 2024
784377d
It's now a 32B.
dirkgr Dec 10, 2024
eec7e10
6T tokens
dirkgr Dec 10, 2024
bd5edee
Official save folder
dirkgr Dec 10, 2024
f516f09
6.5T tokens
dirkgr Dec 10, 2024
49264f5
Merge remote-tracking branch 'origin/main' into 32B
dirkgr Dec 10, 2024
4bb5d5c
Merged
dirkgr Dec 10, 2024
1ff1371
Change project name and location
dirkgr Dec 10, 2024
4375612
Revert "Merged"
dirkgr Dec 10, 2024
20b9b08
Revert "Module isn't callable"
dirkgr Dec 10, 2024
7736198
Revert "Make sure we use the same CE loss that we used for the 13B"
dirkgr Dec 10, 2024
8e0613f
We still want it fused!
dirkgr Dec 10, 2024
5652953
One-in-two activation checkpointing
dirkgr Dec 10, 2024
323c786
Merge remote-tracking branch 'origin/main' into 32B
dirkgr Dec 10, 2024
4f676e2
Smaller microbatch
dirkgr Dec 10, 2024
d4e63fa
Wrap 3 in 4 blocks
dirkgr Dec 10, 2024
7c22386
Don't compile the loss.
dirkgr Dec 10, 2024
f38bff4
Turn off broken eval
dirkgr Dec 11, 2024
3bf2440
Go back to mbsz of 4
dirkgr Dec 11, 2024
ab5afcf
Set drop_last for DownstreamEvaluator to False
2015aroras Dec 11, 2024
47f9545
Bring back Copa now that we have Shane's fix
dirkgr Dec 11, 2024
ee6aa90
Merge remote-tracking branch 'origin/32B' into 32B
dirkgr Dec 11, 2024
c656a41
Check if beaker loading issues are due to beaker changes by updating …
2015aroras Dec 11, 2024
7852e1e
Try hsdp with 2 nodes per replica
2015aroras Dec 11, 2024
b19e76d
Revert "Try hsdp with 2 nodes per replica"
2015aroras Dec 11, 2024
a02dd95
Try activation checkpointing 3 in 4
2015aroras Dec 12, 2024
6eaa5a3
Try activation checkpointing 3 in 4 + all feedforwards checkpointed
2015aroras Dec 12, 2024
b2a07de
Decrease microbatch size
2015aroras Dec 13, 2024
9985d31
Try activation checkpointing on just feed forwards
2015aroras Dec 13, 2024
4cc6a62
Fix name
dirkgr Dec 16, 2024
1060499
Try to run with hybrid sharding.
dirkgr Dec 16, 2024
fb2a274
More batch
dirkgr Dec 16, 2024
1073613
Revert "More batch"
dirkgr Dec 16, 2024
c553b98
There is something wrong with how the `common` object is set up.
dirkgr Dec 16, 2024
e49d4b7
We need a less sharded checkpoint and I guess this is the only way we…
dirkgr Dec 16, 2024
9608482
Revert "We need a less sharded checkpoint and I guess this is the onl…
dirkgr Dec 16, 2024
4804004
Async checkpointer may have problems with large checkpoints?
dirkgr Dec 16, 2024
fd4edb8
For loading checkpoints, it seems we need a longer timeout
dirkgr Dec 16, 2024
1f79446
Revert "Async checkpointer may have problems with large checkpoints?"
dirkgr Dec 16, 2024
072c616
Flight to safety
dirkgr Dec 16, 2024
6ba3e23
Increase microbatch size up to 2 * 4096
2015aroras Dec 17, 2024
07cc66c
Watching the 32B in a notebook
dirkgr Dec 18, 2024
18e9a32
Merge branch '32B' of https://github.com/allenai/OLMo-core into 32B
dirkgr Dec 18, 2024
2150b36
Merge branch 'main' into 32B
2015aroras Dec 19, 2024
c8cf403
Enable HSDP with pre-downloading
2015aroras Dec 19, 2024
d9cb6cf
Turn off hsdp
2015aroras Dec 19, 2024
5f2cf19
Revert "Turn off hsdp"
2015aroras Dec 19, 2024
19c8758
Add option to set thread_count
2015aroras Dec 19, 2024
9a12202
Run formatter
2015aroras Dec 19, 2024
d5e6e2b
Limit thread count
2015aroras Dec 19, 2024
ea0acce
Decrease microbatch size
2015aroras Dec 19, 2024
d2a00a7
Increase microbatch size, increase activation checkpointing
2015aroras Dec 19, 2024
016e426
Decrease microbatch size
2015aroras Dec 20, 2024
a28ca37
Decrease thread_count
2015aroras Dec 20, 2024
1c33794
Thread count 1
2015aroras Dec 20, 2024
484d01c
Back to FSDP
2015aroras Dec 20, 2024
275364c
Back to HSDP, but with less replicas
2015aroras Dec 20, 2024
54d5623
Merge branch 'main' into 32B
2015aroras Dec 20, 2024
4644e6e
Microbatch size back to 1
2015aroras Dec 20, 2024
d7ed30e
Revert "Microbatch size back to 1"
2015aroras Dec 20, 2024
0c47992
Back to FSDP
2015aroras Dec 20, 2024
246eff6
Revert "Back to FSDP"
2015aroras Dec 20, 2024
b956e3f
Enable NCCL debug
2015aroras Dec 20, 2024
f877907
More debug info
2015aroras Dec 20, 2024
58bef95
Merge branch 'main' into 32B
2015aroras Dec 20, 2024
c84708f
Disable pre_download, set higher thread count
2015aroras Dec 20, 2024
56c4ab3
FSDP with AC of selected ops
2015aroras Dec 20, 2024
b5f3a86
Back to AC of just feedforward layers
2015aroras Dec 21, 2024
3fbdeb0
Add new inloop evals
2015aroras Dec 21, 2024
b335cdf
Turn off NCCL debug
2015aroras Dec 21, 2024
30f8f59
Merge branch 'main' into 32B
2015aroras Dec 21, 2024
e17e4b8
Make checkpoint writing respect thread count config
2015aroras Dec 22, 2024
ba49cc4
Add skip step optimizer changes
2015aroras Dec 22, 2024
25ede33
Update 32B config with skip step adamw
2015aroras Dec 22, 2024
ac01e83
Try fix skip step optimizer
2015aroras Dec 22, 2024
ddd61ac
Try manual _std_mean impl
2015aroras Dec 22, 2024
973a26c
Add skip step fixes
2015aroras Dec 22, 2024
baf5700
Have separate save and load thread counts
2015aroras Dec 22, 2024
b6762d8
Decrease threads used for saving
2015aroras Dec 22, 2024
d98f06d
Skipped steps and automatic spike analysis
dirkgr Dec 22, 2024
4a68e9e
Use compile=True for optimizer
2015aroras Dec 22, 2024
d81cd12
Make gcs upload pass generation
2015aroras Dec 23, 2024
0a04034
Update CHANGELOG
2015aroras Dec 23, 2024
5acc7eb
Run formatter
2015aroras Dec 23, 2024
213b03e
Make generation 0 when object does not exist
2015aroras Dec 23, 2024
b4994b0
Merge branch 'shanea/fix-upload-retries' into 32B
2015aroras Dec 23, 2024
3b84351
Run formatting
2015aroras Dec 23, 2024
178d9ad
Remove unneeded import
2015aroras Dec 23, 2024
0b737aa
Add missing reload
2015aroras Dec 23, 2024
3e6f9f1
Updated notebook
dirkgr Dec 23, 2024
663d63a
Updated dashboard
dirkgr Dec 24, 2024
496919b
Update the notebook
dirkgr Dec 24, 2024
a1854bd
Updated notebook
dirkgr Dec 27, 2024
f2de5f4
Retry on bad request
dirkgr Dec 28, 2024
33c0f58
Add some more retries
dirkgr Dec 28, 2024
86afc43
Updated the notebook
dirkgr Dec 29, 2024
2e45a79
Update the dashboard
dirkgr Dec 30, 2024
e4e8fbb
Fix the way we use the step in the optimizer
dirkgr Dec 31, 2024
146caaf
Dashboard update
dirkgr Dec 31, 2024
393a462
Update dashboard
dirkgr Jan 3, 2025
d39c59d
New report
dirkgr Jan 6, 2025
16983c4
Dashboard update
dirkgr Jan 7, 2025
5e4d04f
No more ephemeral checkpoints
dirkgr Jan 8, 2025
eba0418
Don't eval so much
dirkgr Jan 8, 2025
5605001
When you wait on someone, you bring them water.
dirkgr Jan 8, 2025
7ce7efa
Updating the dashboard
dirkgr Jan 8, 2025
05aa94f
Reorder ranks in GCP
dirkgr Jan 9, 2025
9c86bf9
Rank 0 needs to remain rank 0
dirkgr Jan 9, 2025
e27b91d
Slightly less checkpointing
dirkgr Jan 9, 2025
52b9b77
Revert "Slightly less checkpointing"
dirkgr Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/olmo_core/nn/functional/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def cross_entropy_loss(
_fused_cross_entropy_loss: Optional[Callable] = None

try:
import olmo_core.triton.cross_entropy_loss as triton_ce_loss
# import olmo_core.triton.cross_entropy_loss as triton_ce_loss
#_fused_cross_entropy_loss = triton_ce_loss.cross_entropy_loss

# import flash_attn.ops.triton.cross_entropy as flash_attn_ce # type: ignore

_fused_cross_entropy_loss = triton_ce_loss.cross_entropy_loss
import flash_attn.ops.triton.cross_entropy as flash_attn_ce # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our in-house triton CE loss was copied directly from the flash-attn repo, so I don't see the point of this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I took this back out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I want compiling and fused loss at the same time?

_fused_cross_entropy_loss = flash_attn_ce.cross_entropy_loss
except ModuleNotFoundError:
pass

Expand Down
15 changes: 9 additions & 6 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,19 +460,22 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
)

@classmethod
def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 26B OLMo model config.
A 32B OLMo model config.
"""
d_model = 5120
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a very narrow model then... are you sure about that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a clone of Qwen 32. The tradeoffs are, narrow d_model, wide FFN, GQA, lots of layers.

return cls.llama_like(
vocab_size=vocab_size,
d_model=7168,
n_layers=kwargs.pop("n_layers", 40),
n_heads=kwargs.pop("n_heads", 56),
d_model=d_model,
n_layers=kwargs.pop("n_layers", 64),
n_heads=kwargs.pop("n_heads", 40),
n_kv_heads=kwargs.pop("n_kv_heads", 8),
block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm),
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 1024),
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 512),
hidden_size_multiplier=kwargs.pop("hidden_size_multiplier", 27648 / (8 * d_model / 3)),
layer_norm_eps=1e-6,
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/train/callbacks/evaluator_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def build(self, trainer: "Trainer") -> Optional[Callback]:
eval_batch_size = (
self.eval_batch_size
if self.eval_batch_size is not None
else trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group)
else 2 * trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could instead passed an updated evaluator callback in OLMo2-32B.py:

.with_callback(
    "lm_evaluator",
    LMEvaluatorCallbackConfig(
        eval_batch_size=<whatever you want>,
        eval_dataset=NumpyDatasetConfig.from_data_mix(
            DataMix.v3_small_ppl_validation,
            name=NumpyDatasetType.padded_fsl,
            mix_base_dir=root_dir,
            sequence_length=dataset_config.effective_sequence_length,
            tokenizer=tokenizer_config,
            work_dir=get_work_dir(root_dir),
        ),
        eval_interval=1000,
    ),

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but I think this is better. I think we can default to 2x the training batch size. It should always work.

)
dataset = self.eval_dataset.build()
if not isinstance(dataset, NumpyPaddedFSLDataset):
Expand Down
69 changes: 60 additions & 9 deletions src/scripts/train/OLMo2-26B.py → src/scripts/train/OLMo2-32B.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
TransformerDataParallelConfig,
)
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback
from olmo_core.train import TrainerConfig, Duration, DurationUnit
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback, \
DownstreamEvaluatorCallbackConfig

log = logging.getLogger(__name__)


def build_model_config(common: CommonComponents) -> TransformerConfig:
compile = True
return TransformerConfig.olmo2_26B(
return TransformerConfig.olmo2_32B(
vocab_size=common.tokenizer.padded_vocab_size(),
compile=compile,
fused_ops=False,
Expand Down Expand Up @@ -52,20 +53,23 @@ def build_optim_config(common: CommonComponents) -> AdamWConfig:


def build_trainer_config(common: CommonComponents) -> TrainerConfig:
project_name = "peteish32"
return (
TrainerConfig(
save_folder=common.save_folder,
save_folder=f"gs://ai2-llm/checkpoints/{project_name}/",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It defaults to something under my name? Not what we want for an official run?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially if we swap babysitting responsibilities during the run

rank_microbatch_size=4 * 4096,
save_overwrite=True,
metrics_collect_interval=10,
cancel_check_interval=1,
cancel_check_interval=10,
z_loss_multiplier=1e-5,
compile_loss=True,
fused_loss=True,
compile_loss=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the trepidation about the different loss implementations, but the way it was before was the most performant. This way will be slower and have a higher memory footprint.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have some certainty that this will do the right thing? What happens if we take the 13B from a late checkpoint and run it?

max_duration=Duration(int(6.5e12), DurationUnit.tokens)
)
.with_callback(
"checkpointer",
CheckpointerCallback(
save_interval=10_000,
save_interval=1000,
ephemeral_save_interval=250,
save_async=True,
),
Expand All @@ -75,7 +79,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
CometCallback(
name=common.run_name,
workspace="ai2",
project="OLMo-core-26B",
project=project_name,
enabled=True,
cancel_check_interval=10,
),
Expand All @@ -85,10 +89,57 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
WandBCallback(
name=common.run_name,
entity="ai2-llm",
project="OLMo-core-26B",
project=project_name,
enabled=False,
dirkgr marked this conversation as resolved.
Show resolved Hide resolved
cancel_check_interval=10,
),
).with_callback(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just add this to the common callbacks.

"lm_evaluator": LMEvaluatorCallbackConfig(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that we want these for everything. Default should probably be only the new, blessed ones.

"downstream_evaluator",
DownstreamEvaluatorCallbackConfig(
tasks=[
# MMLU for backwards compatibility
"mmlu_stem_mc_5shot",
"mmlu_humanities_mc_5shot",
"mmlu_social_sciences_mc_5shot",
"mmlu_other_mc_5shot",

# MMLU test
"mmlu_stem_mc_5shot_test",
"mmlu_humanities_mc_5shot_test",
"mmlu_social_sciences_mc_5shot_test",
"mmlu_other_mc_5shot_test",

# Core 12 tasks for backwards compatibility
"arc_challenge",
"arc_easy",
"basic_arithmetic",
"boolq",
"commonsense_qa",
"copa",
"hellaswag",
"openbook_qa",
"piqa",
"sciq",
"social_iqa",
"winogrande",

# Core 12 tasks 5-shot
"arc_challenge_rc_5shot",
"arc_easy_rc_5shot",
#"basic_arithmetic_rc_5shot", # doesn't exist
#"boolq_rc_5shot", # we don't like it
"csqa_rc_5shot",
#"copa_rc_5shot", # doesn't exist
"hellaswag_rc_5shot",
"openbookqa_rc_5shot",
"piqa_rc_5shot",
#"sciq_rc_5shot", # doesn't exist
"socialiqa_rc_5shot",
"winogrande_rc_5shot"
],
tokenizer=common.tokenizer,
eval_interval=1000,
),
)
)

Expand Down
Loading