-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add drop-in-replacement for PyTorch DataLoader
This commit adds a data loader class that is compatible with PyTorch's DataLoader. It uses ProcessPoolExecutor, so existing Dataset implementations can be used as-is. Currently it only supports map-style dataset. It is not intended to make the attributes of DataLoader compatible. So far, only the functionality as ``Iterable`` are kept compatible. A quick benchmark shows that this implementation is much faster than PyTorch's implementation. ![QPS and Time-to-First-batch](https://github.com/user-attachments/assets/6ebff0ae-b523-492b-95af-87973c7b5fca) **QPS** (the higher, the better) | #workers | 1 | 2 | 4 | 8 | 16 | 32 | |----------|--------|--------|--------|---------|---------|---------| | SPDL | 251.49 | 454.96 | 773.8 | 1146.63 | 1516.97 | 1763.82 | | PyTorch | 265.22 | 410.2 | 446.48 | 320.63 | 172.6 | 91.83 | **TTFB** (the lower, the better) | #workers | 1 | 2 | 4 | 8 | 16 | 32 | |----------|--------|--------|--------|---------|---------|---------| | SPDL | 3.86 | 3.57 | 3.55 | 3.95 | 3.72 | 4.07 | | PyTorch | 3.49 | 6.79 | 13.32 | 26.6 | 55.32 | 107.43 | <details><summary>Benchmark code</summary> ``` import logging import time from spdl.dataloader import get_pytorch_dataloader from torch.utils.data import DataLoader from torchvision.datasets import ImageNet from torchvision.transforms import CenterCrop, Compose, PILToTensor, Resize def _test(dataset: str, num_workers: int, fn, batch_size: int = 32): print(f"{num_workers=}") num_items = 0 t0 = time.monotonic() dataloader = fn( dataset, batch_size=batch_size, num_workers=num_workers, multiprocessing_context="forkserver", ) for imgs, _ in dataloader: if num_items == 0: ttfb = time.monotonic() - t0 print(f"{ttfb=}") num_items += len(imgs) if num_items > 10000: break elapsed = time.monotonic() - t0 qps = num_items / elapsed print(f"{num_items=}, {elapsed=:.2f} ({qps=:.2f})") def _main(): logging.basicConfig(level=logging.INFO) root = "/home/moto/local/imagenet/" dataset = ImageNet( root=root, split="train", transform=Compose( [ Resize(256), CenterCrop(224), PILToTensor(), ] ), ) for num_workers in (32, 16, 8, 4, 2, 1): for fn in (get_pytorch_dataloader, DataLoader): _test(dataset, num_workers, fn) if __name__ == "__main__": _main() ``` </details>
- Loading branch information
Showing
2 changed files
with
325 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,323 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
__all__ = ["PyTorchDataLoader", "get_pytorch_dataloader"] | ||
|
||
import logging | ||
import multiprocessing as mp | ||
import os | ||
import pickle | ||
import time | ||
from collections.abc import Callable, Iterable, Iterator | ||
from concurrent.futures import Executor, ProcessPoolExecutor | ||
from multiprocessing.shared_memory import SharedMemory | ||
from types import ModuleType | ||
from typing import cast, Sized, TYPE_CHECKING, TypeVar | ||
|
||
from spdl._internal import import_utils | ||
from spdl.pipeline import Pipeline, PipelineBuilder | ||
|
||
if TYPE_CHECKING: | ||
import torch | ||
else: | ||
torch: ModuleType = import_utils.lazy_import("torch") | ||
|
||
|
||
K = TypeVar("K") | ||
T = TypeVar("T") | ||
U = TypeVar("U") | ||
|
||
_LG: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
class PyTorchDataLoader(Iterable[U]): | ||
"""PyTorchDataLoader() | ||
A PyTorch-style data loader that works on map-style dataset. | ||
Use :py:func:`get_pytorch_dataloader` to instantiate this class. | ||
You can use this class as almost drop-in replacement of PyTorch's DataLoader class. | ||
The architecture of data loader is different in following ways: | ||
- Only the dataset and the collate function are copied to the worker process. | ||
(Sampler and Generator are not copied) | ||
- The dataset is copied to worker processed via shared memory. | ||
- Sampler is executed in the main process and the resulting indices are passed to the | ||
worker processes. | ||
- Worker processes share the same input/output queues. | ||
(PyTorch creates a set of i/o queues for each worker process.) | ||
Due to the way Dataset is defined, this class still has to copy the dataset to | ||
each worker process. So the memory consumption is not reduced. | ||
However, fast initialization and reduced inter-process communication makes | ||
this implementation faster than PyTorch DataLoader. | ||
:ivar: dataset: The source dataset. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
dataset: "torch.utils.data.dataset.Dataset[T]", | ||
shmem: SharedMemory, # to keep the reference alive | ||
sampler: "torch.utils.data.sampler.Sampler[K]", | ||
fetch_fn: Callable[[K], U] | Callable[[list[K]], U], | ||
executor: Executor, | ||
num_workers: int, | ||
timeout: float | None, | ||
buffer_size: int, | ||
output_order: str = "completion", | ||
) -> None: | ||
self.dataset = dataset # For external access. | ||
self._shmem: SharedMemory = shmem | ||
self._sampler = sampler | ||
self._fetch_fn = fetch_fn | ||
self._executor = executor | ||
self._num_workers = num_workers | ||
self._buffer_size = buffer_size | ||
self._timeout = timeout | ||
self._output_order = output_order | ||
|
||
def __len__(self) -> int: | ||
"""Returns the number of samples/batches this data loader returns.""" | ||
return len(cast(Sized, self._sampler)) | ||
|
||
def _get_pipeline(self) -> Pipeline: | ||
return ( | ||
PipelineBuilder() | ||
.add_source(self._sampler) | ||
.pipe( | ||
self._fetch_fn, | ||
executor=self._executor, | ||
output_order=self._output_order, | ||
concurrency=self._num_workers, | ||
) | ||
.add_sink(self._buffer_size) | ||
.build(num_threads=1) | ||
) | ||
|
||
def __iter__(self) -> Iterator[U]: | ||
"""Iterate on the dataset and yields samples/batches.""" | ||
pipeline = self._get_pipeline() | ||
with pipeline.auto_stop(): | ||
for item in pipeline.get_iterator(timeout=self._timeout): | ||
yield item | ||
|
||
|
||
################################################################################ | ||
# ProcessExecutor | ||
################################################################################ | ||
|
||
_DATASET: "torch.utils.data.dataset.Dataset[T]" = None # pyre-ignore: [15] | ||
_COLLATE_FN: Callable = None # pyre-ignore: [15] | ||
|
||
|
||
def _get_item(index: K) -> ...: | ||
global _DATASET, _COLLATE_FN | ||
return _COLLATE_FN(_DATASET[index]) | ||
|
||
|
||
def _get_items(indices: list[K]) -> ...: | ||
global _DATASET, _COLLATE_FN | ||
if hasattr(_DATASET, "__getitems__"): | ||
return _COLLATE_FN(_DATASET.__getitems__(indices)) # pyre-ignore: [16] | ||
return _COLLATE_FN([_DATASET[index] for index in indices]) | ||
|
||
|
||
def _init_dataset(name: str, collate_fn: Callable) -> None: | ||
_LG.info("[%s] Initializing dataset.", os.getpid()) | ||
shmem = SharedMemory(name=name) | ||
global _DATASET, _COLLATE_FN | ||
_DATASET = pickle.loads(shmem.buf) | ||
_COLLATE_FN = collate_fn | ||
|
||
|
||
def _get_executor( | ||
name: str, | ||
collate_fn: Callable[[list[T]], U], | ||
num_workers: int, | ||
mp_ctx: mp.context.BaseContext, | ||
) -> ProcessPoolExecutor: | ||
executor = ProcessPoolExecutor( | ||
max_workers=num_workers, | ||
mp_context=mp_ctx, | ||
initializer=_init_dataset, | ||
initargs=(name, collate_fn), | ||
) | ||
return executor | ||
|
||
|
||
def _serialize_dataset(dataset: "torch.utils.data.dataset.Dataset[T]") -> SharedMemory: | ||
_LG.info("Serializing dataset.") | ||
t0 = time.monotonic() | ||
data = pickle.dumps(dataset) | ||
shmem = SharedMemory(create=True, size=len(data)) | ||
shmem.buf[:] = data | ||
elapsed = time.monotonic() - t0 | ||
_LG.info( | ||
"Written dataset into shared memory %s (%s bytes) in %.2f seconds", | ||
shmem.name, | ||
f"{len(data):_d}", | ||
elapsed, | ||
) | ||
return shmem | ||
|
||
|
||
################################################################################ | ||
# resolve sampler, fetch and collate | ||
################################################################################ | ||
|
||
|
||
def _get_sampler( | ||
dataset: "torch.utils.data.dataset.Dataset[T]", | ||
shuffle: bool, | ||
generator: "torch.Generator | None", | ||
) -> "torch.utils.data.sampler.Sampler[int]": | ||
from torch.utils.data.sampler import ( | ||
RandomSampler, | ||
SequentialSampler, | ||
) | ||
|
||
assert hasattr(dataset, "__len__") | ||
ds = cast(Sized, dataset) | ||
return RandomSampler(ds, generator=generator) if shuffle else SequentialSampler(ds) | ||
|
||
|
||
def _resolve_sampler( | ||
dataset: "torch.utils.data.dataset.Dataset[T]", | ||
batch_size: int | None = 1, | ||
shuffle: bool = False, | ||
sampler: "torch.utils.data.sampler.Sampler[K] | None" = None, | ||
batch_sampler: "torch.utils.data.sampler.Sampler[list[K]] | None" = None, | ||
collate_fn: Callable[[list[T]], U] | None = None, | ||
drop_last: bool = False, | ||
generator: "torch.Generator | None" = None, | ||
) -> "tuple[torch.utils.data.sampler.Sampler[K], Callable[[K], U], Callable[[list[T]], U]]": | ||
from torch.utils.data.dataloader import default_collate, default_convert | ||
from torch.utils.data.sampler import BatchSampler | ||
|
||
if all(s is not None for s in [sampler, batch_sampler]): | ||
raise ValueError("`sampler` and `batch_sampler` are mutually exclusive.") | ||
|
||
if all(o is not None for o in [batch_size, batch_sampler]): | ||
raise ValueError("`batch_size` and `batch_sampler` are mutually exclusive.") | ||
|
||
if any(s is not None for s in [sampler, batch_sampler]) and shuffle: | ||
raise ValueError( | ||
"`shuffle` must be False when `batch_sampler` or `sampler` is provided." | ||
) | ||
|
||
if batch_sampler is not None and drop_last: | ||
raise ValueError("`drop_last` must be False when `batch_sampler` is provided.") | ||
|
||
if batch_size is None and drop_last: | ||
raise ValueError("`drop_last` must be False when `batch_size` is None.") | ||
|
||
if batch_sampler is not None: | ||
_sampler = batch_sampler | ||
_fetch_fn = _get_items | ||
_collate_fn = collate_fn or default_collate | ||
elif batch_size is not None: | ||
_sampler = BatchSampler( | ||
sampler or _get_sampler(dataset, shuffle, generator), # pyre-ignore: [6] | ||
batch_size, | ||
drop_last, | ||
) | ||
_fetch_fn = _get_items | ||
_collate_fn = collate_fn or default_collate | ||
elif sampler is not None: | ||
_sampler = sampler | ||
_fetch_fn = _get_item | ||
_collate_fn = collate_fn or default_convert | ||
else: | ||
_sampler = _get_sampler(dataset, shuffle, generator) | ||
_fetch_fn = _get_item | ||
_collate_fn = collate_fn or default_convert | ||
|
||
return _sampler, _fetch_fn, _collate_fn | ||
|
||
|
||
################################################################################ | ||
# get_pytorch_dataloader | ||
################################################################################ | ||
|
||
|
||
def get_pytorch_dataloader( | ||
dataset: "torch.utils.data.dataset.Dataset[T]", | ||
batch_size: int | None = 1, | ||
shuffle: bool = False, | ||
sampler: "torch.utils.data.sampler.Sampler[K] | None" = None, | ||
batch_sampler: "torch.utils.data.sampler.Sampler[list[K]] | None" = None, | ||
num_workers: int = 1, | ||
collate_fn: Callable[[list[T]], U] | None = None, | ||
pin_memory: bool = False, | ||
drop_last: bool = False, | ||
timeout: float | None = None, | ||
worker_init_fn: None = None, | ||
multiprocessing_context: str | mp.context.BaseContext | None = "forkserver", | ||
generator: "torch.Generator | None" = None, | ||
*, | ||
prefetch_factor: int = 2, | ||
persistent_workers: bool = True, | ||
pin_memory_device: str | None = None, | ||
) -> PyTorchDataLoader[U]: | ||
from torch.utils.data.dataloader import IterableDataset | ||
|
||
if isinstance(dataset, IterableDataset): | ||
raise ValueError("IterableDataset is not supported.") | ||
|
||
if worker_init_fn is not None: | ||
raise ValueError("`worker_init_fn` is not supported.") | ||
|
||
if pin_memory: | ||
raise ValueError("`pin_memory` is not supported (yet).") | ||
|
||
if pin_memory_device is not None: | ||
raise ValueError("`pin_memory_device` is not supported (yet).") | ||
|
||
if not persistent_workers: | ||
raise ValueError("`persistent_workers=False` is not supported. ") | ||
|
||
if timeout is not None and timeout < 0: | ||
raise ValueError(f"`timeout` must be positive. Found: {timeout}.") | ||
|
||
if num_workers < 1: | ||
raise ValueError(f"`num_workers` must be greater than 0. Found: {num_workers}") | ||
|
||
buffer_size = prefetch_factor * num_workers | ||
|
||
_sampler, _fetch_fn, _collate_fn = _resolve_sampler( | ||
dataset, | ||
batch_size, | ||
shuffle, | ||
sampler, | ||
batch_sampler, | ||
collate_fn, | ||
drop_last, | ||
generator, | ||
) | ||
|
||
mp_ctx = ( | ||
multiprocessing_context | ||
if isinstance(multiprocessing_context, mp.context.BaseContext) | ||
else mp.get_context(multiprocessing_context) | ||
) | ||
_LG.info("Using multiprocessing context: %s", mp_ctx.get_start_method()) | ||
shmem = _serialize_dataset(dataset) | ||
executor = _get_executor(shmem.name, _collate_fn, num_workers, mp_ctx) | ||
|
||
return PyTorchDataLoader( | ||
dataset=dataset, | ||
shmem=shmem, | ||
sampler=_sampler, | ||
fetch_fn=_fetch_fn, | ||
executor=executor, | ||
num_workers=num_workers, | ||
timeout=timeout, | ||
buffer_size=buffer_size, | ||
) |