Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Orbax's emergency checkpoint. #820

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

hanzhi713
Copy link
Member

No description provided.

@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from b4a00eb to 0e74dc8 Compare November 12, 2024 19:49
@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from cf43485 to 56f51de Compare November 19, 2024 00:37
@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from d5a3e0f to bea8b71 Compare January 8, 2025 22:45
@hanzhi713 hanzhi713 marked this pull request as ready for review January 30, 2025 23:13
@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners January 30, 2025 23:13
del os.environ["JAX_PLATFORMS"]


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this overlap with

class OrbaxCheckpointer(BaseCheckpointer):
?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The orbax regular checkpointer saves 1 gcs checkpoint for n slices per save. This checkpointer saves n-1 checkpoints to a local path (usually a ramdisk), and also 1 checkpoint to gcs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it's not possible to share code between the two implementations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.

  • How should users use them? Should they use both of them or one of them? How should they pick?
  • Or can we replace OrbaxCheckpointer with this class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to clarify this:

    This checkpointer is designed to improve the goodput of large multi-slice training jobs that
    use data-parallelism across slices. At least two data-parallel slices are required. For other
    use cases where this is not applicable or ultimate goodput is not required, please use
    `OrbaxCheckpointer`.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It requires two data-parallel slices

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not just data-parallelism, but data-parallel slices, i.e. it has to be multi-slice training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be an idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs support from orbax though.

@ruomingp ruomingp self-assigned this Jan 31, 2025
@hanzhi713 hanzhi713 requested a review from ruomingp January 31, 2025 21:35
@ruomingp ruomingp removed their request for review February 1, 2025 02:51
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation of the constraints. I wonder what the long term plan is.

Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?

Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?

@hanzhi713
Copy link
Member Author

Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?

I don't know if Google has such a plan. The orbax in-memory checkpointer actually uses the orbax regular checkpointer under the hood, which might be required by design/by nature of the problem that it solves.

Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?

Since in-memory checkpointer uses the regular orbax checkpointer under the hood, the tensor state in the persistent checkpoint (i.e. the one stored to gcs) can be loaded by OrbaxStateBuilder (see #866). Therefore, we can say that the checkpoints are compatible for eval and inference purposes. It's just that the training checkpoint will be incompatible, meaning that OrbaxEmergencyCheckpointer's checkpoint cannot be loaded by OrbaxCheckpointer.

@hanzhi713
Copy link
Member Author

I think in the long term, it's probably possible to unify the checkpoint structure between the two checkpointer (regular and in-memory), but it's unknown whether we can unify the codepath.

Comment on lines +519 to +521
However, the above procedure doesn't apply to some non-tensor states such as data iterators.
Data iterators are unique across jax processes, and thus cannot be stored on nodes. Orbax
emergency checkpointer doesn't support non-tensor states. Therefore, we reuse axlearn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE: "Orbax emergency checkpointer doesn't support non-tensor states"

While this is true, it seems that it can be extended to support non-tensor states. Specifically, the local checkpointer manager already writes process metadata:

https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L569-L574

which is implemented by writing json strings to a file:
https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/mesh_consistency.py#L74-L104

Maybe it can be extended to take user_process_metadata? Would this work? Will the Orbax team be receptive to this idea?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's doable in the beginning. We'll need to talk to Orbax for their plan to support this feature.

del os.environ["JAX_PLATFORMS"]


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.

Comment on lines +744 to +748
# Note that save() waits for prior serialization to finish.
self._non_tensor_manager.save(step=step, state=state)
self._get_tensor_manager(state_with_tensors).save(
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we mark the completion of a checkpoint? It should happen only when both tensor and non-tensor states are saved. How is this ensured?

Please add a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no special marker for completion of both. There are only markers for each of them individually. So, during restore, we look for both of them only load a specific step when both marker exists.

@hanzhi713
Copy link
Member Author

@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API?

I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants