From fee2ec32869d54be4214af6c85e44e5717fe5d07 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Mon, 20 Jan 2025 01:02:32 -0800 Subject: [PATCH] Hotfix for litellm judge (#490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Made litellm judge backend more robust. * Added failed flag to ModelResponse. * Fixed wrong model response. * Removed model response and replaced with string. --------- Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> --- src/lighteval/metrics/llm_as_judge.py | 13 ++++++------- src/lighteval/models/model_output.py | 1 - 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py index 81e1d7d30..23beda76f 100644 --- a/src/lighteval/metrics/llm_as_judge.py +++ b/src/lighteval/metrics/llm_as_judge.py @@ -28,7 +28,6 @@ from tqdm import tqdm -from lighteval.models.model_output import ModelResponse from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available @@ -194,6 +193,7 @@ def __call_litellm(self, prompts): import litellm def __call_api(prompt): + error_message = "ERROR: Failed to get response from the API." for _ in range(self.API_MAX_RETRY): try: kwargs = { @@ -206,20 +206,19 @@ def __call_api(prompt): } response = litellm.completion(**kwargs) text = response.choices[0].message.content - if not text or response.failed: + if not text or text == error_message: kwargs["caching"] = False response = litellm.completion(**kwargs) text = response.choices[0].message.content - if not text or response.failed: + if not text or text == error_message: # Just return an error response if the second attempt fails too - return ModelResponse( - text="Failed to get response from the API.", model=self.model, failed=True - ) + logger.error(f"Failed to get response from the API for prompt: {prompt}") + return error_message return text except Exception as e: logger.warning(f"{type(e), e}") time.sleep(self.API_RETRY_SLEEP) - return ModelResponse(text="Failed to get response from the API.", model=self.model, failed=True) + return error_message results = [] with ThreadPoolExecutor(100) as executor: diff --git a/src/lighteval/models/model_output.py b/src/lighteval/models/model_output.py index b485371ca..7d0ba4818 100644 --- a/src/lighteval/models/model_output.py +++ b/src/lighteval/models/model_output.py @@ -33,7 +33,6 @@ class ModelResponse: generated_tokens: list[int] = field(default_factory=list) # model generations truncated_tokens_count: Optional[int] = 0 # How many tokens truncated padded_tokens_count: Optional[int] = 0 # How many tokens of padding - failed: bool = False def get_result_for_eval(self): raise NotImplementedError()