Skip to content

Commit

Permalink
feat: Optional max_file_size when parsing form
Browse files Browse the repository at this point in the history
  • Loading branch information
khadrawy committed Oct 6, 2024
1 parent e116840 commit bc3b9c3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 3 deletions.
7 changes: 7 additions & 0 deletions docs/requests.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ async with request.form(max_files=1000, max_fields=1000):
...
```

You can configure maximum size per file uploaded with the parameter `max_file_size`:

```python
async with request.form(max_file_size=100*1024*1024): # 100 MB limit per file
...
```

!!! info
These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields.

Expand Down
8 changes: 8 additions & 0 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ def __init__(
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_file_size: int | float | None = None,
) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
self.max_fields = max_fields
self.max_part_file_size = max_part_file_size
self.items: list[tuple[str, str | UploadFile]] = []
self._current_files = 0
self._current_fields = 0
Expand Down Expand Up @@ -247,6 +249,12 @@ async def parse(self) -> FormData:
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
if (
self.max_part_file_size is not None
and part.file.size is not None
and part.file.size + len(data) > self.max_part_file_size
):
raise MultiPartException(f"File exceeds maximum size of {self.max_part_file_size} bytes.")
await part.file.write(data)
for part in self._file_parts_to_finish:
assert part.file # for type checkers
Expand Down
11 changes: 8 additions & 3 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json

async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData:
async def _get_form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_file_size: int | None
) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
Expand All @@ -260,6 +262,7 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_part_file_size=max_file_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
Expand All @@ -274,9 +277,11 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl
return self._form

def form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_file_size: int | None = None
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields))
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields, max_file_size=max_file_size)
)

async def close(self) -> None:
if self._form is not None: # pragma: no branch
Expand Down
48 changes: 48 additions & 0 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from secrets import token_bytes

import pytest

Expand Down Expand Up @@ -127,6 +128,29 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
return app


def make_app_max_file_size(max_file_size: int) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form(max_file_size=max_file_size)
output: dict[str, typing.Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
output[key] = {
"filename": value.filename,
"size": value.size,
"content": content.decode(),
"content_type": value.content_type,
}
else:
output[key] = value
await request.close()
response = JSONResponse(output)
await response(scope, receive, send)

return app


def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART)
Expand Down Expand Up @@ -580,6 +604,30 @@ def test_too_many_files_and_fields_raise(
assert res.text == "Too many files. Maximum number of files is 1000."


@pytest.mark.parametrize(
"app,expectation",
[
(make_app_max_file_size(1024), pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=make_app_max_file_size(1024))]), does_not_raise()),
],
)
def test_max_part_file_size_raise(
tmpdir: Path,
app: ASGIApp,
expectation: typing.ContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(token_bytes(1024 + 1))

client = test_client_factory(app)
with open(path, "rb") as f, expectation:
response = client.post("/", files={"test": f})
assert response.status_code == 400
assert response.text == "File exceeds maximum size of 1024 bytes."


@pytest.mark.parametrize(
"app,expectation",
[
Expand Down

0 comments on commit bc3b9c3

Please sign in to comment.