Skip to content

Commit

Permalink
Merge branch 'master' into raise-exception-on-bkg-task
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Dec 28, 2024
2 parents 7ef64c9 + cef7834 commit ebe0a95
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 18 deletions.
6 changes: 3 additions & 3 deletions docs/requests.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ state with `disconnected = await request.is_disconnected()`.

Request files are normally sent as multipart form data (`multipart/form-data`).

Signature: `request.form(max_files=1000, max_fields=1000)`
Signature: `request.form(max_files=1000, max_fields=1000, max_part_size=1024*1024)`

You can configure the number of maximum fields or files with the parameters `max_files` and `max_fields`:
You can configure the number of maximum fields or files with the parameters `max_files` and `max_fields`; and part size using `max_part_size`:

```python
async with request.form(max_files=1000, max_fields=1000):
async with request.form(max_files=1000, max_fields=1000, max_part_size=1024*1024):
...
```

Expand Down
4 changes: 2 additions & 2 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
if has_exceptiongroups:
if has_exceptiongroups: # pragma: no cover
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0] # pragma: no cover
exc = exc.exceptions[0]

raise exc

Expand Down
3 changes: 2 additions & 1 deletion starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ async def parse(self) -> FormData:

class MultiPartParser:
max_file_size = 1024 * 1024 # 1MB
max_part_size = 1024 * 1024 # 1MB

def __init__(
self,
Expand All @@ -132,6 +131,7 @@ def __init__(
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024, # 1MB
) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
Expand All @@ -148,6 +148,7 @@ def __init__(
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: list[MultipartPart] = []
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
self.max_part_size = max_part_size

def on_part_begin(self) -> None:
self._current_part = MultipartPart()
Expand Down
4 changes: 2 additions & 2 deletions starlette/middleware/gzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
self.compresslevel = compresslevel

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
if scope["type"] == "http": # pragma: no branch
headers = Headers(scope=scope)
if "gzip" in headers.get("Accept-Encoding", ""):
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
Expand Down Expand Up @@ -88,7 +88,7 @@ async def send_with_gzip(self, message: Message) -> None:
await self.send(self.initial_message)
await self.send(message)

elif message_type == "http.response.body":
elif message_type == "http.response.body": # pragma: no branch
# Remaining body in streaming GZip response.
body = message.get("body", b"")
more_body = message.get("more_body", False)
Expand Down
23 changes: 18 additions & 5 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
self._stream_consumed = True
if body:
yield body
elif message["type"] == "http.disconnect":
elif message["type"] == "http.disconnect": # pragma: no branch
self._is_disconnected = True
raise ClientDisconnect()
yield b""
Expand All @@ -249,8 +249,14 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json

async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData:
if self._form is None:
async def _get_form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> FormData:
if self._form is None: # pragma: no branch
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
Expand All @@ -264,6 +270,7 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_part_size=max_part_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
Expand All @@ -278,9 +285,15 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl
return self._form

def form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields))
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
)

async def close(self) -> None:
if self._form is not None: # pragma: no branch
Expand Down
6 changes: 3 additions & 3 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
elif scope["type"] == "websocket":
elif scope["type"] == "websocket": # pragma: no branch
websocket_close = WebSocketClose()
await websocket_close(scope, receive, send)
return
Expand Down Expand Up @@ -398,7 +398,7 @@ def routes(self) -> list[BaseRoute]:

def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, typing.Any]
if scope["type"] in ("http", "websocket"):
if scope["type"] in ("http", "websocket"): # pragma: no branch
root_path = scope.get("root_path", "")
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
Expand Down Expand Up @@ -481,7 +481,7 @@ def routes(self) -> list[BaseRoute]:
return getattr(self.app, "routes", [])

def matches(self, scope: Scope) -> tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"):
if scope["type"] in ("http", "websocket"): # pragma:no branch
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
Expand Down
44 changes: 42 additions & 2 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
await response(scope, receive, send)


def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000) -> ASGIApp:
def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form(max_files=max_files, max_fields=max_fields)
data = await request.form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
output: dict[str, typing.Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
Expand Down Expand Up @@ -699,3 +699,43 @@ def test_max_part_size_exceeds_limit(
response = client.post("/", data=multipart_data, headers=headers) # type: ignore
assert response.status_code == 400
assert response.text == "Part exceeded maximum size of 1024KB."


@pytest.mark.parametrize(
"app,expectation",
[
(make_app_max_parts(max_part_size=1024 * 10), pytest.raises(MultiPartException)),
(
Starlette(routes=[Mount("/", app=make_app_max_parts(max_part_size=1024 * 10))]),
does_not_raise(),
),
],
)
def test_max_part_size_exceeds_custom_limit(
app: ASGIApp,
expectation: typing.ContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR"

multipart_data = (
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="small"\r\n\r\n'
"small content\r\n"
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="large"\r\n\r\n'
+ ("x" * 1024 * 10 + "x") # 1MB + 1 byte of data
+ "\r\n"
f"--{boundary}--\r\n"
).encode("utf-8")

headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Transfer-Encoding": "chunked",
}

with expectation:
response = client.post("/", content=multipart_data, headers=headers)
assert response.status_code == 400
assert response.text == "Part exceeded maximum size of 10KB."
10 changes: 10 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,16 @@ def test_url_for_with_double_mount() -> None:
assert url == "/mount/static/123"


def test_url_for_with_root_path_ending_with_slash(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> JSONResponse:
return JSONResponse({"index": str(request.url_for("homepage"))})

app = Starlette(routes=[Route("/", homepage, name="homepage")])
client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path/")
response = client.get("/sub_path/")
assert response.json() == {"index": "https://www.example.org/sub_path/"}


def test_standalone_route_matches(
test_client_factory: TestClientFactory,
) -> None:
Expand Down

0 comments on commit ebe0a95

Please sign in to comment.