Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Experiment with validation of subfields #14

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion exca/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _add_name(
obj: pydantic.BaseModel, propagate_defaults: bool = False
) -> pydantic.BaseModel:
"""Provide owner object to the infra"""

private = obj.__pydantic_private__ or {}
params = collections.ChainMap(dict(obj), private)
for name, val in params.items():
Expand Down Expand Up @@ -84,10 +85,11 @@ def _add_name(


@pydantic.model_validator(mode="before")
def model_with_infra_validator_before(obj: tp.Any) -> tp.Any:
def model_with_infra_validator_before(cls, obj: tp.Any) -> tp.Any:
"""Provide owner object to the infra
(this is set to the owner class during __set_name__)
"""
utils.check_extra_forbid(cls)
if not isinstance(obj, dict):
return obj # should not happen
for name, val in obj.items():
Expand Down Expand Up @@ -149,6 +151,7 @@ def __set_name__(self, owner: tp.Type[pydantic.BaseModel], name: str) -> None:
msg += f"{cls} must inherit from pydantic.BaseModel"
raise RuntimeError(msg)
owner.model_config.setdefault("extra", "forbid")
owner.model_config.setdefault("revalidate_instances", "always")
self._infra_name = name
# set mechanism to provide owner obj to the infra:
owner._model_with_infra_validator_after = model_with_infra_validator_after # type: ignore
Expand Down
17 changes: 17 additions & 0 deletions exca/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,23 @@ def test_weird_types(tmp_path: Path) -> None:
whatever.build()


def test_recursive_model() -> None:
class SubData(pydantic.BaseModel):
x: int = 12

class Recursive(pydantic.BaseModel):
submodels: list["Recursive"] = []
subd: SubData = SubData()
infra: TaskInfra = TaskInfra()

r = Recursive(submodels=[{}, {}], subd={"y": 3}) # type: ignore
print("Cfg", SubData.model_config)
r = Recursive(submodels=[{}, {}], subd={"y": 3}) # type: ignore
r.subd.y = 13
print(r.subd.y)
raise


def test_defined_in_main() -> None:
try:
import neuralset as ns
Expand Down
30 changes: 30 additions & 0 deletions exca/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,38 @@ def find_models(
return out


_CHECKED: list[tp.Type[pydantic.BaseModel]] = []


def check_extra_forbid(cls: tp.Type[pydantic.BaseModel]) -> None:
print("Checking", cls)
if cls in _CHECKED:
print("Bypassing", cls)
return
_CHECKED.append(cls)
cfg = cls.model_config
if "extra" not in cfg:
msg = f"Automatically setting extra='forbid' for {cls.__name__} "
msg += "(bypass by explicitely setting: model_config = pydantic.ConfigDict(extra='forbid'))"
logging.debug(msg)
print(f"setting extra to {cls}")
cfg["extra"] = "forbid"
cls.model_config = cfg
print(cls.model_config)
print(cls.model_fields)
for val in cls.model_fields.values():
# print(f"Here is {val}")
for annot in _pydantic_hints(val.annotation):
print("SubChecking", annot)
try:
check_extra_forbid(annot)
except Exception as e:
raise ValueError(f"Failing for {val.annotation} ({annot=})") from e


def _pydantic_hints(hint: tp.Any) -> tp.List[tp.Type[pydantic.BaseModel]]:
"""Checks if a type hint contains pydantic models"""
print("checking sub")
try:
if issubclass(hint, pydantic.BaseModel):
return [hint]
Expand Down
Loading