-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate Orbax's emergency checkpoint. #820
base: main
Are you sure you want to change the base?
Conversation
b4a00eb
to
0e74dc8
Compare
cf43485
to
56f51de
Compare
d5a3e0f
to
bea8b71
Compare
bea8b71
to
65f3d46
Compare
65f3d46
to
c1a476d
Compare
del os.environ["JAX_PLATFORMS"] | ||
|
||
|
||
class OrbaxEmergencyCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this overlap with
axlearn/axlearn/common/checkpointer_orbax.py
Line 169 in 140a18f
class OrbaxCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The orbax regular checkpointer saves 1 gcs checkpoint for n slices per save. This checkpointer saves n-1 checkpoints to a local path (usually a ramdisk), and also 1 checkpoint to gcs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately it's not possible to share code between the two implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification.
- How should users use them? Should they use both of them or one of them? How should they pick?
- Or can we replace OrbaxCheckpointer with this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment to clarify this:
This checkpointer is designed to improve the goodput of large multi-slice training jobs that
use data-parallelism across slices. At least two data-parallel slices are required. For other
use cases where this is not applicable or ultimate goodput is not required, please use
`OrbaxCheckpointer`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It requires two data-parallel slices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not just data-parallelism, but data-parallel slices, i.e. it has to be multi-slice training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be an idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still needs support from orbax though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation of the constraints. I wonder what the long term plan is.
Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?
Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?
I don't know if Google has such a plan. The orbax in-memory checkpointer actually uses the orbax regular checkpointer under the hood, which might be required by design/by nature of the problem that it solves.
Since in-memory checkpointer uses the regular orbax checkpointer under the hood, the tensor state in the persistent checkpoint (i.e. the one stored to gcs) can be loaded by |
I think in the long term, it's probably possible to unify the checkpoint structure between the two checkpointer (regular and in-memory), but it's unknown whether we can unify the codepath. |
However, the above procedure doesn't apply to some non-tensor states such as data iterators. | ||
Data iterators are unique across jax processes, and thus cannot be stored on nodes. Orbax | ||
emergency checkpointer doesn't support non-tensor states. Therefore, we reuse axlearn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RE: "Orbax emergency checkpointer doesn't support non-tensor states"
While this is true, it seems that it can be extended to support non-tensor states. Specifically, the local checkpointer manager already writes process metadata:
which is implemented by writing json strings to a file:
https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/mesh_consistency.py#L74-L104
Maybe it can be extended to take user_process_metadata
? Would this work? Will the Orbax team be receptive to this idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it's doable in the beginning. We'll need to talk to Orbax for their plan to support this feature.
del os.environ["JAX_PLATFORMS"] | ||
|
||
|
||
class OrbaxEmergencyCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.
# Note that save() waits for prior serialization to finish. | ||
self._non_tensor_manager.save(step=step, state=state) | ||
self._get_tensor_manager(state_with_tensors).save( | ||
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we mark the completion of a checkpoint? It should happen only when both tensor and non-tensor states are saved. How is this ensured?
Please add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no special marker for completion of both. There are only markers for each of them individually. So, during restore, we look for both of them only load a specific step when both marker exists.
@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API? I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later. |
No description provided.