-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from NowanIlfideme/feature/inherit-configs
Feature: Inherit configs
- Loading branch information
Showing
7 changed files
with
235 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""Functions for internal use.""" | ||
|
||
from typing import Callable, Dict, Type | ||
|
||
from kedro.extras.datasets.pickle import PickleDataSet | ||
from kedro.io.core import AbstractDataSet | ||
from pydantic import BaseModel | ||
|
||
|
||
def get_kedro_map(kls: Type[BaseModel]) -> Dict[Type, Callable[[str], AbstractDataSet]]: | ||
"""Get type-to-dataset mapper for a Pydantic class.""" | ||
if not (isinstance(kls, type) and issubclass(kls, BaseModel)): | ||
raise TypeError(f"Must pass a BaseModel subclass; got {kls!r}") | ||
kedro_map: Dict[Type, Callable[[str], AbstractDataSet]] = {} | ||
# Go through bases of `kls` in order | ||
base_classes = reversed(kls.mro()) | ||
for base_i in base_classes: | ||
# Get config class (if it's defined) | ||
cfg_i = getattr(base_i, "__config__", None) | ||
if cfg_i is None: | ||
continue | ||
# Get kedro_map (if it's defined) | ||
upd = getattr(cfg_i, "kedro_map", None) | ||
if upd is None: | ||
continue | ||
elif isinstance(upd, dict): | ||
# Detailed checks (to help users fix stuff) | ||
bad_keys = [] | ||
bad_vals = [] | ||
for k, v in upd.items(): | ||
if isinstance(k, type): | ||
if callable(v): | ||
kedro_map[k] = v # TODO: Check callable signature? | ||
else: | ||
bad_vals.append(v) | ||
else: | ||
bad_keys.append(k) | ||
if len(bad_keys) > 0: | ||
raise TypeError(f"Keys in `kedro_map` must be types, but got bad keys: {bad_keys}") | ||
if len(bad_vals) > 0: | ||
raise TypeError( | ||
"Values in `kedro_map` must be callable (or types)," | ||
f" but got bad values: {bad_vals}" | ||
) | ||
else: | ||
raise TypeError( | ||
f"The `kedro_map` in config class {base_i.__qualname__} must be a dict, but got {upd!r}" | ||
) | ||
return kedro_map | ||
|
||
|
||
def get_kedro_default(kls: Type[BaseModel]) -> Callable[[str], AbstractDataSet]: | ||
"""Get default Kedro dataset creator.""" | ||
# Go backwards through bases of `kls` until you find a default value | ||
rev_bases = kls.mro() | ||
for base_i in rev_bases: | ||
# Get config class (if defined) | ||
cfg_i = getattr(base_i, "__config__", None) | ||
if cfg_i is None: | ||
continue | ||
# Get kedro_default (if it's defined) | ||
default = getattr(cfg_i, "kedro_default", None) | ||
if default is None: | ||
continue | ||
elif callable(default): | ||
# Special check for types | ||
if isinstance(default, type) and not issubclass(default, AbstractDataSet): | ||
raise TypeError( | ||
"The `kedro_default` must be an AbstractDataSet or callable that creates one," | ||
f" but got {default!r}" | ||
) | ||
# TODO: Check callable signature? | ||
return default | ||
else: | ||
raise TypeError( | ||
"The `kedro_default` must be an AbstractDataSet or callable that creates one," | ||
f" but got {default!r}" | ||
) | ||
|
||
return PickleDataSet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
"""Tests for proper inheritence of classes.""" | ||
|
||
from pathlib import Path | ||
from typing import Type, Union | ||
|
||
import pandas as pd | ||
import pytest | ||
from kedro.extras.datasets.pandas import CSVDataSet, ParquetDataSet | ||
from kedro.extras.datasets.pickle import PickleDataSet | ||
from pydantic import BaseModel | ||
|
||
from pydantic_kedro import PydanticFolderDataSet | ||
|
||
dfx = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "c"]) | ||
|
||
|
||
class BaseA(BaseModel): | ||
"""First model in hierarchy, using Parquet for Pandas.""" | ||
|
||
class Config: | ||
"""Config for pydantic-kedro.""" | ||
|
||
arbitrary_types_allowed = True | ||
kedro_map = {pd.DataFrame: ParquetDataSet} | ||
|
||
|
||
class Model1A(BaseA): | ||
"""Model with Parquet dataset base.""" | ||
|
||
df: pd.DataFrame | ||
|
||
|
||
def csv_ds(path: str) -> CSVDataSet: | ||
"""Create a CSV dataset.""" | ||
return CSVDataSet(path, save_args=dict(index=False), load_args=dict()) | ||
|
||
|
||
class BaseB(BaseA): | ||
"""Second model in hierarchy, using CSV for Pandas.""" | ||
|
||
class Config: | ||
"""Config for pydantic-kedro.""" | ||
|
||
kedro_map = {pd.DataFrame: csv_ds} | ||
|
||
|
||
class Model1B(BaseB): | ||
"""Model with CSV dataset base.""" | ||
|
||
df: pd.DataFrame | ||
|
||
|
||
class BaseC(BaseB): | ||
"""Third model in hierarchy, not providing any kedro_map.""" | ||
|
||
|
||
class Model1C(BaseC): | ||
"""Model with CSV dataset base (again).""" | ||
|
||
df: pd.DataFrame | ||
|
||
|
||
class Fake: | ||
"""Fake class.""" | ||
|
||
|
||
class BaseD(BaseC): | ||
"""Fourth model in hierarchy, providing updated kedro_map (for Fake) and updated default. | ||
However, since we pseudo-inherit `{pd.DataFrame: csv_ds}` mapping from BaseB, | ||
""" | ||
|
||
class Config: | ||
"""Config for pydantic-kedro.""" | ||
|
||
kedro_map = {Fake: PickleDataSet} | ||
kedro_default = ParquetDataSet # Bad idea in practice, but this is for the test | ||
|
||
|
||
class Model1D(BaseD): | ||
"""Model with CSV dataset base, even though we changed other config parts.""" | ||
|
||
df: pd.DataFrame | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["model_type", "ds_type"], | ||
[[Model1A, ParquetDataSet], [Model1B, CSVDataSet], [Model1C, CSVDataSet], [Model1D, CSVDataSet]], | ||
) | ||
def test_pandas_flat_model( | ||
tmpdir, | ||
model_type: Type[Union[Model1A, Model1B, Model1C, Model1D]], | ||
ds_type: Type[Union[ParquetDataSet, CSVDataSet]], | ||
): | ||
"""Test roundtripping of the different dataset models.""" | ||
# Create and save model | ||
model = model_type(df=dfx) | ||
path = Path(f"{tmpdir}/model_on_disk") | ||
PydanticFolderDataSet(str(path)).save(model) | ||
# Try loading with the supposed dataframe type | ||
found_df = ds_type(str(path / ".df")).load() | ||
assert isinstance(found_df, pd.DataFrame) |