From d95d0ac717aecca28824d48f26bd8d5bfaa27ac9 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:33:16 -0500 Subject: [PATCH] Add drop-in-replacement for PyTorch DataLoader This commit adds a data loader class that is compatible with PyTorch's DataLoader. It uses ProcessPoolExecutor, so existing Dataset implementations can be used as-is. Currently it only supports map-style dataset. It is not intended to make the attributes of DataLoader compatible. So far, only the functionality as ``Iterable`` are kept compatible. A quick benchmark shows that this implementation is much faster than PyTorch's implementation. ![QPS and Time-to-First-batch](https://github.com/user-attachments/assets/6ebff0ae-b523-492b-95af-87973c7b5fca) **QPS** (the higher, the better) | #workers | 1 | 2 | 4 | 8 | 16 | 32 | |----------|--------|--------|--------|---------|---------|---------| | SPDL | 251.49 | 454.96 | 773.8 | 1146.63 | 1516.97 | 1763.82 | | PyTorch | 265.22 | 410.2 | 446.48 | 320.63 | 172.6 | 91.83 | **TTFB** (the lower, the better) | #workers | 1 | 2 | 4 | 8 | 16 | 32 | |----------|--------|--------|--------|---------|---------|---------| | SPDL | 3.86 | 3.57 | 3.55 | 3.95 | 3.72 | 4.07 | | PyTorch | 3.49 | 6.79 | 13.32 | 26.6 | 55.32 | 107.43 |
Benchmark code ``` import logging import time from spdl.dataloader import get_pytorch_dataloader from torch.utils.data import DataLoader from torchvision.datasets import ImageNet from torchvision.transforms import CenterCrop, Compose, PILToTensor, Resize def _test(dataset: str, num_workers: int, fn, batch_size: int = 32): print(f"{num_workers=}") num_items = 0 t0 = time.monotonic() dataloader = fn( dataset, batch_size=batch_size, num_workers=num_workers, multiprocessing_context="forkserver", ) for imgs, _ in dataloader: if num_items == 0: ttfb = time.monotonic() - t0 print(f"{ttfb=}") num_items += len(imgs) if num_items > 10000: break elapsed = time.monotonic() - t0 qps = num_items / elapsed print(f"{num_items=}, {elapsed=:.2f} ({qps=:.2f})") def _main(): logging.basicConfig(level=logging.INFO) root = "/home/moto/local/imagenet/" dataset = ImageNet( root=root, split="train", transform=Compose( [ Resize(256), CenterCrop(224), PILToTensor(), ] ), ) for num_workers in (32, 16, 8, 4, 2, 1): for fn in (get_pytorch_dataloader, DataLoader): _test(dataset, num_workers, fn) if __name__ == "__main__": _main() ```
--- src/spdl/dataloader/__init__.py | 3 +- src/spdl/dataloader/_pytorch_dataloader.py | 323 +++++++++++++++++++++ 2 files changed, 325 insertions(+), 1 deletion(-) create mode 100644 src/spdl/dataloader/_pytorch_dataloader.py diff --git a/src/spdl/dataloader/__init__.py b/src/spdl/dataloader/__init__.py index fd4330ff..711c139a 100644 --- a/src/spdl/dataloader/__init__.py +++ b/src/spdl/dataloader/__init__.py @@ -11,11 +11,12 @@ import warnings from typing import Any -from . import _dataloader, _iterators +from . import _dataloader, _iterators, _pytorch_dataloader _mods = [ _dataloader, _iterators, + _pytorch_dataloader, ] __all__ = sorted(item for mod in _mods for item in mod.__all__) diff --git a/src/spdl/dataloader/_pytorch_dataloader.py b/src/spdl/dataloader/_pytorch_dataloader.py new file mode 100644 index 00000000..39d66d8a --- /dev/null +++ b/src/spdl/dataloader/_pytorch_dataloader.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +__all__ = ["PyTorchDataLoader", "get_pytorch_dataloader"] + +import logging +import multiprocessing as mp +import os +import pickle +import time +from collections.abc import Callable, Iterable, Iterator +from concurrent.futures import Executor, ProcessPoolExecutor +from multiprocessing.shared_memory import SharedMemory +from types import ModuleType +from typing import cast, Sized, TYPE_CHECKING, TypeVar + +from spdl._internal import import_utils +from spdl.pipeline import Pipeline, PipelineBuilder + +if TYPE_CHECKING: + import torch +else: + torch: ModuleType = import_utils.lazy_import("torch") + + +K = TypeVar("K") +T = TypeVar("T") +U = TypeVar("U") + +_LG: logging.Logger = logging.getLogger(__name__) + + +class PyTorchDataLoader(Iterable[U]): + """PyTorchDataLoader() + + A PyTorch-style data loader that works on map-style dataset. + Use :py:func:`get_pytorch_dataloader` to instantiate this class. + You can use this class as almost drop-in replacement of PyTorch's DataLoader class. + + The architecture of data loader is different in following ways: + + - Only the dataset and the collate function are copied to the worker process. + (Sampler and Generator are not copied) + - The dataset is copied to worker processed via shared memory. + - Sampler is executed in the main process and the resulting indices are passed to the + worker processes. + - Worker processes share the same input/output queues. + (PyTorch creates a set of i/o queues for each worker process.) + + Due to the way Dataset is defined, this class still has to copy the dataset to + each worker process. So the memory consumption is not reduced. + However, fast initialization and reduced inter-process communication makes + this implementation faster than PyTorch DataLoader. + + :ivar: dataset: The source dataset. + """ + + def __init__( + self, + *, + dataset: "torch.utils.data.dataset.Dataset[T]", + shmem: SharedMemory, # to keep the reference alive + sampler: "torch.utils.data.sampler.Sampler[K]", + fetch_fn: Callable[[K], U] | Callable[[list[K]], U], + executor: Executor, + num_workers: int, + timeout: float | None, + buffer_size: int, + output_order: str = "completion", + ) -> None: + self.dataset = dataset # For external access. + self._shmem: SharedMemory = shmem + self._sampler = sampler + self._fetch_fn = fetch_fn + self._executor = executor + self._num_workers = num_workers + self._buffer_size = buffer_size + self._timeout = timeout + self._output_order = output_order + + def __len__(self) -> int: + """Returns the number of samples/batches this data loader returns.""" + return len(cast(Sized, self._sampler)) + + def _get_pipeline(self) -> Pipeline: + return ( + PipelineBuilder() + .add_source(self._sampler) + .pipe( + self._fetch_fn, + executor=self._executor, + output_order=self._output_order, + concurrency=self._num_workers, + ) + .add_sink(self._buffer_size) + .build(num_threads=1) + ) + + def __iter__(self) -> Iterator[U]: + """Iterate on the dataset and yields samples/batches.""" + pipeline = self._get_pipeline() + with pipeline.auto_stop(): + for item in pipeline.get_iterator(timeout=self._timeout): + yield item + + +################################################################################ +# ProcessExecutor +################################################################################ + +_DATASET: "torch.utils.data.dataset.Dataset[T]" = None # pyre-ignore: [15] +_COLLATE_FN: Callable = None # pyre-ignore: [15] + + +def _get_item(index: K) -> ...: + global _DATASET, _COLLATE_FN + return _COLLATE_FN(_DATASET[index]) + + +def _get_items(indices: list[K]) -> ...: + global _DATASET, _COLLATE_FN + if hasattr(_DATASET, "__getitems__"): + return _COLLATE_FN(_DATASET.__getitems__(indices)) # pyre-ignore: [16] + return _COLLATE_FN([_DATASET[index] for index in indices]) + + +def _init_dataset(name: str, collate_fn: Callable) -> None: + _LG.info("[%s] Initializing dataset.", os.getpid()) + shmem = SharedMemory(name=name) + global _DATASET, _COLLATE_FN + _DATASET = pickle.loads(shmem.buf) + _COLLATE_FN = collate_fn + + +def _get_executor( + name: str, + collate_fn: Callable[[list[T]], U], + num_workers: int, + mp_ctx: mp.context.BaseContext, +) -> ProcessPoolExecutor: + executor = ProcessPoolExecutor( + max_workers=num_workers, + mp_context=mp_ctx, + initializer=_init_dataset, + initargs=(name, collate_fn), + ) + return executor + + +def _serialize_dataset(dataset: "torch.utils.data.dataset.Dataset[T]") -> SharedMemory: + _LG.info("Serializing dataset.") + t0 = time.monotonic() + data = pickle.dumps(dataset) + shmem = SharedMemory(create=True, size=len(data)) + shmem.buf[:] = data + elapsed = time.monotonic() - t0 + _LG.info( + "Written dataset into shared memory %s (%s bytes) in %.2f seconds", + shmem.name, + f"{len(data):_d}", + elapsed, + ) + return shmem + + +################################################################################ +# resolve sampler, fetch and collate +################################################################################ + + +def _get_sampler( + dataset: "torch.utils.data.dataset.Dataset[T]", + shuffle: bool, + generator: "torch.Generator | None", +) -> "torch.utils.data.sampler.Sampler[int]": + from torch.utils.data.sampler import ( + RandomSampler, + SequentialSampler, + ) + + assert hasattr(dataset, "__len__") + ds = cast(Sized, dataset) + return RandomSampler(ds, generator=generator) if shuffle else SequentialSampler(ds) + + +def _resolve_sampler( + dataset: "torch.utils.data.dataset.Dataset[T]", + batch_size: int | None = 1, + shuffle: bool = False, + sampler: "torch.utils.data.sampler.Sampler[K] | None" = None, + batch_sampler: "torch.utils.data.sampler.Sampler[list[K]] | None" = None, + collate_fn: Callable[[list[T]], U] | None = None, + drop_last: bool = False, + generator: "torch.Generator | None" = None, +) -> "tuple[torch.utils.data.sampler.Sampler[K], Callable[[K], U], Callable[[list[T]], U]]": + from torch.utils.data.dataloader import default_collate, default_convert + from torch.utils.data.sampler import BatchSampler + + if all(s is not None for s in [sampler, batch_sampler]): + raise ValueError("`sampler` and `batch_sampler` are mutually exclusive.") + + if all(o is not None for o in [batch_size, batch_sampler]): + raise ValueError("`batch_size` and `batch_sampler` are mutually exclusive.") + + if any(s is not None for s in [sampler, batch_sampler]) and shuffle: + raise ValueError( + "`shuffle` must be False when `batch_sampler` or `sampler` is provided." + ) + + if batch_sampler is not None and drop_last: + raise ValueError("`drop_last` must be False when `batch_sampler` is provided.") + + if batch_size is None and drop_last: + raise ValueError("`drop_last` must be False when `batch_size` is None.") + + if batch_sampler is not None: + _sampler = batch_sampler + _fetch_fn = _get_items + _collate_fn = collate_fn or default_collate + elif batch_size is not None: + _sampler = BatchSampler( + sampler or _get_sampler(dataset, shuffle, generator), # pyre-ignore: [6] + batch_size, + drop_last, + ) + _fetch_fn = _get_items + _collate_fn = collate_fn or default_collate + elif sampler is not None: + _sampler = sampler + _fetch_fn = _get_item + _collate_fn = collate_fn or default_convert + else: + _sampler = _get_sampler(dataset, shuffle, generator) + _fetch_fn = _get_item + _collate_fn = collate_fn or default_convert + + return _sampler, _fetch_fn, _collate_fn + + +################################################################################ +# get_pytorch_dataloader +################################################################################ + + +def get_pytorch_dataloader( + dataset: "torch.utils.data.dataset.Dataset[T]", + batch_size: int | None = 1, + shuffle: bool = False, + sampler: "torch.utils.data.sampler.Sampler[K] | None" = None, + batch_sampler: "torch.utils.data.sampler.Sampler[list[K]] | None" = None, + num_workers: int = 1, + collate_fn: Callable[[list[T]], U] | None = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float | None = None, + worker_init_fn: None = None, + multiprocessing_context: str | mp.context.BaseContext | None = "forkserver", + generator: "torch.Generator | None" = None, + *, + prefetch_factor: int = 2, + persistent_workers: bool = True, + pin_memory_device: str | None = None, +) -> PyTorchDataLoader[U]: + from torch.utils.data.dataloader import IterableDataset + + if isinstance(dataset, IterableDataset): + raise ValueError("IterableDataset is not supported.") + + if worker_init_fn is not None: + raise ValueError("`worker_init_fn` is not supported.") + + if pin_memory: + raise ValueError("`pin_memory` is not supported (yet).") + + if pin_memory_device is not None: + raise ValueError("`pin_memory_device` is not supported (yet).") + + if not persistent_workers: + raise ValueError("`persistent_workers=False` is not supported. ") + + if timeout is not None and timeout < 0: + raise ValueError(f"`timeout` must be positive. Found: {timeout}.") + + if num_workers < 1: + raise ValueError(f"`num_workers` must be greater than 0. Found: {num_workers}") + + buffer_size = prefetch_factor * num_workers + + _sampler, _fetch_fn, _collate_fn = _resolve_sampler( + dataset, + batch_size, + shuffle, + sampler, + batch_sampler, + collate_fn, + drop_last, + generator, + ) + + mp_ctx = ( + multiprocessing_context + if isinstance(multiprocessing_context, mp.context.BaseContext) + else mp.get_context(multiprocessing_context) + ) + _LG.info("Using multiprocessing context: %s", mp_ctx.get_start_method()) + shmem = _serialize_dataset(dataset) + executor = _get_executor(shmem.name, _collate_fn, num_workers, mp_ctx) + + return PyTorchDataLoader( + dataset=dataset, + shmem=shmem, + sampler=_sampler, + fetch_fn=_fetch_fn, + executor=executor, + num_workers=num_workers, + timeout=timeout, + buffer_size=buffer_size, + )