Skip to content

Commit

Permalink
Merge branch 'main' into fix_max_time_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Jan 29, 2025
2 parents f902552 + 4d1d489 commit 4f376b5
Show file tree
Hide file tree
Showing 67 changed files with 296 additions and 253 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ If you'd like to play with the examples or need the bleeding edge of the code an
```
git clone https://github.com/huggingface/transformers.git
cd transformers
pip install
pip install .
```

### With conda
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-11.html#rel-23-11
FROM nvcr.io/nvidia/pytorch:23.04-py3
FROM nvcr.io/nvidia/pytorch:23.11-py3
LABEL maintainer="Hugging Face"

ARG DEBIAN_FRONTEND=noninteractive
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/data/processors/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def squad_convert_example_to_features(
else:
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0

pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id))
special_token_indices = np.asarray(
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
).nonzero()
Expand Down
55 changes: 51 additions & 4 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import types
from typing import Dict

from ..utils import add_end_docstrings, is_tf_available, is_torch_available
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
from .base import Pipeline, build_pipeline_init_args


if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .pt_utils import KeyDataset

Expand Down Expand Up @@ -380,13 +382,44 @@ def _forward(self, model_inputs, **generate_kwargs):
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)

if isinstance(output, ModelOutput):
generated_sequence = output.sequences
other_outputs = {k: v for k, v in output.items() if k != "sequences"}
out_b = generated_sequence.shape[0]

if self.framework == "pt":
for key, value in other_outputs.items():
if isinstance(value, torch.Tensor) and value.shape[0] == out_b:
other_outputs[key] = value.reshape(in_b, out_b // in_b, *value.shape[1:])
if isinstance(value, tuple) and len(value[0]) == out_b:
value = torch.stack(value).swapaxes(0, 1)
other_outputs[key] = value
elif self.framework == "tf":
for key, value in other_outputs.items():
if isinstance(value, tf.Tensor) and value.shape[0] == out_b:
other_outputs[key] = tf.reshape(value, (in_b, out_b // in_b, *value.shape[1:]))
if isinstance(value, tuple) and len(value[0]) == out_b:
value = tf.stack(value).swapaxes(0, 1)
other_outputs[key] = value
else:
generated_sequence = output
other_outputs = {}

out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf":
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}

model_outputs = {
"generated_sequence": generated_sequence,
"input_ids": input_ids,
"prompt_text": prompt_text,
}
model_outputs.update(other_outputs)
return model_outputs

def postprocess(
self,
Expand All @@ -400,7 +433,19 @@ def postprocess(
prompt_text = model_outputs["prompt_text"]
generated_sequence = generated_sequence.numpy().tolist()
records = []
for sequence in generated_sequence:
other_outputs = model_outputs.get("additional_outputs", {})
splitted_keys = {}
if other_outputs:
if self.framework == "pt":
for k, v in other_outputs.items():
if isinstance(v, torch.Tensor) and v.shape[0] == len(generated_sequence):
splitted_keys[k] = v.numpy().tolist()
elif self.framework == "tf":
for k, v in other_outputs.items():
if isinstance(v, tf.Tensor) and v.shape[0] == len(generated_sequence):
splitted_keys[k] = v.numpy().tolist()

for idx, sequence in enumerate(generated_sequence):
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
Expand Down Expand Up @@ -444,6 +489,8 @@ def postprocess(
# When we're not starting from a prefill, the output is a new assistant message
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
record = {"generated_text": all_text}
for key, values in splitted_keys.items():
record[key] = values[idx]
records.append(record)

return records
Loading

0 comments on commit 4f376b5

Please sign in to comment.