Skip to content

Commit

Permalink
Lightning App: Use the batch get endpoint (#19180)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2023
1 parent d4cb46d commit ecdfab0
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 15 deletions.
20 changes: 15 additions & 5 deletions src/lightning/app/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lightning.app import _console
from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest
from lightning.app.core.constants import (
BATCH_DELTA_COUNT,
DEBUG_ENABLED,
FLOW_DURATION_SAMPLES,
FLOW_DURATION_THRESHOLD,
Expand Down Expand Up @@ -308,6 +309,14 @@ def get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None)
except queue.Empty:
return None

@staticmethod
def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> List[dict]:
try:
timeout = timeout or q.default_timeout
return q.batch_get(timeout=timeout, count=BATCH_DELTA_COUNT)
except queue.Empty:
return []

def check_error_queue(self) -> None:
exception: Exception = self.get_state_changed_from_queue(self.error_queue) # type: ignore[assignment,arg-type]
if isinstance(exception, Exception):
Expand Down Expand Up @@ -341,12 +350,15 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque

while (time() - t0) < self.state_accumulate_wait:
# TODO: Fetch all available deltas at once to reduce queue calls.
delta: Optional[
received_deltas: List[
Union[_DeltaRequest, _APIRequest, _CommandRequest, ComponentDelta]
] = self.get_state_changed_from_queue(
] = self.batch_get_state_changed_from_queue(
self.delta_queue # type: ignore[assignment,arg-type]
)
if delta:
if len(received_deltas) == []:
break

for delta in received_deltas:
if isinstance(delta, _DeltaRequest):
deltas.append(delta.delta)
elif isinstance(delta, ComponentDelta):
Expand All @@ -364,8 +376,6 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque
deltas.append(delta)
else:
api_or_command_request_deltas.append(delta)
else:
break

if api_or_command_request_deltas:
_process_requests(self, api_or_command_request_deltas)
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def get_lightning_cloud_url() -> str:
# directory where system customization sync files will be copied to be packed into app tarball
SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync"

BATCH_DELTA_COUNT = int(os.getenv("BATCH_DELTA_COUNT", "128"))


def enable_multiple_works_in_default_container() -> bool:
return bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0")))
Expand Down
49 changes: 45 additions & 4 deletions src/lightning/app/core/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import multiprocessing
import pickle
import queue # needed as import instead from/import for mocking in tests
Expand All @@ -20,14 +21,15 @@
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple
from urllib.parse import urljoin

import backoff
import requests
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout

from lightning.app.core.constants import (
BATCH_DELTA_COUNT,
HTTP_QUEUE_REFRESH_INTERVAL,
HTTP_QUEUE_REQUESTS_PER_SECOND,
HTTP_QUEUE_TOKEN,
Expand Down Expand Up @@ -189,6 +191,20 @@ def get(self, timeout: Optional[float] = None) -> Any:
"""
pass

@abstractmethod
def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
"""Returns the left most elements of the queue.
Parameters
----------
timeout:
Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
A timeout of None can be used to block indefinitely.
count:
The number of element to get from the queue
"""

@property
def is_running(self) -> bool:
"""Returns True if the queue is running, False otherwise.
Expand All @@ -214,6 +230,12 @@ def get(self, timeout: Optional[float] = None) -> Any:
timeout = self.default_timeout
return self.queue.get(timeout=timeout, block=(timeout is None))

def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
if timeout == 0:
timeout = self.default_timeout
# For multiprocessing, we can simply collect the latest upmost element
return [self.queue.get(timeout=timeout, block=(timeout is None))]


class RedisQueue(BaseQueue):
@requires("redis")
Expand Down Expand Up @@ -312,6 +334,9 @@ def get(self, timeout: Optional[float] = None) -> Any:
raise queue.Empty
return pickle.loads(out[1])

def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
return [self.get(timeout=timeout)]

def clear(self) -> None:
"""Clear all elements in the queue."""
self.redis.delete(self.name)
Expand Down Expand Up @@ -366,7 +391,6 @@ def __init__(self, queue: BaseQueue, requests_per_second: float):
self._seconds_per_request = 1 / requests_per_second

self._last_get = 0.0
self._last_put = 0.0

@property
def is_running(self) -> bool:
Expand All @@ -383,9 +407,12 @@ def get(self, timeout: Optional[float] = None) -> Any:
self._last_get = time.time()
return self._queue.get(timeout=timeout)

def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
self._wait_until_allowed(self._last_get)
self._last_get = time.time()
return self._queue.batch_get(timeout=timeout)

def put(self, item: Any) -> None:
self._wait_until_allowed(self._last_put)
self._last_put = time.time()
return self._queue.put(item)


Expand Down Expand Up @@ -477,6 +504,20 @@ def _get(self) -> Any:
# we consider the queue is empty to avoid failing the app.
raise queue.Empty

def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
try:
resp = self.client.post(
f"v1/{self.app_id}/{self._name_suffix}",
query_params={"action": "popCount", "count": str(count or BATCH_DELTA_COUNT)},
)
if resp.status_code == 204:
raise queue.Empty
return [pickle.loads(base64.b64decode(data)) for data in resp.json()]
except ConnectionError:
# Note: If the Http Queue service isn't available,
# we consider the queue is empty to avoid failing the app.
raise queue.Empty

@backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError))
def put(self, item: Any) -> None:
if not self.app_id:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/app/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def get(self, timeout: int = 0):
raise Empty()
return self._queue.pop(0)

def batch_get(self, timeout: int = 0, count: int = None):
if not self._queue:
raise Empty()
return [self._queue.pop(0)]

def __contains__(self, item):
return item in self._queue

Expand Down
10 changes: 10 additions & 0 deletions src/lightning/app/utilities/packaging/lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ def _prepare_lightning_wheels_and_requirements(root: Path, package_name: str = "
tar_name = _copy_tar(lightning_cloud_project_path, root)
tar_files.append(os.path.join(root, tar_name))

lightning_launcher_project_path = get_dist_path_if_editable_install("lightning_launcher")
if lightning_launcher_project_path:
from lightning_launcher.__version__ import __version__ as cloud_version

# todo: check why logging.info is missing in outputs
print(f"Packaged Lightning Launcher with your application. Version: {cloud_version}")
_prepare_wheel(lightning_launcher_project_path)
tar_name = _copy_tar(lightning_launcher_project_path, root)
tar_files.append(os.path.join(root, tar_name))

return functools.partial(_cleanup, *tar_files)


Expand Down
13 changes: 7 additions & 6 deletions tests/tests_app/core/test_lightning_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ def run(self):
@pytest.mark.parametrize(
("sleep_time", "expect"),
[
(1, 0),
pytest.param(0, 10.0, marks=pytest.mark.xfail(strict=False, reason="failing...")), # fixme
(0, 9),
pytest.param(9, 10.0, marks=pytest.mark.xfail(strict=False, reason="failing...")), # fixme
],
)
@pytest.mark.flaky(reruns=5)
Expand All @@ -456,10 +456,10 @@ def test_lightning_app_aggregation_speed(default_timeout, queue_type_cls: BaseQu
time window."""

class SlowQueue(queue_type_cls):
def get(self, timeout):
def batch_get(self, timeout, count):
out = super().get(timeout)
sleep(sleep_time)
return out
return [out]

app = LightningApp(EmptyFlow())

Expand All @@ -480,7 +480,7 @@ def make_delta(i):
delta = app._collect_deltas_from_ui_and_work_queues()[-1]
generated = delta.to_dict()["values_changed"]["root['vars']['counter']"]["new_value"]
if sleep_time:
assert generated == expect
assert generated == expect, (generated, expect)
else:
# validate the flow should have aggregated at least expect.
assert generated > expect
Expand All @@ -497,7 +497,8 @@ def get(self, timeout):
app.delta_queue = SlowQueue("api_delta_queue", 0)
t0 = time()
assert app._collect_deltas_from_ui_and_work_queues() == []
assert (time() - t0) < app.state_accumulate_wait
delta = time() - t0
assert delta < app.state_accumulate_wait + 0.01, delta


class SimpleFlow2(LightningFlow):
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_app/core/test_queues.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import multiprocessing
import pickle
import queue
Expand Down Expand Up @@ -220,6 +221,24 @@ def test_http_queue_get(self, monkeypatch):
)
assert test_queue.get() == "test"

def test_http_queue_batch_get(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
adapter = requests_mock.Adapter()
test_queue.client.session.mount("http://", adapter)

adapter.register_uri(
"POST",
f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=popCount",
request_headers={"Authorization": "Bearer test-token"},
status_code=200,
json=[
base64.b64encode(pickle.dumps("test")).decode("utf-8"),
base64.b64encode(pickle.dumps("test2")).decode("utf-8"),
],
)
assert test_queue.batch_get() == ["test", "test2"]


def test_unreachable_queue(monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
Expand Down

0 comments on commit ecdfab0

Please sign in to comment.