Skip to content

Commit

Permalink
load optim states in CPU and move them to GPU after 1st fwd-bwd to av…
Browse files Browse the repository at this point in the history
…oid peak memory
  • Loading branch information
NouamaneTazi committed Nov 21, 2024
1 parent 312b759 commit 9e1d76f
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/nanotron/optim/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union

import torch

Expand Down Expand Up @@ -34,7 +34,7 @@ def state_dict(self) -> dict:
...

@abstractmethod
def load_state_dict(self, state_dict: dict) -> None:
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
...

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions src/nanotron/optim/inherit_from_other_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cache
from typing import Callable, Dict, Optional, Set
from typing import Callable, Dict, Optional, Set, Union

import torch

Expand Down Expand Up @@ -33,8 +33,8 @@ def state_dict_additional_keys(self) -> Set[str]:
def state_dict(self) -> dict:
return self.optimizer.state_dict()

def load_state_dict(self, state_dict: dict) -> None:
return self.optimizer.load_state_dict(state_dict)
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
return self.optimizer.load_state_dict(state_dict, map_location=map_location)

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
return self.optimizer.step(closure=closure)
Expand Down
6 changes: 3 additions & 3 deletions src/nanotron/optim/named_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -58,7 +58,7 @@ def state_dict(self) -> dict:
}
return optim_state_dict

def load_state_dict(self, state_dict: dict) -> None:
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
assert set(self.id_to_name.values()) == set(
state_dict["names"].values()
), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}"
Expand All @@ -71,4 +71,4 @@ def load_state_dict(self, state_dict: dict) -> None:
key in state
), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}"

return super().load_state_dict(state_dict)
return super().load_state_dict(state_dict, map_location=map_location)
4 changes: 2 additions & 2 deletions src/nanotron/optim/optimizer_from_gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def state_dict(self) -> dict:
state_dict["gradient_accumulator"] = self.gradient_accumulator.state_dict()
return state_dict

def load_state_dict(self, state_dict: dict) -> None:
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
gradient_accumulator_state_dict = state_dict.pop("gradient_accumulator")
super().load_state_dict(state_dict)
super().load_state_dict(state_dict, map_location=map_location)
self.gradient_accumulator.load_state_dict(gradient_accumulator_state_dict)
9 changes: 6 additions & 3 deletions src/nanotron/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def before_optim_step_sanity_checks(
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
optimizer: optim.BaseOptimizer,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Test tied weights gradients are synchronized
Expand Down Expand Up @@ -232,6 +233,9 @@ def before_optim_step_sanity_checks(
msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}",
)

# SANITY CHECK: Check that optimizer states are synchronized across DP
check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg)

# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_optim_step_sanity_checks()

Expand Down Expand Up @@ -259,12 +263,11 @@ def after_optim_step_sanity_checks(
unwrapped_model.after_optim_step_sanity_checks()


def check_optim_state_in_sync(optimizer: optim.BaseOptimizer, pg: dist.ProcessGroup):
for _, optim_state in sorted(optimizer.state_dict()["state"].items(), key=lambda x: x[0]):
def check_optim_state_in_sync(optim_state_dict: dict, pg: dist.ProcessGroup):
for _, optim_state in sorted(optim_state_dict["state"].items(), key=lambda x: x[0]):
for name, tensor in optim_state.items():
if name == "step":
continue

assert_tensor_synced_across_pg(
tensor=tensor, pg=pg, msg=lambda err: f"{name} are not synced across DP {err}"
)
43 changes: 29 additions & 14 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
from torch import nn
Expand All @@ -19,7 +19,6 @@
)
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.sanity_checks import check_optim_state_in_sync
from nanotron.serialize.metadata import TensorMetadata
from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors

Expand Down Expand Up @@ -125,18 +124,34 @@ def save_lr_scheduler(
)


# Helper functions to move optimizer states
@torch.no_grad()
def state_dict_to_device(state_dict: Dict, device: str) -> Dict:
assert (
state_dict["state"][0]["exp_avg"].device.type == "cpu"
), "Optimizer states should be on CPU to avoid extra memory usage when loading from checkpoint"
torch.cuda.empty_cache()

for _, optim_state in sorted(state_dict["state"].items(), key=lambda x: x[0]):
for name, tensor in optim_state.items():
optim_state[name] = tensor.to(device)

assert (
state_dict["state"][0]["exp_avg"].device.type == "cuda"
), "Optimizer states should be on GPU because model is on GPU"
torch.cuda.empty_cache()


@torch.no_grad()
def load_optimizer(
optimizer: optim.BaseOptimizer,
parallel_context: ParallelContext,
root_folder: Path,
map_location: Optional[str] = None,
map_location: Optional[str] = "cpu",
param_shard_metadata: Tuple[Tuple[int, int], TensorMetadata] = None, # (pp_rank, tp_rank) -> TensorMetadata
model: Optional[nn.Module] = None,
):
root_folder = root_folder / "optimizer"
# `load_state_dict` copies the state dict which can be very large in case of Zero-0 so we load to cpu and then move to the right device
map_location = "cpu" if not optimizer.inherit_from(optim.ZeroDistributedOptimizer) else map_location
ckp_optimizer_config_path = root_folder / "optimizer_config.json"
with open(ckp_optimizer_config_path, "r") as file:
ckp_optimizer_config = json.load(file)
Expand All @@ -149,9 +164,10 @@ def load_optimizer(
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
parallel_context.pp_pg.size()
):
warnings.warn(
"You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!"
)
if int(ckp_pp_size) != int(parallel_context.pp_pg.size()):
warnings.warn(
"You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!"
)
assert (
param_shard_metadata is not None
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
Expand Down Expand Up @@ -241,8 +257,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
# TODO(xrsrke): free the memory of the shards that isn't
# corresponding to the current rank
# TODO: maybe better to allocate memory for all states at once
buffer = torch.zeros_like(param, device="cuda", dtype=OPTIMIZER_STATE_DTYPE)
unsharded_buffer = torch.empty(new_unshared_shape, device="cuda", dtype=OPTIMIZER_STATE_DTYPE)
buffer = torch.zeros_like(param, device=map_location, dtype=OPTIMIZER_STATE_DTYPE)
unsharded_buffer = torch.empty(
new_unshared_shape, device=map_location, dtype=OPTIMIZER_STATE_DTYPE
)

for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
Expand Down Expand Up @@ -333,10 +351,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
)
state_dict["state"][param_index][state_name] = sliced_tensor

optimizer.load_state_dict(state_dict)

if not optimizer.inherit_from(optim.ZeroDistributedOptimizer):
check_optim_state_in_sync(optimizer, parallel_context.dp_pg)
optimizer.load_state_dict(state_dict, map_location="cpu")


def load_lr_scheduler(
Expand Down
27 changes: 19 additions & 8 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
save_random_states,
)
from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata
from nanotron.serialize.optimizer import load_optimizer
from nanotron.serialize.optimizer import load_optimizer, state_dict_to_device

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -432,10 +432,17 @@ def train(
# Fix the root_model
self.unwrapped_model.module_id_to_prefix[id(self.unwrapped_model)] = ""

self.initial_iter_step = self.metadata.last_train_step + 1
self.last_iter_step = self.config.tokens.train_steps

prof = get_profiler(config=self.config)
# free memory
import gc

gc.collect()
torch.cuda.empty_cache()
with prof:
for self.iteration_step in range(self.metadata.last_train_step + 1, self.config.tokens.train_steps + 1):
for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1):
if isinstance(prof, torch.profiler.profile):
prof.step()

Expand Down Expand Up @@ -474,7 +481,7 @@ def training_step(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler
)

if self.iteration_step < 5:
if self.iteration_step < self.initial_iter_step + 5:
log_memory(logger=logger)

outputs = self.pipeline_engine.train_batch_iter(
Expand All @@ -485,7 +492,7 @@ def training_step(
grad_accumulator=self.grad_accumulator,
)

if self.iteration_step < 5:
if self.iteration_step < self.initial_iter_step + 5:
log_memory(logger=logger)

after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
Expand Down Expand Up @@ -531,10 +538,6 @@ def training_step(
max_norm=self.config.optimizer.clip_grad,
)

before_optim_step_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator
)

# Compute DP average loss and overlap with optimizer step
if isinstance(outputs[0]["loss"], torch.Tensor):
# This is an average on only one data rank.
Expand All @@ -547,6 +550,14 @@ def training_step(
loss_avg = None
handle = None

# Move optimizer states back to GPU before optimizer step
if self.init_checkpoint_path is not None and self.iteration_step == self.initial_iter_step:
state_dict_to_device(self.optimizer.state_dict(), "cuda")

before_optim_step_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
)

# Apply gradient
self.optimizer.step()
self.optimizer.zero_grad()
Expand Down

0 comments on commit 9e1d76f

Please sign in to comment.