From ece4f984919e41809c5d7e2fd23255b6685fdab3 Mon Sep 17 00:00:00 2001 From: gioannides Date: Sat, 5 Oct 2024 00:25:56 -0700 Subject: [PATCH 1/2] Fix non-contiguous tensor issue in checkpoint consolidation --- optimum/neuron/distributed/checkpointing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index 02cca08f4..97acb9b92 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -147,13 +147,15 @@ def consolidate_tensor_parallel_checkpoints( if sharded_metadata.is_tied: consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu") else: - weights = [state_dict[name] for state_dict in state_dicts] + # Ensure that all tensors are contiguous before concatenating or further processing + weights = [state_dict[name].contiguous() for state_dict in state_dicts] tp_size = len(weights) + full_weight = torch.cat( weights, dim=sharded_metadata.partition_dim, - ) - full_weight = full_weight.to("cpu") + ).contiguous() # Ensure the result is also contiguous + if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]: full_weight = ( torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone() From 15f87cc0338448de1046b079f8f2e665ea3c3d92 Mon Sep 17 00:00:00 2001 From: gioannides Date: Mon, 7 Oct 2024 01:10:03 -0700 Subject: [PATCH 2/2] Update PR with more edge cases where tensor may not be contiguous after placed on cpu --- optimum/neuron/distributed/checkpointing.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index 97acb9b92..7f9ce7d78 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -145,17 +145,21 @@ def consolidate_tensor_parallel_checkpoints( # This might not be the case anymore when `ParameterMetadata` uses slices. sharded_metadata = sharded_metadatas[name] if sharded_metadata.is_tied: - consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu") + consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous() else: # Ensure that all tensors are contiguous before concatenating or further processing weights = [state_dict[name].contiguous() for state_dict in state_dicts] tp_size = len(weights) - full_weight = torch.cat( - weights, - dim=sharded_metadata.partition_dim, - ).contiguous() # Ensure the result is also contiguous - + full_weight = ( + torch.cat( + weights, + dim=sharded_metadata.partition_dim, + ) + .to("cpu") + .contiguous() + ) # Ensure the result is also contiguous + if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]: full_weight = ( torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone()