Skip to content

Commit

Permalink
Follow pep8 for TypeVar naming; fix typing for Context.get (#2785)
Browse files Browse the repository at this point in the history
* Follow pep8 for TypeVar naming; fix typing for Context.get

* Fix py39 and mypy issue

* Update other typevars
  • Loading branch information
sloria authored Jan 16, 2025
1 parent d62c571 commit 01fab37
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 46 deletions.
16 changes: 11 additions & 5 deletions src/marshmallow/experimental/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,35 @@ class UserSchema(Schema):
import contextvars
import typing

_ContextType = typing.TypeVar("_ContextType")
try:
from types import EllipsisType
except ImportError: # Python<3.10
EllipsisType = type(Ellipsis) # type: ignore[misc]

_ContextT = typing.TypeVar("_ContextT")
_DefaultT = typing.TypeVar("_DefaultT")
_CURRENT_CONTEXT: contextvars.ContextVar = contextvars.ContextVar("context")


class Context(contextlib.AbstractContextManager, typing.Generic[_ContextType]):
class Context(contextlib.AbstractContextManager, typing.Generic[_ContextT]):
"""Context manager for setting and retrieving context.
:param context: The context to use within the context manager scope.
"""

def __init__(self, context: _ContextType) -> None:
def __init__(self, context: _ContextT) -> None:
self.context = context
self.token: contextvars.Token | None = None

def __enter__(self) -> Context[_ContextType]:
def __enter__(self) -> Context[_ContextT]:
self.token = _CURRENT_CONTEXT.set(self.context)
return self

def __exit__(self, *args, **kwargs) -> None:
_CURRENT_CONTEXT.reset(typing.cast(contextvars.Token, self.token))

@classmethod
def get(cls, default=...) -> _ContextType:
def get(cls, default: _DefaultT | EllipsisType = ...) -> _ContextT | _DefaultT:
"""Get the current context.
:param default: Default value to return if no context is set.
Expand Down
72 changes: 36 additions & 36 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"Pluck",
]

_InternalType = typing.TypeVar("_InternalType")
_InternalT = typing.TypeVar("_InternalT")


class _BaseFieldKwargs(typing.TypedDict, total=False):
Expand Down Expand Up @@ -113,7 +113,7 @@ def _resolve_field_instance(cls_or_instance: Field | type[Field]) -> Field:
return cls_or_instance


class Field(typing.Generic[_InternalType]):
class Field(typing.Generic[_InternalT]):
"""Base field from which all other fields inherit.
This class should not be used directly within Schemas.
Expand Down Expand Up @@ -252,7 +252,7 @@ def get_value(
typing.Callable[[typing.Any, str, typing.Any], typing.Any] | None
) = None,
default: typing.Any = missing_,
) -> _InternalType:
) -> _InternalT:
"""Return the value for a given key from an object.
:param obj: The object to get the value from.
Expand Down Expand Up @@ -336,7 +336,7 @@ def deserialize(
attr: str | None = None,
data: typing.Mapping[str, typing.Any] | None = None,
**kwargs,
) -> None | _InternalType: ...
) -> None | _InternalT: ...

# If value is not None, internal type is returned
@typing.overload
Expand All @@ -346,15 +346,15 @@ def deserialize(
attr: str | None = None,
data: typing.Mapping[str, typing.Any] | None = None,
**kwargs,
) -> _InternalType: ...
) -> _InternalT: ...

def deserialize(
self,
value: typing.Any,
attr: str | None = None,
data: typing.Mapping[str, typing.Any] | None = None,
**kwargs,
) -> _InternalType | None:
) -> _InternalT | None:
"""Deserialize ``value``.
:param value: The value to deserialize.
Expand Down Expand Up @@ -392,7 +392,7 @@ def _bind_to_schema(self, field_name: str, parent: Schema | Field) -> None:
)

def _serialize(
self, value: _InternalType | None, attr: str | None, obj: typing.Any, **kwargs
self, value: _InternalT | None, attr: str | None, obj: typing.Any, **kwargs
) -> typing.Any:
"""Serializes ``value`` to a basic Python datatype. Noop by default.
Concrete :class:`Field` classes should implement this method.
Expand All @@ -419,7 +419,7 @@ def _deserialize(
attr: str | None,
data: typing.Mapping[str, typing.Any] | None,
**kwargs,
) -> _InternalType:
) -> _InternalT:
"""Deserialize value. Concrete :class:`Field` classes should implement this method.
:param value: The value to be deserialized.
Expand Down Expand Up @@ -682,7 +682,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs):
return self._load(value, partial=partial)


class List(Field[list[typing.Optional[_InternalType]]]):
class List(Field[list[typing.Optional[_InternalT]]]):
"""A list field, composed with another `Field` class or
instance.
Expand All @@ -702,12 +702,12 @@ class List(Field[list[typing.Optional[_InternalType]]]):

def __init__(
self,
cls_or_instance: Field[_InternalType] | type[Field[_InternalType]],
cls_or_instance: Field[_InternalT] | type[Field[_InternalT]],
**kwargs: Unpack[_BaseFieldKwargs],
):
super().__init__(**kwargs)
try:
self.inner: Field[_InternalType] = _resolve_field_instance(cls_or_instance)
self.inner: Field[_InternalT] = _resolve_field_instance(cls_or_instance)
except _FieldInstanceResolutionError as error:
raise ValueError(
"The list elements must be a subclass or instance of "
Expand All @@ -725,12 +725,12 @@ def _bind_to_schema(self, field_name: str, parent: Schema | Field) -> None:
self.inner.only = self.only
self.inner.exclude = self.exclude

def _serialize(self, value, attr, obj, **kwargs) -> list[_InternalType] | None:
def _serialize(self, value, attr, obj, **kwargs) -> list[_InternalT] | None:
if value is None:
return None
return [self.inner._serialize(each, attr, obj, **kwargs) for each in value]

def _deserialize(self, value, attr, data, **kwargs) -> list[_InternalType | None]:
def _deserialize(self, value, attr, data, **kwargs) -> list[_InternalT | None]:
if not utils.is_collection(value):
raise self.make_error("invalid")

Expand All @@ -741,7 +741,7 @@ def _deserialize(self, value, attr, data, **kwargs) -> list[_InternalType | None
result.append(self.inner.deserialize(each, **kwargs))
except ValidationError as error:
if error.valid_data is not None:
result.append(typing.cast(_InternalType, error.valid_data))
result.append(typing.cast(_InternalT, error.valid_data))
errors.update({idx: error.messages})
if errors:
raise ValidationError(errors, valid_data=result)
Expand Down Expand Up @@ -896,10 +896,10 @@ def _deserialize(self, value, attr, data, **kwargs) -> uuid.UUID:
return self._validated(value)


_NumType = typing.TypeVar("_NumType")
_NumT = typing.TypeVar("_NumT")


class Number(Field[_NumType]):
class Number(Field[_NumT]):
"""Base class for number fields. This class should not be used within schemas.
:param as_string: If `True`, format the serialized value as a string.
Expand All @@ -910,7 +910,7 @@ class Number(Field[_NumType]):
Use `Integer <marshmallow.fields.Integer>`, `Float <marshmallow.fields.Float>`, or `Decimal <marshmallow.fields.Decimal>` instead.
"""

num_type: type[_NumType]
num_type: type[_NumT]

#: Default error messages.
default_error_messages = {
Expand All @@ -922,11 +922,11 @@ def __init__(self, *, as_string: bool = False, **kwargs: Unpack[_BaseFieldKwargs
self.as_string = as_string
super().__init__(**kwargs)

def _format_num(self, value) -> _NumType:
def _format_num(self, value) -> _NumT:
"""Return the number value for value, given this field's `num_type`."""
return self.num_type(value) # type: ignore

def _validated(self, value: typing.Any) -> _NumType:
def _validated(self, value: typing.Any) -> _NumT:
"""Format the value or raise a :exc:`ValidationError` if an error occurs."""
# (value is True or value is False) is ~5x faster than isinstance(value, bool)
if value is True or value is False:
Expand All @@ -938,17 +938,17 @@ def _validated(self, value: typing.Any) -> _NumType:
except OverflowError as error:
raise self.make_error("too_large", input=value) from error

def _to_string(self, value: _NumType) -> str:
def _to_string(self, value: _NumT) -> str:
return str(value)

def _serialize(self, value, attr, obj, **kwargs) -> str | _NumType | None:
def _serialize(self, value, attr, obj, **kwargs) -> str | _NumT | None:
"""Return a string if `self.as_string=True`, otherwise return this field's `num_type`."""
if value is None:
return None
ret: _NumType = self._format_num(value)
ret: _NumT = self._format_num(value)
return self._to_string(ret) if self.as_string else ret

def _deserialize(self, value, attr, data, **kwargs) -> _NumType:
def _deserialize(self, value, attr, data, **kwargs) -> _NumT:
return self._validated(value)


Expand Down Expand Up @@ -1531,10 +1531,10 @@ def _deserialize(self, value, attr, data, **kwargs) -> dt.timedelta:
raise self.make_error("invalid") from error


_MappingType = typing.TypeVar("_MappingType", bound=_Mapping)
_MappingT = typing.TypeVar("_MappingT", bound=_Mapping)


class Mapping(Field[_MappingType]):
class Mapping(Field[_MappingT]):
"""An abstract class for objects with key-value pairs. This class should not be used within schemas.
:param keys: A field class or instance for dict keys.
Expand All @@ -1551,7 +1551,7 @@ class Mapping(Field[_MappingType]):
Use `Dict <marshmallow.fields.Dict>` instead.
"""

mapping_type: type[_MappingType]
mapping_type: type[_MappingT]

#: Default error messages.
default_error_messages = {"invalid": "Not a valid mapping type."}
Expand Down Expand Up @@ -1857,10 +1857,10 @@ class IPv6Interface(IPInterface):
DESERIALIZATION_CLASS = ipaddress.IPv6Interface


_EnumType = typing.TypeVar("_EnumType", bound=EnumType)
_EnumT = typing.TypeVar("_EnumT", bound=EnumType)


class Enum(Field[_EnumType]):
class Enum(Field[_EnumT]):
"""An Enum field (de)serializing enum members by symbol (name) or by value.
:param enum: Enum class
Expand All @@ -1880,7 +1880,7 @@ class Enum(Field[_EnumType]):

def __init__(
self,
enum: type[_EnumType],
enum: type[_EnumT],
*,
by_value: bool | Field | type[Field] = False,
**kwargs: Unpack[_BaseFieldKwargs],
Expand Down Expand Up @@ -1912,7 +1912,7 @@ def __init__(
)

def _serialize(
self, value: _EnumType | None, attr: str | None, obj: typing.Any, **kwargs
self, value: _EnumT | None, attr: str | None, obj: typing.Any, **kwargs
) -> typing.Any | None:
if value is None:
return None
Expand All @@ -1922,7 +1922,7 @@ def _serialize(
val = value.name
return self.field._serialize(val, attr, obj, **kwargs)

def _deserialize(self, value, attr, data, **kwargs) -> _EnumType:
def _deserialize(self, value, attr, data, **kwargs) -> _EnumT:
if isinstance(value, self.enum):
return value
val = self.field._deserialize(value, attr, data, **kwargs)
Expand Down Expand Up @@ -2045,10 +2045,10 @@ def _deserialize(self, value, attr, data, **kwargs):
return value


_ContantType = typing.TypeVar("_ContantType")
_ContantT = typing.TypeVar("_ContantT")


class Constant(Field[_ContantType]):
class Constant(Field[_ContantT]):
"""A field that (de)serializes to a preset constant. If you only want the
constant added for serialization or deserialization, you should use
``dump_only=True`` or ``load_only=True`` respectively.
Expand All @@ -2058,16 +2058,16 @@ class Constant(Field[_ContantType]):

_CHECK_ATTRIBUTE = False

def __init__(self, constant: _ContantType, **kwargs: Unpack[_BaseFieldKwargs]):
def __init__(self, constant: _ContantT, **kwargs: Unpack[_BaseFieldKwargs]):
super().__init__(**kwargs)
self.constant = constant
self.load_default = constant
self.dump_default = constant

def _serialize(self, value, *args, **kwargs) -> _ContantType:
def _serialize(self, value, *args, **kwargs) -> _ContantT:
return self.constant

def _deserialize(self, value, *args, **kwargs) -> _ContantType:
def _deserialize(self, value, *args, **kwargs) -> _ContantT:
return self.constant


Expand Down
12 changes: 7 additions & 5 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ class Parent(Schema):
assert field.deserialize("foo") == "FOOBAR"

def test_decorated_processors_with_context(self):
NumDictContext = Context[dict[int, int]]

class MySchema(Schema):
f_1 = fields.Integer()
f_2 = fields.Integer()
Expand All @@ -196,27 +198,27 @@ class MySchema(Schema):

@pre_dump
def multiply_f_1(self, item, **kwargs):
item["f_1"] *= Context.get()[1]
item["f_1"] *= NumDictContext.get()[1]
return item

@pre_load
def multiply_f_2(self, data, **kwargs):
data["f_2"] *= Context.get()[2]
data["f_2"] *= NumDictContext.get()[2]
return data

@post_dump
def multiply_f_3(self, item, **kwargs):
item["f_3"] *= Context.get()[3]
item["f_3"] *= NumDictContext.get()[3]
return item

@post_load
def multiply_f_4(self, data, **kwargs):
data["f_4"] *= Context.get()[4]
data["f_4"] *= NumDictContext.get()[4]
return data

schema = MySchema()

with Context({1: 2, 2: 3, 3: 4, 4: 5}):
with NumDictContext({1: 2, 2: 3, 3: 4, 4: 5}):
assert schema.dump({"f_1": 1, "f_2": 1, "f_3": 1, "f_4": 1}) == {
"f_1": 2,
"f_2": 1,
Expand Down

0 comments on commit 01fab37

Please sign in to comment.