From 0e74dc880a6c9803179e9025fc5af63458b24d99 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 12 Nov 2024 19:48:21 +0000 Subject: [PATCH] in mem ckpt --- axlearn/common/checkpointer.py | 30 +- axlearn/common/checkpointer_orbax.py | 653 +++++++++++++++++++++- axlearn/common/checkpointer_orbax_test.py | 285 +++++++++- axlearn/common/checkpointer_test.py | 18 +- axlearn/common/launch.py | 18 +- pyproject.toml | 2 +- 6 files changed, 984 insertions(+), 22 deletions(-) diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index 0eb48ae3..1aa411a5 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -718,6 +718,8 @@ class Config(Module.Config): every_n_steps_policy ) + # TODO(hanzhi-zhou): deprecate all checkpoint_paths related class methods in favor of + # checkpoint_steps. @classmethod def checkpoint_paths(cls, base_dir: str) -> list[str]: """Returns complete checkpoint paths under base dir. @@ -744,6 +746,24 @@ def latest_checkpoint_path(cls, base_dir: str) -> str: # Note: checkpoint_paths should already filter incomplete checkpoints. return sorted(cls.checkpoint_paths(base_dir)).pop() + @classmethod + def checkpoint_steps(cls, base_dir: str) -> list[int]: + """Returns complete checkpoint steps under base dir. + + Args: + base_dir: Path to checkpoints dir. + + Returns: + A list of committed checkpoint steps. Incomplete checkpoints are dropped. + """ + raise NotImplementedError(cls) + + @classmethod + def latest_checkpoint_step(cls, base_dir: str) -> int: + """Returns the most recent (highest step count) checkpoint step under base dir.""" + # Note: checkpoint_steps should already filter incomplete checkpoints. + return max(cls.checkpoint_steps(base_dir)) + def __init__(self, cfg: Module.Config, *, parent: Optional[Module]): super().__init__(cfg, parent=parent) self._within_context = False @@ -850,7 +870,11 @@ class Config(BaseCheckpointer.Config): @classmethod def checkpoint_paths(cls, base_dir: str) -> list[str]: """See `BaseCheckpointer.checkpointer_paths`.""" - + logging.log_first_n( + logging.WARNING, + msg="checkpoint_paths is deprecated. Use checkpoint_steps instead.", + n=1, + ) # The default checkpointer commits under "/_/index". Using a # concurrent `exists` check for the index file can be several times faster than `glob` on # gcs when there are many checkpoint files, even if using a "native" solution like @@ -867,6 +891,10 @@ def checkpoint_paths(cls, base_dir: str) -> list[str]: index_exists = pool.map(fs.exists, paths) return [os.path.dirname(path) for path, committed in zip(paths, index_exists) if committed] + @classmethod + def checkpoint_steps(cls, base_dir: str) -> list[int]: + return [parse_step_from_dir(path) for path in cls.checkpoint_paths(base_dir)] + @classmethod def cleanup_checkpoint(cls, ckpt_dir: str, *, sync: bool = True): """Removes ckpt_dir if it exists. diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2f714605..24cb0ed5 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -9,30 +9,56 @@ import copy import dataclasses import functools +import hashlib import os +import time from concurrent import futures +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Process from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import jax +import jax.lib import orbax.checkpoint as ocp +import orbax.checkpoint.experimental.emergency.checkpoint_manager as oecp import tensorflow as tf from absl import logging +from jax._src.distributed import global_state +from jax._src.mesh import thread_resources +from jax.experimental.array_serialization import serialization +# TODO(hanzhi-zhou): fix this after orbax move metadata back to public API. +from orbax.checkpoint._src.metadata.value import Metadata +from orbax.checkpoint._src.multihost import multihost as multihost_src + +from axlearn.common import file_system as fs from axlearn.common import utils from axlearn.common.checkpointer import ( STEP_NUM_DIGITS, STEP_PREFIX, BaseCheckpointer, + Checkpointer, + CheckpointPolicy, CheckpointValidationType, + InstantiableConfig, + StateStorage, + StateStorageCommitCallback, async_save_tf_savables, check_state_structure, + config_for_function, + every_n_steps_policy, maybe_restore_grain_savables, maybe_save_grain_savables, + multihost_utils, + parse_step_from_dir, + read_index_file, restore_tf_savables, + write_index_file, ) -from axlearn.common.config import config_class +from axlearn.common.config import REQUIRED, Required, config_class from axlearn.common.module import Module from axlearn.common.utils import Nested, Tensor, TensorSpec +from axlearn.common.utils_spmd import setup try: # The import also registers the checkpoint handlers. @@ -45,7 +71,7 @@ _GRAIN_INSTALLED = False -class _TfIteratorHandler(ocp.pytree_checkpoint_handler.TypeHandler): +class _TfIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes tf.data.Iterator. Reference: @@ -94,10 +120,8 @@ async def deserialize( ] return await asyncio.gather(*futs) - async def metadata( - self, infos: Sequence[ocp.type_handlers.ParamInfo] - ) -> Sequence[ocp.metadata.Metadata]: - return [ocp.metadata.Metadata(name=info.name, directory=info.path) for info in infos] + async def metadata(self, infos: Sequence[ocp.type_handlers.ParamInfo]) -> Sequence[Metadata]: + return [Metadata(name=info.name, directory=info.path) for info in infos] ocp.type_handlers.register_type_handler(tf.data.Iterator, _TfIteratorHandler(), override=True) @@ -105,7 +129,7 @@ async def metadata( if _GRAIN_INSTALLED: - class _GrainDatasetIteratorHandler(ocp.pytree_checkpoint_handler.TypeHandler): + class _GrainDatasetIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes grain dataset iterators.""" @dataclasses.dataclass @@ -143,8 +167,8 @@ async def deserialize( async def metadata( self, infos: Sequence[ocp.type_handlers.ParamInfo] - ) -> Sequence[ocp.metadata.Metadata]: - return [ocp.metadata.Metadata(name=info.name, directory=info.path) for info in infos] + ) -> Sequence[Metadata]: + return [Metadata(name=info.name, directory=info.path) for info in infos] ocp.type_handlers.register_type_handler( grain.DatasetIterator, _GrainDatasetIteratorHandler(), override=True @@ -182,8 +206,18 @@ class Config(BaseCheckpointer.Config): @classmethod def checkpoint_paths(cls, base_dir: str) -> List[str]: """See `BaseCheckpointer.checkpointer_paths`.""" + logging.log_first_n( + logging.WARNING, + msg="checkpoint_paths is deprecated. Use checkpoint_steps instead.", + n=1, + ) return [str(path) for path in ocp.utils.checkpoint_steps_paths(base_dir)] + @classmethod + def checkpoint_steps(cls, base_dir) -> list[int]: + """See `BaseCheckpointer.checkpointer_steps`.""" + return ocp.utils.checkpoint_steps(base_dir) + def __init__(self, cfg: Config, *, parent: Optional[Module]): super().__init__(cfg, parent=parent) @@ -355,3 +389,604 @@ def wait_until_finished(self): def stop(self): """See `BaseCheckpointer.stop` for details.""" self._manager.close() + + +class _TFSavablesStateStorage(StateStorage): + """A StateStorage implementation that only saves the index file and tf savables.""" + + @config_class + class Config(StateStorage.Config): + timeout_secs: int = 300 + + def __init__(self, cfg: Config): + super().__init__(cfg) + # One thread is sufficient because `async_save_tf_savables` only creates one future. + self._executor = ThreadPoolExecutor(1) + self._manager = serialization.AsyncManager(timeout_secs=cfg.timeout_secs) + + def _get_spec(self, *, step: int, state: Nested[Any]) -> Nested[Any]: + spec = {"index": [("step", int(step))], "tf_ckpt_map": {}} + for path, value in utils.flatten_items(state): + if isinstance(value, (Tensor, TensorSpec)): + dtype = getattr(value.dtype, "dtype", value.dtype) + spec["index"].append( + (path, {"dtype": str(dtype), "shape": str(tuple(value.shape))}) + ) + elif isinstance(value, tf.data.Iterator): + spec["index"].append((path, str(type(value)))) + spec["tf_ckpt_map"][path] = value + else: + spec["index"].append((path, value)) + logging.log_first_n(logging.INFO, "TF savables spec: %s", 1, str(spec)) + return spec + + def save_to_dir( + self, + *, + step: int, + state: Nested[Tensor], + ckpt_dir: str, + on_commit_callback: StateStorageCommitCallback = write_index_file, + ): + start_time = time.perf_counter() + # We write data files directly to `ckpt_dir`. `index` is written into `ckpt_dir` in + # `on_commit_callback` to finalize the checkpoint. + spec = self._get_spec(step=step, state=state) + self.wait_until_finished() + + save_tf_future = async_save_tf_savables( + spec["tf_ckpt_map"], + executor=self._executor, + dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}"), + ) + + def commit(): + on_commit_callback(ckpt_dir=ckpt_dir, index=spec["index"]) + logging.info( + "Serialization of TF savables to %s completed in %s seconds.", + ckpt_dir, + time.perf_counter() - start_time, + ) + + # pylint: disable=protected-access + self._manager._add_futures([save_tf_future]) + self._manager._start_async_commit(commit) + + def wait_until_finished(self): + self._manager.wait_until_finished() + + def restore_from_dir( + self, + step: int, + state: Union[Nested[Tensor], Nested[TensorSpec]], + *, + ckpt_dir: str, + validation: CheckpointValidationType = CheckpointValidationType.EXACT, + ) -> Nested[Tensor]: + spec = self._get_spec(step=step, state=state) + logging.info("Restoring TF savables from directory %s", ckpt_dir) + check_state_structure( + read_index_file(ckpt_dir), target_structure=spec["index"], validation=validation + ) + restore_tf_savables( + spec["tf_ckpt_map"], dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}") + ) + multihost_utils.sync_global_devices(ckpt_dir) + return state + + def stop(self): + self._executor.shutdown(wait=True) + + +def _initialize_runtime_to_distributed_ids(timeout: int): + """Initializes orbax's internal process index mapping. + + This function is ported from orbax's source code to make the timeout configurable. + https://github.com/google/orbax/blob/073880cae248fc721fe31e46bf1bb386346d3aa5/checkpoint/orbax/checkpoint/_src/multihost/multihost.py#L52 + """ + # pylint: disable=protected-access + client = multihost_src.get_jax_distributed_client() + + # Index is distributed id. + # Value is runtime id. + multihost_src._RUNTIME_TO_DISTRIBUTED_ID = [0 for _ in range(jax.process_count())] + own_runtime_id = jax.process_index() + own_distributed_id = ( + jax._src.distributed.global_state.process_id # pytype: disable=module-attr + ) # pylint: disable=protected-access + dir_key = "jax/process_id/" + key = dir_key + str(own_runtime_id) + client.key_value_set(key, str(own_distributed_id)) + client.wait_at_barrier("orbax_global_discovery", timeout_in_ms=timeout * 1000) + ids = client.key_value_dir_get(dir_key) + for key, distributed_id in ids: + runtime_id = int(key.split("/")[-1]) + multihost_src._RUNTIME_TO_DISTRIBUTED_ID[runtime_id] = int(distributed_id) + logging.info( + "[process=%s][thread=%s] runtime_to_distributed_id: %s", + multihost_src.process_index(), + multihost_src.threading.current_thread().name, + multihost_src._RUNTIME_TO_DISTRIBUTED_ID, + ) + + +_PROCESS_ID_FILE_NAME: str = "process_id.txt" + + +def _get_previous_process_id(local_dir: str, *, unique_str: str) -> int: + """Gets previous process id from local checkpoint directory. Returns -1 if file isn't found.""" + path = os.path.join(local_dir, _get_unique_id(unique_str), _PROCESS_ID_FILE_NAME) + if not fs.exists(path): + return -1 + + with fs.open(path) as f: + proc_id = int(f.read()) + return proc_id + + +def _dump_process_id(local_dir: str, *, unique_str: str, process_index: int): + """Dumps process id to local checkpoint directory.""" + local_dir = os.path.join(local_dir, _get_unique_id(unique_str)) + fs.makedirs(local_dir) + process_id_file = os.path.join(local_dir, _PROCESS_ID_FILE_NAME) + with fs.open(process_id_file, "w") as f: + f.write(str(process_index)) + + +def _get_unique_id(unique_str: str) -> str: + return hashlib.sha256(unique_str.encode(), usedforsecurity=False).hexdigest() + + +def _init_consistent_proc_ids( + *, + distributed_coordinator: Optional[str] = None, + num_processes: Optional[int] = None, + process_id: Optional[int] = None, + initialization_timeout: Optional[int] = None, + barrier_timeout_ms: int = 300 * 1000, + jax_backend: str, + trainer_dir: str, + local_ckpt_dir: str, +): + """Reads local process id file and assigns globally consistent process ids through rank 0. + + During failover, healthy nodes will read their locally stored process id file, but failed nodes + will lost their process ids. To assign ids that are free in the global id range (i.e. 0 to + num_processes - 1), we let each node report its process id (-1 if missing) to rank 0, and rank + 0 will figure out suitable IDs to assign to each failed node. We reuse Jax's distributed client + to avoid writing our own coordinator. + """ + setup( + jax_backend=jax_backend, + distributed_coordinator=distributed_coordinator, + num_processes=num_processes, + process_id=process_id, + initialization_timeout=initialization_timeout, + ) + client: jax.lib.xla_extension.DistributedRuntimeClient = global_state.client + prev_process_id = _get_previous_process_id(local_ckpt_dir, unique_str=trainer_dir) + prefix = "axlearn/id_reassign" + # Local key just needs to be unique for each process. + local_set_key = f"{prefix}/{jax.process_index()}" + # For TPU backend, only GKE is supported for now. + if jax.default_backend() == "tpu": + # For TPUs, we have the additional requirement that process ids in slice id X must be in + # range [X * num_processes_per_slice, (X + 1) * num_processes_per_slice). Therefore, we + # first identify the healthy slices' ids and then figure out the slice ids to assign to + # failed slices. Each process in the failed slice will then get id `new_slice_id * + # num_proc_per_slice + worker_id`. + client.key_value_set( + local_set_key, + f"{os.environ['MEGASCALE_SLICE_ID']}|{prev_process_id}|{os.environ['TPU_WORKER_ID']}", + ) + client.wait_at_barrier("axlearn/id-reassign-gather-id", timeout_in_ms=barrier_timeout_ms) + if jax.process_index() == 0: + ids = client.key_value_dir_get(prefix) + parsed_ids: list[tuple[int, int, int]] = [] + for _, v in ids: + data = v.split("|") + assert len(data) == 3 + parsed_ids.append(tuple(int(x) for x in data)) + + num_proc_per_slice = len(str(os.environ.get("TPU_WORKER_HOSTNAMES", None)).split(",")) + failed_slices_new_ids = {} + for slice_id, prev_proc_id, _ in parsed_ids: + if prev_proc_id == -1: + failed_slices_new_ids[slice_id] = -1 + + already_assigned_slice_ids = set() + for slice_id, prev_proc_id, _ in parsed_ids: + if slice_id not in failed_slices_new_ids: + already_assigned_slice_ids.add(prev_proc_id // num_proc_per_slice) + + to_be_assigned_slice_ids = ( + set(range(int(os.environ["MEGASCALE_NUM_SLICES"]))) - already_assigned_slice_ids + ) + assert len(to_be_assigned_slice_ids) == len(failed_slices_new_ids) + for k, new_id in zip(failed_slices_new_ids.keys(), to_be_assigned_slice_ids): + failed_slices_new_ids[k] = new_id + + for (k, _), (slice_id, prev_proc_id, worker_id) in zip(ids, parsed_ids): + if (new_slice_id := failed_slices_new_ids.get(slice_id)) is not None: + client.key_value_set( + k + "/get", str(new_slice_id * num_proc_per_slice + worker_id) + ) + else: + client.key_value_set(k + "/get", str(prev_proc_id)) + elif jax.default_backend() == "gpu": + # For GPU backend, failed nodes are assigned with ids that are missing in the global id + # range with arbitrary order. + client.key_value_set(local_set_key, str(prev_process_id)) + client.wait_at_barrier("axlearn/id-reassign-gather-id", timeout_in_ms=barrier_timeout_ms) + if jax.process_index() == 0: + ids = client.key_value_dir_get(prefix) + to_be_assigned_proc_ids = list( + set(range(num_processes)) - set(int(value) for _, value in ids if int(value) != -1) + ) + counter = 0 + for k, value in ids: + if int(value) == -1: + client.key_value_set(k + "/get", str(to_be_assigned_proc_ids[counter])) + counter += 1 + else: + client.key_value_set(k + "/get", value) + assert counter == len(to_be_assigned_proc_ids) + else: + raise RuntimeError(f"Unsupported backend {jax.default_backend()}") + + _dump_process_id( + local_ckpt_dir, + unique_str=trainer_dir, + process_index=int( + client.blocking_key_value_get(local_set_key + "/get", timeout_in_ms=barrier_timeout_ms) + ), + ) + # Block to avoid coordinator exiting too early. + client.wait_at_barrier("axlearn/id-reassign-finalize", timeout_in_ms=barrier_timeout_ms) + jax.distributed.shutdown() + + +def get_consistent_proc_id( + *, + jax_backend: str, + distributed_coordinator: Optional[str] = None, + num_processes: Optional[int] = None, + process_id: Optional[int] = None, + initialization_timeout: Optional[int] = None, + barrier_timeout_ms: int = 300 * 1000, + trainer_dir: str, + local_ckpt_dir: str, +) -> int: + """Returns process id so that process id <-> node mapping stays the same for health nodes. + + This is required to preserve shard order for in-memory checkpoint recovery. For GPU training, + all healthy nodes will have their process id unchanged. For TPU, all nodes in the healthy + slices will have their process id unchanged. See docstring of `_init_consistent_proc_ids` for + implementation details. + """ + proc = Process( + target=_init_consistent_proc_ids, + kwargs=dict( + jax_backend=jax_backend, + distributed_coordinator=distributed_coordinator, + num_processes=num_processes, + process_id=process_id, + initialization_timeout=initialization_timeout, + barrier_timeout_ms=barrier_timeout_ms, + trainer_dir=trainer_dir, + local_ckpt_dir=local_ckpt_dir, + ), + ) + proc.start() + proc.join() + assert proc.exitcode == 0 + + proc_id = _get_previous_process_id(local_ckpt_dir, unique_str=trainer_dir) + assert proc_id != -1 + return proc_id + + +class OrbaxEmergencyCheckpointer(BaseCheckpointer): + """Checkpointer implementation that uses Orbax emergency checkpoint. + + This checkpointer is intended for multi-slice training that uses data-parallelism across + slices. Orbax emergency checkpoint works by exploiting the following properties: + 1. Tensors are replicated across data-parallel replicas. + 2. When a slice fails in a multi-slice training and failover is started, only nodes + corresponding to the non-healthy slice may be restarted. Healthy nodes from healthy slices + will not restart. + + Hence, all slices can write checkpoints to node's memory or disk, providing us with redundancy + when there's a failure. This checkpoint frequency can be much higher than remote filesystem, + which has limited bandwidth to support high frequency saving. Checkpoints on nodes are referred + as local checkpoints. Checkpoints on remote filesystem are referred as persistent checkpoints. + + When a failure occurs, Orbax checkpointer will find the latest step from all local and + persistent checkpoints. If the checkpoint is local, the slice on which that checkpoint is + stored will read the checkpoint and broadcast the read values to other slices. + + However, the above procedure doesn't apply to some non-tensor states such as data iterators. + Data iterators are unique across jax processes, and thus cannot be stored on nodes. Orbax + emergency checkpointer doesn't support non-tensor states. Therefore, we reuse axlearn + Checkpointer to save, restore and garbage collect those states, which include the index file + and tf iterators. These non-tensor states will be saved whenever local or persistent checkpoint + need to be saved. As the result, the persistent checkpoint structure looks like this: + + ├── path_prefix + │ ├── non-tensors + │ │ └── step_00000010 + │ │ ├── index + │ │ └── tf_xxx + │ └── tensors + │ └── step_00000010 + │ └── orbax_files_xxx + + A persistent training checkpoint `step_xxx` is commited when `non-tensors/step_xxx/index` + exists and `tensors/step_xxx` is commited by Orbax. Refer to the docstring of + `OrbaxCheckpointer` for Orbax's commit criteria. + + To abstract the details of the checkpoint layout, the `checkpoint_steps` API returns all steps + for which both Tensor and non-Tensor states have been fully committed. + """ + + _NON_TENSORS_PREFIX: str = "non-tensors" + _TENSORS_PREFIX: str = "tensors" + + @config_class + class Config(BaseCheckpointer.Config): + """Configures OrbaxEmergencyCheckpointer. + + Attributes: + keep_last_n: Keep this many past ckpts. + keep_every_n_steps: If > 0, keeps at least one persistent checkpoint every N steps. + local_keep_last_n: Keep this many past ckpts in local storage (e.g. node memory). + This should almost always set to 1 to avoid OOM. + local_dir: Ckpt base path for local storage. The content in this path must persist + across pod restarts unless the restart is caused by node failure. `local_dir` must + be the same for all processes or processes may hang. + unqiue_str: A string that's unique for the current run. Typically, this is set to + trainer_dir. Local checkpoint will be stored in local_dir/sha256(unique_str). + During init, all other folders in local_dir will be removed. + save_policy: Save policy for persistent checkpoints. + local_save_policy: Save policy for local checkpoints. This should be more frequent than + `save_policy`. Note that data iterator will be saved with either `save_policy` or + `local_save_policy` indicate we should save. + non_tensor_async_timeout_secs: Timeout for async barrier in seconds when saving + non-tensor states. + async_timeout_secs: Timeout for async barrier in seconds when saving tensors. + replica_axis_index: The index of the "data" axis. + """ + + keep_last_n: int = 1 + keep_every_n_steps: Optional[int] = None + local_keep_last_n: int = 1 + local_save_policy: InstantiableConfig[CheckpointPolicy] = config_for_function( + every_n_steps_policy + ).set(n=10) + local_dir: str = "/host-tmp/checkpoints" + unique_str: Required[str] = REQUIRED + non_tensor_async_timeout_secs: int = 300 + async_timeout_secs: int = 3600 + replica_axis_index: Required[int] = REQUIRED + + @classmethod + def checkpoint_paths(cls, base_dir: str) -> List[str]: + """See `BaseCheckpointer.checkpointer_paths`. + + Only persistent checkpoint paths are returned. There's no guarantee that the paths returned + have committed TF savables. Use `checkpoint_steps` to get steps with both tensors and + committed TF savables. + """ + logging.log_first_n( + logging.WARNING, + msg="checkpoint_paths is deprecated. Use checkpoint_steps instead.", + n=1, + ) + tensors_dir = os.path.join(base_dir, cls._TENSORS_PREFIX) + return [str(path) for path in ocp.utils.checkpoint_steps_paths(tensors_dir)] + + @classmethod + def checkpoint_steps(cls, base_dir) -> list[int]: + """See `BaseCheckpointer.checkpointer_steps`. + + Only persistent checkpoint steps are returned. + """ + return list( + set( + ocp.utils.checkpoint_steps(os.path.join(base_dir, cls._TENSORS_PREFIX)) + ).intersection( + set(Checkpointer.checkpoint_steps(os.path.join(base_dir, cls._NON_TENSORS_PREFIX))) + ) + ) + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg: OrbaxEmergencyCheckpointer.Config = self.config + self._name_format = ocp.step.standard_name_format( + step_prefix=STEP_PREFIX, + step_format_fixed_length=STEP_NUM_DIGITS, + ) + if jax.process_index() == 0: + fs.makedirs(os.path.join(cfg.dir, self._NON_TENSORS_PREFIX)) + fs.makedirs(os.path.join(cfg.dir, self._TENSORS_PREFIX)) + # Cleanup local checkpoints from different runs. + unique_id = _get_unique_id(cfg.unique_str) + for fd in fs.listdir(cfg.local_dir): + if not fd.startswith(".") and fd != unique_id: + fs.rmtree(os.path.join(cfg.local_dir, fd)) + self._local_dir = os.path.join(cfg.local_dir, unique_id) + fs.makedirs(self._local_dir) + # Orbax emergency ckpt requires this function to be called prior to checkpointer + # operations. This function also serves as a barrier. + _initialize_runtime_to_distributed_ids(cfg.non_tensor_async_timeout_secs) + ckpt_cfg: Checkpointer.Config = Checkpointer.default_config() + # TODO(hanzhi-zhou): this `keep_last_n` may not be what users expect since non-tensor + # states will save when either local or persistent checkpoint will save. + ckpt_cfg.keep_last_n = cfg.keep_last_n + ckpt_cfg.keep_every_n_steps = cfg.keep_every_n_steps + ckpt_cfg.storage = _TFSavablesStateStorage.default_config() + ckpt_cfg.storage.timeout_secs = cfg.non_tensor_async_timeout_secs + ckpt_cfg.dir = os.path.join(cfg.dir, self._NON_TENSORS_PREFIX) + ckpt_cfg.name = "non-tensors-checkpointer" + + save_policy = cfg.save_policy.instantiate() + local_save_policy = cfg.local_save_policy.instantiate() + + # Non-tensor states must save when either local or persistent ckpt needs to be saved in + # order for restore from either to succeed. + def _composite_save_policy(*, step: int, evaler_summaries: dict[str, Any]): + return save_policy(step=step, evaler_summaries=evaler_summaries) or local_save_policy( + step=step, evaler_summaries=evaler_summaries + ) + + ckpt_cfg.save_policy = config_for_function(lambda: _composite_save_policy) + self._non_tensor_manager: Checkpointer = ckpt_cfg.instantiate(parent=self) + self._tensor_manager: Optional[oecp.CheckpointManager] = None + # See comments of _eval_summaries in `OrbaxCheckpointer`. + self._eval_summaries = None + + # pylint: disable-next=redefined-builtin + def ckpt_dir(self, step: int, dir: Optional[str] = None) -> str: + """Obtains the checkpoint dir for the given step.""" + if dir is None: + dir = self._non_tensor_manager.directory + return str(ocp.step.build_step_path(dir, self._name_format, step)) + + def _get_abstract_state( + self, state_with_tensors: Nested[Tensor] + ) -> Nested[jax.ShapeDtypeStruct]: + """Generate the abstract states required by the Orbax emergency checkpointer.""" + return jax.tree.map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding), + state_with_tensors, + ) + + def _get_tensor_manager(self, state_with_tensors: Nested[Tensor]) -> oecp.CheckpointManager: + """Creates the emergency checkpoint manager if not exists. + + We defer the creation of this checkpoint manager because it requires the state dict, + which is not present during __init__. + """ + cfg: OrbaxEmergencyCheckpointer.Config = self.config + if self._tensor_manager is not None: + return self._tensor_manager + + save_policy = cfg.save_policy.instantiate() + local_save_policy = cfg.local_save_policy.instantiate() + + def _orbax_save_fn( + step: int, last_saved_step: Optional[int], wrapped_save_policy: CheckpointPolicy + ) -> bool: + del last_saved_step + return wrapped_save_policy(step=step, evaler_summaries=self._eval_summaries) + + # For meaning of these options, refer to + # https://github.com/google/orbax/blob/95be2c021bc8cbf4badd83a053ff57b7a9f9b314/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L277 + self._tensor_manager = oecp.CheckpointManager( + self._local_dir, + persistent_directory=os.path.join(cfg.dir, self._TENSORS_PREFIX), + global_mesh=thread_resources.env.physical_mesh, + abstract_state=self._get_abstract_state(state_with_tensors), + options=oecp.CheckpointManagerOptions( + local=oecp.LocalCheckpointOptions( + should_save_fn=functools.partial( + _orbax_save_fn, wrapped_save_policy=local_save_policy + ), + max_to_keep=cfg.local_keep_last_n, + ), + persistent=oecp.PersistentCheckpointOptions( + should_save_fn=functools.partial( + _orbax_save_fn, wrapped_save_policy=save_policy + ), + max_to_keep=cfg.keep_last_n, + ), + replica_axis_index=cfg.replica_axis_index, + async_options=oecp.checkpoint_manager.AsyncOptions( + timeout_secs=cfg.async_timeout_secs + ), + step_name_format=self._name_format, + cleanup_tmp_directories=True, + enable_async_checkpointing=True, + ), + ) + return self._tensor_manager + + def save( + self, *, step: int, state: Nested[Tensor], evaler_summaries: Optional[Dict[str, Any]] = None + ): + """See `BaseCheckpointer.save` for details.""" + assert self._eval_summaries is None, self._eval_summaries + self._eval_summaries = copy.deepcopy(evaler_summaries or {}) + + start_t = time.perf_counter() + state_with_tensors = jax.tree.map( + lambda x: x if isinstance(x, (Tensor, TensorSpec)) else None, state + ) + # Note that save() waits for prior serialization to finish. + self._non_tensor_manager.save(step=step, state=state) + self._get_tensor_manager(state_with_tensors).save( + step=step, args=ocp.args.PyTreeSave(item=state_with_tensors) + ) + self._eval_summaries = None + if (time_diff := time.perf_counter() - start_t) > 0.5: + logging.info("In-mem ckpt blocking time is %fs.", time_diff) + + def restore( + self, + *, + step: Optional[int] = None, + state: Union[Nested[Tensor], Nested[TensorSpec]], + ) -> Tuple[Optional[int], Nested[Tensor]]: + """See `BaseCheckpointer.restore` for details.""" + start_t = time.perf_counter() + cfg: OrbaxEmergencyCheckpointer.Config = self.config + state_with_tensors = jax.tree.map( + lambda x: x if isinstance(x, (Tensor, TensorSpec)) else None, state + ) + tensor_manager = self._get_tensor_manager(state_with_tensors) + if step is None: + # Find the intersection of the checkpoint steps managed by tensor and non-tensor + # manager, and then use the latest step in the intersection for restore. `all_steps` + # from tensor manager contains both local and persistent checkpoints. + common_steps = set(tensor_manager.all_steps()).intersection( + set( + ( + parse_step_from_dir(d) + for d in self._non_tensor_manager.checkpoint_paths( + self._non_tensor_manager.config.dir + ) + ) + ) + ) + if not common_steps: + logging.warning("Could not find any completed checkpoints under %s.", cfg.dir) + return None, state + step = max(common_steps) + + restore_step, state = self._non_tensor_manager.restore(step=step, state=state) + assert step == restore_step + + restored_state_with_tensors = tensor_manager.restore( + step=step, + args=ocp.args.PyTreeRestore(item=self._get_abstract_state(state_with_tensors)), + ) + # Merge non-tensor and tensor states by replacing leaves of the non-tensor Pytree with the + # not-None leaves of the tensor Pytree. + restored_state = jax.tree.map( + lambda non_tensor, tensor: non_tensor if tensor is None else tensor, + state, + restored_state_with_tensors, + ) + time_diff = time.perf_counter() - start_t + logging.info("Took %ss to restore emergency checkpoint from %s.", time_diff, cfg.dir) + return step, restored_state + + def wait_until_finished(self): + """See `BaseCheckpointer.wait_until_finished` docstring for details.""" + self._non_tensor_manager.wait_until_finished() + self._tensor_manager.wait_until_finished() + + def stop(self): + """See `BaseCheckpointer.stop` for details.""" + self._non_tensor_manager.stop() + self._tensor_manager.close() diff --git a/axlearn/common/checkpointer_orbax_test.py b/axlearn/common/checkpointer_orbax_test.py index cbb67630..cb3d6015 100644 --- a/axlearn/common/checkpointer_orbax_test.py +++ b/axlearn/common/checkpointer_orbax_test.py @@ -7,18 +7,44 @@ # pylint: disable=protected-access +import logging +import multiprocessing as mp import os +import socket import tempfile -from typing import Sequence +from contextlib import ExitStack, closing +from typing import Optional, Sequence import jax +import numpy as np import orbax.checkpoint as ocp +import pytest +import tensorflow as tf +from absl.logging import PythonFormatter +from absl.testing import parameterized from jax import numpy as jnp from jax.experimental import mesh_utils from axlearn.common import test_utils from axlearn.common.checkpointer import read_index_file -from axlearn.common.checkpointer_orbax import OrbaxCheckpointer +from axlearn.common.checkpointer_orbax import ( + OrbaxCheckpointer, + OrbaxEmergencyCheckpointer, + _dump_process_id, + _get_previous_process_id, + _init_consistent_proc_ids, + config_for_function, + every_n_steps_policy, + get_consistent_proc_id, + ocp, +) + + +def _find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] def _mesh(mesh_shape: Sequence[int]): @@ -26,8 +52,146 @@ def _mesh(mesh_shape: Sequence[int]): return jax.sharding.Mesh(devices, ("data", "model")) -class OrbaxCheckpointerTest(test_utils.TestCase): - def test_index(self): +def _logger_init(): + handler = logging.StreamHandler() + handler.setFormatter(PythonFormatter()) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + +def _test_orbax_main(process_id: int, port: int, persist_dir: str, local_dir: str, q: mp.Queue): + # pylint: disable=import-outside-toplevel + from orbax.checkpoint._src.multihost import multislice + from orbax.checkpoint.experimental.emergency import checkpoint_manager + + # Patch for GPU use. We don't need to use mock.patch because we're running in a subprocess. + multislice.get_device_memory = lambda: int(80e9) + + def slice_devices( + global_mesh: jax.sharding.Mesh, + *, + replica_id: int = 0, + replica_axis_index: int = 0, + ) -> np.ndarray: + return np.take( + global_mesh.devices, + replica_id, + axis=replica_axis_index, + ) + + def _all_devices_excepting_slice( + devices: np.ndarray, + *, + replica_id: int = 0, + replica_axis_index: int = 0, + ) -> np.ndarray: + return np.delete(devices, replica_id, axis=replica_axis_index) + + # We're not running in a true multi-slice environment. Patch the following two functions to + # mock multi-slice discovery. + multislice.slice_devices = slice_devices + checkpoint_manager._all_devices_excepting_slice = _all_devices_excepting_slice + + prev_process_id = get_consistent_proc_id( + distributed_coordinator=f"127.0.0.1:{port}", + num_processes=4, + process_id=process_id, + trainer_dir=persist_dir, + local_ckpt_dir=local_dir, + jax_backend="gpu", + ) + + jax.distributed.initialize( + coordinator_address=f"127.0.0.1:{port}", + num_processes=4, + process_id=prev_process_id, + local_device_ids=[process_id], + ) + + cfg: OrbaxEmergencyCheckpointer.Config = OrbaxEmergencyCheckpointer.default_config() + cfg.name = "emergency" + cfg.save_policy = config_for_function(every_n_steps_policy).set(n=25) + cfg.local_save_policy = config_for_function(every_n_steps_policy).set(n=5) + # Local checkpoint path suffix must be the same for orbax synchronization to work. + cfg.local_dir = local_dir + cfg.unique_str = persist_dir + cfg.dir = persist_dir + cfg.keep_last_n = 2 + cfg.replica_axis_index = 0 + cfg.async_timeout_secs = 5 + cfg.non_tensor_async_timeout_secs = 5 + checkpointer: OrbaxEmergencyCheckpointer = cfg.instantiate(parent=None) + mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(2, 2), ["data", "model"]) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "model")) + x = jax.make_array_from_process_local_data(sharding, np.zeros((8, 8), dtype=np.float32), (8, 8)) + + input_iter = iter(tf.data.Dataset.counter()) + state = {"x": x, "y": None, "z": input_iter} + with mesh: + step, state = checkpointer.restore(step=None, state=state) + if step is None: + step = 0 + else: + # This is the step of the last persistent ckpt. + assert checkpointer.latest_checkpoint_step(checkpointer.config.dir) == 25 + # This step should already be garbage collected. + assert 20 not in checkpointer._tensor_manager.all_steps() + # Although the persistent save interval is 25 steps, local save interval is 5 + # steps, so we should be able to restore step 40. + assert step >= 40 + # Since we save after step X is finished, after restore the first step is X + 1. + step += 1 + for i in range(step, 100): + assert jnp.all(state["x"] == i).item() + assert i == next(input_iter) # input_iter's state is modified inplace. + state = {"x": state["x"] + 1, "y": state["y"], "z": state["z"]} + checkpointer.save(step=i, state=state) + if process_id == 0: + logging.info("step %d", i) + if i == 45 and q is not None and process_id == 0: + # Signal processes to be killed at around step 45. + q.put(1) + + checkpointer.wait_until_finished() + jax.distributed.shutdown() + + +def _test_init_proc_id_main( + *, + distributed_coordinator: Optional[str] = None, + num_processes: Optional[int] = None, + process_id: Optional[int] = None, + trainer_dir: str, + local_ckpt_dir: str, + proc_per_slice: int, + new_idx_map: dict[int, int], +): + # Fake some envs. + os.environ["MEGASCALE_NUM_SLICES"] = str(num_processes // proc_per_slice) + os.environ["MEGASCALE_SLICE_ID"] = f"{process_id // proc_per_slice}" + os.environ["TPU_WORKER_ID"] = str(process_id % proc_per_slice) + os.environ["TPU_WORKER_HOSTNAMES"] = ",".join(["a"] * proc_per_slice) + + if new_idx_map[process_id] != -1: + _dump_process_id( + local_ckpt_dir, unique_str=trainer_dir, process_index=new_idx_map[process_id] + ) + + jax.default_backend = lambda: "tpu" + _init_consistent_proc_ids( + jax_backend="gpu", # Have to use gpu here to avoid getting an error in `setup`. + distributed_coordinator=distributed_coordinator, + num_processes=num_processes, + process_id=process_id, + trainer_dir=trainer_dir, + local_ckpt_dir=local_ckpt_dir, + ) + + +class OrbaxCheckpointerTest(parameterized.TestCase): + # This test needs to run last since the other two tests require an uninitialized jax backend. + def test2_index(self): """Tests that index files saved with orbax can be read with `read_index_file`.""" mesh_shape = (1, 1) if not test_utils.is_supported_mesh_shape(mesh_shape): @@ -52,3 +216,116 @@ def test_index(self): ), ) self.assertEqual(ref_index, test_index["index"]) + + # This test can also run on CPU. The proc id init GPU logic is tested in `test_emergency_ckpt`. + def test1_init_proc_id_tpu(self): + free_port = _find_free_port() + new_idx_map = { + # First two slices are healthy, but have different slice id during restart. + 0: 2, + 1: 3, + 2: 6, + 3: 7, + 4: -1, # This failed slice has one node swapped out. + 5: 1, + 6: -1, # This failed slice has two nodes swapped out. + 7: -1, + } + with ExitStack() as stack: + num_processes = 8 + local_tempdirs = [ + stack.enter_context(tempfile.TemporaryDirectory()) for _ in range(num_processes) + ] + processes = [] + for i in range(num_processes): + proc = mp.Process( + target=_test_init_proc_id_main, + kwargs=dict( + distributed_coordinator=f"127.0.0.1:{free_port}", + num_processes=num_processes, + process_id=i, + trainer_dir="any", + local_ckpt_dir=local_tempdirs[i], + proc_per_slice=2, + new_idx_map=new_idx_map, + ), + ) + proc.start() + processes.append(proc) + + for p in processes: + p.join() + self.assertEqual(p.exitcode, 0) + + new_proc_ids = [ + _get_previous_process_id(local_dir, unique_str="any") + for local_dir in local_tempdirs + ] + for i in range(4): + self.assertEqual(new_proc_ids[i], new_idx_map[i]) + + if new_proc_ids[4] == 0: + self.assertEqual(new_proc_ids[5], 1) + self.assertEqual(new_proc_ids[6], 4) + self.assertEqual(new_proc_ids[7], 5) + elif new_proc_ids[4] == 4: + self.assertEqual(new_proc_ids[5], 5) + self.assertEqual(new_proc_ids[6], 0) + self.assertEqual(new_proc_ids[7], 1) + else: + self.fail("new proc id of proc 4 should be either 0 or 4") + + # This test requires 4 devices to run. Note: we cannot use skipif(jax.local_device_count() < 4) + # because it will initialize the backend, causing the jax.distributed.initialize to fail in + # _test_orbax_main. Using the `spawn` context from multiprocessing results in a different error. + # Since we don't have GPU/TPU CI anyway, we always skip this test. + @pytest.mark.skipif(False, reason="This test needs to be run manually.") + def test0_emergency_ckpt(self): + with ExitStack() as stack: + num_processes = 4 + local_tempdirs = [ + stack.enter_context(tempfile.TemporaryDirectory()) for _ in range(num_processes) + ] + persistent_tempdir = stack.enter_context(tempfile.TemporaryDirectory()) + q = mp.Queue() + # Populate log messages. + _logger_init() + + def start_processes(reverse_process_id: bool = False) -> list[mp.Process]: + free_port = _find_free_port() + processes = [] + for i in range(num_processes): + p = mp.Process( + target=_test_orbax_main, + args=( + i if not reverse_process_id else num_processes - i - 1, + free_port, + persistent_tempdir, + local_tempdirs[i], + q, + ), + ) + processes.append(p) + p.start() + return processes + + processes = start_processes() + + # Block until we get a signal from a process. + q.get() + + # Kill all processes to simulate a failure. + for p in processes: + p.kill() + + # Shuffle the process ids to verify that we are able to restore the process id. + processes = start_processes(reverse_process_id=True) + + try: + for p in processes: + p.join() + for p in processes: + self.assertEqual(p.exitcode, 0) + finally: + for p in processes: + p.kill() diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index 7b485cd1..e7b0315c 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -47,12 +47,15 @@ read_state_spec, restore_tf_savables, ) -from axlearn.common.checkpointer_orbax import OrbaxCheckpointer -from axlearn.common.input_grain_test import range_dataset +from axlearn.common.checkpointer_orbax import _GRAIN_INSTALLED, OrbaxCheckpointer from axlearn.common.metrics import WeightedScalar from axlearn.common.summary_writer import SummaryWriter from axlearn.common.utils import VDict +# Conditionally import to allow running tests on Apple silicon where grain is not supported. +if _GRAIN_INSTALLED: + from axlearn.common.input_grain_test import range_dataset + def _mesh(mesh_shape: Sequence[int]): devices = mesh_utils.create_device_mesh(mesh_shape) @@ -112,7 +115,9 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]): ) # When the given state has a different array shape: [3] instead of [2] for y. - with self.assertRaisesRegex(ValueError, "checkpoint tree dtypes or shapes"): + with self.assertRaisesRegex( + ValueError, "checkpoint tree dtypes or shapes|not compatible" + ): ckpt.restore( step=None, state=dict( @@ -124,7 +129,7 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]): # Orbax throws AssertionError in this case. with self.assertRaisesRegex( (AssertionError, ValueError), - "(checkpoint tree dtypes or shapes|do not match)", + "(checkpoint tree dtypes or shapes|not compatible)", ): ckpt.restore( step=None, @@ -344,6 +349,8 @@ def tensors_only(tree): @parameterized.parameters([Checkpointer, OrbaxCheckpointer]) def test_grain(self, checkpointer_cls): + if not _GRAIN_INSTALLED: + self.skipTest("Cannot run when grain is not installed.") mesh_shape = (1, 1) if not test_utils.is_supported_mesh_shape(mesh_shape): return @@ -777,7 +784,7 @@ def test_every_n_steps_and_last_policy(self): self.assertTrue(policy(step=13, evaler_summaries={})) @parameterized.parameters([Checkpointer, OrbaxCheckpointer]) - def test_latest_checkpoint_path(self, checkpointer_cls: Type[BaseCheckpointer]): + def test_latest_checkpoint_path_and_step(self, checkpointer_cls: Type[BaseCheckpointer]): with tempfile.TemporaryDirectory() as td: # Test that the most recent checkpoint is returned. ckpt_paths = {} @@ -799,6 +806,7 @@ def test_latest_checkpoint_path(self, checkpointer_cls: Type[BaseCheckpointer]): final_ckpt_path = ckpt_paths[10] # Note: step 11 is not complete, so the latest path returns step 10. self.assertEqual(checkpointer_cls.latest_checkpoint_path(td), final_ckpt_path) + self.assertEqual(checkpointer_cls.latest_checkpoint_step(td), 10) @parameterized.parameters([Checkpointer, OrbaxCheckpointer]) def test_read_state_spec(self, checkpointer_cls: Type[BaseCheckpointer]): diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index 31532be2..34991e35 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -41,6 +41,7 @@ from absl import flags, logging +from axlearn.common.checkpointer_orbax import get_consistent_proc_id from axlearn.common.status_server import StatusHTTPServer from axlearn.common.utils import get_data_dir from axlearn.common.utils_spmd import setup as setup_spmd @@ -88,6 +89,12 @@ "", "See the docstring of your `health_check_module`.", ) +flags.DEFINE_string( + "local_ckpt_dir", + "", + "If specified, enable local checkpoint and saves checkpoints to this " + "directory. See `OrbaxEmergencyCheckpointer` for more details", +) FLAGS = flags.FLAGS @@ -107,13 +114,20 @@ def setup(): health_check = nullcontext() with health_check: - setup_spmd( + init_spmd_args = dict( + jax_backend=FLAGS.jax_backend, distributed_coordinator=FLAGS.distributed_coordinator, num_processes=FLAGS.num_processes, process_id=FLAGS.process_id, - jax_backend=FLAGS.jax_backend, initialization_timeout=FLAGS.initialization_timeout, ) + if FLAGS.local_ckpt_dir: + init_spmd_args["process_id"] = get_consistent_proc_id( + trainer_dir=FLAGS.trainer_dir, + local_ckpt_dir=FLAGS.local_ckpt_dir, + **init_spmd_args, + ) + setup_spmd(**init_spmd_args) if FLAGS.jax_profiler_port is not None: # Start jax.profiler for Tensorboard and profiling in open source. diff --git a/pyproject.toml b/pyproject.toml index 076c4706..bec336cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,7 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - "orbax-checkpoint==0.5.23", + "orbax-checkpoint==0.8.0", ] # Grain input processing. Currently does not support macos. grain = [