Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grad Norm Differences Across Nodes #2240

Open
EugenHotaj opened this issue Jan 9, 2025 · 2 comments
Open

Grad Norm Differences Across Nodes #2240

EugenHotaj opened this issue Jan 9, 2025 · 2 comments
Labels
discussion Start a discussion

Comments

@EugenHotaj
Copy link
Contributor

EugenHotaj commented Jan 9, 2025

Continuing the discussion from #2172 (thanks @mirceamironenco, @ebsmothers for the fix!).

We have a run on the exact same dataset / hparams except we change the number of nodes from 8->2->1. We noticed that when we reduce the number of nodes the gradient norm goes up:

Here is an 8 node run:
Screenshot 2025-01-09 at 12 06 30 PM

Here is a 2 node run:
Screenshot 2025-01-09 at 12 05 41 PM

Here is a 1 node run:
Screenshot 2025-01-09 at 12 05 48 PM

We can see the grad norm at initialization is ~4x different between 8 node and 1 node run. With the fix in #2172, I would expect the grad norms to be similar regardless of the world size. The only difference between the runs is the global batch size (64 on 1 node, 512 on 8 nodes), but I would not expect this to cause such a big difference.

Is it possible there are still some issues in how we compute / scale the gradients?

@joecummings joecummings added the discussion Start a discussion label Jan 9, 2025
@EugenHotaj
Copy link
Contributor Author

EugenHotaj commented Jan 9, 2025

I think this may just be a logging issue actually. To get the correct grad_norm we need to call grad_norm.full_tensor() like they do in torchtitan here: https://github.com/pytorch/torchtitan/blob/90567fc9827ffdf17bdd0349cd5276c662d0769a/torchtitan/utils.py#L396

I think right now we're just reporting the (unreduced) total norm from rank 0. If I call full_tensor() in our local runs, I get the same norm regardless of the world size + matches what we're seeing in an identical run in NeMo.

@ebsmothers
Copy link
Contributor

Thanks @EugenHotaj for the issue and the update. I agree that full_tensor() should give the correct grad. Just ran a quick test to make sure there's no major perf hit to doing this and it looks good. Do you wanna open a PR with the fix? (If not let me know and I am happy to do so)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Start a discussion
Projects
None yet
Development

No branches or pull requests

3 participants