Skip to content

Commit

Permalink
Fix StopStringCriteria to handle tokens above len(tokenizer) (#35797)
Browse files Browse the repository at this point in the history
* Fix StopStringCriteria to handle tokens above len(tokenizer)

This fixes #35244 by clipping token IDs to be within the tokenizer's vocabulary size before performing the embedding lookup. This prevents index errors when model.config.vocab_size > len(tokenizer).

The fix:
1. Adds a clamp operation to ensure token IDs are within bounds
2. Adds a test case to verify the behavior

* Use self.stop_strings instead of stop_strings

* Handle clipping correctly

* make fixup

* Update test to the new embedding vecs

* Use much bigger values in the mismatch test

* Typo fix

* Slight simplification

---------

Co-authored-by: openhands <[email protected]>
  • Loading branch information
Rocketknight1 and openhands-agent authored Feb 6, 2025
1 parent 28f73bc commit 4563ba2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
19 changes: 12 additions & 7 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,26 +245,26 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str,
vocab = tokenizer.get_vocab()
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
token_list, token_indices, self.stop_strings, tokenizer
token_list, token_indices, tokenizer
)

self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings])
self.num_stop_strings = len(self.stop_strings)
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)

def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer):
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer):
# We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE:
if (token_list, token_indices, self.stop_strings) in STOP_STRING_EMBEDDING_CACHE:
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
(token_list, token_indices, self.stop_strings)
]
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings))
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, self.stop_strings))
else:
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
clean_token_list, clean_token_indices, stop_strings
clean_token_list, clean_token_indices, self.stop_strings
)
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = (
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, self.stop_strings)] = (
embedding_vec,
max_valid_positions,
max_valid_end_lens,
Expand Down Expand Up @@ -357,7 +357,9 @@ def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -
)
max_valid_end_lens = max(valid_end_lens)
vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)
# We use +2 instead of +1 so we can have a dummy entry at the end. We will clamp all token values
# over the max to this, ensuring they do not contribute to stop string matching.
gather_vec = np.full((max(token_indices) + 2, vec_size), dtype=np.int32, fill_value=-1)

for i, stop_string in enumerate(stop_strings):
positions = token_valid_positions[stop_string]
Expand Down Expand Up @@ -395,6 +397,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
# Flip input_ids because we're only matching strings at the end of the generated sequence
flipped_ids = torch.flip(input_ids, (1,))

# Clip out-of-vocab values to the dummy value at the end of the embedding vector
flipped_ids = torch.clamp(flipped_ids, max=self.embedding_vec.size(0) - 1)

# Size of the vector of positions a single token can match
max_valid_positions = self.max_valid_positions

Expand Down
18 changes: 15 additions & 3 deletions tests/generation/test_stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,18 @@ def test_stop_string_criteria(self):
for i in range(len(false_strings)):
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))

def test_stop_string_criteria_vocab_size_mismatch(self):
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

# Create input_ids with tokens above len(tokenizer)
input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device)
scores = None
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"])

# This should not raise an error and should return False since no stop string is matched
self.assertFalse(criteria(input_ids, scores))

def test_stop_string_matching_positions(self):
stop_string = "stop"
token_list = ["last", "top", "topper", "s", "p"]
Expand All @@ -200,14 +212,14 @@ def test_stop_string_embedding_vecs(self):

# Positions inside the stop string where the token matches (excluding end overlaps)
valid_positions = embedding_vec[:, 0].tolist()
self.assertEqual(valid_positions, [2, -1, -1, 3, -1])
self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1])

# Overlap lengths between end of stop string and start of token
end_overlaps = embedding_vec[:, 1].tolist()
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1])
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1])

# Length of each token
token_lengths = embedding_vec[:, 2].tolist()
token_lengths = embedding_vec[:-1, 2].tolist()
self.assertEqual(token_lengths, [len(token) for token in token_list])

def test_single_letter_stop_string(self):
Expand Down

0 comments on commit 4563ba2

Please sign in to comment.