diff --git a/src/zeroband/config.py b/src/zeroband/config.py index ee2d615e..7076771e 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -44,8 +44,13 @@ class SoapConfig(BaseConfig): OptimizersConfig: TypeAlias = AdamConfig | SoapConfig +class PowerSGDConfig(BaseConfig): + rank: int = 1 + warmup_steps: int = 1000 + class OptimConfig(BaseConfig): optim: OptimizersConfig = AdamConfig() + power_sgd: PowerSGDConfig | None = None lr: float = 4e-4 weight_decay: float = 0.1 @@ -144,6 +149,7 @@ def validate_remote_data_path(self): return self + class Config(BaseConfig): # main config name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M" diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index cb767790..3a3b96a7 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -23,7 +23,8 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE from torch.nn.attention import SDPBackend, sdpa_kernel -_flex_attention_compiled = torch.compile(flex_attention, dynamic=False) +_flex_attention_compiled = torch.compile(flex_attention, dynamic=True) +# _flex_attention_compiled = flex_attention # copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27 diff --git a/src/zeroband/optimizers.py b/src/zeroband/optimizers.py index 6faf5c01..792d2715 100644 --- a/src/zeroband/optimizers.py +++ b/src/zeroband/optimizers.py @@ -3,7 +3,7 @@ from distributed_shampoo import ( DefaultEigenvalueCorrectedShampooConfig, DistributedShampoo, - FullyShardShampooConfig, + DDPShampooConfig, ShampooPT2CompileConfig, ) from zeroband.config import AdamConfig, SoapConfig, OptimizersConfig @@ -29,7 +29,7 @@ def get_optimizer(params: Iterable[torch.nn.Parameter], config: OptimizersConfig # This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is # less expensive and might thereby allow for a smaller `precondition_frequency`. preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, - distributed_config=FullyShardShampooConfig(), + distributed_config=DDPShampooConfig(), shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False), ) else: diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 32f9cdd0..376975c5 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,11 +1,15 @@ import os import time from typing import TYPE_CHECKING -from multiprocessing.process import _children # type: ignore +from multiprocessing.process import _children # type: ignore import torch import torch.distributed as dist -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy # type: ignore + +# from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy # type: ignore +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook + from torch.autograd.profiler import record_function from zeroband.checkpoint import CkptManager, TrainingProgress @@ -26,7 +30,7 @@ get_tensor_list_signature, get_peak_flops, get_num_params, - get_num_flop_per_token + get_num_flop_per_token, ) from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger from zeroband.utils.monitor import HttpMonitor @@ -65,7 +69,7 @@ def log_hash_training_state( if config.diloco is not None and diloco is not None: outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer) - outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu) # type: ignore + outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu) # type: ignore logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}") logger.debug(f"outer diloco model hash {id} : {outer_model_hash}") @@ -148,28 +152,41 @@ def train(config: Config): enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src ) - mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None + # mp_policy = MixedPrecisionPolicy( + # param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None + # ) + + # for layer_id, transformer_block in model.layers.items(): + # if config.train.reshard_after_forward: + # reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # else: + # reshard_after_forward = False + # fully_shard( + # transformer_block, + # mp_policy=mp_policy, + # mesh=elastic_device_mesh.cuda_local_mesh, + # reshard_after_forward=reshard_after_forward, + # ) + # fully_shard( + # model, + # mp_policy=mp_policy, + # mesh=elastic_device_mesh.cuda_local_mesh, + # reshard_after_forward=config.train.reshard_after_forward, + # ) + model: DDP = DDP( + model, device_ids=[world_info.local_rank], broadcast_buffers=False, gradient_as_bucket_view=True ) - for layer_id, transformer_block in model.layers.items(): - if config.train.reshard_after_forward: - reshard_after_forward = int(layer_id) < len(model.layers) - 1 - else: - reshard_after_forward = False - fully_shard( - transformer_block, - mp_policy=mp_policy, - mesh=elastic_device_mesh.cuda_local_mesh, - reshard_after_forward=reshard_after_forward, + if config.optim.power_sgd is not None: + state = powerSGD_hook.PowerSGDState( + process_group=None, # Default process group + matrix_approximation_rank=config.optim.power_sgd.rank, # Adjust rank based on compression needs + start_powerSGD_iter=config.optim.power_sgd.warmup_steps, # When to start compression ) - fully_shard( - model, - mp_policy=mp_policy, - mesh=elastic_device_mesh.cuda_local_mesh, - reshard_after_forward=config.train.reshard_after_forward, - ) - logger.debug("model fsdped") + + model.register_comm_hook(state, powerSGD_hook.powerSGD_hook) + + logger.debug("model ddped") # Setup optimizers with record_function("Set up Optimizers"): @@ -195,15 +212,15 @@ def train(config: Config): dataloader=train_dataloader, training_progress=training_progress, data_rank=config.data.data_rank, - diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, # type: ignore - diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, # type: ignore + diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, # type: ignore + diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, # type: ignore ) if world_info.rank == 0: logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger metric_logger = logger_cls( project=config.project, - config={"config": config.model_dump(), "world_info": world_info.json()}, + logger_config={"config": config.model_dump(), "world_info": world_info.json()}, resume=config.wandb_resume, ) else: @@ -300,7 +317,8 @@ def train(config: Config): for grad_acc_step in range(gradient_accumulation_steps): is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 # no sync if we are accumulating gradients - model.set_requires_gradient_sync(not is_accumulating) + # model.set_requires_gradient_sync(not is_accumulating) + model.require_backward_grad_sync = not is_accumulating with record_function("Load batch"): # TODO/NOTE: We could overlap sending the batch with communication @@ -315,7 +333,9 @@ def train(config: Config): block_mask = None with record_function("Run model"): - logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") flatten_labels = rearrange(labels, "b seq -> (b seq)") @@ -324,7 +344,7 @@ def train(config: Config): flatten_logits, flatten_labels, z_weight=config.optim.z_loss_weight if config.optim.z_loss else None, - num_chunks=config.optim.num_chunks + num_chunks=config.optim.num_chunks, ) del logits @@ -491,7 +511,7 @@ def train(config: Config): if __name__ == "__main__": # Allow eager fallback during production so that that the training runs dont die # However, in development, we want to know that we broke torch compile - torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ # type: ignore + torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ # type: ignore torch.set_float32_matmul_precision("high") torch.manual_seed(42) @@ -514,21 +534,20 @@ def pretty_dict(d, indent=2): try: if config.train.torch_profiler and world_info.rank == 0: - # NOTE(apaz-cli): I cannot seem to get the memory profiler to work. # Running into this issue: https://github.com/pytorch/pytorch/issues/64345 # In the meantime, we can use the memory snapshotter. logger.debug("Running train() with profiler.") prof = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - #profile_memory=True, - #with_stack=True, - ) + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + # profile_memory=True, + # with_stack=True, + ) try: prof.__enter__() train(config) @@ -546,8 +565,8 @@ def pretty_dict(d, indent=2): logger.info("\n" + "*" * width + " GPU MEM " + "*" * width) logger.info(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10)) - #logger.info("Exporting memory timeline.") - #prof.export_memory_timeline(f"logs/mem_timeline.html", device="cuda:0") + # logger.info("Exporting memory timeline.") + # prof.export_memory_timeline(f"logs/mem_timeline.html", device="cuda:0") else: train(config) except Exception as e: diff --git a/src/zeroband/utils/metric_logger.py b/src/zeroband/utils/metric_logger.py index 0a47dc3f..85847925 100644 --- a/src/zeroband/utils/metric_logger.py +++ b/src/zeroband/utils/metric_logger.py @@ -2,11 +2,9 @@ from typing import Any, Protocol import importlib.util -from zeroband.config import get_env_config - class MetricLogger(Protocol): - def __init__(self, project, config): ... + def __init__(self, project, logger_config): ... def log(self, metrics: dict[str, Any]): ... @@ -14,16 +12,16 @@ def finish(self): ... class WandbMetricLogger(MetricLogger): - def __init__(self, project, config, resume: bool): + def __init__(self, project, logger_config, resume: bool): if importlib.util.find_spec("wandb") is None: raise ImportError("wandb is not installed. Please install it to use WandbMonitor.") import wandb - run_name = get_env_config(config, "run_name") + run_name = logger_config["config"]["run_name"] wandb.init( - project=project, config=config, name=run_name, resume="auto" if resume else None + project=project, config=logger_config, name=run_name, resume="auto" if resume else None ) # make wandb reuse the same run id if possible def log(self, metrics: dict[str, Any]): @@ -38,9 +36,9 @@ def finish(self): class DummyMetricLogger(MetricLogger): - def __init__(self, project, config, *args, **kwargs): + def __init__(self, project, logger_config, *args, **kwargs): self.project = project - self.config = config + self.logger_config = logger_config open(self.project, "a").close() # Create an empty file to append to self.data = []