Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing lambda functions to Nested #1382

Merged
merged 16 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 102 additions & 41 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
import decimal
import math
import typing
import warnings
from collections.abc import Mapping as _Mapping

Expand Down Expand Up @@ -407,58 +408,109 @@ class Nested(Field):

Examples: ::

user = fields.Nested(UserSchema)
user2 = fields.Nested('UserSchema') # Equivalent to above
collaborators = fields.Nested(UserSchema, many=True, only=('id',))
parent = fields.Nested('self')
class ChildSchema(Schema):
id = fields.Str()
name = fields.Str()
# Use lambda functions when you need two-way nesting or self-nesting
parent = fields.Nested(lambda: ParentSchema(only=("id",)), dump_only=True)
siblings = fields.List(fields.Nested(lambda: ChildSchema(only=("id", "name"))))

class ParentSchema(Schema):
id = fields.Str()
children = fields.List(
fields.Nested(ChildSchema(only=("id", "parent", "siblings")))
)
spouse = fields.Nested(lambda: ParentSchema(only=("id",)))

When passing a `Schema <marshmallow.Schema>` instance as the first argument,
the instance's ``exclude``, ``only``, and ``many`` attributes will be respected.

Therefore, when passing the ``exclude``, ``only``, or ``many`` arguments to `fields.Nested`,
you should pass a `Schema <marshmallow.Schema>` class (not an instance) as the first argument.

::

# Yes
author = fields.Nested(UserSchema, only=('id', 'name'))

# No
author = fields.Nested(UserSchema(), only=('id', 'name'))

:param Schema nested: The Schema class or class name (string)
to nest, or ``"self"`` to nest the :class:`Schema` within itself.
:param Schema nested: A `Schema` class, `Schema` instance, or class name (string)
to nest, or a callable that returns a `Schema` instance.
:param tuple exclude: A list or tuple of fields to exclude.
:param only: A list or tuple of fields to marshal. If `None`, all fields are marshalled.
This parameter takes precedence over ``exclude``.
:param bool many: Whether the field is a collection of objects.
:param unknown: Whether to exclude, include, or raise an error for unknown
fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.
:param kwargs: The same keyword arguments that :class:`Field` receives.

.. versionchanged:: 3.1.0
Deprecated ``only``, ``exclude``, and ``unknown`` parameters.
deckar01 marked this conversation as resolved.
Show resolved Hide resolved
Pass these to the schema instance instead.
``many`` was also deprecated in favor of the ``List(Nested(...))`` usage.
"""

default_error_messages = {"type": "Invalid type."}

def __init__(
self, nested, *, default=missing_, exclude=tuple(), only=None, **kwargs
self,
nested: typing.Union[
SchemaABC, typing.Type[SchemaABC], str, typing.Callable[[], SchemaABC]
],
*,
default=missing_,
**kwargs
):
# Raise error if only or exclude is passed as string, not list of strings
if only is not None and not is_collection(only):
raise StringNotCollectionError('"only" should be a collection of strings.')
if exclude is not None and not is_collection(exclude):
raise StringNotCollectionError(
'"exclude" should be a collection of strings.'
# Raise DeprecationWarnings for only, exclude, many, and unknown
if "only" in kwargs:
only = kwargs.pop("only")
if not is_collection(only):
raise StringNotCollectionError(
'"only" should be a collection of strings.'
)
warnings.warn(
"Passing `only` to `Nested` is deprecated. "
"Pass `only` to the schema instance instead.",
DeprecationWarning,
)
else:
only = None
if "exclude" in kwargs:
exclude = kwargs.pop("exclude")
if not is_collection(exclude):
raise StringNotCollectionError(
'"exclude" should be a collection of strings.'
)
warnings.warn(
"Passing `exclude` to `Nested` is deprecated. "
"Pass `only` to the schema instance instead.",
DeprecationWarning,
)
else:
exclude = tuple()

if "many" in kwargs:
many = kwargs.pop("many")
warnings.warn(
"Passing `many` to `Nested` is deprecated. "
"Use List(Nested(...)) instead.",
DeprecationWarning,
)
else:
many = False

if "unknown" in kwargs:
unknown = kwargs.pop("unknown")
warnings.warn(
"Passing `unknown` to `Nested` is deprecated. "
"Pass `unknown` to the schema instance instead.",
DeprecationWarning,
)
else:
unknown = None

self.nested = nested
self._schema = None # Cached Schema instance
# Deprecated attributes
self.only = only
self.exclude = exclude
self.many = kwargs.get("many", False)
self.unknown = kwargs.get("unknown")
self._schema = None # Cached Schema instance
self.many = many
self.unknown = unknown
super().__init__(default=default, **kwargs)

@property
def schema(self):
def schema(self) -> SchemaABC:
"""The nested Schema object.

.. versionchanged:: 1.0.0
Expand All @@ -467,24 +519,33 @@ def schema(self):
if not self._schema:
# Inherit context from parent.
context = getattr(self.parent, "context", {})
if isinstance(self.nested, SchemaABC):
self._schema = self.nested
if callable(self.nested) and not isinstance(self.nested, type):
sloria marked this conversation as resolved.
Show resolved Hide resolved
nested = self.nested()
else:
nested = self.nested

if isinstance(nested, SchemaABC):
self._schema = nested
self._schema.context.update(context)
else:
if isinstance(self.nested, type) and issubclass(self.nested, SchemaABC):
schema_class = self.nested
elif not isinstance(self.nested, (str, bytes)):
if isinstance(nested, type) and issubclass(nested, SchemaABC):
schema_class = nested
elif not isinstance(nested, (str, bytes)):
raise ValueError(
"Nested fields must be passed a "
"Schema, not {}.".format(self.nested.__class__)
"`Nested` fields must be passed a "
"`Schema`, not {}.".format(nested.__class__)
)
elif nested == "self":
schema_class = self.root.__class__
warnings.warn(
"Passing 'self' to `Nested` is deprecated. "
"Use `Nested(lambda: {Class}(...))` instead.".format(
Class=schema_class.__name__
),
DeprecationWarning,
)
elif self.nested == "self":
ret = self
while not isinstance(ret, SchemaABC):
ret = ret.parent
schema_class = ret.__class__
else:
schema_class = class_registry.get_class(self.nested)
schema_class = class_registry.get_class(nested)
self._schema = schema_class(
many=self.many,
only=self.only,
Expand Down
57 changes: 52 additions & 5 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,11 +1081,9 @@ class Meta:


def test_nested_custom_set_not_implementing_getitem():
"""
This test checks that Marshmallow can serialize implementations of
:mod:`collections.abc.MutableSequence`, with ``__getitem__`` arguments
that are not integers.
"""
# This test checks that marshmallow can serialize implementations of
# :mod:`collections.abc.MutableSequence`, with ``__getitem__`` arguments
# that are not integers.

class ListLikeParent:
"""
Expand Down Expand Up @@ -1171,6 +1169,55 @@ class ParentSchema(Schema):
assert "bah" not in grand_child


def test_nested_lambda():
class ChildSchema(Schema):
id = fields.Str()
name = fields.Str()
parent = fields.Nested(lambda: ParentSchema(only=("id",)), dump_only=True)
siblings = fields.List(fields.Nested(lambda: ChildSchema(only=("id", "name"))))

class ParentSchema(Schema):
id = fields.Str()
spouse = fields.Nested(lambda: ParentSchema(only=("id",)))
children = fields.List(
fields.Nested(lambda: ChildSchema(only=("id", "parent", "siblings")))
)

sch = ParentSchema()
data_to_load = {
"id": "p1",
"spouse": {"id": "p2"},
"children": [{"id": "c1", "siblings": [{"id": "c2", "name": "sis"}]}],
}
loaded = sch.load(data_to_load)
assert loaded == data_to_load

data_to_dump = dict(
id="p2",
spouse=dict(id="p2"),
children=[
dict(
id="c1",
name="bar",
parent=dict(id="p2"),
siblings=[dict(id="c2", name="sis")],
)
],
)
dumped = sch.dump(data_to_dump)
assert dumped == {
"id": "p2",
"spouse": {"id": "p2"},
"children": [
{
"id": "c1",
"parent": {"id": "p2"},
"siblings": [{"id": "c2", "name": "sis"}],
}
],
}


@pytest.mark.parametrize("data_key", ("f1", "f5", None))
def test_data_key_collision(data_key):
class MySchema(Schema):
Expand Down