From 0ee361c9e9e7a94f13795ab554bf17d89561d1ab Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 26 Nov 2024 12:38:08 +0100 Subject: [PATCH] Changes: - Use assigned encoders at requests for json_encoder. - save encoders as instances to application, - allow temporary use of a different encoder set --- docs/en/docs/release-notes.md | 8 +++++ esmerald/applications.py | 22 ++++++++---- esmerald/encoders.py | 33 ++++++++++-------- esmerald/responses/base.py | 34 +++++++++++++++---- esmerald/responses/encoders.py | 31 ++++++++++++++--- esmerald/routing/_internal.py | 4 +-- esmerald/transformers/signature.py | 4 +-- .../test_simple_case_injected_annotation.py | 1 - tests/encoding/test_attrs_encoders.py | 17 +++++++--- tests/encoding/test_register_encoder.py | 16 +++++++-- 10 files changed, 123 insertions(+), 47 deletions(-) diff --git a/docs/en/docs/release-notes.md b/docs/en/docs/release-notes.md index 5fa8cff7..9bc90409 100644 --- a/docs/en/docs/release-notes.md +++ b/docs/en/docs/release-notes.md @@ -5,6 +5,14 @@ hide: # Release Notes +## Unreleased + +### Changed + +- Use assigned encoders at requests for json_encoder. +- Allow overwriting the `LILYA_ENCODER_TYPES` for different encoder sets or tests. +- Use more orjson for encoding requests. + ## 3.5.0 ### Added diff --git a/esmerald/applications.py b/esmerald/applications.py index 5fa98dc9..81f640d3 100644 --- a/esmerald/applications.py +++ b/esmerald/applications.py @@ -1,14 +1,14 @@ import warnings +from collections.abc import Callable, Iterable, Sequence from datetime import timezone as dtimezone from functools import cached_property +from inspect import isclass from typing import ( TYPE_CHECKING, Any, - Callable, Dict, List, Optional, - Sequence, Type, TypeVar, Union, @@ -1024,7 +1024,7 @@ async def another(request: Request) -> str: ), ] = None, encoders: Annotated[ - Sequence[Optional[Encoder]], + Optional[Sequence[Union[Encoder, type[Encoder]]]], Doc( """ A `list` of encoders to be used by the application once it @@ -1608,7 +1608,15 @@ def extend(self, config: PluggableConfig) -> None: ] = State() self.async_exit_config = esmerald_settings.async_exit_config - self.encoders = self.load_settings_value("encoders", encoders) or [] + self.encoders = list( + cast( + Iterable[Union[Encoder]], + ( + encoder if isclass(encoder) else encoder + for encoder in self.load_settings_value("encoders", encoders) or [] + ), + ) + ) self._register_application_encoders() if self.enable_scheduler: @@ -1662,8 +1670,8 @@ def _register_application_encoders(self) -> None: This way, the support still remains but using the Lilya Encoders. """ - self.register_encoder(cast(Encoder[Any], PydanticEncoder)) - self.register_encoder(cast(Encoder[Any], MsgSpecEncoder)) + self.register_encoder(cast(Encoder, PydanticEncoder)) + self.register_encoder(cast(Encoder, MsgSpecEncoder)) for encoder in self.encoders: self.register_encoder(encoder) @@ -2611,7 +2619,7 @@ def on_event(self, event_type: str) -> Callable: # pragma: nocover def add_event_handler(self, event_type: str, func: Callable) -> None: # pragma: no cover self.router.add_event_handler(event_type, func) - def register_encoder(self, encoder: Encoder[Any]) -> None: + def register_encoder(self, encoder: Encoder) -> None: """ Registers a Encoder into the list of predefined encoders of the system. """ diff --git a/esmerald/encoders.py b/esmerald/encoders.py index 5dc2df1f..7ad4dd1c 100644 --- a/esmerald/encoders.py +++ b/esmerald/encoders.py @@ -1,5 +1,6 @@ from __future__ import annotations +from inspect import isclass from typing import Any, TypeVar, get_args import msgspec @@ -26,7 +27,7 @@ class Encoder(LilyaEncoder[T]): def is_type(self, value: Any) -> bool: """ Function that checks if the function is - an instance of a given type + an instance of a given type (and also for the subclass of the type in case of encode) """ raise NotImplementedError("All Esmerald encoders must implement is_type() method.") @@ -44,6 +45,19 @@ def encode(self, annotation: Any, value: Any) -> Any: raise NotImplementedError("All Esmerald encoders must implement encode() method.") +def register_esmerald_encoder(encoder: Encoder | type[Encoder]) -> None: + """ + Registers an esmerald encoder into available Lilya encoders + """ + encoder_type = encoder if isclass(encoder) else type(encoder) + if not isinstance(encoder, Encoder) and not is_class_and_subclass(encoder, Encoder): + raise ImproperlyConfigured(f"{encoder_type} must be a subclass of Encoder") + + encoder_types = {_encoder.__class__.__name__ for _encoder in LILYA_ENCODER_TYPES.get()} + if encoder_type.__name__ not in encoder_types: + register_encoder(encoder) + + class MsgSpecEncoder(Encoder): def is_type(self, value: Any) -> bool: return isinstance(value, Struct) or is_class_and_subclass(value, Struct) @@ -75,28 +89,17 @@ def encode(self, annotation: Any, value: Any) -> Any: return annotation(**value) -def register_esmerald_encoder(encoder: Encoder[Any]) -> None: - """ - Registers an esmerald encoder into available Lilya encoders - """ - if not isinstance(encoder, Encoder) and not is_class_and_subclass(encoder, Encoder): # type: ignore - raise ImproperlyConfigured(f"{type(encoder)} must be a subclass of Encoder") - - encoder_types = {encoder.__class__.__name__ for encoder in ENCODER_TYPES} - if encoder.__name__ not in encoder_types: - register_encoder(encoder) - - def is_body_encoder(value: Any) -> bool: """ Function that checks if the value is a body encoder. """ + encoder_types = LILYA_ENCODER_TYPES.get() if not is_union(value): - return any(encoder.is_type(value) for encoder in ENCODER_TYPES) + return any(encoder.is_type(value) for encoder in encoder_types) union_arguments = get_args(value) if not union_arguments: return False return any( - any(encoder.is_type(argument) for encoder in ENCODER_TYPES) for argument in union_arguments + any(encoder.is_type(argument) for encoder in encoder_types) for argument in union_arguments ) diff --git a/esmerald/responses/base.py b/esmerald/responses/base.py index 22297f7a..7cdda69d 100644 --- a/esmerald/responses/base.py +++ b/esmerald/responses/base.py @@ -1,3 +1,5 @@ +from functools import partial +from inspect import isclass from typing import ( TYPE_CHECKING, Any, @@ -23,10 +25,10 @@ Response as LilyaResponse, # noqa StreamingResponse as StreamingResponse, # noqa ) -from orjson import OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps +from orjson import OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps, loads from typing_extensions import Annotated, Doc -from esmerald.encoders import Encoder, json_encoder +from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder, json_encoder from esmerald.enums import MediaType from esmerald.exceptions import ImproperlyConfigured @@ -180,9 +182,21 @@ def transform(value: Any) -> Dict[str, Any]: Supports all the default encoders from Lilya and custom from Esmerald. """ - return cast(Dict[str, Any], json_encoder(value)) + return cast( + Dict[str, Any], json_encoder(value, json_encode_fn=dumps, post_transform_fn=loads) + ) def make_response(self, content: Any) -> Union[bytes, str]: + encoders = ( + ( + ( + *(encoder() if isclass(encoder) else encoder for encoder in self.encoders), + *LILYA_ENCODER_TYPES.get(), + ) + ) + if self.encoders + else None + ) try: if ( content is None @@ -195,10 +209,16 @@ def make_response(self, content: Any) -> Union[bytes, str]: ): return b"" if self.media_type == MediaType.JSON: - return dumps( - content, - default=self.transform, - option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS, + return cast( + bytes, + json_encoder( + content, + json_encode_fn=partial( + dumps, option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS + ), + post_transform_fn=None, + with_encoders=encoders, + ), ) return super().make_response(content) except (AttributeError, ValueError, TypeError) as e: # pragma: no cover diff --git a/esmerald/responses/encoders.py b/esmerald/responses/encoders.py index bf04f905..ebc14f1f 100644 --- a/esmerald/responses/encoders.py +++ b/esmerald/responses/encoders.py @@ -1,7 +1,10 @@ -from typing import Any +from functools import partial +from inspect import isclass +from typing import Any, cast import orjson +from esmerald.encoders import LILYA_ENCODER_TYPES, json_encoder from esmerald.responses.json import BaseJSONResponse try: @@ -18,10 +21,26 @@ class ORJSONResponse(BaseJSONResponse): """ def make_response(self, content: Any) -> bytes: - return orjson.dumps( - content, - default=self.transform, - option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_OMIT_MICROSECONDS, + encoders = ( + ( + ( + *(encoder() if isclass(encoder) else encoder for encoder in self.encoders), + *LILYA_ENCODER_TYPES.get(), + ) + ) + if self.encoders + else None + ) + return cast( + bytes, + json_encoder( + content, + json_encode_fn=partial( + orjson.dumps, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_OMIT_MICROSECONDS + ), + post_transform_fn=None, + with_encoders=encoders, + ), ) @@ -34,6 +53,8 @@ class UJSONResponse(BaseJSONResponse): def make_response(self, content: Any) -> bytes: assert ujson is not None, "You must install the encoders or ujson to use UJSONResponse" + # UJSON is actually in maintainance mode, recommends switch to ORJSON + # https://github.com/ultrajson/ultrajson return ujson.dumps(content, ensure_ascii=False).encode("utf-8") diff --git a/esmerald/routing/_internal.py b/esmerald/routing/_internal.py index 934c332a..d5cfe988 100644 --- a/esmerald/routing/_internal.py +++ b/esmerald/routing/_internal.py @@ -7,7 +7,7 @@ from pydantic.fields import FieldInfo from esmerald.datastructures import UploadFile -from esmerald.encoders import ENCODER_TYPES, is_body_encoder +from esmerald.encoders import LILYA_ENCODER_TYPES, is_body_encoder from esmerald.enums import EncodingType from esmerald.openapi.params import ResponseParam from esmerald.params import Body @@ -62,7 +62,7 @@ def convert_annotation_to_pydantic_model(field_annotation: Any) -> Any: if ( not isinstance(field_annotation, BaseModel) - and any(encoder.is_type(field_annotation) for encoder in ENCODER_TYPES) + and any(encoder.is_type(field_annotation) for encoder in LILYA_ENCODER_TYPES.get()) and inspect.isclass(field_annotation) ): field_definitions: Dict[str, Any] = {} diff --git a/esmerald/transformers/signature.py b/esmerald/transformers/signature.py index 865ab02d..f48cc411 100644 --- a/esmerald/transformers/signature.py +++ b/esmerald/transformers/signature.py @@ -19,7 +19,7 @@ from orjson import loads from pydantic import ValidationError, create_model -from esmerald.encoders import ENCODER_TYPES, Encoder +from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder from esmerald.exceptions import ( HTTPException, ImproperlyConfigured, @@ -516,7 +516,7 @@ def _find_encoder(self, annotation: Any) -> Any: Any: The encoder found, or None if no encoder matches. """ origin = get_origin(annotation) - for encoder in ENCODER_TYPES: + for encoder in LILYA_ENCODER_TYPES.get(): if not origin and encoder.is_type(annotation): return encoder elif origin: diff --git a/tests/dependencies/test_simple_case_injected_annotation.py b/tests/dependencies/test_simple_case_injected_annotation.py index 21c95a68..9f668e58 100644 --- a/tests/dependencies/test_simple_case_injected_annotation.py +++ b/tests/dependencies/test_simple_case_injected_annotation.py @@ -25,7 +25,6 @@ async def create( def test_injection(): - with create_client(routes=[Gateway(handler=DocumentAPIView)]) as client: response = client.post("/", json={"name": "test", "content": "test"}) assert response.status_code == 201 diff --git a/tests/encoding/test_attrs_encoders.py b/tests/encoding/test_attrs_encoders.py index 7585bc3c..1bb54936 100644 --- a/tests/encoding/test_attrs_encoders.py +++ b/tests/encoding/test_attrs_encoders.py @@ -1,14 +1,15 @@ +from collections import deque from typing import Any +import pytest from attrs import asdict, define, field, has from esmerald import Gateway, post -from esmerald.encoders import Encoder, register_esmerald_encoder +from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder, register_esmerald_encoder from esmerald.testclient import create_client class AttrsEncoder(Encoder): - def is_type(self, value: Any) -> bool: return has(value) @@ -19,7 +20,14 @@ def encode(self, annotation: Any, value: Any) -> Any: return annotation(**value) -register_esmerald_encoder(AttrsEncoder) +@pytest.fixture(autouse=True, scope="function") +def additional_encoders(): + token = LILYA_ENCODER_TYPES.set(deque(LILYA_ENCODER_TYPES.get())) + try: + register_esmerald_encoder(AttrsEncoder) + yield + finally: + LILYA_ENCODER_TYPES.reset(token) @define @@ -30,9 +38,9 @@ class AttrItem: def test_can_parse_attrs(test_app_client_factory): - @post("/create") async def create(data: AttrItem) -> AttrItem: + assert type(LILYA_ENCODER_TYPES.get()[0]) is AttrsEncoder return data with create_client(routes=[Gateway(handler=create)]) as client: @@ -44,7 +52,6 @@ async def create(data: AttrItem) -> AttrItem: def test_can_parse_attrs_errors(test_app_client_factory): - @define class Item: sku: str = field() diff --git a/tests/encoding/test_register_encoder.py b/tests/encoding/test_register_encoder.py index e57a2a5e..2a28709a 100644 --- a/tests/encoding/test_register_encoder.py +++ b/tests/encoding/test_register_encoder.py @@ -1,14 +1,15 @@ +from collections import deque from typing import Any +import pytest from attrs import asdict, define, field, has from esmerald import Esmerald, Gateway, post -from esmerald.encoders import Encoder +from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder from esmerald.testclient import EsmeraldTestClient, create_client class AttrsEncoder(Encoder): - def is_type(self, value: Any) -> bool: return has(value) @@ -26,9 +27,19 @@ class AttrItem: email: str +@pytest.fixture(autouse=True, scope="function") +def additional_encoders(): + token = LILYA_ENCODER_TYPES.set(deque(LILYA_ENCODER_TYPES.get())) + try: + yield + finally: + LILYA_ENCODER_TYPES.reset(token) + + def test_can_parse_attrs(test_app_client_factory): @post("/create") async def create(data: AttrItem) -> AttrItem: + assert type(LILYA_ENCODER_TYPES.get()[0]) is AttrsEncoder return data app = Esmerald(routes=[Gateway(handler=create)], encoders=[AttrsEncoder]) @@ -40,7 +51,6 @@ async def create(data: AttrItem) -> AttrItem: def test_can_parse_attrs_errors(test_app_client_factory): - @define class Item: sku: str = field()