Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make cache_all_gather configurable #20508

Open
lgiacomoni opened this issue Dec 13, 2024 · 2 comments
Open

Make cache_all_gather configurable #20508

lgiacomoni opened this issue Dec 13, 2024 · 2 comments

Comments

@lgiacomoni
Copy link

lgiacomoni commented Dec 13, 2024

This option is defaulted to True and there doesn't seem to be a way to set it to False programmatically. Is there a specific reason for that? The accompanying comment seems to suggest that it should be safe to set it to False.

I am trying to have a more fine grain control on how XLA allocates memory for sharded parameters, and I was wondering if cache_all_gather could help with that. At the moment, I have a large model that during the forward pass, gathers the sharded param array for the matmul, but then it keeps it in memory to re use in the backward pass instead of discarding the shards and gathering them again. Has anyone tried to enforce memory allocation/deallocation with XLA before? Is that even possible?

@patrick-toulme
Copy link
Contributor

Those all-gathers are probably getting CSEed away. Try this

# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')

@patrick-toulme
Copy link
Contributor

Run your trainer with this and you can see the number of all-gathers after each pass

export XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=${HLO_DUMP_PATH} --xla_dump_hlo_pass_re='.*'"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants