Skip to content

Commit

Permalink
Hotfix for litellm judge (#490)
Browse files Browse the repository at this point in the history
* Made litellm judge backend more robust.

* Added failed flag to ModelResponse.

* Fixed wrong model response.

* Removed model response and replaced with string.

---------

Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
JoelNiklaus and clefourrier authored Jan 20, 2025
1 parent 3b89734 commit fee2ec3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
13 changes: 6 additions & 7 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from tqdm import tqdm

from lighteval.models.model_output import ModelResponse
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available


Expand Down Expand Up @@ -194,6 +193,7 @@ def __call_litellm(self, prompts):
import litellm

def __call_api(prompt):
error_message = "ERROR: Failed to get response from the API."
for _ in range(self.API_MAX_RETRY):
try:
kwargs = {
Expand All @@ -206,20 +206,19 @@ def __call_api(prompt):
}
response = litellm.completion(**kwargs)
text = response.choices[0].message.content
if not text or response.failed:
if not text or text == error_message:
kwargs["caching"] = False
response = litellm.completion(**kwargs)
text = response.choices[0].message.content
if not text or response.failed:
if not text or text == error_message:
# Just return an error response if the second attempt fails too
return ModelResponse(
text="Failed to get response from the API.", model=self.model, failed=True
)
logger.error(f"Failed to get response from the API for prompt: {prompt}")
return error_message
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
return ModelResponse(text="Failed to get response from the API.", model=self.model, failed=True)
return error_message

results = []
with ThreadPoolExecutor(100) as executor:
Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class ModelResponse:
generated_tokens: list[int] = field(default_factory=list) # model generations
truncated_tokens_count: Optional[int] = 0 # How many tokens truncated
padded_tokens_count: Optional[int] = 0 # How many tokens of padding
failed: bool = False

def get_result_for_eval(self):
raise NotImplementedError()
Expand Down

0 comments on commit fee2ec3

Please sign in to comment.