Skip to content

Commit

Permalink
Allow validators to transform values
Browse files Browse the repository at this point in the history
  • Loading branch information
sloria committed Jan 16, 2025
1 parent 01fab37 commit c2e6b09
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,14 @@ def get_value(
check_key = attr if self.attribute is None else self.attribute
return accessor_func(obj, check_key, default)

def _validate(self, value: typing.Any) -> None:
def _validate(self, value: _InternalT) -> _InternalT:
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
does not succeed.
"""
self._validate_all(value)
return self._validate_all(value)

@property
def _validate_all(self) -> typing.Callable[[typing.Any], None]:
def _validate_all(self) -> typing.Callable[[_InternalT], _InternalT]:
return And(*self.validators)

def make_error(self, key: str, **kwargs) -> ValidationError:
Expand Down Expand Up @@ -373,7 +373,7 @@ def deserialize(
if self.allow_none and value is None:
return None
output = self._deserialize(value, attr, data, **kwargs)
self._validate(output)
output = self._validate(output)
return output

# Methods for concrete classes to override.
Expand Down
6 changes: 6 additions & 0 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,10 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool)
)
if validated_value is missing:
data[idx].pop(field_name, None)
else:
data[idx][field_obj.attribute or field_name] = (
validated_value
)
else:
try:
value = data[field_obj.attribute or field_name]
Expand All @@ -1139,6 +1143,8 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool)
)
if validated_value is missing:
data.pop(field_name, None)
else:
data[field_obj.attribute or field_name] = validated_value

def _invoke_schema_validators(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/marshmallow/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __call__(self, value: typing.Any) -> typing.Any:
kwargs = {}
for validator in self.validators:
try:
validator(value)
value = validator(value)
except ValidationError as err:
kwargs.update(err.kwargs)
if isinstance(err.messages, dict):
Expand Down
7 changes: 5 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,16 @@ def assert_time_equal(t1: dt.time, t2: dt.time) -> None:

##### Validation #####

T = typing.TypeVar("T")


def predicate(
func: typing.Callable[[typing.Any], bool],
) -> typing.Callable[[typing.Any], None]:
def validate(value: typing.Any) -> None:
) -> typing.Callable[[T], T]:
def validate(value: T) -> T:
if func(value) is False:
raise ValidationError("Invalid value.")
return value

return validate

Expand Down
3 changes: 2 additions & 1 deletion tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ class ValidatesSchema(Schema):
foo = fields.Int()

@validates("foo")
def validate_foo(self, value):
def validate_foo(self, value: int) -> int:
if value != 42:
raise ValidationError("The answer to life the universe and everything.")
return value


class TestValidatesDecorator:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,46 @@ def validate_b(self, val):
errors = s.validate({"b": "data"})
assert errors == {"b": {"code": "invalid_b"}}

def test_field_validator_can_transform_deserialized_value(self):
def strip_whitespace(value: str) -> str:
return value.strip()

def uppercase(value: str) -> str:
return value.upper()

class ArtistSchema(Schema):
name = fields.Str(data_key="Name", validate=(strip_whitespace, uppercase))

s = ArtistSchema()
assert s.load({"Name": " foo "}) == {"name": "FOO"}

s_many = ArtistSchema(many=True)
assert s_many.load([{"Name": " foo "}, {"Name": " bar "}]) == [
{"name": "FOO"},
{"name": "BAR"},
]

def test_validator_method_can_transform_deserialized_value(self):
class ArtistSchema(Schema):
name = fields.Str(data_key="Name")

@validates("name")
def strip_whitespace(self, value: str) -> str:
return value.strip()

@validates("name")
def uppercase(self, value: str) -> str:
return value.upper()

s = ArtistSchema()
assert s.load({"Name": " foo "}) == {"name": "FOO"}

s_many = ArtistSchema(many=True)
assert s_many.load([{"Name": " foo "}, {"Name": " bar "}]) == [
{"name": "FOO"},
{"name": "BAR"},
]


def test_schema_repr():
class MySchema(Schema):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,10 @@ def test_containsnoneof_mixing_types():
validate.ContainsNoneOf([1, 2, 3])((1,))


def is_even(value):
def is_even(value: int) -> int:
if value % 2 != 0:
raise ValidationError("Not an even value.")
return value


def test_and():
Expand Down

0 comments on commit c2e6b09

Please sign in to comment.