Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate distributed state dict API #2138

Merged
merged 22 commits into from
Jan 8, 2025
Merged

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Dec 10, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Migrate distributed state dict APIs from torch.distributed.

Changelog

What are the changes made in this PR?

Switch to distributed state dict APIs from torch.distributed.

  • load_from_full_model_state_dict <- set_model_state_dict
  • gather_cpu_state_dict <- get_model_state_dict
  • load_from_full_optimizer_state_dict <- set_optimizer_state_dict
  • get_full_optimizer_state_dict <- get_optimizer_state_dict

To align the inputs, add model input to get_full_optimizer_state_dict and load_from_full_optimizer_state_dict.
Change the sharded_sd input for gather_cpu_state_dict to model.

TODO:
nf4tensor are kept the same, remain as future work

Test plan

pytest tests/torchtune/training/test_distributed.py
pytest tests -m integration_test
(early_exit_finetune_distributed and knowledge_distillation_distributed are not tested in the CI test)
tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed

We compare the running with the previous API and the new API, loss are the same in initial loading and resume from checkpoint.

We also draw the memory traces, results show that the new API won't cost mote memory peak comapred with the current ones.
Screenshot 2025-01-02 at 1 10 18 PM

Copy link

pytorch-bot bot commented Dec 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2138

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4dfef98 with merge base 27fd3a1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 10, 2024
@joecummings joecummings added the distributed Anything related to distributed env (multi-GPU, multi-node) label Dec 10, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 3.38983% with 57 lines in your changes missing coverage. Please review.

Project coverage is 65.26%. Comparing base (f2bd4bc) to head (8b575be).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_distributed.py 3.50% 55 Missing ⚠️
tests/torchtune/training/test_distributed.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2138       +/-   ##
==========================================
+ Coverage   9.33%   65.26%   +55.93%     
==========================================
  Files        289      334       +45     
  Lines      16959    19192     +2233     
==========================================
+ Hits        1583    12526    +10943     
+ Misses     15376     6666     -8710     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mori360 mori360 changed the title Mitigate distributed state dict API Migrate distributed state dict API Dec 18, 2024
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Dec 19, 2024
…ept 2 device type and optimize memory (#142845)

For destributed state dict api [migration](pytorch/torchtune#2138), make the changes here:
1. `load_from_full_model_state_dict` at TorchTune calls `set_model_state_dict` with the options on whether to have cpu_offload. Add cpu_offload at _load_model_state_dict to process to cpu if config is True
2. Change the device check as lora_finetune might hace 2 device types, accept that to be valid.
3. Some changes to optimize the memory performance:
3.1 use `.detach().clone()` instead of view directly
3.2 if local_state is not meta, copy `full_tensor[slices]` to `ret.to_local()`
4. add relative unit tests

Memory performance calling from TorchTune with llama2/7B_full:
1. cpu_offload = True
<img width="555" alt="Screenshot 2024-12-18 at 1 36 47 PM" src="https://github.com/user-attachments/assets/429261f5-1107-4592-b295-de3944a2614b" />

2. cpu_offload = False
<img width="555" alt="Screenshot 2024-12-18 at 1 36 52 PM" src="https://github.com/user-attachments/assets/40bf281a-236a-4218-826b-b1192a10c806" />

Pull Request resolved: #142845
Approved by: https://github.com/fegin
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we get view support for NF4?

cc @andrewor14

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review, we currently skip the NF4 tensor part and plan to support NF4 in the next quarter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's already view support for NF4Tensor? What's the error you're getting?

also cc @drisspg @weifengpy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I brought this up with @ebsmothers and @gau-nernst in Discord. We thought that we needed to do anything else here, it should just be safe to just switch to from_local.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comments, shall I switch to from_local in this pr or get it with the other nf4 tensor support in the next pr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if possible it'd be great to move to from_local here assuming everything works. Imo the more that we can clean this function up the better, as is it has gotten quite unwieldy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump here.

@mori360 mori360 marked this pull request as ready for review December 20, 2024 23:57
@mori360 mori360 requested a review from joecummings December 20, 2024 23:58
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience! Left a bunch of comments, please let me know if anything is unclear. One request is to also manually run lora_finetune_distributed_multi_dataset.py and early_exit_finetune_distributed.py recipes as they do not currently have tests in our CI. Happy to provide any pointers here if you need.

@@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None:
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized we are doing things differently here than in the other recipes.. seems to me like we could move the call to get_adapter_state_dict up before calling gather_cpu_state_dict, then you could make the same changes you did in e.g. lora_finetune_distributed.py (remove the call to get_adapter_state_dict and instead just pass trainable_only=self._save_adapter_weights_only to gather_cpu_state_dict). Lmk if that makes sense to you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, we can take get_adapter_state_dict cleaning among recipes as future work in the next PR

self._is_rank_zero,
device=self._device,
trainable_only=self._save_adapter_weights_only,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also one other thing we will have to be aware of is that in general it may not always be the case that trainable params == adapter params. This holds true today, but especially for multimodal models we need to be careful because some people may want to e.g. do LoRA finetuning on the image encoder and full finetuning on the text decoder. This was disabled in #2150 but we may want to add it back later and in that case this would be misleading. So I think trainable_only is potentially a misnomer and it may be best to rename adapter_weights_only or something like that.

@@ -500,7 +499,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Commenting here for further down in the file but) is there a reason you didn't also update save_checkpoint in this recipe? (We don't yet have a test for it so probably didn't get caught by CI)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR here only removes self._is_rank_zero from training.load_from_full_model_state_dict, which is not called at save_checkpoint, would add the test of lora_finetune_distributed_multi_dataset.py and early_exit_finetune_distributed.py later

@@ -556,7 +556,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to update _setup_optimizer in this recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_setup_optimizer does not call load_from_full_model_state_dict, did not update with removal of self._is_rank_zero in _setup_optimizer


Args:
sharded_sd (Dict[str, DTensor]): Sharded state dict of DTensors
model (FSDPModule): Model to generate fqn for cpu_state_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but I don't think most people know what "fqn" means, might write something more descriptive here

) -> Dict[str, Any]:
"""
Converting sharded state dict into a full state dict on CPU
Returning non-empty result only on rank0 to avoid peaking CPU memory
TODO: add support for NF4Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add more details here so that it's clear to someone who's reading the code? Something like "If the model does not contain any NF4 tensors, we directly use distributed state dict APIs. Otherwise, we need to manually gather any NF4 tensors until all-gather is supported in the NF4Tensor subclass"

torchtune/training/_distributed.py Outdated Show resolved Hide resolved
Comment on lines 194 to 257
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = (
sharded_meta_param._local_tensor.scaler_block_size
)
full_tensor = to_nf4(
full_tensor,
block_size=block_size,
scaler_block_size=scaler_block_size,
)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(
f"only support 1D FSDP but got {mesh.ndim=}"
)
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[
shard_rank
]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
sharded_tensor = DTensor(
local_tensor=sharded_param,
spec=DTensorSpec(
mesh=sharded_meta_param.device_mesh,
placements=sharded_meta_param.placements,
tensor_meta=TensorMeta(
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
stride=sharded_meta_param.stride(),
),
),
requires_grad=sharded_meta_param.requires_grad,
)

elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a duplicate of L274-L335? I think this function is already complicated enough, if we can just use a single if/else branch to consolidate these that'd be preferable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch the structure here to if _USE_DISTRIBUTED_STATE_DICT_API and not has_nf4 and else

Comment on lines 182 to 183
# There are some changes at `set_model_state_dict` to adjust multiple devices from local_state in TorchTune,
# keey version check until PyTorch changes are on stable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments updated on why we have the pytorch version check here

sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if possible it'd be great to move to from_local here assuming everything works. Imo the more that we can clean this function up the better, as is it has gotten quite unwieldy

@mori360 mori360 marked this pull request as draft January 6, 2025 18:02
…ion check, change weights_only to adapter_weights_only
@mori360 mori360 marked this pull request as ready for review January 7, 2025 01:14
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get this in b/c it's a huge improvement over our current functionality. I left a few notes just as reminders to myself, but definitely need to confirm that we are checking that these APIs are available everywhere they are used and then add in a little more explanation in places.

Thanks @mori360


_log: logging.Logger = get_logger()


_valid_distributed_single_node_nnodes = ["1:1", "1"]

torch_version = torch.__version__
_USE_DISTRIBUTED_STATE_DICT_API = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The actual place this is used will determine whether you are actually "using" the distributed state dict API. This variable only says that the API is fully available. Therefore, I might suggest a name like _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE

# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: why can't we use swap_tensor now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We plan to support nf4tensor as future work in Q1.

hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor)
for param in model.parameters()
)
if _USE_DISTRIBUTED_STATE_DICT_API and not has_nf4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's necessary to support NF4 using the distributed state dict API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We plan to process NF4 and others in a more general way to avoid if has_nf4:

@@ -154,7 +168,6 @@ def load_from_full_model_state_dict(
model: "FSDPModule", # noqa
full_sd: Dict[str, Any],
device: torch.device,
is_rank_zero: bool,
strict: bool = False,
cpu_offload: bool = False,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the return type here (and docstring) since we actually return missing and unexpected keys?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually cannot believe this is a public API with such poor documentation. Sorry @mori360!

@@ -166,64 +179,95 @@ def load_from_full_model_state_dict(
- `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment can be removed now that it isn't a part of the params.

sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sigh: we shouldn't have two consecutive to calls. You don't need to fix here - I can do it.

):
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size
full_tensor = to_nf4(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we assume "full" means plain tensor which is the claim above, then this is - in fact - not a "full" tensor?

sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump here.

cpu_state_dict[param_name] = param.cpu()
torch.distributed.barrier()
return cpu_state_dict
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be a check here to make sure that the distributed APIs are available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current distributed api could support with the ideal performance, so don't need to check distributed APIs availability here

full_state[group_id] = group_state
else:
del group_state
options = StateDictOptions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love how clean this is now, but shouldn't we check to make sure Distributed APIs are available?

Copy link
Contributor Author

@mori360 mori360 Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, the changes as Distributed API are to support loading state dict, the other 2 API could work with the ideal performance with the current Distributed API without nightly version landed at Dec 20

@mori360 mori360 requested a review from joecummings January 8, 2025 21:07
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Thanks for working with us to address all comments :)

@joecummings joecummings merged commit 38bf427 into pytorch:main Jan 8, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. distributed Anything related to distributed env (multi-GPU, multi-node)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants