diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index d9dd4c480..ac59ba8b8 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -264,14 +264,14 @@ def get_value( check_key = attr if self.attribute is None else self.attribute return accessor_func(obj, check_key, default) - def _validate(self, value: typing.Any) -> None: + def _validate(self, value: _InternalT) -> _InternalT: """Perform validation on ``value``. Raise a :exc:`ValidationError` if validation does not succeed. """ - self._validate_all(value) + return self._validate_all(value) @property - def _validate_all(self) -> typing.Callable[[typing.Any], None]: + def _validate_all(self) -> typing.Callable[[_InternalT], _InternalT]: return And(*self.validators) def make_error(self, key: str, **kwargs) -> ValidationError: @@ -373,7 +373,7 @@ def deserialize( if self.allow_none and value is None: return None output = self._deserialize(value, attr, data, **kwargs) - self._validate(output) + output = self._validate(output) return output # Methods for concrete classes to override. diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 23d5cd279..095109b05 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -1125,6 +1125,10 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool) ) if validated_value is missing: data[idx].pop(field_name, None) + else: + data[idx][field_obj.attribute or field_name] = ( + validated_value + ) else: try: value = data[field_obj.attribute or field_name] @@ -1139,6 +1143,8 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool) ) if validated_value is missing: data.pop(field_name, None) + else: + data[field_obj.attribute or field_name] = validated_value def _invoke_schema_validators( self, diff --git a/src/marshmallow/validate.py b/src/marshmallow/validate.py index 0f6cf3b1a..dfeb77f7a 100644 --- a/src/marshmallow/validate.py +++ b/src/marshmallow/validate.py @@ -71,7 +71,7 @@ def __call__(self, value: typing.Any) -> typing.Any: kwargs = {} for validator in self.validators: try: - validator(value) + value = validator(value) except ValidationError as err: kwargs.update(err.kwargs) if isinstance(err.messages, dict): diff --git a/tests/base.py b/tests/base.py index 28efaeffc..e7466ca2c 100644 --- a/tests/base.py +++ b/tests/base.py @@ -78,13 +78,16 @@ def assert_time_equal(t1: dt.time, t2: dt.time) -> None: ##### Validation ##### +T = typing.TypeVar("T") + def predicate( func: typing.Callable[[typing.Any], bool], -) -> typing.Callable[[typing.Any], None]: - def validate(value: typing.Any) -> None: +) -> typing.Callable[[T], T]: + def validate(value: T) -> T: if func(value) is False: raise ValidationError("Invalid value.") + return value return validate diff --git a/tests/test_decorators.py b/tests/test_decorators.py index d8a6f77d3..bc6624f4d 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -248,9 +248,10 @@ class ValidatesSchema(Schema): foo = fields.Int() @validates("foo") - def validate_foo(self, value): + def validate_foo(self, value: int) -> int: if value != 42: raise ValidationError("The answer to life the universe and everything.") + return value class TestValidatesDecorator: diff --git a/tests/test_schema.py b/tests/test_schema.py index 674b1a60d..54570e827 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1751,6 +1751,46 @@ def validate_b(self, val): errors = s.validate({"b": "data"}) assert errors == {"b": {"code": "invalid_b"}} + def test_field_validator_can_transform_deserialized_value(self): + def strip_whitespace(value: str) -> str: + return value.strip() + + def uppercase(value: str) -> str: + return value.upper() + + class ArtistSchema(Schema): + name = fields.Str(data_key="Name", validate=(strip_whitespace, uppercase)) + + s = ArtistSchema() + assert s.load({"Name": " foo "}) == {"name": "FOO"} + + s_many = ArtistSchema(many=True) + assert s_many.load([{"Name": " foo "}, {"Name": " bar "}]) == [ + {"name": "FOO"}, + {"name": "BAR"}, + ] + + def test_validator_method_can_transform_deserialized_value(self): + class ArtistSchema(Schema): + name = fields.Str(data_key="Name") + + @validates("name") + def strip_whitespace(self, value: str) -> str: + return value.strip() + + @validates("name") + def uppercase(self, value: str) -> str: + return value.upper() + + s = ArtistSchema() + assert s.load({"Name": " foo "}) == {"name": "FOO"} + + s_many = ArtistSchema(many=True) + assert s_many.load([{"Name": " foo "}, {"Name": " bar "}]) == [ + {"name": "FOO"}, + {"name": "BAR"}, + ] + def test_schema_repr(): class MySchema(Schema): diff --git a/tests/test_validate.py b/tests/test_validate.py index fcda4e816..69d514fc8 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -961,9 +961,10 @@ def test_containsnoneof_mixing_types(): validate.ContainsNoneOf([1, 2, 3])((1,)) -def is_even(value): +def is_even(value: int) -> int: if value % 2 != 0: raise ValidationError("Not an even value.") + return value def test_and():