Skip to content

Commit

Permalink
adding serverless endpoints back
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Dec 12, 2024
1 parent f62cc89 commit 858d3d1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
15 changes: 8 additions & 7 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR: str = os.getenv("HF_HOME", "/scratch")

HELP_PANNEL_NAME_1 = "Common Paramaters"
HELP_PANNEL_NAME_1 = "Common Parameters"
HELP_PANNEL_NAME_2 = "Logging Parameters"
HELP_PANNEL_NAME_3 = "Debug Paramaters"
HELP_PANNEL_NAME_4 = "Modeling Paramaters"
HELP_PANNEL_NAME_3 = "Debug Parameters"
HELP_PANNEL_NAME_4 = "Modeling Parameters"


@app.command(rich_help_panel="Evaluation Backends")
Expand Down Expand Up @@ -93,7 +93,7 @@ def openai(
Evaluate OPENAI models.
"""
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.model_config import OpenAIModelConfig
from lighteval.models.endpoints.openai_model import OpenAIModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
Expand Down Expand Up @@ -147,9 +147,10 @@ def inference_endpoint(
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
free_endpoint: Annotated[
str,
Argument(
help="True if you want to use the serverless free endpoints, False (default) if you want to spin up your own inference endpoint."
bool,
Option(
help="Use serverless free endpoints instead of spinning up your own inference endpoint.",
rich_help_panel=HELP_PANNEL_NAME_4,
),
] = False,
# === Common parameters ===
Expand Down
19 changes: 12 additions & 7 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@

@dataclass
class ServerlessEndpointModelConfig:
model: str
model_name: str
add_special_tokens: bool = True

@classmethod
def from_path(cls, path: str) -> "InferenceEndpointModelConfig":
def from_path(cls, path: str) -> "ServerlessEndpointModelConfig":
import yaml

with open(path, "r") as f:
Expand Down Expand Up @@ -282,10 +282,10 @@ def __init__( # noqa: C901
else: # Free inference client
self.endpoint = None
self.endpoint_name = None
self.name = config.model
self.name = config.model_name
self.revision = "default"
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
self.client = InferenceClient(model=config.model, token=env_config.token)
self.async_client = AsyncInferenceClient(model=config.model_name, token=env_config.token)
self.client = InferenceClient(model=config.model_name, token=env_config.token)

self.use_async = True # set to False for debug - async use is faster

Expand All @@ -295,7 +295,7 @@ def __init__( # noqa: C901
self.model_info = ModelInfo(
model_name=self.name,
model_sha=self.revision,
model_dtype=config.model_dtype or "default",
model_dtype=getattr(config, "model_dtype", "default"),
model_size=-1,
)

Expand Down Expand Up @@ -547,7 +547,12 @@ def loglikelihood(
cont_toks = torch.tensor(cur_request.tokenized_continuation)
len_choice = len(cont_toks)

logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None]
if self.endpoint: # inference endpoint
logits = [
t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None
] # to check
else: # serverless endpoint
logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None]

greedy_tokens = torch.tensor(logits).argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all().squeeze(0)
Expand Down

0 comments on commit 858d3d1

Please sign in to comment.