Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added custom model inference. #437

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5909d4a
Added first version of custom model.
JoelNiklaus Dec 11, 2024
a2d6b63
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 11, 2024
2283c89
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 11, 2024
9563fab
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 12, 2024
319d482
Merge branch 'main' into add-custom-model
clefourrier Dec 12, 2024
464edfe
Merge branch 'main' into add-custom-model
clefourrier Dec 12, 2024
6096042
Moved custom model config.
JoelNiklaus Dec 12, 2024
a7e1fe5
Added warning.
JoelNiklaus Dec 12, 2024
24b8bd3
Added custom model example for google translate.
JoelNiklaus Dec 12, 2024
c177a8e
Added documentation for custom model config.
JoelNiklaus Dec 12, 2024
d712cdb
Added docs.
JoelNiklaus Dec 12, 2024
7553147
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 12, 2024
b41949c
Fixed path error.
JoelNiklaus Dec 12, 2024
aaaadb0
Fixed doc error.
JoelNiklaus Dec 12, 2024
c85065f
Added requirements file for google translate.
JoelNiklaus Dec 12, 2024
f1103da
Moved model loading function to reduce merge conflicts with litellm i…
JoelNiklaus Dec 12, 2024
71f871e
Added diskcache and get source and target language from the task name.
JoelNiklaus Dec 12, 2024
d1af518
Fixed problem with removing languages in the context.
JoelNiklaus Dec 12, 2024
2511158
Added retry logic.
JoelNiklaus Dec 13, 2024
7d5f76d
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 16, 2024
743a284
Update google-translate requirements.
JoelNiklaus Dec 16, 2024
1a37f71
Added another example for a custom model.
JoelNiklaus Dec 17, 2024
2f27645
Made local mt model example more general to support madlad400 as well.
JoelNiklaus Dec 17, 2024
a4d4fee
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 17, 2024
bd08781
Merge branch 'main' into add-custom-model
clefourrier Dec 18, 2024
b7106e4
Make sure generation can happen on the GPU.
JoelNiklaus Dec 18, 2024
a7d176c
Fixed issue with src and tgt lang for seamless model.
JoelNiklaus Dec 19, 2024
f1ba65c
Added cleanup to free the GPU memory again.
JoelNiklaus Dec 19, 2024
ace6e59
Fix dependency issues by switching to deep-translator.
JoelNiklaus Dec 22, 2024
cfd7254
Made inference code more robust against empty responses.
JoelNiklaus Dec 22, 2024
3ddc104
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 23, 2024
f6df2a3
Merge branch 'main' into add-custom-model
clefourrier Jan 2, 2025
348e427
Merge branch 'main' into add-custom-model
JoelNiklaus Jan 7, 2025
a63f4b3
Merge branch 'main' into add-custom-model
JoelNiklaus Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
title: Add a custom task
- local: adding-a-new-metric
title: Add a custom metric
- local: evaluating-a-custom-model
title: Evaluate a custom model
- local: use-vllm-as-backend
title: Use VLLM as backend
- local: evaluate-the-model-on-a-server-or-container
Expand Down
129 changes: 129 additions & 0 deletions docs/source/evaluating-a-custom-model.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Evaluating a Custom Model

Lighteval allows you to evaluate custom model implementations by creating a custom model class that inherits from `LightevalModel`. This is useful when you want to evaluate models that aren't directly supported by the standard backends (transformers, vllm, etc).

## Creating a Custom Model

1. Create a Python file containing your custom model implementation. The model must inherit from `LightevalModel` and implement all required methods.

Here's a basic example:

```python
from lighteval.models.abstract_model import LightevalModel

class MyCustomModel(LightevalModel):
def __init__(self, config, env_config):
super().__init__(config, env_config)
# Initialize your model here...

def greedy_until(self, requests, max_tokens=None, stop_sequences=None):
# Implement generation logic
pass

def loglikelihood(self, requests, log=True):
# Implement loglikelihood computation
pass

def loglikelihood_rolling(self, requests):
# Implement rolling loglikelihood computation
pass

def loglikelihood_single_token(self, requests):
# Implement single token loglikelihood computation
pass
```

2. The custom model file should contain exactly one class that inherits from `LightevalModel`. This class will be automatically detected and instantiated when loading the model.

> [!TIP]
> You can find a complete example of a custom model implementation in `examples/custom_models/google_translate_model.py`.

## Running the Evaluation

You can evaluate your custom model using either the command line interface or the Python API.

### Using the Command Line

```bash
python -m lighteval custom \
"google-translate" \
"examples/custom_models/google_translate_model.py" \
"lighteval|wmt20:fr-de|0|0" \
--output-dir results \
--max-samples 10
```

The command takes three required arguments:
- The model name (used for tracking in results/logs)
- The path to your model implementation file
- The tasks to evaluate on (same format as other backends)

### Using the Python API

```python
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.custom.custom_model import CustomModelConfig
from lighteval.pipeline import Pipeline, PipelineParameters, EnvConfig

# Set up evaluation tracking
evaluation_tracker = EvaluationTracker(
output_dir="results",
save_details=True
)

# Configure the pipeline
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.CUSTOM,
env_config=EnvConfig(cache_dir="tmp/")
)

# Configure your custom model
model_config = CustomModelConfig(
model="my-custom-model",
model_definition_file_path="path/to/my_model.py"
)

# Create and run the pipeline
pipeline = Pipeline(
tasks="leaderboard|truthfulqa:mc|0|0",
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model_config=model_config
)

pipeline.evaluate()
pipeline.save_and_push_results()
```

## Required Methods

Your custom model must implement these core methods:

- `greedy_until`: For generating text until a stop sequence or max tokens is reached
- `loglikelihood`: For computing log probabilities of specific continuations
- `loglikelihood_rolling`: For computing rolling log probabilities of sequences
- `loglikelihood_single_token`: For computing log probabilities of single tokens

See the `LightevalModel` base class documentation for detailed method signatures and requirements.

## Best Practices

1. **Error Handling**: Implement robust error handling in your model methods to gracefully handle edge cases.

2. **Batching**: Consider implementing efficient batching in your model methods to improve performance.

3. **Resource Management**: Properly manage any resources (e.g., API connections, model weights) in your model's `__init__` and `__del__` methods.

4. **Documentation**: Add clear docstrings to your model class and methods explaining any specific requirements or limitations.

## Example Use Cases

Custom models are particularly useful for:

- Evaluating models accessed through custom APIs
- Wrapping models with specialized preprocessing/postprocessing
- Testing novel model architectures
- Evaluating ensemble models
- Integrating with external services or tools

For a complete example of a custom model that wraps the Google Translate API, see `examples/custom_models/google_translate_model.py`.
4 changes: 4 additions & 0 deletions docs/source/package_reference/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
[[autodoc]] models.endpoints.tgi_model.TGIModelConfig
[[autodoc]] models.endpoints.tgi_model.ModelClient

### Custom Model
[[autodoc]] models.custom.custom_model.CustomModelConfig
[[autodoc]] models.custom.custom_model.CustomModel

### Open AI Models
[[autodoc]] models.endpoints.openai_model.OpenAIClient

Expand Down
200 changes: 200 additions & 0 deletions examples/custom_models/google_translate_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import hashlib
import logging
import os
import time
from typing import Optional

import diskcache
import tenacity
from deep_translator import GoogleTranslator
from tqdm import tqdm
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
LoglikelihoodSingleTokenResponse,
)
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)


logger = logging.getLogger(__name__)


class GoogleTranslateClient(LightevalModel):
def __init__(self, config, env_config) -> None:
self.model = config.model
self.model_definition_file_path = config.model_definition_file_path

self.model_info = ModelInfo(
model_name=config.model,
model_sha="",
model_dtype=None,
model_size="",
)

self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility

# Deep-translator also supports other translators
self.translator = GoogleTranslator()

# Initialize disk cache
cache_dir = os.path.join(os.getcwd(), ".translation_cache")
self.cache = diskcache.Cache(cache_dir)

self.max_retries = 3
self.retry_delay = 1

def _get_cache_key(self, context: str, src_lang: str, tgt_lang: str) -> str:
"""Generate a unique cache key for the translation request."""
# IMPORTANT: In case we want to support other translators, we can add the translator name to the key
key_string = f"{context}|{src_lang}|{tgt_lang}"
return hashlib.md5(key_string.encode()).hexdigest()

@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
retry=tenacity.retry_if_exception_type((Exception)),
before_sleep=lambda retry_state: time.sleep(1),
)
def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> str:
"""Translate text using cache if available, otherwise call Google Translate with retry logic."""
cache_key = self._get_cache_key(context, src_lang, tgt_lang)

# Try to get from cache
if cache_key in self.cache:
result = self.cache[cache_key]
if result is not None and result != "":
return result
logger.warning("Translation in cache is empty. Removing from cache and retrying...")
del self.cache[cache_key]

try:
# Updated translation call for deep-translator
self.translator.source = src_lang
self.translator.target = tgt_lang
result = self.translator.translate(context)
if result is None or result == "":
result = ""

self.cache[cache_key] = result
return result
except Exception as e:
logger.warning(f"Translation error: {str(e)}. Retrying...")
raise # Let tenacity handle the retry

def greedy_until(
self,
requests: list[GreedyUntilRequest],
override_bs: Optional[int] = None,
) -> list[GenerativeResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Results are cached to disk to avoid repeated translations.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.
override_bs (int, optional): Override the batch size for generation. Defaults to None.

Returns:
list[GenerativeResponse]: list of generated responses.
"""
for request in requests:
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
results = []

for _ in tqdm(
dataset.splits_start_end_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=False, # self.disable_tqdm,
):
for r in tqdm(dataset, desc="Batch", position=1, disable=False):
# Extract source and target languages from task name
# Format is like "community|sdst-text_level:de-fr|0"
src_lang, tgt_lang = r.task_name.split("|")[1].split(":")[-1].split("-")

context = r.context.replace(f"{src_lang.upper()}: ", "").replace(f"\n{tgt_lang.upper()}: ", "")
result = self._translate_with_cache(context, src_lang, tgt_lang)
if result is None:
result = "" # Set to empty string to prevent errors in metric computation

cur_response = GenerativeResponse(
result=result,
logits=None,
generated_tokens=[],
input_tokens=[],
)
results.append(cur_response)

return dataset.get_original_order(results)

@property
def tokenizer(self):
return self._tokenizer

def tok_encode(self, text: str):
return text

@property
def add_special_tokens(self) -> bool:
return False

@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
return 4096

def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError

def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
raise NotImplementedError

def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError
Loading