diff --git a/esmerald/__init__.py b/esmerald/__init__.py index adbe73a9..9c15f319 100644 --- a/esmerald/__init__.py +++ b/esmerald/__init__.py @@ -23,7 +23,7 @@ ValidationErrorException, ) from .interceptors.interceptor import EsmeraldInterceptor -from .param_functions import Security +from .param_functions import Requires, Security from .params import Body, Cookie, File, Form, Header, Injects, Param, Path, Query from .permissions import AllowAny, BasePermission, DenyAll from .pluggables import Extension, Pluggable @@ -91,6 +91,7 @@ "Query", "Redirect", "Request", + "Requires", "Response", "Router", "Security", diff --git a/esmerald/utils/dependencies.py b/esmerald/utils/dependencies.py index 41e7820b..3620edcf 100644 --- a/esmerald/utils/dependencies.py +++ b/esmerald/utils/dependencies.py @@ -1,4 +1,7 @@ -from typing import Any +import inspect +from typing import Any, Dict, Union + +from lilya.compat import run_sync from esmerald import params from esmerald.security.scopes import Scopes @@ -37,3 +40,59 @@ def is_inject(param: Any) -> bool: from esmerald.injector import Inject return isinstance(param, Inject) + + +async def async_resolve_dependencies(func: Any, overrides: Union[Dict[str, Any]] = None) -> Any: + """ + Resolves dependencies for an asynchronous function by inspecting its signature and + recursively resolving any dependencies specified using the `params.Requires` class. + Args: + func (Any): The target function whose dependencies need to be resolved. + overrides (Union[Dict[str, Any]], optional): A dictionary of overrides for dependencies. + This can be used for testing or customization. Defaults to None. + Returns: + Any: The result of the target function with its dependencies resolved. + Raises: + TypeError: If the target function or any of its dependencies are not callable. + """ + if overrides is None: + overrides = {} + + signature = inspect.signature(func) + kwargs = {} + + for name, param in signature.parameters.items(): + if isinstance(param.default, params.Requires): + dep_func = param.default.dependency + dep_func = overrides.get(dep_func, dep_func) # type: ignore + if inspect.iscoroutinefunction(dep_func): + resolved = await async_resolve_dependencies(dep_func, overrides) + else: + resolved = ( + resolve_dependencies(dep_func, overrides) if callable(dep_func) else dep_func + ) + kwargs[name] = resolved + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + else: + return func(**kwargs) + + +def resolve_dependencies(func: Any, overrides: Union[Dict[str, Any]] = None) -> Any: + """ + Resolves the dependencies for a given function. + + Parameters: + func (Any): The function for which dependencies need to be resolved. + overrides (Union[Dict[str, Any], None], optional): A dictionary of dependency overrides. Defaults to None. + Raises: + ValueError: If the provided function is asynchronous. + + Returns: + Any: The result of running the asynchronous dependency resolution function. + """ + if overrides is None: + overrides = {} + if inspect.iscoroutinefunction(func): + raise ValueError("Function is async. Use resolve_dependencies_async instead.") + return run_sync(async_resolve_dependencies(func, overrides)) diff --git a/tests/dependencies/test_requires.py b/tests/dependencies/test_requires.py new file mode 100644 index 00000000..7837b3d5 --- /dev/null +++ b/tests/dependencies/test_requires.py @@ -0,0 +1,38 @@ +import anyio +import pytest + +from esmerald.param_functions import Requires +from esmerald.utils.dependencies import async_resolve_dependencies, resolve_dependencies + + +def get_user(): + return {"id": 1, "name": "Alice"} + + +def get_current_user(user=Requires(get_user)): + return user + + +async def get_async_user(): + await anyio.sleep(0.1) + return {"id": 2, "name": "Bob"} + + +async def async_endpoint(current_user=Requires(get_async_user)): + return {"message": "Hello", "user": current_user} + + +def endpoint(current_user=Requires(get_current_user)): + return {"message": "Hello", "user": current_user} + + +@pytest.mark.asyncio +async def test_required_dependency_async(): + async_result = await async_resolve_dependencies(async_endpoint) + + assert async_result == {"message": "Hello", "user": {"id": 2, "name": "Bob"}} + + +def test_required_dependency(): + result = resolve_dependencies(endpoint) + assert result == {"message": "Hello", "user": {"id": 1, "name": "Alice"}}