diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py index 7942f9dc3edbe..d7ba2a4f88102 100644 --- a/src/lightning/app/utilities/network.py +++ b/src/lightning/app/utilities/network.py @@ -191,9 +191,9 @@ def get(self, path: str): return self.session.get(url) @_http_method_logger_wrapper - def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None): + def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None): url = urljoin(self.base_url, path) - return self.session.post(url, data=data, params=query_params) + return self.session.post(url, data=data, params=query_params, json=json) @_http_method_logger_wrapper def delete(self, path: str): diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index 2fe7f7411bb86..0da89b5d1895b 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from urllib import parse -import torch from tqdm.auto import tqdm as _tqdm from lightning import seed_everything @@ -28,14 +27,8 @@ _LIGHTNING_CLOUD_LATEST, _TORCH_GREATER_EQUAL_2_1_0, ) +from lightning.data.utilities.broadcast import broadcast_object from lightning.data.utilities.packing import _pack_greedily -from lightning.fabric.accelerators.cuda import is_cuda_available -from lightning.fabric.plugins.environments import LightningEnvironment -from lightning.fabric.utilities.distributed import ( - _distributed_is_initialized, - _init_dist_connection, -) -from lightning.fabric.utilities.distributed import group as _group if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads @@ -785,11 +778,11 @@ def __init__( self.reorder_files = reorder_files # Ensure the input dir is the same across all nodes - self.input_dir = self._broadcast_object(self.input_dir) + self.input_dir = broadcast_object("input_dir", self.input_dir) if self.output_dir: # Ensure the output dir is the same across all nodes - self.output_dir = self._broadcast_object(self.output_dir) + self.output_dir = broadcast_object("output_dir", self.output_dir) print(f"Storing the files under {self.output_dir.path}") self.random_seed = random_seed @@ -971,17 +964,3 @@ def _cleanup_cache(self) -> None: shutil.rmtree(cache_data_dir, ignore_errors=True) os.makedirs(cache_data_dir, exist_ok=True) - - def _broadcast_object(self, obj: Any) -> Any: - """Enable to synchronize an object across machines using torch.distributed.collectives.""" - num_nodes = _get_num_nodes() - if num_nodes == 1: - return obj - - if not _distributed_is_initialized(): - process_group_backend = "nccl" if is_cuda_available() else "gloo" - _init_dist_connection(LightningEnvironment(), process_group_backend, _get_node_rank(), num_nodes) - - obj = [obj] - torch.distributed.broadcast_object_list(obj, 0, group=_group.WORLD) - return obj[0] diff --git a/src/lightning/data/utilities/broadcast.py b/src/lightning/data/utilities/broadcast.py new file mode 100644 index 0000000000000..1713a3b37edcd --- /dev/null +++ b/src/lightning/data/utilities/broadcast.py @@ -0,0 +1,159 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import pickle +from logging import Logger +from typing import Any, Callable, Dict, Optional +from urllib.parse import urljoin + +import requests +import urllib3 + +# for backwards compatibility +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +logger = Logger(__name__) + +_CONNECTION_RETRY_TOTAL = 2880 +_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5 +_DEFAULT_REQUEST_TIMEOUT = 30 # seconds + + +class _CustomRetryAdapter(HTTPAdapter): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT) + super().__init__(*args, **kwargs) + + def send(self, request: Any, *args: Any, **kwargs: Any) -> Any: + kwargs["timeout"] = kwargs.get("timeout", self.timeout) + return super().send(request, **kwargs) + + +def _response(r: Any, *args: Any, **kwargs: Any) -> Any: + return r.raise_for_status() + + +class _HTTPClient: + """A wrapper class around the requests library which handles chores like logging, retries, and timeouts + automatically.""" + + def __init__( + self, + base_url: str, + auth_token: Optional[str] = None, + log_callback: Optional[Callable] = None, + use_retry: bool = True, + ) -> None: + self.base_url = base_url + retry_strategy = Retry( + # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1)) + # but the the maximum wait time is 120 secs. By setting a large value (2880), we'll make sure clients + # are going to be alive for a very long time (~ 4 days) but retries every 120 seconds + total=_CONNECTION_RETRY_TOTAL, + backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR, + status_forcelist=[ + 408, # Request Timeout + 429, # Too Many Requests + 500, # Internal Server Error + 502, # Bad Gateway + 503, # Service Unavailable + 504, # Gateway Timeout + ], + ) + adapter = _CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT) + self.session = requests.Session() + + self.session.hooks = {"response": _response} + + if use_retry: + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + + if auth_token: + self.session.headers.update({"Authorization": f"Bearer {auth_token}"}) + + def get(self, path: str) -> Any: + url = urljoin(self.base_url, path) + return self.session.get(url) + + def post( + self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None + ) -> Any: + url = urljoin(self.base_url, path) + return self.session.post(url, data=data, params=query_params, json=json) + + def delete(self, path: str) -> Any: + url = urljoin(self.base_url, path) + return self.session.delete(url) + + +class _ImmutableDistributedMap: + """The _ImmutableDistributedMap enables to create a distributed key value pair in the cloud. + + The first process to perform the set operation defines its value. + + """ + + def __init__(self) -> None: + token = _get_token() + + lightning_app_external_url = os.getenv("LIGHTNING_APP_EXTERNAL_URL") + if lightning_app_external_url is None: + raise RuntimeError("The `LIGHTNING_APP_EXTERNAL_URL` should be set.") + + self.public_client: _HTTPClient = _HTTPClient(lightning_app_external_url, auth_token=token, use_retry=False) + + lightning_app_state_url = os.getenv("LIGHTNING_APP_STATE_URL") + if lightning_app_state_url is None: + raise RuntimeError("The `LIGHTNING_APP_STATE_URL` should be set.") + + self.private_client: _HTTPClient = _HTTPClient(lightning_app_state_url, auth_token=token, use_retry=False) + + def set_and_get(self, key: str, value: Any) -> Any: + payload = {"key": key, "value": pickle.dumps(value, 0).decode()} + + # Try the public address first + try: + resp = self.public_client.post("/broadcast", json=payload) + except (requests.exceptions.ConnectionError, urllib3.exceptions.MaxRetryError): + # fallback to the private one + resp = self.private_client.post("/broadcast", json=payload) + + if resp.status_code != 200: + raise RuntimeError(f"Failed to broadcast the following {key=} {value=}.") + return pickle.loads(bytes(resp.json()["value"], "utf-8")) + + +def broadcast_object(key: str, obj: Any) -> Any: + """This function enables to broadcast object across machines.""" + if os.getenv("LIGHTNING_APP_EXTERNAL_URL") is not None: + return _ImmutableDistributedMap().set_and_get(key, obj) + return obj + + +def _get_token() -> Optional[str]: + """This function tries to retrieve a temporary token.""" + if os.getenv("LIGHTNING_CLOUD_URL") is None: + return None + + payload = {"apiKey": os.getenv("LIGHTNING_API_KEY"), "username": os.getenv("LIGHTNING_USERNAME")} + url_login = os.getenv("LIGHTNING_CLOUD_URL", "") + "/v1/auth/login" + res = requests.post(url_login, data=json.dumps(payload)) + if "token" not in res.json(): + raise RuntimeError( + f"You haven't properly setup your environment variables with {url_login} and data: \n{payload}" + ) + return res.json()["token"] diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index efa886a9c41c4..0a5bf87fc67e3 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -233,17 +233,6 @@ def fn(*_, **__): _wait_for_file_to_exist(s3, obj, sleep_time=0.01) -def test_broadcast_object(tmpdir, monkeypatch): - data_processor = DataProcessor(input_dir=str(tmpdir)) - assert data_processor._broadcast_object("dummy") == "dummy" - monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2") - monkeypatch.setattr(data_processor_module, "_distributed_is_initialized", lambda: True) - torch_mock = mock.MagicMock() - monkeypatch.setattr(data_processor_module, "torch", torch_mock) - assert data_processor._broadcast_object("dummy") == "dummy" - assert torch_mock.distributed.broadcast_object_list._mock_call_args.args == (["dummy"], 0) - - def test_cache_dir_cleanup(tmpdir, monkeypatch): cache_dir = os.path.join(tmpdir, "chunks") cache_data_dir = os.path.join(tmpdir, "data") diff --git a/tests/tests_data/utilities/test_broadcast.py b/tests/tests_data/utilities/test_broadcast.py new file mode 100644 index 0000000000000..97f404181915d --- /dev/null +++ b/tests/tests_data/utilities/test_broadcast.py @@ -0,0 +1,22 @@ +import os +from unittest import mock + +from lightning.data.utilities.broadcast import broadcast_object, requests + + +@mock.patch.dict( + os.environ, {"LIGHTNING_APP_EXTERNAL_URL": "http://", "LIGHTNING_APP_STATE_URL": "http://"}, clear=True +) +def test_broadcast(monkeypatch): + session = mock.MagicMock() + resp = requests.Response() + resp.status_code = 200 + + def fn(*args, **kwargs): + nonlocal session + return {"value": session.post._mock_call_args_list[0].kwargs["json"]["value"]} + + resp.json = fn + session.post.return_value = resp + monkeypatch.setattr(requests, "Session", mock.MagicMock(return_value=session)) + assert broadcast_object("key", "value") == "value"