Skip to content

Commit

Permalink
Test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 17, 2024
1 parent 9a6c01e commit 1aacc9e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 35 deletions.
7 changes: 0 additions & 7 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,13 +1437,6 @@ def wrapper(
self._init_weights(lm_head)

return orig_resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)
new_embedding_shape = resized_embedding.weight.shape
if embedding_shape != new_embedding_shape:
resized_embedding._shape_before_resized = embedding_shape
lm_head = self.get_output_embeddings()
if lm_head is not None:
lm_head._shape_before_resized = embedding_shape
return resized_embedding

bound_wrapper = wrapper.__get__(orig_resize_token_embeddings.__self__)
return bound_wrapper
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ def _inner_training_loop(
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

# Mark step before training to materialize any tensor before creating the training graph.
# xm.mark_step()
xm.mark_step()

# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
Expand Down
48 changes: 21 additions & 27 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import torch
import torch.utils._pytree as pytree
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
Expand All @@ -44,7 +44,7 @@

import optimum
from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager
from optimum.neuron.distributed.utils import compute_query_indices_for_rank, lazy_load_for_parallelism
from optimum.neuron.distributed.utils import compute_query_indices_for_rank
from optimum.neuron.utils.cache_utils import (
get_num_neuron_cores,
)
Expand Down Expand Up @@ -221,6 +221,10 @@ def sequence_parallel_enabled(self, request):
def parallelize_embeddings(self, request):
return request.param

@pytest.fixture(scope="class", params=[False, True], ids=["embeddings_not_tied", "tied_embeddings"])
def tie_embeddings(self, request):
return request.param

def early_skip(self, fixtures_kwargs):
pp_size = fixtures_kwargs.get("pp_size", None)
parallel_sizes = fixtures_kwargs.get("parallel_sizes", None)
Expand Down Expand Up @@ -562,24 +566,27 @@ def test_llama_v2_gqa(
)

@pytest.mark.parallel_sizes((2, 2, 1))
def test_resize_embedding(self):
def test_resize_embedding(self, tie_embeddings):
tp_size = get_tensor_model_parallel_size()
tp_group = get_tensor_model_parallel_group()

static_seed_patcher = StaticSeedPatcher(42)

config = AutoConfig.from_pretrained(LLAMA_V2_MODEL_NAME)
config.tie_word_embeddings = tie_embeddings

with static_seed_patcher:
orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME)
orig_model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config)
orig_model.eval()
vocab_size = orig_model.config.vocab_size
new_vocab_size = vocab_size + tp_size

with static_seed_patcher:
orig_model.resize_token_embeddings(new_vocab_size)

# with lazy_load_for_parallelism(tensor_parallel_size=tp_size):
model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME)
model.eval()
with static_seed_patcher:
model = AutoModelForCausalLM.from_pretrained(LLAMA_V2_MODEL_NAME, config=config)
model.eval()

with static_seed_patcher:
model.resize_token_embeddings(new_vocab_size)
Expand All @@ -593,20 +600,6 @@ def test_resize_embedding(self):
with static_seed_patcher:
model = accelerator.prepare_model(model)

# Tying weights to end up with the same LM head.
orig_model.lm_head.weight = orig_model.model.embed_tokens.weight
model.lm_head.weight = model.model.embed_tokens.weight
print(orig_model.model.embed_tokens.weight.shape)
print(model.model.embed_tokens.weight.shape)

# for t1, t2 in zip(orig_model.named_parameters(), model.to("cpu").named_parameters()):
# n1, p1 = t1
# _, p2 = t2
# xm.master_print(f"{n1}, p1 = {p1}, p2 = {p2}")

xm.master_print(orig_model.lm_head.weight)
xm.master_print(model.lm_head.weight)

# First we check that the embedding weights match
gathered = [torch.empty_like(model.model.embed_tokens.weight) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, model.model.embed_tokens.weight, group=tp_group)
Expand All @@ -621,13 +614,14 @@ def test_resize_embedding(self):
inputs = {k: v.to("xla") for k, v in inputs.items()}
orig_model = orig_model.to("xla")
orig_logits = orig_model(**inputs).logits
xm.master_print(orig_logits)
xm.mark_step()
logits = model(**inputs).logits
xm.master_print(logits)
# gathered = [torch.empty_like(logits) for _ in range(tp_size)]
# torch.distributed.all_gather(gathered, logits, group=tp_group)
# gathered_logits = torch.cat(gathered, dim=2)
# torch.testing.assert_close(orig_logits, gathered_logits.to("cpu"))
xm.mark_step()
gathered = [torch.empty_like(logits) for _ in range(tp_size)]
torch.distributed.all_gather(gathered, logits, group=tp_group)
gathered_logits = torch.cat(gathered, dim=2)
xm.mark_step()
torch.testing.assert_close(orig_logits, gathered_logits)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 1aacc9e

Please sign in to comment.