Skip to content

Commit

Permalink
Introduce Dataset Optimizer (#18788)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 142977d)
  • Loading branch information
tchaton authored and lantiga committed Nov 6, 2023
1 parent 3aafbea commit 9bb08c3
Show file tree
Hide file tree
Showing 14 changed files with 1,297 additions and 96 deletions.
2 changes: 1 addition & 1 deletion requirements/app/app.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
lightning-cloud ==0.5.39 # Must be pinned to ensure compatibility
lightning-cloud ==0.5.41 # Must be pinned to ensure compatibility
packaging
typing-extensions >=4.0.0, <4.8.0
deepdiff >=5.7.0, <6.6.0
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/data/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@

from lightning.data.cache.cache import Cache
from lightning.data.cache.dataloader import LightningDataLoader
from lightning.data.cache.dataset_optimizer import DatasetOptimizer

__all__ = ["Cache", "LightningDataLoader"]
__all__ = ["Cache", "DatasetOptimizer", "LightningDataLoader"]
28 changes: 19 additions & 9 deletions src/lightning/data/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.cache.reader import BinaryReader
from lightning.data.cache.sampler import ChunkedIndex
from lightning.data.cache.writer import BinaryWriter
Expand Down Expand Up @@ -46,11 +46,13 @@ def __init__(
"""
super().__init__()
if not _TORCH_2_1_0_AVAILABLE:
if not _TORCH_GREATER_EQUAL_2_1_0:
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression)
self._cache_dir = cache_dir
self._writer = BinaryWriter(
str(cache_dir), chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression
)
self._reader = BinaryReader(str(cache_dir), remote_dir=remote_dir, compression=compression)
self._cache_dir = str(cache_dir)
self._is_done = False
self._distributed_env = _DistributedEnv.detect()

Expand All @@ -66,19 +68,27 @@ def __setitem__(self, index: int, data: Any) -> None:
"""Store an item in the writer."""
self._writer[index] = data

def _add_item(self, index: int, data: Any) -> Optional[str]:
"""Store an item in the writer and optionally return the chunk path."""
return self._writer.add_item(index, data)

def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]:
"""Read an item in the reader."""
if isinstance(index, int):
index = ChunkedIndex(index, self._get_chunk_index_from_index(index))
return self._reader.read(index)

def done(self) -> None:
def done(self) -> Optional[List[str]]:
"""Inform the writer the chunking phase is finished."""
return self._writer.done()

def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None:
"""Inform the writer the chunking phase is finished."""
self._writer.done()
self._writer.merge(num_workers, node_rank=node_rank)

def merge(self, num_workers: int = 1) -> None:
def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
"""Inform the writer the chunking phase is finished."""
self._writer.merge(num_workers)
self._writer._merge_no_wait(node_rank=node_rank)

def __len__(self) -> int:
return self._reader.get_length()
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/data/cache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import os
from typing import Any, Dict, List, Optional, Tuple

from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.cache.downloader import get_downloader_cls
from lightning.data.cache.sampler import ChunkedIndex

if _TORCH_2_1_0_AVAILABLE:
if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import treespec_loads


Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str]):
if (end - start) != chunk["chunk_size"]:
raise Exception(
"The config intervals doesn't match the number of samples. This shouldn't have happened."
f" Found {end} {start} {chunk['chunk_size']}"
)
self._intervals.append((chunk["interval"][0], chunk["interval"][1]))

Expand Down
5 changes: 4 additions & 1 deletion src/lightning/data/cache/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

_INDEX_FILENAME = "index.json"
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
_DEFAULT_FAST_DEV_RUN_ITEMS = 10

# This is required for full pytree serialization / deserialization support
_TORCH_2_1_0_AVAILABLE = RequirementCache("torch>=2.1.0")
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_41 = RequirementCache("lightning-cloud>=0.5.41")
_BOTO3_AVAILABLE = RequirementCache("boto3")
42 changes: 35 additions & 7 deletions src/lightning/data/cache/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
from torch.utils.data.sampler import BatchSampler, Sampler

from lightning.data.cache import Cache
from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_2_1_0_AVAILABLE, _VIZ_TRACKER_AVAILABLE
from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
from lightning.data.cache.sampler import CacheBatchSampler
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
from lightning.data.datasets.env import _DistributedEnv

if _TORCH_2_1_0_AVAILABLE:
if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import tree_flatten

logger = logging.Logger(__name__)
Expand Down Expand Up @@ -154,13 +154,27 @@ def __init__(self, global_rank: int, profile: bool = False) -> None:
self._global_rank = global_rank
self._profile = profile

def __call__(self, dataset_kind: _DatasetKind, *args: Any, **kwargs: Any) -> None:
def __call__(
self,
dataset_kind: Any,
dataset: Any,
index_queue: Any,
data_queue: Any,
done_event: Any,
auto_collation: Any,
collate_fn: Any,
drop_last: Any,
base_seed: Any,
init_fn: Any,
worker_id: Any,
*args: Any,
**kwargs: Any,
) -> None:
from torch.utils.data._utils import worker

from lightning.data.cache.cache import Cache

rank = _WorkerEnv.detect().rank
enable_profiling = self._global_rank == 0 and rank == 0 and _VIZ_TRACKER_AVAILABLE and self._profile
enable_profiling = self._global_rank == 0 and worker_id == 0 and _VIZ_TRACKER_AVAILABLE and self._profile

if enable_profiling:
from viztracer import VizTracer
Expand All @@ -180,7 +194,21 @@ def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher":

_DatasetKind.create_fetcher = create_fetcher_fn # type: ignore

reloaded_worker._worker_loop(dataset_kind, *args, **kwargs)
reloaded_worker._worker_loop(
dataset_kind,
dataset,
index_queue,
data_queue,
done_event,
auto_collation,
collate_fn,
drop_last,
base_seed,
init_fn,
worker_id,
*args,
**kwargs,
)

if dataset_kind == _DatasetKind.Map:
assert fetcher
Expand Down
Loading

0 comments on commit 9bb08c3

Please sign in to comment.