From d104a5c3a07d97c4f7579be90048faf234bfed4c Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Wed, 11 Sep 2019 22:02:46 -0400 Subject: [PATCH] Fix propagating `only` and `exclude` to nested field instances (#1385) * Respect dotted `only` and `exclude` on nested schema instances * Refactor: use _init_fields to re-initialize fields based on `only` and `exclude` Also, be explicit about which helper methods are private * Update changelog * Refactor only/exclude propagation This generalizes slightly better to accommodate #1382 * Fix behavior when dumping multiple times Copy schema to avoid unwanted sharing of `only` and `exclude` across instances * Remove unnecessary falsy checks * Update changelog --- CHANGELOG.rst | 8 +++++ src/marshmallow/fields.py | 11 +++++- src/marshmallow/schema.py | 47 +++++++++++++----------- tests/test_fields.py | 76 ++++++++++++++++++++++++++++++++++++--- 4 files changed, 116 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5d010b989..f547143c5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,14 @@ Changelog --------- +3.0.4 (unreleased) +++++++++++++++++++ + +Bug fixes: + +- Fix propagating dot-delimited `only` and `exclude` parameters to nested schema instances (:issue:`1384`). +- Includes bug fix from 2.20.4 (:issue:`1160`). + 3.0.3 (2019-09-04) ++++++++++++++++++ diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 2858cde77..8c5b8568c 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -468,8 +468,17 @@ def schema(self): # Inherit context from parent. context = getattr(self.parent, "context", {}) if isinstance(self.nested, SchemaABC): - self._schema = self.nested + self._schema = copy.deepcopy(self.nested) self._schema.context.update(context) + # Respect only and exclude passed from parent and re-initialize fields + set_class = self._schema.set_class + if self.only is not None: + original = self._schema.only + self._schema.only = set_class(self.only).intersection(original) + if self.exclude: + original = self._schema.exclude + self._schema.exclude = set_class(self.exclude).union(original) + self._schema._init_fields() else: if isinstance(self.nested, type) and issubclass(self.nested, SchemaABC): schema_class = self.nested diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index afd2f5db5..7614f5755 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -378,16 +378,10 @@ def __init__( self.context = context or {} self._normalize_nested_options() #: Dictionary mapping field_names -> :class:`Field` objects - self.fields = self._init_fields() - self.dump_fields, self.load_fields = self.dict_class(), self.dict_class() - for field_name, field_obj in self.fields.items(): - if field_obj.load_only: - self.load_fields[field_name] = field_obj - elif field_obj.dump_only: - self.dump_fields[field_name] = field_obj - else: - self.load_fields[field_name] = field_obj - self.dump_fields[field_name] = field_obj + self.fields = None + self.load_fields = None + self.dump_fields = None + self._init_fields() messages = {} messages.update(self._default_error_messages) for cls in reversed(self.__class__.__mro__): @@ -757,6 +751,7 @@ def _do_load( self, data, *, many=None, partial=None, unknown=None, postprocess=True ): """Deserialize `data`, returning the deserialized result. + This method is private API. :param data: The data to deserialize. :param bool many: Whether to deserialize `data` as a collection. If `None`, the @@ -844,7 +839,9 @@ def _do_load( return result def _normalize_nested_options(self): - """Apply then flatten nested schema options""" + """Apply then flatten nested schema options. + This method is private API. + """ if self.only is not None: # Apply the only option to nested fields. self.__apply_nested_option("only", self.only, "intersection") @@ -878,7 +875,9 @@ def __apply_nested_option(self, option_name, field_names, set_operation): setattr(self.declared_fields[key], option_name, new_options) def _init_fields(self): - """Update fields based on schema options.""" + """Update self.fields, self.load_fields, and self.dump_fields based on schema options. + This method is private API. + """ if self.opts.fields: available_field_names = self.set_class(self.opts.fields) else: @@ -913,10 +912,19 @@ def _init_fields(self): self._bind_field(field_name, field_obj) fields_dict[field_name] = field_obj + load_fields, dump_fields = self.dict_class(), self.dict_class() + for field_name, field_obj in fields_dict.items(): + if field_obj.load_only: + load_fields[field_name] = field_obj + elif field_obj.dump_only: + dump_fields[field_name] = field_obj + else: + load_fields[field_name] = field_obj + dump_fields[field_name] = field_obj + dump_data_keys = [ field_obj.data_key if field_obj.data_key is not None else name - for name, field_obj in fields_dict.items() - if not field_obj.load_only + for name, field_obj in dump_fields.items() ] if len(dump_data_keys) != len(set(dump_data_keys)): data_keys_duplicates = { @@ -928,12 +936,7 @@ def _init_fields(self): "Check the following field names and " "data_key arguments: {}".format(list(data_keys_duplicates)) ) - - load_attributes = [ - obj.attribute or name - for name, obj in fields_dict.items() - if not obj.dump_only - ] + load_attributes = [obj.attribute or name for name, obj in load_fields.items()] if len(load_attributes) != len(set(load_attributes)): attributes_duplicates = { x for x in load_attributes if load_attributes.count(x) > 1 @@ -945,7 +948,9 @@ def _init_fields(self): "attribute arguments: {}".format(list(attributes_duplicates)) ) - return fields_dict + self.fields = fields_dict + self.dump_fields = dump_fields + self.load_fields = load_fields def on_bind_field(self, field_name, field_obj): """Hook to modify a field when it is bound to the `Schema`. diff --git a/tests/test_fields.py b/tests/test_fields.py index 751f7cd9e..f4cb3bd24 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -294,10 +294,15 @@ class Family(Schema): assert getattr(schema.fields["children"].inner.schema, param) == {"name"} @pytest.mark.parametrize( - ("param", "expected"), - (("only", {"name"}), ("exclude", {"name", "surname", "age"})), + ("param", "expected_attribute", "expected_dump"), + ( + ("only", {"name"}, {"children": [{"name": "Lily"}]}), + ("exclude", {"name", "surname", "age"}, {"children": [{}]}), + ), ) - def test_list_nested_only_and_exclude_merged_with_nested(self, param, expected): + def test_list_nested_class_only_and_exclude_merged_with_nested( + self, param, expected_attribute, expected_dump + ): class Child(Schema): name = fields.String() surname = fields.String() @@ -307,7 +312,70 @@ class Family(Schema): children = fields.List(fields.Nested(Child, **{param: ("name", "surname")})) schema = Family(**{param: ["children.name", "children.age"]}) - assert getattr(schema.fields["children"].inner, param) == expected + assert getattr(schema.fields["children"].inner, param) == expected_attribute + + family = {"children": [{"name": "Lily", "surname": "Martinez", "age": 15}]} + assert schema.dump(family) == expected_dump + + def test_list_nested_class_multiple_dumps(self): + class Child(Schema): + name = fields.String() + surname = fields.String() + age = fields.Integer() + + class Family(Schema): + children = fields.List(fields.Nested(Child, only=("name", "age"))) + + family = {"children": [{"name": "Lily", "surname": "Martinez", "age": 15}]} + assert Family(only=("children.age",)).dump(family) == { + "children": [{"age": 15}] + } + assert Family(only=("children.name",)).dump(family) == { + "children": [{"name": "Lily"}] + } + + @pytest.mark.parametrize( + ("param", "expected_attribute", "expected_dump"), + ( + ("only", {"name"}, {"children": [{"name": "Lily"}]}), + ("exclude", {"name", "surname", "age"}, {"children": [{}]}), + ), + ) + def test_list_nested_instance_only_and_exclude_merged_with_nested( + self, param, expected_attribute, expected_dump + ): + class Child(Schema): + name = fields.String() + surname = fields.String() + age = fields.Integer() + + class Family(Schema): + children = fields.List(fields.Nested(Child(**{param: ("name", "surname")}))) + + schema = Family(**{param: ["children.name", "children.age"]}) + assert ( + getattr(schema.fields["children"].inner.schema, param) == expected_attribute + ) + + family = {"children": [{"name": "Lily", "surname": "Martinez", "age": 15}]} + assert schema.dump(family) == expected_dump + + def test_list_nested_instance_multiple_dumps(self): + class Child(Schema): + name = fields.String() + surname = fields.String() + age = fields.Integer() + + class Family(Schema): + children = fields.List(fields.Nested(Child(only=("name", "age")))) + + family = {"children": [{"name": "Lily", "surname": "Martinez", "age": 15}]} + assert Family(only=("children.age",)).dump(family) == { + "children": [{"age": 15}] + } + assert Family(only=("children.name",)).dump(family) == { + "children": [{"name": "Lily"}] + } def test_list_nested_partial_propagated_to_nested(self): class Child(Schema):