From 7299fdc7fd7c9b7e7d5ea12f075ad11c7d29248e Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Thu, 21 Nov 2024 16:07:59 +0000 Subject: [PATCH] move load custom func to base --- src/nanotron/helpers.py | 125 +--------------------------------- src/nanotron/optim/base.py | 134 ++++++++++++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 125 deletions(-) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 6a46551e..6ff564a8 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -4,22 +4,16 @@ import math import os import time -from collections import defaultdict -from copy import deepcopy from datetime import datetime from functools import partial -from itertools import chain from math import ceil from typing import ( Any, - DefaultDict, Dict, - Hashable, Iterable, List, Optional, Tuple, - Union, ) import numpy as np @@ -28,7 +22,6 @@ from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LambdaLR from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler -from typing_extensions import TypeAlias from nanotron import distributed as dist from nanotron import logging @@ -36,7 +29,7 @@ from nanotron.distributed import ProcessGroup from nanotron.logging import LogItem, log_rank from nanotron.models.base import NanotronModel -from nanotron.optim.base import BaseOptimizer, Optimizer +from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict from nanotron.optim.gradient_accumulator import ( FP32GradBucketManager, FP32GradientAccumulator, @@ -58,11 +51,6 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata -Args: TypeAlias = Tuple[Any, ...] -Kwargs: TypeAlias = Dict[str, Any] -StateDict: TypeAlias = Dict[str, Any] - - logger = logging.get_logger(__name__) @@ -310,117 +298,6 @@ def merge_named_param_groups( return named_param_groups -# Modified from torch.optim.Optimizer._process_value_according_to_param_policy -@staticmethod -def _process_value_according_to_param_policy( - param: torch.Tensor, - value: torch.Tensor, - param_id: int, - param_groups: List[Dict[Any, Any]], - map_location: Optional[Union[str, torch.device]], - key: Hashable = None, -) -> torch.Tensor: - # If map_location is specified, use it instead of param.device - target_device = map_location if map_location is not None else param.device - - fused = False - capturable = False - assert param_groups is not None - for pg in param_groups: - if param_id in pg["params"]: - fused = pg["fused"] if "fused" in pg else False - capturable = pg["capturable"] if "capturable" in pg else False - break - - if key == "step": - if capturable or fused: - return value.to(dtype=torch.float32, device=target_device) - else: - return value - else: - if param.is_floating_point(): - return value.to(dtype=param.dtype, device=target_device) - else: - return value.to(device=target_device) - - -# Modified from torch.optim.Optimizer.load_state_dict -@torch._disable_dynamo -def custom_load_state_dict( - self, state_dict: StateDict, map_location: Optional[Union[str, torch.device]] = "cpu" -) -> None: - r"""Loads the optimizer state. - - Args: - state_dict (dict): optimizer state. Should be an object returned - from a call to :meth:`state_dict`. - map_location (str or torch.device, optional): Device where to load the optimizer states. - If None, states will be loaded to the same device as their corresponding parameters. - Default: None - """ - - # shallow copy, to be consistent with module API - state_dict = state_dict.copy() - - for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): - hook_result = pre_hook(self, state_dict) - if hook_result is not None: - state_dict = hook_result - - # Validate the state_dict - groups = self.param_groups - - # Deepcopy as we write into saved_groups later to update state - saved_groups = deepcopy(state_dict["param_groups"]) - - if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " "parameter groups") - param_lens = (len(g["params"]) for g in groups) - saved_lens = (len(g["params"]) for g in saved_groups) - if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError( - "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" - ) - - # Update the state - id_map = dict( - zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)) - ) - - def _cast(param, value, param_id=None, param_groups=None, key=None): - r"""Make a deep copy of value, casting all tensors to device of param.""" - if isinstance(value, torch.Tensor): - return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key) - elif isinstance(value, dict): - return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()} - elif isinstance(value, Iterable): - return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) - else: - return value - - # Copy state assigned to params (and cast tensors to appropriate types). - # State that is not assigned to params is copied as is (needed for - # backward compatibility). - state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) - for k, v in state_dict["state"].items(): - if k in id_map: - param = id_map[k] - state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) - else: - state[k] = v - - # Update parameter groups, setting their 'params' value - def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: - new_group["params"] = group["params"] - return new_group - - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({"state": state, "param_groups": param_groups}) - - for post_hook in self._optimizer_load_state_dict_post_hooks.values(): - post_hook(self) - - def init_optimizer_and_grad_accumulator( parametrization_method: ParametrizationMethod, model: nn.Module, diff --git a/src/nanotron/optim/base.py b/src/nanotron/optim/base.py index 34c33f42..9418b44a 100644 --- a/src/nanotron/optim/base.py +++ b/src/nanotron/optim/base.py @@ -1,7 +1,28 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union +from collections import defaultdict +from copy import deepcopy +from itertools import chain +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import torch +from typing_extensions import TypeAlias + +Args: TypeAlias = Tuple[Any, ...] +Kwargs: TypeAlias = Dict[str, Any] +StateDict: TypeAlias = Dict[str, Any] class BaseOptimizer(ABC): @@ -46,3 +67,114 @@ def inherit_from(self, cls) -> bool: Optimizer = TypeVar("Optimizer", BaseOptimizer, torch.optim.Optimizer) + + +# Modified from torch.optim.Optimizer._process_value_according_to_param_policy +@staticmethod +def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + map_location: Optional[Union[str, torch.device]], + key: Hashable = None, +) -> torch.Tensor: + # If map_location is specified, use it instead of param.device + target_device = map_location if map_location is not None else param.device + + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=target_device) + else: + return value + else: + if param.is_floating_point(): + return value.to(dtype=param.dtype, device=target_device) + else: + return value.to(device=target_device) + + +# Modified from torch.optim.Optimizer.load_state_dict +@torch._disable_dynamo +def custom_load_state_dict( + self, state_dict: StateDict, map_location: Optional[Union[str, torch.device]] = "cpu" +) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + map_location (str or torch.device, optional): Device where to load the optimizer states. + If None, states will be loaded to the same device as their corresponding parameters. + Default: None + """ + + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " "parameter groups") + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key) + elif isinstance(value, dict): + return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()} + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self)