diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3daa4052..9ad55466 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -33,9 +33,7 @@ - local: package_reference/evaluation_tracker title: EvaluationTracker - local: package_reference/models - title: Models - - local: package_reference/model_config - title: ModelConfig + title: Models and ModelConfigs - local: package_reference/pipeline title: Pipeline title: Main classes diff --git a/docs/source/package_reference/model_config.mdx b/docs/source/package_reference/model_config.mdx deleted file mode 100644 index e2ecceb4..00000000 --- a/docs/source/package_reference/model_config.mdx +++ /dev/null @@ -1,10 +0,0 @@ -# ModelConfig - -[[autodoc]] models.model_config.BaseModelConfig - -[[autodoc]] models.model_config.AdapterModelConfig -[[autodoc]] models.model_config.DeltaModelConfig -[[autodoc]] models.model_config.InferenceEndpointModelConfig -[[autodoc]] models.model_config.InferenceModelConfig -[[autodoc]] models.model_config.TGIModelConfig -[[autodoc]] models.model_config.VLLMModelConfig diff --git a/docs/source/package_reference/models.mdx b/docs/source/package_reference/models.mdx index 34b5b273..096ce7be 100644 --- a/docs/source/package_reference/models.mdx +++ b/docs/source/package_reference/models.mdx @@ -4,24 +4,38 @@ ### LightevalModel [[autodoc]] models.abstract_model.LightevalModel + ## Accelerate and Transformers Models ### BaseModel -[[autodoc]] models.base_model.BaseModel +[[autodoc]] models.transformers.base_model.BaseModelConfig +[[autodoc]] models.transformers.base_model.BaseModel + ### AdapterModel -[[autodoc]] models.adapter_model.AdapterModel +[[autodoc]] models.transformers.adapter_model.AdapterModelConfig +[[autodoc]] models.transformers.adapter_model.AdapterModel + ### DeltaModel -[[autodoc]] models.delta_model.DeltaModel +[[autodoc]] models.transformers.delta_model.DeltaModelConfig +[[autodoc]] models.transformers.delta_model.DeltaModel -## Inference Endpoints and TGI Models +## Endpoints-based Models ### InferenceEndpointModel -[[autodoc]] models.endpoint_model.InferenceEndpointModel -### ModelClient -[[autodoc]] models.tgi_model.ModelClient +[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig +[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig +[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel + +### TGI ModelClient +[[autodoc]] models.endpoints.tgi_model.TGIModelConfig +[[autodoc]] models.endpoints.tgi_model.ModelClient + +### Open AI Models +[[autodoc]] models.endpoints.openai_model.OpenAIClient ## Nanotron Model ### NanotronLightevalModel -[[autodoc]] models.nanotron_model.NanotronLightevalModel +[[autodoc]] models.nanotron.nanotron_model.NanotronLightevalModel ## VLLM Model ### VLLMModel -[[autodoc]] models.vllm_model.VLLMModel +[[autodoc]] models.vllm.vllm_model.VLLMModelConfig +[[autodoc]] models.vllm.vllm_model.VLLMModel diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index e7d18c80..27e4141f 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -107,7 +107,9 @@ def accelerate( # noqa C901 from accelerate import Accelerator, InitProcessGroupKwargs from lighteval.logging.evaluation_tracker import EvaluationTracker - from lighteval.models.model_config import AdapterModelConfig, BaseModelConfig, BitsAndBytesConfig, DeltaModelConfig + from lighteval.models.transformers.adapter_model import AdapterModelConfig + from lighteval.models.transformers.base_model import BaseModelConfig, BitsAndBytesConfig + from lighteval.models.transformers.delta_model import DeltaModelConfig from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 5069c414..d17da432 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -201,7 +201,7 @@ def inference_endpoint( import yaml from lighteval.logging.evaluation_tracker import EvaluationTracker - from lighteval.models.model_config import ( + from lighteval.models.endpoints.endpoint_model import ( InferenceEndpointModelConfig, ) from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters diff --git a/src/lighteval/main_vllm.py b/src/lighteval/main_vllm.py index 4bd1681d..078000da 100644 --- a/src/lighteval/main_vllm.py +++ b/src/lighteval/main_vllm.py @@ -89,7 +89,7 @@ def vllm( Evaluate models using vllm as backend. """ from lighteval.logging.evaluation_tracker import EvaluationTracker - from lighteval.models.model_config import VLLMModelConfig + from lighteval.models.vllm.vllm_model import VLLMModelConfig from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters TOKEN = os.getenv("HF_TOKEN") diff --git a/src/lighteval/models/dummy_model.py b/src/lighteval/models/dummy/dummy_model.py similarity index 97% rename from src/lighteval/models/dummy_model.py rename to src/lighteval/models/dummy/dummy_model.py index b9fa60e0..ff89656b 100644 --- a/src/lighteval/models/dummy_model.py +++ b/src/lighteval/models/dummy/dummy_model.py @@ -23,12 +23,12 @@ # inspired by https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/dummy.py import random +from dataclasses import dataclass from typing import Optional from transformers import AutoTokenizer from lighteval.models.abstract_model import LightevalModel, ModelInfo -from lighteval.models.model_config import DummyModelConfig from lighteval.models.model_output import GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse from lighteval.tasks.requests import ( GreedyUntilRequest, @@ -39,6 +39,11 @@ from lighteval.utils.utils import EnvConfig +@dataclass +class DummyModelConfig: + seed: int = 42 + + class DummyModel(LightevalModel): """Dummy model to generate random baselines.""" diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py similarity index 91% rename from src/lighteval/models/endpoint_model.py rename to src/lighteval/models/endpoints/endpoint_model.py index bd82f058..11233896 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -24,7 +24,8 @@ import logging import re import time -from typing import Coroutine, List, Optional, Union +from dataclasses import dataclass +from typing import Coroutine, Dict, List, Optional, Union import requests import torch @@ -47,7 +48,6 @@ from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset from lighteval.models.abstract_model import LightevalModel, ModelInfo -from lighteval.models.model_config import InferenceEndpointModelConfig, InferenceModelConfig from lighteval.models.model_output import GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse from lighteval.tasks.requests import ( GreedyUntilRequest, @@ -74,6 +74,59 @@ ] +@dataclass +class InferenceModelConfig: + model: str + add_special_tokens: bool = True + + +@dataclass +class InferenceEndpointModelConfig: + endpoint_name: str = None + model_name: str = None + should_reuse_existing: bool = False + accelerator: str = "gpu" + model_dtype: str = None # if empty, we use the default + vendor: str = "aws" + region: str = "us-east-1" # this region has the most hardware options available + instance_size: str = None # if none, we autoscale + instance_type: str = None # if none, we autoscale + framework: str = "pytorch" + endpoint_type: str = "protected" + add_special_tokens: bool = True + revision: str = "main" + namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace + image_url: str = None + env_vars: dict = None + + def __post_init__(self): + # xor operator, one is None but not the other + if (self.instance_size is None) ^ (self.instance_type is None): + raise ValueError( + "When creating an inference endpoint, you need to specify explicitely both instance_type and instance_size, or none of them for autoscaling." + ) + + if not (self.endpoint_name is None) ^ int(self.model_name is None): + raise ValueError("You need to set either endpoint_name or model_name (but not both).") + + def get_dtype_args(self) -> Dict[str, str]: + if self.model_dtype is None: + return {} + model_dtype = self.model_dtype.lower() + if model_dtype in ["awq", "eetq", "gptq"]: + return {"QUANTIZE": model_dtype} + if model_dtype == "8bit": + return {"QUANTIZE": "bitsandbytes"} + if model_dtype == "4bit": + return {"QUANTIZE": "bitsandbytes-nf4"} + if model_dtype in ["bfloat16", "float16"]: + return {"DTYPE": model_dtype} + return {} + + def get_custom_env_vars(self) -> Dict[str, str]: + return {k: str(v) for k, v in self.env_vars.items()} if self.env_vars else {} + + class InferenceEndpointModel(LightevalModel): """InferenceEndpointModels can be used both with the free inference client, or with inference endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation. diff --git a/src/lighteval/models/openai_model.py b/src/lighteval/models/endpoints/openai_model.py similarity index 98% rename from src/lighteval/models/openai_model.py rename to src/lighteval/models/endpoints/openai_model.py index 12fbeb95..b2ca2528 100644 --- a/src/lighteval/models/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -24,13 +24,14 @@ import os import time from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from typing import Optional from tqdm import tqdm from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset from lighteval.models.abstract_model import LightevalModel -from lighteval.models.endpoint_model import ModelInfo +from lighteval.models.endpoints.endpoint_model import ModelInfo from lighteval.models.model_output import ( GenerativeResponse, LoglikelihoodResponse, @@ -58,6 +59,11 @@ logging.getLogger("httpx").setLevel(logging.ERROR) +@dataclass +class OpenAIModelConfig: + model: str + + class OpenAIClient(LightevalModel): _DEFAULT_MAX_LENGTH: int = 4096 diff --git a/src/lighteval/models/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py similarity index 94% rename from src/lighteval/models/tgi_model.py rename to src/lighteval/models/endpoints/tgi_model.py index 99d7bd10..d95609a5 100644 --- a/src/lighteval/models/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -21,13 +21,14 @@ # SOFTWARE. import asyncio +from dataclasses import dataclass from typing import Coroutine, Optional import requests from huggingface_hub import TextGenerationInputGrammarType, TextGenerationOutput from transformers import AutoTokenizer -from lighteval.models.endpoint_model import InferenceEndpointModel, ModelInfo +from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel, ModelInfo from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available @@ -44,6 +45,13 @@ def divide_chunks(array, n): yield array[i : i + n] +@dataclass +class TGIModelConfig: + inference_server_address: str + inference_server_auth: str + model_id: str + + # inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite # the client functions, since they use a different client. class ModelClient(InferenceEndpointModel): diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py deleted file mode 100644 index 1eda1e02..00000000 --- a/src/lighteval/models/model_config.py +++ /dev/null @@ -1,314 +0,0 @@ -# MIT License - -# Copyright (c) 2024 The HuggingFace Team - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import logging -from dataclasses import dataclass -from typing import Dict, Optional, Union - -import torch -from transformers import AutoConfig, BitsAndBytesConfig, GPTQConfig, PretrainedConfig - -from lighteval.models.utils import _get_model_sha -from lighteval.utils.imports import ( - NO_AUTOGPTQ_ERROR_MSG, - NO_BNB_ERROR_MSG, - NO_PEFT_ERROR_MSG, - is_accelerate_available, - is_autogptq_available, - is_bnb_available, - is_peft_available, -) -from lighteval.utils.utils import EnvConfig, boolstring_to_bool - - -logger = logging.getLogger(__name__) - -if is_accelerate_available(): - from accelerate import Accelerator - - -@dataclass -class BaseModelConfig: - """ - Base configuration class for models. - - Attributes: - pretrained (str): - HuggingFace Hub model ID name or the path to a pre-trained - model to load. This is effectively the `pretrained_model_name_or_path` - argument of `from_pretrained` in the HuggingFace `transformers` API. - accelerator (Accelerator): accelerator to use for model training. - tokenizer (Optional[str]): HuggingFace Hub tokenizer ID that will be - used for tokenization. - multichoice_continuations_start_space (Optional[bool]): Whether to add a - space at the start of each continuation in multichoice generation. - For example, context: "What is the capital of France?" and choices: "Paris", "London". - Will be tokenized as: "What is the capital of France? Paris" and "What is the capital of France? London". - True adds a space, False strips a space, None does nothing - pairwise_tokenization (bool): Whether to tokenize the context and continuation as separately or together. - subfolder (Optional[str]): The subfolder within the model repository. - revision (str): The revision of the model. - batch_size (int): The batch size for model training. - max_gen_toks (Optional[int]): The maximum number of tokens to generate. - max_length (Optional[int]): The maximum length of the generated output. - add_special_tokens (bool, optional, defaults to True): Whether to add special tokens to the input sequences. - If `None`, the default value will be set to `True` for seq2seq models (e.g. T5) and - `False` for causal models. - model_parallel (bool, optional, defaults to False): - True/False: force to use or not the `accelerate` library to load a large - model across multiple devices. - Default: None which corresponds to comparing the number of processes with - the number of GPUs. If it's smaller => model-parallelism, else not. - dtype (Union[str, torch.dtype], optional, defaults to None):): - Converts the model weights to `dtype`, if specified. Strings get - converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`). - Use `dtype="auto"` to derive the type from the model's weights. - device (Union[int, str]): device to use for model training. - quantization_config (Optional[BitsAndBytesConfig]): quantization - configuration for the model, manually provided to load a normally floating point - model at a quantized precision. Needed for 4-bit and 8-bit precision. - trust_remote_code (bool): Whether to trust remote code during model - loading. - - Methods: - __post_init__(): Performs post-initialization checks on the configuration. - _init_configs(model_name, env_config): Initializes the model configuration. - init_configs(env_config): Initializes the model configuration using the environment configuration. - get_model_sha(): Retrieves the SHA of the model. - - """ - - pretrained: str - accelerator: "Accelerator" = None - tokenizer: Optional[str] = None - multichoice_continuations_start_space: Optional[bool] = None - pairwise_tokenization: bool = False - subfolder: Optional[str] = None - revision: str = "main" - batch_size: int = -1 - max_gen_toks: Optional[int] = 256 - max_length: Optional[int] = None - add_special_tokens: bool = True - model_parallel: Optional[bool] = None - dtype: Optional[Union[str, torch.dtype]] = None - device: Union[int, str] = "cuda" - quantization_config: Optional[BitsAndBytesConfig] = None - trust_remote_code: bool = False - use_chat_template: bool = False - compile: bool = False - - def __post_init__(self): - # Making sure this parameter is a boolean - self.multichoice_continuations_start_space = boolstring_to_bool(self.multichoice_continuations_start_space) - - if self.multichoice_continuations_start_space is not None: - if self.multichoice_continuations_start_space: - logger.info( - "You set `multichoice_continuations_start_space` to true. This will force multichoice continuations to use a starting space" - ) - else: - logger.info( - "You set `multichoice_continuations_start_space` to false. This will remove a leading space from multichoice continuations, if present." - ) - - self.model_parallel = boolstring_to_bool(self.model_parallel) - self.compile = boolstring_to_bool(self.compile) - - if self.quantization_config is not None and not is_bnb_available(): - raise ImportError(NO_BNB_ERROR_MSG) - - if not isinstance(self.pretrained, str): - raise ValueError("Pretrained model name must be passed as string.") - if not isinstance(self.device, str): - raise ValueError("Current device must be passed as string.") - - def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedConfig: - revision = self.revision - if self.subfolder: - revision = f"{self.revision}/{self.subfolder}" - auto_config = AutoConfig.from_pretrained( - model_name, - revision=revision, - trust_remote_code=self.trust_remote_code, - cache_dir=env_config.cache_dir, - token=env_config.token, - ) - - # Gathering the model's automatic quantization config, if available - try: - model_auto_quantization_config = auto_config.quantization_config - logger.info("An automatic quantization config was found in the model's config. Using it to load the model") - except (AttributeError, KeyError): - model_auto_quantization_config = None - - if model_auto_quantization_config is not None: - if self.quantization_config is not None: - # We don't load models quantized by default with a different user provided conf - raise ValueError("You manually requested quantization on a model already quantized!") - - # We add the quantization to the model params we store - if model_auto_quantization_config["quant_method"] == "gptq": - if not is_autogptq_available(): - raise ImportError(NO_AUTOGPTQ_ERROR_MSG) - auto_config.quantization_config["use_exllama"] = None - self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True) - elif model_auto_quantization_config["quant_method"] == "bitsandbytes": - if not is_bnb_available(): - raise ImportError(NO_BNB_ERROR_MSG) - self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config) - - return auto_config - - def init_configs(self, env_config: EnvConfig) -> PretrainedConfig: - return self._init_configs(self.pretrained, env_config=env_config) - - def get_model_sha(self): - return _get_model_sha(repo_id=self.pretrained, revision=self.revision) - - -@dataclass -class DeltaModelConfig(BaseModelConfig): - # Delta models look at the pretrained (= the delta weights) for the tokenizer and model config - base_model: str = None - - def __post_init__(self): - self.revision = "main" - - if not self.base_model: # must have a default value bc of dataclass inheritance, but can't actually be None - raise ValueError("The base_model argument must not be null for a delta model config") - - return super().__post_init__() - - def get_model_sha(self): - return _get_model_sha(repo_id=self.pretrained, revision="main") - - -@dataclass -class AdapterModelConfig(BaseModelConfig): - # Adapter models have the specificity that they look at the base model (= the parent) for the tokenizer and config - base_model: str = None - - def __post_init__(self): - if not is_peft_available(): - raise ImportError(NO_PEFT_ERROR_MSG) - - if not self.base_model: # must have a default value bc of dataclass inheritance, but can't actually be None - raise ValueError("The base_model argument must not be null for an adapter model config") - - return super().__post_init__() - - def init_configs(self, env_config: EnvConfig): - return self._init_configs(self.base_model, env_config) - - -@dataclass -class VLLMModelConfig: - pretrained: str - gpu_memory_utilisation: float = 0.9 # lower this if you are running out of memory - revision: str = "main" # revision of the model - dtype: str | None = None - tensor_parallel_size: int = 1 # how many GPUs to use for tensor parallelism - pipeline_parallel_size: int = 1 # how many GPUs to use for pipeline parallelism - data_parallel_size: int = 1 # how many GPUs to use for data parallelism - max_model_length: int | None = None # maximum length of the model, ussually infered automatically. reduce this if you encouter OOM issues, 4096 is usually enough - swap_space: int = 4 # CPU swap space size (GiB) per GPU. - seed: int = 1234 - trust_remote_code: bool = False - use_chat_template: bool = False - add_special_tokens: bool = True - multichoice_continuations_start_space: bool = ( - True # whether to add a space at the start of each continuation in multichoice generation - ) - pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together. - - subfolder: Optional[str] = None - temperature: float = 0.6 # will be used for multi sampling tasks, for tasks requiring no sampling, this will be ignored and set to 0. - - -@dataclass -class OpenAIModelConfig: - model: str - - -@dataclass -class TGIModelConfig: - inference_server_address: str - inference_server_auth: str - model_id: str - - -@dataclass -class DummyModelConfig: - seed: int = 42 - - -@dataclass -class InferenceModelConfig: - model: str - add_special_tokens: bool = True - - -@dataclass -class InferenceEndpointModelConfig: - endpoint_name: str = None - model_name: str = None - should_reuse_existing: bool = False - accelerator: str = "gpu" - model_dtype: str = None # if empty, we use the default - vendor: str = "aws" - region: str = "us-east-1" # this region has the most hardware options available - instance_size: str = None # if none, we autoscale - instance_type: str = None # if none, we autoscale - framework: str = "pytorch" - endpoint_type: str = "protected" - add_special_tokens: bool = True - revision: str = "main" - namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace - image_url: str = None - env_vars: dict = None - - def __post_init__(self): - # xor operator, one is None but not the other - if (self.instance_size is None) ^ (self.instance_type is None): - raise ValueError( - "When creating an inference endpoint, you need to specify explicitely both instance_type and instance_size, or none of them for autoscaling." - ) - - if not (self.endpoint_name is None) ^ int(self.model_name is None): - raise ValueError("You need to set either endpoint_name or model_name (but not both).") - - def get_dtype_args(self) -> Dict[str, str]: - if self.model_dtype is None: - return {} - model_dtype = self.model_dtype.lower() - if model_dtype in ["awq", "eetq", "gptq"]: - return {"QUANTIZE": model_dtype} - if model_dtype == "8bit": - return {"QUANTIZE": "bitsandbytes"} - if model_dtype == "4bit": - return {"QUANTIZE": "bitsandbytes-nf4"} - if model_dtype in ["bfloat16", "float16"]: - return {"DTYPE": model_dtype} - return {} - - def get_custom_env_vars(self) -> Dict[str, str]: - return {k: str(v) for k, v in self.env_vars.items()} if self.env_vars else {} diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 1a409746..b0817be4 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -23,25 +23,18 @@ import logging from typing import Union -from lighteval.models.adapter_model import AdapterModel -from lighteval.models.base_model import BaseModel -from lighteval.models.delta_model import DeltaModel -from lighteval.models.dummy_model import DummyModel -from lighteval.models.endpoint_model import InferenceEndpointModel -from lighteval.models.model_config import ( - AdapterModelConfig, - BaseModelConfig, - DeltaModelConfig, - DummyModelConfig, +from lighteval.models.dummy.dummy_model import DummyModel, DummyModelConfig +from lighteval.models.endpoints.endpoint_model import ( + InferenceEndpointModel, InferenceEndpointModelConfig, InferenceModelConfig, - OpenAIModelConfig, - TGIModelConfig, - VLLMModelConfig, ) -from lighteval.models.openai_model import OpenAIClient -from lighteval.models.tgi_model import ModelClient -from lighteval.models.vllm_model import VLLMModel +from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig +from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig +from lighteval.models.transformers.adapter_model import AdapterModel, AdapterModelConfig +from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig +from lighteval.models.transformers.delta_model import DeltaModel, DeltaModelConfig +from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig from lighteval.utils.imports import ( NO_TGI_ERROR_MSG, NO_VLLM_ERROR_MSG, diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron/nanotron_model.py similarity index 99% rename from src/lighteval/models/nanotron_model.py rename to src/lighteval/models/nanotron/nanotron_model.py index 21b60504..b7e9b1a5 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron/nanotron_model.py @@ -42,13 +42,13 @@ LoglikelihoodDataset, LoglikelihoodSingleTokenDataset, ) -from lighteval.models.base_model import LightevalModel, ModelInfo from lighteval.models.model_output import ( Batch, GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse, ) +from lighteval.models.transformers.base_model import LightevalModel, ModelInfo from lighteval.tasks.requests import ( GreedyUntilRequest, LoglikelihoodRequest, diff --git a/src/lighteval/models/adapter_model.py b/src/lighteval/models/transformers/adapter_model.py similarity index 80% rename from src/lighteval/models/adapter_model.py rename to src/lighteval/models/transformers/adapter_model.py index 24de80f4..449c2c1a 100644 --- a/src/lighteval/models/adapter_model.py +++ b/src/lighteval/models/transformers/adapter_model.py @@ -22,14 +22,14 @@ import logging from contextlib import nullcontext +from dataclasses import dataclass import torch from transformers import AutoModelForCausalLM, PreTrainedTokenizer -from lighteval.models.base_model import BaseModel -from lighteval.models.model_config import AdapterModelConfig +from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig from lighteval.models.utils import _get_dtype -from lighteval.utils.imports import is_peft_available +from lighteval.utils.imports import NO_PEFT_ERROR_MSG, is_peft_available from lighteval.utils.utils import EnvConfig @@ -39,6 +39,24 @@ from peft import PeftModel +@dataclass +class AdapterModelConfig(BaseModelConfig): + # Adapter models have the specificity that they look at the base model (= the parent) for the tokenizer and config + base_model: str = None + + def __post_init__(self): + if not is_peft_available(): + raise ImportError(NO_PEFT_ERROR_MSG) + + if not self.base_model: # must have a default value bc of dataclass inheritance, but can't actually be None + raise ValueError("The base_model argument must not be null for an adapter model config") + + return super().__post_init__() + + def init_configs(self, env_config: EnvConfig): + return self._init_configs(self.base_model, env_config) + + class AdapterModel(BaseModel): def _create_auto_tokenizer(self, config: AdapterModelConfig, env_config: EnvConfig) -> PreTrainedTokenizer: # By default, we look at the model config for the model stored in `base_model` diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/transformers/base_model.py similarity index 87% rename from src/lighteval/models/base_model.py rename to src/lighteval/models/transformers/base_model.py index fedc56a5..9b815d2b 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/transformers/base_model.py @@ -22,6 +22,7 @@ import logging import os +from dataclasses import dataclass from typing import Optional, Tuple, Union import torch @@ -30,12 +31,18 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + GPTQConfig, + PretrainedConfig, +) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset from lighteval.models.abstract_model import LightevalModel, ModelInfo -from lighteval.models.model_config import BaseModelConfig from lighteval.models.model_output import ( Batch, GenerativeMultiturnResponse, @@ -43,7 +50,7 @@ LoglikelihoodResponse, LoglikelihoodSingleTokenResponse, ) -from lighteval.models.utils import _get_dtype, _simplify_name, batched +from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name, batched from lighteval.tasks.requests import ( GreedyUntilMultiTurnRequest, GreedyUntilRequest, @@ -52,9 +59,15 @@ LoglikelihoodSingleTokenRequest, Request, ) -from lighteval.utils.imports import is_accelerate_available +from lighteval.utils.imports import ( + NO_AUTOGPTQ_ERROR_MSG, + NO_BNB_ERROR_MSG, + is_accelerate_available, + is_autogptq_available, + is_bnb_available, +) from lighteval.utils.parallelism import find_executable_batch_size -from lighteval.utils.utils import EnvConfig, as_list +from lighteval.utils.utils import EnvConfig, as_list, boolstring_to_bool logger = logging.getLogger(__name__) @@ -69,6 +82,145 @@ STARTING_BATCH_SIZE = 512 +@dataclass +class BaseModelConfig: + """ + Base configuration class for models. + + Attributes: + pretrained (str): + HuggingFace Hub model ID name or the path to a pre-trained + model to load. This is effectively the `pretrained_model_name_or_path` + argument of `from_pretrained` in the HuggingFace `transformers` API. + accelerator (Accelerator): accelerator to use for model training. + tokenizer (Optional[str]): HuggingFace Hub tokenizer ID that will be + used for tokenization. + multichoice_continuations_start_space (Optional[bool]): Whether to add a + space at the start of each continuation in multichoice generation. + For example, context: "What is the capital of France?" and choices: "Paris", "London". + Will be tokenized as: "What is the capital of France? Paris" and "What is the capital of France? London". + True adds a space, False strips a space, None does nothing + pairwise_tokenization (bool): Whether to tokenize the context and continuation as separately or together. + subfolder (Optional[str]): The subfolder within the model repository. + revision (str): The revision of the model. + batch_size (int): The batch size for model training. + max_gen_toks (Optional[int]): The maximum number of tokens to generate. + max_length (Optional[int]): The maximum length of the generated output. + add_special_tokens (bool, optional, defaults to True): Whether to add special tokens to the input sequences. + If `None`, the default value will be set to `True` for seq2seq models (e.g. T5) and + `False` for causal models. + model_parallel (bool, optional, defaults to False): + True/False: force to use or not the `accelerate` library to load a large + model across multiple devices. + Default: None which corresponds to comparing the number of processes with + the number of GPUs. If it's smaller => model-parallelism, else not. + dtype (Union[str, torch.dtype], optional, defaults to None):): + Converts the model weights to `dtype`, if specified. Strings get + converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`). + Use `dtype="auto"` to derive the type from the model's weights. + device (Union[int, str]): device to use for model training. + quantization_config (Optional[BitsAndBytesConfig]): quantization + configuration for the model, manually provided to load a normally floating point + model at a quantized precision. Needed for 4-bit and 8-bit precision. + trust_remote_code (bool): Whether to trust remote code during model + loading. + + Methods: + __post_init__(): Performs post-initialization checks on the configuration. + _init_configs(model_name, env_config): Initializes the model configuration. + init_configs(env_config): Initializes the model configuration using the environment configuration. + get_model_sha(): Retrieves the SHA of the model. + + """ + + pretrained: str + accelerator: "Accelerator" = None + tokenizer: Optional[str] = None + multichoice_continuations_start_space: Optional[bool] = None + pairwise_tokenization: bool = False + subfolder: Optional[str] = None + revision: str = "main" + batch_size: int = -1 + max_gen_toks: Optional[int] = 256 + max_length: Optional[int] = None + add_special_tokens: bool = True + model_parallel: Optional[bool] = None + dtype: Optional[Union[str, torch.dtype]] = None + device: Union[int, str] = "cuda" + quantization_config: Optional[BitsAndBytesConfig] = None + trust_remote_code: bool = False + use_chat_template: bool = False + compile: bool = False + + def __post_init__(self): + # Making sure this parameter is a boolean + self.multichoice_continuations_start_space = boolstring_to_bool(self.multichoice_continuations_start_space) + + if self.multichoice_continuations_start_space is not None: + if self.multichoice_continuations_start_space: + logger.info( + "You set `multichoice_continuations_start_space` to true. This will force multichoice continuations to use a starting space" + ) + else: + logger.info( + "You set `multichoice_continuations_start_space` to false. This will remove a leading space from multichoice continuations, if present." + ) + + self.model_parallel = boolstring_to_bool(self.model_parallel) + self.compile = boolstring_to_bool(self.compile) + + if self.quantization_config is not None and not is_bnb_available(): + raise ImportError(NO_BNB_ERROR_MSG) + + if not isinstance(self.pretrained, str): + raise ValueError("Pretrained model name must be passed as string.") + if not isinstance(self.device, str): + raise ValueError("Current device must be passed as string.") + + def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedConfig: + revision = self.revision + if self.subfolder: + revision = f"{self.revision}/{self.subfolder}" + auto_config = AutoConfig.from_pretrained( + model_name, + revision=revision, + trust_remote_code=self.trust_remote_code, + cache_dir=env_config.cache_dir, + token=env_config.token, + ) + + # Gathering the model's automatic quantization config, if available + try: + model_auto_quantization_config = auto_config.quantization_config + logger.info("An automatic quantization config was found in the model's config. Using it to load the model") + except (AttributeError, KeyError): + model_auto_quantization_config = None + + if model_auto_quantization_config is not None: + if self.quantization_config is not None: + # We don't load models quantized by default with a different user provided conf + raise ValueError("You manually requested quantization on a model already quantized!") + + # We add the quantization to the model params we store + if model_auto_quantization_config["quant_method"] == "gptq": + if not is_autogptq_available(): + raise ImportError(NO_AUTOGPTQ_ERROR_MSG) + auto_config.quantization_config["use_exllama"] = None + self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True) + elif model_auto_quantization_config["quant_method"] == "bitsandbytes": + if not is_bnb_available(): + raise ImportError(NO_BNB_ERROR_MSG) + self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config) + + return auto_config + + def init_configs(self, env_config: EnvConfig) -> PretrainedConfig: + return self._init_configs(self.pretrained, env_config=env_config) + + def get_model_sha(self): + return _get_model_sha(repo_id=self.pretrained, revision=self.revision) + + class BaseModel(LightevalModel): def __init__( self, diff --git a/src/lighteval/models/delta_model.py b/src/lighteval/models/transformers/delta_model.py similarity index 81% rename from src/lighteval/models/delta_model.py rename to src/lighteval/models/transformers/delta_model.py index 9aa8c01d..20780f1e 100644 --- a/src/lighteval/models/delta_model.py +++ b/src/lighteval/models/transformers/delta_model.py @@ -22,20 +22,37 @@ import logging from contextlib import nullcontext +from dataclasses import dataclass import torch from tqdm import tqdm from transformers import AutoModelForCausalLM -from lighteval.models.base_model import BaseModel -from lighteval.models.model_config import DeltaModelConfig -from lighteval.models.utils import _get_dtype +from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig +from lighteval.models.utils import _get_dtype, _get_model_sha from lighteval.utils.utils import EnvConfig logger = logging.getLogger(__name__) +@dataclass +class DeltaModelConfig(BaseModelConfig): + # Delta models look at the pretrained (= the delta weights) for the tokenizer and model config + base_model: str = None + + def __post_init__(self): + self.revision = "main" + + if not self.base_model: # must have a default value bc of dataclass inheritance, but can't actually be None + raise ValueError("The base_model argument must not be null for a delta model config") + + return super().__post_init__() + + def get_model_sha(self): + return _get_model_sha(repo_id=self.pretrained, revision="main") + + class DeltaModel(BaseModel): def _create_auto_model( self, diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py similarity index 92% rename from src/lighteval/models/vllm_model.py rename to src/lighteval/models/vllm/vllm_model.py index ecfe8fd8..2d413807 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -24,6 +24,7 @@ import itertools import logging import os +from dataclasses import dataclass from typing import Optional import torch @@ -31,7 +32,6 @@ from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset from lighteval.models.abstract_model import LightevalModel, ModelInfo -from lighteval.models.model_config import VLLMModelConfig from lighteval.models.model_output import ( GenerativeResponse, LoglikelihoodResponse, @@ -66,6 +66,30 @@ STARTING_BATCH_SIZE = 512 +@dataclass +class VLLMModelConfig: + pretrained: str + gpu_memory_utilisation: float = 0.9 # lower this if you are running out of memory + revision: str = "main" # revision of the model + dtype: str | None = None + tensor_parallel_size: int = 1 # how many GPUs to use for tensor parallelism + pipeline_parallel_size: int = 1 # how many GPUs to use for pipeline parallelism + data_parallel_size: int = 1 # how many GPUs to use for data parallelism + max_model_length: int | None = None # maximum length of the model, ussually infered automatically. reduce this if you encouter OOM issues, 4096 is usually enough + swap_space: int = 4 # CPU swap space size (GiB) per GPU. + seed: int = 1234 + trust_remote_code: bool = False + use_chat_template: bool = False + add_special_tokens: bool = True + multichoice_continuations_start_space: bool = ( + True # whether to add a space at the start of each continuation in multichoice generation + ) + pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together. + + subfolder: Optional[str] = None + temperature: float = 0.6 # will be used for multi sampling tasks, for tasks requiring no sampling, this will be ignored and set to 0. + + class VLLMModel(LightevalModel): def __init__( self, diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 9d08ba12..58724242 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -41,7 +41,7 @@ apply_target_perplexity_metric, ) from lighteval.metrics.metrics import Metric, MetricCategory, Metrics -from lighteval.models.base_model import BaseModel +from lighteval.models.transformers.base_model import BaseModel from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import ( Doc, diff --git a/tests/models/test_abstract_model.py b/tests/models/test_abstract_model.py index a598bdc4..e7fc0172 100644 --- a/tests/models/test_abstract_model.py +++ b/tests/models/test_abstract_model.py @@ -22,8 +22,7 @@ from transformers import AutoTokenizer -from lighteval.models.dummy_model import DummyModel -from lighteval.models.model_config import DummyModelConfig +from lighteval.models.dummy.dummy_model import DummyModel, DummyModelConfig from lighteval.utils.utils import EnvConfig diff --git a/tests/models/test_base_model.py b/tests/models/test_base_model.py index dac396f5..4f26d292 100644 --- a/tests/models/test_base_model.py +++ b/tests/models/test_base_model.py @@ -20,9 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from lighteval.models.base_model import BaseModel -from lighteval.models.model_config import BaseModelConfig from lighteval.models.model_loader import load_model +from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig from lighteval.utils.utils import EnvConfig