Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option for specifying model config directly, for programmatic use #245

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/model_configs/openai_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model:
type: "openai" # can be base, openai, tgi, or endpoint
instance:
address: null
model_id: null # Required, pointing to the HF repo for loading the tokenized
auth_token: null # Optional
3 changes: 2 additions & 1 deletion src/lighteval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.base_model import BaseModel
from lighteval.models.tgi_model import ModelClient
from lighteval.models.oai_model import OAIModelClient
from lighteval.tasks.lighteval_task import LightevalTask
from lighteval.tasks.requests import Doc, Request, RequestType, TaskExampleId


def evaluate( # noqa: C901
lm: Union[BaseModel, ModelClient],
lm: Union[BaseModel, ModelClient, OAIModelClient],
requests_dict: Dict[RequestType, list[Request]],
docs: Dict[TaskExampleId, Doc],
task_dict: Dict[str, LightevalTask],
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import asyncio
from typing import Coroutine, List, Optional, Union

from lighteval.models.utils import retry_with_backoff
import torch
from huggingface_hub import (
AsyncInferenceClient,
Expand Down Expand Up @@ -53,7 +54,6 @@

BATCH_SIZE = 50


class InferenceEndpointModel(LightevalModel):
"""InferenceEndpointModels can be used both with the free inference client, or with inference
endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation.
Expand Down Expand Up @@ -156,14 +156,14 @@ def _async_process_request(
) -> Coroutine[None, list[TextGenerationOutput], str]:
# Todo: add an option to launch with conversational instead for chat prompts
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
generated_text = self.async_client.text_generation(
generated_text = retry_with_backoff(lambda: self.async_client.text_generation(
prompt=context,
details=True,
decoder_input_details=True,
max_new_tokens=max_tokens,
stop_sequences=stop_tokens,
# truncate=,
)
))

return generated_text

Expand Down
25 changes: 22 additions & 3 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ class TGIModelConfig:
inference_server_auth: str
model_id: str

@dataclass
class OAIModelConfig:
address: str
model_id: str
auth_token: str


@dataclass
class DummyModelConfig:
Expand Down Expand Up @@ -282,6 +288,7 @@ def create_model_config( # noqa: C901
AdapterModelConfig,
DeltaModelConfig,
TGIModelConfig,
OAIModelConfig,
InferenceEndpointModelConfig,
DummyModelConfig,
]:
Expand All @@ -293,7 +300,7 @@ def create_model_config( # noqa: C901
accelerator (Union[Accelerator, None]): accelerator to use for model training.

Returns:
Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, InferenceEndpointModelConfig, DummyModelConfig]: model configuration.
Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, OAIModelConfig, InferenceEndpointModelConfig, DummyModelConfig]: model configuration.

Raises:
ValueError: If both an inference server address and model arguments are provided.
Expand All @@ -312,15 +319,27 @@ def create_model_config( # noqa: C901

return BaseModelConfig(**args_dict)

with open(args.model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]
# This option isn't provided for use via the CLI but rather for programmatic use when constructing the config
# in code, so that the config won't need to be saved to a temp file when calling lighteval.
if hasattr(args, "model_config") and args.model_config:
config = args.model_config["model"]
else:
with open(args.model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]

if config["type"] == "tgi":
return TGIModelConfig(
inference_server_address=config["instance"]["inference_server_address"],
inference_server_auth=config["instance"]["inference_server_auth"],
model_id=config["instance"]["model_id"],
)

if config["type"] == "openai":
return OAIModelConfig(
address=config["instance"]["address"],
model_id=config["instance"]["model_id"],
auth_token=config["instance"]["auth_token"]
)

if config["type"] == "endpoint":
reuse_existing_endpoint = config["base_params"]["reuse_existing"]
Expand Down
17 changes: 15 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
EnvConfig,
InferenceEndpointModelConfig,
InferenceModelConfig,
OAIModelConfig,
TGIModelConfig,
)
from lighteval.models.oai_model import OAIModelClient
from lighteval.models.tgi_model import ModelClient
from lighteval.utils import NO_TGI_ERROR_MSG, is_accelerate_available, is_tgi_available

Expand All @@ -60,12 +62,13 @@ def load_model( # noqa: C901
BaseModelConfig,
AdapterModelConfig,
DeltaModelConfig,
OAIModelConfig,
TGIModelConfig,
InferenceEndpointModelConfig,
DummyModelConfig,
],
env_config: EnvConfig,
) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel], ModelInfo]:
) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, OAIModelClient, DummyModel], ModelInfo]:
"""Will load either a model from an inference server or a model from a checkpoint, depending
on the config type.

Expand All @@ -79,11 +82,14 @@ def load_model( # noqa: C901
ValueError: If you did not specify a base model when using delta weights or adapter weights

Returns:
Union[BaseModel, AdapterModel, DeltaModel, ModelClient]: The model that will be evaluated
Union[BaseModel, AdapterModel, DeltaModel, ModelClient, OAIModelClient]: The model that will be evaluated
"""
# Inference server loading
if isinstance(config, TGIModelConfig):
return load_model_with_tgi(config)

if isinstance(config, OAIModelConfig):
return load_model_with_oai(config)

if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig):
return load_model_with_inference_endpoints(config, env_config=env_config)
Expand All @@ -95,6 +101,13 @@ def load_model( # noqa: C901
return load_dummy_model(config=config, env_config=env_config)


def load_model_with_oai(config: OAIModelConfig):
model = OAIModelClient(
address=config.address, model_id=config.model_id, auth_token=config.auth_token
)
return model, ModelInfo(model_name=config.model_id, model_sha='unknown', model_dtype='unknown', model_size='unknown')


def load_model_with_tgi(config: TGIModelConfig):
if not is_tgi_available():
raise ImportError(NO_TGI_ERROR_MSG)
Expand Down
83 changes: 83 additions & 0 deletions src/lighteval/models/oai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
# the client functions, since they use a different client.
import asyncio
from typing import Coroutine
from lighteval.models.endpoint_model import InferenceEndpointModel
from huggingface_hub import TextGenerationOutput
from lighteval.models.utils import retry_with_backoff
from transformers import AutoTokenizer
from openai import AsyncOpenAI

class OAIModelClient(InferenceEndpointModel):
_DEFAULT_MAX_LENGTH: int = 4096

def __init__(self, address, model_id, auth_token=None) -> None:
self.client = AsyncOpenAI(base_url=address, api_key=(auth_token or "none"))
self.model_id = model_id
self._max_gen_toks = 256

self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self._add_special_tokens = True
self.use_async = True

async def _async_process_request(
self, context: str, stop_tokens: list[str], max_tokens: int
) -> Coroutine[None, TextGenerationOutput, str]:
# Todo: add an option to launch with conversational instead for chat prompts
output = await retry_with_backoff(lambda: self.client.completions.create(
model="/repository",
prompt=context,
max_tokens=max_tokens,
stop=stop_tokens
))

return TextGenerationOutput(generated_text=output.choices[0].text)

def _process_request(self, *args, **kwargs) -> TextGenerationOutput:
return asyncio.run(self._async_process_request(*args, **kwargs))

def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook

@property
def tokenizer(self):
return self._tokenizer

@property
def add_special_tokens(self):
return self._add_special_tokens

@property
def max_length(self) -> int:
if hasattr(self.tokenizer, "model_max_length"):
return self.tokenizer.model_max_length
return OAIModelClient._DEFAULT_MAX_LENGTH

@property
def disable_tqdm(self) -> bool:
False

def cleanup(self):
pass
18 changes: 18 additions & 0 deletions src/lighteval/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import asyncio
import logging
import os
from itertools import islice
from typing import Optional, Union
Expand Down Expand Up @@ -98,3 +100,19 @@ def batched(iterable, n):
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch

import random
MAX_RETRIES = 15
INITIAL_BACKOFF = 1
async def retry_with_backoff(coro_fn):
for attempt in range(MAX_RETRIES):
try:
return await coro_fn()
except Exception as e:
if attempt < MAX_RETRIES - 1:
backoff_time = INITIAL_BACKOFF * (2 ** attempt) + random.uniform(0, 1) # used to be 2 **, but waited too long
logging.info(e)
logging.info(f'Encountered error, backing off and retrying in {backoff_time}s...')
await asyncio.sleep(backoff_time)
else:
raise e