From bc62cb98e45a97a50a7cc76d5e60b184d3fb6ae1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Tue, 17 Dec 2024 13:54:51 +0100 Subject: [PATCH 1/2] use concrete version instead of latest (#452) * recommended to use concrete version instead of latest * ruff style --------- Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- src/lighteval/models/endpoints/endpoint_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 1344e2485..e50c0405c 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -216,9 +216,7 @@ def __init__( # noqa: C901 **config.get_dtype_args(), **config.get_custom_env_vars(), }, - "url": ( - config.image_url or "ghcr.io/huggingface/text-generation-inference:latest" - ), + "url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:3.0.1"), }, ) else: # Endpoint exists From 51ca581fccdd047e98510632934c4773d10c59d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:58:34 +0100 Subject: [PATCH 2/2] Adds serverless endpoints back (#445) * init * adding serverless endpoints back * updated tests --- docs/source/package_reference/models.mdx | 2 +- ..._model_lite.yaml => serverless_model.yaml} | 0 src/lighteval/main_endpoint.py | 19 +++++++----- .../models/endpoints/endpoint_model.py | 29 ++++++++++++++----- src/lighteval/models/model_loader.py | 4 +-- tests/models/endpoints/test_endpoint_model.py | 2 +- 6 files changed, 37 insertions(+), 19 deletions(-) rename examples/model_configs/{endpoint_model_lite.yaml => serverless_model.yaml} (100%) diff --git a/docs/source/package_reference/models.mdx b/docs/source/package_reference/models.mdx index 096ce7be3..dcf5bc8dc 100644 --- a/docs/source/package_reference/models.mdx +++ b/docs/source/package_reference/models.mdx @@ -21,7 +21,7 @@ ## Endpoints-based Models ### InferenceEndpointModel [[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig -[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig +[[autodoc]] models.endpoints.endpoint_model.ServerlessEndpointModelConfig [[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel ### TGI ModelClient diff --git a/examples/model_configs/endpoint_model_lite.yaml b/examples/model_configs/serverless_model.yaml similarity index 100% rename from examples/model_configs/endpoint_model_lite.yaml rename to examples/model_configs/serverless_model.yaml diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 47d059660..04a00f0a5 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -146,6 +146,13 @@ def inference_endpoint( str, Argument(help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml)") ], tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")], + free_endpoint: Annotated[ + bool, + Option( + help="Use serverless free endpoints instead of spinning up your own inference endpoint.", + rich_help_panel=HELP_PANEL_NAME_4, + ), + ] = False, # === Common parameters === use_chat_template: Annotated[ bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) @@ -200,9 +207,7 @@ def inference_endpoint( """ from lighteval.logging.evaluation_tracker import EvaluationTracker - from lighteval.models.endpoints.endpoint_model import ( - InferenceEndpointModelConfig, - ) + from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) @@ -220,10 +225,10 @@ def inference_endpoint( parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote # Find a way to add this back - # if config["base_params"].get("endpoint_name", None): - # return InferenceModelConfig(model=config["base_params"]["endpoint_name"]) - - model_config = InferenceEndpointModelConfig.from_path(model_config_path) + if free_endpoint: + model_config = ServerlessEndpointModelConfig.from_path(model_config_path) + else: + model_config = InferenceEndpointModelConfig.from_path(model_config_path) pipeline_params = PipelineParameters( launcher_type=parallelism_manager, diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index e50c0405c..80798b616 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -75,10 +75,18 @@ @dataclass -class InferenceModelConfig: - model: str +class ServerlessEndpointModelConfig: + model_name: str add_special_tokens: bool = True + @classmethod + def from_path(cls, path: str) -> "ServerlessEndpointModelConfig": + import yaml + + with open(path, "r") as f: + config = yaml.safe_load(f)["model"] + return cls(**config["base_params"]) + @dataclass class InferenceEndpointModelConfig: @@ -150,7 +158,7 @@ class InferenceEndpointModel(LightevalModel): """ def __init__( # noqa: C901 - self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig + self, config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig ) -> None: self.reuse_existing = getattr(config, "reuse_existing", False) self._max_length = None @@ -280,10 +288,10 @@ def __init__( # noqa: C901 else: # Free inference client self.endpoint = None self.endpoint_name = None - self.name = config.model + self.name = config.model_name self.revision = "default" - self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token) - self.client = InferenceClient(model=config.model, token=env_config.token) + self.async_client = AsyncInferenceClient(model=config.model_name, token=env_config.token) + self.client = InferenceClient(model=config.model_name, token=env_config.token) self.use_async = True # set to False for debug - async use is faster @@ -293,7 +301,7 @@ def __init__( # noqa: C901 self.model_info = ModelInfo( model_name=self.name, model_sha=self.revision, - model_dtype=config.model_dtype or "default", + model_dtype=getattr(config, "model_dtype", "default"), model_size=-1, ) @@ -545,7 +553,12 @@ def loglikelihood( cont_toks = torch.tensor(cur_request.tokenized_continuation) len_choice = len(cont_toks) - logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None] + if self.endpoint: # inference endpoint + logits = [ + t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None + ] # to check + else: # serverless endpoint + logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None] greedy_tokens = torch.tensor(logits).argmax(dim=-1) max_equal = (greedy_tokens == cont_toks).all().squeeze(0) diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index b0817be4a..66eb99886 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -27,7 +27,7 @@ from lighteval.models.endpoints.endpoint_model import ( InferenceEndpointModel, InferenceEndpointModelConfig, - InferenceModelConfig, + ServerlessEndpointModelConfig, ) from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig @@ -80,7 +80,7 @@ def load_model( # noqa: C901 if isinstance(config, TGIModelConfig): return load_model_with_tgi(config) - if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig): + if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, ServerlessEndpointModelConfig): return load_model_with_inference_endpoints(config, env_config=env_config) if isinstance(config, BaseModelConfig): diff --git a/tests/models/endpoints/test_endpoint_model.py b/tests/models/endpoints/test_endpoint_model.py index 29fbb3c48..f4ba15d91 100644 --- a/tests/models/endpoints/test_endpoint_model.py +++ b/tests/models/endpoints/test_endpoint_model.py @@ -53,7 +53,7 @@ class TestInferenceEndpointModelConfig: }, ), ( - "examples/model_configs/endpoint_model_lite.yaml", + "examples/model_configs/serverless_model.yaml", { "model_name": "meta-llama/Llama-3.1-8B-Instruct", # Defaults: