diff --git a/CHANGELOG.md b/CHANGELOG.md index 8276edc6..13f3c37a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,14 +15,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`. - Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint. - Added a callback for sending Slack notifications. +- Added `SkipStepAdamW` optimizer. ### Changed - Changed storage of shared shard state in sharded checkpoints from smallest shard to lowest rank (normally 0). +- Changed underlying AdamW implementation. ### Fixed - Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch. +- Fixed bug where GCS upload does not retry on transient failures. ## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27 diff --git a/src/olmo_core/distributed/checkpoint/__init__.py b/src/olmo_core/distributed/checkpoint/__init__.py index 70d0e542..068fc3e8 100644 --- a/src/olmo_core/distributed/checkpoint/__init__.py +++ b/src/olmo_core/distributed/checkpoint/__init__.py @@ -63,6 +63,7 @@ def save_state_dict( state_dict: Dict[str, Any], process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, ): """ Save an arbitrary state dictionary to a distributed format that can loaded again with @@ -80,7 +81,7 @@ def save_state_dict( dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite) dist_cp.state_dict_saver.save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count), process_group=process_group, ) @@ -93,6 +94,7 @@ def save_model_and_optim_state( *, process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, ) -> None: """ Save model and optimizer state dictionaries. The model state can be a sharded model, in which @@ -123,7 +125,7 @@ def save_model_and_optim_state( planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) dist_cp.state_dict_saver.save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count), process_group=process_group, planner=planner, ) @@ -137,6 +139,7 @@ def async_save_model_and_optim_state( *, process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, ) -> Future[None]: """ An async version of :func:`save_model_and_optim_state()`. @@ -148,7 +151,7 @@ def async_save_model_and_optim_state( planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) return dist_cp.state_dict_saver.async_save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count), process_group=process_group, planner=planner, ) @@ -164,6 +167,7 @@ def load_model_and_optim_state( key_mapping: Optional[Dict[str, str]] = None, pre_download: bool = False, work_dir: Optional[PathOrStr] = None, + thread_count: Optional[int] = None, ): """ Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`. @@ -201,10 +205,13 @@ def load_model_and_optim_state( This dictionary should map current keys to keys in the checkpoint to be loaded. :param pre_download: Download and cache relevant remote checkpoint files before trying to read from them. :param work_dir: A working directory for caching files/directories. + :param thread_count: Set the number of threads used for certain operations. """ dir = normalize_path(dir) state_dict = _prepare_state_dict(model, optim, process_group=process_group) - reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir) + reader = RemoteFileSystemReader( + dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir + ) if key_mapping is not None: metadata = reader.read_metadata() diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index 93b03fb8..24a0a784 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -92,6 +92,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut "enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0", ) set_env_var("NCCL_SOCKET_IFNAME", "enp0s12") + set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET") if backend_supports_cuda(backend): # Set CUDA device. diff --git a/src/olmo_core/internal/common.py b/src/olmo_core/internal/common.py index 1c2d426a..d660f0a0 100644 --- a/src/olmo_core/internal/common.py +++ b/src/olmo_core/internal/common.py @@ -102,6 +102,7 @@ def build_launch_config( # Setup python environment. "conda shell.bash activate base", "pip install -e '.[all]'", + "pip install --upgrade beaker-py", # Quickly try a new version of PyTorch like this # "pip install --upgrade --pre torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121", "pip freeze", diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 4d1e9ee7..1015a60d 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -130,6 +130,7 @@ def build_common_components( root_dir=root_dir, cmd=[script, cmd_to_launch, run_name, cluster, *overrides], cluster=cluster, + nccl_debug=False, ) beaker_user = get_beaker_username() diff --git a/src/olmo_core/io.py b/src/olmo_core/io.py index 5fda2741..ccd7816c 100644 --- a/src/olmo_core/io.py +++ b/src/olmo_core/io.py @@ -533,8 +533,13 @@ def _get_gcs_client(): def _gcs_is_retriable(exc: Exception) -> bool: from google.api_core.retry import if_transient_error + from google.api_core.exceptions import BadRequest - return if_transient_error(exc) or isinstance(exc, requests.exceptions.Timeout) + return ( + if_transient_error(exc) or + isinstance(exc, requests.exceptions.Timeout) or + isinstance(exc, BadRequest) # Weird choice, but Google throws this transiently + ) def _get_gcs_retry(): @@ -577,7 +582,7 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) try: - blob.reload() + blob.reload(retry=_get_gcs_retry()) except NotFound: raise FileNotFoundError(f"gs://{bucket_name}/{key}") return blob.download_as_bytes( @@ -590,11 +595,21 @@ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = storage_client = _get_gcs_client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) - if not save_overwrite and blob.exists(): - raise FileExistsError( - f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." - ) - blob.upload_from_filename(source, retry=_get_gcs_conditional_retry()) + + generation: int = 0 + if blob.exists(retry=_get_gcs_retry()): + if not save_overwrite: + raise FileExistsError( + f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." + ) + + blob.reload(retry=_get_gcs_retry()) + assert blob.generation is not None + generation = blob.generation + + blob.upload_from_filename( + source, if_generation_match=generation, retry=_get_gcs_conditional_retry() + ) @retriable() diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 435142f3..db660faf 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -323,6 +323,14 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: ] if torchrun: + entrypoint_script.append( + "export BEAKER_REPLICA_RANK=$(" + "python src/scripts/reorder_ranks_in_gcp.py " + "${BEAKER_REPLICA_RANK} " + "${BEAKER_REPLICA_COUNT} " + "${BEAKER_LEADER_REPLICA_HOSTNAME}" + ")" + ) entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"') else: entrypoint_script.append('python "$@"') diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index f77ec4e2..44d34252 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -460,19 +460,22 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": ) @classmethod - def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": + def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": """ - A 26B OLMo model config. + A 32B OLMo model config. """ + d_model = 5120 return cls.llama_like( vocab_size=vocab_size, - d_model=7168, - n_layers=kwargs.pop("n_layers", 40), - n_heads=kwargs.pop("n_heads", 56), + d_model=d_model, + n_layers=kwargs.pop("n_layers", 64), + n_heads=kwargs.pop("n_heads", 40), + n_kv_heads=kwargs.pop("n_kv_heads", 8), block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm), qk_norm=kwargs.pop("qk_norm", True), rope_theta=kwargs.pop("rope_theta", 500_000), - hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 1024), + hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 512), + hidden_size_multiplier=kwargs.pop("hidden_size_multiplier", 27648 / (8 * d_model / 3)), layer_norm_eps=1e-6, **kwargs, ) diff --git a/src/olmo_core/optim/__init__.py b/src/olmo_core/optim/__init__.py index 0e1cf986..c514a24d 100644 --- a/src/olmo_core/optim/__init__.py +++ b/src/olmo_core/optim/__init__.py @@ -1,5 +1,5 @@ from .adam import AdamConfig -from .adamw import AdamWConfig +from .adamw import AdamW, AdamWConfig, SkipStepAdamW, SkipStepAdamWConfig from .config import OptimConfig, OptimGroupOverride from .lion import Lion, LionConfig, SkipStepLion, SkipStepLionConfig from .scheduler import ( @@ -19,6 +19,9 @@ "SkipStepOptimizer", "AdamWConfig", "AdamConfig", + "AdamW", + "SkipStepAdamWConfig", + "SkipStepAdamW", "LionConfig", "Lion", "SkipStepLionConfig", diff --git a/src/olmo_core/optim/adamw.py b/src/olmo_core/optim/adamw.py index bc5f1e46..a9e014da 100644 --- a/src/olmo_core/optim/adamw.py +++ b/src/olmo_core/optim/adamw.py @@ -1,14 +1,14 @@ -import math from dataclasses import dataclass from typing import Optional, Tuple, Type import torch import torch.nn as nn +from torch.optim.optimizer import Optimizer from .config import OptimConfig +from .skip_step_optimizer import SkipStepOptimizer -# TODO: use this when we implement a "skip step" version of AdamW. def adamw_step( p: nn.Parameter, *, @@ -18,7 +18,7 @@ def adamw_step( weight_decay: float, exp_avg: torch.Tensor, exp_avg_sq: torch.Tensor, - step: int, + step: torch.Tensor, step_factor: torch.Tensor, ): if p.grad is None: @@ -34,19 +34,139 @@ def adamw_step( exp_avg_sq.mul_(1 - step_factor * (1 - beta2)) exp_avg_sq.add_(step_factor * p.grad * p.grad, alpha=1 - beta2) - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step + bias_correction1 = 1 - beta1**(step + 1) + bias_correction2 = 1 - beta2**(step + 1) step_size = lr / bias_correction1 - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + denom = (exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) update = -step_size * torch.div(exp_avg, denom) update.mul_(step_factor) p.add_(update) +class AdamW(Optimizer): + """ + An implementation of the AdamW optimizer. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + ): + assert lr > 0.0 + assert all([0.0 <= beta <= 1.0 for beta in betas]) + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, fused=fused + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None) -> None: + if closure is not None: + with torch.enable_grad(): + closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state["step"] = torch.tensor(0.0, dtype=torch.float32, device=p.device) + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + + adamw_step( + p, + lr=group["lr"], + betas=group["betas"], + eps=group["eps"], + weight_decay=group["weight_decay"], + exp_avg=state["exp_avg"], + exp_avg_sq=state["exp_avg_sq"], + step=state["step"], + step_factor=torch.tensor(1.0, device=p.device), + ) + + +class SkipStepAdamW(SkipStepOptimizer): + """ + A "skip step" version of :class:`AdamW`. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + rolling_interval_length: int = 128, + sigma_factor: int = 6, + ) -> None: + assert lr > 0.0 + assert all([0.0 <= beta <= 1.0 for beta in betas]) + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, fused=fused + ) + super().__init__( + params, + defaults, + rolling_interval_length=rolling_interval_length, + sigma_factor=sigma_factor, + ) + self._step_skipped: Optional[torch.Tensor] = None + + @property + def step_skipped(self) -> torch.Tensor: + if self._step_skipped is not None: + return self._step_skipped + else: + return torch.tensor(0.0) + + @torch.no_grad() + def step(self, closure=None) -> None: + if closure is not None: + with torch.enable_grad(): + closure() + + step_factor = self.get_step_factor() + self._step_skipped = 1 - step_factor + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state["step"] = torch.tensor(0.0, dtype=torch.float32, device=p.device) + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + + adamw_step( + p, + lr=group["lr"], + betas=group["betas"], + eps=group["eps"], + weight_decay=group["weight_decay"], + exp_avg=state["exp_avg"], + exp_avg_sq=state["exp_avg_sq"], + step=state["step"], + step_factor=step_factor, + ) + + @dataclass class AdamWConfig(OptimConfig): # NOTE: omagaconf doesn't like "OptimConfig[torch.optim.AdamW]" """ @@ -61,5 +181,25 @@ class AdamWConfig(OptimConfig): # NOTE: omagaconf doesn't like "OptimConfig[tor fused: Optional[bool] = None @classmethod - def optimizer(cls) -> Type[torch.optim.AdamW]: - return torch.optim.AdamW + def optimizer(cls) -> Type[AdamW]: + return AdamW + + +@dataclass +class SkipStepAdamWConfig(OptimConfig): + """ + Configuration class for building a :class:`SkipStepAdamW` optimizer. + """ + + lr: float = 1e-3 + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 1e-2 + foreach: Optional[bool] = None + fused: Optional[bool] = None + rolling_interval_length: int = 128 + sigma_factor: int = 6 + + @classmethod + def optimizer(cls) -> Type[SkipStepAdamW]: + return SkipStepAdamW diff --git a/src/olmo_core/optim/skip_step_optimizer.py b/src/olmo_core/optim/skip_step_optimizer.py index 98ada1bd..40b0b034 100644 --- a/src/olmo_core/optim/skip_step_optimizer.py +++ b/src/olmo_core/optim/skip_step_optimizer.py @@ -91,17 +91,20 @@ def get_step_factor(self) -> torch.Tensor: The tensor can be used within the optimizer's step computation to essentially skip a step without a host-device sync. """ - if len(self._losses) < max(20, self.rolling_interval_length // 2): + if len(self._losses) < max(2, self.rolling_interval_length // 2): return torch.tensor(1.0).to(device=self.device, non_blocking=True) loss_std, loss_mean = torch.std_mean(torch.stack(self._losses[:-1])) if self._grad_norms: grad_norm_std, grad_norm_mean = torch.std_mean(torch.stack(self._grad_norms[:-1])) - return ((self.latest_loss - loss_mean) <= self.sigma_factor * loss_std) and ( - (self.latest_grad_norm - grad_norm_mean) <= self.sigma_factor * grad_norm_std + step_factor = torch.logical_and( + (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std, + (self.latest_grad_norm - grad_norm_mean) <= self.sigma_factor * grad_norm_std, ) else: - return (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std + step_factor = (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std + + return step_factor.float() @property def step_skipped(self) -> torch.Tensor: diff --git a/src/olmo_core/train/__init__.py b/src/olmo_core/train/__init__.py index ba59008b..e14f3dc7 100644 --- a/src/olmo_core/train/__init__.py +++ b/src/olmo_core/train/__init__.py @@ -75,7 +75,7 @@ def prepare_training_environment( *, seed: Optional[int] = None, backend: Optional[str] = "cpu:gloo,cuda:nccl", - timeout: timedelta = timedelta(minutes=10), + timeout: timedelta = timedelta(minutes=30), log_filter_type: Optional[LogFilterType] = None, ): """ diff --git a/src/olmo_core/train/callbacks/evaluator_callback.py b/src/olmo_core/train/callbacks/evaluator_callback.py index ea2bfa58..556492b7 100644 --- a/src/olmo_core/train/callbacks/evaluator_callback.py +++ b/src/olmo_core/train/callbacks/evaluator_callback.py @@ -129,7 +129,7 @@ def build(self, trainer: "Trainer") -> Optional[Callback]: eval_batch_size = ( self.eval_batch_size if self.eval_batch_size is not None - else trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) + else 2 * trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) ) dataset = self.eval_dataset.build() if not isinstance(dataset, NumpyPaddedFSLDataset): diff --git a/src/olmo_core/train/callbacks/grad_clipper.py b/src/olmo_core/train/callbacks/grad_clipper.py index 0a0ebbcb..97ad3b8d 100644 --- a/src/olmo_core/train/callbacks/grad_clipper.py +++ b/src/olmo_core/train/callbacks/grad_clipper.py @@ -4,6 +4,7 @@ import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from olmo_core.distributed.utils import get_local_tensor from olmo_core.optim import SkipStepOptimizer from .callback import Callback @@ -26,6 +27,8 @@ def pre_optim_step(self): self.trainer.model.parameters(), self.max_grad_norm ) + grad_norm = get_local_tensor(grad_norm.detach()) + # NOTE: grad norm is already reduced over ranks, so we set `reduce_type` to `None`. self.trainer.record_metric("optim/total grad norm", grad_norm, reduce_type=None) if isinstance(self.trainer.optim, SkipStepOptimizer): diff --git a/src/olmo_core/train/checkpoint.py b/src/olmo_core/train/checkpoint.py index b262fe81..a10496ca 100644 --- a/src/olmo_core/train/checkpoint.py +++ b/src/olmo_core/train/checkpoint.py @@ -48,6 +48,8 @@ class CheckpointerConfig(Config): work_dir: Optional[str] = None save_overwrite: Optional[bool] = None pre_download: bool = False + save_thread_count: Optional[int] = None + load_thread_count: Optional[int] = None def build(self, process_group: Optional[dist.ProcessGroup] = None, **kwargs) -> "Checkpointer": kwargs = {**self.as_dict(exclude_none=True, recurse=False), **kwargs} @@ -75,6 +77,8 @@ class Checkpointer: save_overwrite: bool = False pre_download: bool = False process_group: Optional[dist.ProcessGroup] = None + save_thread_count: Optional[int] = None + load_thread_count: Optional[int] = None def __post_init__(self): self.work_dir = Path(self.work_dir) @@ -100,6 +104,7 @@ def save(self, dir: PathOrStr, model: nn.Module, optim: Optimizer, train_state: optim, process_group=self.process_group, save_overwrite=self.save_overwrite, + thread_count=self.save_thread_count, ) self._save_metadata(dir, CheckpointMetadata()) @@ -129,6 +134,7 @@ def save_async( optim, process_group=self.process_group, save_overwrite=self.save_overwrite, + thread_count=self.save_thread_count, ) def done_callback(fut: Future): @@ -179,6 +185,7 @@ def load( key_mapping=key_mapping, pre_download=is_url(dir) and self.pre_download, work_dir=self.work_dir, + thread_count=self.load_thread_count, ) return trainer_state @@ -299,7 +306,7 @@ def _save_train_state(self, dir: PathOrStr, wd: Path, train_state: Dict[str, Any # NOTE: if 'dir' is a URL, the 'wd' will be a different temp dir for each rank. if is_url(dir) or get_fs_local_rank() == 0: train_dir.mkdir(exist_ok=True, parents=True) - wait_for(train_dir.exists, description=f"Waiting on '{train_dir}' to be created...") + wait_for(train_dir.exists, description=f"Waiting for '{train_dir}' to be created...") torch.save(train_state, train_dir / f"rank{get_rank()}.pt") def _save_metadata(self, dir: PathOrStr, metadata: CheckpointMetadata): diff --git a/src/scripts/reorder_ranks_in_gcp.py b/src/scripts/reorder_ranks_in_gcp.py new file mode 100644 index 00000000..dea8f0d4 --- /dev/null +++ b/src/scripts/reorder_ranks_in_gcp.py @@ -0,0 +1,71 @@ +import sys +from datetime import timedelta + +import requests +import torch.distributed as dist +import argparse + +from urllib3.exceptions import MaxRetryError, NameResolutionError + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("rank", type=int, help="Worker number") + parser.add_argument("world_size", type=int, help="Total number of workers") + parser.add_argument("master_addr", help="Hostname of worker 0") + parser.add_argument("--master_port", type=int, default=29501, help="Port for TCPStore") + parser.add_argument("--debug", action="store_true", help="Enable debug mode (outside of GCP)") + args = parser.parse_args() + + # Create or connect to the store + store = dist.TCPStore( + host_name=args.master_addr, + port=args.master_port, + world_size=args.world_size, + is_master=(args.rank == 0) + ) + + # Get our own host id + if args.debug: + import socket + host_id = f"{socket.gethostname()}_{args.rank}" + else: + try: + response = requests.get( + "http://metadata.google.internal/computeMetadata/v1/instance/attributes/physical_host", + headers={"Metadata-Flavor": "Google"} + ) + assert response.status_code == 200 + host_id = response.text.strip() + except requests.exceptions.ConnectionError as e: + # Unwrap the exception + e = e.args[0] + if not isinstance(e, MaxRetryError): + raise + e = e.reason + if not isinstance(e, NameResolutionError): + raise + # Seems we called this outside of GCP, so we do nothing and just print our original rank. + print(args.rank) + sys.exit(0) + + # Find the index of our host id + store.set(f"node_{args.rank}_hostid", host_id) + store.wait([f"node_{i}_hostid" for i in range(args.world_size)]) + all_host_ids = [store.get(f"node_{i}_hostid").decode("UTF-8") for i in range(args.world_size)] + assert len(set(all_host_ids)) == len(all_host_ids) + assert host_id in all_host_ids + rank0_host_id = all_host_ids[0] + all_host_ids.sort() + # Rank 0 needs to remain rank 0, so we reshuffle around it + rank0_index = all_host_ids.index(rank0_host_id) + all_host_ids = all_host_ids[rank0_index:] + all_host_ids[:rank0_index] + print(all_host_ids.index(host_id)) + + # Make sure we're all done before exiting + store.set(f"node_{args.rank}_done", host_id) + store.wait([f"node_{i}_done" for i in range(args.world_size)]) + + +if __name__ == "__main__": + main() diff --git a/src/scripts/train/OLMo2-26B.py b/src/scripts/train/OLMo2-26B.py deleted file mode 100644 index 6453407c..00000000 --- a/src/scripts/train/OLMo2-26B.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Train a 26B OLMo model. Run this script without any arguments to see usage info. -""" - -import logging - -from olmo_core.config import DType -from olmo_core.distributed.parallel import DataParallelType -from olmo_core.float8 import Float8Config -from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.transformer import ( - TransformerActivationCheckpointingConfig, - TransformerActivationCheckpointingMode, - TransformerConfig, - TransformerDataParallelConfig, -) -from olmo_core.optim import AdamWConfig, OptimGroupOverride -from olmo_core.train import TrainerConfig -from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback - -log = logging.getLogger(__name__) - - -def build_model_config(common: CommonComponents) -> TransformerConfig: - compile = True - return TransformerConfig.olmo2_26B( - vocab_size=common.tokenizer.padded_vocab_size(), - compile=compile, - fused_ops=False, - use_flash=not compile, - dp_config=TransformerDataParallelConfig( - name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 - ), - ac_config=TransformerActivationCheckpointingConfig( - mode=TransformerActivationCheckpointingMode.full - ), - float8_config=Float8Config(compile=compile, enabled=False), - ) - - -def build_optim_config(common: CommonComponents) -> AdamWConfig: - del common - return AdamWConfig( - lr=6e-4, - weight_decay=0.1, - betas=(0.9, 0.95), - group_overrides=[ - OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) - ], - fused=True, - ) - - -def build_trainer_config(common: CommonComponents) -> TrainerConfig: - return ( - TrainerConfig( - save_folder=common.save_folder, - rank_microbatch_size=4 * 4096, - save_overwrite=True, - metrics_collect_interval=10, - cancel_check_interval=1, - z_loss_multiplier=1e-5, - compile_loss=True, - ) - .with_callback( - "checkpointer", - CheckpointerCallback( - save_interval=10_000, - ephemeral_save_interval=250, - save_async=True, - ), - ) - .with_callback( - "comet", - CometCallback( - name=common.run_name, - workspace="ai2", - project="OLMo-core-26B", - enabled=True, - cancel_check_interval=10, - ), - ) - .with_callback( - "wandb", - WandBCallback( - name=common.run_name, - entity="ai2-llm", - project="OLMo-core-26B", - enabled=False, - cancel_check_interval=10, - ), - ) - ) - - -if __name__ == "__main__": - main( - global_batch_size=2048 * 4096, - model_config_builder=build_model_config, - optim_config_builder=build_optim_config, - trainer_config_builder=build_trainer_config, - ) diff --git a/src/scripts/train/OLMo2-32B.ipynb b/src/scripts/train/OLMo2-32B.ipynb new file mode 100644 index 00000000..07d9e673 --- /dev/null +++ b/src/scripts/train/OLMo2-32B.ipynb @@ -0,0 +1,1791 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-01-08T08:22:29.647389Z", + "start_time": "2025-01-08T08:22:29.367715Z" + } + }, + "source": [ + "import os\n", + "from comet_ml.api import API\n", + "\n", + "comet_api = API(os.environ[\"COMETML_API_KEY\"])\n" + ], + "outputs": [], + "execution_count": 16 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-08T08:22:30.546725Z", + "start_time": "2025-01-08T08:22:29.653627Z" + } + }, + "cell_type": "code", + "source": [ + "exps = {\n", + " \"peteish32\": comet_api.get_experiments(\"ai2\", \"peteish32\", \"peteish32\"),\n", + " \"peteish13\": comet_api.get_experiments(\"ai2\", \"olmo-2-1124-13b\", \"OLMo-2-1124-13B-stage-1\"),\n", + " \"peteish7\": comet_api.get_experiments(\"ai2\", \"olmo-core-7b\", \"peteish7\")\n", + "}\n", + "\n", + "print(repr({k: len(v) for k, v in exps.items()}))" + ], + "id": "2c17abe415dabf07", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'peteish32': 40, 'peteish13': 75, 'peteish7': 13}\n" + ] + } + ], + "execution_count": 17 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-08T08:22:44.498869Z", + "start_time": "2025-01-08T08:22:30.594052Z" + } + }, + "cell_type": "code", + "source": [ + "# print available metrics\n", + "\n", + "for name, es in exps.items():\n", + " metrics = set()\n", + " for exp in es:\n", + " for summary in exp.get_metrics_summary():\n", + " metrics.add(summary[\"name\"])\n", + " metrics = list(metrics)\n", + " metrics.sort()\n", + "\n", + " print(f\"{name}:\")\n", + " for metric in metrics:\n", + " print(\"\\t\", metric)" + ], + "id": "dc7c5e3c92741b89", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "peteish32:\n", + "\t data/sequence length\n", + "\t eval/downstream/arc_challenge (BPB)\n", + "\t eval/downstream/arc_challenge (CE loss)\n", + "\t eval/downstream/arc_challenge (length-normalized accuracy)\n", + "\t eval/downstream/arc_challenge (log soft loss)\n", + "\t eval/downstream/arc_challenge (soft loss)\n", + "\t eval/downstream/arc_challenge_rc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_challenge_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (accuracy)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (soft loss)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (accuracy)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (soft loss)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy (BPB)\n", + "\t eval/downstream/arc_easy (CE loss)\n", + "\t eval/downstream/arc_easy (accuracy)\n", + "\t eval/downstream/arc_easy (log soft loss)\n", + "\t eval/downstream/arc_easy (soft loss)\n", + "\t eval/downstream/arc_easy_rc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_rc_5shot (accuracy)\n", + "\t eval/downstream/arc_easy_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (accuracy)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (accuracy)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (soft loss)\n", + "\t eval/downstream/basic_arithmetic (BPB)\n", + "\t eval/downstream/basic_arithmetic (CE loss)\n", + "\t eval/downstream/basic_arithmetic (accuracy)\n", + "\t eval/downstream/basic_arithmetic (log soft loss)\n", + "\t eval/downstream/basic_arithmetic (soft loss)\n", + "\t eval/downstream/boolq (BPB)\n", + "\t eval/downstream/boolq (CE loss)\n", + "\t eval/downstream/boolq (accuracy)\n", + "\t eval/downstream/boolq (log soft loss)\n", + "\t eval/downstream/boolq (soft loss)\n", + "\t eval/downstream/boolq_val_mc_5shot (BPB)\n", + "\t eval/downstream/boolq_val_mc_5shot (CE loss)\n", + "\t eval/downstream/boolq_val_mc_5shot (accuracy)\n", + "\t eval/downstream/boolq_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/boolq_val_mc_5shot (soft loss)\n", + "\t eval/downstream/boolq_val_rc_5shot (BPB)\n", + "\t eval/downstream/boolq_val_rc_5shot (CE loss)\n", + "\t eval/downstream/boolq_val_rc_5shot (accuracy)\n", + "\t eval/downstream/boolq_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/boolq_val_rc_5shot (soft loss)\n", + "\t eval/downstream/commonsense_qa (BPB)\n", + "\t eval/downstream/commonsense_qa (CE loss)\n", + "\t eval/downstream/commonsense_qa (length-normalized accuracy)\n", + "\t eval/downstream/commonsense_qa (log soft loss)\n", + "\t eval/downstream/commonsense_qa (soft loss)\n", + "\t eval/downstream/copa (BPB)\n", + "\t eval/downstream/copa (CE loss)\n", + "\t eval/downstream/copa (accuracy)\n", + "\t eval/downstream/copa (log soft loss)\n", + "\t eval/downstream/copa (soft loss)\n", + "\t eval/downstream/csqa_rc_5shot (BPB)\n", + "\t eval/downstream/csqa_rc_5shot (CE loss)\n", + "\t eval/downstream/csqa_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/csqa_rc_5shot (log soft loss)\n", + "\t eval/downstream/csqa_rc_5shot (soft loss)\n", + "\t eval/downstream/csqa_val_mc_5shot (BPB)\n", + "\t eval/downstream/csqa_val_mc_5shot (CE loss)\n", + "\t eval/downstream/csqa_val_mc_5shot (accuracy)\n", + "\t eval/downstream/csqa_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/csqa_val_mc_5shot (soft loss)\n", + "\t eval/downstream/csqa_val_rc_5shot (BPB)\n", + "\t eval/downstream/csqa_val_rc_5shot (CE loss)\n", + "\t eval/downstream/csqa_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/csqa_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/csqa_val_rc_5shot (soft loss)\n", + "\t eval/downstream/hellaswag (BPB)\n", + "\t eval/downstream/hellaswag (CE loss)\n", + "\t eval/downstream/hellaswag (length-normalized accuracy)\n", + "\t eval/downstream/hellaswag (log soft loss)\n", + "\t eval/downstream/hellaswag (soft loss)\n", + "\t eval/downstream/hellaswag_rc_5shot (BPB)\n", + "\t eval/downstream/hellaswag_rc_5shot (CE loss)\n", + "\t eval/downstream/hellaswag_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/hellaswag_rc_5shot (log soft loss)\n", + "\t eval/downstream/hellaswag_rc_5shot (soft loss)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (BPB)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (CE loss)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (accuracy)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (soft loss)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (BPB)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (CE loss)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (BPB)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (CE loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (log soft loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (soft loss)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (BPB)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_other_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (BPB)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (CE loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (log soft loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (soft loss)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (BPB)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (BPB)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (CE loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (log soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (BPB)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (BPB)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (CE loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (log soft loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (soft loss)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (BPB)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (soft loss)\n", + "\t eval/downstream/openbook_qa (BPB)\n", + "\t eval/downstream/openbook_qa (CE loss)\n", + "\t eval/downstream/openbook_qa (length-normalized accuracy)\n", + "\t eval/downstream/openbook_qa (log soft loss)\n", + "\t eval/downstream/openbook_qa (soft loss)\n", + "\t eval/downstream/openbookqa_rc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_rc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/openbookqa_rc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_rc_5shot (soft loss)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (accuracy)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (soft loss)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (soft loss)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (accuracy)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (soft loss)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (soft loss)\n", + "\t eval/downstream/piqa (BPB)\n", + "\t eval/downstream/piqa (CE loss)\n", + "\t eval/downstream/piqa (length-normalized accuracy)\n", + "\t eval/downstream/piqa (log soft loss)\n", + "\t eval/downstream/piqa (soft loss)\n", + "\t eval/downstream/piqa_rc_5shot (BPB)\n", + "\t eval/downstream/piqa_rc_5shot (CE loss)\n", + "\t eval/downstream/piqa_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/piqa_rc_5shot (log soft loss)\n", + "\t eval/downstream/piqa_rc_5shot (soft loss)\n", + "\t eval/downstream/piqa_val_mc_5shot (BPB)\n", + "\t eval/downstream/piqa_val_mc_5shot (CE loss)\n", + "\t eval/downstream/piqa_val_mc_5shot (accuracy)\n", + "\t eval/downstream/piqa_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/piqa_val_mc_5shot (soft loss)\n", + "\t eval/downstream/piqa_val_rc_5shot (BPB)\n", + "\t eval/downstream/piqa_val_rc_5shot (CE loss)\n", + "\t eval/downstream/piqa_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/piqa_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/piqa_val_rc_5shot (soft loss)\n", + "\t eval/downstream/sciq (BPB)\n", + "\t eval/downstream/sciq (CE loss)\n", + "\t eval/downstream/sciq (accuracy)\n", + "\t eval/downstream/sciq (log soft loss)\n", + "\t eval/downstream/sciq (soft loss)\n", + "\t eval/downstream/social_iqa (BPB)\n", + "\t eval/downstream/social_iqa (CE loss)\n", + "\t eval/downstream/social_iqa (length-normalized accuracy)\n", + "\t eval/downstream/social_iqa (log soft loss)\n", + "\t eval/downstream/social_iqa (soft loss)\n", + "\t eval/downstream/socialiqa_rc_5shot (BPB)\n", + "\t eval/downstream/socialiqa_rc_5shot (CE loss)\n", + "\t eval/downstream/socialiqa_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/socialiqa_rc_5shot (log soft loss)\n", + "\t eval/downstream/socialiqa_rc_5shot (soft loss)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (BPB)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (CE loss)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (accuracy)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (soft loss)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (BPB)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (CE loss)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (soft loss)\n", + "\t eval/downstream/winogrande (BPB)\n", + "\t eval/downstream/winogrande (CE loss)\n", + "\t eval/downstream/winogrande (accuracy)\n", + "\t eval/downstream/winogrande (log soft loss)\n", + "\t eval/downstream/winogrande (soft loss)\n", + "\t eval/downstream/winogrande_rc_5shot (BPB)\n", + "\t eval/downstream/winogrande_rc_5shot (CE loss)\n", + "\t eval/downstream/winogrande_rc_5shot (accuracy)\n", + "\t eval/downstream/winogrande_rc_5shot (log soft loss)\n", + "\t eval/downstream/winogrande_rc_5shot (soft loss)\n", + "\t eval/downstream/winogrande_val_mc_5shot (BPB)\n", + "\t eval/downstream/winogrande_val_mc_5shot (CE loss)\n", + "\t eval/downstream/winogrande_val_mc_5shot (accuracy)\n", + "\t eval/downstream/winogrande_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/winogrande_val_mc_5shot (soft loss)\n", + "\t eval/downstream/winogrande_val_rc_5shot (BPB)\n", + "\t eval/downstream/winogrande_val_rc_5shot (CE loss)\n", + "\t eval/downstream/winogrande_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/winogrande_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/winogrande_val_rc_5shot (soft loss)\n", + "\t eval/lm/c4_en-validation/CE loss\n", + "\t eval/lm/c4_en-validation/PPL\n", + "\t eval/lm/dolma_books-validation/CE loss\n", + "\t eval/lm/dolma_books-validation/PPL\n", + "\t eval/lm/dolma_common-crawl-validation/CE loss\n", + "\t eval/lm/dolma_common-crawl-validation/PPL\n", + "\t eval/lm/dolma_pes2o-validation/CE loss\n", + "\t eval/lm/dolma_pes2o-validation/PPL\n", + "\t eval/lm/dolma_reddit-validation/CE loss\n", + "\t eval/lm/dolma_reddit-validation/PPL\n", + "\t eval/lm/dolma_stack-validation/CE loss\n", + "\t eval/lm/dolma_stack-validation/PPL\n", + "\t eval/lm/dolma_wiki-validation/CE loss\n", + "\t eval/lm/dolma_wiki-validation/PPL\n", + "\t eval/lm/ice-validation/CE loss\n", + "\t eval/lm/ice-validation/PPL\n", + "\t eval/lm/m2d2_s2orc-validation/CE loss\n", + "\t eval/lm/m2d2_s2orc-validation/PPL\n", + "\t eval/lm/pile-validation/CE loss\n", + "\t eval/lm/pile-validation/PPL\n", + "\t eval/lm/wikitext_103-validation/CE loss\n", + "\t eval/lm/wikitext_103-validation/PPL\n", + "\t optim/LR (group 0)\n", + "\t optim/LR (group 1)\n", + "\t optim/step skipped\n", + "\t optim/total grad norm\n", + "\t sys.compute.overall\n", + "\t sys.compute.utilized\n", + "\t sys.cpu.percent.avg\n", + "\t sys.disk.read_bps\n", + "\t sys.disk.root.percent.used\n", + "\t sys.disk.root.used\n", + "\t sys.disk.write_bps\n", + "\t sys.gpu.0.free_memory\n", + "\t sys.gpu.0.gpu_utilization\n", + "\t sys.gpu.0.memory_utilization\n", + "\t sys.gpu.0.percent.used_memory\n", + "\t sys.gpu.0.power_usage\n", + "\t sys.gpu.0.temperature\n", + "\t sys.gpu.0.total_memory\n", + "\t sys.gpu.0.used_memory\n", + "\t sys.gpu.1.free_memory\n", + "\t sys.gpu.1.gpu_utilization\n", + "\t sys.gpu.1.memory_utilization\n", + "\t sys.gpu.1.percent.used_memory\n", + "\t sys.gpu.1.power_usage\n", + "\t sys.gpu.1.temperature\n", + "\t sys.gpu.1.total_memory\n", + "\t sys.gpu.1.used_memory\n", + "\t sys.gpu.2.free_memory\n", + "\t sys.gpu.2.gpu_utilization\n", + "\t sys.gpu.2.memory_utilization\n", + "\t sys.gpu.2.percent.used_memory\n", + "\t sys.gpu.2.power_usage\n", + "\t sys.gpu.2.temperature\n", + "\t sys.gpu.2.total_memory\n", + "\t sys.gpu.2.used_memory\n", + "\t sys.gpu.3.free_memory\n", + "\t sys.gpu.3.gpu_utilization\n", + "\t sys.gpu.3.memory_utilization\n", + "\t sys.gpu.3.percent.used_memory\n", + "\t sys.gpu.3.power_usage\n", + "\t sys.gpu.3.temperature\n", + "\t sys.gpu.3.total_memory\n", + "\t sys.gpu.3.used_memory\n", + "\t sys.gpu.4.free_memory\n", + "\t sys.gpu.4.gpu_utilization\n", + "\t sys.gpu.4.memory_utilization\n", + "\t sys.gpu.4.percent.used_memory\n", + "\t sys.gpu.4.power_usage\n", + "\t sys.gpu.4.temperature\n", + "\t sys.gpu.4.total_memory\n", + "\t sys.gpu.4.used_memory\n", + "\t sys.gpu.5.free_memory\n", + "\t sys.gpu.5.gpu_utilization\n", + "\t sys.gpu.5.memory_utilization\n", + "\t sys.gpu.5.percent.used_memory\n", + "\t sys.gpu.5.power_usage\n", + "\t sys.gpu.5.temperature\n", + "\t sys.gpu.5.total_memory\n", + "\t sys.gpu.5.used_memory\n", + "\t sys.gpu.6.free_memory\n", + "\t sys.gpu.6.gpu_utilization\n", + "\t sys.gpu.6.memory_utilization\n", + "\t sys.gpu.6.percent.used_memory\n", + "\t sys.gpu.6.power_usage\n", + "\t sys.gpu.6.temperature\n", + "\t sys.gpu.6.total_memory\n", + "\t sys.gpu.6.used_memory\n", + "\t sys.gpu.7.free_memory\n", + "\t sys.gpu.7.gpu_utilization\n", + "\t sys.gpu.7.memory_utilization\n", + "\t sys.gpu.7.percent.used_memory\n", + "\t sys.gpu.7.power_usage\n", + "\t sys.gpu.7.temperature\n", + "\t sys.gpu.7.total_memory\n", + "\t sys.gpu.7.used_memory\n", + "\t sys.load.avg\n", + "\t sys.network.receive_bps\n", + "\t sys.network.send_bps\n", + "\t sys.ram.available\n", + "\t sys.ram.percent.used\n", + "\t sys.ram.total\n", + "\t sys.ram.used\n", + "\t system/GPU active mem (%)\n", + "\t system/GPU active mem (GiB)\n", + "\t system/GPU reserved mem (%)\n", + "\t system/GPU reserved mem (GiB)\n", + "\t throughput/device/BPS\n", + "\t throughput/device/BPS (actual avg)\n", + "\t throughput/device/TPS\n", + "\t throughput/device/TPS (actual avg)\n", + "\t throughput/device/data loading (%)\n", + "\t throughput/device/data loading (s)\n", + "\t throughput/total tokens\n", + "\t train/CE loss\n", + "\t train/PPL\n", + "\t train/Z loss\n", + "peteish13:\n", + "\t eval/downstream/arc_challenge (length-normalized accuracy)\n", + "\t eval/downstream/arc_easy (accuracy)\n", + "\t eval/downstream/basic_arithmetic (accuracy)\n", + "\t eval/downstream/boolq (accuracy)\n", + "\t eval/downstream/commonsense_qa (length-normalized accuracy)\n", + "\t eval/downstream/copa (accuracy)\n", + "\t eval/downstream/hellaswag (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_var (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_var (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_var (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_var (length-normalized accuracy)\n", + "\t eval/downstream/openbook_qa (length-normalized accuracy)\n", + "\t eval/downstream/piqa (length-normalized accuracy)\n", + "\t eval/downstream/sciq (accuracy)\n", + "\t eval/downstream/social_iqa (length-normalized accuracy)\n", + "\t eval/downstream/winogrande (accuracy)\n", + "\t optim/LR (group 0)\n", + "\t optim/LR (group 1)\n", + "\t optim/total grad norm\n", + "\t sys.compute.overall\n", + "\t sys.compute.utilized\n", + "\t sys.cpu.percent.avg\n", + "\t sys.disk.read_bps\n", + "\t sys.disk.root.percent.used\n", + "\t sys.disk.root.used\n", + "\t sys.disk.write_bps\n", + "\t sys.gpu.0.free_memory\n", + "\t sys.gpu.0.gpu_utilization\n", + "\t sys.gpu.0.memory_utilization\n", + "\t sys.gpu.0.percent.used_memory\n", + "\t sys.gpu.0.power_usage\n", + "\t sys.gpu.0.temperature\n", + "\t sys.gpu.0.total_memory\n", + "\t sys.gpu.0.used_memory\n", + "\t sys.gpu.1.free_memory\n", + "\t sys.gpu.1.gpu_utilization\n", + "\t sys.gpu.1.memory_utilization\n", + "\t sys.gpu.1.percent.used_memory\n", + "\t sys.gpu.1.power_usage\n", + "\t sys.gpu.1.temperature\n", + "\t sys.gpu.1.total_memory\n", + "\t sys.gpu.1.used_memory\n", + "\t sys.gpu.2.free_memory\n", + "\t sys.gpu.2.gpu_utilization\n", + "\t sys.gpu.2.memory_utilization\n", + "\t sys.gpu.2.percent.used_memory\n", + "\t sys.gpu.2.power_usage\n", + "\t sys.gpu.2.temperature\n", + "\t sys.gpu.2.total_memory\n", + "\t sys.gpu.2.used_memory\n", + "\t sys.gpu.3.free_memory\n", + "\t sys.gpu.3.gpu_utilization\n", + "\t sys.gpu.3.memory_utilization\n", + "\t sys.gpu.3.percent.used_memory\n", + "\t sys.gpu.3.power_usage\n", + "\t sys.gpu.3.temperature\n", + "\t sys.gpu.3.total_memory\n", + "\t sys.gpu.3.used_memory\n", + "\t sys.load.avg\n", + "\t sys.network.receive_bps\n", + "\t sys.network.send_bps\n", + "\t sys.ram.available\n", + "\t sys.ram.percent.used\n", + "\t sys.ram.total\n", + "\t sys.ram.used\n", + "\t throughput/device/BPS\n", + "\t throughput/device/TPS\n", + "\t train/CE loss\n", + "\t train/PPL\n", + "\t train/Z loss\n", + "peteish7:\n", + "\t optim/LR (group 0)\n", + "\t optim/LR (group 1)\n", + "\t optim/total grad norm\n", + "\t sys.compute.overall\n", + "\t sys.compute.utilized\n", + "\t sys.cpu.percent.avg\n", + "\t sys.disk.read_bps\n", + "\t sys.disk.root.percent.used\n", + "\t sys.disk.root.used\n", + "\t sys.disk.write_bps\n", + "\t sys.load.avg\n", + "\t sys.network.receive_bps\n", + "\t sys.network.send_bps\n", + "\t sys.ram.available\n", + "\t sys.ram.percent.used\n", + "\t sys.ram.total\n", + "\t sys.ram.used\n", + "\t throughput/device/BPS\n", + "\t throughput/device/TPS\n", + "\t train/CE loss\n", + "\t train/PPL\n", + "\t train/Z loss\n" + ] + } + ], + "execution_count": 18 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-08T08:23:16.363876Z", + "start_time": "2025-01-08T08:22:44.510810Z" + } + }, + "cell_type": "code", + "source": [ + "from tqdm.notebook import tqdm\n", + "\n", + "def download_metric(exps, metric_name):\n", + " result = {}\n", + " for exp in tqdm(exps):\n", + " metrics = exp.get_metrics(metric_name)\n", + " for values in metrics:\n", + " result[values['step']] = float(values['metricValue'])\n", + " result = dict(sorted(result.items()))\n", + " return result\n", + "\n", + "loss = {\n", + " name: download_metric(es, \"train/CE loss\")\n", + " for name, es in exps.items()\n", + "}\n", + "\n", + "skipped_steps = {\n", + " name: download_metric(es, \"optim/step skipped\")\n", + " for name, es in exps.items()\n", + "}\n", + "\n", + "speed = {\n", + " name: download_metric(es, \"train/CE loss\")\n", + " for name, es in exps.items()\n", + "}" + ], + "id": "6aa86a5638253061", + "outputs": [ + { + "data": { + "text/plain": [ + " 0%| | 0/40 [00:00 0])" + ], + "id": "277e0e889edb7b16", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-08T00:23:16.473138\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Steps skipped for the 32B: 50\n", + "[80788, 81072, 84048, 85129, 87386, 92844, 107316, 111491, 113030, 114230, 118668, 121925, 126863, 127493, 128136, 129747, 134843, 136385, 142362, 142815, 144303, 144548, 147139, 147455, 148216, 148703, 150206, 154267, 159678, 159881, 160407, 163682, 167141, 167784, 175621, 187888, 188783, 194308, 204820, 205830, 206617, 212691, 217589, 226667, 230116, 231534, 232070, 232547, 233702, 241716]\n" + ] + } + ], + "execution_count": 20 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Downstream", + "id": "83cbde8bd1160629" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-08T08:26:51.700791Z", + "start_time": "2025-01-08T08:23:16.536088Z" + } + }, + "cell_type": "code", + "source": [ + "aggregate_metric_definitions = {\n", + " \"MMLU 5-shot MC\": {\n", + " \"eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\": 0.215,\n", + " \"eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\": 0.335,\n", + " \"eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\": 0.219,\n", + " \"eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\": 0.231\n", + " },\n", + " \"Average of core 12\": {\n", + " \"eval/downstream/arc_challenge (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/arc_easy (accuracy)\": 1 / 12,\n", + " \"eval/downstream/basic_arithmetic (accuracy)\": 1 / 12,\n", + " \"eval/downstream/boolq (accuracy)\": 1 / 12,\n", + " \"eval/downstream/commonsense_qa (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/copa (accuracy)\": 1 / 12,\n", + " \"eval/downstream/hellaswag (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/openbook_qa (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/piqa (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/sciq (accuracy)\": 1 / 12,\n", + " \"eval/downstream/social_iqa (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/winogrande (accuracy)\": 1 / 12,\n", + " },\n", + " \"Hellswag\": {\n", + " \"eval/downstream/hellaswag (length-normalized accuracy)\": 1\n", + " }\n", + "}\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'svg'\n", + "import numpy as np\n", + "\n", + "fig, axs = plt.subplots(nrows=len(aggregate_metric_definitions), sharex=True, figsize=(10, len(aggregate_metric_definitions)*3))\n", + "\n", + "for ax, agg_metric_name in zip(axs, aggregate_metric_definitions):\n", + " metric_to_weight = aggregate_metric_definitions[agg_metric_name]\n", + " for run_name, run_exps in exps.items():\n", + " metric_to_values = {}\n", + " for metric in metric_to_weight.keys():\n", + " metric_to_values[metric] = download_metric(run_exps, metric)\n", + "\n", + " all_steps = set.union(*[set(v.keys()) for v in metric_to_values.values()])\n", + " minimal_steps = set.intersection(*[set(v.keys()) for v in metric_to_values.values()])\n", + " if all_steps != minimal_steps:\n", + " print(f\"Missing steps for {run_name} / {agg_metric_name}: {all_steps - minimal_steps}\")\n", + "\n", + " aggregated_values = {}\n", + " for step in minimal_steps:\n", + " value = 0.0\n", + " for metric, weight in metric_to_weight.items():\n", + " value += metric_to_values[metric][step] * weight\n", + " aggregated_values[step] = value\n", + " if len(aggregated_values) == 0:\n", + " continue\n", + "\n", + " print(f\"{run_name} / {agg_metric_name} max: {max(aggregated_values.values())}\")\n", + "\n", + " xs = np.array(list(aggregated_values.keys()))\n", + " ys = np.array(list(aggregated_values.values()))\n", + " order = np.argsort(xs)\n", + " xs = xs[order]\n", + " ys = ys[order]\n", + " xs *= (2048 * 4096)\n", + " ax.plot(xs, ys, linewidth=0.5)\n", + " ax.set_ylabel(agg_metric_name)\n", + "\n", + "plt.xlabel(\"step\")\n", + "plt.show()" + ], + "id": "8b310d9cc68ad856", + "outputs": [ + { + "data": { + "text/plain": [ + " 0%| | 0/40 [00:00" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-08T00:26:51.675515\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 21 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Spike Analysis", + "id": "744574cd19bbe369" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-08T08:26:52.300841Z", + "start_time": "2025-01-08T08:26:51.735362Z" + } + }, + "cell_type": "code", + "source": [ + "window_size = 128\n", + "losses = np.array(list(loss[\"peteish32\"].values()))\n", + "steps = np.array(list(loss[\"peteish32\"].keys()))\n", + "\n", + "from numpy.lib.stride_tricks import sliding_window_view\n", + "windows = sliding_window_view(losses, window_size)\n", + "\n", + "stds = windows.std(axis=1)\n", + "means = windows.mean(axis=1)\n", + "losses = losses[window_size - 1 :]\n", + "steps = steps[window_size - 1 :]\n", + "spike_steps = steps[np.argwhere(losses > means + stds * 6)].flatten()\n", + "print(f\"Steps with spikes: {spike_steps}\")\n", + "\n", + "fig, axes = plt.subplots(\n", + " nrows=len(spike_steps),\n", + " figsize=(7, len(spike_steps)*3),\n", + " sharex=False\n", + ")\n", + "\n", + "for ax, spike in zip(axes, spike_steps):\n", + " for name, values in loss.items():\n", + " xs = np.array(list(values.keys()))\n", + " ys = np.array(list(values.values()))\n", + " ax.plot(xs, ys, linewidth=0.5)\n", + " ax.set_ylim(2.1, 2.5)\n", + " ax.set_xlim(spike-1000, spike+1000)\n", + " plt.yscale('log')\n", + " plt.xlabel(\"step\")\n", + " plt.ylabel(\"loss\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n" + ], + "id": "6eb5abfb647663a5", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Steps with spikes: [ 29645 38677 49089 54503 66257 73019 144302]\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-08T00:26:52.236753\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 22 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/scripts/train/OLMo2-32B.py b/src/scripts/train/OLMo2-32B.py new file mode 100644 index 00000000..25f6eca1 --- /dev/null +++ b/src/scripts/train/OLMo2-32B.py @@ -0,0 +1,198 @@ +""" +Train a 32B OLMo model. Run this script without any arguments to see usage info. +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel import DataParallelType +from olmo_core.float8 import Float8Config +from olmo_core.internal.experiment import CommonComponents, main +from olmo_core.nn.transformer import ( + TransformerActivationCheckpointingConfig, + TransformerActivationCheckpointingMode, + TransformerConfig, + TransformerDataParallelConfig, +) +from olmo_core.optim import OptimGroupOverride, SkipStepAdamWConfig +from olmo_core.train import Duration, DurationUnit, TrainerConfig +from olmo_core.train.callbacks import ( + CheckpointerCallback, + CometCallback, + DownstreamEvaluatorCallbackConfig, + WandBCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig + +log = logging.getLogger(__name__) + + +def build_model_config(common: CommonComponents) -> TransformerConfig: + compile = True + return TransformerConfig.olmo2_32B( + vocab_size=common.tokenizer.padded_vocab_size(), + compile=compile, + fused_ops=False, + use_flash=not compile, + dp_config=TransformerDataParallelConfig( + name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 + ), + # dp_config=TransformerDataParallelConfig( + # name=DataParallelType.hsdp, + # param_dtype=DType.bfloat16, + # reduce_dtype=DType.float32, + # num_replicas=64 // 16, # common.launch.num_nodes // 2, + # ), + # ac_config=TransformerActivationCheckpointingConfig(TransformerActivationCheckpointingMode.full), + ac_config=TransformerActivationCheckpointingConfig( + mode=TransformerActivationCheckpointingMode.selected_modules, + modules=[f"blocks.{i}.feed_forward" for i in range(64)], + ), + float8_config=Float8Config(compile=compile, enabled=False), + ) + + +def build_optim_config(common: CommonComponents) -> SkipStepAdamWConfig: + del common + return SkipStepAdamWConfig( + lr=6e-4, + weight_decay=0.1, + betas=(0.9, 0.95), + group_overrides=[ + OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) + ], + # fused=True, + compile=True, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + project_name = "peteish32" + return ( + TrainerConfig( + save_folder=f"gs://ai2-llm/checkpoints/{project_name}/", + rank_microbatch_size=2 * 4096, + checkpointer=CheckpointerConfig(save_thread_count=1, load_thread_count=32), + save_overwrite=True, + metrics_collect_interval=10, + cancel_check_interval=10, + z_loss_multiplier=1e-5, + compile_loss=False, + fused_loss=True, + max_duration=Duration(int(6.5e12), DurationUnit.tokens), + ) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=1000, + save_async=True, + ), + ) + .with_callback( + "comet", + CometCallback( + name=common.run_name, + workspace="ai2", + project=project_name, + enabled=True, + cancel_check_interval=10, + ), + ) + .with_callback( + "wandb", + WandBCallback( + name=common.run_name, + entity="ai2-llm", + project=project_name, + enabled=False, + cancel_check_interval=10, + ), + ) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=[ + # MMLU for backwards compatibility + "mmlu_stem_mc_5shot", + "mmlu_humanities_mc_5shot", + "mmlu_social_sciences_mc_5shot", + "mmlu_other_mc_5shot", + # MMLU test + "mmlu_stem_mc_5shot_test", + "mmlu_humanities_mc_5shot_test", + "mmlu_social_sciences_mc_5shot_test", + "mmlu_other_mc_5shot_test", + ## Core 12 tasks for backwards compatibility + #"arc_challenge", + #"arc_easy", + #"basic_arithmetic", + #"boolq", + #"commonsense_qa", + #"copa", + #"hellaswag", + #"openbook_qa", + #"piqa", + #"sciq", + #"social_iqa", + #"winogrande", + ## Core 12 tasks 5-shot + #"arc_challenge_rc_5shot", + #"arc_easy_rc_5shot", + ## "basic_arithmetic_rc_5shot", # doesn't exist + ## "boolq_rc_5shot", # we don't like it + #"csqa_rc_5shot", + ## "copa_rc_5shot", # doesn't exist + #"hellaswag_rc_5shot", + #"openbookqa_rc_5shot", + #"piqa_rc_5shot", + ## "sciq_rc_5shot", # doesn't exist + #"socialiqa_rc_5shot", + #"winogrande_rc_5shot", + ## New in-loop evals + #"arc_challenge_val_rc_5shot", + #"arc_challenge_val_mc_5shot", + "arc_challenge_test_rc_5shot", + #"arc_challenge_test_mc_5shot", + #"arc_easy_val_rc_5shot", + #"arc_easy_val_mc_5shot", + "arc_easy_test_rc_5shot", + #"arc_easy_test_mc_5shot", + #"boolq_val_rc_5shot", + #"boolq_val_mc_5shot", + "csqa_val_rc_5shot", + #"csqa_val_mc_5shot", + "hellaswag_val_rc_5shot", + #"hellaswag_val_mc_5shot", + #"openbookqa_val_rc_5shot", + #"openbookqa_val_mc_5shot", + "openbookqa_test_rc_5shot", + #"openbookqa_test_mc_5shot", + "piqa_val_rc_5shot", + #"piqa_val_mc_5shot", + "socialiqa_val_rc_5shot", + #"socialiqa_val_mc_5shot", + #"winogrande_val_rc_5shot", + #"winogrande_val_mc_5shot", + #"mmlu_stem_val_rc_5shot", + #"mmlu_stem_val_mc_5shot", + #"mmlu_humanities_val_rc_5shot", + #"mmlu_humanities_val_mc_5shot", + #"mmlu_social_sciences_val_rc_5shot", + #"mmlu_social_sciences_val_mc_5shot", + #"mmlu_other_val_rc_5shot", + #"mmlu_other_val_mc_5shot", + ], + tokenizer=common.tokenizer, + eval_interval=1000, + ), + ) + ) + + +if __name__ == "__main__": + main( + global_batch_size=2048 * 4096, + model_config_builder=build_model_config, + optim_config_builder=build_optim_config, + trainer_config_builder=build_trainer_config, + ) diff --git a/src/test/optim/adamw_test.py b/src/test/optim/adamw_test.py index 5756f9a6..8f3c58e5 100644 --- a/src/test/optim/adamw_test.py +++ b/src/test/optim/adamw_test.py @@ -1,7 +1,10 @@ +from test.utils import DEVICES + +import pytest import torch import torch.nn as nn -from olmo_core.optim import AdamWConfig, OptimGroupOverride +from olmo_core.optim import AdamW, AdamWConfig, OptimGroupOverride, SkipStepAdamWConfig class MyModel(nn.Module): @@ -20,7 +23,7 @@ def test_adamw_config_to_optim(): model = MyModel() optim = config.build(model) - assert isinstance(optim, torch.optim.AdamW) + assert isinstance(optim, AdamW) assert len(optim.param_groups) == 1 assert config.merge(["lr=1e-1"]).lr == 0.1 @@ -33,7 +36,7 @@ def test_adamw_config_to_optim_with_group_overrides(): model = MyModel() optim = config.build(model) - assert isinstance(optim, torch.optim.AdamW) + assert isinstance(optim, AdamW) assert len(optim.param_groups) == 2 assert optim.param_groups[0]["weight_decay"] == 0.0 assert len(optim.param_groups[0]["params"]) == 1 @@ -43,3 +46,33 @@ def test_adamw_config_to_optim_with_group_overrides(): for group in optim.param_groups: assert "initial_lr" in group + + +@pytest.mark.parametrize("device", DEVICES) +def test_adamw(device: torch.device): + config = AdamWConfig() + model = MyModel().train().to(device) + optim = config.build(model) + + for group in optim.param_groups: + assert "initial_lr" in group + + # Take a step. + optim.zero_grad(set_to_none=True) + model(torch.randint(0, 1024, (2, 8), device=device).int()).sum().backward() + optim.step() + + +@pytest.mark.parametrize("device", DEVICES) +def test_skip_step_adamw(device: torch.device): + config = SkipStepAdamWConfig() + model = MyModel().train().to(device) + optim = config.build(model) + + for group in optim.param_groups: + assert "initial_lr" in group + + # Take a step. + optim.zero_grad(set_to_none=True) + model(torch.randint(0, 1024, (2, 8), device=device).int()).sum().backward() + optim.step()