diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 647b195d7..80841501f 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -1437,13 +1437,6 @@ def wrapper( self._init_weights(lm_head) return orig_resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of) - new_embedding_shape = resized_embedding.weight.shape - if embedding_shape != new_embedding_shape: - resized_embedding._shape_before_resized = embedding_shape - lm_head = self.get_output_embeddings() - if lm_head is not None: - lm_head._shape_before_resized = embedding_shape - return resized_embedding bound_wrapper = wrapper.__get__(orig_resize_token_embeddings.__self__) return bound_wrapper diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index d80614478..6608d5825 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -949,7 +949,7 @@ def _inner_training_loop( self.control = self.callback_handler.on_train_begin(args, self.state, self.control) # Mark step before training to materialize any tensor before creating the training graph. - # xm.mark_step() + xm.mark_step() # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index 298ec3d67..30a59dd26 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -20,7 +20,7 @@ import pytest import torch import torch.utils._pytree as pytree -from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, @@ -44,7 +44,7 @@ import optimum from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager -from optimum.neuron.distributed.utils import compute_query_indices_for_rank, lazy_load_for_parallelism +from optimum.neuron.distributed.utils import compute_query_indices_for_rank from optimum.neuron.utils.cache_utils import ( get_num_neuron_cores, ) @@ -221,6 +221,10 @@ def sequence_parallel_enabled(self, request): def parallelize_embeddings(self, request): return request.param + @pytest.fixture(scope="class", params=[False, True], ids=["embeddings_not_tied", "tied_embeddings"]) + def tie_embeddings(self, request): + return request.param + def early_skip(self, fixtures_kwargs): pp_size = fixtures_kwargs.get("pp_size", None) parallel_sizes = fixtures_kwargs.get("parallel_sizes", None) @@ -562,14 +566,17 @@ def test_llama_v2_gqa( ) @pytest.mark.parallel_sizes((2, 2, 1)) - def test_resize_embedding(self): + def test_resize_embedding(self, tie_embeddings): tp_size = get_tensor_model_parallel_size() tp_group = get_tensor_model_parallel_group() static_seed_patcher = StaticSeedPatcher(42) + config = AutoConfig.from_pretrained(LLAMA_V2_MODEL_NAME) + config.tie_word_embeddings = tie_embeddings + with static_seed_patcher: - orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME) + orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config) orig_model.eval() vocab_size = orig_model.config.vocab_size new_vocab_size = vocab_size + tp_size @@ -577,9 +584,9 @@ def test_resize_embedding(self): with static_seed_patcher: orig_model.resize_token_embeddings(new_vocab_size) - # with lazy_load_for_parallelism(tensor_parallel_size=tp_size): - model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME) - model.eval() + with static_seed_patcher: + model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config) + model.eval() with static_seed_patcher: model.resize_token_embeddings(new_vocab_size) @@ -593,20 +600,6 @@ def test_resize_embedding(self): with static_seed_patcher: model = accelerator.prepare_model(model) - # Tying weights to end up with the same LM head. - orig_model.lm_head.weight = orig_model.model.embed_tokens.weight - model.lm_head.weight = model.model.embed_tokens.weight - print(orig_model.model.embed_tokens.weight.shape) - print(model.model.embed_tokens.weight.shape) - - # for t1, t2 in zip(orig_model.named_parameters(), model.to("cpu").named_parameters()): - # n1, p1 = t1 - # _, p2 = t2 - # xm.master_print(f"{n1}, p1 = {p1}, p2 = {p2}") - - xm.master_print(orig_model.lm_head.weight) - xm.master_print(model.lm_head.weight) - # First we check that the embedding weights match gathered = [torch.empty_like(model.model.embed_tokens.weight) for _ in range(tp_size)] torch.distributed.all_gather(gathered, model.model.embed_tokens.weight, group=tp_group) @@ -621,13 +614,14 @@ def test_resize_embedding(self): inputs = {k: v.to("xla") for k, v in inputs.items()} orig_model = orig_model.to("xla") orig_logits = orig_model(**inputs).logits - xm.master_print(orig_logits) + xm.mark_step() logits = model(**inputs).logits - xm.master_print(logits) - # gathered = [torch.empty_like(logits) for _ in range(tp_size)] - # torch.distributed.all_gather(gathered, logits, group=tp_group) - # gathered_logits = torch.cat(gathered, dim=2) - # torch.testing.assert_close(orig_logits, gathered_logits.to("cpu")) + xm.mark_step() + gathered = [torch.empty_like(logits) for _ in range(tp_size)] + torch.distributed.all_gather(gathered, logits, group=tp_group) + gathered_logits = torch.cat(gathered, dim=2) + xm.mark_step() + torch.testing.assert_close(orig_logits, gathered_logits) @pytest.mark.parametrize(