Skip to content

Commit

Permalink
Fix propagating only and exclude to nested field instances (#1385)
Browse files Browse the repository at this point in the history
* Respect dotted `only` and `exclude` on nested schema instances

* Refactor: use _init_fields to re-initialize fields based on `only` and `exclude`

Also, be explicit about which helper methods are private

* Update changelog

* Refactor only/exclude propagation

This generalizes slightly better to accommodate #1382

* Fix behavior when dumping multiple times

Copy schema to avoid unwanted sharing of `only` and `exclude`
across instances

* Remove unnecessary falsy checks

* Update changelog
  • Loading branch information
sloria authored Sep 12, 2019
1 parent 8b3a326 commit d104a5c
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 26 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
---------

3.0.4 (unreleased)
++++++++++++++++++

Bug fixes:

- Fix propagating dot-delimited `only` and `exclude` parameters to nested schema instances (:issue:`1384`).
- Includes bug fix from 2.20.4 (:issue:`1160`).

3.0.3 (2019-09-04)
++++++++++++++++++

Expand Down
11 changes: 10 additions & 1 deletion src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,17 @@ def schema(self):
# Inherit context from parent.
context = getattr(self.parent, "context", {})
if isinstance(self.nested, SchemaABC):
self._schema = self.nested
self._schema = copy.deepcopy(self.nested)
self._schema.context.update(context)
# Respect only and exclude passed from parent and re-initialize fields
set_class = self._schema.set_class
if self.only is not None:
original = self._schema.only
self._schema.only = set_class(self.only).intersection(original)
if self.exclude:
original = self._schema.exclude
self._schema.exclude = set_class(self.exclude).union(original)
self._schema._init_fields()
else:
if isinstance(self.nested, type) and issubclass(self.nested, SchemaABC):
schema_class = self.nested
Expand Down
47 changes: 26 additions & 21 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,16 +378,10 @@ def __init__(
self.context = context or {}
self._normalize_nested_options()
#: Dictionary mapping field_names -> :class:`Field` objects
self.fields = self._init_fields()
self.dump_fields, self.load_fields = self.dict_class(), self.dict_class()
for field_name, field_obj in self.fields.items():
if field_obj.load_only:
self.load_fields[field_name] = field_obj
elif field_obj.dump_only:
self.dump_fields[field_name] = field_obj
else:
self.load_fields[field_name] = field_obj
self.dump_fields[field_name] = field_obj
self.fields = None
self.load_fields = None
self.dump_fields = None
self._init_fields()
messages = {}
messages.update(self._default_error_messages)
for cls in reversed(self.__class__.__mro__):
Expand Down Expand Up @@ -757,6 +751,7 @@ def _do_load(
self, data, *, many=None, partial=None, unknown=None, postprocess=True
):
"""Deserialize `data`, returning the deserialized result.
This method is private API.
:param data: The data to deserialize.
:param bool many: Whether to deserialize `data` as a collection. If `None`, the
Expand Down Expand Up @@ -844,7 +839,9 @@ def _do_load(
return result

def _normalize_nested_options(self):
"""Apply then flatten nested schema options"""
"""Apply then flatten nested schema options.
This method is private API.
"""
if self.only is not None:
# Apply the only option to nested fields.
self.__apply_nested_option("only", self.only, "intersection")
Expand Down Expand Up @@ -878,7 +875,9 @@ def __apply_nested_option(self, option_name, field_names, set_operation):
setattr(self.declared_fields[key], option_name, new_options)

def _init_fields(self):
"""Update fields based on schema options."""
"""Update self.fields, self.load_fields, and self.dump_fields based on schema options.
This method is private API.
"""
if self.opts.fields:
available_field_names = self.set_class(self.opts.fields)
else:
Expand Down Expand Up @@ -913,10 +912,19 @@ def _init_fields(self):
self._bind_field(field_name, field_obj)
fields_dict[field_name] = field_obj

load_fields, dump_fields = self.dict_class(), self.dict_class()
for field_name, field_obj in fields_dict.items():
if field_obj.load_only:
load_fields[field_name] = field_obj
elif field_obj.dump_only:
dump_fields[field_name] = field_obj
else:
load_fields[field_name] = field_obj
dump_fields[field_name] = field_obj

dump_data_keys = [
field_obj.data_key if field_obj.data_key is not None else name
for name, field_obj in fields_dict.items()
if not field_obj.load_only
for name, field_obj in dump_fields.items()
]
if len(dump_data_keys) != len(set(dump_data_keys)):
data_keys_duplicates = {
Expand All @@ -928,12 +936,7 @@ def _init_fields(self):
"Check the following field names and "
"data_key arguments: {}".format(list(data_keys_duplicates))
)

load_attributes = [
obj.attribute or name
for name, obj in fields_dict.items()
if not obj.dump_only
]
load_attributes = [obj.attribute or name for name, obj in load_fields.items()]
if len(load_attributes) != len(set(load_attributes)):
attributes_duplicates = {
x for x in load_attributes if load_attributes.count(x) > 1
Expand All @@ -945,7 +948,9 @@ def _init_fields(self):
"attribute arguments: {}".format(list(attributes_duplicates))
)

return fields_dict
self.fields = fields_dict
self.dump_fields = dump_fields
self.load_fields = load_fields

def on_bind_field(self, field_name, field_obj):
"""Hook to modify a field when it is bound to the `Schema`.
Expand Down
76 changes: 72 additions & 4 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,15 @@ class Family(Schema):
assert getattr(schema.fields["children"].inner.schema, param) == {"name"}

@pytest.mark.parametrize(
("param", "expected"),
(("only", {"name"}), ("exclude", {"name", "surname", "age"})),
("param", "expected_attribute", "expected_dump"),
(
("only", {"name"}, {"children": [{"name": "Lily"}]}),
("exclude", {"name", "surname", "age"}, {"children": [{}]}),
),
)
def test_list_nested_only_and_exclude_merged_with_nested(self, param, expected):
def test_list_nested_class_only_and_exclude_merged_with_nested(
self, param, expected_attribute, expected_dump
):
class Child(Schema):
name = fields.String()
surname = fields.String()
Expand All @@ -307,7 +312,70 @@ class Family(Schema):
children = fields.List(fields.Nested(Child, **{param: ("name", "surname")}))

schema = Family(**{param: ["children.name", "children.age"]})
assert getattr(schema.fields["children"].inner, param) == expected
assert getattr(schema.fields["children"].inner, param) == expected_attribute

family = {"children": [{"name": "Lily", "surname": "Martinez", "age": 15}]}
assert schema.dump(family) == expected_dump

def test_list_nested_class_multiple_dumps(self):
class Child(Schema):
name = fields.String()
surname = fields.String()
age = fields.Integer()

class Family(Schema):
children = fields.List(fields.Nested(Child, only=("name", "age")))

family = {"children": [{"name": "Lily", "surname": "Martinez", "age": 15}]}
assert Family(only=("children.age",)).dump(family) == {
"children": [{"age": 15}]
}
assert Family(only=("children.name",)).dump(family) == {
"children": [{"name": "Lily"}]
}

@pytest.mark.parametrize(
("param", "expected_attribute", "expected_dump"),
(
("only", {"name"}, {"children": [{"name": "Lily"}]}),
("exclude", {"name", "surname", "age"}, {"children": [{}]}),
),
)
def test_list_nested_instance_only_and_exclude_merged_with_nested(
self, param, expected_attribute, expected_dump
):
class Child(Schema):
name = fields.String()
surname = fields.String()
age = fields.Integer()

class Family(Schema):
children = fields.List(fields.Nested(Child(**{param: ("name", "surname")})))

schema = Family(**{param: ["children.name", "children.age"]})
assert (
getattr(schema.fields["children"].inner.schema, param) == expected_attribute
)

family = {"children": [{"name": "Lily", "surname": "Martinez", "age": 15}]}
assert schema.dump(family) == expected_dump

def test_list_nested_instance_multiple_dumps(self):
class Child(Schema):
name = fields.String()
surname = fields.String()
age = fields.Integer()

class Family(Schema):
children = fields.List(fields.Nested(Child(only=("name", "age"))))

family = {"children": [{"name": "Lily", "surname": "Martinez", "age": 15}]}
assert Family(only=("children.age",)).dump(family) == {
"children": [{"age": 15}]
}
assert Family(only=("children.name",)).dump(family) == {
"children": [{"name": "Lily"}]
}

def test_list_nested_partial_propagated_to_nested(self):
class Child(Schema):
Expand Down

0 comments on commit d104a5c

Please sign in to comment.