Skip to content

Commit

Permalink
Add test for requires
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Jan 10, 2025
1 parent e8d9271 commit 77426ff
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 2 deletions.
3 changes: 2 additions & 1 deletion esmerald/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ValidationErrorException,
)
from .interceptors.interceptor import EsmeraldInterceptor
from .param_functions import Security
from .param_functions import Requires, Security
from .params import Body, Cookie, File, Form, Header, Injects, Param, Path, Query
from .permissions import AllowAny, BasePermission, DenyAll
from .pluggables import Extension, Pluggable
Expand Down Expand Up @@ -91,6 +91,7 @@
"Query",
"Redirect",
"Request",
"Requires",
"Response",
"Router",
"Security",
Expand Down
61 changes: 60 additions & 1 deletion esmerald/utils/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any
import inspect
from typing import Any, Dict, Union

from lilya.compat import run_sync

from esmerald import params
from esmerald.security.scopes import Scopes
Expand Down Expand Up @@ -37,3 +40,59 @@ def is_inject(param: Any) -> bool:
from esmerald.injector import Inject

return isinstance(param, Inject)


async def async_resolve_dependencies(func: Any, overrides: Union[Dict[str, Any]] = None) -> Any:
"""
Resolves dependencies for an asynchronous function by inspecting its signature and
recursively resolving any dependencies specified using the `params.Requires` class.
Args:
func (Any): The target function whose dependencies need to be resolved.
overrides (Union[Dict[str, Any]], optional): A dictionary of overrides for dependencies.
This can be used for testing or customization. Defaults to None.
Returns:
Any: The result of the target function with its dependencies resolved.
Raises:
TypeError: If the target function or any of its dependencies are not callable.
"""
if overrides is None:
overrides = {}

signature = inspect.signature(func)
kwargs = {}

for name, param in signature.parameters.items():
if isinstance(param.default, params.Requires):
dep_func = param.default.dependency
dep_func = overrides.get(dep_func, dep_func) # type: ignore
if inspect.iscoroutinefunction(dep_func):
resolved = await async_resolve_dependencies(dep_func, overrides)
else:
resolved = (
resolve_dependencies(dep_func, overrides) if callable(dep_func) else dep_func
)
kwargs[name] = resolved
if inspect.iscoroutinefunction(func):
return await func(**kwargs)
else:
return func(**kwargs)


def resolve_dependencies(func: Any, overrides: Union[Dict[str, Any]] = None) -> Any:
"""
Resolves the dependencies for a given function.
Parameters:
func (Any): The function for which dependencies need to be resolved.
overrides (Union[Dict[str, Any], None], optional): A dictionary of dependency overrides. Defaults to None.
Raises:
ValueError: If the provided function is asynchronous.
Returns:
Any: The result of running the asynchronous dependency resolution function.
"""
if overrides is None:
overrides = {}
if inspect.iscoroutinefunction(func):
raise ValueError("Function is async. Use resolve_dependencies_async instead.")
return run_sync(async_resolve_dependencies(func, overrides))
38 changes: 38 additions & 0 deletions tests/dependencies/test_requires.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import anyio
import pytest

from esmerald.param_functions import Requires
from esmerald.utils.dependencies import async_resolve_dependencies, resolve_dependencies


def get_user():
return {"id": 1, "name": "Alice"}


def get_current_user(user=Requires(get_user)):
return user


async def get_async_user():
await anyio.sleep(0.1)
return {"id": 2, "name": "Bob"}


async def async_endpoint(current_user=Requires(get_async_user)):
return {"message": "Hello", "user": current_user}


def endpoint(current_user=Requires(get_current_user)):
return {"message": "Hello", "user": current_user}


@pytest.mark.asyncio
async def test_required_dependency_async():
async_result = await async_resolve_dependencies(async_endpoint)

assert async_result == {"message": "Hello", "user": {"id": 2, "name": "Bob"}}


def test_required_dependency():
result = resolve_dependencies(endpoint)
assert result == {"message": "Hello", "user": {"id": 1, "name": "Alice"}}

0 comments on commit 77426ff

Please sign in to comment.