Skip to content

Commit

Permalink
[Implementation] - Redesign signature to allow multiple types of enco…
Browse files Browse the repository at this point in the history
…ders. (#299)

* Add extra encoder validation
* Change implementation of the signature
* Allow esmerald custom encoder to be declared
* Rename exception names in signature
* Add encoders to esmerald properties
* Add tests to add encoders
  • Loading branch information
tarsil committed May 3, 2024
1 parent 04ab931 commit b4df3b4
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 36 deletions.
1 change: 1 addition & 0 deletions docs/references/esmerald.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ from esmerald import Esmerald
- add_child_esmerald
- add_router
- add_pluggable
- register_encoder

::: esmerald.ChildEsmerald
options:
Expand Down
62 changes: 51 additions & 11 deletions esmerald/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class Esmerald(Lilya):
"timezone",
"title",
"version",
"encoders",
)

def __init__(
Expand Down Expand Up @@ -1083,6 +1084,46 @@ async def another(request: Request) -> str:
"""
),
] = None,
encoders: Annotated[
Sequence[Optional[Encoder]],
Doc(
"""
A `list` of encoders to be used by the application once it
starts.
Returns:
List of encoders
**Example**
```python
from typing import Any
from attrs import asdict, define, field, has
from esmerald.encoders import Encoder
class AttrsEncoder(Encoder):
def is_type(self, value: Any) -> bool:
return has(value)
def serialize(self, obj: Any) -> Any:
return asdict(obj)
def encode(self, annotation: Any, value: Any) -> Any:
return annotation(**value)
class AppSettings(EsmeraldAPISettings):
@property
def encoders(self) -> Union[List[Encoder], None]:
return [AttrsEncoder]
```
"""
),
] = None,
exception_handlers: Annotated[
Optional["ExceptionHandlerMap"],
Doc(
Expand Down Expand Up @@ -1582,6 +1623,9 @@ def extend(self, config: PluggableConfig) -> None:
] = State()
self.async_exit_config = esmerald_settings.async_exit_config

self.encoders = self.load_settings_value("encoders", encoders) or []
self._register_application_encoders()

self.router: "Router" = Router(
on_shutdown=self.on_shutdown,
on_startup=self.on_startup,
Expand All @@ -1593,14 +1637,20 @@ def extend(self, config: PluggableConfig) -> None:
redirect_slashes=self.redirect_slashes,
)
self.get_default_exception_handlers()
self.register_default_encoders()
self.user_middleware = self.build_user_middleware_stack()
self.middleware_stack = self.build_middleware_stack()
self.pluggable_stack = self.build_pluggable_stack()
self.template_engine = self.get_template_engine(self.template_config)

self._configure()

def _register_application_encoders(self) -> None:
self.register_encoder(cast(Encoder[Any], PydanticEncoder))
self.register_encoder(cast(Encoder[Any], MsgSpecEncoder))

for encoder in self.encoders:
self.register_encoder(encoder)

def _configure(self) -> None:
"""
Starts the Esmerald configurations.
Expand Down Expand Up @@ -2287,16 +2337,6 @@ def get_default_exception_handlers(self) -> None:

self.exception_handlers.setdefault(ValidationError, pydantic_validation_error_handler)

def register_default_encoders(self) -> None:
"""
Registers the default encoders supported by Esmerald.
The default Encoders are simple validation libraries like Pydantic/MsgSpec
that out of the box, Esmerald will make sure it does understand them.
"""
self.register_encoder(cast(Encoder[Any], PydanticEncoder))
self.register_encoder(cast(Encoder[Any], MsgSpecEncoder))

def build_routes_exception_handlers(
self,
route: "RouteParent",
Expand Down
40 changes: 40 additions & 0 deletions esmerald/conf/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from esmerald.config.asyncexit import AsyncExitConfig
from esmerald.datastructures import Secret
from esmerald.encoders import Encoder
from esmerald.interceptors.types import Interceptor
from esmerald.permissions.types import Permission
from esmerald.pluggables import Pluggable
Expand Down Expand Up @@ -1425,6 +1426,45 @@ def pluggables(self) -> Dict[str, "Pluggable"]:
"""
return {}

@property
def encoders(self) -> Union[List[Encoder], None]:
"""
A `list` of encoders to be used by the application once it
starts.
Returns:
List of encoders
**Example**
```python
from typing import Any
from attrs import asdict, define, field, has
from esmerald.encoders import Encoder
class AttrsEncoder(Encoder):
def is_type(self, value: Any) -> bool:
return has(value)
def serialize(self, obj: Any) -> Any:
return asdict(obj)
def encode(self, annotation: Any, value: Any) -> Any:
return annotation(**value)
class AppSettings(EsmeraldAPISettings):
@property
def encoders(self) -> Union[List[Encoder], None]:
return [AttrsEncoder]
```
"""
return []

def __hash__(self) -> int:
values: Dict[str, Any] = {}
for key, value in self.__dict__.items():
Expand Down
46 changes: 41 additions & 5 deletions esmerald/encoders.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
from __future__ import annotations

from typing import Any
from typing import Any, TypeVar

import msgspec
from lilya._internal._encoders import json_encoder as json_encoder # noqa
from lilya._utils import is_class_and_subclass
from lilya.encoders import (
ENCODER_TYPES as ENCODER_TYPES, # noqa
Encoder as Encoder, # noqa
Encoder as LilyaEncoder, # noqa
register_encoder as register_encoder, # noqa
)
from msgspec import Struct
from pydantic import BaseModel

from esmerald.exceptions import ImproperlyConfigured

T = TypeVar("T")


class Encoder(LilyaEncoder[T]):

def is_type(self, value: Any) -> bool:
"""
Function that checks if the function is
an instance of a given type
"""
raise NotImplementedError("All Esmerald encoders must implement is_type() method")

def encode(self, annotation: Any, value: Any) -> Any:
"""
Function that transforms the kwargs into a structure
"""
raise NotImplementedError("All Esmerald encoders must implement encode() method")


class MsgSpecEncoder(Encoder):
__type__ = Struct

def is_type(self, value: Any) -> bool:
return isinstance(value, Struct) or is_class_and_subclass(value, Struct)

def serialize(self, obj: Any) -> Any:
"""
Expand All @@ -26,15 +46,31 @@ def serialize(self, obj: Any) -> Any:
"""
return msgspec.json.decode(msgspec.json.encode(obj))

def encode(self, annotation: Any, value: Any) -> Any:
return msgspec.json.decode(msgspec.json.encode(value), type=annotation)


class PydanticEncoder(Encoder):
__type__ = BaseModel

def is_type(self, value: Any) -> bool:
return isinstance(value, BaseModel) or is_class_and_subclass(value, BaseModel)

def serialize(self, obj: BaseModel) -> dict[str, Any]:
return obj.model_dump()

def encode(self, annotation: Any, value: Any) -> Any:
if isinstance(value, BaseModel):
return value
return annotation(**value)


def register_esmerald_encoder(encoder: Encoder[Any]) -> None:
"""
Registers an esmerald encoder into available Lilya encoders
"""
if not isinstance(encoder, Encoder) and not is_class_and_subclass(encoder, Encoder): # type: ignore
raise ImproperlyConfigured(f"{type(encoder)} must be a subclass of Encoder")
register_encoder(encoder)

encoder_types = {encoder.__class__.__name__ for encoder in ENCODER_TYPES}
if encoder.__name__ not in encoder_types:
register_encoder(encoder)
3 changes: 3 additions & 0 deletions esmerald/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pydantic import AnyUrl

from esmerald.applications import Esmerald
from esmerald.encoders import Encoder
from esmerald.utils.crypto import get_random_secret_key

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -123,6 +124,7 @@ def create_client(
redirect_slashes: Optional[bool] = None,
tags: Optional[List[str]] = None,
webhooks: Optional[Sequence["WebhookGateway"]] = None,
encoders: Optional[Sequence[Encoder]] = None,
) -> EsmeraldTestClient:
return EsmeraldTestClient(
app=Esmerald(
Expand Down Expand Up @@ -167,6 +169,7 @@ def create_client(
tags=tags,
webhooks=webhooks,
pluggables=pluggables,
encoders=encoders,
),
base_url=base_url,
backend=backend,
Expand Down
43 changes: 25 additions & 18 deletions esmerald/transformers/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from inspect import Parameter as InspectParameter, Signature
from typing import Any, ClassVar, Dict, Optional, Set, Union

import msgspec
from msgspec import ValidationError as MsgspecValidationError
from orjson import loads
from pydantic import ValidationError

from esmerald.encoders import Encoder
from esmerald.exceptions import ImproperlyConfigured, InternalServerError, ValidationErrorException
from esmerald.parsers import ArbitraryBaseModel
from esmerald.requests import Request
Expand All @@ -19,17 +18,25 @@
class EsmeraldSignature(ArbitraryBaseModel):
dependency_names: ClassVar[Set[str]]
return_annotation: ClassVar[Any]
msgspec_structs: ClassVar[Dict[str, msgspec.Struct]]
encoders: ClassVar[Dict[str, Any]]

@classmethod
def parse_msgspec_structures(cls, kwargs: Any) -> Any:
def parse_encoders(cls, kwargs: Any) -> Any:
"""
Parses the kwargs for a possible msgspec Struct and instantiates it.
Parses the kwargs into a proper structure of the encoder
itself.
The encoders **must** be of Esmerald encoder or else
it will default to Lilya encoders.
Lilya encoders do not implement custom `encode()` functionality.
"""
for k, v in kwargs.items():
if k in cls.msgspec_structs:
kwargs[k] = msgspec.json.decode(
msgspec.json.encode(v), type=cls.msgspec_structs[k]
for key, value in kwargs.items():
if key in cls.encoders:
encoder: "Encoder" = cls.encoders[key]["encoder"]
annotation = cls.encoders[key]["annotation"]
kwargs[key] = (
encoder.encode(annotation, value) if isinstance(encoder, Encoder) else value
)
return kwargs

Expand All @@ -38,23 +45,23 @@ def parse_values_for_connection(
cls, connection: Union[Request, WebSocket], **kwargs: Any
) -> Any:
try:
if cls.msgspec_structs:
kwargs = cls.parse_msgspec_structures(kwargs)
if cls.encoders:
kwargs = cls.parse_encoders(kwargs)

signature = cls(**kwargs)
values = {}
for key in cls.model_fields:
values[key] = signature.field_value(key)
return values
except ValidationError as e:
raise cls.build_exception(connection, e) from e
except MsgspecValidationError as e:
raise cls.build_msgspec_exception(connection, e) from e
raise cls.build_base_system_exception(connection, e) from e
except Exception as e:
raise cls.build_encoder_exception(connection, e) from e

@classmethod
def build_msgspec_exception(
cls, connection: Union[Request, WebSocket], exception: MsgspecValidationError
) -> ValidationErrorException:
def build_encoder_exception(
cls, connection: Union[Request, WebSocket], exception: Exception
) -> Exception:
"""
Builds the exceptions for the message spec.
"""
Expand All @@ -73,7 +80,7 @@ def build_msgspec_exception(
return ValidationErrorException(detail=error_message, extra=[message])

@classmethod
def build_exception(
def build_base_system_exception(
cls, connection: Union[Request, WebSocket], exception: ValidationError
) -> Union[InternalServerError, ValidationErrorException]:
server_errors = []
Expand Down
Loading

0 comments on commit b4df3b4

Please sign in to comment.