Skip to content

Commit

Permalink
StreamingDataset: Cleanup chunks right away if the dataset doesn't fi…
Browse files Browse the repository at this point in the history
…t within the cache (#19168)

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: thomas <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Dec 18, 2023
1 parent ecdfab0 commit 0a5cca6
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 58 deletions.
10 changes: 8 additions & 2 deletions src/lightning/data/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self._downloader = None

if remote_dir:
self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks)
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks)

def download_chunk_from_index(self, chunk_index: int) -> None:
chunk_filename = self._chunks[chunk_index]["filename"]
Expand All @@ -85,6 +85,12 @@ def intervals(self) -> List[Tuple[int, int]]:
raise RuntimeError("The intervals should be defined.")
return self._intervals

@property
def num_bytes(self) -> int:
if self._config is None:
raise RuntimeError("The config should be defined.")
return sum(c["chunk_bytes"] for c in self._chunks)

@property
def data_format(self) -> Any:
if self._config is None:
Expand Down Expand Up @@ -146,7 +152,7 @@ def load(
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)

if isinstance(remote_dir, str):
downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, [])
downloader = get_downloader_cls(remote_dir, cache_dir, [])
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)

if not os.path.exists(cache_index_filepath):
Expand Down
24 changes: 11 additions & 13 deletions src/lightning/data/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# limitations under the License.
import os
import shutil
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
from abc import ABC
from typing import Any, Dict, List
from urllib import parse

from lightning.data.streaming.client import S3Client
Expand All @@ -31,28 +31,27 @@ def download_chunk_from_index(self, chunk_index: int) -> None:
remote_chunkpath = os.path.join(self._remote_dir, chunk_filename)
self.download_file(remote_chunkpath, local_chunkpath)

@abstractmethod
def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
pass


class S3Downloader(Downloader):
@classmethod
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
super().__init__(remote_dir, cache_dir, chunks)
self._client = S3Client()

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
obj = parse.urlparse(remote_filepath)

if obj.scheme != "s3":
raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}")

# TODO: Add caching to avoid re-creating it
s3 = S3Client()

from boto3.s3.transfer import TransferConfig

extra_args: Dict[str, Any] = {}

# Issue: https://github.com/boto/boto3/issues/3113
s3.client.download_file(
self._client.client.download_file(
obj.netloc,
obj.path.lstrip("/"),
local_filepath,
Expand All @@ -62,8 +61,7 @@ def download_file(cls, remote_filepath: str, local_filepath: str) -> None:


class LocalDownloader(Downloader):
@classmethod
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if not os.path.exists(remote_filepath):
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")
if remote_filepath != local_filepath:
Expand All @@ -73,8 +71,8 @@ def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
_DOWNLOADERS = {"s3://": S3Downloader, "": LocalDownloader}


def get_downloader_cls(remote_dir: str) -> Type[Downloader]:
def get_downloader_cls(remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]) -> Downloader:
for k, cls in _DOWNLOADERS.items():
if str(remote_dir).startswith(k):
return cls
return cls(remote_dir, cache_dir, chunks)
raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.")
6 changes: 2 additions & 4 deletions src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
first_exists = exists = os.path.exists(chunk_filepath)

while not exists:
sleep(0.01)
sleep(0.1)
exists = os.path.exists(chunk_filepath)

# Wait to avoid any corruption when the file appears
Expand Down Expand Up @@ -166,7 +166,6 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
def _load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
if chunk_index in self._mmaps:
return

chunk = self._chunks[chunk_index]

# Skip the header
Expand All @@ -192,7 +191,7 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
first_exists = exists = os.path.exists(chunk_filepath)

while not exists:
sleep(0.01)
sleep(0.1)
exists = os.path.exists(chunk_filepath)

# Wait to avoid any corruption when the file appears
Expand All @@ -202,7 +201,6 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
self._chunk_filepaths[chunk_filepath] = True

self._load_chunk(chunk_index, chunk_filepath)

assert self._dtype

buffer: bytes = self._buffers[chunk_index]
Expand Down
73 changes: 36 additions & 37 deletions src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from logging import Logger
from queue import Empty
from threading import Thread
from time import sleep
from typing import Any, Dict, List, Optional, Tuple, Union

from lightning.data.streaming.config import ChunksConfig
Expand All @@ -30,14 +29,21 @@

warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*")


if _TORCH_GREATER_EQUAL_2_1_0:
pass


logger = Logger(__name__)


_END_TOKEN = "END"

# Note: The timeout here should not be too short. We need to prevent the caller from aggressively
# querying the queue and consuming too many CPU cycles.
_DEFAULT_TIMEOUT = 0.1
_LONG_DEFAULT_TIMEOUT = 5


class PrepareChunksThread(Thread):
"""This thread is responsible to download the chunks associated to a given worker."""

Expand All @@ -59,22 +65,7 @@ def __init__(
self._parent_cache_dir = os.path.dirname(self._config._cache_dir)
self._to_download_queue: multiprocessing.Queue = multiprocessing.Queue()
self._to_delete_queue: multiprocessing.Queue = multiprocessing.Queue()
self._to_stop_queue: multiprocessing.Queue = multiprocessing.Queue()

# populate back the queues with existing items. As they already exists, this is almost a no-op
for chunk_index in self._collect_ordered_chunk_indexes_from_cache():
self._to_download_queue.put(chunk_index)
self._to_delete_queue.put(chunk_index)

def _collect_ordered_chunk_indexes_from_cache(self) -> List[int]:
"""List the chunks available in the cache, order them based on their creation time and retrieves their
indexes."""
chunk_indexes = [
[self._config._get_chunk_index_from_filename(f), os.path.getctime(os.path.join(self._config._cache_dir, f))]
for f in os.listdir(self._config._cache_dir)
if f.endswith(".bin")
]
return [int(x[0]) for x in sorted(chunk_indexes, key=lambda x: x[1])]
self._delete_chunks_when_processed = self._config.num_bytes > max_cache_size if max_cache_size else False

def download(self, chunk_indexes: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
Expand All @@ -93,10 +84,15 @@ def _delete(self, chunk_index: int) -> None:

def stop(self) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
self._to_stop_queue.put(True)
self._to_download_queue.put(_END_TOKEN)

def _maybe_delete_chunks(self) -> None:
chunk_index = _get_from_queue(self._to_delete_queue)
reached_pre_download = self._pre_download_counter == self._max_pre_download

# we have already pre-downloaded some chunks, we just need to wait for them to be processed.
chunk_index = _get_from_queue(
self._to_delete_queue, timeout=_LONG_DEFAULT_TIMEOUT if reached_pre_download else _DEFAULT_TIMEOUT
)

if chunk_index is not None:
self._pre_download_counter -= 1
Expand All @@ -105,14 +101,17 @@ def _maybe_delete_chunks(self) -> None:
self._chunks_index_to_be_deleted.append(chunk_index)

# Get the current cache size and decide whether we need to start cleanup. Otherwise, keep track of it
while (
self._max_cache_size
and self._chunks_index_to_be_deleted
and _get_folder_size(self._parent_cache_dir) >= self._max_cache_size
):
while self._max_cache_size and self._chunks_index_to_be_deleted and self._can_delete_chunk():
# Delete the oldest chunk
self._delete(self._chunks_index_to_be_deleted.pop(0))

return

def _can_delete_chunk(self) -> bool:
if self._delete_chunks_when_processed:
return self._pre_download_counter == self._max_pre_download - 1
return self._max_cache_size is not None and _get_folder_size(self._parent_cache_dir) >= self._max_cache_size

def _pre_load_chunk(self, chunk_index: int) -> None:
chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
self._item_loader.pre_load_chunk(chunk_index, chunk_filepath)
Expand All @@ -121,6 +120,9 @@ def run(self) -> None:
while True:
if self._pre_download_counter <= self._max_pre_download:
chunk_index = _get_from_queue(self._to_download_queue)
if chunk_index == _END_TOKEN:
return

if chunk_index is not None:
self._config.download_chunk_from_index(chunk_index)

Expand All @@ -135,11 +137,6 @@ def run(self) -> None:
if self._max_cache_size:
self._maybe_delete_chunks()

if _get_from_queue(self._to_stop_queue):
return

sleep(0.05)


class BinaryReader:
def __init__(
Expand Down Expand Up @@ -238,6 +235,9 @@ def read(self, index: ChunkedIndex) -> Any:
assert self._prepare_thread
self._prepare_thread.download([index.chunk_index])

if self._last_chunk_index is None:
self._last_chunk_index = index.chunk_index

# Fetch the element
chunk_filepath, begin, _ = self.config[index]
item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)
Expand All @@ -246,9 +246,10 @@ def read(self, index: ChunkedIndex) -> Any:
# Otherwise, this could trigger segmentation fault error depending on the item loader used.
if self._config and self._config._remote_dir and index.chunk_index != self._last_chunk_index:
assert self._prepare_thread
if self._last_chunk_index is not None:
# inform the chunk has been completely consumed
self._prepare_thread.delete([self._last_chunk_index])
assert self._last_chunk_index is not None

# inform the chunk has been completely consumed
self._prepare_thread.delete([self._last_chunk_index])

# track the new chunk index as the latest one
self._last_chunk_index = index.chunk_index
Expand Down Expand Up @@ -294,11 +295,9 @@ def _get_folder_size(path: str) -> int:
return size


def _get_from_queue(queue: multiprocessing.Queue) -> Optional[Any]:
def _get_from_queue(queue: multiprocessing.Queue, timeout: float = _DEFAULT_TIMEOUT) -> Optional[Any]:
try:
# Note: The timeout here should not be too short. We need to prevent the caller from aggressively
# querying the queue and consuming too many CPU cycles.
return queue.get(timeout=0.1)
return queue.get(timeout=timeout)
except Empty:
pass
except OSError as e:
Expand Down
25 changes: 23 additions & 2 deletions tests/tests_data/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import numpy as np
from lightning.data.streaming.cache import Cache
from lightning.data.streaming.config import ChunkedIndex
from lightning.data.streaming.reader import _get_folder_size
from lightning.data.streaming.item_loader import PyTreeLoader
from lightning.data.streaming.reader import PrepareChunksThread, _get_folder_size
from lightning_cloud.resolver import Dir


def test_reader_chunk_removal(tmpdir, monkeypatch):
def test_reader_chunk_removal(tmpdir):
cache_dir = os.path.join(tmpdir, "cache_dir")
remote_dir = os.path.join(tmpdir, "remote_dir")
os.makedirs(cache_dir, exist_ok=True)
Expand Down Expand Up @@ -79,3 +80,23 @@ def test_get_folder_size(tmpdir):
np.save(os.path.join(tmpdir, "array_2.npy"), array)

assert _get_folder_size(tmpdir) == 928 * 2


def test_prepare_chunks_thread(tmpdir):
cache_dir = os.path.join(tmpdir, "cache_dir")
os.makedirs(cache_dir, exist_ok=True)
cache = Cache(input_dir=cache_dir, chunk_size=2, max_cache_size=28020)

for i in range(25):
cache[i] = i

cache.done()
cache.merge()

cache._reader._try_load_config()

thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=1)
assert thread._delete_chunks_when_processed

thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=10000)
assert not thread._delete_chunks_when_processed

0 comments on commit 0a5cca6

Please sign in to comment.