From 582329a45cefe3a0ddb09adde19d407aeee98d1e Mon Sep 17 00:00:00 2001 From: Jesus Orozco Date: Mon, 9 Sep 2024 17:06:23 +0000 Subject: [PATCH 1/9] Adding support for Pathways proxy --- Dockerfile | 2 ++ axlearn/common/launch.py | 2 +- axlearn/common/launch_trainer_main.py | 1 + axlearn/common/utils_spmd.py | 7 ++++--- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index acdb66593..cee5a48ea 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,6 +33,8 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH" # Install dependencies. RUN pip install flit RUN pip install --upgrade pip +COPY previewutilities-0.0.5-py3-none-any.whl previewutilities-0.0.5-py3-none-any.whl +RUN pip install previewutilities-0.0.5-py3-none-any.whl ################################################################################ # CI container spec. # diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index 843454aab..f098aa7bb 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -113,7 +113,7 @@ def setup(): logging.info("Devices: %s", devices) local_devices = jax.local_devices() logging.info("Local Devices: %s", local_devices) - if not devices or not all(device.platform == FLAGS.jax_backend for device in devices): + if FLAGS.jax_backend != "proxy" and (not devices or not all(device.platform == FLAGS.jax_backend for device in devices) ): raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.") if FLAGS.data_dir: # TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR. diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index 8d170a950..d11018584 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -7,6 +7,7 @@ from axlearn.common import launch, launch_trainer, measurement from axlearn.common.config import config_for_function +import previewutilities def main(_): measurement.initialize(flags.FLAGS) diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 142da070e..32d5a0e97 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -53,7 +53,7 @@ def setup( if initialization_timeout is not None: init_kwargs["initialization_timeout"] = initialization_timeout - if jax_backend == "tpu": + if jax_backend == "tpu" or jax_backend == "proxy": if not ( distributed_coordinator is None and num_processes is None and process_id is None ): @@ -115,5 +115,6 @@ def setup( f"({initialization_timeout} seconds)." ) else: - jax.distributed.initialize(**init_kwargs) - _jax_distributed_initialized = True + if jax_backend != "proxy": + jax.distributed.initialize(**init_kwargs) + _jax_distributed_initialized = True \ No newline at end of file From 8d3c643e48fa6d48851c582f41311c31b9650bf8 Mon Sep 17 00:00:00 2001 From: Jesus Orozco Date: Tue, 1 Oct 2024 03:28:17 +0000 Subject: [PATCH 2/9] Update pathways-utils dependency and fix formatting --- Dockerfile | 2 -- axlearn/common/launch.py | 4 +++- axlearn/common/launch_trainer_main.py | 2 +- axlearn/common/utils_spmd.py | 4 ++-- pyproject.toml | 1 + 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index cee5a48ea..acdb66593 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,8 +33,6 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH" # Install dependencies. RUN pip install flit RUN pip install --upgrade pip -COPY previewutilities-0.0.5-py3-none-any.whl previewutilities-0.0.5-py3-none-any.whl -RUN pip install previewutilities-0.0.5-py3-none-any.whl ################################################################################ # CI container spec. # diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index f098aa7bb..c97ce0f39 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -113,7 +113,9 @@ def setup(): logging.info("Devices: %s", devices) local_devices = jax.local_devices() logging.info("Local Devices: %s", local_devices) - if FLAGS.jax_backend != "proxy" and (not devices or not all(device.platform == FLAGS.jax_backend for device in devices) ): + if FLAGS.jax_backend != "proxy" and ( + not devices or not all(device.platform == FLAGS.jax_backend for device in devices) + ): raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.") if FLAGS.data_dir: # TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR. diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index d11018584..8517611b8 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -2,12 +2,12 @@ """Main function for launching the trainer.""" +import pathwaysutils # pylint: disable=unused-import from absl import app, flags from axlearn.common import launch, launch_trainer, measurement from axlearn.common.config import config_for_function -import previewutilities def main(_): measurement.initialize(flags.FLAGS) diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 32d5a0e97..ce8bbe5b6 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -53,7 +53,7 @@ def setup( if initialization_timeout is not None: init_kwargs["initialization_timeout"] = initialization_timeout - if jax_backend == "tpu" or jax_backend == "proxy": + if jax_backend in ("tpu", "proxy"): if not ( distributed_coordinator is None and num_processes is None and process_id is None ): @@ -117,4 +117,4 @@ def setup( else: if jax_backend != "proxy": jax.distributed.initialize(**init_kwargs) - _jax_distributed_initialized = True \ No newline at end of file + _jax_distributed_initialized = True diff --git a/pyproject.toml b/pyproject.toml index 566780ee9..50970b470 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ gcp = [ "google-cloud-build==3.24.1", "ml_goodput_measurement==0.0.2", "pyOpenSSL>=22.1.0", # compat with cryptography version. + "pathwaysutils@git+https://github.com/google/pathways-utils", # for JAX+Pathways single-controller accelerator coordinator ] # For TPU training. # Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install. From 0e61b765edd824e3f8652498c9e65806e461480f Mon Sep 17 00:00:00 2001 From: jesus-orozco <92802826+jesus-orozco@users.noreply.github.com> Date: Mon, 7 Oct 2024 09:45:15 -0700 Subject: [PATCH 3/9] Move pathways package to its own dependency tree and pin it to a specific tagged version --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 50970b470..01c04922e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,6 @@ gcp = [ "google-cloud-build==3.24.1", "ml_goodput_measurement==0.0.2", "pyOpenSSL>=22.1.0", # compat with cryptography version. - "pathwaysutils@git+https://github.com/google/pathways-utils", # for JAX+Pathways single-controller accelerator coordinator ] # For TPU training. # Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install. @@ -144,6 +143,11 @@ orbax = [ "orbax-checkpoint==0.5.23", ] +# Pathways utilities. +pathways = [ + "pathwaysutils@git+https://github.com/google/pathways-utils@v0.0.5", # for JAX+Pathways single-controller accelerator coordinator +] + [tool.flit.module] # This defines the import name. https://flit.pypa.io/en/stable/pyproject_toml.html#module-section name = "axlearn" From 4260c383d78817e5e53b4f1238b70e7cd59bed1a Mon Sep 17 00:00:00 2001 From: Jesus Orozco Date: Wed, 9 Oct 2024 23:24:35 +0000 Subject: [PATCH 4/9] Relocate pathwaysutils import --- axlearn/cloud/gcp/job.py | 13 +++++++++++++ axlearn/common/launch_trainer_main.py | 1 - axlearn/common/utils_spmd.py | 1 + 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index f7f9e7cbf..f2220b73b 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -6,6 +6,7 @@ """ import atexit +import importlib import io import logging import math @@ -384,6 +385,7 @@ class Config(GKEJob.Config): reservation: Optional[str] = None enable_tpu_ici_resiliency: Optional[bool] = None location_hint: Optional[str] = None + use_pathways: Optional[bool] = False @classmethod def define_flags(cls, fv: flags.FlagValues): @@ -398,6 +400,9 @@ def define_flags(cls, fv: flags.FlagValues): "not all TPU types support this flag.", **common_kwargs, ) + flags.DEFINE_boolean( + "use_pathways", False, "Wether the workload is pathways-enabled.", **common_kwargs + ) @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: @@ -418,6 +423,14 @@ def __init__(self, cfg: Config): raise NotImplementedError(f"Missing system characteristics for {self._tpu_type}") super().__init__(cfg) self._gcsfuse_volume = "gcs-fuse-csi-ephemeral" + if cfg.use_pathways: + self._import_pathways() + + def _import_pathways(self): + try: + importlib.import_module("pathwaysutils") + except ModuleNotFoundError: + logging.error("An error occurred while importing pathways-utils.") def _build_container(self) -> Nested[Any]: """Builds a config for a single container. diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index 8517611b8..8d170a950 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -2,7 +2,6 @@ """Main function for launching the trainer.""" -import pathwaysutils # pylint: disable=unused-import from absl import app, flags from axlearn.common import launch, launch_trainer, measurement diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index ce8bbe5b6..baebe4133 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -53,6 +53,7 @@ def setup( if initialization_timeout is not None: init_kwargs["initialization_timeout"] = initialization_timeout + # TPU resources orchestrated by Pathways use 'proxy' as the JAX backend if jax_backend in ("tpu", "proxy"): if not ( distributed_coordinator is None and num_processes is None and process_id is None From ac6bcd236bcc51ebc1a8bd3adaeb9b9ebb7504e7 Mon Sep 17 00:00:00 2001 From: jesus-orozco <92802826+jesus-orozco@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:57:47 -0700 Subject: [PATCH 5/9] Update pathwaysutils source to pypi --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5a7a8f957..f3f721aa0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,7 +157,7 @@ audio = [ # Pathways utilities. pathways = [ - "pathwaysutils@git+https://github.com/google/pathways-utils@v0.0.5", # for JAX+Pathways single-controller accelerator coordinator + "pathwaysutils==0.0.7", # for JAX+Pathways single-controller accelerator coordinator ] [tool.flit.module] From 4868815f5fb43b2d267b47b9ef8a1d5f6595de6d Mon Sep 17 00:00:00 2001 From: Jesus Orozco Date: Mon, 11 Nov 2024 01:26:30 +0000 Subject: [PATCH 6/9] Refactor pathways config flag --- axlearn/cloud/gcp/job.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index d33e7828c..31a6f8c82 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -444,7 +444,7 @@ class Config(GKEJob.Config): enable_tpu_ici_resiliency: Optional[bool] = None location_hint: Optional[str] = None enable_tpu_smart_repair: bool = False - use_pathways: Optional[bool] = False + import_pathways: Optional[list[str]] = [] @classmethod def define_flags(cls, fv: flags.FlagValues): @@ -459,8 +459,8 @@ def define_flags(cls, fv: flags.FlagValues): "not all TPU types support this flag.", **common_kwargs, ) - flags.DEFINE_boolean( - "use_pathways", False, "Wether the workload is pathways-enabled.", **common_kwargs + flags.DEFINE_list( + "import_pathways", [], "Modules to enable pathways proxy.", **common_kwargs ) @classmethod @@ -485,14 +485,15 @@ def __init__(self, cfg: Config): raise NotImplementedError(f"Missing system characteristics for {self._tpu_type}") super().__init__(cfg) self._output_volume_mount = dict(name="shared-output", mountPath="/output") - if cfg.use_pathways: - self._import_pathways() + if len(cfg.import_pathways) > 0: + self._import_pathways(cfg.import_pathways) - def _import_pathways(self): + def _import_pathways(self, import_pathways: list[str]): try: - importlib.import_module("pathwaysutils") + for module in import_pathways: + importlib.import_module(module) except ModuleNotFoundError: - logging.error("An error occurred while importing pathways-utils.") + logging.error("An error occurred while importing pathways dependencies.") def _maybe_add_volume_mount(self, volume_mounts: list[dict], *, spec: Optional[VolumeMount]): if spec: From b51e67ed9c11dbeaaa3abb1f812f181b75398eb3 Mon Sep 17 00:00:00 2001 From: jesus-orozco <92802826+jesus-orozco@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:39:34 -0800 Subject: [PATCH 7/9] Update axlearn/cloud/gcp/job.py Co-authored-by: Ruoming Pang --- axlearn/cloud/gcp/job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 31a6f8c82..da47fcc53 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -444,7 +444,7 @@ class Config(GKEJob.Config): enable_tpu_ici_resiliency: Optional[bool] = None location_hint: Optional[str] = None enable_tpu_smart_repair: bool = False - import_pathways: Optional[list[str]] = [] + import_modules: list[str] = [] @classmethod def define_flags(cls, fv: flags.FlagValues): From 65372745498724e52334eeac70d34ccf1061b07c Mon Sep 17 00:00:00 2001 From: jesus-orozco <92802826+jesus-orozco@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:43:08 -0800 Subject: [PATCH 8/9] Update job.py with dynamic module imports --- axlearn/cloud/gcp/job.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index da47fcc53..8729c8b58 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -460,7 +460,7 @@ def define_flags(cls, fv: flags.FlagValues): **common_kwargs, ) flags.DEFINE_list( - "import_pathways", [], "Modules to enable pathways proxy.", **common_kwargs + "import_modules", [], "Modules to enable pathways proxy.", **common_kwargs ) @classmethod @@ -485,12 +485,12 @@ def __init__(self, cfg: Config): raise NotImplementedError(f"Missing system characteristics for {self._tpu_type}") super().__init__(cfg) self._output_volume_mount = dict(name="shared-output", mountPath="/output") - if len(cfg.import_pathways) > 0: - self._import_pathways(cfg.import_pathways) + if len(cfg.import_modules) > 0: + self._import_modules(cfg.import_modules) - def _import_pathways(self, import_pathways: list[str]): + def _import_modules(self, import_modules: list[str]): try: - for module in import_pathways: + for module in import_modules: importlib.import_module(module) except ModuleNotFoundError: logging.error("An error occurred while importing pathways dependencies.") From 7dbb1b9b193f51478af045553463ae29b21e140e Mon Sep 17 00:00:00 2001 From: jesus-orozco <92802826+jesus-orozco@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:44:59 -0800 Subject: [PATCH 9/9] Update job.py - remove pathways from dynamic import error message --- axlearn/cloud/gcp/job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 8729c8b58..c63b438d7 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -493,7 +493,7 @@ def _import_modules(self, import_modules: list[str]): for module in import_modules: importlib.import_module(module) except ModuleNotFoundError: - logging.error("An error occurred while importing pathways dependencies.") + logging.error("An error occurred while importing dependencies.") def _maybe_add_volume_mount(self, volume_mounts: list[dict], *, spec: Optional[VolumeMount]): if spec: