Skip to content

Commit

Permalink
Merge pull request #13 from NowanIlfideme/feature/yaml-datasets
Browse files Browse the repository at this point in the history
Feature: YAML datasets
  • Loading branch information
NowanIlfideme authored Apr 26, 2023
2 parents 7ff43af + e2e8021 commit eb02d48
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 11 deletions.
6 changes: 5 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and [Datasets](https://docs.kedro.org/en/stable/data/kedro_io.html).

If you have a JSON-safe Pydantic model, you can use a
[PydanticJsonDataSet][pydantic_kedro.PydanticJsonDataSet]
or [PydanticYamlDataSet][pydantic_kedro.PydanticYamlDataSet]
to save your model to any `fsspec`-supported location:

```python
Expand All @@ -47,7 +48,10 @@ read_obj = ds.load()
assert read_obj.x == 1
```

Note that specifying [custom JSON encoders](https://docs.pydantic.dev/usage/exporting_models/#json_encoders) will work as usual.
> Note: YAML support is enabled by [`pydantic-yaml`](https://pydantic-yaml.readthedocs.io/en/latest/).
Note that specifying [custom JSON encoders](https://docs.pydantic.dev/usage/exporting_models/#json_encoders)
will work as usual, even for YAML models.

However, if your custom type is difficult or impossible to encode/decode via
JSON, read on to [Arbitrary Types](./arbitrary_types.md).
14 changes: 13 additions & 1 deletion docs/reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,16 @@

This is where to find auto-generated Python API docs.

::: pydantic_kedro
<!-- For simple models -->

::: pydantic_kedro.PydanticJsonDataSet

::: pydantic_kedro.PydanticYamlDataSet

<!-- For arbitrary models -->

::: pydantic_kedro.ArbModel

::: pydantic_kedro.PydanticFolderDataSet

::: pydantic_kedro.PydanticZipDataSet
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
]
dependencies = [
"pydantic>=1.10.0,<2",
# "pydantic-yaml", # not currently needed
"pydantic-yaml>=1.0.0a2",
"kedro",
"fsspec",
]
Expand Down
2 changes: 2 additions & 0 deletions src/pydantic_kedro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
"ArbModel",
"PydanticFolderDataSet",
"PydanticJsonDataSet",
"PydanticYamlDataSet",
"PydanticZipDataSet",
"__version__",
]

from .datasets.folder import PydanticFolderDataSet
from .datasets.json import PydanticJsonDataSet
from .datasets.yaml import PydanticYamlDataSet
from .datasets.zip import PydanticZipDataSet
from .models import ArbModel
from .version import __version__
4 changes: 3 additions & 1 deletion src/pydantic_kedro/datasets/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class PydanticJsonDataSet(AbstractDataSet[BaseModel, BaseModel]):
"""A Pydantic model with JSON-based load/save.
Please note that the Pydantic model must not have any JSON-unfriendly fields.
Please note that the Pydantic model must be JSON-serializable.
That means the fields are "pure" Pydantic fields,
or you have added `json_encoders` to the model config.
Example:
-------
Expand Down
95 changes: 95 additions & 0 deletions src/pydantic_kedro/datasets/yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""YAML dataset definition for Pydantic."""

from pathlib import PurePosixPath
from typing import Any, Dict, no_type_check

import fsspec
from fsspec import AbstractFileSystem
from kedro.io.core import AbstractDataSet, get_filepath_str, get_protocol_and_path
from pydantic import BaseModel, Field, create_model
from pydantic.utils import import_string
from pydantic_yaml import parse_yaml_file_as, to_yaml_file

KLS_MARK_STR = "class"


class _YamlPreLoader(BaseModel):
"""YAML pre-loader model."""

kls_mark_str: str = Field(alias=KLS_MARK_STR) # type: ignore


class PydanticYamlDataSet(AbstractDataSet[BaseModel, BaseModel]):
"""A Pydantic model with YAML-based load/save.
Please note that the Pydantic model must be JSON-serializable.
That means the fields are "pure" Pydantic fields,
or you have added `json_encoders` to the model config.
Example:
-------
```python
class MyModel(BaseModel):
x: str
ds = PydanticYamlDataSet('memory://path/to/model.yaml') # using memory to avoid tempfile
ds.save(MyModel(x="example"))
assert ds.load().x == "example"
```
"""

def __init__(self, filepath: str) -> None:
"""Create a new instance of PydanticYamlDataSet to load/save Pydantic models for given filepath.
Args:
----
filepath : The location of the YAML file.
"""
# parse the path and protocol (e.g. file, http, s3, etc.)
protocol, path = get_protocol_and_path(filepath)
self._protocol = protocol
self._filepath = PurePosixPath(path)
self._fs: AbstractFileSystem = fsspec.filesystem(self._protocol)

def _load(self) -> BaseModel:
"""Load Pydantic model from the filepath.
Returns
-------
Pydantic model.
"""
# using get_filepath_str ensures that the protocol and path
# are appended correctly for different filesystems
load_path = get_filepath_str(self._filepath, self._protocol)
with self._fs.open(load_path, mode="r") as f:
preloader = parse_yaml_file_as(_YamlPreLoader, f)
pyd_kls = import_string(preloader.kls_mark_str)
assert issubclass(pyd_kls, BaseModel), f"Type must be a Pydantic model, got {type(pyd_kls)!r}."
with self._fs.open(load_path, mode="r") as f:
res = parse_yaml_file_as(pyd_kls, f)
return res # type: ignore

@no_type_check
def _save(self, data: BaseModel) -> None:
"""Save Pydantic model to the filepath."""
# Add metadata to our Pydantic model
pyd_kls = type(data)
if KLS_MARK_STR in pyd_kls.__fields__.keys():
raise ValueError(f"Marker {KLS_MARK_STR!r} already exists as a field; can't dump model.")
pyd_kls_path = f"{pyd_kls.__module__}.{pyd_kls.__qualname__}"
tmp_kls = create_model(
pyd_kls.__name__,
__base__=pyd_kls,
__module__=pyd_kls.__module__,
**{KLS_MARK_STR: (str, pyd_kls_path)},
)
tmp_obj = tmp_kls(**data.dict())

# Open file and write to it
save_path = get_filepath_str(self._filepath, self._protocol)
with self._fs.open(save_path, mode="w") as f:
to_yaml_file(f, tmp_obj)

def _describe(self) -> Dict[str, Any]:
"""Return a dict that describes the attributes of the dataset."""
return dict(filepath=self._filepath, protocol=self._protocol)
16 changes: 9 additions & 7 deletions src/test/test_ds_simple.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
"""Tests the JSON dataset on a simple model."""
"""Tests the datasets on a simple model."""

from typing import Optional, Union
from typing import Optional

import pytest
from kedro.io.core import AbstractDataSet
from pydantic import BaseModel

from pydantic_kedro import (
PydanticFolderDataSet,
PydanticJsonDataSet,
PydanticYamlDataSet,
PydanticZipDataSet,
)


class SimpleTestModel(BaseModel):
"""My very own model.
NOTE: Since this is defined in `__main__`, this can only be loaded if ran in this notebook.
NOTE: Since this is defined in `__main__`, this can only be loaded if ran in this file.
"""

name: str
alter_ego: Optional[str] = None


Kls = Union[PydanticFolderDataSet, PydanticJsonDataSet, PydanticZipDataSet]
types = [PydanticJsonDataSet, PydanticYamlDataSet, PydanticFolderDataSet, PydanticZipDataSet]


@pytest.mark.parametrize(
"mdl", [SimpleTestModel(name="user"), SimpleTestModel(name="Dr. Jekyll", alter_ego="Mr. Hyde")]
)
@pytest.mark.parametrize("kls", [PydanticJsonDataSet, PydanticFolderDataSet, PydanticZipDataSet])
def test_simple_model_rt(mdl: SimpleTestModel, kls: Kls, tmpdir):
@pytest.mark.parametrize("kls", types)
def test_simple_model_rt(mdl: SimpleTestModel, kls: AbstractDataSet, tmpdir): # type: ignore
"""Tests whether a simple model survives roundtripping."""
paths = [f"{tmpdir}/model_on_disk", f"memory://{tmpdir}/model_in_memory"]
for path in paths:
ds: Kls = kls(path) # type: ignore
ds: AbstractDataSet = kls(path) # type: ignore
ds.save(mdl)
m2 = ds.load()
assert isinstance(m2, SimpleTestModel)
Expand Down

0 comments on commit eb02d48

Please sign in to comment.