Skip to content

Commit

Permalink
Fix optional data payload encoding validation
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Nov 18, 2024
1 parent 884de68 commit e94ac3b
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 45 deletions.
19 changes: 5 additions & 14 deletions esmerald/openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,21 @@
get_schema_from_model_field,
is_status_code_allowed,
)
from esmerald.params import Param, Security
from esmerald.params import Param
from esmerald.routing import gateways, router
from esmerald.routing._internal import (
convert_annotation_to_pydantic_model,
)
from esmerald.security.oauth2.oauth import SecurityBase
from esmerald.transformers.model import ParamSetting
from esmerald.typing import Undefined
from esmerald.utils.dependencies import is_security_scheme
from esmerald.utils.helpers import is_class_and_subclass, is_union

ADDITIONAL_TYPES = ["bool", "list", "dict"]
TRANSFORMER_TYPES_KEYS = list(TRANSFORMER_TYPES.keys())
TRANSFORMER_TYPES_KEYS += ADDITIONAL_TYPES


def is_security_scheme(param: ParamSetting) -> bool:
"""
Checks if the object is a security scheme.
"""
if not param.field_info.default:
return False
return isinstance(param.field_info.default, Security)


def get_flat_params(route: Union[router.HTTPHandler, Any], body_fields: List[str]) -> List[Any]:
"""
Gets all the neded params of the request and route.
Expand All @@ -93,19 +84,19 @@ def get_flat_params(route: Union[router.HTTPHandler, Any], body_fields: List[str

# Making sure all the optional and union types are included
if is_union_or_optional:
if not is_security_scheme(param):
if not is_security_scheme(param.field_info.default):
query_params.append(param.field_info)

else:
if isinstance(param.field_info.annotation, _GenericAlias) and not is_security_scheme(
param
param.field_info.default
):
query_params.append(param.field_info)
elif (
param.field_info.annotation.__class__.__name__ in TRANSFORMER_TYPES_KEYS
or param.field_info.annotation.__name__ in TRANSFORMER_TYPES_KEYS
):
if not is_security_scheme(param):
if not is_security_scheme(param.field_info.default):
query_params.append(param.field_info)

return path_params + query_params + cookie_params + header_params
Expand Down
4 changes: 2 additions & 2 deletions esmerald/param_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def Security(
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)


def DirectInjects(
def Requires(
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
Expand All @@ -31,7 +31,7 @@ def DirectInjects(
This function should be only called if Inject/Injects is not used in the dependencies.
This is a simple wrapper of the classic Depends().
"""
return params.Depends(dependency=dependency, use_cache=use_cache)
return params.Requires(dependency=dependency, use_cache=use_cache)

def Form(
default: Any = _PyUndefined,
Expand Down
46 changes: 44 additions & 2 deletions esmerald/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,33 @@ def __init__(
super().__init__(default=default, json_schema_extra=self.extra)


class Depends:
class Requires:
"""
A class that represents a requirement with an optional dependency and caching behavior.
Attributes:
dependency (Optional[Callable[..., Any]]): An optional callable that represents the dependency.
use_cache (bool): A flag indicating whether to use caching for the dependency. Defaults to True.
Methods:
__repr__(): Returns a string representation of the Requires instance.
"""

def __init__(self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True):
"""
Initializes a Requires instance.
Args:
dependency (Optional[Callable[..., Any]]): An optional callable that represents the dependency.
use_cache (bool): A flag indicating whether to use caching for the dependency. Defaults to True.
"""

"""
Returns a string representation of the Requires instance.
Returns:
str: A string representation of the Requires instance.
"""
self.dependency = dependency
self.use_cache = use_cache

Expand All @@ -638,7 +663,24 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({attr}{cache})"


class Security(Depends):
class Security(Requires):
"""
A class used to represent security requirements for a particular operation.
Attributes:
----------
dependency : Optional[Callable[..., Any]]
A callable that represents the dependency required for security.
scopes : Optional[Sequence[str]]
A sequence of scopes required for the security. Defaults to an empty list.
use_cache : bool
A flag indicating whether to use cache. Defaults to True.
Methods:
-------
__init__(self, dependency: Optional[Callable[..., Any]] = None, *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True)
Initializes the Security class with the given dependency, scopes, and use_cache flag.
"""
def __init__(
self,
dependency: Optional[Callable[..., Any]] = None,
Expand Down
23 changes: 22 additions & 1 deletion esmerald/transformers/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from orjson import loads
from pydantic import ValidationError, create_model
from pydantic.fields import FieldInfo

from esmerald.encoders import ENCODER_TYPES, Encoder
from esmerald.exceptions import (
Expand All @@ -31,8 +32,9 @@
from esmerald.transformers.constants import CLASS_SPECIAL_WORDS, UNDEFINED, VALIDATION_NAMES
from esmerald.transformers.utils import get_connection_info, get_field_definition_from_param
from esmerald.typing import Undefined
from esmerald.utils.dependency import is_dependency_field, should_skip_dependency_validation
from esmerald.utils.constants import IS_DEPENDENCY, SKIP_VALIDATION
from esmerald.utils.helpers import is_optional_union
from esmerald.utils.schema import extract_arguments
from esmerald.websockets import WebSocket

if TYPE_CHECKING:
Expand All @@ -59,6 +61,16 @@ def is_server_error(error: Any, klass: Type["SignatureModel"]) -> bool:
return False


def is_dependency_field(val: Any) -> bool:
json_schema_extra = getattr(val, "json_schema_extra", None) or {}
return bool(isinstance(val, FieldInfo) and bool(json_schema_extra.get(IS_DEPENDENCY)))


def should_skip_dependency_validation(val: Any) -> bool:
json_schema_extra = getattr(val, "json_schema_extra", None) or {}
return bool(is_dependency_field(val) and json_schema_extra.get(SKIP_VALIDATION))


class Parameter(ArbitraryBaseModel):
"""
Represents a function parameter with associated metadata.
Expand Down Expand Up @@ -178,6 +190,15 @@ def encode_value(encoder: "Encoder", annotation: Any, value: Any) -> Any:
encoder_info: Dict[str, "Encoder"] = cls.encoders[key] # type: ignore
encoder: "Encoder" = encoder_info["encoder"]
annotation = encoder_info["annotation"]

if is_optional_union(annotation) and not value:
kwargs[key] = None
continue

if is_optional_union(annotation) and value:
decoded_list = extract_arguments(annotation)
annotation = decoded_list[0] # type: ignore

kwargs[key] = encode_value(encoder, annotation, value)

return kwargs
Expand Down
20 changes: 20 additions & 0 deletions esmerald/utils/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any

from esmerald import params
from esmerald.utils.helpers import is_class_and_subclass


def is_requires_scheme(param: Any) -> bool:
"""
Checks if the object is a security scheme.
"""
return is_class_and_subclass(param, params.Requires)


def is_security_scheme(param: Any) -> bool:
"""
Checks if the object is a security scheme.
"""
if not param:
return False
return isinstance(param, params.Security)
15 changes: 0 additions & 15 deletions esmerald/utils/dependency.py

This file was deleted.

25 changes: 25 additions & 0 deletions esmerald/utils/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,28 @@ def should_skip_json_schema(field_info: Union[FieldInfo, Any]) -> FieldInfo:
arguments = tuple(arguments) # type: ignore
field_info.annotation = Union[arguments]
return field_info


def extract_arguments(
param: Union[Any, None] = None, argument_list: Union[List[Any], None] = None
) -> List[Type[type]]:
"""
Recursively extracts unique types from a parameter's type annotation.
Args:
param (Union[Parameter, None], optional): The parameter with type annotation to extract from.
argument_list (Union[List[Any], None], optional): The list of arguments extracted so far.
Returns:
List[Type[type]]: A list of unique types extracted from the parameter's type annotation.
"""
arguments: List[Any] = list(argument_list) if argument_list is not None else []
args = get_args(param)

for arg in args:
if isinstance(arg, _GenericAlias):
arguments.extend(extract_arguments(param=arg, argument_list=arguments))
else:
if arg not in arguments:
arguments.append(arg)
return arguments
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ testing = [
"ujson>=5.7.0,<6",
"anyio[trio]>=3.6.2,<5.0.0",
"brotli>=1.0.9,<2.0.0",
"edgy[postgres]>=0.16.0",
"edgy[postgres]>=0.21.0",
"databasez>=0.9.7",
"flask>=1.1.2,<4.0.0",
"freezegun>=1.2.2,<2.0.0",
Expand Down
18 changes: 9 additions & 9 deletions tests/dependencies/test_injects_with_fastapi_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from esmerald import Esmerald, Gateway, get
from esmerald.param_functions import DirectInjects
from esmerald.param_functions import Requires
from esmerald.testclient import EsmeraldTestClient


Expand Down Expand Up @@ -49,53 +49,53 @@ async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:


@get("/callable-dependency")
async def get_callable_dependency(value: str = DirectInjects(callable_dependency)) -> str:
async def get_callable_dependency(value: str = Requires(callable_dependency)) -> str:
return value


@get("/callable-gen-dependency")
async def get_callable_gen_dependency(value: str = DirectInjects(callable_gen_dependency)) -> str:
async def get_callable_gen_dependency(value: str = Requires(callable_gen_dependency)) -> str:
return value


@get("/async-callable-dependency")
async def get_async_callable_dependency(
value: str = DirectInjects(async_callable_dependency),
value: str = Requires(async_callable_dependency),
) -> str:
return value


@get("/async-callable-gen-dependency")
async def get_async_callable_gen_dependency(
value: str = DirectInjects(async_callable_gen_dependency),
value: str = Requires(async_callable_gen_dependency),
) -> str:
return value


@get("/synchronous-method-dependency")
async def get_synchronous_method_dependency(
value: str = DirectInjects(methods_dependency.synchronous),
value: str = Requires(methods_dependency.synchronous),
) -> str:
return value


@get("/synchronous-method-gen-dependency")
async def get_synchronous_method_gen_dependency(
value: str = DirectInjects(methods_dependency.synchronous_gen),
value: str = Requires(methods_dependency.synchronous_gen),
) -> str:
return value


@get("/asynchronous-method-dependency")
async def get_asynchronous_method_dependency(
value: str = DirectInjects(methods_dependency.asynchronous),
value: str = Requires(methods_dependency.asynchronous),
) -> str:
return value


@get("/asynchronous-method-gen-dependency")
async def get_asynchronous_method_gen_dependency(
value: str = DirectInjects(methods_dependency.asynchronous_gen),
value: str = Requires(methods_dependency.asynchronous_gen),
) -> str:
return value

Expand Down
50 changes: 50 additions & 0 deletions tests/encoding/test_encoder_optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, Optional

from pydantic import BaseModel

from esmerald import Gateway, post
from esmerald.testclient import create_client


class User(BaseModel):
username: str


@post("/optional")
async def create(data: Optional[User]) -> Any:
return data if data else {}


def test_optional():
with create_client(routes=[Gateway(handler=create)]) as client:
response = client.post("/optional", json={"username": "test"})
assert response.status_code == 201
assert response.json() == {"username": "test"}

response = client.post("/optional", json={})
assert response.status_code == 201
assert response.json() == {}

response = client.post("/optional")
assert response.status_code == 201
assert response.json() == {}


@post("/union")
async def create_union(data: Optional[User]) -> Any:
return data if data else {}


def test_union():
with create_client(routes=[Gateway(handler=create_union)]) as client:
response = client.post("/union", json={"username": "test"})
assert response.status_code == 201
assert response.json() == {"username": "test"}

response = client.post("/union", json={})
assert response.status_code == 201
assert response.json() == {}

response = client.post("/union")
assert response.status_code == 201
assert response.json() == {}
2 changes: 1 addition & 1 deletion tests/security/oauth/test_security_oauth2_optional_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_security_oauth2_password_other_header():
assert response.json() == {"username": "Other footokenbar"}


def xtest_security_oauth2_password_bearer_no_header():
def test_security_oauth2_password_bearer_no_header():
with create_client(
routes=[Gateway(handler=read_users_me)], security=[reusable_oauth2]
) as client:
Expand Down

0 comments on commit e94ac3b

Please sign in to comment.