From 312cb7af95523b48da31b5d745b4b271f63c7bbf Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 20 Nov 2024 17:10:31 +0100 Subject: [PATCH] Copy tensor model attributes when initializing --- optimum/neuron/distributed/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 7941c6842..20bbcec6a 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -330,6 +330,7 @@ def _initialize_or_load_weights( from neuronx_distributed import parallel_layers from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank + from neuronx_distributed.parallel_layers.utils import copy_tensor_model_parallel_attributes weight_map = getattr(model, "_weight_map", {}) with torch.no_grad(): @@ -389,6 +390,7 @@ def _initialize_or_load_weights( if device is not None: weight_data = weight_data.to(device) new_parameter = torch.nn.Parameter(weight_data) + copy_tensor_model_parallel_attributes(new_parameter, parameter) elif parameter.device != torch.device("meta") and ( was_already_initialized_during_parallelization(parameter) or not parameter_can_be_initialized(model, module, attribute_name) @@ -401,6 +403,7 @@ def _initialize_or_load_weights( # We first create the module on CPU, initialize it and then move it on device if needed. device = torch.device("cpu") new_parameter = torch.nn.Parameter(torch.empty_like(parameter, device=device)) + copy_tensor_model_parallel_attributes(new_parameter, parameter) modules_to_initialize[module].append(attribute_name) setattr(