diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index 02cca08f4..7f9ce7d78 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -145,15 +145,21 @@ def consolidate_tensor_parallel_checkpoints( # This might not be the case anymore when `ParameterMetadata` uses slices. sharded_metadata = sharded_metadatas[name] if sharded_metadata.is_tied: - consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu") + consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous() else: - weights = [state_dict[name] for state_dict in state_dicts] + # Ensure that all tensors are contiguous before concatenating or further processing + weights = [state_dict[name].contiguous() for state_dict in state_dicts] tp_size = len(weights) - full_weight = torch.cat( - weights, - dim=sharded_metadata.partition_dim, - ) - full_weight = full_weight.to("cpu") + + full_weight = ( + torch.cat( + weights, + dim=sharded_metadata.partition_dim, + ) + .to("cpu") + .contiguous() + ) # Ensure the result is also contiguous + if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]: full_weight = ( torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone()