Skip to content

Commit

Permalink
Merge pull request #64 from NowanIlfideme/feature/spark-tests
Browse files Browse the repository at this point in the history
Add PySpark support (SparkDataSet)
  • Loading branch information
NowanIlfideme authored Sep 12, 2023
2 parents a416722 + 6000cb2 commit 83ee958
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 25 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
]
dependencies = [
"pydantic>=1.10.0,<2",
"pydantic-yaml>=1.0.0a2",
"pydantic-yaml>=1.1.2",
"kedro",
"fsspec",
]
Expand All @@ -39,7 +39,9 @@ dev = [
"mypy==1.4.1",
"pytest==7.4.0",
# required for testing
"kedro[pandas]",
"pandas~=1.5.3",
"pyspark~=3.4.1",
"kedro[pandas,spark]",
]
docs = [
"mkdocs",
Expand Down
76 changes: 76 additions & 0 deletions src/pydantic_kedro/_local_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Implementation of local caching.
When we load something from a remote location, we currently need to copy it to
the local disk. This is a limitation of `pydantic-kedro` due to particular
libraries (e.g. Spark) not working with `fsspec` URLs.
Ideally we would just use a `tempfile.TemporaryDirectory`, however because some
libraries do lazy loading (Spark, Polars, so many...) we actually need to
instantiate the files locally.
"""

import atexit
import logging
import shutil
import tempfile
from pathlib import Path
from typing import Union

logger = logging.getLogger(__name__)

_INITIAL_TMPDIR: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory(prefix="pydantic_kedro_")
PYD_KEDRO_CACHE_DIR: Path = Path(_INITIAL_TMPDIR.name)
"""Local-ish cache directory for pydantic-kedro.
DO NOT MODIFY - use `set_cache_dir(path)` and `get_cache_dir()` instead.
TODO: Consider using module-level getattr. See https://peps.python.org/pep-0562/
"""


def set_cache_dir(path: Union[Path, str]) -> None:
"""Set the 'local' caching directory for pydantic-kedro.
For Spark and other multi-machine setups, it might make more sense to use
a common mount location.
"""
global PYD_KEDRO_CACHE_DIR, _INITIAL_TMPDIR

cache_dir = Path(path).resolve()
logger.info("Preparing to set cache directory to: %s", cache_dir)
logger.info("Clearing old path: %s", PYD_KEDRO_CACHE_DIR)
remove_temp_objects()

if cache_dir.exists():
logger.warning("Cache path exists, reusing existing path: %s", cache_dir)
else:
logger.warning("Creating cache directory: %s", cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
PYD_KEDRO_CACHE_DIR = cache_dir


def get_cache_dir() -> Path:
"""Get caching directory for pydantic-kedro."""
global PYD_KEDRO_CACHE_DIR

return PYD_KEDRO_CACHE_DIR


def remove_temp_objects() -> None:
"""Remove temporary objects at exist.
This will be called at the exit of your application
NOTE: This will NOT handle clearing objects when you change the cache
directory outside of `set_cache_dir()`.
"""
global PYD_KEDRO_CACHE_DIR, _INITIAL_TMPDIR

shutil.rmtree(PYD_KEDRO_CACHE_DIR, ignore_errors=True)
PYD_KEDRO_CACHE_DIR.unlink(missing_ok=True)
if _INITIAL_TMPDIR is not None:
# We no longer use this directory
_INITIAL_TMPDIR.cleanup()


atexit.register(remove_temp_objects)
24 changes: 13 additions & 11 deletions src/pydantic_kedro/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from pydantic_kedro._dict_io import PatchPydanticIter, dict_to_model
from pydantic_kedro._internals import get_kedro_default, get_kedro_map, import_string
from pydantic_kedro._local_caching import get_cache_dir

__all__ = ["PydanticFolderDataSet"]

Expand Down Expand Up @@ -216,17 +217,18 @@ def _load(self) -> BaseModel:
if isinstance(fs, LocalFileSystem):
return self._load_local(self._filepath)
else:
from tempfile import TemporaryDirectory

with TemporaryDirectory(prefix="pyd_kedro_") as tmpdir:
# Copy from remote... yes, I know, not ideal!
m_remote = fsspec.get_mapper(self._filepath)
m_local = fsspec.get_mapper(tmpdir)
for k, v in m_remote.items():
m_local[k] = v

# Load locally
return self._load_local(tmpdir)
# Making a temp directory in the current cache dir location
tmpdir = get_cache_dir() / str(uuid4()).replace("-", "")
tmpdir.mkdir(exist_ok=False, parents=True)

# Copy from remote... yes, this is not ideal!
m_remote = fsspec.get_mapper(self._filepath)
m_local = fsspec.get_mapper(str(tmpdir))
for k, v in m_remote.items():
m_local[k] = v

# Load locally
return self._load_local(str(tmpdir))

def _load_local(self, filepath: str) -> BaseModel:
"""Load Pydantic model from the local filepath.
Expand Down
29 changes: 17 additions & 12 deletions src/pydantic_kedro/datasets/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from tempfile import TemporaryDirectory
from typing import Any, Dict
from uuid import uuid4

import fsspec
from fsspec.implementations.zip import ZipFileSystem
from kedro.io.core import AbstractDataSet
from pydantic import BaseModel

from pydantic_kedro._local_caching import get_cache_dir

from .folder import PydanticFolderDataSet


Expand Down Expand Up @@ -50,18 +53,20 @@ def _load(self) -> BaseModel:
Pydantic model.
"""
filepath = self._filepath
with TemporaryDirectory(prefix="pyd_kedro_") as tmpdir:
m_local = fsspec.get_mapper(tmpdir)
# Unzip via copying to folder
with fsspec.open(filepath) as zip_file:
zip_fs = ZipFileSystem(fo=zip_file) # type: ignore
m_zip = zip_fs.get_mapper()
for k, v in m_zip.items():
m_local[k] = v
zip_fs.close()
# Load folder dataset
pfds = PydanticFolderDataSet(tmpdir)
res = pfds.load()
# Making a temp directory in the current cache dir location
tmpdir = get_cache_dir() / str(uuid4()).replace("-", "")
tmpdir.mkdir(exist_ok=False, parents=True)
m_local = fsspec.get_mapper(str(tmpdir))
# Unzip via copying to folder
with fsspec.open(filepath) as zip_file:
zip_fs = ZipFileSystem(fo=zip_file) # type: ignore
m_zip = zip_fs.get_mapper()
for k, v in m_zip.items():
m_local[k] = v
zip_fs.close()
# Load folder dataset
pfds = PydanticFolderDataSet(str(tmpdir))
res = pfds.load()
return res

def _save(self, data: BaseModel) -> None:
Expand Down
56 changes: 56 additions & 0 deletions src/test/test_ds_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Test dataset for PySpark specifically."""

from typing import Any, Union

import pytest
from kedro.extras.datasets.spark import SparkDataSet
from pyspark.sql import DataFrame, SparkSession

from pydantic_kedro import (
ArbModel,
PydanticAutoDataSet,
PydanticFolderDataSet,
PydanticZipDataSet,
)

Kls = Union[PydanticAutoDataSet, PydanticFolderDataSet, PydanticZipDataSet]


class _SparkModel(ArbModel):
"""Spark model, configured to use SparkDataSet (mult-file parquet)."""

class Config(ArbModel.Config):
kedro_map = {DataFrame: SparkDataSet}


class FlatSparkModel(_SparkModel):
"""Flat model that tests Spark using Picke dataset (default)."""

df: DataFrame
val: int


@pytest.fixture
def spark() -> SparkSession:
"""Create a Spark session for testing."""
return SparkSession.Builder().appName("pydantic-kedro-testing").getOrCreate()


@pytest.mark.parametrize("kls", [PydanticAutoDataSet, PydanticFolderDataSet, PydanticZipDataSet])
@pytest.mark.parametrize(
"df_raw",
[
[{"a": 1, "b": 2, "c": 3}],
],
)
def test_spark_flat_model(kls: Kls, df_raw: list[dict[str, Any]], spark: SparkSession, tmpdir):
"""Test roundtripping of the flat Spark model, using Kedro's SparkDataSet."""
dfx = spark.createDataFrame(df_raw)
mdl = FlatSparkModel(df=dfx, val=1)
paths = [f"{tmpdir}/model_on_disk", f"memory://{tmpdir}/model_in_memory"]
for path in paths:
ds: Kls = kls(path) # type: ignore
ds.save(mdl)
m2 = ds.load()
assert isinstance(m2, FlatSparkModel)
assert m2.df.count() == mdl.df.count()

0 comments on commit 83ee958

Please sign in to comment.