Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually import classes/functions #320

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/spdl/dataloader/_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,23 @@ def __init__(
self._output_order = output_order

def _get_pipeline(self) -> Pipeline:
pipe_args = {
"concurrency": self._num_threads,
"output_order": self._output_order,
}

builder = PipelineBuilder().add_source(self._src)
if self._preprocessor:
builder.pipe(self._preprocessor, **pipe_args)
builder.pipe(
self._preprocessor,
concurrency=self._num_threads,
output_order=self._output_order,
)

if self._batch_size:
builder.aggregate(self._batch_size, drop_last=self._drop_last)

if self._aggregator:
builder.pipe(self._aggregator, **pipe_args)
builder.pipe(
self._aggregator,
concurrency=self._num_threads,
output_order=self._output_order,
)

# Transfer runs in the default thread pool (with num_threads=1)
# because GPU data transfer cannot be parallelized.
Expand Down
8 changes: 4 additions & 4 deletions src/spdl/dataloader/_pytorch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def __init__(
dataset: "torch.utils.data.dataset.Dataset[T]",
shmem: SharedMemory, # to keep the reference alive
sampler: "torch.utils.data.sampler.Sampler[K]",
fetch_fn: Callable[[K], U] | Callable[[list[K]], U],
collate_fn: Callable[[list[T]], U],
fetch_fn: Callable[[K], U],
collate_fn: Callable[[list[U]], V],
mp_ctx: mp.context.BaseContext,
num_workers: int,
timeout: float | None,
Expand All @@ -149,7 +149,7 @@ def __len__(self) -> int:
"""Returns the number of samples/batches this data loader returns."""
return len(cast(Sized, self._sampler))

def _get_pipeline(self) -> Pipeline:
def _get_pipeline(self) -> tuple[ProcessPoolExecutor, Pipeline]:
executor = _get_executor(
self._shmem.name, self._collate_fn, self._num_workers, self._mp_ctx
)
Expand Down Expand Up @@ -321,7 +321,7 @@ def get_pytorch_dataloader(
dataset=dataset,
shmem=shmem,
sampler=_sampler,
fetch_fn=_fetch_fn,
fetch_fn=_fetch_fn, # pyre-ignore
collate_fn=_collate_fn,
mp_ctx=mp_ctx,
num_workers=num_workers,
Expand Down
39 changes: 14 additions & 25 deletions src/spdl/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,18 @@

"""Implements :py:class:`~spdl.pipeline.Pipeline`, a generic task execution engine."""

# pyre-unsafe

from typing import Any

from . import _builder, _hook, _pipeline, _utils

_mods = [
_builder,
_hook,
_pipeline,
_utils,
# pyre-strict

from ._builder import PipelineBuilder, PipelineFailure
from ._hook import PipelineHook, TaskStatsHook
from ._pipeline import Pipeline
from ._utils import create_task

__all__ = [
"Pipeline",
"PipelineBuilder",
"PipelineFailure",
"create_task",
"PipelineHook",
"TaskStatsHook",
]

__all__ = sorted(item for mod in _mods for item in mod.__all__)


def __dir__():
return __all__


def __getattr__(name: str) -> Any:
for mod in _mods:
if name in mod.__all__:
return getattr(mod, name)

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
2 changes: 1 addition & 1 deletion src/spdl/pipeline/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def pipe(
/,
*,
concurrency: int = 1,
executor: type[Executor] | None = None,
executor: Executor | None = None,
name: str | None = None,
hooks: Sequence[PipelineHook] | None = None,
report_stats_interval: float | None = None,
Expand Down
3 changes: 2 additions & 1 deletion src/spdl/pipeline/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _to_batch_async_gen(
) -> Callable[[T], AsyncIterable[U]]:
async def afunc(item: T) -> AsyncIterable[U]:
loop = asyncio.get_running_loop()
# pyre-ignore: [6]
for result in await loop.run_in_executor(executor, _wrap_gen, func, item):
yield result

Expand Down Expand Up @@ -111,4 +112,4 @@ def convert_to_async(
return _to_async_gen(op, executor=executor)

# Convert a regular sync function to async function.
return _to_async(op, executor=executor)
return _to_async(op, executor=executor) # pyre-ignore: [7]
Loading