From 78789c12a199d049940be498dc3ac2a66f23d644 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Tue, 17 Dec 2024 15:15:34 +0000 Subject: [PATCH] allow bette rmessage managment for litellm --- src/lighteval/models/litellm_model.py | 4 ++-- src/lighteval/tasks/prompt_manager.py | 20 ++++++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/lighteval/models/litellm_model.py b/src/lighteval/models/litellm_model.py index 35f30ee7d..66f25ec16 100644 --- a/src/lighteval/models/litellm_model.py +++ b/src/lighteval/models/litellm_model.py @@ -112,7 +112,7 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se response = litellm.completion( model=self.model, - messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], + messages=prompt, max_completion_tokens=completion_tokens, logprobs=return_logits if self.provider == "openai" else None, stop=stop_sequence, @@ -234,7 +234,7 @@ def tokenizer(self): return self._tokenizer def tok_encode(self, text: str): - return self.tokenizer.encode(text) + return text @property def add_special_tokens(self) -> bool: diff --git a/src/lighteval/tasks/prompt_manager.py b/src/lighteval/tasks/prompt_manager.py index cb9f94d04..c8c842223 100644 --- a/src/lighteval/tasks/prompt_manager.py +++ b/src/lighteval/tasks/prompt_manager.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union from lighteval.models.abstract_model import LightevalModel +from lighteval.models.litellm_model import LiteLLMClient from lighteval.tasks.requests import Doc from lighteval.utils.utils import as_list @@ -205,7 +206,10 @@ def _single_turn_context( system_prompt=system_prompt, use_chat_template=use_chat_template, ) - toks = self.model.tok_encode(output) + if not use_chat_template: + toks = self.model.tok_encode(output) + else: + toks = "".join([msg["content"] for msg in output]) # If we need to truncate few-shots to fit in the context if truncate_few_shots and self.model.max_length is not None and self.model.tokenizer is not None: @@ -223,9 +227,17 @@ def _single_turn_context( system_prompt=system_prompt, use_chat_template=use_chat_template, ) - toks = self.model.tokenizer(output)["input_ids"] + if not use_chat_template: + toks = self.model.tok_encode(output) + else: + toks = "".join([msg["content"] for msg in output]) + + if isinstance(self.model, LiteLLMClient): + return output, num_effective_fewshots - return output, num_effective_fewshots + return self.model.tokenizer.apply_chat_template( + output, tokenize=False, add_generation_prompt=True + ), num_effective_fewshots def get_examples( self, @@ -256,7 +268,7 @@ def get_examples( examples.insert(0, {"role": "system", "content": system_prompt + instruction}) else: # Else we add the instruction to the first example examples[0]["content"] = instruction + examples[0]["content"] - return self.model.tokenizer.apply_chat_template(examples, tokenize=False, add_generation_prompt=True) + return examples else: if system_prompt is not None: output = system_prompt + instruction + "\n\n".join(examples)