Skip to content

Commit

Permalink
Query Parameters & OpenAPI (#374)
Browse files Browse the repository at this point in the history
* Add missing types
* Update internal for generics
* Add compatibility for list and object in query
* Add fix for dict in query string
* Add hashing for unique types
* Add tests for OpenAPI types
* Fix dependency tests and logic
  • Loading branch information
tarsil authored Aug 2, 2024
1 parent d915830 commit ec7ebeb
Show file tree
Hide file tree
Showing 8 changed files with 469 additions and 65 deletions.
14 changes: 10 additions & 4 deletions esmerald/openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import json
import warnings
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, _GenericAlias, cast

from lilya._internal._path import clean_path
from lilya.middleware import DefineMiddleware
Expand Down Expand Up @@ -43,6 +43,10 @@
from esmerald.utils.constants import DATA, PAYLOAD
from esmerald.utils.helpers import is_class_and_subclass, is_union

ADDITIONAL_TYPES = ["bool", "list", "dict"]
TRANSFORMER_TYPES_KEYS = list(TRANSFORMER_TYPES.keys())
TRANSFORMER_TYPES_KEYS += ADDITIONAL_TYPES


def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]:
"""Gets all the neded params of the request and route"""
Expand All @@ -59,9 +63,11 @@ def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]:
query_params.append(param.field_info)

else:
if (
param.field_info.annotation.__class__.__name__ in TRANSFORMER_TYPES.keys()
or param.field_info.annotation.__name__ in TRANSFORMER_TYPES.keys()
if isinstance(param.field_info.annotation, _GenericAlias):
query_params.append(param.field_info)
elif (
param.field_info.annotation.__class__.__name__ in TRANSFORMER_TYPES_KEYS
or param.field_info.annotation.__name__ in TRANSFORMER_TYPES_KEYS
):
query_params.append(param.field_info)

Expand Down
55 changes: 0 additions & 55 deletions esmerald/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,61 +289,6 @@ def merge_with(self, other: "TransformerModel") -> "TransformerModel":
is_optional=self.is_optional or other.is_optional,
)

def _get_request_params(
self,
connection: Union["WebSocket", "Request"],
handler: Union["HTTPHandler", "WebSocketHandler"] = None,
) -> Any:
"""
Get request parameters.
Args:
connection (Union["WebSocket", "Request"]): Connection object.
handler (Union["HTTPHandler", "WebSocketHandler"], optional): Handler object. Defaults to None.
Returns:
Any: Request parameters.
"""
connection_params: Dict[str, Any] = {}
for key, value in connection.query_params.items():
if len(value) == 1:
value = value[0]
connection_params[key] = value

query_params = get_request_params(
params=cast("MappingUnion", connection.query_params),
expected=self.query_params,
url=connection.url,
)
path_params = get_request_params(
params=cast("MappingUnion", connection.path_params),
expected=self.path_params,
url=connection.url,
)
headers = get_request_params(
params=cast("MappingUnion", connection.headers),
expected=self.headers,
url=connection.url,
)
cookies = get_request_params(
params=cast("MappingUnion", connection.cookies),
expected=self.cookies,
url=connection.url,
)

if not self.reserved_kwargs:
return {**query_params, **path_params, **headers, **cookies}

return self.handle_reserved_kwargs(
connection=connection,
connection_params=connection_params,
path_params=path_params,
query_params=query_params,
headers=headers,
cookies=cookies,
handler=handler,
)

def handle_reserved_kwargs(
self,
connection: Union["WebSocket", "Request"],
Expand Down
7 changes: 6 additions & 1 deletion esmerald/transformers/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ def __init__(
"If it should receive any value, use 'Any' as type."
)
self.annotation = parameter.annotation
self.default = parameter.default

self.default = (
tuple(parameter.default)
if isinstance(parameter.default, (list, dict, set))
else parameter.default
)
self.param_name = param_name
self.name = param_name
self.optional = is_optional_union(self.annotation)
Expand Down
37 changes: 32 additions & 5 deletions esmerald/transformers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Set, Tuple, Type, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Mapping,
NamedTuple,
Set,
Tuple,
Type,
Union,
cast,
get_origin,
)

from lilya.datastructures import URL
from pydantic.fields import FieldInfo
Expand All @@ -10,6 +23,7 @@
from esmerald.requests import Request
from esmerald.typing import Undefined
from esmerald.utils.constants import REQUIRED
from esmerald.utils.helpers import is_class_and_subclass

if TYPE_CHECKING: # pragma: no cover
from esmerald.injector import Inject
Expand Down Expand Up @@ -166,7 +180,9 @@ def _get_missing_required_params(params: Any, expected: Set[ParamSetting]) -> Li
return missing_params


def get_request_params(params: Any, expected: Set[ParamSetting], url: URL) -> Any:
def get_request_params(
params: Mapping[Union[int, str], Any], expected: Set[ParamSetting], url: URL
) -> Any:
"""
Gather the parameters from the request.
Expand All @@ -187,9 +203,20 @@ def get_request_params(params: Any, expected: Set[ParamSetting], url: URL) -> An
f"Missing required parameter(s) {', '.join(missing_params)} for URL {url}."
)

values = {
param.field_name: params.get(param.field_alias, param.default_value) for param in expected
}
values = {}
for param in expected:
annotation = get_origin(param.field_info.annotation)
origin = annotation or param.field_info.annotation

if is_class_and_subclass(origin, (list, tuple)):
values[param.field_name] = params.values()
elif is_class_and_subclass(origin, dict):
if not params.items():
values[param.field_name] = None
else:
values[param.field_name] = dict(params.items()) # type: ignore[assignment]
else:
values[param.field_name] = params.get(param.field_alias, param.default_value)
return values


Expand Down
Empty file.
109 changes: 109 additions & 0 deletions tests/openapi/uniques/test_openapi_bool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from esmerald import Gateway, JSONResponse, get
from esmerald.testclient import create_client


@get("/bool")
async def check_bool(a_value: bool) -> JSONResponse:
return JSONResponse({"value": a_value})


def test_query_param(test_client_factory):
with create_client(routes=Gateway(handler=check_bool)) as client:

response = client.get("/bool?a_value=true")

assert response.status_code == 200
assert response.json() == {"value": True}

response = client.get("/bool?a_value=1")
assert response.json() == {"value": True}

response = client.get("/bool?a_value=0")
assert response.json() == {"value": False}

response = client.get("/bool?a_value=false")
assert response.json() == {"value": False}


def test_open_api(test_app_client_factory):
with create_client(routes=Gateway(handler=check_bool)) as client:
response = client.get("/openapi.json")

assert response.status_code == 200, response.text

assert response.json() == {
"openapi": "3.1.0",
"info": {
"title": "Esmerald",
"summary": "Esmerald application",
"description": "Highly scalable, performant, easy to learn and for every application.",
"contact": {"name": "admin", "email": "[email protected]"},
"version": client.app.version,
},
"servers": [{"url": "/"}],
"paths": {
"/bool": {
"get": {
"summary": "Check Bool",
"operationId": "check_bool_bool_get",
"parameters": [
{
"name": "a_value",
"in": "query",
"required": True,
"deprecated": False,
"allowEmptyValue": False,
"allowReserved": False,
"schema": {"type": "boolean", "title": "A Value"},
}
],
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "string"}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"deprecated": False,
}
}
},
"components": {
"schemas": {
"HTTPValidationError": {
"properties": {
"detail": {
"items": {"$ref": "#/components/schemas/ValidationError"},
"type": "array",
"title": "Detail",
}
},
"type": "object",
"title": "HTTPValidationError",
},
"ValidationError": {
"properties": {
"loc": {
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
"type": "array",
"title": "Location",
},
"msg": {"type": "string", "title": "Message"},
"type": {"type": "string", "title": "Error Type"},
},
"type": "object",
"required": ["loc", "msg", "type"],
"title": "ValidationError",
},
}
},
}
102 changes: 102 additions & 0 deletions tests/openapi/uniques/test_openapi_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Any, Dict

from esmerald import Gateway, JSONResponse, get
from esmerald.testclient import create_client


@get("/dict")
async def check_dict(a_value: Dict[str, Any]) -> JSONResponse:
return JSONResponse({"value": a_value})


def test_query_param(test_client_factory):
with create_client(routes=Gateway(handler=check_dict)) as client:
response = client.get("/dict?a_value=true&b_value=false&c_value=test")

assert response.status_code == 200
assert response.json() == {
"value": {"a_value": "true", "b_value": "false", "c_value": "test"}
}


def test_open_api(test_app_client_factory):
with create_client(routes=Gateway(handler=check_dict)) as client:
response = client.get("/openapi.json")

assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.1.0",
"info": {
"title": "Esmerald",
"summary": "Esmerald application",
"description": "Highly scalable, performant, easy to learn and for every application.",
"contact": {"name": "admin", "email": "[email protected]"},
"version": client.app.version,
},
"servers": [{"url": "/"}],
"paths": {
"/dict": {
"get": {
"summary": "Check Dict",
"operationId": "check_dict_dict_get",
"parameters": [
{
"name": "a_value",
"in": "query",
"required": True,
"deprecated": False,
"allowEmptyValue": False,
"allowReserved": False,
"schema": {"type": "object", "title": "A Value"},
}
],
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "string"}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"deprecated": False,
}
}
},
"components": {
"schemas": {
"HTTPValidationError": {
"properties": {
"detail": {
"items": {"$ref": "#/components/schemas/ValidationError"},
"type": "array",
"title": "Detail",
}
},
"type": "object",
"title": "HTTPValidationError",
},
"ValidationError": {
"properties": {
"loc": {
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
"type": "array",
"title": "Location",
},
"msg": {"type": "string", "title": "Message"},
"type": {"type": "string", "title": "Error Type"},
},
"type": "object",
"required": ["loc", "msg", "type"],
"title": "ValidationError",
},
}
},
}
Loading

0 comments on commit ec7ebeb

Please sign in to comment.