Skip to content

Commit

Permalink
Add requires for generic types
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Jan 13, 2025
1 parent d82f5f7 commit 48c7ef1
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 8 deletions.
2 changes: 1 addition & 1 deletion esmerald/openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_flat_params(route: Union[router.HTTPHandler, Any], body_fields: List[str
if param.field_info.alias in body_fields:
continue

if param.is_security:
if param.is_security or param.is_requires_dependency:
continue

# Making sure all the optional and union types are included
Expand Down
7 changes: 7 additions & 0 deletions esmerald/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,23 @@ def get_parameter_settings(
for field_name, model_field in signature_fields.items():
if field_name not in ignored_keys:
allow_none = getattr(model_field, "allow_none", True)

# Flag if its a Security dependency
is_security = is_security_scheme(model_field.default) or is_security_scope(
model_field.annotation
)

# Flag if its a Requires() dependency
is_requires_dependency = is_requires(model_field.default)

parameter_definitions.add(
create_parameter_setting(
allow_none=allow_none,
field_name=field_name,
field_info=model_field,
path_parameters=path_parameters,
is_security=is_security,
is_requires_dependency=is_requires_dependency,
)
)

Expand Down
11 changes: 11 additions & 0 deletions esmerald/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from esmerald.requests import Request
from esmerald.typing import Undefined
from esmerald.utils.constants import REQUIRED
from esmerald.utils.dependencies import is_requires
from esmerald.utils.helpers import is_class_and_subclass, is_union
from esmerald.utils.schema import should_skip_json_schema

Expand All @@ -41,6 +42,7 @@ class ParamSetting(NamedTuple):
param_type: ParamType
field_info: FieldInfo
is_security: bool = False
is_requires_dependency: bool = False


class Dependency(HashableBaseModel, ArbitraryExtraBaseModel):
Expand Down Expand Up @@ -113,6 +115,7 @@ def create_parameter_setting(
field_name: str,
path_parameters: Set[str],
is_security: bool,
is_requires_dependency: bool,
) -> ParamSetting:
"""
Create a setting definition for a parameter.
Expand Down Expand Up @@ -164,6 +167,7 @@ def create_parameter_setting(
field_info=param,
is_required=is_required and (default_value is None and not allow_none),
is_security=is_security,
is_requires_dependency=is_requires_dependency,
)
return param_settings

Expand Down Expand Up @@ -213,6 +217,13 @@ async def get_request_params(

values: Dict[Any, Any] = {}
for param in expected:
is_requires_dependency = is_requires(param.default_value)

# Using the default value if the parameter is a dependency requires
if is_requires_dependency:
values[param.field_name] = param.default_value
continue

if not is_union(param.field_info.annotation):
annotation = get_origin(param.field_info.annotation)
origin = annotation or param.field_info.annotation
Expand Down
1 change: 1 addition & 0 deletions esmerald/utils/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def is_inject(param: Any) -> bool:

return isinstance(param, Inject)


def is_requires(param: Any) -> bool:
"""
Checks if the object is an Inject.
Expand Down
55 changes: 48 additions & 7 deletions tests/dependencies/test_requires.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, List, Optional, Union

import anyio
import pytest
Expand Down Expand Up @@ -57,8 +57,29 @@ def test_use_requires_in_function_dependencies_using_inject(test_client_factory)
assert response.json() == {"message": "Hello", "user": {"id": 1, "name": "Alice"}}


@get("/requires-simple-union")
async def get_requires_simple_union(
current_user: Union[Dict[str, Any], None] = Requires(endpoint),
) -> JSONResponse:
return JSONResponse(current_user)


def test_use_requires_as_a_non_dependency_union(test_app_client_factory):
with create_client(
routes=[
Gateway(handler=get_requires_simple_union),
],
) as client:
response = client.get("/requires-simple-union")

assert response.status_code == 200
assert response.json() == {"message": "Hello", "user": {"id": 1, "name": "Alice"}}


@get("/requires-simple")
async def get_requires_simple(current_user: Any = Requires(endpoint)) -> JSONResponse:
async def get_requires_simple(
current_user: Any = Requires(endpoint),
) -> JSONResponse:
return JSONResponse(current_user)


Expand All @@ -74,15 +95,20 @@ def test_use_requires_as_a_non_dependency(test_app_client_factory):
assert response.json() == {"message": "Hello", "user": {"id": 1, "name": "Alice"}}


@get("/requires-typed-error")
async def get_requires_typed_error(current_user: int = Requires(endpoint)) -> JSONResponse: ...
@pytest.mark.parametrize(
"type_field", [List[str], int, float, frozenset, Union[List[str], None], Optional[List[str]]], ids=["list", "int", "float", "frozenset", "union", "optional"]
)
def test_use_requires_raise_error_for_typing(test_app_client_factory, type_field):
@get("/requires-typed-error")
async def get_requires_typed_error(
current_user: type_field = Requires(endpoint),
) -> JSONResponse: ...


def test_use_requires_raise_error_for_typing(test_app_client_factory):
with create_client(
routes=[
Gateway(handler=get_requires_typed_error),
],
debug=False,
) as client:
response = client.get("/requires-typed-error")

Expand All @@ -93,6 +119,7 @@ def test_openapi(test_client_factory):
with create_client(
routes=[
Gateway(handler=get_requires_simple),
Gateway(handler=get_requires_simple_union),
],
enable_openapi=True,
) as client:
Expand Down Expand Up @@ -123,6 +150,20 @@ def test_openapi(test_client_factory):
},
"deprecated": False,
}
}
},
"/requires-simple-union": {
"get": {
"summary": "Get Requires Simple Union",
"description": "",
"operationId": "get_requires_simple_union_requires_simple_union_get",
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "string"}}},
}
},
"deprecated": False,
}
},
},
}

0 comments on commit 48c7ef1

Please sign in to comment.