Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- Use assigned encoders at requests for json_encoder.
- save encoders as instances to application,
- allow temporary use of a different encoder set
  • Loading branch information
devkral committed Nov 26, 2024
1 parent 0e1e695 commit 0ee361c
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 47 deletions.
8 changes: 8 additions & 0 deletions docs/en/docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ hide:

# Release Notes

## Unreleased

### Changed

- Use assigned encoders at requests for json_encoder.
- Allow overwriting the `LILYA_ENCODER_TYPES` for different encoder sets or tests.
- Use more orjson for encoding requests.

## 3.5.0

### Added
Expand Down
22 changes: 15 additions & 7 deletions esmerald/applications.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import warnings
from collections.abc import Callable, Iterable, Sequence
from datetime import timezone as dtimezone
from functools import cached_property
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -1024,7 +1024,7 @@ async def another(request: Request) -> str:
),
] = None,
encoders: Annotated[
Sequence[Optional[Encoder]],
Optional[Sequence[Union[Encoder, type[Encoder]]]],
Doc(
"""
A `list` of encoders to be used by the application once it
Expand Down Expand Up @@ -1608,7 +1608,15 @@ def extend(self, config: PluggableConfig) -> None:
] = State()
self.async_exit_config = esmerald_settings.async_exit_config

self.encoders = self.load_settings_value("encoders", encoders) or []
self.encoders = list(
cast(
Iterable[Union[Encoder]],
(
encoder if isclass(encoder) else encoder
for encoder in self.load_settings_value("encoders", encoders) or []
),
)
)
self._register_application_encoders()

if self.enable_scheduler:
Expand Down Expand Up @@ -1662,8 +1670,8 @@ def _register_application_encoders(self) -> None:
This way, the support still remains but using the Lilya Encoders.
"""
self.register_encoder(cast(Encoder[Any], PydanticEncoder))
self.register_encoder(cast(Encoder[Any], MsgSpecEncoder))
self.register_encoder(cast(Encoder, PydanticEncoder))
self.register_encoder(cast(Encoder, MsgSpecEncoder))

for encoder in self.encoders:
self.register_encoder(encoder)
Expand Down Expand Up @@ -2611,7 +2619,7 @@ def on_event(self, event_type: str) -> Callable: # pragma: nocover
def add_event_handler(self, event_type: str, func: Callable) -> None: # pragma: no cover
self.router.add_event_handler(event_type, func)

def register_encoder(self, encoder: Encoder[Any]) -> None:
def register_encoder(self, encoder: Encoder) -> None:
"""
Registers a Encoder into the list of predefined encoders of the system.
"""
Expand Down
33 changes: 18 additions & 15 deletions esmerald/encoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from inspect import isclass
from typing import Any, TypeVar, get_args

import msgspec
Expand All @@ -26,7 +27,7 @@ class Encoder(LilyaEncoder[T]):
def is_type(self, value: Any) -> bool:
"""
Function that checks if the function is
an instance of a given type
an instance of a given type (and also for the subclass of the type in case of encode)
"""
raise NotImplementedError("All Esmerald encoders must implement is_type() method.")

Expand All @@ -44,6 +45,19 @@ def encode(self, annotation: Any, value: Any) -> Any:
raise NotImplementedError("All Esmerald encoders must implement encode() method.")


def register_esmerald_encoder(encoder: Encoder | type[Encoder]) -> None:
"""
Registers an esmerald encoder into available Lilya encoders
"""
encoder_type = encoder if isclass(encoder) else type(encoder)
if not isinstance(encoder, Encoder) and not is_class_and_subclass(encoder, Encoder):
raise ImproperlyConfigured(f"{encoder_type} must be a subclass of Encoder")

encoder_types = {_encoder.__class__.__name__ for _encoder in LILYA_ENCODER_TYPES.get()}
if encoder_type.__name__ not in encoder_types:
register_encoder(encoder)


class MsgSpecEncoder(Encoder):
def is_type(self, value: Any) -> bool:
return isinstance(value, Struct) or is_class_and_subclass(value, Struct)
Expand Down Expand Up @@ -75,28 +89,17 @@ def encode(self, annotation: Any, value: Any) -> Any:
return annotation(**value)


def register_esmerald_encoder(encoder: Encoder[Any]) -> None:
"""
Registers an esmerald encoder into available Lilya encoders
"""
if not isinstance(encoder, Encoder) and not is_class_and_subclass(encoder, Encoder): # type: ignore
raise ImproperlyConfigured(f"{type(encoder)} must be a subclass of Encoder")

encoder_types = {encoder.__class__.__name__ for encoder in ENCODER_TYPES}
if encoder.__name__ not in encoder_types:
register_encoder(encoder)


def is_body_encoder(value: Any) -> bool:
"""
Function that checks if the value is a body encoder.
"""
encoder_types = LILYA_ENCODER_TYPES.get()
if not is_union(value):
return any(encoder.is_type(value) for encoder in ENCODER_TYPES)
return any(encoder.is_type(value) for encoder in encoder_types)

union_arguments = get_args(value)
if not union_arguments:
return False
return any(
any(encoder.is_type(argument) for encoder in ENCODER_TYPES) for argument in union_arguments
any(encoder.is_type(argument) for encoder in encoder_types) for argument in union_arguments
)
34 changes: 27 additions & 7 deletions esmerald/responses/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -23,10 +25,10 @@
Response as LilyaResponse, # noqa
StreamingResponse as StreamingResponse, # noqa
)
from orjson import OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps
from orjson import OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps, loads
from typing_extensions import Annotated, Doc

from esmerald.encoders import Encoder, json_encoder
from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder, json_encoder
from esmerald.enums import MediaType
from esmerald.exceptions import ImproperlyConfigured

Expand Down Expand Up @@ -180,9 +182,21 @@ def transform(value: Any) -> Dict[str, Any]:
Supports all the default encoders from Lilya and custom from Esmerald.
"""
return cast(Dict[str, Any], json_encoder(value))
return cast(
Dict[str, Any], json_encoder(value, json_encode_fn=dumps, post_transform_fn=loads)
)

def make_response(self, content: Any) -> Union[bytes, str]:
encoders = (
(
(
*(encoder() if isclass(encoder) else encoder for encoder in self.encoders),
*LILYA_ENCODER_TYPES.get(),
)
)
if self.encoders
else None
)
try:
if (
content is None
Expand All @@ -195,10 +209,16 @@ def make_response(self, content: Any) -> Union[bytes, str]:
):
return b""
if self.media_type == MediaType.JSON:
return dumps(
content,
default=self.transform,
option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS,
return cast(
bytes,
json_encoder(
content,
json_encode_fn=partial(
dumps, option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS
),
post_transform_fn=None,
with_encoders=encoders,
),
)
return super().make_response(content)
except (AttributeError, ValueError, TypeError) as e: # pragma: no cover
Expand Down
31 changes: 26 additions & 5 deletions esmerald/responses/encoders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any
from functools import partial
from inspect import isclass
from typing import Any, cast

import orjson

from esmerald.encoders import LILYA_ENCODER_TYPES, json_encoder
from esmerald.responses.json import BaseJSONResponse

try:
Expand All @@ -18,10 +21,26 @@ class ORJSONResponse(BaseJSONResponse):
"""

def make_response(self, content: Any) -> bytes:
return orjson.dumps(
content,
default=self.transform,
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_OMIT_MICROSECONDS,
encoders = (
(
(
*(encoder() if isclass(encoder) else encoder for encoder in self.encoders),
*LILYA_ENCODER_TYPES.get(),
)
)
if self.encoders
else None
)
return cast(
bytes,
json_encoder(
content,
json_encode_fn=partial(
orjson.dumps, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_OMIT_MICROSECONDS
),
post_transform_fn=None,
with_encoders=encoders,
),
)


Expand All @@ -34,6 +53,8 @@ class UJSONResponse(BaseJSONResponse):

def make_response(self, content: Any) -> bytes:
assert ujson is not None, "You must install the encoders or ujson to use UJSONResponse"
# UJSON is actually in maintainance mode, recommends switch to ORJSON
# https://github.com/ultrajson/ultrajson
return ujson.dumps(content, ensure_ascii=False).encode("utf-8")


Expand Down
4 changes: 2 additions & 2 deletions esmerald/routing/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic.fields import FieldInfo

from esmerald.datastructures import UploadFile
from esmerald.encoders import ENCODER_TYPES, is_body_encoder
from esmerald.encoders import LILYA_ENCODER_TYPES, is_body_encoder
from esmerald.enums import EncodingType
from esmerald.openapi.params import ResponseParam
from esmerald.params import Body
Expand Down Expand Up @@ -62,7 +62,7 @@ def convert_annotation_to_pydantic_model(field_annotation: Any) -> Any:

if (
not isinstance(field_annotation, BaseModel)
and any(encoder.is_type(field_annotation) for encoder in ENCODER_TYPES)
and any(encoder.is_type(field_annotation) for encoder in LILYA_ENCODER_TYPES.get())
and inspect.isclass(field_annotation)
):
field_definitions: Dict[str, Any] = {}
Expand Down
4 changes: 2 additions & 2 deletions esmerald/transformers/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from orjson import loads
from pydantic import ValidationError, create_model

from esmerald.encoders import ENCODER_TYPES, Encoder
from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder
from esmerald.exceptions import (
HTTPException,
ImproperlyConfigured,
Expand Down Expand Up @@ -516,7 +516,7 @@ def _find_encoder(self, annotation: Any) -> Any:
Any: The encoder found, or None if no encoder matches.
"""
origin = get_origin(annotation)
for encoder in ENCODER_TYPES:
for encoder in LILYA_ENCODER_TYPES.get():
if not origin and encoder.is_type(annotation):
return encoder
elif origin:
Expand Down
1 change: 0 additions & 1 deletion tests/dependencies/test_simple_case_injected_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ async def create(


def test_injection():

with create_client(routes=[Gateway(handler=DocumentAPIView)]) as client:
response = client.post("/", json={"name": "test", "content": "test"})
assert response.status_code == 201
Expand Down
17 changes: 12 additions & 5 deletions tests/encoding/test_attrs_encoders.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from collections import deque
from typing import Any

import pytest
from attrs import asdict, define, field, has

from esmerald import Gateway, post
from esmerald.encoders import Encoder, register_esmerald_encoder
from esmerald.encoders import LILYA_ENCODER_TYPES, Encoder, register_esmerald_encoder
from esmerald.testclient import create_client


class AttrsEncoder(Encoder):

def is_type(self, value: Any) -> bool:
return has(value)

Expand All @@ -19,7 +20,14 @@ def encode(self, annotation: Any, value: Any) -> Any:
return annotation(**value)


register_esmerald_encoder(AttrsEncoder)
@pytest.fixture(autouse=True, scope="function")
def additional_encoders():
token = LILYA_ENCODER_TYPES.set(deque(LILYA_ENCODER_TYPES.get()))
try:
register_esmerald_encoder(AttrsEncoder)
yield
finally:
LILYA_ENCODER_TYPES.reset(token)


@define
Expand All @@ -30,9 +38,9 @@ class AttrItem:


def test_can_parse_attrs(test_app_client_factory):

@post("/create")
async def create(data: AttrItem) -> AttrItem:
assert type(LILYA_ENCODER_TYPES.get()[0]) is AttrsEncoder
return data

with create_client(routes=[Gateway(handler=create)]) as client:
Expand All @@ -44,7 +52,6 @@ async def create(data: AttrItem) -> AttrItem:


def test_can_parse_attrs_errors(test_app_client_factory):

@define
class Item:
sku: str = field()
Expand Down
Loading

0 comments on commit 0ee361c

Please sign in to comment.