From 48c7ef1e0d3221afa3ce649ad7d9a317bd955229 Mon Sep 17 00:00:00 2001 From: tarsil Date: Mon, 13 Jan 2025 12:13:07 +0100 Subject: [PATCH] Add requires for generic types --- esmerald/openapi/openapi.py | 2 +- esmerald/transformers/model.py | 7 ++++ esmerald/transformers/utils.py | 11 ++++++ esmerald/utils/dependencies.py | 1 + tests/dependencies/test_requires.py | 55 +++++++++++++++++++++++++---- 5 files changed, 68 insertions(+), 8 deletions(-) diff --git a/esmerald/openapi/openapi.py b/esmerald/openapi/openapi.py index 74199144..6dc91c46 100644 --- a/esmerald/openapi/openapi.py +++ b/esmerald/openapi/openapi.py @@ -86,7 +86,7 @@ def get_flat_params(route: Union[router.HTTPHandler, Any], body_fields: List[str if param.field_info.alias in body_fields: continue - if param.is_security: + if param.is_security or param.is_requires_dependency: continue # Making sure all the optional and union types are included diff --git a/esmerald/transformers/model.py b/esmerald/transformers/model.py index 86d1e06a..ec1ee30b 100644 --- a/esmerald/transformers/model.py +++ b/esmerald/transformers/model.py @@ -459,9 +459,15 @@ def get_parameter_settings( for field_name, model_field in signature_fields.items(): if field_name not in ignored_keys: allow_none = getattr(model_field, "allow_none", True) + + # Flag if its a Security dependency is_security = is_security_scheme(model_field.default) or is_security_scope( model_field.annotation ) + + # Flag if its a Requires() dependency + is_requires_dependency = is_requires(model_field.default) + parameter_definitions.add( create_parameter_setting( allow_none=allow_none, @@ -469,6 +475,7 @@ def get_parameter_settings( field_info=model_field, path_parameters=path_parameters, is_security=is_security, + is_requires_dependency=is_requires_dependency, ) ) diff --git a/esmerald/transformers/utils.py b/esmerald/transformers/utils.py index c0ce449d..eaefae3f 100644 --- a/esmerald/transformers/utils.py +++ b/esmerald/transformers/utils.py @@ -24,6 +24,7 @@ from esmerald.requests import Request from esmerald.typing import Undefined from esmerald.utils.constants import REQUIRED +from esmerald.utils.dependencies import is_requires from esmerald.utils.helpers import is_class_and_subclass, is_union from esmerald.utils.schema import should_skip_json_schema @@ -41,6 +42,7 @@ class ParamSetting(NamedTuple): param_type: ParamType field_info: FieldInfo is_security: bool = False + is_requires_dependency: bool = False class Dependency(HashableBaseModel, ArbitraryExtraBaseModel): @@ -113,6 +115,7 @@ def create_parameter_setting( field_name: str, path_parameters: Set[str], is_security: bool, + is_requires_dependency: bool, ) -> ParamSetting: """ Create a setting definition for a parameter. @@ -164,6 +167,7 @@ def create_parameter_setting( field_info=param, is_required=is_required and (default_value is None and not allow_none), is_security=is_security, + is_requires_dependency=is_requires_dependency, ) return param_settings @@ -213,6 +217,13 @@ async def get_request_params( values: Dict[Any, Any] = {} for param in expected: + is_requires_dependency = is_requires(param.default_value) + + # Using the default value if the parameter is a dependency requires + if is_requires_dependency: + values[param.field_name] = param.default_value + continue + if not is_union(param.field_info.annotation): annotation = get_origin(param.field_info.annotation) origin = annotation or param.field_info.annotation diff --git a/esmerald/utils/dependencies.py b/esmerald/utils/dependencies.py index 657bc165..12419710 100644 --- a/esmerald/utils/dependencies.py +++ b/esmerald/utils/dependencies.py @@ -48,6 +48,7 @@ def is_inject(param: Any) -> bool: return isinstance(param, Inject) + def is_requires(param: Any) -> bool: """ Checks if the object is an Inject. diff --git a/tests/dependencies/test_requires.py b/tests/dependencies/test_requires.py index 8ef07fe5..9a35cfdf 100644 --- a/tests/dependencies/test_requires.py +++ b/tests/dependencies/test_requires.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict, List, Optional, Union import anyio import pytest @@ -57,8 +57,29 @@ def test_use_requires_in_function_dependencies_using_inject(test_client_factory) assert response.json() == {"message": "Hello", "user": {"id": 1, "name": "Alice"}} +@get("/requires-simple-union") +async def get_requires_simple_union( + current_user: Union[Dict[str, Any], None] = Requires(endpoint), +) -> JSONResponse: + return JSONResponse(current_user) + + +def test_use_requires_as_a_non_dependency_union(test_app_client_factory): + with create_client( + routes=[ + Gateway(handler=get_requires_simple_union), + ], + ) as client: + response = client.get("/requires-simple-union") + + assert response.status_code == 200 + assert response.json() == {"message": "Hello", "user": {"id": 1, "name": "Alice"}} + + @get("/requires-simple") -async def get_requires_simple(current_user: Any = Requires(endpoint)) -> JSONResponse: +async def get_requires_simple( + current_user: Any = Requires(endpoint), +) -> JSONResponse: return JSONResponse(current_user) @@ -74,15 +95,20 @@ def test_use_requires_as_a_non_dependency(test_app_client_factory): assert response.json() == {"message": "Hello", "user": {"id": 1, "name": "Alice"}} -@get("/requires-typed-error") -async def get_requires_typed_error(current_user: int = Requires(endpoint)) -> JSONResponse: ... +@pytest.mark.parametrize( + "type_field", [List[str], int, float, frozenset, Union[List[str], None], Optional[List[str]]], ids=["list", "int", "float", "frozenset", "union", "optional"] +) +def test_use_requires_raise_error_for_typing(test_app_client_factory, type_field): + @get("/requires-typed-error") + async def get_requires_typed_error( + current_user: type_field = Requires(endpoint), + ) -> JSONResponse: ... - -def test_use_requires_raise_error_for_typing(test_app_client_factory): with create_client( routes=[ Gateway(handler=get_requires_typed_error), ], + debug=False, ) as client: response = client.get("/requires-typed-error") @@ -93,6 +119,7 @@ def test_openapi(test_client_factory): with create_client( routes=[ Gateway(handler=get_requires_simple), + Gateway(handler=get_requires_simple_union), ], enable_openapi=True, ) as client: @@ -123,6 +150,20 @@ def test_openapi(test_client_factory): }, "deprecated": False, } - } + }, + "/requires-simple-union": { + "get": { + "summary": "Get Requires Simple Union", + "description": "", + "operationId": "get_requires_simple_union_requires_simple_union_get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + } + }, + "deprecated": False, + } + }, }, }