Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix validation of parameters requiring a body #467

Merged
merged 4 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/en/docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ hide:

### Fixed

- `data` and `payload` special kwargs are now allowed when a not-bodyless method is available for the handler. They default to None.
- `bytes` won't be encoded as json when returned from a handler. This would unexpectly lead to a base64 encoding.
- SessionConfig has a unneccessarily heavily restricted secret_key parameter.
- Gracefully handle situations where cookies are None in `get_cookies`.
- Fix validation of parameters requiring a body.

## 3.6.1

Expand Down
43 changes: 35 additions & 8 deletions esmerald/routing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from esmerald.interceptors.types import Interceptor
from esmerald.openapi.datastructures import OpenAPIResponse
from esmerald.openapi.utils import is_status_code_allowed
from esmerald.params import Form
from esmerald.requests import Request
from esmerald.responses import Response
from esmerald.routing._internal import OpenAPIFieldInfoMixin
Expand All @@ -66,7 +67,7 @@
from esmerald.transformers.utils import get_signature
from esmerald.typing import Void, VoidType
from esmerald.utils.constants import DATA, PAYLOAD, REDIRECT_STATUS_CODES, REQUEST, SOCKET
from esmerald.utils.helpers import is_async_callable, is_class_and_subclass
from esmerald.utils.helpers import is_async_callable, is_class_and_subclass, is_optional_union
from esmerald.websockets import WebSocket, WebSocketClose

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -1917,6 +1918,9 @@ def wrapper(func: Callable) -> Callable:
return wrapper


_body_less_methods = frozenset({"GET", "HEAD", "OPTIONS", "TRACE"})


class HTTPHandler(Dispatcher, OpenAPIFieldInfoMixin, LilyaPath):
__slots__ = (
"path",
Expand Down Expand Up @@ -2218,24 +2222,47 @@ def validate_annotations(self) -> None: # pragma: no cover
]:
self.media_type = MediaType.TEXT

def validate_bodyless_kwargs(self) -> None:
if _body_less_methods.isdisjoint(self.methods):
return
body_less_only = _body_less_methods.issuperset(self.methods)
for special in [DATA, PAYLOAD]:
if special in self.handler_signature.parameters:
if body_less_only:
raise ImproperlyConfigured(
f"'{special}' argument unsupported when only body-less methods like 'GET' and 'HEAD' are handled"
)
elif not is_optional_union(self.handler_signature.parameters[special].annotation):
raise ImproperlyConfigured(
f"'{special}' argument must be optional when body-less methods like 'GET' and 'HEAD' are handled"
)
for parameter_name, parameter in self.handler_signature.parameters.items():
# don't check twice
if parameter_name == DATA or parameter_name == PAYLOAD:
continue
if isinstance(parameter.default, Form):
if body_less_only:
raise ImproperlyConfigured(
f"'{special}' uses Form() which is unsupported when only body-less methods "
"like 'GET' and 'HEAD' are handled"
)
elif not is_optional_union(parameter.annotation):
raise ImproperlyConfigured(
f"'{special}' argument must be optional when body-less methods like 'GET' and 'HEAD' are handled"
)

def validate_reserved_kwargs(self) -> None: # pragma: no cover
"""
Validates if special words are in the signature.
"""
if DATA in self.handler_signature.parameters and "GET" in self.methods:
raise ImproperlyConfigured("'data' argument is unsupported for 'GET' request handlers")

if PAYLOAD in self.handler_signature.parameters and "GET" in self.methods:
raise ImproperlyConfigured(
"'payload' argument is unsupported for 'GET' request handlers"
)

if SOCKET in self.handler_signature.parameters:
raise ImproperlyConfigured("The 'socket' argument is not supported with http handlers")

def validate_handler(self) -> None:
self.check_handler_function()
self.validate_annotations()
self.validate_bodyless_kwargs()
self.validate_reserved_kwargs()

async def to_response(self, app: "Esmerald", data: Any) -> LilyaResponse:
Expand Down
50 changes: 28 additions & 22 deletions tests/forms/test_forms_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass
from typing import Any, Dict
from typing import Any, Dict, Optional

import pytest
from pydantic import BaseModel
from pydantic.dataclasses import dataclass as pydantic_dataclass

from esmerald import Form, Gateway, post
from esmerald import Form, Gateway, post, route
from esmerald.exceptions import ImproperlyConfigured
from esmerald.testclient import create_client


Expand All @@ -25,27 +27,11 @@ class UserModel(BaseModel):
name: str


@post("/form")
async def test_form(data: Any = Form()) -> Dict[str, str]:
return {"name": data["name"]}


@post("/complex-form-pydantic")
async def test_complex_form_pydantic_dataclass(data: User = Form()) -> User:
return data


@post("/complex-form-dataclass")
async def test_complex_form_dataclass(data: UserOut = Form()) -> UserOut:
return data


@post("/complex-form-basemodel")
async def test_complex_form_basemodel(data: UserModel = Form()) -> UserModel:
return data


def test_send_form(test_client_factory):
@post("/form")
async def test_form(data: Any = Form()) -> Dict[str, str]:
return {"name": data["name"]}

data = {"name": "Test"}

with create_client(routes=[Gateway(handler=test_form)]) as client:
Expand All @@ -56,6 +42,10 @@ def test_send_form(test_client_factory):


def test_send_complex_form_pydantic_dataclass(test_client_factory):
@post("/complex-form-pydantic")
async def test_complex_form_pydantic_dataclass(data: User = Form()) -> User:
return data

data = {"id": 1, "name": "Test"}
with create_client(
routes=[Gateway(handler=test_complex_form_pydantic_dataclass)],
Expand All @@ -67,6 +57,10 @@ def test_send_complex_form_pydantic_dataclass(test_client_factory):


def test_send_complex_form_normal_dataclass(test_client_factory):
@post("/complex-form-dataclass")
async def test_complex_form_dataclass(data: UserOut = Form()) -> UserOut:
return data

data = {"id": 1, "name": "Test"}
with create_client(
routes=[Gateway(handler=test_complex_form_dataclass)],
Expand All @@ -78,6 +72,10 @@ def test_send_complex_form_normal_dataclass(test_client_factory):


def test_send_complex_form_base_model(test_client_factory):
@post("/complex-form-basemodel")
async def test_complex_form_basemodel(data: UserModel = Form()) -> UserModel:
return data

data = {"id": 1, "name": "Test"}
with create_client(
routes=[Gateway(handler=test_complex_form_basemodel)],
Expand All @@ -86,3 +84,11 @@ def test_send_complex_form_base_model(test_client_factory):
response = client.post("/complex-form-basemodel", data=data)
assert response.status_code == 201, response.text
assert response.json() == {"id": 1, "name": "Test"}


def test_get_and_head_data():
with pytest.raises(ImproperlyConfigured):

@route(methods=["GET", "HEAD"])
async def start(data: Optional[UserModel]) -> bytes:
return b"hello world"
50 changes: 28 additions & 22 deletions tests/forms/test_forms_payload.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass
from typing import Any, Dict
from typing import Any, Dict, Optional

import pytest
from pydantic import BaseModel
from pydantic.dataclasses import dataclass as pydantic_dataclass

from esmerald import Form, Gateway, post
from esmerald import Form, Gateway, post, route
from esmerald.exceptions import ImproperlyConfigured
from esmerald.testclient import create_client


Expand All @@ -25,27 +27,11 @@ class UserModel(BaseModel):
name: str


@post("/form")
async def test_form(payload: Any = Form()) -> Dict[str, str]:
return {"name": payload["name"]}


@post("/complex-form-pydantic")
async def test_complex_form_pydantic_dataclass(payload: User = Form()) -> User:
return payload


@post("/complex-form-dataclass")
async def test_complex_form_dataclass(payload: UserOut = Form()) -> UserOut:
return payload


@post("/complex-form-basemodel")
async def test_complex_form_basemodel(payload: UserModel = Form()) -> UserModel:
return payload


def test_send_form(test_client_factory):
@post("/form")
async def test_form(payload: Any = Form()) -> Dict[str, str]:
return {"name": payload["name"]}

payload = {"name": "Test"}

with create_client(routes=[Gateway(handler=test_form)]) as client:
Expand All @@ -56,6 +42,10 @@ def test_send_form(test_client_factory):


def test_send_complex_form_pydantic_dataclass(test_client_factory):
@post("/complex-form-pydantic")
async def test_complex_form_pydantic_dataclass(payload: User = Form()) -> User:
return payload

payload = {"id": 1, "name": "Test"}
with create_client(
routes=[Gateway(handler=test_complex_form_pydantic_dataclass)],
Expand All @@ -67,6 +57,10 @@ def test_send_complex_form_pydantic_dataclass(test_client_factory):


def test_send_complex_form_normal_dataclass(test_client_factory):
@post("/complex-form-dataclass")
async def test_complex_form_dataclass(payload: UserOut = Form()) -> UserOut:
return payload

payload = {"id": 1, "name": "Test"}
with create_client(
routes=[Gateway(handler=test_complex_form_dataclass)],
Expand All @@ -78,6 +72,10 @@ def test_send_complex_form_normal_dataclass(test_client_factory):


def test_send_complex_form_base_model(test_client_factory):
@post("/complex-form-basemodel")
async def test_complex_form_basemodel(payload: UserModel = Form()) -> UserModel:
return payload

payload = {"id": 1, "name": "Test"}
with create_client(
routes=[Gateway(handler=test_complex_form_basemodel)],
Expand All @@ -86,3 +84,11 @@ def test_send_complex_form_base_model(test_client_factory):
response = client.post("/complex-form-basemodel", data=payload)
assert response.status_code == 201, response.text
assert response.json() == {"id": 1, "name": "Test"}


def test_get_and_head_payload():
with pytest.raises(ImproperlyConfigured):

@route(methods=["GET", "HEAD"])
async def start(payload: Optional[UserModel]) -> bytes:
return b"hello world"
50 changes: 50 additions & 0 deletions tests/forms/test_forms_route.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional, Union

import pytest
from pydantic import BaseModel

from esmerald import Esmerald, Form, Request
from esmerald.exceptions import ImproperlyConfigured
from esmerald.routing.gateways import Gateway
from esmerald.routing.handlers import route
from esmerald.testclient import EsmeraldTestClient


class Model(BaseModel):
id: str


def test_get_and_post():
@route(methods=["GET", "POST"])
async def start(request: Request, form: Union[Model, None] = Form()) -> bytes:
return b"hello world"

app = Esmerald(
debug=True,
routes=[Gateway("/", handler=start)],
)
client = EsmeraldTestClient(app)
response = client.get("/")
assert response.status_code == 200


def test_get_and_post_optional():
@route(methods=["GET", "POST"])
async def start(request: Request, form: Optional[Model] = Form()) -> bytes:
return b"hello world"

app = Esmerald(
debug=True,
routes=[Gateway("/", handler=start)],
)
client = EsmeraldTestClient(app)
response = client.get("/")
assert response.status_code == 200


def test_get_and_head_form():
with pytest.raises(ImproperlyConfigured):

@route(methods=["GET", "HEAD"])
async def start(form: Optional[Model] = Form()) -> bytes:
return b"hello world"
Loading
Loading