Skip to content

Commit

Permalink
Implemented the possibility to load predictions from details files an…
Browse files Browse the repository at this point in the history
…d continue evaluating from there.
  • Loading branch information
JoelNiklaus committed Jan 7, 2025
1 parent f6fee3a commit a95156e
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 7 deletions.
31 changes: 29 additions & 2 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,36 @@ def save_results(self, date_id: str, results_dict: dict):
with self.fs.open(output_results_file, "w") as f:
f.write(json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False))

def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
def _get_details_sub_folder(self, date_id: str):
output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name
output_dir_details_sub_folder = output_dir_details / date_id
if date_id == "latest":
# Get all folders in output_dir_details
if not self.fs.exists(output_dir_details):
raise FileNotFoundError(f"Details directory {output_dir_details} does not exist")

# List all folders and filter out files
folders = [f['name'] for f in self.fs.listdir(output_dir_details) if f['type'] == 'directory']

if not folders:
raise FileNotFoundError(f"No timestamp folders found in {output_dir_details}")

# Parse timestamps and get latest
date_id = max(folders)
return output_dir_details / date_id

def load_details_datasets(self, date_id: str) -> dict[str, Dataset]:
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
date_id = output_dir_details_sub_folder.name # Overwrite date_id in case of latest
details_datasets = {}
for file in self.fs.glob(str(output_dir_details_sub_folder / f"details_*_{date_id}.parquet")):
task_name = Path(file).stem.replace(f"details_", "").replace(f"_{date_id}", "")
dataset = load_dataset("parquet", data_files=file, split="train")
details_datasets[task_name] = dataset
return details_datasets


def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
self.fs.mkdirs(output_dir_details_sub_folder, exist_ok=True)
logger.info(f"Saving details to {output_dir_details_sub_folder}")
for task_name, dataset in details_datasets.items():
Expand Down
4 changes: 4 additions & 0 deletions src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def accelerate( # noqa C901
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
Expand Down Expand Up @@ -137,6 +140,7 @@ def accelerate( # noqa C901
max_samples=max_samples,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
load_responses_from_details_date_id=load_responses_from_details_date_id,
)

# TODO (nathan): better handling of model_args
Expand Down
12 changes: 12 additions & 0 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def inference_endpoint(
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
Expand Down Expand Up @@ -247,6 +250,7 @@ def inference_endpoint(
max_samples=max_samples,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
load_responses_from_details_date_id=load_responses_from_details_date_id,
)
pipeline = Pipeline(
tasks=tasks,
Expand Down Expand Up @@ -292,6 +296,9 @@ def tgi(
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
Expand Down Expand Up @@ -355,6 +362,7 @@ def tgi(
max_samples=max_samples,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
load_responses_from_details_date_id=load_responses_from_details_date_id,
)
pipeline = Pipeline(
tasks=tasks,
Expand Down Expand Up @@ -400,6 +408,9 @@ def litellm(
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
Expand Down Expand Up @@ -464,6 +475,7 @@ def litellm(
max_samples=max_samples,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
load_responses_from_details_date_id=load_responses_from_details_date_id,
)
pipeline = Pipeline(
tasks=tasks,
Expand Down
4 changes: 4 additions & 0 deletions src/lighteval/main_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def vllm(
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
Expand Down Expand Up @@ -124,6 +127,7 @@ def vllm(
max_samples=max_samples,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
load_responses_from_details_date_id=load_responses_from_details_date_id,
)

if model_args.endswith(".yaml"):
Expand Down
66 changes: 61 additions & 5 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import ast
import collections
import os
import random
Expand All @@ -34,10 +35,10 @@
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.metrics.utils.metric_utils import MetricCategory
from lighteval.models.model_loader import TransformersModel, load_model
from lighteval.models.model_output import ModelResponse
from lighteval.models.model_output import GenerativeMultiturnResponse, GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse, ModelResponse
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
from lighteval.tasks.registry import Registry, taskinfo_selector
from lighteval.tasks.requests import SampleUid
from lighteval.tasks.requests import RequestType, SampleUid
from lighteval.utils.imports import (
NO_ACCELERATE_ERROR_MSG,
NO_NANOTRON_ERROR_MSG,
Expand Down Expand Up @@ -95,6 +96,7 @@ class PipelineParameters:
max_samples: int | None = None
use_chat_template: bool = False
system_prompt: str | None = None
load_responses_from_details_date_id: str | None = None

def __post_init__(self): # noqa C901
if self.launcher_type == ParallelismManager.ACCELERATE:
Expand Down Expand Up @@ -245,7 +247,11 @@ def evaluate(self):
config=self.model_config,
)

sample_id_to_responses = self._run_model()
if self.pipeline_parameters.load_responses_from_details_date_id:
sample_id_to_responses = self._load_responses_from_details()
else:
sample_id_to_responses = self._run_model()

self._compute_metrics(sample_id_to_responses)

if self.is_main_process():
Expand All @@ -261,6 +267,53 @@ def evaluate(self):
except OSError:
pass


def _load_responses_from_details(self):
logger.info("--- LOADING RESPONSES FROM DETAILS ---")
sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list)

request_types = list(self.requests.keys())
if len(request_types) > 1:
raise ValueError("Loading responses from details when there are multiple request types is currently not supported")
request_type = request_types[0]
if request_type == RequestType.LOGLIKELIHOOD:
model_response_type = LoglikelihoodResponse
elif request_type == RequestType.LOGLIKELIHOOD_SINGLE_TOKEN:
model_response_type = LoglikelihoodSingleTokenResponse
elif request_type == RequestType.LOGLIKELIHOOD_ROLLING:
model_response_type = LoglikelihoodResponse
elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN:
model_response_type = GenerativeMultiturnResponse
elif request_type == RequestType.GREEDY_UNTIL:
model_response_type = GenerativeResponse
else:
raise ValueError(f"Loading responses from details for request type {request_type} is currently not supported")

details_datasets = self.evaluation_tracker.load_details_datasets(self.pipeline_parameters.load_responses_from_details_date_id)
for task_name, dataset in details_datasets.items():
task: LightevalTask = self._get_task(task_name)
num_samples = len(dataset["predictions"])
max_samples = self.pipeline_parameters.max_samples if self.pipeline_parameters.max_samples else num_samples
if num_samples > max_samples:
logger.warning(f"Skipping {num_samples - max_samples} samples for {task_name} when loading responses from details because max_samples is set to {max_samples}")
num_samples = self.pipeline_parameters.max_samples
for metric_category, has_metric_category in task.has_metric_category.items():
if not has_metric_category:
continue
for idx in range(num_samples):
kwargs = {
"result": ast.literal_eval(dataset["predictions"][idx]),
"input_tokens": ast.literal_eval(dataset["input_tokens"][idx]),
"generated_tokens": ast.literal_eval(dataset["cont_tokens"][idx]),
"truncated_tokens_count": ast.literal_eval(dataset["truncated"][idx])[0],
"padded_tokens_count": ast.literal_eval(dataset["padded"][idx])[0]
}
if model_response_type == GenerativeResponse:
kwargs["logits"] = ast.literal_eval(dataset["pred_logits"][idx])
response = model_response_type(**kwargs)
sample_id_to_responses[(SampleUid(task_name, f"{idx}_{0}"), metric_category)] = [response]
return sample_id_to_responses

def _run_model(self):
# Running all requests depending on the model call type (log likelihood, generative, ...)
# to be able to batch them
Expand All @@ -283,6 +336,10 @@ def _run_model(self):

return sample_id_to_responses

def _get_task(self, task_name: str):
short_task_name = task_name.rsplit("|", 1)[0]
return self.task_dict[short_task_name]

def _compute_metrics(self, sample_id_to_responses):
# To compute the metrics we first group the samples and task and then by metrics.
# This way we can batch the metrics computation for each task and metric category
Expand All @@ -307,8 +364,7 @@ def _compute_metrics(self, sample_id_to_responses):
task_metric_category_groups[sample_id.task_name][metric_category]["docs"].append(self.docs[sample_id])

for task_name, samples_per_metric in task_metric_category_groups.items():
short_task_name = task_name.rsplit("|", 1)[0]
task: LightevalTask = self.task_dict[short_task_name]
task: LightevalTask = self._get_task(task_name)

for metric_category, samples in samples_per_metric.items():
sample_ids = samples["ids"]
Expand Down

0 comments on commit a95156e

Please sign in to comment.