Skip to content

Commit

Permalink
lightning.data: Remove torch distributed for the Dataset Optimizer (#…
Browse files Browse the repository at this point in the history
…19182)

Co-authored-by: thomas <[email protected]>
  • Loading branch information
tchaton and thomas authored Dec 20, 2023
1 parent 0a5cca6 commit 1284713
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 37 deletions.
4 changes: 2 additions & 2 deletions src/lightning/app/utilities/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def get(self, path: str):
return self.session.get(url)

@_http_method_logger_wrapper
def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None):
def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None):
url = urljoin(self.base_url, path)
return self.session.post(url, data=data, params=query_params)
return self.session.post(url, data=data, params=query_params, json=json)

@_http_method_logger_wrapper
def delete(self, path: str):
Expand Down
27 changes: 3 additions & 24 deletions src/lightning/data/streaming/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from urllib import parse

import torch
from tqdm.auto import tqdm as _tqdm

from lightning import seed_everything
Expand All @@ -28,14 +27,8 @@
_LIGHTNING_CLOUD_LATEST,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.utilities.broadcast import broadcast_object
from lightning.data.utilities.packing import _pack_greedily
from lightning.fabric.accelerators.cuda import is_cuda_available
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.utilities.distributed import (
_distributed_is_initialized,
_init_dist_connection,
)
from lightning.fabric.utilities.distributed import group as _group

if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads
Expand Down Expand Up @@ -785,11 +778,11 @@ def __init__(
self.reorder_files = reorder_files

# Ensure the input dir is the same across all nodes
self.input_dir = self._broadcast_object(self.input_dir)
self.input_dir = broadcast_object("input_dir", self.input_dir)

if self.output_dir:
# Ensure the output dir is the same across all nodes
self.output_dir = self._broadcast_object(self.output_dir)
self.output_dir = broadcast_object("output_dir", self.output_dir)
print(f"Storing the files under {self.output_dir.path}")

self.random_seed = random_seed
Expand Down Expand Up @@ -971,17 +964,3 @@ def _cleanup_cache(self) -> None:
shutil.rmtree(cache_data_dir, ignore_errors=True)

os.makedirs(cache_data_dir, exist_ok=True)

def _broadcast_object(self, obj: Any) -> Any:
"""Enable to synchronize an object across machines using torch.distributed.collectives."""
num_nodes = _get_num_nodes()
if num_nodes == 1:
return obj

if not _distributed_is_initialized():
process_group_backend = "nccl" if is_cuda_available() else "gloo"
_init_dist_connection(LightningEnvironment(), process_group_backend, _get_node_rank(), num_nodes)

obj = [obj]
torch.distributed.broadcast_object_list(obj, 0, group=_group.WORLD)
return obj[0]
159 changes: 159 additions & 0 deletions src/lightning/data/utilities/broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import pickle
from logging import Logger
from typing import Any, Callable, Dict, Optional
from urllib.parse import urljoin

import requests
import urllib3

# for backwards compatibility
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

logger = Logger(__name__)

_CONNECTION_RETRY_TOTAL = 2880
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
_DEFAULT_REQUEST_TIMEOUT = 30 # seconds


class _CustomRetryAdapter(HTTPAdapter):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)
super().__init__(*args, **kwargs)

def send(self, request: Any, *args: Any, **kwargs: Any) -> Any:
kwargs["timeout"] = kwargs.get("timeout", self.timeout)
return super().send(request, **kwargs)


def _response(r: Any, *args: Any, **kwargs: Any) -> Any:
return r.raise_for_status()


class _HTTPClient:
"""A wrapper class around the requests library which handles chores like logging, retries, and timeouts
automatically."""

def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
log_callback: Optional[Callable] = None,
use_retry: bool = True,
) -> None:
self.base_url = base_url
retry_strategy = Retry(
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
# but the the maximum wait time is 120 secs. By setting a large value (2880), we'll make sure clients
# are going to be alive for a very long time (~ 4 days) but retries every 120 seconds
total=_CONNECTION_RETRY_TOTAL,
backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
status_forcelist=[
408, # Request Timeout
429, # Too Many Requests
500, # Internal Server Error
502, # Bad Gateway
503, # Service Unavailable
504, # Gateway Timeout
],
)
adapter = _CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT)
self.session = requests.Session()

self.session.hooks = {"response": _response}

if use_retry:
self.session.mount("http://", adapter)
self.session.mount("https://", adapter)

if auth_token:
self.session.headers.update({"Authorization": f"Bearer {auth_token}"})

def get(self, path: str) -> Any:
url = urljoin(self.base_url, path)
return self.session.get(url)

def post(
self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None
) -> Any:
url = urljoin(self.base_url, path)
return self.session.post(url, data=data, params=query_params, json=json)

def delete(self, path: str) -> Any:
url = urljoin(self.base_url, path)
return self.session.delete(url)


class _ImmutableDistributedMap:
"""The _ImmutableDistributedMap enables to create a distributed key value pair in the cloud.
The first process to perform the set operation defines its value.
"""

def __init__(self) -> None:
token = _get_token()

lightning_app_external_url = os.getenv("LIGHTNING_APP_EXTERNAL_URL")
if lightning_app_external_url is None:
raise RuntimeError("The `LIGHTNING_APP_EXTERNAL_URL` should be set.")

self.public_client: _HTTPClient = _HTTPClient(lightning_app_external_url, auth_token=token, use_retry=False)

lightning_app_state_url = os.getenv("LIGHTNING_APP_STATE_URL")
if lightning_app_state_url is None:
raise RuntimeError("The `LIGHTNING_APP_STATE_URL` should be set.")

self.private_client: _HTTPClient = _HTTPClient(lightning_app_state_url, auth_token=token, use_retry=False)

def set_and_get(self, key: str, value: Any) -> Any:
payload = {"key": key, "value": pickle.dumps(value, 0).decode()}

# Try the public address first
try:
resp = self.public_client.post("/broadcast", json=payload)
except (requests.exceptions.ConnectionError, urllib3.exceptions.MaxRetryError):
# fallback to the private one
resp = self.private_client.post("/broadcast", json=payload)

if resp.status_code != 200:
raise RuntimeError(f"Failed to broadcast the following {key=} {value=}.")
return pickle.loads(bytes(resp.json()["value"], "utf-8"))


def broadcast_object(key: str, obj: Any) -> Any:
"""This function enables to broadcast object across machines."""
if os.getenv("LIGHTNING_APP_EXTERNAL_URL") is not None:
return _ImmutableDistributedMap().set_and_get(key, obj)
return obj


def _get_token() -> Optional[str]:
"""This function tries to retrieve a temporary token."""
if os.getenv("LIGHTNING_CLOUD_URL") is None:
return None

payload = {"apiKey": os.getenv("LIGHTNING_API_KEY"), "username": os.getenv("LIGHTNING_USERNAME")}
url_login = os.getenv("LIGHTNING_CLOUD_URL", "") + "/v1/auth/login"
res = requests.post(url_login, data=json.dumps(payload))
if "token" not in res.json():
raise RuntimeError(
f"You haven't properly setup your environment variables with {url_login} and data: \n{payload}"
)
return res.json()["token"]
11 changes: 0 additions & 11 deletions tests/tests_data/streaming/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,6 @@ def fn(*_, **__):
_wait_for_file_to_exist(s3, obj, sleep_time=0.01)


def test_broadcast_object(tmpdir, monkeypatch):
data_processor = DataProcessor(input_dir=str(tmpdir))
assert data_processor._broadcast_object("dummy") == "dummy"
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
monkeypatch.setattr(data_processor_module, "_distributed_is_initialized", lambda: True)
torch_mock = mock.MagicMock()
monkeypatch.setattr(data_processor_module, "torch", torch_mock)
assert data_processor._broadcast_object("dummy") == "dummy"
assert torch_mock.distributed.broadcast_object_list._mock_call_args.args == (["dummy"], 0)


def test_cache_dir_cleanup(tmpdir, monkeypatch):
cache_dir = os.path.join(tmpdir, "chunks")
cache_data_dir = os.path.join(tmpdir, "data")
Expand Down
22 changes: 22 additions & 0 deletions tests/tests_data/utilities/test_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from unittest import mock

from lightning.data.utilities.broadcast import broadcast_object, requests


@mock.patch.dict(
os.environ, {"LIGHTNING_APP_EXTERNAL_URL": "http://", "LIGHTNING_APP_STATE_URL": "http://"}, clear=True
)
def test_broadcast(monkeypatch):
session = mock.MagicMock()
resp = requests.Response()
resp.status_code = 200

def fn(*args, **kwargs):
nonlocal session
return {"value": session.post._mock_call_args_list[0].kwargs["json"]["value"]}

resp.json = fn
session.post.return_value = resp
monkeypatch.setattr(requests, "Session", mock.MagicMock(return_value=session))
assert broadcast_object("key", "value") == "value"

0 comments on commit 1284713

Please sign in to comment.