Skip to content

Commit

Permalink
OpenAPI - Different keywords other than data for upload files (#431)
Browse files Browse the repository at this point in the history
* Allow keywords different from data for upload
* Fix request data and Upload files via OpenAPI
  • Loading branch information
tarsil authored Nov 13, 2024
1 parent 5df293a commit f01a22c
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 21 deletions.
21 changes: 15 additions & 6 deletions esmerald/openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
TRANSFORMER_TYPES_KEYS += ADDITIONAL_TYPES


def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]:
def get_flat_params(route: Union[router.HTTPHandler, Any], body_fields: List[str]) -> List[Any]:
"""
Gets all the neded params of the request and route.
"""
Expand All @@ -75,6 +75,9 @@ def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]:
for param in handler_query_params:
is_union_or_optional = is_union(param.field_info.annotation)

if param.field_info.alias in body_fields:
continue

# Making sure all the optional and union types are included
if is_union_or_optional:
# field_info = should_skip_json_schema(param.field_info)
Expand Down Expand Up @@ -143,7 +146,8 @@ def get_fields_from_routes(
response_from_routes.append(response)

# Get the params from the transformer
params = get_flat_params(handler)
body_fields_names = [field.alias for field in body_fields]
params = get_flat_params(handler, body_fields_names)
if params:
request_fields.extend(params)

Expand Down Expand Up @@ -297,7 +301,12 @@ def get_openapi_path(
if security_definitions:
security_schemes.update(security_definitions)

all_route_params = get_flat_params(handler)
body_fields = []
if handler.data_field:
body_fields.append(handler.data_field)

body_fields_names = [field.alias for field in body_fields]
all_route_params = get_flat_params(handler, body_fields_names)
operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params,
field_mapping=field_mapping,
Expand All @@ -323,9 +332,9 @@ def get_openapi_path(
operation["requestBody"] = request_data_oai

status_code = str(handler.status_code)
operation.setdefault("responses", {}).setdefault(status_code, {})[
"description"
] = handler.response_description
operation.setdefault("responses", {}).setdefault(status_code, {})["description"] = (
handler.response_description
)

# Media type
if route_response_media_type and is_status_code_allowed(handler.status_code):
Expand Down
77 changes: 63 additions & 14 deletions esmerald/routing/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Union, _GenericAlias, cast, get_args

from lilya.datastructures import DataUpload
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo

from esmerald.datastructures import UploadFile
from esmerald.encoders import ENCODER_TYPES, is_body_encoder
from esmerald.enums import EncodingType
from esmerald.openapi.params import ResponseParam
from esmerald.params import Body
from esmerald.utils.constants import DATA, PAYLOAD
from esmerald.utils.helpers import is_class_and_subclass
from esmerald.utils.models import create_field_model

if TYPE_CHECKING:
Expand Down Expand Up @@ -81,8 +84,60 @@ def convert_annotation_to_pydantic_model(field_annotation: Any) -> Any:
return field_annotation


def handle_upload_files(data: Any, body: Body) -> Body:
"""
Handles the creation of the body field for the upload files.
"""
# For Uploads and Multi Part
args = get_args(body.annotation)
name = "File" if not args else "Files"

model = create_field_model(field=body, name=name, model_name=body.title)
data_field = Body(annotation=model, title=body.title)

for key, _ in data._attributes_set.items():
if key != "annotation":
setattr(data_field, key, getattr(body, key, None))
return data_field


def get_upload_body(handler: Union["HTTPHandler"]) -> Any:
"""
This function repeats some of the steps but covers all the
cases for simple use cases.
"""
for name, _ in handler.signature_model.model_fields.items():
data = handler.signature_model.model_fields[name]

if not isinstance(data, Body):
body = Body(alias="body")
for key, _ in data._attributes_set.items():
setattr(body, key, getattr(data, key, None))
else:
body = data

# Check the annotation type
body.annotation = convert_annotation_to_pydantic_model(body.annotation)

if not body.title:
body.title = f"Body_{handler.operation_id}"

# For everything else that is not MULTI_PART
extra = cast("Dict[str, Any]", body.json_schema_extra) or {}
if extra.get(
"media_type", EncodingType.JSON
) != EncodingType.MULTI_PART and not is_class_and_subclass(
body.annotation, (UploadFile, DataUpload)
):
continue

# For Uploads and Multi Part
data_field = handle_upload_files(data, body)
return data_field


def get_original_data_field(
handler: Union["HTTPHandler", "WebhookHandler", Any]
handler: Union["HTTPHandler", "WebhookHandler", Any],
) -> Any: # pragma: no cover
"""
The field used for the payload body.
Expand Down Expand Up @@ -116,16 +171,7 @@ def get_original_data_field(
return body

# For Uploads and Multi Part
args = get_args(body.annotation)
name = "File" if not args else "Files"

model = create_field_model(field=body, name=name, model_name=body.title)
data_field = Body(annotation=model, title=body.title)

for key, _ in data._attributes_set.items():
if key != "annotation":
setattr(data_field, key, getattr(body, key, None))

data_field = handle_upload_files(data, body)
return data_field


Expand Down Expand Up @@ -197,15 +243,18 @@ def get_data_field(handler: Union["HTTPHandler", "WebhookHandler", Any]) -> Any:
"""
# If there are no body fields, we simply return the original
# default Esmerald body parsing
if not handler.body_encoder_fields:
return get_original_data_field(handler)

is_data_or_payload = (
DATA
if DATA in handler.signature_model.model_fields
else (PAYLOAD if PAYLOAD in handler.signature_model.model_fields else None)
)

if not handler.body_encoder_fields and is_data_or_payload:
return get_original_data_field(handler)

if not handler.body_encoder_fields:
return get_upload_body(handler)

if len(handler.body_encoder_fields) < 2 and is_data_or_payload is not None:
return get_original_data_field(handler)
return get_complex_data_field(handler, fields=handler.body_encoder_fields)
Expand Down
7 changes: 6 additions & 1 deletion esmerald/routing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,12 @@ async def _get_response_data(
# Check if there is request data
if request_data is not None:
# Assign each key-value pair in the request data to kwargs
if isinstance(request_data, (UploadFile, DataUpload)):
if isinstance(request_data, (UploadFile, DataUpload)) or (
isinstance(request_data, (list, tuple))
and any(
isinstance(value, (UploadFile, DataUpload)) for value in request_data
)
):
for key, _ in kwargs.items():
kwargs[key] = request_data
else:
Expand Down
113 changes: 113 additions & 0 deletions tests/openapi/test_upload_using_not_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Dict

from esmerald import Gateway, UploadFile, post, status
from esmerald.params import File
from esmerald.testclient import create_client


@post("/upload", status_code=status.HTTP_200_OK)
async def upload_file(upload: UploadFile = File()) -> Dict[str, str]:
names = []
for file in upload:
names.append(file.filename)
return {"names": names}


def test_openapi_schema(test_client_factory):
with create_client(
routes=[
Gateway(handler=upload_file),
],
enable_openapi=True,
) as client:
response = client.get("/openapi.json")
assert response.status_code == 200, response.text

assert response.json() == {
"openapi": "3.1.0",
"info": {
"title": "Esmerald",
"summary": "Esmerald application",
"description": "Highly scalable, performant, easy to learn and for every application.",
"contact": {"name": "admin", "email": "[email protected]"},
"version": client.app.version,
},
"servers": [{"url": "/"}],
"paths": {
"/upload": {
"post": {
"summary": "Upload File",
"description": "",
"operationId": "upload_file_upload_post",
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"$ref": "#/components/schemas/Body_upload_file_upload_post"
}
}
},
"required": True,
},
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "string"}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"deprecated": False,
}
}
},
"components": {
"schemas": {
"Body_upload_file_upload_post": {
"properties": {
"file": {
"type": "string",
"format": "binary",
"title": "Body_upload_file_upload_post",
}
},
"type": "object",
"required": ["file"],
"title": "Body_upload_file_upload_post",
},
"HTTPValidationError": {
"properties": {
"detail": {
"items": {"$ref": "#/components/schemas/ValidationError"},
"type": "array",
"title": "Detail",
}
},
"type": "object",
"title": "HTTPValidationError",
},
"ValidationError": {
"properties": {
"loc": {
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
"type": "array",
"title": "Location",
},
"msg": {"type": "string", "title": "Message"},
"type": {"type": "string", "title": "Error Type"},
},
"type": "object",
"required": ["loc", "msg", "type"],
"title": "ValidationError",
},
}
},
}
Loading

0 comments on commit f01a22c

Please sign in to comment.