From f01a22cfbc30cd7be69d5dd27d9fe99ee77f8191 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Wed, 13 Nov 2024 10:41:29 +0000 Subject: [PATCH] OpenAPI - Different keywords other than data for upload files (#431) * Allow keywords different from data for upload * Fix request data and Upload files via OpenAPI --- esmerald/openapi/openapi.py | 21 +++- esmerald/routing/_internal.py | 77 +++++++++--- esmerald/routing/base.py | 7 +- tests/openapi/test_upload_using_not_data.py | 113 ++++++++++++++++++ .../test_upload_using_not_data_list.py | 113 ++++++++++++++++++ 5 files changed, 310 insertions(+), 21 deletions(-) create mode 100644 tests/openapi/test_upload_using_not_data.py create mode 100644 tests/openapi/test_upload_using_not_data_list.py diff --git a/esmerald/openapi/openapi.py b/esmerald/openapi/openapi.py index 3306c6bd..febdf896 100644 --- a/esmerald/openapi/openapi.py +++ b/esmerald/openapi/openapi.py @@ -58,7 +58,7 @@ TRANSFORMER_TYPES_KEYS += ADDITIONAL_TYPES -def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]: +def get_flat_params(route: Union[router.HTTPHandler, Any], body_fields: List[str]) -> List[Any]: """ Gets all the neded params of the request and route. """ @@ -75,6 +75,9 @@ def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]: for param in handler_query_params: is_union_or_optional = is_union(param.field_info.annotation) + if param.field_info.alias in body_fields: + continue + # Making sure all the optional and union types are included if is_union_or_optional: # field_info = should_skip_json_schema(param.field_info) @@ -143,7 +146,8 @@ def get_fields_from_routes( response_from_routes.append(response) # Get the params from the transformer - params = get_flat_params(handler) + body_fields_names = [field.alias for field in body_fields] + params = get_flat_params(handler, body_fields_names) if params: request_fields.extend(params) @@ -297,7 +301,12 @@ def get_openapi_path( if security_definitions: security_schemes.update(security_definitions) - all_route_params = get_flat_params(handler) + body_fields = [] + if handler.data_field: + body_fields.append(handler.data_field) + + body_fields_names = [field.alias for field in body_fields] + all_route_params = get_flat_params(handler, body_fields_names) operation_parameters = get_openapi_operation_parameters( all_route_params=all_route_params, field_mapping=field_mapping, @@ -323,9 +332,9 @@ def get_openapi_path( operation["requestBody"] = request_data_oai status_code = str(handler.status_code) - operation.setdefault("responses", {}).setdefault(status_code, {})[ - "description" - ] = handler.response_description + operation.setdefault("responses", {}).setdefault(status_code, {})["description"] = ( + handler.response_description + ) # Media type if route_response_media_type and is_status_code_allowed(handler.status_code): diff --git a/esmerald/routing/_internal.py b/esmerald/routing/_internal.py index a272c5e1..934c332a 100644 --- a/esmerald/routing/_internal.py +++ b/esmerald/routing/_internal.py @@ -2,14 +2,17 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, List, Union, _GenericAlias, cast, get_args +from lilya.datastructures import DataUpload from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo +from esmerald.datastructures import UploadFile from esmerald.encoders import ENCODER_TYPES, is_body_encoder from esmerald.enums import EncodingType from esmerald.openapi.params import ResponseParam from esmerald.params import Body from esmerald.utils.constants import DATA, PAYLOAD +from esmerald.utils.helpers import is_class_and_subclass from esmerald.utils.models import create_field_model if TYPE_CHECKING: @@ -81,8 +84,60 @@ def convert_annotation_to_pydantic_model(field_annotation: Any) -> Any: return field_annotation +def handle_upload_files(data: Any, body: Body) -> Body: + """ + Handles the creation of the body field for the upload files. + """ + # For Uploads and Multi Part + args = get_args(body.annotation) + name = "File" if not args else "Files" + + model = create_field_model(field=body, name=name, model_name=body.title) + data_field = Body(annotation=model, title=body.title) + + for key, _ in data._attributes_set.items(): + if key != "annotation": + setattr(data_field, key, getattr(body, key, None)) + return data_field + + +def get_upload_body(handler: Union["HTTPHandler"]) -> Any: + """ + This function repeats some of the steps but covers all the + cases for simple use cases. + """ + for name, _ in handler.signature_model.model_fields.items(): + data = handler.signature_model.model_fields[name] + + if not isinstance(data, Body): + body = Body(alias="body") + for key, _ in data._attributes_set.items(): + setattr(body, key, getattr(data, key, None)) + else: + body = data + + # Check the annotation type + body.annotation = convert_annotation_to_pydantic_model(body.annotation) + + if not body.title: + body.title = f"Body_{handler.operation_id}" + + # For everything else that is not MULTI_PART + extra = cast("Dict[str, Any]", body.json_schema_extra) or {} + if extra.get( + "media_type", EncodingType.JSON + ) != EncodingType.MULTI_PART and not is_class_and_subclass( + body.annotation, (UploadFile, DataUpload) + ): + continue + + # For Uploads and Multi Part + data_field = handle_upload_files(data, body) + return data_field + + def get_original_data_field( - handler: Union["HTTPHandler", "WebhookHandler", Any] + handler: Union["HTTPHandler", "WebhookHandler", Any], ) -> Any: # pragma: no cover """ The field used for the payload body. @@ -116,16 +171,7 @@ def get_original_data_field( return body # For Uploads and Multi Part - args = get_args(body.annotation) - name = "File" if not args else "Files" - - model = create_field_model(field=body, name=name, model_name=body.title) - data_field = Body(annotation=model, title=body.title) - - for key, _ in data._attributes_set.items(): - if key != "annotation": - setattr(data_field, key, getattr(body, key, None)) - + data_field = handle_upload_files(data, body) return data_field @@ -197,15 +243,18 @@ def get_data_field(handler: Union["HTTPHandler", "WebhookHandler", Any]) -> Any: """ # If there are no body fields, we simply return the original # default Esmerald body parsing - if not handler.body_encoder_fields: - return get_original_data_field(handler) - is_data_or_payload = ( DATA if DATA in handler.signature_model.model_fields else (PAYLOAD if PAYLOAD in handler.signature_model.model_fields else None) ) + if not handler.body_encoder_fields and is_data_or_payload: + return get_original_data_field(handler) + + if not handler.body_encoder_fields: + return get_upload_body(handler) + if len(handler.body_encoder_fields) < 2 and is_data_or_payload is not None: return get_original_data_field(handler) return get_complex_data_field(handler, fields=handler.body_encoder_fields) diff --git a/esmerald/routing/base.py b/esmerald/routing/base.py index 3b6676ae..ff1d9880 100644 --- a/esmerald/routing/base.py +++ b/esmerald/routing/base.py @@ -206,7 +206,12 @@ async def _get_response_data( # Check if there is request data if request_data is not None: # Assign each key-value pair in the request data to kwargs - if isinstance(request_data, (UploadFile, DataUpload)): + if isinstance(request_data, (UploadFile, DataUpload)) or ( + isinstance(request_data, (list, tuple)) + and any( + isinstance(value, (UploadFile, DataUpload)) for value in request_data + ) + ): for key, _ in kwargs.items(): kwargs[key] = request_data else: diff --git a/tests/openapi/test_upload_using_not_data.py b/tests/openapi/test_upload_using_not_data.py new file mode 100644 index 00000000..254c8bc5 --- /dev/null +++ b/tests/openapi/test_upload_using_not_data.py @@ -0,0 +1,113 @@ +from typing import Dict + +from esmerald import Gateway, UploadFile, post, status +from esmerald.params import File +from esmerald.testclient import create_client + + +@post("/upload", status_code=status.HTTP_200_OK) +async def upload_file(upload: UploadFile = File()) -> Dict[str, str]: + names = [] + for file in upload: + names.append(file.filename) + return {"names": names} + + +def test_openapi_schema(test_client_factory): + with create_client( + routes=[ + Gateway(handler=upload_file), + ], + enable_openapi=True, + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/upload": { + "post": { + "summary": "Upload File", + "description": "", + "operationId": "upload_file_upload_post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upload_file_upload_post" + } + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Body_upload_file_upload_post": { + "properties": { + "file": { + "type": "string", + "format": "binary", + "title": "Body_upload_file_upload_post", + } + }, + "type": "object", + "required": ["file"], + "title": "Body_upload_file_upload_post", + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + } diff --git a/tests/openapi/test_upload_using_not_data_list.py b/tests/openapi/test_upload_using_not_data_list.py new file mode 100644 index 00000000..e785cdae --- /dev/null +++ b/tests/openapi/test_upload_using_not_data_list.py @@ -0,0 +1,113 @@ +from typing import Dict, List + +from esmerald import Gateway, UploadFile, post, status +from esmerald.params import File +from esmerald.testclient import create_client + + +@post("/upload", status_code=status.HTTP_200_OK) +async def upload_file(upload: List[UploadFile] = File()) -> Dict[str, str]: + names = [] + for file in upload: + names.append(file.filename) + return {"names": names} + + +def test_openapi_schema(test_client_factory): + with create_client( + routes=[ + Gateway(handler=upload_file), + ], + enable_openapi=True, + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "Highly scalable, performant, easy to learn and for every application.", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/upload": { + "post": { + "summary": "Upload File", + "description": "", + "operationId": "upload_file_upload_post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upload_file_upload_post" + } + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "string"}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Body_upload_file_upload_post": { + "properties": { + "files": { + "items": {"type": "string", "format": "binary"}, + "type": "array", + "title": "Body_upload_file_upload_post", + } + }, + "type": "object", + "required": ["files"], + "title": "Body_upload_file_upload_post", + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + }