Skip to content

Commit

Permalink
Add the ability to randomly sample the space of scenarios which will …
Browse files Browse the repository at this point in the history
…be useful when using the studies to perform hyperparameter tuning.
  • Loading branch information
bojan-karlas committed Jul 9, 2024
1 parent a1b86d6 commit 60abbfb
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 21 deletions.
65 changes: 46 additions & 19 deletions experiments/datascope/experiments/bench/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from glob import glob
from inspect import signature
from io import TextIOBase, StringIO, SEEK_END
from itertools import product
from logging import Logger
from matplotlib.figure import Figure
from methodtools import lru_cache
Expand All @@ -49,6 +48,7 @@
Callable,
Dict,
Generic,
Hashable,
Iterable,
List,
Optional,
Expand All @@ -64,6 +64,8 @@
Protocol,
)

from .generator import ConfigGenerator, GridConfigGenerator, RandomConfigGenerator, CombinedConfigGenerator


def represent(x: Any):
if isinstance(x, Enum):
Expand Down Expand Up @@ -714,7 +716,7 @@ def _compose_attributes(cls: Type["Configurable"], attributes: Dict[str, Any]) -
target_cls._get_attribute_descriptors()
)
for kk, vd in target_attribute_descriptors.items():
if vd.inherit:
if vd.inherit and kk in attributes:
target_attributes[kk] = attributes[kk]

# Pass down all attributes that are prefixed with the current key.
Expand Down Expand Up @@ -885,33 +887,58 @@ def get_keyword_replacements(cls: Type["Scenario"]) -> Dict[str, str]:
return {}

@classmethod
def get_instances(cls: Type["Scenario"], **kwargs: Any) -> Iterable["Scenario"]:
def get_instances(
cls: Type["Scenario"],
subset_sample_size: Union[int, float] = -1,
subset_grid_attributes: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterable["Scenario"]:
if cls == Scenario:
# for id, scenario in Scenario.scenarios.items():
for id, subclass in Scenario.get_subclasses().items():
if kwargs.get("scenario", None) is None or id in kwargs["scenario"]:
assert issubclass(subclass, Scenario)
for instance in subclass.get_instances(**kwargs):
for instance in subclass.get_instances(
**kwargs,
subset_sample_size=subset_sample_size,
subset_grid_attributes=subset_grid_attributes,
):
yield instance
else:
# Build a dictionary of attribute domains.
attribute_descriptors: Dict[str, AttributeDescriptor] = cls._get_attribute_descriptors(flattened=True)
domains = []
names = list(attribute_descriptors.keys())
for name in names:
if name in kwargs and kwargs[name] is not None:
domain = kwargs[name]
if not isinstance(domain, Iterable) or isinstance(domain, str):
domain = [domain]
domains.append(list(domain))
else:
domains.append([None])
for values in product(*domains):
attributes = dict((name, value) for (name, value) in zip(names, values) if value is not None)
composed_attributes = cls._compose_attributes(attributes)
attribute_domains: Dict[str, List[Hashable]] = dict(
{
name: ([domain] if not isinstance(domain, Iterable) or isinstance(domain, str) else list(domain))
for name, domain in kwargs.items()
if name in attribute_descriptors
}
)

# Initialize the appropriate config generator based on the provided attributes.
generator: Optional[ConfigGenerator] = None
if subset_sample_size > 0:
subset_grid_attributes = subset_grid_attributes or []
random_config_space = {k: v for k, v in attribute_domains.items() if k not in subset_grid_attributes}
grid_config_space = {k: v for k, v in attribute_domains.items() if k in subset_grid_attributes}
genrators: List[ConfigGenerator] = []
if len(random_config_space) > 0:
genrators.append(RandomConfigGenerator(random_config_space, subset_sample_size))
if len(grid_config_space) > 0:
genrators.append(GridConfigGenerator(grid_config_space))
generator = CombinedConfigGenerator(*genrators)

else:
generator = GridConfigGenerator(attribute_domains)

assert generator is not None
for config in generator:
instance_attributes = {k: v for k, v in config.items() if v is not None}
composed_attributes = cls._compose_attributes(instance_attributes)
if cls.is_valid_config(**composed_attributes):
scenario = cls(**composed_attributes)
# assert isinstance(scenario, Scenario)
yield scenario
else:
generator.register_invalid_config(config)

@classmethod
def is_valid_config(cls, **attributes: Any) -> bool:
Expand Down
10 changes: 8 additions & 2 deletions experiments/datascope/experiments/bench/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from multiprocessing import Pool, Lock
from multiprocessing.synchronize import Lock as LockType
from tqdm import tqdm
from typing import Any, Optional, Sequence, Tuple, List, Type
from typing import Any, Optional, Sequence, Tuple, List, Type, Union

from ..datasets import DEFAULT_BATCH_SIZE, DEFAULT_CACHE_DIR, Dataset
from ..pipelines import Pipeline
Expand Down Expand Up @@ -44,6 +44,8 @@ def run(
slurm_args: Optional[str] = None,
eventstream_host_ip: Optional[str] = None,
eventstream_host_port: Optional[int] = None,
subset_sample_size: Union[int, float] = -1,
subset_grid_attributes: Optional[List[str]] = None,
**attributes: Any
) -> None:
# If we should continue the execution of an existing study, then we should load it.
Expand All @@ -65,7 +67,11 @@ def run(
)

# Construct a study from a set of scenarios.
scenarios = list(Scenario.get_instances(**attributes))
scenarios = list(
Scenario.get_instances(
**attributes, subset_sample_size=subset_sample_size, subset_grid_attributes=subset_grid_attributes
)
)
if study is not None:
existing_scenarios = list(study.scenarios)
for cs in scenarios:
Expand Down
140 changes: 140 additions & 0 deletions experiments/datascope/experiments/bench/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from abc import ABC, abstractmethod
from collections import defaultdict, ChainMap
from itertools import product
from typing import List, Hashable, Iterator, Union, Set, Tuple, Dict
import itertools
import random
import math


MAX_SAMPLING_ATTEMPTS = 10000


class ConfigGenerator(ABC):
"""Abstract base class for generating configurations from a given config space.
Attributes:
config_space:
The configuration space.
"""

def __init__(self, config_space: Dict[str, List[Hashable]]) -> None:
self.config_space = config_space
self.invalid_configs: Set[Tuple[Hashable, ...]] = set()

@abstractmethod
def __iter__(self) -> Iterator[Dict[str, Hashable]]:
"""Iterates over the configurations in the config space."""
pass

def register_invalid_config(self, config: Dict[str, Hashable]) -> None:
"""Registers a configuration as invalid and prevents it from being generated again. Furthermore,
it helps some generators with keeping track of the number of valid configurations that were generated.
Args:
config:
The configuration to register as invalid. Keys not present in the config space are ignored.
"""
invalid_config = tuple(config[k] for k in self.config_space.keys())
self.invalid_configs.add(invalid_config)


class GridConfigGenerator(ConfigGenerator):
"""Generates configurations by iterating over the product of all variables in the config space."""

def __iter__(self) -> Iterator[Dict[str, Hashable]]:
keys, values = zip(*self.config_space.items())
for combination in itertools.product(*values):
yield dict(zip(keys, combination))


class RandomConfigGenerator(ConfigGenerator):
"""Generates random configurations from the config space.
Attributes:
sample_size:
Number of configurations to sample.
sampled_configs:
Set of sampled configurations.
sampled_values:
Dictionary tracking sampled values for each variable.
seed:
Random seed.
"""

def __init__(self, config_space: Dict[str, List[Hashable]], sample_size: Union[int, float], seed: int = 0) -> None:
super().__init__(config_space)
self.sample_size = sample_size
self.sampled_configs: Set[Tuple[Hashable, ...]] = set()
self.sampled_values: Dict[str, Set[Hashable]] = defaultdict(set)
self.random = random.Random(seed)

def __iter__(self) -> Iterator[Dict[str, Hashable]]:
total_combinations = math.prod(len(values) for values in self.config_space.values())
sample_size = (
math.ceil(self.sample_size * total_combinations)
if isinstance(self.sample_size, float)
else self.sample_size
)

while (
len(self.sampled_configs - self.invalid_configs) < sample_size
and len(self.sampled_configs | self.invalid_configs) < total_combinations
):
config = self._sample_config()
if config not in self.sampled_configs:
self.sampled_configs.add(config)
result = {k: v for k, v in zip(self.config_space.keys(), config)}
yield result

def _sample_config(self) -> Tuple[Hashable, ...]:
sampled_config: List[Hashable] = []

for key, values in self.config_space.items():
remaining_values = set(values) - self.sampled_values[key]
if remaining_values:
value = self.random.choice(list(remaining_values))
else:
value = self.random.choice(values)
sampled_config.append(value)

# If the sampled configuration is invalid, we randomly reset values of variables until a valid configuration is
# found. This is done to avoid getting stuck in a situation where no valid configurations can be sampled.
sampling_attempts = 0
while tuple(sampled_config) in self.invalid_configs:
sampling_attempts += 1
if sampling_attempts >= MAX_SAMPLING_ATTEMPTS:
raise ValueError(
f"Could not sample a valid configuration after {MAX_SAMPLING_ATTEMPTS} attempts. "
"This may happen if the sample size is too large compared to the number of valid configurations."
)
for i, (key, values) in enumerate(self.config_space.items()):
sampled_config[i] = self.random.choice(values)
if tuple(sampled_config) not in self.invalid_configs:
break

# Register the values of the sampled configuration.
for key, value in zip(self.config_space.keys(), sampled_config):
self.sampled_values[key].add(value)

return tuple(sampled_config)


class CombinedConfigGenerator(ConfigGenerator):
"""Combines configurations from a list of generators operating on disjoint sets of config variables.
Attributes:
generators:
The list of generators.
"""

def __init__(self, *generators: ConfigGenerator):
self.generators = generators

def __iter__(self) -> Iterator[Dict[str, Hashable]]:
for configs in product(*self.generators):
yield dict(ChainMap(*reversed(list(configs))))

def register_invalid_config(self, config: Dict[str, Hashable]):
for generator in self.generators:
generator.register_invalid_config(config)
19 changes: 19 additions & 0 deletions experiments/datascope/experiments/bench/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,25 @@ def main():
help="The port to use for receiving events from distributed jobs (if slurm is used as backend).",
)

parser_run.add_argument(
"--subset-sample-size",
type=int,
action=env_default("SUBSET_SAMPLE_SIZE"),
required=False,
default=-1,
help="The number of samples to take from the scenario configuration space. Default: -1 (no sampling).",
)

parser_run.add_argument(
"--subset-grid-attributes",
type=str,
nargs="+",
action=env_default("SUBSET_GRID_ATTRIBUTES"),
required=False,
default=None,
help="The attributes that will be treated as grid attributes when sampling the scenario configuration space.",
)

# Build arguments from scenario attributes.
Scenario.add_dynamic_arguments(parser=parser_run, all_iterable=True, single_instance=False)

Expand Down

0 comments on commit 60abbfb

Please sign in to comment.