Skip to content

Commit

Permalink
feat: add MinTokensLogitProcessor and min_tokens argument to server (#…
Browse files Browse the repository at this point in the history
…1333)

* implement min_tokens

* set default to 0

* pass min_tokens

* fix

* remove copy

* implement MinTokensLogitsProcessor

* format

* fix condition
  • Loading branch information
twaka authored May 14, 2024
1 parent 389e09c commit 5212fb0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
16 changes: 16 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,3 +2084,19 @@ def __call__(
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
) -> bool:
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])


class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(self, min_tokens: int, token_eos: int):
self.min_tokens = min_tokens
self.token_eos = token_eos
self.prompt_tokens = None

def __call__(
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
) -> npt.NDArray[np.single]:
if self.prompt_tokens is None:
self.prompt_tokens = len(input_ids)
if len(input_ids) - self.prompt_tokens < self.min_tokens:
scores[self.token_eos] = -np.inf
return scores
20 changes: 20 additions & 0 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ async def create_completion(
"best_of",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)

Expand All @@ -288,6 +289,15 @@ async def create_completion(
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
Expand Down Expand Up @@ -445,6 +455,7 @@ async def create_chat_completion(
"n",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
Expand All @@ -458,6 +469,15 @@ async def create_chat_completion(
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
Expand Down
8 changes: 8 additions & 0 deletions llama_cpp/server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
default=16, ge=1, description="The maximum number of tokens to generate."
)

min_tokens_field = Field(
default=0,
ge=0,
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
)

temperature_field = Field(
default=0.8,
description="Adjust the randomness of the generated text.\n\n"
Expand Down Expand Up @@ -111,6 +117,7 @@ class CreateCompletionRequest(BaseModel):
max_tokens: Optional[int] = Field(
default=16, ge=0, description="The maximum number of tokens to generate."
)
min_tokens: int = min_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
Expand Down Expand Up @@ -206,6 +213,7 @@ class CreateChatCompletionRequest(BaseModel):
default=None,
description="The maximum number of tokens to generate. Defaults to inf",
)
min_tokens: int = min_tokens_field
logprobs: Optional[bool] = Field(
default=False,
description="Whether to output the logprobs or not. Default is True"
Expand Down

0 comments on commit 5212fb0

Please sign in to comment.