diff --git a/examples/model_configs/openai_model.yaml b/examples/model_configs/openai_model.yaml new file mode 100644 index 000000000..fee54a991 --- /dev/null +++ b/examples/model_configs/openai_model.yaml @@ -0,0 +1,6 @@ +model: + type: "openai" # can be base, openai, tgi, or endpoint + instance: + address: null + model_id: null # Required, pointing to the HF repo for loading the tokenized + auth_token: null # Optional \ No newline at end of file diff --git a/src/lighteval/evaluator.py b/src/lighteval/evaluator.py index bd58d11d2..cd5ba70f5 100644 --- a/src/lighteval/evaluator.py +++ b/src/lighteval/evaluator.py @@ -33,12 +33,13 @@ from lighteval.logging.hierarchical_logger import hlog from lighteval.models.base_model import BaseModel from lighteval.models.tgi_model import ModelClient +from lighteval.models.oai_model import OAIModelClient from lighteval.tasks.lighteval_task import LightevalTask from lighteval.tasks.requests import Doc, Request, RequestType, TaskExampleId def evaluate( # noqa: C901 - lm: Union[BaseModel, ModelClient], + lm: Union[BaseModel, ModelClient, OAIModelClient], requests_dict: Dict[RequestType, list[Request]], docs: Dict[TaskExampleId, Doc], task_dict: Dict[str, LightevalTask], diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index 87959ef61..a61836428 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -23,6 +23,7 @@ import asyncio from typing import Coroutine, List, Optional, Union +from lighteval.models.utils import retry_with_backoff import torch from huggingface_hub import ( AsyncInferenceClient, @@ -53,7 +54,6 @@ BATCH_SIZE = 50 - class InferenceEndpointModel(LightevalModel): """InferenceEndpointModels can be used both with the free inference client, or with inference endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation. @@ -156,14 +156,14 @@ def _async_process_request( ) -> Coroutine[None, list[TextGenerationOutput], str]: # Todo: add an option to launch with conversational instead for chat prompts # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational - generated_text = self.async_client.text_generation( + generated_text = retry_with_backoff(lambda: self.async_client.text_generation( prompt=context, details=True, decoder_input_details=True, max_new_tokens=max_tokens, stop_sequences=stop_tokens, # truncate=, - ) + )) return generated_text diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 75a29d02c..63bf713bf 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -219,6 +219,12 @@ class TGIModelConfig: inference_server_auth: str model_id: str +@dataclass +class OAIModelConfig: + address: str + model_id: str + auth_token: str + @dataclass class DummyModelConfig: @@ -282,6 +288,7 @@ def create_model_config( # noqa: C901 AdapterModelConfig, DeltaModelConfig, TGIModelConfig, + OAIModelConfig, InferenceEndpointModelConfig, DummyModelConfig, ]: @@ -293,7 +300,7 @@ def create_model_config( # noqa: C901 accelerator (Union[Accelerator, None]): accelerator to use for model training. Returns: - Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, InferenceEndpointModelConfig, DummyModelConfig]: model configuration. + Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, OAIModelConfig, InferenceEndpointModelConfig, DummyModelConfig]: model configuration. Raises: ValueError: If both an inference server address and model arguments are provided. @@ -312,8 +319,13 @@ def create_model_config( # noqa: C901 return BaseModelConfig(**args_dict) - with open(args.model_config_path, "r") as f: - config = yaml.safe_load(f)["model"] + # This option isn't provided for use via the CLI but rather for programmatic use when constructing the config + # in code, so that the config won't need to be saved to a temp file when calling lighteval. + if hasattr(args, "model_config") and args.model_config: + config = args.model_config["model"] + else: + with open(args.model_config_path, "r") as f: + config = yaml.safe_load(f)["model"] if config["type"] == "tgi": return TGIModelConfig( @@ -321,6 +333,13 @@ def create_model_config( # noqa: C901 inference_server_auth=config["instance"]["inference_server_auth"], model_id=config["instance"]["model_id"], ) + + if config["type"] == "openai": + return OAIModelConfig( + address=config["instance"]["address"], + model_id=config["instance"]["model_id"], + auth_token=config["instance"]["auth_token"] + ) if config["type"] == "endpoint": reuse_existing_endpoint = config["base_params"]["reuse_existing"] diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index c72d64038..f8d352297 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -37,8 +37,10 @@ EnvConfig, InferenceEndpointModelConfig, InferenceModelConfig, + OAIModelConfig, TGIModelConfig, ) +from lighteval.models.oai_model import OAIModelClient from lighteval.models.tgi_model import ModelClient from lighteval.utils import NO_TGI_ERROR_MSG, is_accelerate_available, is_tgi_available @@ -60,12 +62,13 @@ def load_model( # noqa: C901 BaseModelConfig, AdapterModelConfig, DeltaModelConfig, + OAIModelConfig, TGIModelConfig, InferenceEndpointModelConfig, DummyModelConfig, ], env_config: EnvConfig, -) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel], ModelInfo]: +) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, OAIModelClient, DummyModel], ModelInfo]: """Will load either a model from an inference server or a model from a checkpoint, depending on the config type. @@ -79,11 +82,14 @@ def load_model( # noqa: C901 ValueError: If you did not specify a base model when using delta weights or adapter weights Returns: - Union[BaseModel, AdapterModel, DeltaModel, ModelClient]: The model that will be evaluated + Union[BaseModel, AdapterModel, DeltaModel, ModelClient, OAIModelClient]: The model that will be evaluated """ # Inference server loading if isinstance(config, TGIModelConfig): return load_model_with_tgi(config) + + if isinstance(config, OAIModelConfig): + return load_model_with_oai(config) if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig): return load_model_with_inference_endpoints(config, env_config=env_config) @@ -95,6 +101,13 @@ def load_model( # noqa: C901 return load_dummy_model(config=config, env_config=env_config) +def load_model_with_oai(config: OAIModelConfig): + model = OAIModelClient( + address=config.address, model_id=config.model_id, auth_token=config.auth_token + ) + return model, ModelInfo(model_name=config.model_id, model_sha='unknown', model_dtype='unknown', model_size='unknown') + + def load_model_with_tgi(config: TGIModelConfig): if not is_tgi_available(): raise ImportError(NO_TGI_ERROR_MSG) diff --git a/src/lighteval/models/oai_model.py b/src/lighteval/models/oai_model.py new file mode 100644 index 000000000..fe46be11b --- /dev/null +++ b/src/lighteval/models/oai_model.py @@ -0,0 +1,83 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite +# the client functions, since they use a different client. +import asyncio +from typing import Coroutine +from lighteval.models.endpoint_model import InferenceEndpointModel +from huggingface_hub import TextGenerationOutput +from lighteval.models.utils import retry_with_backoff +from transformers import AutoTokenizer +from openai import AsyncOpenAI + +class OAIModelClient(InferenceEndpointModel): + _DEFAULT_MAX_LENGTH: int = 4096 + + def __init__(self, address, model_id, auth_token=None) -> None: + self.client = AsyncOpenAI(base_url=address, api_key=(auth_token or "none")) + self.model_id = model_id + self._max_gen_toks = 256 + + self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self._add_special_tokens = True + self.use_async = True + + async def _async_process_request( + self, context: str, stop_tokens: list[str], max_tokens: int + ) -> Coroutine[None, TextGenerationOutput, str]: + # Todo: add an option to launch with conversational instead for chat prompts + output = await retry_with_backoff(lambda: self.client.completions.create( + model="/repository", + prompt=context, + max_tokens=max_tokens, + stop=stop_tokens + )) + + return TextGenerationOutput(generated_text=output.choices[0].text) + + def _process_request(self, *args, **kwargs) -> TextGenerationOutput: + return asyncio.run(self._async_process_request(*args, **kwargs)) + + def set_cache_hook(self, cache_hook): + self.cache_hook = cache_hook + + @property + def tokenizer(self): + return self._tokenizer + + @property + def add_special_tokens(self): + return self._add_special_tokens + + @property + def max_length(self) -> int: + if hasattr(self.tokenizer, "model_max_length"): + return self.tokenizer.model_max_length + return OAIModelClient._DEFAULT_MAX_LENGTH + + @property + def disable_tqdm(self) -> bool: + False + + def cleanup(self): + pass diff --git a/src/lighteval/models/utils.py b/src/lighteval/models/utils.py index ec4a22758..76f8d840f 100644 --- a/src/lighteval/models/utils.py +++ b/src/lighteval/models/utils.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import asyncio +import logging import os from itertools import islice from typing import Optional, Union @@ -98,3 +100,19 @@ def batched(iterable, n): it = iter(iterable) while batch := tuple(islice(it, n)): yield batch + +import random +MAX_RETRIES = 15 +INITIAL_BACKOFF = 1 +async def retry_with_backoff(coro_fn): + for attempt in range(MAX_RETRIES): + try: + return await coro_fn() + except Exception as e: + if attempt < MAX_RETRIES - 1: + backoff_time = INITIAL_BACKOFF * (2 ** attempt) + random.uniform(0, 1) # used to be 2 **, but waited too long + logging.info(e) + logging.info(f'Encountered error, backing off and retrying in {backoff_time}s...') + await asyncio.sleep(backoff_time) + else: + raise e