diff --git a/pyproject.toml b/pyproject.toml index a834769..064bbbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ ] dependencies = [ "pydantic>=1.10.0,<2", - "pydantic-yaml>=1.0.0a2", + "pydantic-yaml>=1.1.2", "kedro", "fsspec", ] @@ -39,7 +39,9 @@ dev = [ "mypy==1.4.1", "pytest==7.4.0", # required for testing - "kedro[pandas]", + "pandas~=1.5.3", + "pyspark~=3.4.1", + "kedro[pandas,spark]", ] docs = [ "mkdocs", diff --git a/src/pydantic_kedro/_local_caching.py b/src/pydantic_kedro/_local_caching.py new file mode 100644 index 0000000..298e4a3 --- /dev/null +++ b/src/pydantic_kedro/_local_caching.py @@ -0,0 +1,76 @@ +"""Implementation of local caching. + +When we load something from a remote location, we currently need to copy it to +the local disk. This is a limitation of `pydantic-kedro` due to particular +libraries (e.g. Spark) not working with `fsspec` URLs. + +Ideally we would just use a `tempfile.TemporaryDirectory`, however because some +libraries do lazy loading (Spark, Polars, so many...) we actually need to +instantiate the files locally. +""" + +import atexit +import logging +import shutil +import tempfile +from pathlib import Path +from typing import Union + +logger = logging.getLogger(__name__) + +_INITIAL_TMPDIR: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory(prefix="pydantic_kedro_") +PYD_KEDRO_CACHE_DIR: Path = Path(_INITIAL_TMPDIR.name) +"""Local-ish cache directory for pydantic-kedro. + +DO NOT MODIFY - use `set_cache_dir(path)` and `get_cache_dir()` instead. + +TODO: Consider using module-level getattr. See https://peps.python.org/pep-0562/ +""" + + +def set_cache_dir(path: Union[Path, str]) -> None: + """Set the 'local' caching directory for pydantic-kedro. + + For Spark and other multi-machine setups, it might make more sense to use + a common mount location. + """ + global PYD_KEDRO_CACHE_DIR, _INITIAL_TMPDIR + + cache_dir = Path(path).resolve() + logger.info("Preparing to set cache directory to: %s", cache_dir) + logger.info("Clearing old path: %s", PYD_KEDRO_CACHE_DIR) + remove_temp_objects() + + if cache_dir.exists(): + logger.warning("Cache path exists, reusing existing path: %s", cache_dir) + else: + logger.warning("Creating cache directory: %s", cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + PYD_KEDRO_CACHE_DIR = cache_dir + + +def get_cache_dir() -> Path: + """Get caching directory for pydantic-kedro.""" + global PYD_KEDRO_CACHE_DIR + + return PYD_KEDRO_CACHE_DIR + + +def remove_temp_objects() -> None: + """Remove temporary objects at exist. + + This will be called at the exit of your application + + NOTE: This will NOT handle clearing objects when you change the cache + directory outside of `set_cache_dir()`. + """ + global PYD_KEDRO_CACHE_DIR, _INITIAL_TMPDIR + + shutil.rmtree(PYD_KEDRO_CACHE_DIR, ignore_errors=True) + PYD_KEDRO_CACHE_DIR.unlink(missing_ok=True) + if _INITIAL_TMPDIR is not None: + # We no longer use this directory + _INITIAL_TMPDIR.cleanup() + + +atexit.register(remove_temp_objects) diff --git a/src/pydantic_kedro/datasets/folder.py b/src/pydantic_kedro/datasets/folder.py index 2da7057..1d524d7 100644 --- a/src/pydantic_kedro/datasets/folder.py +++ b/src/pydantic_kedro/datasets/folder.py @@ -18,6 +18,7 @@ from pydantic_kedro._dict_io import PatchPydanticIter, dict_to_model from pydantic_kedro._internals import get_kedro_default, get_kedro_map, import_string +from pydantic_kedro._local_caching import get_cache_dir __all__ = ["PydanticFolderDataSet"] @@ -216,17 +217,18 @@ def _load(self) -> BaseModel: if isinstance(fs, LocalFileSystem): return self._load_local(self._filepath) else: - from tempfile import TemporaryDirectory - - with TemporaryDirectory(prefix="pyd_kedro_") as tmpdir: - # Copy from remote... yes, I know, not ideal! - m_remote = fsspec.get_mapper(self._filepath) - m_local = fsspec.get_mapper(tmpdir) - for k, v in m_remote.items(): - m_local[k] = v - - # Load locally - return self._load_local(tmpdir) + # Making a temp directory in the current cache dir location + tmpdir = get_cache_dir() / str(uuid4()).replace("-", "") + tmpdir.mkdir(exist_ok=False, parents=True) + + # Copy from remote... yes, this is not ideal! + m_remote = fsspec.get_mapper(self._filepath) + m_local = fsspec.get_mapper(str(tmpdir)) + for k, v in m_remote.items(): + m_local[k] = v + + # Load locally + return self._load_local(str(tmpdir)) def _load_local(self, filepath: str) -> BaseModel: """Load Pydantic model from the local filepath. diff --git a/src/pydantic_kedro/datasets/zip.py b/src/pydantic_kedro/datasets/zip.py index 0ffead9..c0f10a5 100644 --- a/src/pydantic_kedro/datasets/zip.py +++ b/src/pydantic_kedro/datasets/zip.py @@ -2,12 +2,15 @@ from tempfile import TemporaryDirectory from typing import Any, Dict +from uuid import uuid4 import fsspec from fsspec.implementations.zip import ZipFileSystem from kedro.io.core import AbstractDataSet from pydantic import BaseModel +from pydantic_kedro._local_caching import get_cache_dir + from .folder import PydanticFolderDataSet @@ -50,18 +53,20 @@ def _load(self) -> BaseModel: Pydantic model. """ filepath = self._filepath - with TemporaryDirectory(prefix="pyd_kedro_") as tmpdir: - m_local = fsspec.get_mapper(tmpdir) - # Unzip via copying to folder - with fsspec.open(filepath) as zip_file: - zip_fs = ZipFileSystem(fo=zip_file) # type: ignore - m_zip = zip_fs.get_mapper() - for k, v in m_zip.items(): - m_local[k] = v - zip_fs.close() - # Load folder dataset - pfds = PydanticFolderDataSet(tmpdir) - res = pfds.load() + # Making a temp directory in the current cache dir location + tmpdir = get_cache_dir() / str(uuid4()).replace("-", "") + tmpdir.mkdir(exist_ok=False, parents=True) + m_local = fsspec.get_mapper(str(tmpdir)) + # Unzip via copying to folder + with fsspec.open(filepath) as zip_file: + zip_fs = ZipFileSystem(fo=zip_file) # type: ignore + m_zip = zip_fs.get_mapper() + for k, v in m_zip.items(): + m_local[k] = v + zip_fs.close() + # Load folder dataset + pfds = PydanticFolderDataSet(str(tmpdir)) + res = pfds.load() return res def _save(self, data: BaseModel) -> None: diff --git a/src/test/test_ds_spark.py b/src/test/test_ds_spark.py new file mode 100644 index 0000000..320114b --- /dev/null +++ b/src/test/test_ds_spark.py @@ -0,0 +1,56 @@ +"""Test dataset for PySpark specifically.""" + +from typing import Any, Union + +import pytest +from kedro.extras.datasets.spark import SparkDataSet +from pyspark.sql import DataFrame, SparkSession + +from pydantic_kedro import ( + ArbModel, + PydanticAutoDataSet, + PydanticFolderDataSet, + PydanticZipDataSet, +) + +Kls = Union[PydanticAutoDataSet, PydanticFolderDataSet, PydanticZipDataSet] + + +class _SparkModel(ArbModel): + """Spark model, configured to use SparkDataSet (mult-file parquet).""" + + class Config(ArbModel.Config): + kedro_map = {DataFrame: SparkDataSet} + + +class FlatSparkModel(_SparkModel): + """Flat model that tests Spark using Picke dataset (default).""" + + df: DataFrame + val: int + + +@pytest.fixture +def spark() -> SparkSession: + """Create a Spark session for testing.""" + return SparkSession.Builder().appName("pydantic-kedro-testing").getOrCreate() + + +@pytest.mark.parametrize("kls", [PydanticAutoDataSet, PydanticFolderDataSet, PydanticZipDataSet]) +@pytest.mark.parametrize( + "df_raw", + [ + [{"a": 1, "b": 2, "c": 3}], + ], +) +def test_spark_flat_model(kls: Kls, df_raw: list[dict[str, Any]], spark: SparkSession, tmpdir): + """Test roundtripping of the flat Spark model, using Kedro's SparkDataSet.""" + dfx = spark.createDataFrame(df_raw) + mdl = FlatSparkModel(df=dfx, val=1) + paths = [f"{tmpdir}/model_on_disk", f"memory://{tmpdir}/model_in_memory"] + for path in paths: + ds: Kls = kls(path) # type: ignore + ds.save(mdl) + m2 = ds.load() + assert isinstance(m2, FlatSparkModel) + assert m2.df.count() == mdl.df.count()