Skip to content

Commit

Permalink
Add missing_values parameter to field
Browse files Browse the repository at this point in the history
Allows specifying which values are treated as "missing".

Addresses #713
  • Loading branch information
sloria committed Sep 8, 2019
1 parent 4290d0f commit afda7cb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 6 deletions.
19 changes: 15 additions & 4 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class Field(FieldABC):
"validator_failed": "Invalid value.",
}

default_missing_values = tuple()

def __init__(
self,
*,
Expand All @@ -147,6 +149,7 @@ def __init__(
load_only=False,
dump_only=False,
error_messages=None,
missing_values=None,
**metadata
):
self.default = default
Expand All @@ -168,9 +171,14 @@ def __init__(
"or a collection of callables."
)

# If missing=None, None should be considered valid by default
self.missing_values = (
missing_values
if missing_values is not None
else self.default_missing_values
)
# If missing=None or None is in missing_values, None should be considered valid by default
if allow_none is None:
if missing is None:
if missing is None or self._is_missing_value(None):
self.allow_none = True
else:
self.allow_none = False
Expand Down Expand Up @@ -223,6 +231,9 @@ def get_value(self, obj, attr, accessor=None, default=missing_):
check_key = attr if attribute is None else attribute
return accessor_func(obj, check_key, default)

def _is_missing_value(self, value):
return value is missing_ or value in self.missing_values

def _validate(self, value):
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
does not succeed.
Expand Down Expand Up @@ -279,7 +290,7 @@ def _validate_missing(self, value):
"""Validate missing values. Raise a :exc:`ValidationError` if
`value` should be considered missing.
"""
if value is missing_:
if self._is_missing_value(value):
if hasattr(self, "required") and self.required:
raise self.make_error("required")
if value is None:
Expand Down Expand Up @@ -319,7 +330,7 @@ def deserialize(self, value, attr=None, data=None, **kwargs):
# Validate required fields, deserialize, then validate
# deserialized value
self._validate_missing(value)
if value is missing_:
if self._is_missing_value(value):
_miss = self.missing
return _miss() if callable(_miss) else _miss
if getattr(self, "allow_none", False) is True and value is None:
Expand Down
49 changes: 48 additions & 1 deletion tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_fields_dont_allow_none_by_default(self, FieldClass):
with pytest.raises(ValidationError, match="Field may not be null."):
field.deserialize(None)

def test_allow_none_is_true_if_missing_is_true(self):
def test_allow_none_is_true_if_missing_is_none(self):
field = fields.Field(missing=None)
assert field.allow_none is True
field.deserialize(None) is None
Expand Down Expand Up @@ -1338,6 +1338,53 @@ class AliasingUserSerializer(Schema):
assert result["name"] == "Mick"
assert result["years"] is None

# https://github.com/marshmallow-code/marshmallow/issues/713
@pytest.mark.parametrize(
("missing", "missing_values", "input_data", "expected"),
[
(None, {None}, {"name": None}, {"name": None}),
(None, {None}, {"name": ""}, {"name": ""}),
(None, {""}, {"name": ""}, {"name": None}),
(None, {""}, {}, {"name": None}),
("", {""}, {"name": ""}, {"name": ""}),
("", {None}, {"name": None}, {"name": ""}),
("", {None}, {}, {"name": ""}),
],
)
def test_deserialize_with_custom_missing_values(
self, missing, missing_values, input_data, expected
):
class ArtistSchema(Schema):
name = fields.String(missing=missing, missing_values=missing_values)

schema = ArtistSchema()
assert schema.load(input_data) == expected

def test_deserialize_required_field_with_custom_missing_values(self):
class ArtistSchema(Schema):
album_names = fields.List(
fields.Str(), required=True, missing_values=([], ())
)

with pytest.raises(ValidationError, match="required"):
ArtistSchema().load({"album_names": []})

def test_setting_default_missing_values(self, monkeypatch):
monkeypatch.setattr(fields.Field, "default_missing_values", ("",))
monkeypatch.setattr(fields.List, "default_missing_values", ([], ()))

class ArtistSchema(Schema):
name = fields.String(missing=None)
dob = fields.DateTime(missing=None)
album_names = fields.List(fields.Str(), required=True)

schema = ArtistSchema()
loaded = schema.load({"name": "", "dob": "", "album_names": ["Hunky Dory"]})
assert loaded == {"name": None, "dob": None, "album_names": ["Hunky Dory"]}

with pytest.raises(ValidationError, match="required"):
assert schema.load({"name": "", "dob": "", "album_names": []})

def test_deserialization_raises_with_errors(self):
bad_data = {"email": "invalid-email", "colors": "burger", "age": -1}
v = Validator()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_function_field_load_only(self):
field = fields.Function(deserialize=lambda obj: None)
assert field.load_only

def test_function_field_passed_serialize_with_context(self, user, monkeypatch):
def test_function_field_passed_serialize_with_context(self, user):
class Parent(Schema):
pass

Expand Down

0 comments on commit afda7cb

Please sign in to comment.