Skip to content

Commit

Permalink
Improved loading speed and added more useful error messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Jan 11, 2025
1 parent 874541b commit 742a672
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
25 changes: 17 additions & 8 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,28 +215,37 @@ def _get_details_sub_folder(self, date_id: str):
# Get all folders in output_dir_details
if not self.fs.exists(output_dir_details):
raise FileNotFoundError(f"Details directory {output_dir_details} does not exist")

# List all folders and filter out files
folders = [f['name'] for f in self.fs.listdir(output_dir_details) if f['type'] == 'directory']
folders = [f["name"] for f in self.fs.listdir(output_dir_details) if f["type"] == "directory"]

if not folders:
raise FileNotFoundError(f"No timestamp folders found in {output_dir_details}")

# Parse timestamps and get latest
date_id = max(folders)
return output_dir_details / date_id

def load_details_datasets(self, date_id: str) -> dict[str, Dataset]:
def load_details_datasets(self, date_id: str, task_names: list[str]) -> dict[str, Dataset]:
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
date_id = output_dir_details_sub_folder.name # Overwrite date_id in case of latest
logger.info(f"Loading details from {output_dir_details_sub_folder}")
date_id = output_dir_details_sub_folder.name # Overwrite date_id in case of latest
details_datasets = {}
for file in self.fs.glob(str(output_dir_details_sub_folder / f"details_*_{date_id}.parquet")):
task_name = Path(file).stem.replace(f"details_", "").replace(f"_{date_id}", "")
task_name = Path(file).stem.replace("details_", "").replace(f"_{date_id}", "")
if "|".join(task_name.split("|")[:-1]) not in task_names:
logger.info(f"Skipping {task_name} because it is not in the task_names list")
continue
dataset = load_dataset("parquet", data_files=file, split="train")
details_datasets[task_name] = dataset

for task_name in task_names:
if not any(task_name.startswith(task_name) for task_name in details_datasets.keys()):
raise ValueError(
f"Task {task_name} not found in details datasets. Check the tasks to be evaluated or the date_id used to load the details ({self.pipeline_parameters.load_responses_from_details_date_id})."
)
return details_datasets


def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
self.fs.mkdirs(output_dir_details_sub_folder, exist_ok=True)
Expand Down
30 changes: 22 additions & 8 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from enum import Enum, auto

import numpy as np
from tqdm import tqdm

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.metrics.utils.metric_utils import MetricCategory
Expand Down Expand Up @@ -291,9 +292,10 @@ def _load_responses_from_details(self):
model_response_type = self._get_model_response_type(request_types[0])

details_datasets = self.evaluation_tracker.load_details_datasets(
self.pipeline_parameters.load_responses_from_details_date_id
self.pipeline_parameters.load_responses_from_details_date_id, self.task_names_list
)
for task_name, dataset in details_datasets.items():

for task_name, dataset in tqdm(details_datasets.items(), desc="Loading responses from details for tasks"):
task: LightevalTask = self._get_task(task_name)
num_samples = len(dataset["predictions"])
max_samples = self.pipeline_parameters.max_samples if self.pipeline_parameters.max_samples else num_samples
Expand All @@ -305,16 +307,28 @@ def _load_responses_from_details(self):
for metric_category, has_metric_category in task.has_metric_category.items():
if not has_metric_category:
continue

# Pre-evaluate all the literal strings once
predictions = [ast.literal_eval(p) for p in dataset["predictions"][:num_samples]]
input_tokens = [ast.literal_eval(t) for t in dataset["input_tokens"][:num_samples]]
cont_tokens = [ast.literal_eval(t) for t in dataset["cont_tokens"][:num_samples]]
truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]]
padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]]

if model_response_type == GenerativeResponse:
logits = [ast.literal_eval(p) for p in dataset["pred_logits"][:num_samples]]

for idx in range(num_samples):
kwargs = {
"result": ast.literal_eval(dataset["predictions"][idx]),
"input_tokens": ast.literal_eval(dataset["input_tokens"][idx]),
"generated_tokens": ast.literal_eval(dataset["cont_tokens"][idx]),
"truncated_tokens_count": ast.literal_eval(dataset["truncated"][idx])[0],
"padded_tokens_count": ast.literal_eval(dataset["padded"][idx])[0],
"result": predictions[idx],
"input_tokens": input_tokens[idx],
"generated_tokens": cont_tokens[idx],
"truncated_tokens_count": truncated[idx],
"padded_tokens_count": padded[idx],
}
if model_response_type == GenerativeResponse:
kwargs["logits"] = ast.literal_eval(dataset["pred_logits"][idx])
kwargs["logits"] = logits[idx]

response = model_response_type(**kwargs)
sample_id_to_responses[(SampleUid(task_name, f"{idx}_{0}"), metric_category)] = [response]
return sample_id_to_responses
Expand Down

0 comments on commit 742a672

Please sign in to comment.