diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b950a69f8b64..4627aeb97027 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -245,26 +245,26 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, vocab = tokenizer.get_vocab() token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values()) self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache( - token_list, token_indices, self.stop_strings, tokenizer + token_list, token_indices, tokenizer ) self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) self.num_stop_strings = len(self.stop_strings) self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32) - def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer): + def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer): # We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality - if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE: + if (token_list, token_indices, self.stop_strings) in STOP_STRING_EMBEDDING_CACHE: embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[ (token_list, token_indices, self.stop_strings) ] - STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings)) + STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, self.stop_strings)) else: clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer) embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec( - clean_token_list, clean_token_indices, stop_strings + clean_token_list, clean_token_indices, self.stop_strings ) - STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = ( + STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, self.stop_strings)] = ( embedding_vec, max_valid_positions, max_valid_end_lens, @@ -357,7 +357,9 @@ def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) - ) max_valid_end_lens = max(valid_end_lens) vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 - gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1) + # We use +2 instead of +1 so we can have a dummy entry at the end. We will clamp all token values + # over the max to this, ensuring they do not contribute to stop string matching. + gather_vec = np.full((max(token_indices) + 2, vec_size), dtype=np.int32, fill_value=-1) for i, stop_string in enumerate(stop_strings): positions = token_valid_positions[stop_string] @@ -395,6 +397,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # Flip input_ids because we're only matching strings at the end of the generated sequence flipped_ids = torch.flip(input_ids, (1,)) + # Clip out-of-vocab values to the dummy value at the end of the embedding vector + flipped_ids = torch.clamp(flipped_ids, max=self.embedding_vec.size(0) - 1) + # Size of the vector of positions a single token can match max_valid_positions = self.max_valid_positions diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index e8594dcdb07e..ace7d496dab6 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -176,6 +176,18 @@ def test_stop_string_criteria(self): for i in range(len(false_strings)): self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) + def test_stop_string_criteria_vocab_size_mismatch(self): + """Test that StopStringCriteria handles tokens above len(tokenizer) correctly.""" + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + # Create input_ids with tokens above len(tokenizer) + input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device) + scores = None + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"]) + + # This should not raise an error and should return False since no stop string is matched + self.assertFalse(criteria(input_ids, scores)) + def test_stop_string_matching_positions(self): stop_string = "stop" token_list = ["last", "top", "topper", "s", "p"] @@ -200,14 +212,14 @@ def test_stop_string_embedding_vecs(self): # Positions inside the stop string where the token matches (excluding end overlaps) valid_positions = embedding_vec[:, 0].tolist() - self.assertEqual(valid_positions, [2, -1, -1, 3, -1]) + self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1]) # Overlap lengths between end of stop string and start of token end_overlaps = embedding_vec[:, 1].tolist() - self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1]) + self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1]) # Length of each token - token_lengths = embedding_vec[:, 2].tolist() + token_lengths = embedding_vec[:-1, 2].tolist() self.assertEqual(token_lengths, [len(token) for token in token_list]) def test_single_letter_stop_string(self):