Skip to content

Commit

Permalink
Add drop-in-replacement for PyTorch DataLoader
Browse files Browse the repository at this point in the history
This commit adds a data loader class that is compatible with
PyTorch's DataLoader.

It uses ProcessPoolExecutor, so existing Dataset implementations
can be used as-is.

Currently it only supports map-style dataset.
It is not intended to make the attributes of DataLoader compatible.
So far, only the functionality as ``Iterable`` are kept compatible.

A quick benchmark shows that this implementation is much faster
than PyTorch's implementation.

![QPS and Time-to-First-batch](https://github.com/user-attachments/assets/6ebff0ae-b523-492b-95af-87973c7b5fca)

**QPS** (the higher, the better)

| #workers |      1 |      2 |      4 |       8 |      16 |      32 |
|----------|--------|--------|--------|---------|---------|---------|
| SPDL     | 251.49 | 454.96 |  773.8 | 1146.63 | 1516.97 | 1763.82 |
| PyTorch  | 265.22 |  410.2 | 446.48 |  320.63 |   172.6 |   91.83 |

**TTFB** (the lower, the better)

| #workers |      1 |      2 |      4 |       8 |      16 |      32 |
|----------|--------|--------|--------|---------|---------|---------|
| SPDL     |   3.86 |   3.57 |   3.55 |    3.95 |    3.72 |    4.07 |
| PyTorch  |   3.49 |   6.79 |  13.32 |    26.6 |   55.32 |  107.43 |

<details><summary>Benchmark code</summary>

```
import logging
import time

from spdl.dataloader import get_pytorch_dataloader
from torch.utils.data import DataLoader
from torchvision.datasets import ImageNet
from torchvision.transforms import CenterCrop, Compose, PILToTensor, Resize

def _test(dataset: str, num_workers: int, fn, batch_size: int = 32):
    print(f"{num_workers=}")
    num_items = 0
    t0 = time.monotonic()
    dataloader = fn(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        multiprocessing_context="forkserver",
    )

    for imgs, _ in dataloader:
        if num_items == 0:
            ttfb = time.monotonic() - t0
            print(f"{ttfb=}")
        num_items += len(imgs)
        if num_items > 10000:
            break
    elapsed = time.monotonic() - t0
    qps = num_items / elapsed
    print(f"{num_items=}, {elapsed=:.2f} ({qps=:.2f})")

def _main():
    logging.basicConfig(level=logging.INFO)
    root = "/home/moto/local/imagenet/"

    dataset = ImageNet(
        root=root,
        split="train",
        transform=Compose(
            [
                Resize(256),
                CenterCrop(224),
                PILToTensor(),
            ]
        ),
    )
    for num_workers in (32, 16, 8, 4, 2, 1):
        for fn in (get_pytorch_dataloader, DataLoader):
            _test(dataset, num_workers, fn)

if __name__ == "__main__":
    _main()
```

</details>
  • Loading branch information
mthrok committed Dec 27, 2024
1 parent 56220ad commit d95d0ac
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/spdl/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import warnings
from typing import Any

from . import _dataloader, _iterators
from . import _dataloader, _iterators, _pytorch_dataloader

_mods = [
_dataloader,
_iterators,
_pytorch_dataloader,
]

__all__ = sorted(item for mod in _mods for item in mod.__all__)
Expand Down
323 changes: 323 additions & 0 deletions src/spdl/dataloader/_pytorch_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

__all__ = ["PyTorchDataLoader", "get_pytorch_dataloader"]

import logging
import multiprocessing as mp
import os
import pickle
import time
from collections.abc import Callable, Iterable, Iterator
from concurrent.futures import Executor, ProcessPoolExecutor
from multiprocessing.shared_memory import SharedMemory
from types import ModuleType
from typing import cast, Sized, TYPE_CHECKING, TypeVar

from spdl._internal import import_utils
from spdl.pipeline import Pipeline, PipelineBuilder

if TYPE_CHECKING:
import torch
else:
torch: ModuleType = import_utils.lazy_import("torch")


K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")

_LG: logging.Logger = logging.getLogger(__name__)


class PyTorchDataLoader(Iterable[U]):
"""PyTorchDataLoader()
A PyTorch-style data loader that works on map-style dataset.
Use :py:func:`get_pytorch_dataloader` to instantiate this class.
You can use this class as almost drop-in replacement of PyTorch's DataLoader class.
The architecture of data loader is different in following ways:
- Only the dataset and the collate function are copied to the worker process.
(Sampler and Generator are not copied)
- The dataset is copied to worker processed via shared memory.
- Sampler is executed in the main process and the resulting indices are passed to the
worker processes.
- Worker processes share the same input/output queues.
(PyTorch creates a set of i/o queues for each worker process.)
Due to the way Dataset is defined, this class still has to copy the dataset to
each worker process. So the memory consumption is not reduced.
However, fast initialization and reduced inter-process communication makes
this implementation faster than PyTorch DataLoader.
:ivar: dataset: The source dataset.
"""

def __init__(
self,
*,
dataset: "torch.utils.data.dataset.Dataset[T]",
shmem: SharedMemory, # to keep the reference alive
sampler: "torch.utils.data.sampler.Sampler[K]",
fetch_fn: Callable[[K], U] | Callable[[list[K]], U],
executor: Executor,
num_workers: int,
timeout: float | None,
buffer_size: int,
output_order: str = "completion",
) -> None:
self.dataset = dataset # For external access.
self._shmem: SharedMemory = shmem
self._sampler = sampler
self._fetch_fn = fetch_fn
self._executor = executor
self._num_workers = num_workers
self._buffer_size = buffer_size
self._timeout = timeout
self._output_order = output_order

def __len__(self) -> int:
"""Returns the number of samples/batches this data loader returns."""
return len(cast(Sized, self._sampler))

def _get_pipeline(self) -> Pipeline:
return (
PipelineBuilder()
.add_source(self._sampler)
.pipe(
self._fetch_fn,
executor=self._executor,
output_order=self._output_order,
concurrency=self._num_workers,
)
.add_sink(self._buffer_size)
.build(num_threads=1)
)

def __iter__(self) -> Iterator[U]:
"""Iterate on the dataset and yields samples/batches."""
pipeline = self._get_pipeline()
with pipeline.auto_stop():
for item in pipeline.get_iterator(timeout=self._timeout):
yield item


################################################################################
# ProcessExecutor
################################################################################

_DATASET: "torch.utils.data.dataset.Dataset[T]" = None # pyre-ignore: [15]
_COLLATE_FN: Callable = None # pyre-ignore: [15]


def _get_item(index: K) -> ...:
global _DATASET, _COLLATE_FN
return _COLLATE_FN(_DATASET[index])


def _get_items(indices: list[K]) -> ...:
global _DATASET, _COLLATE_FN
if hasattr(_DATASET, "__getitems__"):
return _COLLATE_FN(_DATASET.__getitems__(indices)) # pyre-ignore: [16]
return _COLLATE_FN([_DATASET[index] for index in indices])


def _init_dataset(name: str, collate_fn: Callable) -> None:
_LG.info("[%s] Initializing dataset.", os.getpid())
shmem = SharedMemory(name=name)
global _DATASET, _COLLATE_FN
_DATASET = pickle.loads(shmem.buf)
_COLLATE_FN = collate_fn


def _get_executor(
name: str,
collate_fn: Callable[[list[T]], U],
num_workers: int,
mp_ctx: mp.context.BaseContext,
) -> ProcessPoolExecutor:
executor = ProcessPoolExecutor(
max_workers=num_workers,
mp_context=mp_ctx,
initializer=_init_dataset,
initargs=(name, collate_fn),
)
return executor


def _serialize_dataset(dataset: "torch.utils.data.dataset.Dataset[T]") -> SharedMemory:
_LG.info("Serializing dataset.")
t0 = time.monotonic()
data = pickle.dumps(dataset)
shmem = SharedMemory(create=True, size=len(data))
shmem.buf[:] = data
elapsed = time.monotonic() - t0
_LG.info(
"Written dataset into shared memory %s (%s bytes) in %.2f seconds",
shmem.name,
f"{len(data):_d}",
elapsed,
)
return shmem


################################################################################
# resolve sampler, fetch and collate
################################################################################


def _get_sampler(
dataset: "torch.utils.data.dataset.Dataset[T]",
shuffle: bool,
generator: "torch.Generator | None",
) -> "torch.utils.data.sampler.Sampler[int]":
from torch.utils.data.sampler import (
RandomSampler,
SequentialSampler,
)

assert hasattr(dataset, "__len__")
ds = cast(Sized, dataset)
return RandomSampler(ds, generator=generator) if shuffle else SequentialSampler(ds)


def _resolve_sampler(
dataset: "torch.utils.data.dataset.Dataset[T]",
batch_size: int | None = 1,
shuffle: bool = False,
sampler: "torch.utils.data.sampler.Sampler[K] | None" = None,
batch_sampler: "torch.utils.data.sampler.Sampler[list[K]] | None" = None,
collate_fn: Callable[[list[T]], U] | None = None,
drop_last: bool = False,
generator: "torch.Generator | None" = None,
) -> "tuple[torch.utils.data.sampler.Sampler[K], Callable[[K], U], Callable[[list[T]], U]]":
from torch.utils.data.dataloader import default_collate, default_convert
from torch.utils.data.sampler import BatchSampler

if all(s is not None for s in [sampler, batch_sampler]):
raise ValueError("`sampler` and `batch_sampler` are mutually exclusive.")

if all(o is not None for o in [batch_size, batch_sampler]):
raise ValueError("`batch_size` and `batch_sampler` are mutually exclusive.")

if any(s is not None for s in [sampler, batch_sampler]) and shuffle:
raise ValueError(
"`shuffle` must be False when `batch_sampler` or `sampler` is provided."
)

if batch_sampler is not None and drop_last:
raise ValueError("`drop_last` must be False when `batch_sampler` is provided.")

if batch_size is None and drop_last:
raise ValueError("`drop_last` must be False when `batch_size` is None.")

if batch_sampler is not None:
_sampler = batch_sampler
_fetch_fn = _get_items
_collate_fn = collate_fn or default_collate
elif batch_size is not None:
_sampler = BatchSampler(
sampler or _get_sampler(dataset, shuffle, generator), # pyre-ignore: [6]
batch_size,
drop_last,
)
_fetch_fn = _get_items
_collate_fn = collate_fn or default_collate
elif sampler is not None:
_sampler = sampler
_fetch_fn = _get_item
_collate_fn = collate_fn or default_convert
else:
_sampler = _get_sampler(dataset, shuffle, generator)
_fetch_fn = _get_item
_collate_fn = collate_fn or default_convert

return _sampler, _fetch_fn, _collate_fn


################################################################################
# get_pytorch_dataloader
################################################################################


def get_pytorch_dataloader(
dataset: "torch.utils.data.dataset.Dataset[T]",
batch_size: int | None = 1,
shuffle: bool = False,
sampler: "torch.utils.data.sampler.Sampler[K] | None" = None,
batch_sampler: "torch.utils.data.sampler.Sampler[list[K]] | None" = None,
num_workers: int = 1,
collate_fn: Callable[[list[T]], U] | None = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float | None = None,
worker_init_fn: None = None,
multiprocessing_context: str | mp.context.BaseContext | None = "forkserver",
generator: "torch.Generator | None" = None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = True,
pin_memory_device: str | None = None,
) -> PyTorchDataLoader[U]:
from torch.utils.data.dataloader import IterableDataset

if isinstance(dataset, IterableDataset):
raise ValueError("IterableDataset is not supported.")

if worker_init_fn is not None:
raise ValueError("`worker_init_fn` is not supported.")

if pin_memory:
raise ValueError("`pin_memory` is not supported (yet).")

if pin_memory_device is not None:
raise ValueError("`pin_memory_device` is not supported (yet).")

if not persistent_workers:
raise ValueError("`persistent_workers=False` is not supported. ")

if timeout is not None and timeout < 0:
raise ValueError(f"`timeout` must be positive. Found: {timeout}.")

if num_workers < 1:
raise ValueError(f"`num_workers` must be greater than 0. Found: {num_workers}")

buffer_size = prefetch_factor * num_workers

_sampler, _fetch_fn, _collate_fn = _resolve_sampler(
dataset,
batch_size,
shuffle,
sampler,
batch_sampler,
collate_fn,
drop_last,
generator,
)

mp_ctx = (
multiprocessing_context
if isinstance(multiprocessing_context, mp.context.BaseContext)
else mp.get_context(multiprocessing_context)
)
_LG.info("Using multiprocessing context: %s", mp_ctx.get_start_method())
shmem = _serialize_dataset(dataset)
executor = _get_executor(shmem.name, _collate_fn, num_workers, mp_ctx)

return PyTorchDataLoader(
dataset=dataset,
shmem=shmem,
sampler=_sampler,
fetch_fn=_fetch_fn,
executor=executor,
num_workers=num_workers,
timeout=timeout,
buffer_size=buffer_size,
)

0 comments on commit d95d0ac

Please sign in to comment.