Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix t5 decoder compilation error since Neuron sdk 2.20 #732

Merged
merged 5 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def export_neuronx(
# Construct compiler configurations
if auto_cast is not None:
logger.info(f"Using Neuron: --auto-cast {auto_cast}")

auto_cast = "matmult" if auto_cast == "matmul" else auto_cast
compiler_args = ["--auto-cast", auto_cast]

Expand All @@ -552,6 +551,10 @@ def export_neuronx(
compiler_args = ["--auto-cast", "none"]

compiler_args.extend(["--optlevel", optlevel])
logger.info(f"Using Neuron: --optlevel {optlevel}")

if getattr(config._config, "is_encoder_decoder", False):
compiler_args.extend(["--model-type", "transformer"])

compiler_args = add_stable_diffusion_compiler_args(config, compiler_args) # diffusers specific

Expand Down
3 changes: 2 additions & 1 deletion optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ def update_past(self, past_key_values):

def reorder_cache(self, past_key_values, beam_idx):
for i in range(len(past_key_values)):
past_key_values[i] = torch.index_select(past_key_values[i], 0, beam_idx)
gather_index = beam_idx.view([beam_idx.shape[0], 1, 1, 1]).expand_as(past_key_values[i])
past_key_values[i] = torch.gather(past_key_values[i], dim=0, index=gather_index)
return past_key_values

def forward(
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,14 +462,14 @@ def forward(
decoder_hidden_states = None

# Skip pkv which can't be copied from memory to buffer
if output_attentions and self.config.neuron.get("output_attentions"):
if output_attentions and self.configs["decoder"].neuron.get("output_attentions"):
if self.config.is_encoder_decoder:
cross_attentions = outputs[-self.config.num_decoder_layers :]
cur_idx += self.config.num_decoder_layers
decoder_attentions = outputs[-(self.config.num_decoder_layers + cur_idx) : -cur_idx]
cur_idx += self.config.num_decoder_layers

if output_hidden_states and self.config.neuron.get("output_hidden_states"):
if output_hidden_states and self.configs["decoder"].neuron.get("output_hidden_states"):
decoder_hidden_states = outputs[-(self.config.num_decoder_layers + 1 + cur_idx) : -cur_idx]

decoder_outputs = ModelOutput(
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.0.25.dev0"
__version__ = "0.0.27.dev0"

__sdk_version__ = "2.20.0"
9 changes: 0 additions & 9 deletions tests/cli/test_export_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,6 @@ def test_replace_unet(self):
check=True,
)

@unittest.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@requires_neuronx
def test_encoder_decoder(self):
model_id = "hf-internal-testing/tiny-random-t5"
Expand Down Expand Up @@ -335,9 +332,6 @@ def test_encoder_decoder(self):
check=True,
)

@unittest.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@requires_neuronx
def test_encoder_decoder_optional_outputs(self):
model_id = "hf-internal-testing/tiny-random-t5"
Expand Down Expand Up @@ -369,9 +363,6 @@ def test_encoder_decoder_optional_outputs(self):
check=True,
)

@unittest.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@requires_neuronx
def test_encoder_decoder_tp2(self):
model_id = "michaelbenayoun/t5-tiny-random"
Expand Down
3 changes: 0 additions & 3 deletions tests/exporters/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,6 @@ def test_export_sd_with_fused_lora_weights(self):
)


@unittest.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@is_inferentia_test
@requires_neuronx
class NeuronEncoderDecoderExportTestCase(unittest.TestCase):
Expand Down
42 changes: 0 additions & 42 deletions tests/generation/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen
assert sample_output.shape[0] == batch_size


@pytest.mark.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_beam(neuron_seq2seq_beam_path):
Expand All @@ -58,9 +55,6 @@ def test_seq2seq_generation_beam(neuron_seq2seq_beam_path):
assert len(output[0].unique()) <= 5 + 1 # +1 for `decoder_start_token_id`


@pytest.mark.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_beam_with_optional_outputs(neuron_seq2seq_beam_path_with_optional_outputs):
Expand All @@ -83,9 +77,6 @@ def test_seq2seq_generation_beam_with_optional_outputs(neuron_seq2seq_beam_path_
assert "decoder_hidden_states" in output


@pytest.mark.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_greedy(neuron_seq2seq_greedy_path):
Expand All @@ -106,9 +97,6 @@ def test_seq2seq_generation_greedy(neuron_seq2seq_greedy_path):
assert len(output[0]) <= 5 + 1 # +1 for `decoder_start_token_id`


@pytest.mark.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_greedy_with_optional_outputs(neuron_seq2seq_greedy_path_with_optional_outputs):
Expand All @@ -129,29 +117,6 @@ def test_seq2seq_generation_greedy_with_optional_outputs(neuron_seq2seq_greedy_p
assert "decoder_hidden_states" in output


@pytest.mark.skip(
"T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013."
)
@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_tp2(neuron_seq2seq_tp2_path):
model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_tp2_path)
tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_tp2_path)
inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt")

output = model.generate(
**inputs,
num_return_sequences=1,
max_length=20,
output_attentions=True,
output_hidden_states=True,
return_dict_in_generate=True,
)
assert "decoder_attentions" in output
assert "cross_attentions" in output
assert "decoder_hidden_states" in output


@pytest.mark.skip("Makes pytest fail, to fix.")
@pytest.mark.parametrize(
"gen_kwargs",
Expand Down Expand Up @@ -195,10 +160,3 @@ def test_general_seq2seq_generation(export_seq2seq_id, export_seq2seq_model_clas
model = export_seq2seq_model_class.from_pretrained(export_seq2seq_id)
tokenizer = AutoTokenizer.from_pretrained(export_seq2seq_id)
_test_model_generation_trn(model, tokenizer, 1, 10, **gen_kwargs)


# Compulsory for multiprocessing tests, since we want children processes to be spawned only in the main program.
# eg. tensor parallel tracing, `neuronx_distributed.parallel_model_trace` will spawn multiple processes to trace
# and compile the model.
if __name__ == "__main__":
pytest.main([__file__])
48 changes: 48 additions & 0 deletions tests/generation/test_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
from transformers import AutoTokenizer

from optimum.neuron import NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx


@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_tp2(neuron_seq2seq_tp2_path):
model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_tp2_path)
tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_tp2_path)
inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt")

output = model.generate(
**inputs,
num_return_sequences=1,
max_length=20,
output_attentions=True,
output_hidden_states=True,
return_dict_in_generate=True,
)
assert "decoder_attentions" in output
assert "cross_attentions" in output
assert "decoder_hidden_states" in output


# Compulsory for multiprocessing tests, since we want children processes to be spawned only in the main program.
# eg. tensor parallel tracing, `neuronx_distributed.parallel_model_trace` will spawn multiple processes to trace
# and compile the model.
if __name__ == "__main__":
pytest.main([__file__])
Loading