[PyTorch] Avoid parameters
function in op backward pass
#1403
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
We have recently experienced some esoteric errors in the LayerNorm backward pass:
I haven't fully investigated, but I suspect that FSDP is manipulating module parameters so that they are only available in the forward pass. This messes with a check the operation fuser does in the backward pass to make sure the number of params and param grads match.
This PR tweaks the operation fuser to avoid calling
parameters()
in the backward pass. In particular, it counts params for each op in the forward pass and caches the counts for use in the backward pass.Type of change
Changes
parameters
function in backward passChecklist: