From 1ca095d442c16dccc361e922d3671c2ffa23d1cb Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Sun, 12 Jan 2025 16:42:34 -0500 Subject: [PATCH] Manually import classes/functions --- src/spdl/dataloader/_dataloader.py | 17 ++++++---- src/spdl/dataloader/_pytorch_dataloader.py | 8 ++--- src/spdl/pipeline/__init__.py | 39 ++++++++-------------- src/spdl/pipeline/_builder.py | 2 +- src/spdl/pipeline/_convert.py | 3 +- 5 files changed, 31 insertions(+), 38 deletions(-) diff --git a/src/spdl/dataloader/_dataloader.py b/src/spdl/dataloader/_dataloader.py index 393265c8..2b60a0d1 100644 --- a/src/spdl/dataloader/_dataloader.py +++ b/src/spdl/dataloader/_dataloader.py @@ -228,20 +228,23 @@ def __init__( self._output_order = output_order def _get_pipeline(self) -> Pipeline: - pipe_args = { - "concurrency": self._num_threads, - "output_order": self._output_order, - } - builder = PipelineBuilder().add_source(self._src) if self._preprocessor: - builder.pipe(self._preprocessor, **pipe_args) + builder.pipe( + self._preprocessor, + concurrency=self._num_threads, + output_order=self._output_order, + ) if self._batch_size: builder.aggregate(self._batch_size, drop_last=self._drop_last) if self._aggregator: - builder.pipe(self._aggregator, **pipe_args) + builder.pipe( + self._aggregator, + concurrency=self._num_threads, + output_order=self._output_order, + ) # Transfer runs in the default thread pool (with num_threads=1) # because GPU data transfer cannot be parallelized. diff --git a/src/spdl/dataloader/_pytorch_dataloader.py b/src/spdl/dataloader/_pytorch_dataloader.py index a7223830..d675b4a8 100644 --- a/src/spdl/dataloader/_pytorch_dataloader.py +++ b/src/spdl/dataloader/_pytorch_dataloader.py @@ -126,8 +126,8 @@ def __init__( dataset: "torch.utils.data.dataset.Dataset[T]", shmem: SharedMemory, # to keep the reference alive sampler: "torch.utils.data.sampler.Sampler[K]", - fetch_fn: Callable[[K], U] | Callable[[list[K]], U], - collate_fn: Callable[[list[T]], U], + fetch_fn: Callable[[K], U], + collate_fn: Callable[[list[U]], V], mp_ctx: mp.context.BaseContext, num_workers: int, timeout: float | None, @@ -149,7 +149,7 @@ def __len__(self) -> int: """Returns the number of samples/batches this data loader returns.""" return len(cast(Sized, self._sampler)) - def _get_pipeline(self) -> Pipeline: + def _get_pipeline(self) -> tuple[ProcessPoolExecutor, Pipeline]: executor = _get_executor( self._shmem.name, self._collate_fn, self._num_workers, self._mp_ctx ) @@ -321,7 +321,7 @@ def get_pytorch_dataloader( dataset=dataset, shmem=shmem, sampler=_sampler, - fetch_fn=_fetch_fn, + fetch_fn=_fetch_fn, # pyre-ignore collate_fn=_collate_fn, mp_ctx=mp_ctx, num_workers=num_workers, diff --git a/src/spdl/pipeline/__init__.py b/src/spdl/pipeline/__init__.py index 9adb19f8..960d1708 100644 --- a/src/spdl/pipeline/__init__.py +++ b/src/spdl/pipeline/__init__.py @@ -6,29 +6,18 @@ """Implements :py:class:`~spdl.pipeline.Pipeline`, a generic task execution engine.""" -# pyre-unsafe - -from typing import Any - -from . import _builder, _hook, _pipeline, _utils - -_mods = [ - _builder, - _hook, - _pipeline, - _utils, +# pyre-strict + +from ._builder import PipelineBuilder, PipelineFailure +from ._hook import PipelineHook, TaskStatsHook +from ._pipeline import Pipeline +from ._utils import create_task + +__all__ = [ + "Pipeline", + "PipelineBuilder", + "PipelineFailure", + "create_task", + "PipelineHook", + "TaskStatsHook", ] - -__all__ = sorted(item for mod in _mods for item in mod.__all__) - - -def __dir__(): - return __all__ - - -def __getattr__(name: str) -> Any: - for mod in _mods: - if name in mod.__all__: - return getattr(mod, name) - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/spdl/pipeline/_builder.py b/src/spdl/pipeline/_builder.py index be3394d2..3998130b 100644 --- a/src/spdl/pipeline/_builder.py +++ b/src/spdl/pipeline/_builder.py @@ -547,7 +547,7 @@ def pipe( /, *, concurrency: int = 1, - executor: type[Executor] | None = None, + executor: Executor | None = None, name: str | None = None, hooks: Sequence[PipelineHook] | None = None, report_stats_interval: float | None = None, diff --git a/src/spdl/pipeline/_convert.py b/src/spdl/pipeline/_convert.py index 9729dfbc..2634da30 100644 --- a/src/spdl/pipeline/_convert.py +++ b/src/spdl/pipeline/_convert.py @@ -48,6 +48,7 @@ def _to_batch_async_gen( ) -> Callable[[T], AsyncIterable[U]]: async def afunc(item: T) -> AsyncIterable[U]: loop = asyncio.get_running_loop() + # pyre-ignore: [6] for result in await loop.run_in_executor(executor, _wrap_gen, func, item): yield result @@ -111,4 +112,4 @@ def convert_to_async( return _to_async_gen(op, executor=executor) # Convert a regular sync function to async function. - return _to_async(op, executor=executor) + return _to_async(op, executor=executor) # pyre-ignore: [7]