Skip to content

Commit

Permalink
Add litellm inference (#385)
Browse files Browse the repository at this point in the history
This PR enables running inference using any model provider supported by litellm as well as using litellm for llm as a judge.

---------

Co-authored-by: Egor Lebedev <[email protected]>
Co-authored-by: Kryvich <[email protected]>
Co-authored-by: Clémentine Fourrier <[email protected]>
Co-authored-by: Nazim Ali <[email protected]>
Co-authored-by: vsabolcec <[email protected]>
Co-authored-by: Nathan Habib <[email protected]>
Co-authored-by: Nathan Habib <[email protected]>
Co-authored-by: Albert Villanova del Moral <[email protected]>
  • Loading branch information
9 people authored Jan 2, 2025
1 parent 2ef9740 commit a2541b1
Show file tree
Hide file tree
Showing 11 changed files with 493 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
cache: 'pip'
- name: Install lighteval in editable mode
run: |
pip install -e .[dev,extended_tasks,multilingual]
pip install -e .[dev,extended_tasks,multilingual,litellm]
- name: Get cached files
uses: actions/cache@v4
id: get-cache
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ dependencies = [
]

[project.optional-dependencies]
litellm = ["litellm", "diskcache"]
tgi = ["text-generation==0.6.0"]
optimum = ["optimum==1.12.0"]
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
Expand Down
109 changes: 109 additions & 0 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,112 @@ def tgi(
pipeline.save_and_push_results()

return results


@app.command(rich_help_panel="Evaluation Backends")
def litellm(
# === general ===
model_name: Annotated[
str, Argument(help="The model name to evaluate (has to be available through the litellm API.")
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4)
] = False,
system_prompt: Annotated[
Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANEL_NAME_4)
] = None,
dataset_loading_processes: Annotated[
int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
custom_tasks: Annotated[
Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
cache_dir: Annotated[
str, Option(help="Cache directory for datasets and models.", rich_help_panel=HELP_PANEL_NAME_1)
] = CACHE_DIR,
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
] = "results",
push_to_hub: Annotated[
bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
push_to_tensorboard: Annotated[
bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
public_run: Annotated[
bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
results_org: Annotated[
Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANEL_NAME_2)
] = None,
save_details: Annotated[
bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
# === debug ===
max_samples: Annotated[
Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANEL_NAME_3)
] = None,
override_batch_size: Annotated[
int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANEL_NAME_3)
] = -1,
job_id: Annotated[
int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANEL_NAME_3)
] = 0,
):
"""
Evaluate models using LiteLLM as backend.
"""

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.litellm_model import LiteLLMModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
evaluation_tracker = EvaluationTracker(
output_dir=output_dir,
save_details=save_details,
push_to_hub=push_to_hub,
push_to_tensorboard=push_to_tensorboard,
public=public_run,
hub_results_org=results_org,
)

# TODO (nathan): better handling of model_args
parallelism_manager = ParallelismManager.NONE

model_config = LiteLLMModelConfig(model=model_name)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
env_config=env_config,
job_id=job_id,
dataset_loading_processes=dataset_loading_processes,
custom_tasks_directory=custom_tasks,
override_batch_size=override_batch_size,
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
)
pipeline = Pipeline(
tasks=tasks,
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model_config=model_config,
)

pipeline.evaluate()

pipeline.show_results()

results = pipeline.get_results()

pipeline.save_and_push_results()

return results
39 changes: 36 additions & 3 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from tqdm import tqdm

from lighteval.utils.imports import is_openai_available, is_vllm_available
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available


logging.getLogger("openai").setLevel(logging.ERROR)
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
model: str,
templates: Callable,
process_judge_response: Callable,
judge_backend: Literal["openai", "transformers", "tgi", "vllm"],
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"],
url: str | None = None,
api_key: str | None = None,
):
Expand All @@ -93,7 +93,7 @@ def __init__(

def __lazy_load_client(self):
match self.backend:
# Wether we use openai or TGI models, we go trhough the openai API
# Wether we use openai or TGI models, we go through the openai API
# to route to the endpoint
case "openai" | "tgi" if is_openai_available():
if self.client is None:
Expand All @@ -104,6 +104,8 @@ def __lazy_load_client(self):
else:
self.client = OpenAI(base_url=self.url, api_key=self.api_key)
return self.__call_api_parallel
case "litellm" if is_litellm_available():
return self.__call_litellm
case "vllm" if is_vllm_available():
if self.pipe is None:
from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -187,6 +189,37 @@ def __call_vllm(self, prompt):
outputs = [output.outputs[0].text for output in output]
return outputs

def __call_litellm(self, prompts):
import litellm

def __call_api(prompt):
for _ in range(self.API_MAX_RETRY):
try:
response = litellm.completion(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
n=1,
caching=True,
)
text = response.choices[0].message.content
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
raise Exception("Failed to get response from the API")

results = []
with ThreadPoolExecutor(100) as executor:
for entry in tqdm(executor.map(__call_api, prompts), total=len(prompts)):
results.append(entry)

if None in results:
raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.")

return results

def __call_api_parallel(self, prompts):
results = []
with ThreadPoolExecutor(100) as executor:
Expand Down
5 changes: 4 additions & 1 deletion src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def __init__(
judge_model_name: str,
template: Callable,
process_judge_response: Callable,
judge_backend: Literal["openai", "transformers", "vllm", "tgi"],
judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"],
short_judge_name: str | None = None,
) -> None:
match judge_backend:
Expand All @@ -872,6 +872,9 @@ def __init__(
case "tgi":
api_key = os.getenv("HF_TOKEN")
url = "https://api-inference.huggingface.co/v1/"
case "litellm":
api_key = None
url = None
case "transformers" | "vllm":
api = HfApi()
models = api.list_models(model_name=judge_model_name)
Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/endpoints/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def greedy_until(
Args:
requests (list[Request]): list of requests containing the context and ending conditions.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
Returns:
Expand Down
Loading

0 comments on commit a2541b1

Please sign in to comment.