Skip to content

Commit

Permalink
Merge pull request #22 from NowanIlfideme/fix/generalize-yaml-libs
Browse files Browse the repository at this point in the history
Fix: Generalize YAML libs
  • Loading branch information
NowanIlfideme authored Jun 8, 2022
2 parents 90569e9 + f255611 commit 547c76f
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pydantic_yaml/compat/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ def inject_all():
for cls in get_str_like_types():
register_str_like(cls, method=str)
for cls in get_int_like_types():
register_int_like(cls)
register_int_like(cls, method=int)
10 changes: 5 additions & 5 deletions pydantic_yaml/compat/representers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial
from typing import Any, Callable, TypeVar

from .yaml_lib import yaml, dumper_classes
from .yaml_lib import yaml, dumper_classes, representer_classes

CType = TypeVar("CType")

Expand Down Expand Up @@ -66,8 +66,8 @@ def register_str_like(cls: CType, method: Callable[[Any], str] = str) -> CType:
cls
This is the same as the input `cls`.
"""
for dump_cls in dumper_classes:
dump_cls.add_representer(cls, partial(dump_as_str, method=method))
for x_cls in dumper_classes + representer_classes:
x_cls.add_representer(cls, partial(dump_as_str, method=method))
return cls


Expand All @@ -87,6 +87,6 @@ def register_int_like(cls: CType, method: Callable[[Any], int] = int) -> CType:
cls
This is the same as the input `cls`.
"""
for dump_cls in dumper_classes:
dump_cls.add_representer(cls, partial(dump_as_int, method=method))
for x_cls in dumper_classes + representer_classes:
x_cls.add_representer(cls, partial(dump_as_int, method=method))
return cls
6 changes: 3 additions & 3 deletions pydantic_yaml/compat/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict, Tuple, Union

from .representers import register_int_like, register_str_like
from .yaml_lib import yaml
from .yaml_lib import yaml_safe_dump

__all__ = ["YamlInt", "YamlIntEnum", "YamlStr", "YamlStrEnum"]

Expand Down Expand Up @@ -91,7 +91,7 @@ def __init_subclass__(cls):
res = super().__init_subclass__()
if not isabstract(cls):
vals: Dict[str, cls] = dict(cls.__members__)
yaml.safe_dump(vals)
yaml_safe_dump(vals)
return res

# def __new__(cls, v):
Expand All @@ -116,7 +116,7 @@ def __init_subclass__(cls):
res = super().__init_subclass__()
if not isabstract(cls):
vals: Dict[str, cls] = dict(cls.__members__)
yaml.safe_dump(vals)
yaml_safe_dump(vals)
return res


Expand Down
52 changes: 51 additions & 1 deletion pydantic_yaml/compat/yaml_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

# flake8: noqa

from typing import Any, Optional
from io import BytesIO, StringIO, IOBase


try:
import ruamel.yaml as yaml # type: ignore

Expand Down Expand Up @@ -31,5 +35,51 @@
except Exception:
pass

representer_classes = []
for _fld in dir(yaml):
try:
if "Representer" in _fld:
_obj = getattr(yaml, _fld)
representer_classes.append(_obj)
except Exception:
pass


def yaml_safe_load(stream) -> Any:
"""Wrapper around YAML library loader."""
if __yaml_lib__ in ["ruamel-old", "pyyaml"]:
return yaml.safe_load(stream)
# Fixing deprecation warning in new ruamel.yaml versions
assert __yaml_lib__ == "ruamel-new"
ruamel_obj = yaml.YAML(typ="safe", pure=True)
if isinstance(stream, str):
return ruamel_obj.load(StringIO(stream))
elif isinstance(stream, bytes):
return ruamel_obj.load(BytesIO(stream))
# we hope it's a stream, but don't enforce it
return ruamel_obj.load(stream)


def yaml_safe_dump(data: Any, stream=None, **kwds) -> Optional[Any]:
"""Wrapper around YAML library dumper."""
if __yaml_lib__ in ["ruamel-old", "pyyaml"]:
return yaml.safe_dump(data, stream=stream, **kwds)
# Fixing deprecation warning in new ruamel.yaml versions
assert __yaml_lib__ == "ruamel-new"
ruamel_obj = yaml.YAML(typ="safe", pure=True)
# Hacking some options that aren't
for kw in ["encoding", "default_flow_style", "default_style", "indent"]:
if kw in kwds:
setattr(ruamel_obj, kw, kwds[kw])

if stream is None:
text_stream = StringIO()
ruamel_obj.dump(data, stream=text_stream)
text_stream.seek(0) # otherwise we always get ''
return text_stream.read()
else:
ruamel_obj.dump(data, stream=stream)
return None


__all__ = ["yaml", "__yaml_lib__"]
__all__ = ["yaml_safe_dump", "yaml_safe_load", "yaml", "__yaml_lib__"]
4 changes: 3 additions & 1 deletion pydantic_yaml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
__all__ = [
"__version__",
"yaml",
"yaml_safe_dump",
"yaml_safe_load",
"YamlEnum",
"YamlInt",
"YamlIntEnum",
Expand All @@ -18,7 +20,7 @@
from .compat.old_enums import YamlEnum
from .compat.hacks import inject_all as _inject_yaml_hacks
from .compat.types import YamlInt, YamlIntEnum, YamlStr, YamlStrEnum
from .compat.yaml_lib import yaml
from .compat.yaml_lib import yaml, yaml_safe_dump, yaml_safe_load
from .ext.semver import SemVer
from .ext.versioned_model import VersionedYamlModel
from .mixin import YamlModelMixin, YamlModelMixinConfig
Expand Down
6 changes: 3 additions & 3 deletions pydantic_yaml/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

Model = TypeVar("Model", bound="BaseModel")

from .compat.yaml_lib import yaml
from .compat.yaml_lib import yaml_safe_dump, yaml_safe_load


ExtendedProto = Union[Protocol, Literal["yaml"]]
Expand Down Expand Up @@ -58,8 +58,8 @@ def is_yaml_requested(
class YamlModelMixinConfig:
"""Additional configuration for YamlModelMixin."""

yaml_loads: Callable[[str], Any] = yaml.safe_load # type: ignore
yaml_dumps: Callable[..., str] = yaml.safe_dump # type: ignore
yaml_loads: Callable[[str], Any] = yaml_safe_load # type: ignore
yaml_dumps: Callable[..., str] = yaml_safe_dump # type: ignore


class YamlModelMixin(metaclass=ModelMetaclass):
Expand Down
10 changes: 5 additions & 5 deletions pydantic_yaml/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from pydantic_yaml.compat.types import YamlInt, YamlIntEnum, YamlStr, YamlStrEnum
from pydantic_yaml.compat.yaml_lib import yaml
from pydantic_yaml.compat.yaml_lib import yaml_safe_dump, yaml_safe_load


class XSE(YamlStrEnum):
Expand Down Expand Up @@ -31,8 +31,8 @@ def test_str_enum():

x1 = XSE.a
x2 = XSE("b")
assert yaml.safe_load(yaml.safe_dump(x1)) == "a"
assert yaml.safe_load(yaml.safe_dump(x2)) == "b"
assert yaml_safe_load(yaml_safe_dump(x1)) == "a"
assert yaml_safe_load(yaml_safe_dump(x2)) == "b"

with pytest.raises(ValueError):
XSE("c")
Expand All @@ -42,8 +42,8 @@ def test_int_enum():
"""Test for YamlIntEnum class."""
x1 = XIE.a
x2 = XIE(2)
assert yaml.safe_load(yaml.safe_dump(x1)) == 1
assert yaml.safe_load(yaml.safe_dump(x2)) == 2
assert yaml_safe_load(yaml_safe_dump(x1)) == 1
assert yaml_safe_load(yaml_safe_dump(x2)) == 2

with pytest.raises(ValueError):
XIE(3)
Expand Down

0 comments on commit 547c76f

Please sign in to comment.