Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop-in-replacement for PyTorch DataLoader #303

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/spdl/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import warnings
from typing import Any

from . import _dataloader, _iterators
from . import _dataloader, _iterators, _pytorch_dataloader

_mods = [
_dataloader,
_iterators,
_pytorch_dataloader,
]

__all__ = sorted(item for mod in _mods for item in mod.__all__)
Expand Down
323 changes: 323 additions & 0 deletions src/spdl/dataloader/_pytorch_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

__all__ = ["PyTorchDataLoader", "get_pytorch_dataloader"]

import logging
import multiprocessing as mp
import os
import pickle
import time
from collections.abc import Callable, Iterable, Iterator
from concurrent.futures import Executor, ProcessPoolExecutor
from multiprocessing.shared_memory import SharedMemory
from types import ModuleType
from typing import cast, Sized, TYPE_CHECKING, TypeVar

from spdl._internal import import_utils
from spdl.pipeline import Pipeline, PipelineBuilder

if TYPE_CHECKING:
import torch
else:
torch: ModuleType = import_utils.lazy_import("torch")


K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")

_LG: logging.Logger = logging.getLogger(__name__)


class PyTorchDataLoader(Iterable[U]):
"""PyTorchDataLoader()

A PyTorch-style data loader that works on map-style dataset.
Use :py:func:`get_pytorch_dataloader` to instantiate this class.
You can use this class as almost drop-in replacement of PyTorch's DataLoader class.

The architecture of data loader is different in following ways:

- Only the dataset and the collate function are copied to the worker process.
(Sampler and Generator are not copied)
- The dataset is copied to worker processed via shared memory.
- Sampler is executed in the main process and the resulting indices are passed to the
worker processes.
- Worker processes share the same input/output queues.
(PyTorch creates a set of i/o queues for each worker process.)

Due to the way Dataset is defined, this class still has to copy the dataset to
each worker process. So the memory consumption is not reduced.
However, fast initialization and reduced inter-process communication makes
this implementation faster than PyTorch DataLoader.

:ivar: dataset: The source dataset.
"""

def __init__(
self,
*,
dataset: "torch.utils.data.dataset.Dataset[T]",
shmem: SharedMemory, # to keep the reference alive
sampler: "torch.utils.data.sampler.Sampler[K]",
fetch_fn: Callable[[K], U] | Callable[[list[K]], U],
executor: Executor,
num_workers: int,
timeout: float | None,
buffer_size: int,
output_order: str = "completion",
) -> None:
self.dataset = dataset # For external access.
self._shmem: SharedMemory = shmem
self._sampler = sampler
self._fetch_fn = fetch_fn
self._executor = executor
self._num_workers = num_workers
self._buffer_size = buffer_size
self._timeout = timeout
self._output_order = output_order

def __len__(self) -> int:
"""Returns the number of samples/batches this data loader returns."""
return len(cast(Sized, self._sampler))

def _get_pipeline(self) -> Pipeline:
return (
PipelineBuilder()
.add_source(self._sampler)
.pipe(
self._fetch_fn,
executor=self._executor,
output_order=self._output_order,
concurrency=self._num_workers,
)
.add_sink(self._buffer_size)
.build(num_threads=1)
)

def __iter__(self) -> Iterator[U]:
"""Iterate on the dataset and yields samples/batches."""
pipeline = self._get_pipeline()
with pipeline.auto_stop():
for item in pipeline.get_iterator(timeout=self._timeout):
yield item


################################################################################
# ProcessExecutor
################################################################################

_DATASET: "torch.utils.data.dataset.Dataset[T]" = None # pyre-ignore: [15]
_COLLATE_FN: Callable = None # pyre-ignore: [15]


def _get_item(index: K) -> ...:
global _DATASET, _COLLATE_FN
return _COLLATE_FN(_DATASET[index])


def _get_items(indices: list[K]) -> ...:
global _DATASET, _COLLATE_FN
if hasattr(_DATASET, "__getitems__"):
return _COLLATE_FN(_DATASET.__getitems__(indices)) # pyre-ignore: [16]
return _COLLATE_FN([_DATASET[index] for index in indices])


def _init_dataset(name: str, collate_fn: Callable) -> None:
_LG.info("[%s] Initializing dataset.", os.getpid())
shmem = SharedMemory(name=name)
global _DATASET, _COLLATE_FN
_DATASET = pickle.loads(shmem.buf)
_COLLATE_FN = collate_fn


def _get_executor(
name: str,
collate_fn: Callable[[list[T]], U],
num_workers: int,
mp_ctx: mp.context.BaseContext,
) -> ProcessPoolExecutor:
executor = ProcessPoolExecutor(
max_workers=num_workers,
mp_context=mp_ctx,
initializer=_init_dataset,
initargs=(name, collate_fn),
)
return executor


def _serialize_dataset(dataset: "torch.utils.data.dataset.Dataset[T]") -> SharedMemory:
_LG.info("Serializing dataset.")
t0 = time.monotonic()
data = pickle.dumps(dataset)
shmem = SharedMemory(create=True, size=len(data))
shmem.buf[:] = data
elapsed = time.monotonic() - t0
_LG.info(
"Written dataset into shared memory %s (%s bytes) in %.2f seconds",
shmem.name,
f"{len(data):_d}",
elapsed,
)
return shmem


################################################################################
# resolve sampler, fetch and collate
################################################################################


def _get_sampler(
dataset: "torch.utils.data.dataset.Dataset[T]",
shuffle: bool,
generator: "torch.Generator | None",
) -> "torch.utils.data.sampler.Sampler[int]":
from torch.utils.data.sampler import (
RandomSampler,
SequentialSampler,
)

assert hasattr(dataset, "__len__")
ds = cast(Sized, dataset)
return RandomSampler(ds, generator=generator) if shuffle else SequentialSampler(ds)


def _resolve_sampler(
dataset: "torch.utils.data.dataset.Dataset[T]",
batch_size: int | None = 1,
shuffle: bool = False,
sampler: "torch.utils.data.sampler.Sampler[K] | None" = None,
batch_sampler: "torch.utils.data.sampler.Sampler[list[K]] | None" = None,
collate_fn: Callable[[list[T]], U] | None = None,
drop_last: bool = False,
generator: "torch.Generator | None" = None,
) -> "tuple[torch.utils.data.sampler.Sampler[K], Callable[[K], U], Callable[[list[T]], U]]":
from torch.utils.data.dataloader import default_collate, default_convert
from torch.utils.data.sampler import BatchSampler

if all(s is not None for s in [sampler, batch_sampler]):
raise ValueError("`sampler` and `batch_sampler` are mutually exclusive.")

if all(o is not None for o in [batch_size, batch_sampler]):
raise ValueError("`batch_size` and `batch_sampler` are mutually exclusive.")

if any(s is not None for s in [sampler, batch_sampler]) and shuffle:
raise ValueError(
"`shuffle` must be False when `batch_sampler` or `sampler` is provided."
)

if batch_sampler is not None and drop_last:
raise ValueError("`drop_last` must be False when `batch_sampler` is provided.")

if batch_size is None and drop_last:
raise ValueError("`drop_last` must be False when `batch_size` is None.")

if batch_sampler is not None:
_sampler = batch_sampler
_fetch_fn = _get_items
_collate_fn = collate_fn or default_collate
elif batch_size is not None:
_sampler = BatchSampler(
sampler or _get_sampler(dataset, shuffle, generator), # pyre-ignore: [6]
batch_size,
drop_last,
)
_fetch_fn = _get_items
_collate_fn = collate_fn or default_collate
elif sampler is not None:
_sampler = sampler
_fetch_fn = _get_item
_collate_fn = collate_fn or default_convert
else:
_sampler = _get_sampler(dataset, shuffle, generator)
_fetch_fn = _get_item
_collate_fn = collate_fn or default_convert

return _sampler, _fetch_fn, _collate_fn


################################################################################
# get_pytorch_dataloader
################################################################################


def get_pytorch_dataloader(
dataset: "torch.utils.data.dataset.Dataset[T]",
batch_size: int | None = 1,
shuffle: bool = False,
sampler: "torch.utils.data.sampler.Sampler[K] | None" = None,
batch_sampler: "torch.utils.data.sampler.Sampler[list[K]] | None" = None,
num_workers: int = 1,
collate_fn: Callable[[list[T]], U] | None = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float | None = None,
worker_init_fn: None = None,
multiprocessing_context: str | mp.context.BaseContext | None = "forkserver",
generator: "torch.Generator | None" = None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = True,
pin_memory_device: str | None = None,
) -> PyTorchDataLoader[U]:
from torch.utils.data.dataloader import IterableDataset

if isinstance(dataset, IterableDataset):
raise ValueError("IterableDataset is not supported.")

if worker_init_fn is not None:
raise ValueError("`worker_init_fn` is not supported.")

if pin_memory:
raise ValueError("`pin_memory` is not supported (yet).")

if pin_memory_device is not None:
raise ValueError("`pin_memory_device` is not supported (yet).")

if not persistent_workers:
raise ValueError("`persistent_workers=False` is not supported. ")

if timeout is not None and timeout < 0:
raise ValueError(f"`timeout` must be positive. Found: {timeout}.")

if num_workers < 1:
raise ValueError(f"`num_workers` must be greater than 0. Found: {num_workers}")

buffer_size = prefetch_factor * num_workers

_sampler, _fetch_fn, _collate_fn = _resolve_sampler(
dataset,
batch_size,
shuffle,
sampler,
batch_sampler,
collate_fn,
drop_last,
generator,
)

mp_ctx = (
multiprocessing_context
if isinstance(multiprocessing_context, mp.context.BaseContext)
else mp.get_context(multiprocessing_context)
)
_LG.info("Using multiprocessing context: %s", mp_ctx.get_start_method())
shmem = _serialize_dataset(dataset)
executor = _get_executor(shmem.name, _collate_fn, num_workers, mp_ctx)

return PyTorchDataLoader(
dataset=dataset,
shmem=shmem,
sampler=_sampler,
fetch_fn=_fetch_fn,
executor=executor,
num_workers=num_workers,
timeout=timeout,
buffer_size=buffer_size,
)
Loading