Skip to content

Commit

Permalink
Add requires as alterantive direct dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Jan 13, 2025
1 parent 77426ff commit d82f5f7
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 13 deletions.
8 changes: 4 additions & 4 deletions esmerald/openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
from esmerald.security.oauth2.oauth import SecurityBase
from esmerald.typing import Undefined
from esmerald.utils.dependencies import is_security_scheme
from esmerald.utils.dependencies import is_base_requires
from esmerald.utils.helpers import is_class_and_subclass, is_union

ADDITIONAL_TYPES = ["bool", "list", "dict"]
Expand Down Expand Up @@ -91,19 +91,19 @@ def get_flat_params(route: Union[router.HTTPHandler, Any], body_fields: List[str

# Making sure all the optional and union types are included
if is_union_or_optional:
if not is_security_scheme(param.field_info.default):
if not is_base_requires(param.field_info.default):
query_params.append(param.field_info)

else:
if isinstance(param.field_info.annotation, _GenericAlias) and not is_security_scheme(
if isinstance(param.field_info.annotation, _GenericAlias) and not is_base_requires(
param.field_info.default
):
query_params.append(param.field_info)
elif (
param.field_info.annotation.__class__.__name__ in TRANSFORMER_TYPES_KEYS
or param.field_info.annotation.__name__ in TRANSFORMER_TYPES_KEYS
):
if not is_security_scheme(param.field_info.default):
if not is_base_requires(param.field_info.default):
query_params.append(param.field_info)

return path_params + query_params + cookie_params + header_params
Expand Down
9 changes: 7 additions & 2 deletions esmerald/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,12 @@ def __init__(
super().__init__(default=default, json_schema_extra=self.extra)


class Requires:
class BaseRequires:
"""
A class that represents a requirement with an optional dependency and caching behavior.
This object serves as a base class for other classes that require dependencies.
Attributes:
dependency (Optional[Callable[..., Any]]): An optional callable that represents the dependency.
use_cache (bool): A flag indicating whether to use caching for the dependency. Defaults to True.
Expand Down Expand Up @@ -671,7 +673,10 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({attr}{cache})"


class Security(Requires):
class Requires(BaseRequires): ...


class Security(BaseRequires):
"""
A class used to represent security requirements for a particular operation.
Expand Down
39 changes: 36 additions & 3 deletions esmerald/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from esmerald.context import Context
from esmerald.enums import EncodingType, ParamType
from esmerald.exceptions import ImproperlyConfigured
from esmerald.params import Body, Security
from esmerald.params import Body, Requires, Security
from esmerald.parsers import ArbitraryExtraBaseModel, parse_form_data
from esmerald.requests import Request
from esmerald.transformers.signature import SignatureModel
Expand All @@ -32,7 +32,12 @@
)
from esmerald.typing import Undefined
from esmerald.utils.constants import CONTEXT, DATA, PAYLOAD, RESERVED_KWARGS
from esmerald.utils.dependencies import is_security_scheme, is_security_scope
from esmerald.utils.dependencies import (
async_resolve_dependencies,
is_requires,
is_security_scheme,
is_security_scope,
)
from esmerald.utils.schema import is_field_optional

if TYPE_CHECKING:
Expand Down Expand Up @@ -155,7 +160,7 @@ def get_security_scope_params(self) -> Dict[str, ParamSetting]:

def get_security_definition(self) -> Dict[str, ParamSetting]:
"""
Get header parameters.
Get header parameters for security.
Returns:
Set[ParamSetting]: Set of header parameters.
Expand All @@ -166,6 +171,20 @@ def get_security_definition(self) -> Dict[str, ParamSetting]:
if field.is_security and is_security_scheme(field.default_value)
}

def get_requires_definition(self) -> Dict[str, ParamSetting]:
"""
Get header parameters for requires.
Returns:
Set[ParamSetting]: Set of header parameters.
"""

return {
field.field_name: field
for field in self.get_query_params()
if is_requires(field.default_value)
}

async def to_kwargs(
self,
connection: Union["WebSocket", "Request"],
Expand Down Expand Up @@ -269,6 +288,15 @@ async def get_for_security_dependencies(

return kwargs

async def get_requires_dependencies(self, kwargs: Any) -> Any:
"""
get_requires_dependencies.
"""
for name, dependency in kwargs.items():
if isinstance(dependency, Requires):
kwargs[name] = await async_resolve_dependencies(dependency.dependency)
return kwargs

async def get_dependencies(
self, dependency: Dependency, connection: Union["WebSocket", "Request"], **kwargs: Any
) -> Any:
Expand All @@ -290,9 +318,14 @@ async def get_dependencies(
dependency=_dependency, connection=connection, **kwargs
)

# Handles with Security dependencies only
if kwargs and self.get_security_definition():
kwargs = await self.get_for_security_dependencies(connection, kwargs)

# Handles with everything that is related with a Requires
if kwargs and self.get_requires_definition():
kwargs = await self.get_requires_dependencies(kwargs)

dependency_kwargs = await signature_model.parse_values_for_connection(
connection=connection, **kwargs
)
Expand Down
26 changes: 26 additions & 0 deletions esmerald/transformers/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from esmerald.transformers.utils import get_connection_info, get_field_definition_from_param
from esmerald.typing import Undefined
from esmerald.utils.constants import IS_DEPENDENCY, SKIP_VALIDATION
from esmerald.utils.dependencies import async_resolve_dependencies, is_requires
from esmerald.utils.helpers import is_optional_union
from esmerald.utils.schema import extract_arguments
from esmerald.websockets import WebSocket
Expand Down Expand Up @@ -203,6 +204,26 @@ def encode_value(encoder: "Encoder", annotation: Any, value: Any) -> Any:

return kwargs

@classmethod
async def check_requires(cls, kwargs: Any) -> Any:
"""
Checks if any of the parameters is a requires dependency.
Args:
connection (Union[Request, WebSocket]): The connection object to check.
Raises:
BaseSystemException: If validation error occurs.
EncoderException: If encoder error occurs.
"""
if kwargs is None:
return kwargs

for key, value in kwargs.items():
if is_requires(value):
kwargs[key] = await async_resolve_dependencies(value.dependency)
return kwargs

@classmethod
async def parse_values_for_connection(
cls, connection: Union[Request, WebSocket], **kwargs: Dict[str, Any]
Expand All @@ -225,6 +246,11 @@ async def parse_values_for_connection(
try:
if cls.encoders:
kwargs = await cls.parse_encoders(kwargs)

# Checks if any of the parameters is a requires dependency
kwargs = await cls.check_requires(kwargs)

# Apply into the signature
signature = cls(**kwargs)
values = {}
for key in cls.model_fields:
Expand Down
15 changes: 15 additions & 0 deletions esmerald/utils/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def is_security_scheme(param: Any) -> bool:
return isinstance(param, params.Security)


def is_base_requires(param: Any) -> bool:
"""
Checks if the object is a base requires object.
"""
return is_class_and_subclass(param, params.BaseRequires)


def is_security_scope(param: Any) -> bool:
"""
Checks if the object is a security scope object.
Expand All @@ -41,6 +48,14 @@ def is_inject(param: Any) -> bool:

return isinstance(param, Inject)

def is_requires(param: Any) -> bool:
"""
Checks if the object is an Inject.
"""
if not param:
return False
return isinstance(param, params.Requires)


async def async_resolve_dependencies(func: Any, overrides: Union[Dict[str, Any]] = None) -> Any:
"""
Expand Down
98 changes: 94 additions & 4 deletions tests/dependencies/test_requires.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import Any

import anyio
import pytest

from esmerald.param_functions import Requires
from esmerald import Gateway, Inject, Injects, JSONResponse, Requires, get
from esmerald.testclient import create_client
from esmerald.utils.dependencies import async_resolve_dependencies, resolve_dependencies


def get_user():
return {"id": 1, "name": "Alice"}


def get_current_user(user=Requires(get_user)):
def get_current_user(user: Any = Requires(get_user)):
return user


Expand All @@ -18,11 +21,11 @@ async def get_async_user():
return {"id": 2, "name": "Bob"}


async def async_endpoint(current_user=Requires(get_async_user)):
async def async_endpoint(current_user: Any = Requires(get_async_user)):
return {"message": "Hello", "user": current_user}


def endpoint(current_user=Requires(get_current_user)):
def endpoint(current_user: Any = Requires(get_current_user)):
return {"message": "Hello", "user": current_user}


Expand All @@ -36,3 +39,90 @@ async def test_required_dependency_async():
def test_required_dependency():
result = resolve_dependencies(endpoint)
assert result == {"message": "Hello", "user": {"id": 1, "name": "Alice"}}


@get("/requires", dependencies={"current_user": Inject(get_current_user)})
async def get_requires(current_user: Any = Injects()) -> JSONResponse:
return JSONResponse({"message": "Hello", "user": current_user})


def test_use_requires_in_function_dependencies_using_inject(test_client_factory):
with create_client(
routes=[
Gateway(handler=get_requires),
],
) as client:
response = client.get("/requires")
assert response.status_code == 200
assert response.json() == {"message": "Hello", "user": {"id": 1, "name": "Alice"}}


@get("/requires-simple")
async def get_requires_simple(current_user: Any = Requires(endpoint)) -> JSONResponse:
return JSONResponse(current_user)


def test_use_requires_as_a_non_dependency(test_app_client_factory):
with create_client(
routes=[
Gateway(handler=get_requires_simple),
],
) as client:
response = client.get("/requires-simple")

assert response.status_code == 200
assert response.json() == {"message": "Hello", "user": {"id": 1, "name": "Alice"}}


@get("/requires-typed-error")
async def get_requires_typed_error(current_user: int = Requires(endpoint)) -> JSONResponse: ...


def test_use_requires_raise_error_for_typing(test_app_client_factory):
with create_client(
routes=[
Gateway(handler=get_requires_typed_error),
],
) as client:
response = client.get("/requires-typed-error")

assert response.status_code == 400


def test_openapi(test_client_factory):
with create_client(
routes=[
Gateway(handler=get_requires_simple),
],
enable_openapi=True,
) as client:
response = client.get("/openapi.json")

assert response.status_code == 200
assert response.json() == {
"openapi": "3.1.0",
"info": {
"title": "Esmerald",
"summary": "Esmerald application",
"description": "Highly scalable, performant, easy to learn and for every application.",
"contact": {"name": "admin", "email": "[email protected]"},
"version": client.app.version,
},
"servers": [{"url": "/"}],
"paths": {
"/requires-simple": {
"get": {
"summary": "Get Requires Simple",
"description": "",
"operationId": "get_requires_simple_requires_simple_get",
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "string"}}},
}
},
"deprecated": False,
}
}
},
}

0 comments on commit d82f5f7

Please sign in to comment.