From 2d6259544e8902eb0674c22b8815983df3136780 Mon Sep 17 00:00:00 2001 From: Mostafa Khalil Date: Sun, 6 Oct 2024 06:58:55 +0300 Subject: [PATCH] feat: Optional max_file_size when parsing form --- docs/requests.md | 7 ++++++ starlette/formparsers.py | 4 ++++ starlette/requests.py | 11 ++++++--- tests/test_formparsers.py | 48 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 3 deletions(-) diff --git a/docs/requests.md b/docs/requests.md index 9140bb964f..4046540d35 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -122,6 +122,13 @@ async with request.form(max_files=1000, max_fields=1000): ... ``` +You can configure maximum size per file uploaded with the parameter `max_file_size`: + +```python +async with request.form(max_file_size=100*1024*1024): # 100 MB limit per file + ... +``` + !!! info These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields. diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 48b0f2aad6..9420b56324 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -126,12 +126,14 @@ def __init__( *, max_files: int | float = 1000, max_fields: int | float = 1000, + max_part_file_size: int | None = None, ) -> None: assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.max_files = max_files self.max_fields = max_fields + self.max_part_file_size = max_part_file_size self.items: list[tuple[str, str | UploadFile]] = [] self._current_files = 0 self._current_fields = 0 @@ -247,6 +249,8 @@ async def parse(self) -> FormData: # the main thread. for part, data in self._file_parts_to_write: assert part.file # for type checkers + if self.max_part_file_size and part.file.size + len(data) > self.max_part_file_size: + raise MultiPartException(f"File exceeds maximum size of {self.max_part_file_size} bytes.") await part.file.write(data) for part in self._file_parts_to_finish: assert part.file # for type checkers diff --git a/starlette/requests.py b/starlette/requests.py index 9c6776f0c4..1ebec2b611 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -245,7 +245,9 @@ async def json(self) -> typing.Any: self._json = json.loads(body) return self._json - async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData: + async def _get_form( + self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_file_size: int | None + ) -> FormData: if self._form is None: assert ( parse_options_header is not None @@ -260,6 +262,7 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl self.stream(), max_files=max_files, max_fields=max_fields, + max_part_file_size=max_file_size, ) self._form = await multipart_parser.parse() except MultiPartException as exc: @@ -274,9 +277,11 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl return self._form def form( - self, *, max_files: int | float = 1000, max_fields: int | float = 1000 + self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_file_size: int | None = None ) -> AwaitableOrContextManager[FormData]: - return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields)) + return AwaitableOrContextManagerWrapper( + self._get_form(max_files=max_files, max_fields=max_fields, max_file_size=max_file_size) + ) async def close(self) -> None: if self._form is not None: # pragma: no branch diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 61c1bede19..5a96a62f57 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -4,6 +4,7 @@ import typing from contextlib import nullcontext as does_not_raise from pathlib import Path +from secrets import token_bytes import pytest @@ -127,6 +128,29 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: return app +def make_app_max_file_size(max_file_size: int) -> ASGIApp: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope, receive) + data = await request.form(max_file_size=max_file_size) + output: dict[str, typing.Any] = {} + for key, value in data.items(): + if isinstance(value, UploadFile): + content = await value.read() + output[key] = { + "filename": value.filename, + "size": value.size, + "content": content.decode(), + "content_type": value.content_type, + } + else: + output[key] = value + await request.close() + response = JSONResponse(output) + await response(scope, receive, send) + + return app + + def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) @@ -580,6 +604,30 @@ def test_too_many_files_and_fields_raise( assert res.text == "Too many files. Maximum number of files is 1000." +@pytest.mark.parametrize( + "app,expectation", + [ + (make_app_max_file_size(1024), pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=make_app_max_file_size(1024))]), does_not_raise()), + ], +) +def test_max_part_file_size_raise( + tmpdir: Path, + app: ASGIApp, + expectation: typing.ContextManager[Exception], + test_client_factory: TestClientFactory, +) -> None: + path = os.path.join(tmpdir, "test.txt") + with open(path, "wb") as file: + file.write(token_bytes(1024 + 1)) + + client = test_client_factory(app) + with open(path, "rb") as f, expectation: + response = client.post("/", files={"test": f}) + assert response.status_code == 400 + assert response.text == "File exceeds maximum size of 1024 bytes." + + @pytest.mark.parametrize( "app,expectation", [