Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Power sgd #196

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/zeroband/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ class SoapConfig(BaseConfig):
OptimizersConfig: TypeAlias = AdamConfig | SoapConfig


class PowerSGDConfig(BaseConfig):
rank: int = 1
warmup_steps: int = 1000

class OptimConfig(BaseConfig):
optim: OptimizersConfig = AdamConfig()
power_sgd: PowerSGDConfig | None = None

lr: float = 4e-4
weight_decay: float = 0.1
Expand Down Expand Up @@ -144,6 +149,7 @@ def validate_remote_data_path(self):
return self



class Config(BaseConfig):
# main config
name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M"
Expand Down
3 changes: 2 additions & 1 deletion src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE
from torch.nn.attention import SDPBackend, sdpa_kernel

_flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
_flex_attention_compiled = torch.compile(flex_attention, dynamic=True)
# _flex_attention_compiled = flex_attention


# copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27
Expand Down
4 changes: 2 additions & 2 deletions src/zeroband/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from distributed_shampoo import (
DefaultEigenvalueCorrectedShampooConfig,
DistributedShampoo,
FullyShardShampooConfig,
DDPShampooConfig,
ShampooPT2CompileConfig,
)
from zeroband.config import AdamConfig, SoapConfig, OptimizersConfig
Expand All @@ -29,7 +29,7 @@ def get_optimizer(params: Iterable[torch.nn.Parameter], config: OptimizersConfig
# This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is
# less expensive and might thereby allow for a smaller `precondition_frequency`.
preconditioner_config=DefaultEigenvalueCorrectedShampooConfig,
distributed_config=FullyShardShampooConfig(),
distributed_config=DDPShampooConfig(),
shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False),
)
else:
Expand Down
101 changes: 60 additions & 41 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import time
from typing import TYPE_CHECKING
from multiprocessing.process import _children # type: ignore
from multiprocessing.process import _children # type: ignore

import torch
import torch.distributed as dist
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy # type: ignore

# from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy # type: ignore
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook

from torch.autograd.profiler import record_function

from zeroband.checkpoint import CkptManager, TrainingProgress
Expand All @@ -26,7 +30,7 @@
get_tensor_list_signature,
get_peak_flops,
get_num_params,
get_num_flop_per_token
get_num_flop_per_token,
)
from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger
from zeroband.utils.monitor import HttpMonitor
Expand Down Expand Up @@ -65,7 +69,7 @@ def log_hash_training_state(

if config.diloco is not None and diloco is not None:
outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer)
outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu) # type: ignore
outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu) # type: ignore

logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}")
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}")
Expand Down Expand Up @@ -148,28 +152,41 @@ def train(config: Config):
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
)

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
# mp_policy = MixedPrecisionPolicy(
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
# )

# for layer_id, transformer_block in model.layers.items():
# if config.train.reshard_after_forward:
# reshard_after_forward = int(layer_id) < len(model.layers) - 1
# else:
# reshard_after_forward = False
# fully_shard(
# transformer_block,
# mp_policy=mp_policy,
# mesh=elastic_device_mesh.cuda_local_mesh,
# reshard_after_forward=reshard_after_forward,
# )
# fully_shard(
# model,
# mp_policy=mp_policy,
# mesh=elastic_device_mesh.cuda_local_mesh,
# reshard_after_forward=config.train.reshard_after_forward,
# )
model: DDP = DDP(
model, device_ids=[world_info.local_rank], broadcast_buffers=False, gradient_as_bucket_view=True
)

for layer_id, transformer_block in model.layers.items():
if config.train.reshard_after_forward:
reshard_after_forward = int(layer_id) < len(model.layers) - 1
else:
reshard_after_forward = False
fully_shard(
transformer_block,
mp_policy=mp_policy,
mesh=elastic_device_mesh.cuda_local_mesh,
reshard_after_forward=reshard_after_forward,
if config.optim.power_sgd is not None:
state = powerSGD_hook.PowerSGDState(
process_group=None, # Default process group
matrix_approximation_rank=config.optim.power_sgd.rank, # Adjust rank based on compression needs
start_powerSGD_iter=config.optim.power_sgd.warmup_steps, # When to start compression
)
fully_shard(
model,
mp_policy=mp_policy,
mesh=elastic_device_mesh.cuda_local_mesh,
reshard_after_forward=config.train.reshard_after_forward,
)
logger.debug("model fsdped")

model.register_comm_hook(state, powerSGD_hook.powerSGD_hook)

logger.debug("model ddped")

# Setup optimizers
with record_function("Set up Optimizers"):
Expand All @@ -195,15 +212,15 @@ def train(config: Config):
dataloader=train_dataloader,
training_progress=training_progress,
data_rank=config.data.data_rank,
diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, # type: ignore
diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, # type: ignore
diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, # type: ignore
diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, # type: ignore
)

if world_info.rank == 0:
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger
metric_logger = logger_cls(
project=config.project,
config={"config": config.model_dump(), "world_info": world_info.json()},
logger_config={"config": config.model_dump(), "world_info": world_info.json()},
resume=config.wandb_resume,
)
else:
Expand Down Expand Up @@ -300,7 +317,8 @@ def train(config: Config):
for grad_acc_step in range(gradient_accumulation_steps):
is_accumulating = grad_acc_step < gradient_accumulation_steps - 1
# no sync if we are accumulating gradients
model.set_requires_gradient_sync(not is_accumulating)
# model.set_requires_gradient_sync(not is_accumulating)
model.require_backward_grad_sync = not is_accumulating

with record_function("Load batch"):
# TODO/NOTE: We could overlap sending the batch with communication
Expand All @@ -315,7 +333,9 @@ def train(config: Config):
block_mask = None

with record_function("Run model"):
logits = model(tokens=input_ids, block_mask=block_mask).contiguous()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(tokens=input_ids, block_mask=block_mask).contiguous()

flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
flatten_labels = rearrange(labels, "b seq -> (b seq)")

Expand All @@ -324,7 +344,7 @@ def train(config: Config):
flatten_logits,
flatten_labels,
z_weight=config.optim.z_loss_weight if config.optim.z_loss else None,
num_chunks=config.optim.num_chunks
num_chunks=config.optim.num_chunks,
)
del logits

Expand Down Expand Up @@ -491,7 +511,7 @@ def train(config: Config):
if __name__ == "__main__":
# Allow eager fallback during production so that that the training runs dont die
# However, in development, we want to know that we broke torch compile
torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ # type: ignore
torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ # type: ignore
torch.set_float32_matmul_precision("high")
torch.manual_seed(42)

Expand All @@ -514,21 +534,20 @@ def pretty_dict(d, indent=2):

try:
if config.train.torch_profiler and world_info.rank == 0:

# NOTE(apaz-cli): I cannot seem to get the memory profiler to work.
# Running into this issue: https://github.com/pytorch/pytorch/issues/64345
# In the meantime, we can use the memory snapshotter.

logger.debug("Running train() with profiler.")
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
#profile_memory=True,
#with_stack=True,
)
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
# profile_memory=True,
# with_stack=True,
)
try:
prof.__enter__()
train(config)
Expand All @@ -546,8 +565,8 @@ def pretty_dict(d, indent=2):
logger.info("\n" + "*" * width + " GPU MEM " + "*" * width)
logger.info(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

#logger.info("Exporting memory timeline.")
#prof.export_memory_timeline(f"logs/mem_timeline.html", device="cuda:0")
# logger.info("Exporting memory timeline.")
# prof.export_memory_timeline(f"logs/mem_timeline.html", device="cuda:0")
else:
train(config)
except Exception as e:
Expand Down
14 changes: 6 additions & 8 deletions src/zeroband/utils/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,26 @@
from typing import Any, Protocol
import importlib.util

from zeroband.config import get_env_config


class MetricLogger(Protocol):
def __init__(self, project, config): ...
def __init__(self, project, logger_config): ...

def log(self, metrics: dict[str, Any]): ...

def finish(self): ...


class WandbMetricLogger(MetricLogger):
def __init__(self, project, config, resume: bool):
def __init__(self, project, logger_config, resume: bool):
if importlib.util.find_spec("wandb") is None:
raise ImportError("wandb is not installed. Please install it to use WandbMonitor.")

import wandb

run_name = get_env_config(config, "run_name")
run_name = logger_config["config"]["run_name"]

wandb.init(
project=project, config=config, name=run_name, resume="auto" if resume else None
project=project, config=logger_config, name=run_name, resume="auto" if resume else None
) # make wandb reuse the same run id if possible

def log(self, metrics: dict[str, Any]):
Expand All @@ -38,9 +36,9 @@ def finish(self):


class DummyMetricLogger(MetricLogger):
def __init__(self, project, config, *args, **kwargs):
def __init__(self, project, logger_config, *args, **kwargs):
self.project = project
self.config = config
self.logger_config = logger_config
open(self.project, "a").close() # Create an empty file to append to

self.data = []
Expand Down
Loading