Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Jan 30, 2025
1 parent 16d877a commit c1a476d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
31 changes: 17 additions & 14 deletions axlearn/common/checkpointer_orbax_emergency.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,12 @@ def to_string(self):

@property
def prev_slice_id(self):
assert self.num_proc_per_slice is not None
return self.inv_proc_id // self.num_proc_per_slice

@property
def cur_slice_id(self):
assert self.num_proc_per_slice is not None
return self.cur_proc_id // self.num_proc_per_slice

@classmethod
Expand All @@ -224,27 +226,27 @@ def from_string(
return cls(ls[0], int(ls[1]), int(ls[2]), key=key, num_proc_per_slice=num_proc_per_slice)


def _get_previous_process_info(local_dir: str, *, unique_str: str) -> _ProcessInfo:
def _get_previous_process_info(local_dir: str, *, trainer_dir: str) -> _ProcessInfo:
"""Gets process info from local checkpoint directory."""
path = os.path.join(local_dir, _get_unique_id(unique_str), _PROCESS_ID_FILE_NAME)
path = os.path.join(local_dir, _get_unique_id(trainer_dir), _PROCESS_ID_FILE_NAME)
if not fs.exists(path):
return _ProcessInfo(address="", inv_proc_id=-1, cur_proc_id=-1)

with fs.open(path) as f:
return _ProcessInfo.from_string(f.read())


def _dump_process_info(local_dir: str, *, unique_str: str, proc_info: _ProcessInfo):
def _dump_process_info(local_dir: str, *, trainer_dir: str, proc_info: _ProcessInfo):
"""Dumps process info to local checkpoint directory."""
local_dir = os.path.join(local_dir, _get_unique_id(unique_str))
local_dir = os.path.join(local_dir, _get_unique_id(trainer_dir))
fs.makedirs(local_dir)
process_id_file = os.path.join(local_dir, _PROCESS_ID_FILE_NAME)
with fs.open(process_id_file, "w") as f:
f.write(proc_info.to_string())


def _get_unique_id(unique_str: str) -> str:
return hashlib.sha256(unique_str.encode(), usedforsecurity=False).hexdigest()
def _get_unique_id(trainer_dir: str) -> str:
return hashlib.sha256(trainer_dir.encode(), usedforsecurity=False).hexdigest()


def _logger_init():
Expand Down Expand Up @@ -275,7 +277,7 @@ def _init_consistent_proc_ids(
timeout_ms = barrier_timeout_seconds * 1000
utils_spmd.setup(**setup_kwargs)
client: jax.lib.xla_extension.DistributedRuntimeClient = global_state.client
local_proc_info = _get_previous_process_info(local_ckpt_dir, unique_str=trainer_dir)
local_proc_info = _get_previous_process_info(local_ckpt_dir, trainer_dir=trainer_dir)
key_prefix = "axlearn/id_reassign"
# Local key just needs to be unique for each process.
local_proc_info.key = f"{key_prefix}/{jax.process_index()}"
Expand Down Expand Up @@ -398,7 +400,7 @@ def assign_fn(info: _ProcessInfo):
new_info.inv_proc_id,
new_info.address,
)
_dump_process_info(local_ckpt_dir, unique_str=trainer_dir, proc_info=new_info)
_dump_process_info(local_ckpt_dir, trainer_dir=trainer_dir, proc_info=new_info)
# Block to avoid coordinator exiting too early.
client.wait_at_barrier("axlearn/id-reassign-finalize", timeout_in_ms=timeout_ms)
jax.distributed.shutdown()
Expand Down Expand Up @@ -456,7 +458,7 @@ def get_consistent_proc_info(
f"Got exit code {proc.exitcode}. Please check the log above for errors."
)

info = _get_previous_process_info(local_ckpt_dir, unique_str=trainer_dir)
info = _get_previous_process_info(local_ckpt_dir, trainer_dir=trainer_dir)
if info.inv_proc_id == -1:
raise RuntimeError("Expects inv process id != -1, but got -1.")
logging.info(
Expand Down Expand Up @@ -537,9 +539,10 @@ class Config(BaseCheckpointer.Config):
local_dir: Ckpt base path for local storage. The content in this path must persist
across pod restarts unless the restart is caused by node failure. `local_dir` must
be the same for all processes or processes may hang.
unqiue_str: A string that's unique for the current run. Typically, this is set to
trainer_dir. Local checkpoint will be stored in local_dir/sha256(unique_str).
During init, all other folders in local_dir will be removed.
trainer_dir: A string that's unique for the current run. Typically, this is set to
trainer_dir. Local checkpoint will be stored in local_dir/sha256(trainer_dir).
During init, all other folders in local_dir will be removed to prevent unexpected
memory usage.
save_policy: Save policy for persistent checkpoints.
local_save_policy: Save policy for local checkpoints. This should be more frequent than
`save_policy`. Note that data iterator will be saved with either `save_policy` or
Expand All @@ -557,7 +560,7 @@ class Config(BaseCheckpointer.Config):
every_n_steps_policy
).set(n=10)
local_dir: str = "/host-tmp/checkpoints"
unique_str: Required[str] = REQUIRED
trainer_dir: Required[str] = REQUIRED
non_tensor_async_timeout_secs: int = 300
async_timeout_secs: int = 3600
replica_axis_index: Required[int] = REQUIRED
Expand Down Expand Up @@ -603,7 +606,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
fs.makedirs(os.path.join(cfg.dir, self._NON_TENSORS_PREFIX))
fs.makedirs(os.path.join(cfg.dir, self._TENSORS_PREFIX))
# Cleanup local checkpoints from different runs.
unique_id = _get_unique_id(cfg.unique_str)
unique_id = _get_unique_id(cfg.trainer_dir)
for fd in fs.listdir(cfg.local_dir):
if not fd.startswith(".") and fd != unique_id:
fs.rmtree(os.path.join(cfg.local_dir, fd))
Expand Down
6 changes: 3 additions & 3 deletions axlearn/common/checkpointer_orbax_emergency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _all_devices_excepting_slice(
cfg.local_save_policy = config_for_function(every_n_steps_policy).set(n=5)
# Local checkpoint path suffix must be the same for orbax synchronization to work.
cfg.local_dir = local_dir
cfg.unique_str = persist_dir
cfg.trainer_dir = persist_dir
cfg.dir = persist_dir
cfg.keep_last_n = 2
cfg.replica_axis_index = 0
Expand Down Expand Up @@ -160,7 +160,7 @@ def _test_init_proc_id_main(
if new_idx_map[process_id] != -1:
_dump_process_info(
local_ckpt_dir,
unique_str=trainer_dir,
trainer_dir=trainer_dir,
proc_info=_ProcessInfo(distributed_coordinator, new_idx_map[process_id], process_id),
)

Expand Down Expand Up @@ -228,7 +228,7 @@ def test_init_proc_id_tpu(self):
self.assertEqual(p.exitcode, 0)

infos = [
_get_previous_process_info(local_dir, unique_str="any")
_get_previous_process_info(local_dir, trainer_dir="any")
for local_dir in local_tempdirs
]
for info in infos:
Expand Down
6 changes: 4 additions & 2 deletions axlearn/common/launch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def get_trainer_config(
from axlearn.cloud.gcp.monitoring.tpu_device_monitor import create_tpu_monitor

trainer_config.device_monitor = create_tpu_monitor()
if hasattr(trainer_config.checkpointer, "unique_str"):
trainer_config.checkpointer.unique_str = trainer_config.dir
if hasattr(trainer_config.checkpointer, "trainer_dir"):
# Set trainer_dir if not already set.
if not isinstance(trainer_config.checkpointer.trainer_dir, str):
trainer_config.checkpointer.trainer_dir = trainer_config.dir
return trainer_config


Expand Down

0 comments on commit c1a476d

Please sign in to comment.