-
Notifications
You must be signed in to change notification settings - Fork 483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate distributed state dict API #2138
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2138
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4dfef98 with merge base 27fd3a1 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2138 +/- ##
==========================================
+ Coverage 9.33% 65.26% +55.93%
==========================================
Files 289 334 +45
Lines 16959 19192 +2233
==========================================
+ Hits 1583 12526 +10943
+ Misses 15376 6666 -8710 ☔ View full report in Codecov by Sentry. |
…ept 2 device type and optimize memory (#142845) For destributed state dict api [migration](pytorch/torchtune#2138), make the changes here: 1. `load_from_full_model_state_dict` at TorchTune calls `set_model_state_dict` with the options on whether to have cpu_offload. Add cpu_offload at _load_model_state_dict to process to cpu if config is True 2. Change the device check as lora_finetune might hace 2 device types, accept that to be valid. 3. Some changes to optimize the memory performance: 3.1 use `.detach().clone()` instead of view directly 3.2 if local_state is not meta, copy `full_tensor[slices]` to `ret.to_local()` 4. add relative unit tests Memory performance calling from TorchTune with llama2/7B_full: 1. cpu_offload = True <img width="555" alt="Screenshot 2024-12-18 at 1 36 47 PM" src="https://github.com/user-attachments/assets/429261f5-1107-4592-b295-de3944a2614b" /> 2. cpu_offload = False <img width="555" alt="Screenshot 2024-12-18 at 1 36 52 PM" src="https://github.com/user-attachments/assets/40bf281a-236a-4218-826b-b1192a10c806" /> Pull Request resolved: #142845 Approved by: https://github.com/fegin
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
|
||
# TODO: change to from_local API (need to add view support for NF4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we get view support for NF4?
cc @andrewor14
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the review, we currently skip the NF4 tensor part and plan to support NF4 in the next quarter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there's already view support for NF4Tensor? What's the error you're getting?
also cc @drisspg @weifengpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I brought this up with @ebsmothers and @gau-nernst in Discord. We thought that we needed to do anything else here, it should just be safe to just switch to from_local
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the comments, shall I switch to from_local
in this pr or get it with the other nf4 tensor support in the next pr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes if possible it'd be great to move to from_local
here assuming everything works. Imo the more that we can clean this function up the better, as is it has gotten quite unwieldy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bump here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patience! Left a bunch of comments, please let me know if anything is unclear. One request is to also manually run lora_finetune_distributed_multi_dataset.py
and early_exit_finetune_distributed.py
recipes as they do not currently have tests in our CI. Happy to provide any pointers here if you need.
@@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None: | |||
# To prevent GPU memory from spiking during checkpoint save, | |||
# we consolidate the full model and optim state dicts on CPU for rank 0 | |||
cpu_state_dict = training.gather_cpu_state_dict( | |||
self._model.state_dict(), | |||
self._model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized we are doing things differently here than in the other recipes.. seems to me like we could move the call to get_adapter_state_dict
up before calling gather_cpu_state_dict
, then you could make the same changes you did in e.g. lora_finetune_distributed.py
(remove the call to get_adapter_state_dict
and instead just pass trainable_only=self._save_adapter_weights_only
to gather_cpu_state_dict
). Lmk if that makes sense to you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline, we can take get_adapter_state_dict
cleaning among recipes as future work in the next PR
recipes/lora_finetune_distributed.py
Outdated
self._is_rank_zero, | ||
device=self._device, | ||
trainable_only=self._save_adapter_weights_only, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also one other thing we will have to be aware of is that in general it may not always be the case that trainable params == adapter params. This holds true today, but especially for multimodal models we need to be careful because some people may want to e.g. do LoRA finetuning on the image encoder and full finetuning on the text decoder. This was disabled in #2150 but we may want to add it back later and in that case this would be misleading. So I think trainable_only
is potentially a misnomer and it may be best to rename adapter_weights_only
or something like that.
@@ -500,7 +499,6 @@ def _setup_model( | |||
model, | |||
base_model_state_dict, | |||
self._device, | |||
self._is_rank_zero, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Commenting here for further down in the file but) is there a reason you didn't also update save_checkpoint
in this recipe? (We don't yet have a test for it so probably didn't get caught by CI)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR here only removes self._is_rank_zero
from training.load_from_full_model_state_dict
, which is not called at save_checkpoint
, would add the test of lora_finetune_distributed_multi_dataset.py
and early_exit_finetune_distributed.py
later
@@ -556,7 +556,6 @@ def _setup_model( | |||
model, | |||
model_state_dict, | |||
self._device, | |||
self._is_rank_zero, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need to update _setup_optimizer
in this recipe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_setup_optimizer
does not call load_from_full_model_state_dict
, did not update with removal of self._is_rank_zero
in _setup_optimizer
torchtune/training/_distributed.py
Outdated
|
||
Args: | ||
sharded_sd (Dict[str, DTensor]): Sharded state dict of DTensors | ||
model (FSDPModule): Model to generate fqn for cpu_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but I don't think most people know what "fqn" means, might write something more descriptive here
torchtune/training/_distributed.py
Outdated
) -> Dict[str, Any]: | ||
""" | ||
Converting sharded state dict into a full state dict on CPU | ||
Returning non-empty result only on rank0 to avoid peaking CPU memory | ||
TODO: add support for NF4Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add more details here so that it's clear to someone who's reading the code? Something like "If the model does not contain any NF4 tensors, we directly use distributed state dict APIs. Otherwise, we need to manually gather any NF4 tensors until all-gather is supported in the NF4Tensor subclass"
torchtune/training/_distributed.py
Outdated
for param_name, full_tensor in full_sd.items(): | ||
sharded_meta_param = meta_sharded_sd.get(param_name) | ||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) | ||
if hasattr(sharded_meta_param, "_local_tensor") and isinstance( | ||
sharded_meta_param._local_tensor, NF4Tensor | ||
): | ||
block_size = sharded_meta_param._local_tensor.block_size | ||
scaler_block_size = ( | ||
sharded_meta_param._local_tensor.scaler_block_size | ||
) | ||
full_tensor = to_nf4( | ||
full_tensor, | ||
block_size=block_size, | ||
scaler_block_size=scaler_block_size, | ||
) | ||
# replicating logic from `_fsdp_param.py`` `_init_sharded_param` | ||
# otherwise `distribute_tensor(DTensor(local=NF4))` | ||
# requires dispatching `c10d.scatter_`` | ||
# long-term solution is `swap_tensor` | ||
mesh = sharded_meta_param.device_mesh | ||
if mesh.ndim > 1: | ||
raise NotImplementedError( | ||
f"only support 1D FSDP but got {mesh.ndim=}" | ||
) | ||
shard_mesh_dim = 0 | ||
shard_world_size = mesh.size(shard_mesh_dim) | ||
shard_rank = cast( | ||
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim) | ||
).rank() | ||
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[ | ||
shard_rank | ||
] | ||
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
|
||
# TODO: change to from_local API (need to add view support for NF4) | ||
sharded_tensor = DTensor( | ||
local_tensor=sharded_param, | ||
spec=DTensorSpec( | ||
mesh=sharded_meta_param.device_mesh, | ||
placements=sharded_meta_param.placements, | ||
tensor_meta=TensorMeta( | ||
shape=sharded_meta_param.size(), | ||
dtype=sharded_meta_param.dtype, | ||
stride=sharded_meta_param.stride(), | ||
), | ||
), | ||
requires_grad=sharded_meta_param.requires_grad, | ||
) | ||
|
||
elif not hasattr(sharded_meta_param, "device_mesh"): | ||
# In cases where parts of the model aren't sharded, some parameters will be plain tensors | ||
sharded_tensor = full_tensor | ||
elif not hasattr(sharded_meta_param, "device_mesh"): | ||
# In cases where parts of the model aren't sharded, some parameters will be plain tensors | ||
sharded_tensor = full_tensor | ||
else: | ||
sharded_tensor = distribute_tensor( | ||
full_tensor, | ||
sharded_meta_param.device_mesh, | ||
sharded_meta_param.placements, | ||
) | ||
if cpu_offload: | ||
sharded_tensor = sharded_tensor.cpu() | ||
sharded_sd[param_name] = nn.Parameter(sharded_tensor) | ||
# choose `assign=True` since we cannot call `copy_` on meta tensor | ||
return model.load_state_dict(sharded_sd, strict=strict, assign=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a duplicate of L274-L335? I think this function is already complicated enough, if we can just use a single if/else branch to consolidate these that'd be preferable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch the structure here to if _USE_DISTRIBUTED_STATE_DICT_API and not has_nf4
and else
torchtune/training/_distributed.py
Outdated
# There are some changes at `set_model_state_dict` to adjust multiple devices from local_state in TorchTune, | ||
# keey version check until PyTorch changes are on stable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments updated on why we have the pytorch version check here
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
|
||
# TODO: change to from_local API (need to add view support for NF4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes if possible it'd be great to move to from_local
here assuming everything works. Imo the more that we can clean this function up the better, as is it has gotten quite unwieldy
…ion check, change weights_only to adapter_weights_only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to get this in b/c it's a huge improvement over our current functionality. I left a few notes just as reminders to myself, but definitely need to confirm that we are checking that these APIs are available everywhere they are used and then add in a little more explanation in places.
Thanks @mori360
torchtune/training/_distributed.py
Outdated
|
||
_log: logging.Logger = get_logger() | ||
|
||
|
||
_valid_distributed_single_node_nnodes = ["1:1", "1"] | ||
|
||
torch_version = torch.__version__ | ||
_USE_DISTRIBUTED_STATE_DICT_API = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The actual place this is used will determine whether you are actually "using" the distributed state dict API. This variable only says that the API is fully available. Therefore, I might suggest a name like _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE
# replicating logic from `_fsdp_param.py`` `_init_sharded_param` | ||
# otherwise `distribute_tensor(DTensor(local=NF4))` | ||
# requires dispatching `c10d.scatter_`` | ||
# long-term solution is `swap_tensor` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb question: why can't we use swap_tensor
now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We plan to support nf4tensor as future work in Q1.
torchtune/training/_distributed.py
Outdated
hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor) | ||
for param in model.parameters() | ||
) | ||
if _USE_DISTRIBUTED_STATE_DICT_API and not has_nf4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's necessary to support NF4 using the distributed state dict API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We plan to process NF4 and others in a more general way to avoid if has_nf4:
torchtune/training/_distributed.py
Outdated
@@ -154,7 +168,6 @@ def load_from_full_model_state_dict( | |||
model: "FSDPModule", # noqa | |||
full_sd: Dict[str, Any], | |||
device: torch.device, | |||
is_rank_zero: bool, | |||
strict: bool = False, | |||
cpu_offload: bool = False, | |||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we update the return type here (and docstring) since we actually return missing and unexpected keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually cannot believe this is a public API with such poor documentation. Sorry @mori360!
torchtune/training/_distributed.py
Outdated
@@ -166,64 +179,95 @@ def load_from_full_model_state_dict( | |||
- `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment can be removed now that it isn't a part of the params.
sharded_sd = {} | ||
for param_name, full_tensor in full_sd.items(): | ||
sharded_meta_param = meta_sharded_sd.get(param_name) | ||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sigh: we shouldn't have two consecutive to
calls. You don't need to fix here - I can do it.
): | ||
block_size = sharded_meta_param._local_tensor.block_size | ||
scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size | ||
full_tensor = to_nf4( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we assume "full" means plain tensor which is the claim above, then this is - in fact - not a "full" tensor?
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
|
||
# TODO: change to from_local API (need to add view support for NF4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bump here.
cpu_state_dict[param_name] = param.cpu() | ||
torch.distributed.barrier() | ||
return cpu_state_dict | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't there be a check here to make sure that the distributed APIs are available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current distributed api could support with the ideal performance, so don't need to check distributed APIs availability here
full_state[group_id] = group_state | ||
else: | ||
del group_state | ||
options = StateDictOptions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love how clean this is now, but shouldn't we check to make sure Distributed APIs are available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, the changes as Distributed API are to support loading state dict, the other 2 API could work with the ideal performance with the current Distributed API without nightly version landed at Dec 20
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Thanks for working with us to address all comments :)
Context
What is the purpose of this PR? Is it to
Migrate distributed state dict APIs from torch.distributed.
Changelog
What are the changes made in this PR?
Switch to distributed state dict APIs from torch.distributed.
load_from_full_model_state_dict
<-set_model_state_dict
gather_cpu_state_dict
<-get_model_state_dict
load_from_full_optimizer_state_dict
<-set_optimizer_state_dict
get_full_optimizer_state_dict
<-get_optimizer_state_dict
To align the inputs, add model input to
get_full_optimizer_state_dict
andload_from_full_optimizer_state_dict
.Change the sharded_sd input for
gather_cpu_state_dict
to model.TODO:
nf4tensor are kept the same, remain as future work
Test plan
pytest tests/torchtune/training/test_distributed.py
pytest tests -m integration_test
(
early_exit_finetune_distributed
andknowledge_distillation_distributed
are not tested in the CI test)tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed
We compare the running with the previous API and the new API, loss are the same in initial loading and resume from checkpoint.
We also draw the memory traces, results show that the new API won't cost mote memory peak comapred with the current ones.