You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This option is defaulted to True and there doesn't seem to be a way to set it to False programmatically. Is there a specific reason for that? The accompanying comment seems to suggest that it should be safe to set it to False.
I am trying to have a more fine grain control on how XLA allocates memory for sharded parameters, and I was wondering if cache_all_gather could help with that. At the moment, I have a large model that during the forward pass, gathers the sharded param array for the matmul, but then it keeps it in memory to re use in the backward pass instead of discarding the shards and gathering them again. Has anyone tried to enforce memory allocation/deallocation with XLA before? Is that even possible?
The text was updated successfully, but these errors were encountered:
Those all-gathers are probably getting CSEed away. Try this
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
This option is defaulted to True and there doesn't seem to be a way to set it to False programmatically. Is there a specific reason for that? The accompanying comment seems to suggest that it should be safe to set it to False.
I am trying to have a more fine grain control on how XLA allocates memory for sharded parameters, and I was wondering if
cache_all_gather
could help with that. At the moment, I have a large model that during the forward pass, gathers the sharded param array for the matmul, but then it keeps it in memory to re use in the backward pass instead of discarding the shards and gathering them again. Has anyone tried to enforce memory allocation/deallocation with XLA before? Is that even possible?The text was updated successfully, but these errors were encountered: