diff --git a/src/spdl/dataloader/__init__.py b/src/spdl/dataloader/__init__.py index fc025eae..ec981aa3 100644 --- a/src/spdl/dataloader/__init__.py +++ b/src/spdl/dataloader/__init__.py @@ -6,29 +6,24 @@ """Task specific data loading solutions based on :py:class:`~spdl.pipeline.Pipeline`.""" -# pyre-unsafe - from typing import Any -from . import _dataloader, _pytorch_dataloader +# pyre-unsafe +from ._dataloader import DataLoader +from ._pytorch_dataloader import get_pytorch_dataloader, PyTorchDataLoader -_mods = [ - _dataloader, - _pytorch_dataloader, +__all__ = [ + "DataLoader", + "get_pytorch_dataloader", + "PyTorchDataLoader", ] -__all__ = sorted(item for mod in _mods for item in mod.__all__) - -def __dir__(): +def __dir__() -> list[str]: return __all__ def __getattr__(name: str) -> Any: - for mod in _mods: - if name in mod.__all__: - return getattr(mod, name) - # For backward compatibility if name == "iterate_in_subprocess": import warnings