Skip to content

Commit

Permalink
Merge pull request #38 from NowanIlfideme/fix/scrict-classes-settings
Browse files Browse the repository at this point in the history
Fix: Strict Classes Support
  • Loading branch information
NowanIlfideme authored Jun 16, 2023
2 parents 95116f1 + b3bc1d0 commit 2cbb8b6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/pydantic_kedro/_dict_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def dict_to_model(dct: Union[Dict[str, Any], List[Any]]) -> BaseModel:
elif isinstance(value, dict):
raw[key] = _dict_manip(value)
# otherwise ignore
return pyd_kls(**raw) # Consider parse_obj_as(pyd_kls, raw) ?
keywords = dict(raw)
del keywords[KLS_MARK_STR]
return pyd_kls(**keywords) # Consider parse_obj_as(pyd_kls, keywords) ?


def model_to_dict(model: BaseModel) -> Dict[str, Any]:
Expand Down
49 changes: 49 additions & 0 deletions src/test/test_strict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Test strict models and BaseSettings subclasses."""

import pytest
from pydantic import BaseModel, BaseSettings
from typing_extensions import Literal

from pydantic_kedro import load_model, save_model


class ExSettings(BaseSettings):
"""Settings class."""

val: str


class ExModel(BaseModel):
"""Model class."""

x: int = 1
settings: ExSettings


class StrictModel(BaseModel):
"""Strict no-extras values."""

val: str

class Config:
"""Pydantic model configuration."""

extra = "forbid"


@pytest.mark.parametrize("format", ["auto", "zip", "folder", "yaml", "json"])
def test_rt_settings(tmpdir: str, format: Literal["auto", "zip", "folder", "yaml", "json"]):
"""Test settings round-trip."""
obj = ExModel(settings=ExSettings(val="val"))
save_model(obj, f"{tmpdir}/obj", format=format)
obj2 = load_model(f"{tmpdir}/obj", ExModel)
assert obj.settings == obj2.settings


@pytest.mark.parametrize("format", ["auto", "zip", "folder", "yaml", "json"])
def test_rt_strict_model(tmpdir: str, format: Literal["auto", "zip", "folder", "yaml", "json"]):
"""Test strict_model round-trip."""
obj = StrictModel(val="val")
save_model(obj, f"{tmpdir}/obj", format=format)
obj2 = load_model(f"{tmpdir}/obj", StrictModel)
assert obj.val == obj2.val

0 comments on commit 2cbb8b6

Please sign in to comment.