Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow validators to transform values #2786

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,14 @@ def get_value(
check_key = attr if self.attribute is None else self.attribute
return accessor_func(obj, check_key, default)

def _validate(self, value: typing.Any) -> None:
def _validate(self, value: _InternalT) -> _InternalT:
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
does not succeed.
"""
self._validate_all(value)
return self._validate_all(value)

@property
def _validate_all(self) -> typing.Callable[[typing.Any], None]:
def _validate_all(self) -> typing.Callable[[_InternalT], _InternalT]:
return And(*self.validators)

def make_error(self, key: str, **kwargs) -> ValidationError:
Expand Down Expand Up @@ -373,7 +373,7 @@ def deserialize(
if self.allow_none and value is None:
return None
output = self._deserialize(value, attr, data, **kwargs)
self._validate(output)
output = self._validate(output)
return output

# Methods for concrete classes to override.
Expand Down
6 changes: 6 additions & 0 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,10 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool)
)
if validated_value is missing:
data[idx].pop(field_name, None)
else:
data[idx][field_obj.attribute or field_name] = (
validated_value
)
else:
try:
value = data[field_obj.attribute or field_name]
Expand All @@ -1139,6 +1143,8 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool)
)
if validated_value is missing:
data.pop(field_name, None)
else:
data[field_obj.attribute or field_name] = validated_value

def _invoke_schema_validators(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/marshmallow/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __call__(self, value: typing.Any) -> typing.Any:
kwargs = {}
for validator in self.validators:
try:
validator(value)
value = validator(value)
except ValidationError as err:
kwargs.update(err.kwargs)
if isinstance(err.messages, dict):
Expand Down
7 changes: 5 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,16 @@ def assert_time_equal(t1: dt.time, t2: dt.time) -> None:

##### Validation #####

T = typing.TypeVar("T")


def predicate(
func: typing.Callable[[typing.Any], bool],
) -> typing.Callable[[typing.Any], None]:
def validate(value: typing.Any) -> None:
) -> typing.Callable[[T], T]:
def validate(value: T) -> T:
if func(value) is False:
raise ValidationError("Invalid value.")
return value

return validate

Expand Down
3 changes: 2 additions & 1 deletion tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ class ValidatesSchema(Schema):
foo = fields.Int()

@validates("foo")
def validate_foo(self, value):
def validate_foo(self, value: int) -> int:
if value != 42:
raise ValidationError("The answer to life the universe and everything.")
return value


class TestValidatesDecorator:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,46 @@ def validate_b(self, val):
errors = s.validate({"b": "data"})
assert errors == {"b": {"code": "invalid_b"}}

def test_field_validator_can_transform_deserialized_value(self):
def strip_whitespace(value: str) -> str:
return value.strip()

def uppercase(value: str) -> str:
return value.upper()

class ArtistSchema(Schema):
name = fields.Str(data_key="Name", validate=(strip_whitespace, uppercase))

s = ArtistSchema()
assert s.load({"Name": " foo "}) == {"name": "FOO"}

s_many = ArtistSchema(many=True)
assert s_many.load([{"Name": " foo "}, {"Name": " bar "}]) == [
{"name": "FOO"},
{"name": "BAR"},
]

def test_validator_method_can_transform_deserialized_value(self):
class ArtistSchema(Schema):
name = fields.Str(data_key="Name")

@validates("name")
def strip_whitespace(self, value: str) -> str:
return value.strip()

@validates("name")
def uppercase(self, value: str) -> str:
return value.upper()

s = ArtistSchema()
assert s.load({"Name": " foo "}) == {"name": "FOO"}

s_many = ArtistSchema(many=True)
assert s_many.load([{"Name": " foo "}, {"Name": " bar "}]) == [
{"name": "FOO"},
{"name": "BAR"},
]


def test_schema_repr():
class MySchema(Schema):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,10 @@ def test_containsnoneof_mixing_types():
validate.ContainsNoneOf([1, 2, 3])((1,))


def is_even(value):
def is_even(value: int) -> int:
if value % 2 != 0:
raise ValidationError("Not an even value.")
return value


def test_and():
Expand Down
Loading