From d26a5769621dc0b2ad38a633a98a66bc415ea101 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Mon, 29 Jul 2024 19:25:04 +0300 Subject: [PATCH 1/5] Added option for specifying model config in args directly --- src/lighteval/models/model_config.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 75a29d02c..e1505ba04 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -312,8 +312,11 @@ def create_model_config( # noqa: C901 return BaseModelConfig(**args_dict) - with open(args.model_config_path, "r") as f: - config = yaml.safe_load(f)["model"] + if hasattr(args, "model_config") and args.model_config: + config = args.model_config["model"] + else: + with open(args.model_config_path, "r") as f: + config = yaml.safe_load(f)["model"] if config["type"] == "tgi": return TGIModelConfig( From 68230b3a253d7cd50eebd0722305cdeabc92006e Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Mon, 29 Jul 2024 19:28:08 +0300 Subject: [PATCH 2/5] Added comment explaining --- src/lighteval/models/model_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index e1505ba04..37c511c14 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -312,6 +312,8 @@ def create_model_config( # noqa: C901 return BaseModelConfig(**args_dict) + # This option isn't provided for use via the CLI but rather for programmatic use when constructing the config + # in code, so that the config won't need to be saved to a temp file when calling lighteval. if hasattr(args, "model_config") and args.model_config: config = args.model_config["model"] else: From 54ed088523c276c08ae1aa73deb876dce2484a43 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Sun, 25 Aug 2024 00:21:40 +0300 Subject: [PATCH 3/5] Added support for OpenAI model --- examples/model_configs/openai_model.yaml | 6 ++ src/lighteval/evaluator.py | 3 +- src/lighteval/models/model_config.py | 16 ++++- src/lighteval/models/model_loader.py | 17 ++++- src/lighteval/models/oai_model.py | 81 ++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 examples/model_configs/openai_model.yaml create mode 100644 src/lighteval/models/oai_model.py diff --git a/examples/model_configs/openai_model.yaml b/examples/model_configs/openai_model.yaml new file mode 100644 index 000000000..fee54a991 --- /dev/null +++ b/examples/model_configs/openai_model.yaml @@ -0,0 +1,6 @@ +model: + type: "openai" # can be base, openai, tgi, or endpoint + instance: + address: null + model_id: null # Required, pointing to the HF repo for loading the tokenized + auth_token: null # Optional \ No newline at end of file diff --git a/src/lighteval/evaluator.py b/src/lighteval/evaluator.py index bd58d11d2..cd5ba70f5 100644 --- a/src/lighteval/evaluator.py +++ b/src/lighteval/evaluator.py @@ -33,12 +33,13 @@ from lighteval.logging.hierarchical_logger import hlog from lighteval.models.base_model import BaseModel from lighteval.models.tgi_model import ModelClient +from lighteval.models.oai_model import OAIModelClient from lighteval.tasks.lighteval_task import LightevalTask from lighteval.tasks.requests import Doc, Request, RequestType, TaskExampleId def evaluate( # noqa: C901 - lm: Union[BaseModel, ModelClient], + lm: Union[BaseModel, ModelClient, OAIModelClient], requests_dict: Dict[RequestType, list[Request]], docs: Dict[TaskExampleId, Doc], task_dict: Dict[str, LightevalTask], diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 37c511c14..63bf713bf 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -219,6 +219,12 @@ class TGIModelConfig: inference_server_auth: str model_id: str +@dataclass +class OAIModelConfig: + address: str + model_id: str + auth_token: str + @dataclass class DummyModelConfig: @@ -282,6 +288,7 @@ def create_model_config( # noqa: C901 AdapterModelConfig, DeltaModelConfig, TGIModelConfig, + OAIModelConfig, InferenceEndpointModelConfig, DummyModelConfig, ]: @@ -293,7 +300,7 @@ def create_model_config( # noqa: C901 accelerator (Union[Accelerator, None]): accelerator to use for model training. Returns: - Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, InferenceEndpointModelConfig, DummyModelConfig]: model configuration. + Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, OAIModelConfig, InferenceEndpointModelConfig, DummyModelConfig]: model configuration. Raises: ValueError: If both an inference server address and model arguments are provided. @@ -326,6 +333,13 @@ def create_model_config( # noqa: C901 inference_server_auth=config["instance"]["inference_server_auth"], model_id=config["instance"]["model_id"], ) + + if config["type"] == "openai": + return OAIModelConfig( + address=config["instance"]["address"], + model_id=config["instance"]["model_id"], + auth_token=config["instance"]["auth_token"] + ) if config["type"] == "endpoint": reuse_existing_endpoint = config["base_params"]["reuse_existing"] diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index c72d64038..f8d352297 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -37,8 +37,10 @@ EnvConfig, InferenceEndpointModelConfig, InferenceModelConfig, + OAIModelConfig, TGIModelConfig, ) +from lighteval.models.oai_model import OAIModelClient from lighteval.models.tgi_model import ModelClient from lighteval.utils import NO_TGI_ERROR_MSG, is_accelerate_available, is_tgi_available @@ -60,12 +62,13 @@ def load_model( # noqa: C901 BaseModelConfig, AdapterModelConfig, DeltaModelConfig, + OAIModelConfig, TGIModelConfig, InferenceEndpointModelConfig, DummyModelConfig, ], env_config: EnvConfig, -) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel], ModelInfo]: +) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, OAIModelClient, DummyModel], ModelInfo]: """Will load either a model from an inference server or a model from a checkpoint, depending on the config type. @@ -79,11 +82,14 @@ def load_model( # noqa: C901 ValueError: If you did not specify a base model when using delta weights or adapter weights Returns: - Union[BaseModel, AdapterModel, DeltaModel, ModelClient]: The model that will be evaluated + Union[BaseModel, AdapterModel, DeltaModel, ModelClient, OAIModelClient]: The model that will be evaluated """ # Inference server loading if isinstance(config, TGIModelConfig): return load_model_with_tgi(config) + + if isinstance(config, OAIModelConfig): + return load_model_with_oai(config) if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig): return load_model_with_inference_endpoints(config, env_config=env_config) @@ -95,6 +101,13 @@ def load_model( # noqa: C901 return load_dummy_model(config=config, env_config=env_config) +def load_model_with_oai(config: OAIModelConfig): + model = OAIModelClient( + address=config.address, model_id=config.model_id, auth_token=config.auth_token + ) + return model, ModelInfo(model_name=config.model_id, model_sha='unknown', model_dtype='unknown', model_size='unknown') + + def load_model_with_tgi(config: TGIModelConfig): if not is_tgi_available(): raise ImportError(NO_TGI_ERROR_MSG) diff --git a/src/lighteval/models/oai_model.py b/src/lighteval/models/oai_model.py new file mode 100644 index 000000000..2462a09f2 --- /dev/null +++ b/src/lighteval/models/oai_model.py @@ -0,0 +1,81 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite +# the client functions, since they use a different client. +import asyncio +from typing import Coroutine +from lighteval.models.endpoint_model import InferenceEndpointModel +from huggingface_hub import TextGenerationOutput +from transformers import AutoTokenizer +from openai import AsyncOpenAI + +class OAIModelClient(InferenceEndpointModel): + _DEFAULT_MAX_LENGTH: int = 4096 + + def __init__(self, address, model_id, auth_token=None) -> None: + self.client = AsyncOpenAI(base_url=address, api_key=(auth_token or "none")) + self.model_id = model_id + self._max_gen_toks = 256 + + self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self._add_special_tokens = True + self.use_async = True + + async def _async_process_request( + self, context: str, stop_tokens: list[str], max_tokens: int + ) -> Coroutine[None, TextGenerationOutput, str]: + # Todo: add an option to launch with conversational instead for chat prompts + output = await self.client.completions.create( + model="/repository", + prompt=context, + max_tokens=max_tokens, + stop=stop_tokens) + + return TextGenerationOutput(generated_text=output.choices[0].text) + + def _process_request(self, *args, **kwargs) -> TextGenerationOutput: + return asyncio.run(self._async_process_request(*args, **kwargs)) + + def set_cache_hook(self, cache_hook): + self.cache_hook = cache_hook + + @property + def tokenizer(self): + return self._tokenizer + + @property + def add_special_tokens(self): + return self._add_special_tokens + + @property + def max_length(self) -> int: + if hasattr(self.tokenizer, "model_max_length"): + return self.tokenizer.model_max_length + return OAIModelClient._DEFAULT_MAX_LENGTH + + @property + def disable_tqdm(self) -> bool: + False + + def cleanup(self): + pass From 7415803941c042e20d75d71f0c9a4c90ae704e24 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Sun, 25 Aug 2024 13:35:19 +0300 Subject: [PATCH 4/5] Added backoff --- src/lighteval/models/endpoint_model.py | 4 ++-- src/lighteval/models/oai_model.py | 8 +++++--- src/lighteval/models/utils.py | 15 +++++++++++++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index 87959ef61..9781ee478 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -23,6 +23,7 @@ import asyncio from typing import Coroutine, List, Optional, Union +from lighteval.models.utils import retry_with_backoff import torch from huggingface_hub import ( AsyncInferenceClient, @@ -53,7 +54,6 @@ BATCH_SIZE = 50 - class InferenceEndpointModel(LightevalModel): """InferenceEndpointModels can be used both with the free inference client, or with inference endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation. @@ -165,7 +165,7 @@ def _async_process_request( # truncate=, ) - return generated_text + return retry_with_backoff(generated_text) def _process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput: # Todo: add an option to launch with conversational instead for chat prompts diff --git a/src/lighteval/models/oai_model.py b/src/lighteval/models/oai_model.py index 2462a09f2..5885dc37b 100644 --- a/src/lighteval/models/oai_model.py +++ b/src/lighteval/models/oai_model.py @@ -26,6 +26,7 @@ from typing import Coroutine from lighteval.models.endpoint_model import InferenceEndpointModel from huggingface_hub import TextGenerationOutput +from lighteval.models.utils import retry_with_backoff from transformers import AutoTokenizer from openai import AsyncOpenAI @@ -45,12 +46,13 @@ async def _async_process_request( self, context: str, stop_tokens: list[str], max_tokens: int ) -> Coroutine[None, TextGenerationOutput, str]: # Todo: add an option to launch with conversational instead for chat prompts - output = await self.client.completions.create( + output = await retry_with_backoff(self.client.completions.create( model="/repository", prompt=context, max_tokens=max_tokens, - stop=stop_tokens) - + stop=stop_tokens + )) + return TextGenerationOutput(generated_text=output.choices[0].text) def _process_request(self, *args, **kwargs) -> TextGenerationOutput: diff --git a/src/lighteval/models/utils.py b/src/lighteval/models/utils.py index ec4a22758..6fa6ad601 100644 --- a/src/lighteval/models/utils.py +++ b/src/lighteval/models/utils.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import asyncio import os from itertools import islice from typing import Optional, Union @@ -98,3 +99,17 @@ def batched(iterable, n): it = iter(iterable) while batch := tuple(islice(it, n)): yield batch + +import random +MAX_RETRIES = 5 +INITIAL_BACKOFF = 1 +async def retry_with_backoff(coro): + for attempt in range(MAX_RETRIES): + try: + return await coro + except Exception as e: + if attempt < MAX_RETRIES - 1: + backoff_time = INITIAL_BACKOFF * (2 ** attempt) + random.uniform(0, 1) + await asyncio.sleep(backoff_time) + else: + raise e From 2340a4fcb85a36123b1d3ed4494d8028e0817b03 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Sun, 25 Aug 2024 17:36:59 +0300 Subject: [PATCH 5/5] Fix retry with backoff --- src/lighteval/models/endpoint_model.py | 6 +++--- src/lighteval/models/oai_model.py | 2 +- src/lighteval/models/utils.py | 11 +++++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index 9781ee478..a61836428 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -156,16 +156,16 @@ def _async_process_request( ) -> Coroutine[None, list[TextGenerationOutput], str]: # Todo: add an option to launch with conversational instead for chat prompts # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational - generated_text = self.async_client.text_generation( + generated_text = retry_with_backoff(lambda: self.async_client.text_generation( prompt=context, details=True, decoder_input_details=True, max_new_tokens=max_tokens, stop_sequences=stop_tokens, # truncate=, - ) + )) - return retry_with_backoff(generated_text) + return generated_text def _process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput: # Todo: add an option to launch with conversational instead for chat prompts diff --git a/src/lighteval/models/oai_model.py b/src/lighteval/models/oai_model.py index 5885dc37b..fe46be11b 100644 --- a/src/lighteval/models/oai_model.py +++ b/src/lighteval/models/oai_model.py @@ -46,7 +46,7 @@ async def _async_process_request( self, context: str, stop_tokens: list[str], max_tokens: int ) -> Coroutine[None, TextGenerationOutput, str]: # Todo: add an option to launch with conversational instead for chat prompts - output = await retry_with_backoff(self.client.completions.create( + output = await retry_with_backoff(lambda: self.client.completions.create( model="/repository", prompt=context, max_tokens=max_tokens, diff --git a/src/lighteval/models/utils.py b/src/lighteval/models/utils.py index 6fa6ad601..76f8d840f 100644 --- a/src/lighteval/models/utils.py +++ b/src/lighteval/models/utils.py @@ -21,6 +21,7 @@ # SOFTWARE. import asyncio +import logging import os from itertools import islice from typing import Optional, Union @@ -101,15 +102,17 @@ def batched(iterable, n): yield batch import random -MAX_RETRIES = 5 +MAX_RETRIES = 15 INITIAL_BACKOFF = 1 -async def retry_with_backoff(coro): +async def retry_with_backoff(coro_fn): for attempt in range(MAX_RETRIES): try: - return await coro + return await coro_fn() except Exception as e: if attempt < MAX_RETRIES - 1: - backoff_time = INITIAL_BACKOFF * (2 ** attempt) + random.uniform(0, 1) + backoff_time = INITIAL_BACKOFF * (2 ** attempt) + random.uniform(0, 1) # used to be 2 **, but waited too long + logging.info(e) + logging.info(f'Encountered error, backing off and retrying in {backoff_time}s...') await asyncio.sleep(backoff_time) else: raise e