Skip to content

Commit

Permalink
Add source module for data source and iterator utils
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 10, 2025
1 parent 92c0299 commit b4bd718
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 45 deletions.
3 changes: 3 additions & 0 deletions docs/source/_templates/_custom_autosummary_class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
.. currentmodule:: {{ module }}

.. autoclass:: {{ objname }}
{%- if module.startswith("spdl.source") %}
:show-inheritance:
{%- endif %}
:members:

{%- block meths %}
Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ API Reference
spdl.io
spdl.pipeline
spdl.dataloader
spdl.source
spdl.utils
3 changes: 2 additions & 1 deletion examples/imagenet_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
import spdl.io
import spdl.utils
import torch
from spdl.dataloader import DataLoader, ImageNet
from spdl.dataloader import DataLoader
from spdl.source.imagenet import ImageNet
from torch import Tensor
from torch.profiler import profile

Expand Down
26 changes: 7 additions & 19 deletions src/spdl/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@

# pyre-unsafe

import warnings
from typing import Any

from . import _dataloader, _iterators, _pytorch_dataloader
from ._source import _imagenet, _local_directory, _type
from . import _dataloader, _pytorch_dataloader

_mods = [
_dataloader,
_iterators,
_pytorch_dataloader,
_imagenet,
_local_directory,
_type,
]

__all__ = sorted(item for mod in _mods for item in mod.__all__)
Expand All @@ -36,23 +30,17 @@ def __getattr__(name: str) -> Any:
return getattr(mod, name)

# For backward compatibility
import spdl.pipeline
if name == "iterate_in_subprocess":
import warnings

if name in spdl.pipeline.__all__:
warnings.warn(
f"{name} has been moved to {spdl.pipeline.__name__}. "
"`iterate_in_subprocess` has been moved to `spdl.source.utils`. "
"Please update the import statement to "
f"`from {spdl.pipeline.__name__} import {name}`.",
"`from spdl.source.utils import iterate_in_subprocess`.",
stacklevel=2,
)
return getattr(spdl.pipeline, name)
import spdl.source.utils

if name == "run_in_subprocess":
warnings.warn(
"`run_in_subprocess` has been deprecated. "
"Use `iterate_in_subprocess` instead.",
stacklevel=2,
)
return _iterators.run_in_subprocess
return spdl.source.utils.iterate_in_subprocess

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
17 changes: 17 additions & 0 deletions src/spdl/source/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Iterables for traversing datasets and utilities for transforming them."""

# pyre-strict

from ._type import IterableWithShuffle

__all__ = ["IterableWithShuffle"]


def __dir__() -> list[str]:
return __all__
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from os import PathLike
from pathlib import Path

from ._local_directory import LocalDirectory
from ._type import IterableWithShuffle
from .local_directory import LocalDirectory

# pyre-strict

Expand Down
File renamed without changes.
24 changes: 2 additions & 22 deletions src/spdl/dataloader/_iterators.py → src/spdl/source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
Iterator,
Sequence,
)
from typing import Any, TypeVar
from typing import TypeVar

from ._source._type import IterableWithShuffle
from ._type import IterableWithShuffle

T = TypeVar("T")
K = TypeVar("K")
Expand Down Expand Up @@ -180,26 +180,6 @@ def _drain() -> Iterator[T]:
_LG.warning("Failed to kill the worker process.")


def run_in_subprocess(
fn: Callable[..., Iterable[T]],
args: tuple[...] | None = None,
kwargs: dict[str, Any] | None = None,
queue_size: int = 64,
mp_context: str = "forkserver",
timeout: float | None = None,
daemon: bool = False,
) -> Iterator[T]:
from functools import partial

return iterate_in_subprocess(
fn=partial(fn, *(args or ()), **(kwargs or {})),
queue_size=queue_size,
mp_context=mp_context,
timeout=timeout,
daemon=daemon,
)


################################################################################
# MergeIterator
################################################################################
Expand Down
2 changes: 1 addition & 1 deletion tests/spdl_unittest/dataloader/iterator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from unittest.mock import patch

import pytest
from spdl.dataloader import MergeIterator, repeat_source
from spdl.source.utils import MergeIterator, repeat_source


def test_mergeiterator_ordered():
Expand Down
3 changes: 2 additions & 1 deletion tests/spdl_unittest/dataloader/source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pathlib import Path
from tempfile import TemporaryDirectory

from spdl.dataloader import ImageNet, LocalDirectory
from spdl.source.imagenet import ImageNet
from spdl.source.local_directory import LocalDirectory


def _make_files(paths: Iterable[Path]) -> None:
Expand Down

0 comments on commit b4bd718

Please sign in to comment.