diff --git a/esmerald/openapi/openapi.py b/esmerald/openapi/openapi.py index 12927c90..d0fd0638 100644 --- a/esmerald/openapi/openapi.py +++ b/esmerald/openapi/openapi.py @@ -2,7 +2,7 @@ import inspect import json import warnings -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, _GenericAlias, cast from lilya._internal._path import clean_path from lilya.middleware import DefineMiddleware @@ -43,6 +43,10 @@ from esmerald.utils.constants import DATA, PAYLOAD from esmerald.utils.helpers import is_class_and_subclass, is_union +ADDITIONAL_TYPES = ["bool", "list", "dict"] +TRANSFORMER_TYPES_KEYS = list(TRANSFORMER_TYPES.keys()) +TRANSFORMER_TYPES_KEYS += ADDITIONAL_TYPES + def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]: """Gets all the neded params of the request and route""" @@ -59,9 +63,11 @@ def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]: query_params.append(param.field_info) else: - if ( - param.field_info.annotation.__class__.__name__ in TRANSFORMER_TYPES.keys() - or param.field_info.annotation.__name__ in TRANSFORMER_TYPES.keys() + if isinstance(param.field_info.annotation, _GenericAlias): + query_params.append(param.field_info) + elif ( + param.field_info.annotation.__class__.__name__ in TRANSFORMER_TYPES_KEYS + or param.field_info.annotation.__name__ in TRANSFORMER_TYPES_KEYS ): query_params.append(param.field_info) diff --git a/esmerald/transformers/model.py b/esmerald/transformers/model.py index b059f723..0daadd4d 100644 --- a/esmerald/transformers/model.py +++ b/esmerald/transformers/model.py @@ -289,61 +289,6 @@ def merge_with(self, other: "TransformerModel") -> "TransformerModel": is_optional=self.is_optional or other.is_optional, ) - def _get_request_params( - self, - connection: Union["WebSocket", "Request"], - handler: Union["HTTPHandler", "WebSocketHandler"] = None, - ) -> Any: - """ - Get request parameters. - - Args: - connection (Union["WebSocket", "Request"]): Connection object. - handler (Union["HTTPHandler", "WebSocketHandler"], optional): Handler object. Defaults to None. - - Returns: - Any: Request parameters. - """ - connection_params: Dict[str, Any] = {} - for key, value in connection.query_params.items(): - if len(value) == 1: - value = value[0] - connection_params[key] = value - - query_params = get_request_params( - params=cast("MappingUnion", connection.query_params), - expected=self.query_params, - url=connection.url, - ) - path_params = get_request_params( - params=cast("MappingUnion", connection.path_params), - expected=self.path_params, - url=connection.url, - ) - headers = get_request_params( - params=cast("MappingUnion", connection.headers), - expected=self.headers, - url=connection.url, - ) - cookies = get_request_params( - params=cast("MappingUnion", connection.cookies), - expected=self.cookies, - url=connection.url, - ) - - if not self.reserved_kwargs: - return {**query_params, **path_params, **headers, **cookies} - - return self.handle_reserved_kwargs( - connection=connection, - connection_params=connection_params, - path_params=path_params, - query_params=query_params, - headers=headers, - cookies=cookies, - handler=handler, - ) - def handle_reserved_kwargs( self, connection: Union["WebSocket", "Request"], diff --git a/esmerald/transformers/signature.py b/esmerald/transformers/signature.py index e5087eb6..865ab02d 100644 --- a/esmerald/transformers/signature.py +++ b/esmerald/transformers/signature.py @@ -98,7 +98,12 @@ def __init__( "If it should receive any value, use 'Any' as type." ) self.annotation = parameter.annotation - self.default = parameter.default + + self.default = ( + tuple(parameter.default) + if isinstance(parameter.default, (list, dict, set)) + else parameter.default + ) self.param_name = param_name self.name = param_name self.optional = is_optional_union(self.annotation) diff --git a/esmerald/transformers/utils.py b/esmerald/transformers/utils.py index 9aaf24e3..47c48a3c 100644 --- a/esmerald/transformers/utils.py +++ b/esmerald/transformers/utils.py @@ -1,4 +1,17 @@ -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Set, Tuple, Type, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Mapping, + NamedTuple, + Set, + Tuple, + Type, + Union, + cast, + get_origin, +) from lilya.datastructures import URL from pydantic.fields import FieldInfo @@ -10,6 +23,7 @@ from esmerald.requests import Request from esmerald.typing import Undefined from esmerald.utils.constants import REQUIRED +from esmerald.utils.helpers import is_class_and_subclass if TYPE_CHECKING: # pragma: no cover from esmerald.injector import Inject @@ -166,7 +180,9 @@ def _get_missing_required_params(params: Any, expected: Set[ParamSetting]) -> Li return missing_params -def get_request_params(params: Any, expected: Set[ParamSetting], url: URL) -> Any: +def get_request_params( + params: Mapping[Union[int, str], Any], expected: Set[ParamSetting], url: URL +) -> Any: """ Gather the parameters from the request. @@ -187,9 +203,20 @@ def get_request_params(params: Any, expected: Set[ParamSetting], url: URL) -> An f"Missing required parameter(s) {', '.join(missing_params)} for URL {url}." ) - values = { - param.field_name: params.get(param.field_alias, param.default_value) for param in expected - } + values = {} + for param in expected: + annotation = get_origin(param.field_info.annotation) + origin = annotation or param.field_info.annotation + + if is_class_and_subclass(origin, (list, tuple)): + values[param.field_name] = params.values() + elif is_class_and_subclass(origin, dict): + if not params.items(): + values[param.field_name] = None + else: + values[param.field_name] = dict(params.items()) # type: ignore[assignment] + else: + values[param.field_name] = params.get(param.field_alias, param.default_value) return values diff --git a/tests/openapi/uniques/__init__.py b/tests/openapi/uniques/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/openapi/uniques/test_openapi_bool.py b/tests/openapi/uniques/test_openapi_bool.py new file mode 100644 index 00000000..12dcfd27 --- /dev/null +++ b/tests/openapi/uniques/test_openapi_bool.py @@ -0,0 +1,109 @@ +from esmerald import Gateway, JSONResponse, get +from esmerald.testclient import create_client + + +@get("/bool") +async def check_bool(a_value: bool) -> JSONResponse: + return JSONResponse({"value": a_value}) + + +def test_query_param(test_client_factory): + with create_client(routes=Gateway(handler=check_bool)) as client: + + response = client.get("/bool?a_value=true") + + assert response.status_code == 200 + assert response.json() == {"value": True} + + response = client.get("/bool?a_value=1") + assert response.json() == {"value": True} + + response = client.get("/bool?a_value=0") + assert response.json() == {"value": False} + + response = client.get("/bool?a_value=false") + assert response.json() == {"value": False} + + +def test_open_api(test_app_client_factory): + with create_client(routes=Gateway(handler=check_bool)) as client: + response = client.get("/openapi.json") + + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/bool": { + "get": { + "summary": "Check Bool", + "operationId": "check_bool_bool_get", + "parameters": [ + { + "name": "a_value", + "in": "query", + "required": True, + "deprecated": False, + "allowEmptyValue": False, + "allowReserved": False, + "schema": {"type": "boolean", "title": "A Value"}, + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + } diff --git a/tests/openapi/uniques/test_openapi_dict.py b/tests/openapi/uniques/test_openapi_dict.py new file mode 100644 index 00000000..8e84be27 --- /dev/null +++ b/tests/openapi/uniques/test_openapi_dict.py @@ -0,0 +1,102 @@ +from typing import Any, Dict + +from esmerald import Gateway, JSONResponse, get +from esmerald.testclient import create_client + + +@get("/dict") +async def check_dict(a_value: Dict[str, Any]) -> JSONResponse: + return JSONResponse({"value": a_value}) + + +def test_query_param(test_client_factory): + with create_client(routes=Gateway(handler=check_dict)) as client: + response = client.get("/dict?a_value=true&b_value=false&c_value=test") + + assert response.status_code == 200 + assert response.json() == { + "value": {"a_value": "true", "b_value": "false", "c_value": "test"} + } + + +def test_open_api(test_app_client_factory): + with create_client(routes=Gateway(handler=check_dict)) as client: + response = client.get("/openapi.json") + + assert response.status_code == 200, response.text + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/dict": { + "get": { + "summary": "Check Dict", + "operationId": "check_dict_dict_get", + "parameters": [ + { + "name": "a_value", + "in": "query", + "required": True, + "deprecated": False, + "allowEmptyValue": False, + "allowReserved": False, + "schema": {"type": "object", "title": "A Value"}, + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + } diff --git a/tests/openapi/uniques/test_openapi_list.py b/tests/openapi/uniques/test_openapi_list.py new file mode 100644 index 00000000..eb062868 --- /dev/null +++ b/tests/openapi/uniques/test_openapi_list.py @@ -0,0 +1,210 @@ +from typing import List + +from typing_extensions import Annotated + +from esmerald import Gateway, JSONResponse, Query, get +from esmerald.testclient import create_client + + +@get("/list") +async def check_list(a_value: List[str]) -> JSONResponse: + return JSONResponse({"value": a_value}) + + +@get("/another-list") +async def check_another_list( + a_value: Annotated[list, Query()] = ["true", "false", "test"] # noqa +) -> JSONResponse: + return JSONResponse({"value": a_value}) + + +def test_query_param(test_client_factory): + with create_client( + routes=[Gateway(handler=check_list), Gateway(handler=check_another_list)] + ) as client: + + response = client.get("/list?a_value=true&a_value=false&a_value=test") + + assert response.status_code == 200 + assert response.json() == {"value": ["true", "false", "test"]} + + response = client.get("/another-list?a_value=true&a_value=false&a_value=test") + + assert response.status_code == 200 + assert response.json() == {"value": ["true", "false", "test"]} + + +def test_open_api(test_app_client_factory): + with create_client(routes=Gateway(handler=check_list)) as client: + response = client.get("/openapi.json") + + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/list": { + "get": { + "summary": "Check List", + "operationId": "check_list_list_get", + "parameters": [ + { + "name": "a_value", + "in": "query", + "required": True, + "deprecated": False, + "allowEmptyValue": False, + "allowReserved": False, + "schema": { + "items": {"type": "string"}, + "type": "array", + "title": "A Value", + }, + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + } + + +def test_open_api_annotated(test_app_client_factory): + with create_client(routes=Gateway(handler=check_another_list)) as client: + response = client.get("/openapi.json") + + assert response.status_code == 200, response.text + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/another-list": { + "get": { + "summary": "Check Another List", + "operationId": "check_another_list_another_list_get", + "parameters": [ + { + "name": "a_value", + "in": "query", + "required": False, + "deprecated": False, + "allowEmptyValue": False, + "allowReserved": False, + "schema": { + "items": {}, + "type": "array", + "title": "A Value", + "default": ["true", "false", "test"], + }, + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + }