Skip to content

Commit

Permalink
allow configurable scheduler load group
Browse files Browse the repository at this point in the history
Summary: Allow configurable scheduler load group for clean scheduler splits

Differential Revision: D67290464
  • Loading branch information
lgarg26 authored and facebook-github-bot committed Dec 16, 2024
1 parent c1a195a commit e6820d9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
5 changes: 4 additions & 1 deletion torchx/runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,10 @@ def _configparser() -> configparser.ConfigParser:


def _get_scheduler(name: str) -> Scheduler:
schedulers = get_scheduler_factories()
schedulers = {
**get_scheduler_factories(),
**get_scheduler_factories(group="torchx.schedulers.orchestrator"),
}
if name not in schedulers:
raise ValueError(
f"`{name}` is not a registered scheduler. Valid scheduler names: {schedulers.keys()}"
Expand Down
8 changes: 5 additions & 3 deletions torchx/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def run(*args: object, **kwargs: object) -> Scheduler:
return run


def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
def get_scheduler_factories(
group: str = "torchx.schedulers",
) -> Dict[str, SchedulerFactory]:
"""
get_scheduler_factories returns all the available schedulers names and the
get_scheduler_factories returns all the available schedulers names under `group` and the
method to instantiate them.
The first scheduler in the dictionary is used as the default scheduler.
Expand All @@ -55,7 +57,7 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
default_schedulers[scheduler] = _defer_load_scheduler(path)

return load_group(
"torchx.schedulers",
group,
default=default_schedulers,
)

Expand Down

0 comments on commit e6820d9

Please sign in to comment.