Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow overwriting the encoder set / use encoders of requests / more orjson for more speed #440

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/en/docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ hide:

# Release Notes

## Unreleased

### Changed

- Use assigned encoders at requests for json_encoder.
- Allow overwriting the `LILYA_ENCODER_TYPES` for different encoder sets or tests.
- Use more orjson for encoding requests.

## 3.5.0

### Added
Expand Down
22 changes: 15 additions & 7 deletions esmerald/applications.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import warnings
from collections.abc import Callable, Iterable, Sequence
from datetime import timezone as dtimezone
from functools import cached_property
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -1024,7 +1024,7 @@ async def another(request: Request) -> str:
),
] = None,
encoders: Annotated[
Sequence[Optional[Encoder]],
Optional[Sequence[Union[Encoder, type[Encoder]]]],
Doc(
"""
A `list` of encoders to be used by the application once it
Expand Down Expand Up @@ -1608,7 +1608,15 @@ def extend(self, config: PluggableConfig) -> None:
] = State()
self.async_exit_config = esmerald_settings.async_exit_config

self.encoders = self.load_settings_value("encoders", encoders) or []
self.encoders = list(
cast(
Iterable[Union[Encoder]],
(
encoder if isclass(encoder) else encoder
for encoder in self.load_settings_value("encoders", encoders) or []
),
)
)
self._register_application_encoders()

if self.enable_scheduler:
Expand Down Expand Up @@ -1662,8 +1670,8 @@ def _register_application_encoders(self) -> None:

This way, the support still remains but using the Lilya Encoders.
"""
self.register_encoder(cast(Encoder[Any], PydanticEncoder))
self.register_encoder(cast(Encoder[Any], MsgSpecEncoder))
self.register_encoder(cast(Encoder, PydanticEncoder))
self.register_encoder(cast(Encoder, MsgSpecEncoder))

for encoder in self.encoders:
self.register_encoder(encoder)
Expand Down Expand Up @@ -2611,7 +2619,7 @@ def on_event(self, event_type: str) -> Callable: # pragma: nocover
def add_event_handler(self, event_type: str, func: Callable) -> None: # pragma: no cover
self.router.add_event_handler(event_type, func)

def register_encoder(self, encoder: Encoder[Any]) -> None:
def register_encoder(self, encoder: Encoder) -> None:
"""
Registers a Encoder into the list of predefined encoders of the system.
"""
Expand Down
33 changes: 18 additions & 15 deletions esmerald/encoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from inspect import isclass
from typing import Any, TypeVar, get_args

import msgspec
Expand All @@ -26,7 +27,7 @@ class Encoder(LilyaEncoder[T]):
def is_type(self, value: Any) -> bool:
"""
Function that checks if the function is
an instance of a given type
an instance of a given type (and also for the subclass of the type in case of encode)
"""
raise NotImplementedError("All Esmerald encoders must implement is_type() method.")

Expand All @@ -44,6 +45,19 @@ def encode(self, annotation: Any, value: Any) -> Any:
raise NotImplementedError("All Esmerald encoders must implement encode() method.")


def register_esmerald_encoder(encoder: Encoder | type[Encoder]) -> None:
"""
Registers an esmerald encoder into available Lilya encoders
"""
encoder_type = encoder if isclass(encoder) else type(encoder)
if not isinstance(encoder, Encoder) and not is_class_and_subclass(encoder, Encoder):
raise ImproperlyConfigured(f"{encoder_type} must be a subclass of Encoder")

encoder_types = {_encoder.__class__.__name__ for _encoder in LILYA_ENCODER_TYPES.get()}
if encoder_type.__name__ not in encoder_types:
register_encoder(encoder)


class MsgSpecEncoder(Encoder):
def is_type(self, value: Any) -> bool:
return isinstance(value, Struct) or is_class_and_subclass(value, Struct)
Expand Down Expand Up @@ -75,28 +89,17 @@ def encode(self, annotation: Any, value: Any) -> Any:
return annotation(**value)


def register_esmerald_encoder(encoder: Encoder[Any]) -> None:
"""
Registers an esmerald encoder into available Lilya encoders
"""
if not isinstance(encoder, Encoder) and not is_class_and_subclass(encoder, Encoder): # type: ignore
raise ImproperlyConfigured(f"{type(encoder)} must be a subclass of Encoder")

encoder_types = {encoder.__class__.__name__ for encoder in ENCODER_TYPES}
if encoder.__name__ not in encoder_types:
register_encoder(encoder)


def is_body_encoder(value: Any) -> bool:
"""
Function that checks if the value is a body encoder.
"""
encoder_types = LILYA_ENCODER_TYPES.get()
if not is_union(value):
return any(encoder.is_type(value) for encoder in ENCODER_TYPES)
return any(encoder.is_type(value) for encoder in encoder_types)

union_arguments = get_args(value)
if not union_arguments:
return False
return any(
any(encoder.is_type(argument) for encoder in ENCODER_TYPES) for argument in union_arguments
any(encoder.is_type(argument) for encoder in encoder_types) for argument in union_arguments
)
34 changes: 27 additions & 7 deletions esmerald/responses/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -23,10 +25,10 @@
Response as LilyaResponse, # noqa
StreamingResponse as StreamingResponse, # noqa
)
from orjson import OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps
from orjson import OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps, loads
from typing_extensions import Annotated, Doc

from esmerald.encoders import Encoder, json_encoder
from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder, json_encoder
from esmerald.enums import MediaType
from esmerald.exceptions import ImproperlyConfigured

Expand Down Expand Up @@ -180,9 +182,21 @@ def transform(value: Any) -> Dict[str, Any]:

Supports all the default encoders from Lilya and custom from Esmerald.
"""
return cast(Dict[str, Any], json_encoder(value))
return cast(
Dict[str, Any], json_encoder(value, json_encode_fn=dumps, post_transform_fn=loads)
)

def make_response(self, content: Any) -> Union[bytes, str]:
encoders = (
(
(
*(encoder() if isclass(encoder) else encoder for encoder in self.encoders),
*LILYA_ENCODER_TYPES.get(),
)
)
if self.encoders
else None
)
try:
if (
content is None
Expand All @@ -195,10 +209,16 @@ def make_response(self, content: Any) -> Union[bytes, str]:
):
return b""
if self.media_type == MediaType.JSON:
return dumps(
content,
default=self.transform,
option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS,
return cast(
bytes,
json_encoder(
content,
json_encode_fn=partial(
dumps, option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS
),
post_transform_fn=None,
with_encoders=encoders,
),
)
return super().make_response(content)
except (AttributeError, ValueError, TypeError) as e: # pragma: no cover
Expand Down
31 changes: 26 additions & 5 deletions esmerald/responses/encoders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any
from functools import partial
from inspect import isclass
from typing import Any, cast

import orjson

from esmerald.encoders import LILYA_ENCODER_TYPES, json_encoder
from esmerald.responses.json import BaseJSONResponse

try:
Expand All @@ -18,10 +21,26 @@ class ORJSONResponse(BaseJSONResponse):
"""

def make_response(self, content: Any) -> bytes:
return orjson.dumps(
content,
default=self.transform,
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_OMIT_MICROSECONDS,
encoders = (
(
(
*(encoder() if isclass(encoder) else encoder for encoder in self.encoders),
*LILYA_ENCODER_TYPES.get(),
)
)
if self.encoders
else None
)
return cast(
bytes,
json_encoder(
content,
json_encode_fn=partial(
orjson.dumps, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_OMIT_MICROSECONDS
),
post_transform_fn=None,
with_encoders=encoders,
),
)


Expand All @@ -34,6 +53,8 @@ class UJSONResponse(BaseJSONResponse):

def make_response(self, content: Any) -> bytes:
assert ujson is not None, "You must install the encoders or ujson to use UJSONResponse"
# UJSON is actually in maintainance mode, recommends switch to ORJSON
# https://github.com/ultrajson/ultrajson
return ujson.dumps(content, ensure_ascii=False).encode("utf-8")


Expand Down
4 changes: 2 additions & 2 deletions esmerald/routing/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic.fields import FieldInfo

from esmerald.datastructures import UploadFile
from esmerald.encoders import ENCODER_TYPES, is_body_encoder
from esmerald.encoders import LILYA_ENCODER_TYPES, is_body_encoder
from esmerald.enums import EncodingType
from esmerald.openapi.params import ResponseParam
from esmerald.params import Body
Expand Down Expand Up @@ -62,7 +62,7 @@ def convert_annotation_to_pydantic_model(field_annotation: Any) -> Any:

if (
not isinstance(field_annotation, BaseModel)
and any(encoder.is_type(field_annotation) for encoder in ENCODER_TYPES)
and any(encoder.is_type(field_annotation) for encoder in LILYA_ENCODER_TYPES.get())
and inspect.isclass(field_annotation)
):
field_definitions: Dict[str, Any] = {}
Expand Down
4 changes: 2 additions & 2 deletions esmerald/transformers/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from orjson import loads
from pydantic import ValidationError, create_model

from esmerald.encoders import ENCODER_TYPES, Encoder
from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder
from esmerald.exceptions import (
HTTPException,
ImproperlyConfigured,
Expand Down Expand Up @@ -516,7 +516,7 @@ def _find_encoder(self, annotation: Any) -> Any:
Any: The encoder found, or None if no encoder matches.
"""
origin = get_origin(annotation)
for encoder in ENCODER_TYPES:
for encoder in LILYA_ENCODER_TYPES.get():
if not origin and encoder.is_type(annotation):
return encoder
elif origin:
Expand Down
1 change: 0 additions & 1 deletion tests/dependencies/test_simple_case_injected_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ async def create(


def test_injection():

with create_client(routes=[Gateway(handler=DocumentAPIView)]) as client:
response = client.post("/", json={"name": "test", "content": "test"})
assert response.status_code == 201
Expand Down
26 changes: 21 additions & 5 deletions tests/encoding/test_attrs_encoders.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from collections import deque
from typing import Any

import pytest
from attrs import asdict, define, field, has

from esmerald import Gateway, post
from esmerald.encoders import Encoder, register_esmerald_encoder
from esmerald.encoders import (
ENCODER_TYPES,
LILYA_ENCODER_TYPES,
Encoder,
register_esmerald_encoder,
)
from esmerald.testclient import create_client


class AttrsEncoder(Encoder):

def is_type(self, value: Any) -> bool:
return has(value)

Expand All @@ -19,7 +25,14 @@ def encode(self, annotation: Any, value: Any) -> Any:
return annotation(**value)


register_esmerald_encoder(AttrsEncoder)
@pytest.fixture(autouse=True, scope="function")
def additional_encoders():
token = LILYA_ENCODER_TYPES.set(deque(LILYA_ENCODER_TYPES.get()))
try:
register_esmerald_encoder(AttrsEncoder)
yield
finally:
LILYA_ENCODER_TYPES.reset(token)


@define
Expand All @@ -29,10 +42,14 @@ class AttrItem:
email: str


def test_can_parse_attrs(test_app_client_factory):
def test_working_overwrite():
assert LILYA_ENCODER_TYPES.get() is not ENCODER_TYPES


def test_can_parse_attrs(test_app_client_factory):
@post("/create")
async def create(data: AttrItem) -> AttrItem:
assert type(LILYA_ENCODER_TYPES.get()[0]) is AttrsEncoder
return data

with create_client(routes=[Gateway(handler=create)]) as client:
Expand All @@ -44,7 +61,6 @@ async def create(data: AttrItem) -> AttrItem:


def test_can_parse_attrs_errors(test_app_client_factory):

@define
class Item:
sku: str = field()
Expand Down
Loading