Skip to content

Commit

Permalink
improved
Browse files Browse the repository at this point in the history
  • Loading branch information
optimass committed Jan 22, 2025
1 parent 75c3092 commit 945a50e
Showing 1 changed file with 55 additions and 38 deletions.
93 changes: 55 additions & 38 deletions browsergym/experiments/src/browsergym/experiments/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,57 +96,74 @@ def prepare_backends(self):
prepare_backend(backend)
logger.info(f"{backend} backend ready")

def subset_from_split(
def subset_from_split(self, split: Literal["train", "valid", "test"]):
split_column = "browsergym_split"

# check for a split column in metadata
if split_column not in self.task_metadata.columns:
raise NotImplementedError(
f"This benchmark does not provide default train/valid/test splits (missing a {repr(split_column)} column in task metadata)"
)

# recover the target split
sub_benchmark = self.subset_from_regexp(split_column, regexp=f"^{split}$")
sub_benchmark.name = f"{self.name}_{split}"

# check that the split exists (non-empty task list)
if not sub_benchmark.env_args_list:
raise ValueError(f"The default {split} split for this benchmark is empty.")

return sub_benchmark

def subset_from_list(
self,
split: Literal["train", "valid", "test"],
task_splits: Optional[dict[str, list[str]]] = None,
task_list: list[str],
benchmark_name_suffix: Optional[str] = "custom",
split: Optional[str] = None,
):
"""Create a subset of the benchmark containing only tasks from the specified split.
"""Create a sub-benchmark containing only the specified tasks.
Args:
split: The split to filter for ("train", "valid", or "test")
task_splits: Optional dictionary mapping splits to lists of task names.
Example: {"train": ["task1", "task2"], "valid": ["task3", "task4"], "test": ["task5", "task6"]}
benchmark_name_suffix: Optional suffix to append to the new benchmark name
task_list: List of task names to include in the sub-benchmark.
benchmark_name_suffix: Optional suffix to append to the benchmark name. Defaults to "custom".
split: Optional split name to append to the benchmark name. Useful for organization.
Returns:
A new Benchmark instance containing only tasks from the specified split.
Benchmark: A new benchmark instance containing only the specified tasks.
Raises:
NotImplementedError: If task_splits is None and the metadata has no 'browsergym_split' column
ValueError: If the resulting split would be empty
ValueError: If the resulting task list is empty or if any specified task doesn't exist.
"""
if task_splits is not None:

sub_benchmark = Benchmark(
name=f"{self.name}_{benchmark_name_suffix}_{split}",
high_level_action_set_args=self.high_level_action_set_args,
is_multi_tab=self.is_multi_tab,
supports_parallel_seeds=self.supports_parallel_seeds,
backends=self.backends,
env_args_list=[
env_args
for env_args in self.env_args_list
if env_args.task_name in task_splits[split]
],
task_metadata=self.task_metadata,
)
else:
split_column = "browsergym_split"
# check for a split column in metadata
if split_column not in self.task_metadata.columns:
raise NotImplementedError(
f"This benchmark does not provide default train/valid/test splits (missing a {repr(split_column)} column in task metadata)"
)
if not task_list:
raise ValueError("Task list cannot be empty")

# recover the target split
sub_benchmark = self.subset_from_regexp(split_column, regexp=f"^{split}$")
sub_benchmark.name = f"{self.name}_{split}"
# Validate that all requested tasks exist in the original benchmark
existing_tasks = {env_args.task_name for env_args in self.env_args_list}
invalid_tasks = set(task_list) - existing_tasks
if invalid_tasks:
raise ValueError(f"The following tasks do not exist in the benchmark: {invalid_tasks}")

# check that the split exists (non-empty task list)
name = f"{self.name}_{benchmark_name_suffix}"
if split:
name += f"_{split}"

sub_benchmark = Benchmark(
name=name,
high_level_action_set_args=self.high_level_action_set_args,
is_multi_tab=self.is_multi_tab,
supports_parallel_seeds=self.supports_parallel_seeds,
backends=self.backends,
env_args_list=[
env_args for env_args in self.env_args_list if env_args.task_name in task_list
],
task_metadata=self.task_metadata,
)

# This check is redundant now due to the validation above, but kept for safety
if not sub_benchmark.env_args_list:
raise ValueError(f"The {split} split for this benchmark is empty.")
raise ValueError(
f"The custom {split if split else ''} split for this benchmark is empty."
)

return sub_benchmark

Expand Down

0 comments on commit 945a50e

Please sign in to comment.