Skip to content

Commit

Permalink
fix: stop checker
Browse files Browse the repository at this point in the history
  • Loading branch information
numb3r3 committed Jan 17, 2025
1 parent bcd0d60 commit db46379
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
20 changes: 12 additions & 8 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RepetitionConfig:
medium_ngram_threshold: int = 8

# start checking for repetition after the first 1024 tokens
start_checking_after: int = 1024
start_checking_after: int = 1


class StopChecker:
Expand Down Expand Up @@ -132,6 +132,10 @@ def check_ngram_repetition(self,
if token == last_token_id:
repeated_at = seq.repeat_start_from + i
repeated_gap = output_len - repeated_at - 1
if seq.repeated_gap is None:
seq.repeated_gap = repeated_gap
elif seq.repeated_gap == repeated_gap:
break

if repeated_at is not None:
seq.repeated_count += 1
Expand All @@ -140,19 +144,19 @@ def check_ngram_repetition(self,
# f"\n==> token ({last_token}) at {output_len}\n"
# f"==> repeat_at: {repeated_at}\n"
# f"==> repeated_count: {seq.repeated_count}\n"
# f"==> repeated_gap: {repeated_gap}\n"
# f"==> repeate_start_from: {seq.repeat_start_from}"
# f"==> repeated_gap: {repeated_gap} vs {seq.repeated_gap}\n"
# f"==> repeate_start_from: {seq.repeat_start_from}\n"
# f"==> repeated_total: {seq.repeated_total}\n"
# )

seq.repeat_start_from = repeated_at

# reset the repetition count if the gap changes
if repeated_at is None or repeated_gap != seq.repeated_gap:
seq.repeated_count = 0
seq.repeated_gap = 0
seq.repeated_total = 0

if repeated_gap is not None:
if seq.repeated_gap != repeated_gap:
seq.repeated_gap = repeated_gap

if seq.repeated_count == seq.repeated_gap and seq.repeated_gap:
Expand All @@ -161,13 +165,13 @@ def check_ngram_repetition(self,

# print(f"==> repeated_total: {seq.repeated_total}")

repeate_ngram_size = seq.repeated_gap
# repeate_ngram_size = seq.repeated_gap
# print(f'==> repeate_ngram_size: {repeate_ngram_size}')

if repeate_ngram_size == 1:
if seq.repeated_gap == 1:
# single token repetition
is_done = seq.repeated_total > self.repetition_config.single_token_threshold
elif repeate_ngram_size > 64:
elif seq.repeated_gap > 64:
# paragraph repetition
is_done = seq.repeated_total >= self.repetition_config.large_ngram_threshold
else:
Expand Down
2 changes: 1 addition & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __init__(
# the number of tokens repeated
self.repeated_count = 0
# the gap between the repeated tokens
self.repeated_gap = 0
self.repeated_gap = None
# the repeated ngram that we already generated
self.repeated_total = 0

Expand Down

0 comments on commit db46379

Please sign in to comment.