Skip to content

Commit

Permalink
Made loading details more robust against tensors being saved in the d…
Browse files Browse the repository at this point in the history
…etails files.
  • Loading branch information
JoelNiklaus committed Jan 14, 2025
1 parent dae2d2b commit 299b90c
Showing 1 changed file with 76 additions and 2 deletions.
78 changes: 76 additions & 2 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import collections
import os
import random
import re
import shutil
from contextlib import nullcontext
from dataclasses import dataclass, field
Expand Down Expand Up @@ -288,6 +289,79 @@ def _unpack(self, x):
else:
raise ValueError(f"Unknown type {type(x)} of prediction {x}")

def _parse_tensor_string(self, tensor_string):
"""
Convert a string containing PyTorch-like `tensor([...], device='cuda:0', ...)`
into a Python list (or nested lists) of numbers.
Example:
"[tensor([1, 2, 3], device='cuda:0'), tensor([[4,5],[6,7]], dtype=torch.int64)]"
-> [[1, 2, 3], [[4, 5], [6, 7]]]
"""

# Regex explanation:
# - tensor\(\s*: Matches "tensor(" (possibly with spaces after), literally.
# - (.*?): Captures everything lazily into group(1), until the first subsequent part matches.
# We rely on the next pattern to anchor the end of this capture.
# - \): The literal closing parenthesis, but we anchor the match by ignoring
# further arguments (device=..., dtype=..., etc.) inside.
#
# The tricky part: a tensor might look like
# tensor([ ... ], device='cuda:0', dtype=torch.int64)
# so the bracket portion is `[ ... ]`, but it can have newlines, etc.
#
# We'll handle that by first capturing the entire content up to the final parenthesis,
# then parse out the bracket portion. This can be done in a function-based re.sub.

pattern = re.compile(
r"tensor\s*\(\s*(.*?)\s*\)", # capture everything inside tensor(...)
flags=re.DOTALL,
)

def tensor_replacer(match):
inside = match.group(1).strip()
# `inside` might look like: [1, 2, 3], device='cuda:0'
# or:
# [
# 1, 2, 3,
# 4, 5, ...
# ], device='cuda:0', dtype=torch.int64
#
# 1) Extract the bracketed array portion: the first [ ... ] block
# which might be multi-line. We'll use another regex for that.

# We look for the bracketed portion from the first '[' to its matching ']'.
# Because the inside can be multi-line, we use DOTALL. But we still need
# to ensure we don't accidentally go beyond the matching bracket.
#
# A robust approach to properly match brackets can be done with a small parser,
# but for typical well-formed strings, a lazy match of the form
# r"\[.*?\]" DOTALL often suffices, assuming no nested brackets inside.

bracket_pattern = re.compile(r"\[.*?\]", re.DOTALL)
bracket_match = bracket_pattern.search(inside)
if not bracket_match:
# If we fail to find a bracket, just return something safe.
# This means the string didn't match the expected format.
return "[]"

# The bracketed portion (e.g. "[1, 2, 3\n, 4]").
bracketed_content = bracket_match.group(0)

# Return just the bracketed content,
# effectively replacing "tensor(...)" with "[...]".
return bracketed_content

# Step 1: Replace every `tensor(...)` occurrence with just the bracketed list.
processed = pattern.sub(tensor_replacer, tensor_string)

# Step 2: Now we can safely parse the result with literal_eval.
# If there's still something weird, it may throw ValueError.
try:
return ast.literal_eval(processed)
except Exception as e:
raise ValueError(f"Failed to parse after preprocessing. " f"Processed string:\n{processed}\n\nError: {e}")

def _load_responses_from_details(self):
logger.info("--- LOADING RESPONSES FROM DETAILS ---")
sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list)
Expand All @@ -314,8 +388,8 @@ def _load_responses_from_details(self):
num_samples = self.pipeline_parameters.max_samples

predictions = [self._unpack(ast.literal_eval(p)) for p in dataset["predictions"][:num_samples]]
input_tokens = [ast.literal_eval(t) for t in dataset["input_tokens"][:num_samples]]
cont_tokens = [ast.literal_eval(t) for t in dataset["cont_tokens"][:num_samples]]
input_tokens = [self._parse_tensor_string(t) for t in dataset["input_tokens"][:num_samples]]
cont_tokens = [self._parse_tensor_string(t) for t in dataset["cont_tokens"][:num_samples]]
truncated = [ast.literal_eval(t)[0] for t in dataset["truncated"][:num_samples]]
padded = [ast.literal_eval(p)[0] for p in dataset["padded"][:num_samples]]

Expand Down

0 comments on commit 299b90c

Please sign in to comment.