Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate distributed state dict API #2138

Merged
merged 22 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
de72b9c
change load_from_full_model_state_dict
mori360 Dec 9, 2024
4b959f0
lint
mori360 Dec 10, 2024
2ad64d2
correct input parameter for gather_cpu_state_dict
mori360 Dec 10, 2024
19cf431
correct test_distribtued
mori360 Dec 10, 2024
7e11c4f
correct NF4 tensor process, add todo
mori360 Dec 11, 2024
b22c2e8
correct load_from_full_model_state_dict witj nf4 tensor
mori360 Dec 11, 2024
851fcf8
correct import error
mori360 Dec 11, 2024
9409ee8
correct nf4 tensor process in load_from_full_model_state_dict
mori360 Dec 11, 2024
581a65c
remove print
mori360 Dec 12, 2024
1b0ba78
modify load_from_full_model_state_dict to optimize memory cost as before
mori360 Dec 12, 2024
ca234c7
change load_from_full_model_state_dict
mori360 Dec 13, 2024
9cd28a9
add trainable_only to gather_cpu_state_dict
mori360 Dec 13, 2024
bfc7668
add trainable_only to gather_cpu_state_dict
mori360 Dec 13, 2024
f368b96
adjust dcp api using
mori360 Dec 13, 2024
eda1b9f
adjust dcp api using
mori360 Dec 13, 2024
db0eeed
add cpu_offload at load_from_full_optimizer_state_dict
mori360 Dec 17, 2024
91e9818
remove is_rank_zero from load_from_full_model_state_dict, add torch v…
mori360 Dec 20, 2024
11e24f1
import init_optim
mori360 Dec 20, 2024
3d0d26f
update _USE_DISTRIBUTED_STATE_DICT_API version check
mori360 Dec 20, 2024
e4730de
restructure load_from_full_model_state_dict, rewrite comments on vers…
mori360 Jan 6, 2025
ad08f42
Merge branch 'pytorch:main' into state_dict
mori360 Jan 8, 2025
4dfef98
change version check name, add more comments, load_from_full_model_st…
mori360 Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions recipes/dev/early_exit_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to update _setup_optimizer in this recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_setup_optimizer does not call load_from_full_model_state_dict, did not update with removal of self._is_rank_zero in _setup_optimizer

strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -757,7 +756,7 @@ def save_checkpoint(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -773,6 +772,7 @@ def save_checkpoint(
log.info("Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -781,7 +781,7 @@ def save_checkpoint(
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
self._model, opt, self._is_rank_zero, device=self._device
)
if self._is_rank_zero:
log.info(
Expand Down
3 changes: 2 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -602,6 +601,7 @@ def _setup_optimizer(
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
Expand All @@ -617,6 +617,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down
7 changes: 3 additions & 4 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -486,7 +485,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down Expand Up @@ -574,7 +572,6 @@ def _setup_teacher_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -611,6 +608,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None:
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized we are doing things differently here than in the other recipes.. seems to me like we could move the call to get_adapter_state_dict up before calling gather_cpu_state_dict, then you could make the same changes you did in e.g. lora_finetune_distributed.py (remove the call to get_adapter_state_dict and instead just pass trainable_only=self._save_adapter_weights_only to gather_cpu_state_dict). Lmk if that makes sense to you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, we can take get_adapter_state_dict cleaning among recipes as future work in the next PR

self._is_rank_zero,
device=self._device,
)

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
11 changes: 4 additions & 7 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -410,7 +409,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
is_dora = False
Expand Down Expand Up @@ -458,6 +456,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -546,17 +545,15 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
11 changes: 4 additions & 7 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -505,7 +504,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down Expand Up @@ -549,6 +547,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -656,14 +655,11 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
utils.log_rank_zero(
log,
Expand All @@ -673,6 +669,7 @@ def save_checkpoint(
if intermediate_checkpoint:
utils.log_rank_zero(log, "Retrieving optimizer state dict...")
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
2 changes: 0 additions & 2 deletions recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -500,7 +499,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Commenting here for further down in the file but) is there a reason you didn't also update save_checkpoint in this recipe? (We don't yet have a test for it so probably didn't get caught by CI)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR here only removes self._is_rank_zero from training.load_from_full_model_state_dict, which is not called at save_checkpoint, would add the test of lora_finetune_distributed_multi_dataset.py and early_exit_finetune_distributed.py later

cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down
8 changes: 5 additions & 3 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down Expand Up @@ -562,6 +561,7 @@ def _setup_optimizer(
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
Expand All @@ -577,6 +577,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -667,7 +668,7 @@ def save_checkpoint(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -682,6 +683,7 @@ def save_checkpoint(
utils.log_rank_zero(log, "Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand All @@ -690,7 +692,7 @@ def save_checkpoint(
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
self._model, opt, self._is_rank_zero, device=self._device
)
utils.log_rank_zero(
log,
Expand Down
11 changes: 4 additions & 7 deletions recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -550,7 +549,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
validate_missing_and_unexpected_for_lora(
Expand Down Expand Up @@ -589,6 +587,7 @@ def _setup_optimizer(
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
Expand Down Expand Up @@ -699,14 +698,11 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
)
if self._is_rank_zero:
log.info(
Expand All @@ -717,6 +713,7 @@ def save_checkpoint(
if self._is_rank_zero:
log.info("Retrieving optimizer state dict...")
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
Expand Down
2 changes: 0 additions & 2 deletions tests/torchtune/modules/peft/test_dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def _test_dora_distributed_init(self, load_dora_weights):
ffn,
adapter_state_dict,
device,
is_rank_zero,
)
if is_rank_zero:
for dora_linear in [ffn.w1, ffn.w2, ffn.w3]:
Expand Down Expand Up @@ -377,7 +376,6 @@ def _test_dora_distributed_init(self, load_dora_weights):
ffn,
base_model_state_dict,
device,
is_rank_zero,
)

# After this, everything should be off meta device
Expand Down
13 changes: 5 additions & 8 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,9 @@ def test_lora_state_dict(self):
fsdp_optim_to_save.zero_grad()
expected_model_sd = base_model.state_dict()
expected_optim_sd = base_optim.state_dict()
model_full_sd = training.gather_cpu_state_dict(
fsdp_model_to_save.state_dict(), is_rank_zero
)
model_full_sd = training.gather_cpu_state_dict(fsdp_model_to_save, is_rank_zero)
optim_full_sd = training.get_full_optimizer_state_dict(
fsdp_model_to_save,
fsdp_optim_to_save,
is_rank_zero,
)
Expand Down Expand Up @@ -222,12 +221,12 @@ def test_lora_state_dict(self):
fsdp_model_to_load,
copy.deepcopy(base_model.state_dict()),
torch.device("cuda"),
is_rank_zero,
)
fsdp_optim_to_load = torch.optim.Adam(
fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01
)
training.load_from_full_optimizer_state_dict(
fsdp_model_to_load,
fsdp_optim_to_load,
# mimic mmap=True where every rank see full SD
copy.deepcopy(self._broadcast_full_state_dict(optim_full_sd)),
Expand Down Expand Up @@ -324,9 +323,7 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
fsdp_model_to_save(inp)

expected_model_sd = {k: v.cpu() for k, v in base_model.state_dict().items()}
model_full_sd = training.gather_cpu_state_dict(
fsdp_model_to_save.state_dict(), is_rank_zero
)
model_full_sd = training.gather_cpu_state_dict(fsdp_model_to_save, is_rank_zero)
if is_rank_zero:
self.assertEqual(set(model_full_sd.keys()), set(expected_model_sd.keys()))
for key, value in model_full_sd.items():
Expand Down Expand Up @@ -357,7 +354,7 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
fully_shard(m)
fully_shard(fsdp_model_to_load)
training.load_from_full_model_state_dict(
fsdp_model_to_load, expected_model_sd, torch.device("cuda"), is_rank_zero
fsdp_model_to_load, expected_model_sd, torch.device("cuda")
)
fsdp_model_to_load(inp)
sharded_model_sd = fsdp_model_to_load.state_dict()
Expand Down
Loading
Loading