From b4df3b4de7dc4b882e0c24fa87e05a9d71c5e3bb Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Thu, 25 Apr 2024 09:54:23 +0100 Subject: [PATCH] [Implementation] - Redesign signature to allow multiple types of encoders. (#299) * Add extra encoder validation * Change implementation of the signature * Allow esmerald custom encoder to be declared * Rename exception names in signature * Add encoders to esmerald properties * Add tests to add encoders --- docs/references/esmerald.md | 1 + esmerald/applications.py | 62 +++++++++++++++++++----- esmerald/conf/global_settings.py | 40 ++++++++++++++++ esmerald/encoders.py | 46 ++++++++++++++++-- esmerald/testclient.py | 3 ++ esmerald/transformers/datastructures.py | 43 ++++++++++------- esmerald/transformers/signature.py | 38 ++++++++++++++- tests/encoding/__init__.py | 0 tests/encoding/test_attrs_encoders.py | 64 +++++++++++++++++++++++++ tests/encoding/test_register_encoder.py | 60 +++++++++++++++++++++++ 10 files changed, 321 insertions(+), 36 deletions(-) create mode 100644 tests/encoding/__init__.py create mode 100644 tests/encoding/test_attrs_encoders.py create mode 100644 tests/encoding/test_register_encoder.py diff --git a/docs/references/esmerald.md b/docs/references/esmerald.md index 4dd011f7..fce570bd 100644 --- a/docs/references/esmerald.md +++ b/docs/references/esmerald.md @@ -22,6 +22,7 @@ from esmerald import Esmerald - add_child_esmerald - add_router - add_pluggable + - register_encoder ::: esmerald.ChildEsmerald options: diff --git a/esmerald/applications.py b/esmerald/applications.py index 505be386..50c10d70 100644 --- a/esmerald/applications.py +++ b/esmerald/applications.py @@ -140,6 +140,7 @@ class Esmerald(Lilya): "timezone", "title", "version", + "encoders", ) def __init__( @@ -1083,6 +1084,46 @@ async def another(request: Request) -> str: """ ), ] = None, + encoders: Annotated[ + Sequence[Optional[Encoder]], + Doc( + """ + A `list` of encoders to be used by the application once it + starts. + + Returns: + List of encoders + + **Example** + + ```python + from typing import Any + + from attrs import asdict, define, field, has + from esmerald.encoders import Encoder + + + class AttrsEncoder(Encoder): + + def is_type(self, value: Any) -> bool: + return has(value) + + def serialize(self, obj: Any) -> Any: + return asdict(obj) + + def encode(self, annotation: Any, value: Any) -> Any: + return annotation(**value) + + + class AppSettings(EsmeraldAPISettings): + + @property + def encoders(self) -> Union[List[Encoder], None]: + return [AttrsEncoder] + ``` + """ + ), + ] = None, exception_handlers: Annotated[ Optional["ExceptionHandlerMap"], Doc( @@ -1582,6 +1623,9 @@ def extend(self, config: PluggableConfig) -> None: ] = State() self.async_exit_config = esmerald_settings.async_exit_config + self.encoders = self.load_settings_value("encoders", encoders) or [] + self._register_application_encoders() + self.router: "Router" = Router( on_shutdown=self.on_shutdown, on_startup=self.on_startup, @@ -1593,7 +1637,6 @@ def extend(self, config: PluggableConfig) -> None: redirect_slashes=self.redirect_slashes, ) self.get_default_exception_handlers() - self.register_default_encoders() self.user_middleware = self.build_user_middleware_stack() self.middleware_stack = self.build_middleware_stack() self.pluggable_stack = self.build_pluggable_stack() @@ -1601,6 +1644,13 @@ def extend(self, config: PluggableConfig) -> None: self._configure() + def _register_application_encoders(self) -> None: + self.register_encoder(cast(Encoder[Any], PydanticEncoder)) + self.register_encoder(cast(Encoder[Any], MsgSpecEncoder)) + + for encoder in self.encoders: + self.register_encoder(encoder) + def _configure(self) -> None: """ Starts the Esmerald configurations. @@ -2287,16 +2337,6 @@ def get_default_exception_handlers(self) -> None: self.exception_handlers.setdefault(ValidationError, pydantic_validation_error_handler) - def register_default_encoders(self) -> None: - """ - Registers the default encoders supported by Esmerald. - - The default Encoders are simple validation libraries like Pydantic/MsgSpec - that out of the box, Esmerald will make sure it does understand them. - """ - self.register_encoder(cast(Encoder[Any], PydanticEncoder)) - self.register_encoder(cast(Encoder[Any], MsgSpecEncoder)) - def build_routes_exception_handlers( self, route: "RouteParent", diff --git a/esmerald/conf/global_settings.py b/esmerald/conf/global_settings.py index 378bcf3e..a7edb719 100644 --- a/esmerald/conf/global_settings.py +++ b/esmerald/conf/global_settings.py @@ -18,6 +18,7 @@ ) from esmerald.config.asyncexit import AsyncExitConfig from esmerald.datastructures import Secret +from esmerald.encoders import Encoder from esmerald.interceptors.types import Interceptor from esmerald.permissions.types import Permission from esmerald.pluggables import Pluggable @@ -1425,6 +1426,45 @@ def pluggables(self) -> Dict[str, "Pluggable"]: """ return {} + @property + def encoders(self) -> Union[List[Encoder], None]: + """ + A `list` of encoders to be used by the application once it + starts. + + Returns: + List of encoders + + **Example** + + ```python + from typing import Any + + from attrs import asdict, define, field, has + from esmerald.encoders import Encoder + + + class AttrsEncoder(Encoder): + + def is_type(self, value: Any) -> bool: + return has(value) + + def serialize(self, obj: Any) -> Any: + return asdict(obj) + + def encode(self, annotation: Any, value: Any) -> Any: + return annotation(**value) + + + class AppSettings(EsmeraldAPISettings): + + @property + def encoders(self) -> Union[List[Encoder], None]: + return [AttrsEncoder] + ``` + """ + return [] + def __hash__(self) -> int: values: Dict[str, Any] = {} for key, value in self.__dict__.items(): diff --git a/esmerald/encoders.py b/esmerald/encoders.py index f7dbde38..a495d0ce 100644 --- a/esmerald/encoders.py +++ b/esmerald/encoders.py @@ -1,13 +1,13 @@ from __future__ import annotations -from typing import Any +from typing import Any, TypeVar import msgspec from lilya._internal._encoders import json_encoder as json_encoder # noqa from lilya._utils import is_class_and_subclass from lilya.encoders import ( ENCODER_TYPES as ENCODER_TYPES, # noqa - Encoder as Encoder, # noqa + Encoder as LilyaEncoder, # noqa register_encoder as register_encoder, # noqa ) from msgspec import Struct @@ -15,9 +15,29 @@ from esmerald.exceptions import ImproperlyConfigured +T = TypeVar("T") + + +class Encoder(LilyaEncoder[T]): + + def is_type(self, value: Any) -> bool: + """ + Function that checks if the function is + an instance of a given type + """ + raise NotImplementedError("All Esmerald encoders must implement is_type() method") + + def encode(self, annotation: Any, value: Any) -> Any: + """ + Function that transforms the kwargs into a structure + """ + raise NotImplementedError("All Esmerald encoders must implement encode() method") + class MsgSpecEncoder(Encoder): - __type__ = Struct + + def is_type(self, value: Any) -> bool: + return isinstance(value, Struct) or is_class_and_subclass(value, Struct) def serialize(self, obj: Any) -> Any: """ @@ -26,15 +46,31 @@ def serialize(self, obj: Any) -> Any: """ return msgspec.json.decode(msgspec.json.encode(obj)) + def encode(self, annotation: Any, value: Any) -> Any: + return msgspec.json.decode(msgspec.json.encode(value), type=annotation) + class PydanticEncoder(Encoder): - __type__ = BaseModel + + def is_type(self, value: Any) -> bool: + return isinstance(value, BaseModel) or is_class_and_subclass(value, BaseModel) def serialize(self, obj: BaseModel) -> dict[str, Any]: return obj.model_dump() + def encode(self, annotation: Any, value: Any) -> Any: + if isinstance(value, BaseModel): + return value + return annotation(**value) + def register_esmerald_encoder(encoder: Encoder[Any]) -> None: + """ + Registers an esmerald encoder into available Lilya encoders + """ if not isinstance(encoder, Encoder) and not is_class_and_subclass(encoder, Encoder): # type: ignore raise ImproperlyConfigured(f"{type(encoder)} must be a subclass of Encoder") - register_encoder(encoder) + + encoder_types = {encoder.__class__.__name__ for encoder in ENCODER_TYPES} + if encoder.__name__ not in encoder_types: + register_encoder(encoder) diff --git a/esmerald/testclient.py b/esmerald/testclient.py index d0e8faaa..84c2ad73 100644 --- a/esmerald/testclient.py +++ b/esmerald/testclient.py @@ -17,6 +17,7 @@ from pydantic import AnyUrl from esmerald.applications import Esmerald +from esmerald.encoders import Encoder from esmerald.utils.crypto import get_random_secret_key if TYPE_CHECKING: # pragma: no cover @@ -123,6 +124,7 @@ def create_client( redirect_slashes: Optional[bool] = None, tags: Optional[List[str]] = None, webhooks: Optional[Sequence["WebhookGateway"]] = None, + encoders: Optional[Sequence[Encoder]] = None, ) -> EsmeraldTestClient: return EsmeraldTestClient( app=Esmerald( @@ -167,6 +169,7 @@ def create_client( tags=tags, webhooks=webhooks, pluggables=pluggables, + encoders=encoders, ), base_url=base_url, backend=backend, diff --git a/esmerald/transformers/datastructures.py b/esmerald/transformers/datastructures.py index a826d7ab..1e4e095c 100644 --- a/esmerald/transformers/datastructures.py +++ b/esmerald/transformers/datastructures.py @@ -2,11 +2,10 @@ from inspect import Parameter as InspectParameter, Signature from typing import Any, ClassVar, Dict, Optional, Set, Union -import msgspec -from msgspec import ValidationError as MsgspecValidationError from orjson import loads from pydantic import ValidationError +from esmerald.encoders import Encoder from esmerald.exceptions import ImproperlyConfigured, InternalServerError, ValidationErrorException from esmerald.parsers import ArbitraryBaseModel from esmerald.requests import Request @@ -19,17 +18,25 @@ class EsmeraldSignature(ArbitraryBaseModel): dependency_names: ClassVar[Set[str]] return_annotation: ClassVar[Any] - msgspec_structs: ClassVar[Dict[str, msgspec.Struct]] + encoders: ClassVar[Dict[str, Any]] @classmethod - def parse_msgspec_structures(cls, kwargs: Any) -> Any: + def parse_encoders(cls, kwargs: Any) -> Any: """ - Parses the kwargs for a possible msgspec Struct and instantiates it. + Parses the kwargs into a proper structure of the encoder + itself. + + The encoders **must** be of Esmerald encoder or else + it will default to Lilya encoders. + + Lilya encoders do not implement custom `encode()` functionality. """ - for k, v in kwargs.items(): - if k in cls.msgspec_structs: - kwargs[k] = msgspec.json.decode( - msgspec.json.encode(v), type=cls.msgspec_structs[k] + for key, value in kwargs.items(): + if key in cls.encoders: + encoder: "Encoder" = cls.encoders[key]["encoder"] + annotation = cls.encoders[key]["annotation"] + kwargs[key] = ( + encoder.encode(annotation, value) if isinstance(encoder, Encoder) else value ) return kwargs @@ -38,8 +45,8 @@ def parse_values_for_connection( cls, connection: Union[Request, WebSocket], **kwargs: Any ) -> Any: try: - if cls.msgspec_structs: - kwargs = cls.parse_msgspec_structures(kwargs) + if cls.encoders: + kwargs = cls.parse_encoders(kwargs) signature = cls(**kwargs) values = {} @@ -47,14 +54,14 @@ def parse_values_for_connection( values[key] = signature.field_value(key) return values except ValidationError as e: - raise cls.build_exception(connection, e) from e - except MsgspecValidationError as e: - raise cls.build_msgspec_exception(connection, e) from e + raise cls.build_base_system_exception(connection, e) from e + except Exception as e: + raise cls.build_encoder_exception(connection, e) from e @classmethod - def build_msgspec_exception( - cls, connection: Union[Request, WebSocket], exception: MsgspecValidationError - ) -> ValidationErrorException: + def build_encoder_exception( + cls, connection: Union[Request, WebSocket], exception: Exception + ) -> Exception: """ Builds the exceptions for the message spec. """ @@ -73,7 +80,7 @@ def build_msgspec_exception( return ValidationErrorException(detail=error_message, extra=[message]) @classmethod - def build_exception( + def build_base_system_exception( cls, connection: Union[Request, WebSocket], exception: ValidationError ) -> Union[InternalServerError, ValidationErrorException]: server_errors = [] diff --git a/esmerald/transformers/signature.py b/esmerald/transformers/signature.py index 18d66cf4..08f37002 100644 --- a/esmerald/transformers/signature.py +++ b/esmerald/transformers/signature.py @@ -16,6 +16,7 @@ import msgspec from pydantic import create_model +from esmerald.encoders import ENCODER_TYPES from esmerald.exceptions import ImproperlyConfigured from esmerald.parsers import ArbitraryExtraBaseModel from esmerald.transformers.constants import CLASS_SPECIAL_WORDS, VALIDATION_NAMES @@ -118,6 +119,39 @@ def handle_msgspec_structs( msgpspec_structs[param.name] = param.annotation return msgpspec_structs + def handle_encoders( + self, parameters: Generator[Parameter, None, None] = None + ) -> Dict[str, Any]: + """ + Handles the extraction of any of the passed encoders. + """ + custom_encoders: Dict[str, Any] = {} + + if parameters is None: + parameters = self.parameters + + for param in parameters: + origin = get_origin(param.annotation) + + for encoder in ENCODER_TYPES: + if not origin: + if encoder.is_type(param.annotation): + custom_encoders[param.name] = { + "encoder": encoder, + "annotation": param.annotation, + } + continue + else: + arguments: List[Type[type]] = self.extract_arguments(param=param.annotation) + + if any(encoder.is_type(value) for value in arguments): + custom_encoders[param.name] = { + "encoder": encoder, + "annotation": param.annotation, + } + continue + return custom_encoders + def create_signature(self) -> Type[EsmeraldSignature]: """ Creates the EsmeraldSignature based on the type of parameteres. @@ -125,7 +159,7 @@ def create_signature(self) -> Type[EsmeraldSignature]: This allows to understand if the msgspec is also available and allowed. """ try: - msgpspec_structs: Dict[str, msgspec.Struct] = self.handle_msgspec_structs() + encoders: Dict[str, Any] = self.handle_encoders() for param in self.parameters: self.validate_missing_dependency(param) @@ -144,7 +178,7 @@ def create_signature(self) -> Type[EsmeraldSignature]: ) model.return_annotation = self.signature.return_annotation model.dependency_names = self.dependency_names - model.msgspec_structs = msgpspec_structs + model.encoders = encoders return model except TypeError as e: raise ImproperlyConfigured( # pragma: no cover diff --git a/tests/encoding/__init__.py b/tests/encoding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/encoding/test_attrs_encoders.py b/tests/encoding/test_attrs_encoders.py new file mode 100644 index 00000000..7585bc3c --- /dev/null +++ b/tests/encoding/test_attrs_encoders.py @@ -0,0 +1,64 @@ +from typing import Any + +from attrs import asdict, define, field, has + +from esmerald import Gateway, post +from esmerald.encoders import Encoder, register_esmerald_encoder +from esmerald.testclient import create_client + + +class AttrsEncoder(Encoder): + + def is_type(self, value: Any) -> bool: + return has(value) + + def serialize(self, obj: Any) -> Any: + return asdict(obj) + + def encode(self, annotation: Any, value: Any) -> Any: + return annotation(**value) + + +register_esmerald_encoder(AttrsEncoder) + + +@define +class AttrItem: + name: str = field() + age: int = field() + email: str + + +def test_can_parse_attrs(test_app_client_factory): + + @post("/create") + async def create(data: AttrItem) -> AttrItem: + return data + + with create_client(routes=[Gateway(handler=create)]) as client: + response = client.post( + "/create", json={"name": "test", "age": 2, "email": "test@foobar.com"} + ) + assert response.status_code == 201 + assert response.json() == {"name": "test", "age": 2, "email": "test@foobar.com"} + + +def test_can_parse_attrs_errors(test_app_client_factory): + + @define + class Item: + sku: str = field() + + @sku.validator + def check(self, attribute, value): + if not isinstance(value, str): + raise ValueError(f"'{attribute.name}' must be a string.") + + @post("/create") + async def create(data: Item) -> AttrItem: + return data + + with create_client(routes=[Gateway(handler=create)]) as client: + response = client.post("/create", json={"sku": 1}) + assert response.status_code == 400 + assert response.json()["errors"] == ["'sku' must be a string."] diff --git a/tests/encoding/test_register_encoder.py b/tests/encoding/test_register_encoder.py new file mode 100644 index 00000000..e57a2a5e --- /dev/null +++ b/tests/encoding/test_register_encoder.py @@ -0,0 +1,60 @@ +from typing import Any + +from attrs import asdict, define, field, has + +from esmerald import Esmerald, Gateway, post +from esmerald.encoders import Encoder +from esmerald.testclient import EsmeraldTestClient, create_client + + +class AttrsEncoder(Encoder): + + def is_type(self, value: Any) -> bool: + return has(value) + + def serialize(self, obj: Any) -> Any: + return asdict(obj) + + def encode(self, annotation: Any, value: Any) -> Any: + return annotation(**value) + + +@define +class AttrItem: + name: str = field() + age: int = field() + email: str + + +def test_can_parse_attrs(test_app_client_factory): + @post("/create") + async def create(data: AttrItem) -> AttrItem: + return data + + app = Esmerald(routes=[Gateway(handler=create)], encoders=[AttrsEncoder]) + client = EsmeraldTestClient(app) + + response = client.post("/create", json={"name": "test", "age": 2, "email": "test@foobar.com"}) + assert response.status_code == 201 + assert response.json() == {"name": "test", "age": 2, "email": "test@foobar.com"} + + +def test_can_parse_attrs_errors(test_app_client_factory): + + @define + class Item: + sku: str = field() + + @sku.validator + def check(self, attribute, value): + if not isinstance(value, str): + raise ValueError(f"'{attribute.name}' must be a string.") + + @post("/create") + async def create(data: Item) -> AttrItem: + return data + + with create_client(routes=[Gateway(handler=create)], encoders=[AttrsEncoder]) as client: + response = client.post("/create", json={"sku": 1}) + assert response.status_code == 400 + assert response.json()["errors"] == ["'sku' must be a string."]