From e6820d98b4c93f69ab20db02f0beeadb04a66f98 Mon Sep 17 00:00:00 2001 From: Lakshya Garg Date: Mon, 16 Dec 2024 11:11:17 -0800 Subject: [PATCH] allow configurable scheduler load group Summary: Allow configurable scheduler load group for clean scheduler splits Differential Revision: D67290464 --- torchx/runner/config.py | 5 ++++- torchx/schedulers/__init__.py | 8 +++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torchx/runner/config.py b/torchx/runner/config.py index a7c0f4e8c..7f5349770 100644 --- a/torchx/runner/config.py +++ b/torchx/runner/config.py @@ -197,7 +197,10 @@ def _configparser() -> configparser.ConfigParser: def _get_scheduler(name: str) -> Scheduler: - schedulers = get_scheduler_factories() + schedulers = { + **get_scheduler_factories(), + **get_scheduler_factories(group="torchx.schedulers.orchestrator"), + } if name not in schedulers: raise ValueError( f"`{name}` is not a registered scheduler. Valid scheduler names: {schedulers.keys()}" diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index 4fb47b8e8..06df90e19 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -42,9 +42,11 @@ def run(*args: object, **kwargs: object) -> Scheduler: return run -def get_scheduler_factories() -> Dict[str, SchedulerFactory]: +def get_scheduler_factories( + group: str = "torchx.schedulers", +) -> Dict[str, SchedulerFactory]: """ - get_scheduler_factories returns all the available schedulers names and the + get_scheduler_factories returns all the available schedulers names under `group` and the method to instantiate them. The first scheduler in the dictionary is used as the default scheduler. @@ -55,7 +57,7 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]: default_schedulers[scheduler] = _defer_load_scheduler(path) return load_group( - "torchx.schedulers", + group, default=default_schedulers, )