Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer state offload to CPU #204

Merged
merged 14 commits into from
Jan 30, 2025
Merged

Optimizer state offload to CPU #204

merged 14 commits into from
Jan 30, 2025

Conversation

apaz-cli
Copy link
Contributor

@apaz-cli apaz-cli commented Jan 22, 2025

./scripts/simulate_multi_node_diloco.sh 1 8 src/zeroband/train.py @ configs/10B/H100_devel.toml

Before:

22:55:18 [DEBUG] [Rank 0] Max memory used: 49695.58 MB

After:

22:55:18 [DEBUG] [Rank 0] Max memory used: 49695.58 MB

@awgu do you have any idea what's going on here? I added CPUOffloadPolicy, but it doesn't seem to be offloading.

@awgu
Copy link

awgu commented Jan 22, 2025

I have been summoned 😮

I would recommend:

  1. Adding some assertions to check if your model.parameters() are actually on CPU when you expect them to be (e.g. after init)
  2. Use memory snapshot which can give you some stack traces of allocations to know why you are still seeing GPU allocation

There might be a public API now, but this is what I usually do:

# Add this somewhere early in your init code
torch.cuda.memory._record_memory_history()
...
# Later
snapshot = torch.cuda.memory._snapshot()
with open("snapshot.pickle", "wb") as f:
    pickle.dump(snapshot, f)
# Or something like
snapshot = torch.cuda.memory._snapshot()
with open("snapshot_{torch.distributed.get_rank()}.pickle", "wb") as f:
    pickle.dump(snapshot, f)

@samsja
Copy link
Collaborator

samsja commented Jan 22, 2025

thanks @awgu !!

@apaz-cli
Copy link
Contributor Author

apaz-cli commented Jan 23, 2025

@awgu I just spent a couple hours poking at it.

Actually, as far as I can tell, the parameters are not on CPU until the first time loss.backward() is called with model.set_requires_gradient_sync(True). Until that point, param.grad for all the params is None. I expected it to be there (especially because I had plans to write my own CPU optimizer using register_post_accumulate_grad_hook()), and I'm sort of wondering what's going on. Do you know what's up with that? Is it documented anywhere?

That's a different question though. The answer to the original question turned out to be that the optimizer actually DID get offloaded. Contrary to my belief. The max memory usage happens right at the beginning when the model is materialized before fully_shard().

So, now I'm trying to figure out how to not actually materialize the full model before sharding.

@awgu
Copy link

awgu commented Jan 23, 2025

I think the unsharded model passed to FSDP is on GPU, which aligns with your observation that the peak memory is before calling fully_shard.

model = model.to(world_info.local_rank)

At this point:

logger.debug("model fsdped")

I would expect all model parameters to be on CPU:

for param_name, param in model.named_parameters():
    assert param.device.type == "cpu", f"{param_name} is not on CPU!"

Could you check if that is the case?

@apaz-cli
Copy link
Contributor Author

apaz-cli commented Jan 23, 2025

@awgu Yep, the parameters are all on CPU.

I now believe the issue was that I uncommented torch.set_default_device(cuda) so the model would load faster. In like 5 seconds, rather than 40. But it also materializes it. Should learn how that works. When I comment it out, I see the memory savings.

Comment on lines 1 to 2
name_model = "26B"
type_model = "llama2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's create a 26B model config then instead of modifying the 10b_devel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair. This is the one that I'm making changes to, to try to get the model sized to the machine.

Copy link
Collaborator

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you break down this pr in two/three PR ?

one with only the CPU offload code
one with all the others modification.
one with the lieager kernel updates

Otherwise its hard to review and potentially revert if needed

This was referenced Jan 30, 2025
@apaz-cli
Copy link
Contributor Author

Summary of changes:

  • Removed einops
  • Fixed types
  • Added logging to inner training loop
  • Changed the arguments on get_optimizer()
  • Fixed model initialization
  • Swap the order of detach() and clone()
  • Overlap loss all_reduce()s

@samsja
Copy link
Collaborator

samsja commented Jan 30, 2025

Summary of changes:
Removed einops
Fixed types
Added logging to inner training loop
Changed the arguments on get_optimizer()
Fixed model initialization
Swap the order of detach() and clone()
Overlap loss all_reduce()s

Copy link
Collaborator

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lfgtm

@apaz-cli apaz-cli merged commit eabafad into main Jan 30, 2025
1 of 2 checks passed
@apaz-cli apaz-cli deleted the ap/opt_offload branch January 30, 2025 01:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants