Skip to content

Commit

Permalink
Fixed gnarly bug with details loading to prevent loading too many exa…
Browse files Browse the repository at this point in the history
…mples.
  • Loading branch information
JoelNiklaus committed Jan 11, 2025
1 parent eaedd04 commit ca8331a
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,27 +297,27 @@ def _load_responses_from_details(self):

for task_name, dataset in tqdm(details_datasets.items(), desc="Loading responses from details for tasks"):
task: LightevalTask = self._get_task(task_name)
num_samples = len(dataset["predictions"])
num_samples = len(set(dataset["specifics"]))
max_samples = self.pipeline_parameters.max_samples if self.pipeline_parameters.max_samples else num_samples
if num_samples > max_samples:
logger.warning(
f"Skipping {num_samples - max_samples} samples for {task_name} when loading responses from details because max_samples is set to {max_samples}"
)
num_samples = self.pipeline_parameters.max_samples

predictions = [ast.literal_eval(p) for p in dataset["predictions"][:num_samples]]
input_tokens = [ast.literal_eval(t) for t in dataset["input_tokens"][:num_samples]]
cont_tokens = [ast.literal_eval(t) for t in dataset["cont_tokens"][:num_samples]]
truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]]
padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]]

if model_response_type == GenerativeResponse:
logits = [ast.literal_eval(p) for p in dataset["pred_logits"][:num_samples]]

for metric_category, has_metric_category in task.has_metric_category.items():
if not has_metric_category:
continue

# Pre-evaluate all the literal strings once
predictions = [ast.literal_eval(p) for p in dataset["predictions"][:num_samples]]
input_tokens = [ast.literal_eval(t) for t in dataset["input_tokens"][:num_samples]]
cont_tokens = [ast.literal_eval(t) for t in dataset["cont_tokens"][:num_samples]]
truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]]
padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]]

if model_response_type == GenerativeResponse:
logits = [ast.literal_eval(p) for p in dataset["pred_logits"][:num_samples]]

for idx in range(num_samples):
kwargs = {
"result": predictions[idx],
Expand Down

0 comments on commit ca8331a

Please sign in to comment.