You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We have a run on the exact same dataset / hparams except we change the number of nodes from 8->2->1. We noticed that when we reduce the number of nodes the gradient norm goes up:
Here is an 8 node run:
Here is a 2 node run:
Here is a 1 node run:
We can see the grad norm at initialization is ~4x different between 8 node and 1 node run. With the fix in #2172, I would expect the grad norms to be similar regardless of the world size. The only difference between the runs is the global batch size (64 on 1 node, 512 on 8 nodes), but I would not expect this to cause such a big difference.
Is it possible there are still some issues in how we compute / scale the gradients?
The text was updated successfully, but these errors were encountered:
I think right now we're just reporting the (unreduced) total norm from rank 0. If I call full_tensor() in our local runs, I get the same norm regardless of the world size + matches what we're seeing in an identical run in NeMo.
Thanks @EugenHotaj for the issue and the update. I agree that full_tensor() should give the correct grad. Just ran a quick test to make sure there's no major perf hit to doing this and it looks good. Do you wanna open a PR with the fix? (If not let me know and I am happy to do so)
Continuing the discussion from #2172 (thanks @mirceamironenco, @ebsmothers for the fix!).
We have a run on the exact same dataset / hparams except we change the number of nodes from 8->2->1. We noticed that when we reduce the number of nodes the gradient norm goes up:
Here is an 8 node run:
Here is a 2 node run:
Here is a 1 node run:
We can see the grad norm at initialization is ~4x different between 8 node and 1 node run. With the fix in #2172, I would expect the grad norms to be similar regardless of the world size. The only difference between the runs is the global batch size (64 on 1 node, 512 on 8 nodes), but I would not expect this to cause such a big difference.
Is it possible there are still some issues in how we compute / scale the gradients?
The text was updated successfully, but these errors were encountered: