diff --git a/falcon/media/json.py b/falcon/media/json.py index cf0111e82..502be0126 100644 --- a/falcon/media/json.py +++ b/falcon/media/json.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from functools import partial import json +from typing import Any, Callable, Optional, Union from falcon import errors from falcon import http_error from falcon.media.base import BaseHandler from falcon.media.base import TextBaseHandlerWS +from falcon.typing import AsyncReadableIO +from falcon.typing import ReadableIO class JSONHandler(BaseHandler): @@ -148,7 +153,11 @@ def default(self, obj): loads (func): Function to use when deserializing JSON requests. """ - def __init__(self, dumps=None, loads=None): + def __init__( + self, + dumps: Optional[Callable[[Any], Union[str, bytes]]] = None, + loads: Optional[Callable[[str], Any]] = None, + ) -> None: self._dumps = dumps or partial(json.dumps, ensure_ascii=False) self._loads = loads or json.loads @@ -156,11 +165,11 @@ def __init__(self, dumps=None, loads=None): # proper serialize implementation. result = self._dumps({'message': 'Hello World'}) if isinstance(result, str): - self.serialize = self._serialize_s - self.serialize_async = self._serialize_async_s + self.serialize = self._serialize_s # type: ignore[method-assign] + self.serialize_async = self._serialize_async_s # type: ignore[method-assign] else: - self.serialize = self._serialize_b - self.serialize_async = self._serialize_async_b + self.serialize = self._serialize_b # type: ignore[method-assign] + self.serialize_async = self._serialize_async_b # type: ignore[method-assign] # NOTE(kgriffs): To be safe, only enable the optimized protocol when # not subclassed. @@ -168,7 +177,7 @@ def __init__(self, dumps=None, loads=None): self._serialize_sync = self.serialize self._deserialize_sync = self._deserialize - def _deserialize(self, data): + def _deserialize(self, data: bytes) -> Any: if not data: raise errors.MediaNotFoundError('JSON') try: @@ -176,27 +185,41 @@ def _deserialize(self, data): except ValueError as err: raise errors.MediaMalformedError('JSON') from err - def deserialize(self, stream, content_type, content_length): + def deserialize( + self, + stream: ReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(stream.read()) - async def deserialize_async(self, stream, content_type, content_length): + async def deserialize_async( + self, + stream: AsyncReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(await stream.read()) # NOTE(kgriffs): Make content_type a kwarg to support the # Request.render_body() shortcut optimization. - def _serialize_s(self, media, content_type=None) -> bytes: - return self._dumps(media).encode() + def _serialize_s(self, media: Any, content_type: Optional[str] = None) -> bytes: + return self._dumps(media).encode() # type: ignore[union-attr] - async def _serialize_async_s(self, media, content_type) -> bytes: - return self._dumps(media).encode() + async def _serialize_async_s( + self, media: Any, content_type: Optional[str] + ) -> bytes: + return self._dumps(media).encode() # type: ignore[union-attr] # NOTE(kgriffs): Make content_type a kwarg to support the # Request.render_body() shortcut optimization. - def _serialize_b(self, media, content_type=None) -> bytes: - return self._dumps(media) + def _serialize_b(self, media: Any, content_type: Optional[str] = None) -> bytes: + return self._dumps(media) # type: ignore[return-value] - async def _serialize_async_b(self, media, content_type) -> bytes: - return self._dumps(media) + async def _serialize_async_b( + self, media: Any, content_type: Optional[str] + ) -> bytes: + return self._dumps(media) # type: ignore[return-value] class JSONHandlerWS(TextBaseHandlerWS): @@ -257,7 +280,11 @@ class JSONHandlerWS(TextBaseHandlerWS): __slots__ = ['dumps', 'loads'] - def __init__(self, dumps=None, loads=None): + def __init__( + self, + dumps: Optional[Callable[[Any], str]] = None, + loads: Optional[Callable[[str], Any]] = None, + ) -> None: self._dumps = dumps or partial(json.dumps, ensure_ascii=False) self._loads = loads or json.loads diff --git a/falcon/media/msgpack.py b/falcon/media/msgpack.py index 0267e2511..5b8c587c9 100644 --- a/falcon/media/msgpack.py +++ b/falcon/media/msgpack.py @@ -1,10 +1,12 @@ -from __future__ import absolute_import # NOTE(kgriffs): Work around a Cython bug +from __future__ import annotations -from typing import Union +from typing import Any, Callable, Optional, Protocol from falcon import errors from falcon.media.base import BaseHandler from falcon.media.base import BinaryBaseHandlerWS +from falcon.typing import AsyncReadableIO +from falcon.typing import ReadableIO class MessagePackHandler(BaseHandler): @@ -28,7 +30,10 @@ class MessagePackHandler(BaseHandler): $ pip install msgpack """ - def __init__(self): + _pack: Callable[[Any], bytes] + _unpackb: UnpackMethod + + def __init__(self) -> None: import msgpack packer = msgpack.Packer(autoreset=True, use_bin_type=True) @@ -38,10 +43,10 @@ def __init__(self): # NOTE(kgriffs): To be safe, only enable the optimized protocol when # not subclassed. if type(self) is MessagePackHandler: - self._serialize_sync = self._pack + self._serialize_sync = self._pack # type: ignore[assignment] self._deserialize_sync = self._deserialize - def _deserialize(self, data): + def _deserialize(self, data: bytes) -> Any: if not data: raise errors.MediaNotFoundError('MessagePack') try: @@ -51,16 +56,26 @@ def _deserialize(self, data): except ValueError as err: raise errors.MediaMalformedError('MessagePack') from err - def deserialize(self, stream, content_type, content_length): + def deserialize( + self, + stream: ReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(stream.read()) - async def deserialize_async(self, stream, content_type, content_length): + async def deserialize_async( + self, + stream: AsyncReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(await stream.read()) - def serialize(self, media, content_type) -> bytes: + def serialize(self, media: Any, content_type: Optional[str]) -> bytes: return self._pack(media) - async def serialize_async(self, media, content_type) -> bytes: + async def serialize_async(self, media: Any, content_type: Optional[str]) -> bytes: return self._pack(media) @@ -81,19 +96,26 @@ class MessagePackHandlerWS(BinaryBaseHandlerWS): $ pip install msgpack """ - __slots__ = ['msgpack', 'packer'] + __slots__ = ('msgpack', 'packer') + + _pack: Callable[[Any], bytes] + _unpackb: UnpackMethod - def __init__(self): + def __init__(self) -> None: import msgpack packer = msgpack.Packer(autoreset=True, use_bin_type=True) self._pack = packer.pack self._unpackb = msgpack.unpackb - def serialize(self, media: object) -> Union[bytes, bytearray, memoryview]: + def serialize(self, media: object) -> bytes: return self._pack(media) - def deserialize(self, payload: bytes) -> object: + def deserialize(self, payload: bytes) -> Any: # NOTE(jmvrbanac): Using unpackb since we would need to manage # a buffer for Unpacker() which wouldn't gain us much. return self._unpackb(payload, raw=False) + + +class UnpackMethod(Protocol): + def __call__(self, data: bytes, raw: bool = ...) -> Any: ... diff --git a/falcon/media/urlencoded.py b/falcon/media/urlencoded.py index 17f73dd65..1d7f6cb04 100644 --- a/falcon/media/urlencoded.py +++ b/falcon/media/urlencoded.py @@ -1,7 +1,12 @@ +from __future__ import annotations + +from typing import Any, Optional from urllib.parse import urlencode from falcon import errors from falcon.media.base import BaseHandler +from falcon.typing import AsyncReadableIO +from falcon.typing import ReadableIO from falcon.util.uri import parse_query_string @@ -28,7 +33,7 @@ class URLEncodedFormHandler(BaseHandler): when deserializing. """ - def __init__(self, keep_blank=True, csv=False): + def __init__(self, keep_blank: bool = True, csv: bool = False) -> None: self._keep_blank = keep_blank self._csv = csv @@ -40,23 +45,35 @@ def __init__(self, keep_blank=True, csv=False): # NOTE(kgriffs): Make content_type a kwarg to support the # Request.render_body() shortcut optimization. - def serialize(self, media, content_type=None) -> bytes: + def serialize(self, media: Any, content_type: Optional[str] = None) -> bytes: # NOTE(vytas): Setting doseq to True to mirror the parse_query_string # behaviour. return urlencode(media, doseq=True).encode() - def _deserialize(self, body): + def _deserialize(self, body: bytes) -> Any: try: # NOTE(kgriffs): According to http://goo.gl/6rlcux the # body should be US-ASCII. Enforcing this also helps # catch malicious input. - body = body.decode('ascii') - return parse_query_string(body, keep_blank=self._keep_blank, csv=self._csv) + body_str = body.decode('ascii') + return parse_query_string( + body_str, keep_blank=self._keep_blank, csv=self._csv + ) except Exception as err: raise errors.MediaMalformedError('URL-encoded') from err - def deserialize(self, stream, content_type, content_length): + def deserialize( + self, + stream: ReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(stream.read()) - async def deserialize_async(self, stream, content_type, content_length): + async def deserialize_async( + self, + stream: AsyncReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(await stream.read()) diff --git a/pyproject.toml b/pyproject.toml index c857130dc..214ba7bde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,10 +45,7 @@ "falcon.asgi.multipart", "falcon.asgi.response", "falcon.asgi.stream", - "falcon.media.json", - "falcon.media.msgpack", "falcon.media.multipart", - "falcon.media.urlencoded", "falcon.media.validators.*", "falcon.responders", "falcon.response_helpers",