Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid collapsing exception groups from user code #2830

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
42 changes: 29 additions & 13 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import functools
import sys
import typing
from contextlib import contextmanager
from contextlib import asynccontextmanager

import anyio.abc

from starlette.types import Scope

Expand All @@ -13,12 +15,14 @@
else: # pragma: no cover
from typing_extensions import TypeGuard

has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
try:
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
from exceptiongroup import BaseExceptionGroup
except ImportError:
has_exceptiongroups = False

class BaseExceptionGroup(BaseException): # type: ignore[no-redef]
pass


T = typing.TypeVar("T")
AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
Expand Down Expand Up @@ -70,16 +74,28 @@ async def __aexit__(self, *args: typing.Any) -> None | bool:
return None


@contextmanager
def collapse_excgroups() -> typing.Generator[None, None, None]:
@asynccontextmanager
async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]:
try:
yield
except BaseException as exc:
if has_exceptiongroups: # pragma: no cover
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]

raise exc
async with anyio.create_task_group() as tg:
yield tg
except BaseExceptionGroup as excs:
if len(excs.exceptions) != 1:
graingert marked this conversation as resolved.
Show resolved Hide resolved
raise

exc = excs.exceptions[0]
context = exc.__context__
tb = exc.__traceback__
cause = exc.__cause__
sc = exc.__suppress_context__
try:
raise exc
finally:
exc.__traceback__ = tb
exc.__context__ = context
exc.__cause__ = cause
exc.__suppress_context__ = sc
del exc, cause, tb, context


def get_route_path(scope: Scope) -> str:
Expand Down
6 changes: 3 additions & 3 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import anyio

from starlette._utils import collapse_excgroups
from starlette._utils import create_collapsing_task_group
from starlette.requests import ClientDisconnect, Request
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -174,8 +174,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:

streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
send_stream, recv_stream = streams
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
with recv_stream, send_stream:
async with create_collapsing_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
Expand Down
3 changes: 2 additions & 1 deletion starlette/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import create_collapsing_task_group
from starlette.types import Receive, Scope, Send

warnings.warn(
Expand Down Expand Up @@ -102,7 +103,7 @@ async def __call__(self, receive: Receive, send: Send) -> None:
more_body = message.get("more_body", False)
environ = build_environ(self.scope, body)

async with anyio.create_task_group() as task_group:
async with create_collapsing_task_group() as task_group:
task_group.start_soon(self.sender, send)
async with self.stream_send:
await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
Expand Down
3 changes: 2 additions & 1 deletion starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import anyio
import anyio.to_thread

from starlette._utils import create_collapsing_task_group
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders
Expand Down Expand Up @@ -258,7 +259,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
except OSError:
raise ClientDisconnect()
else:
async with anyio.create_task_group() as task_group:
async with create_collapsing_task_group() as task_group:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also have an issue with StreamingResponse standalone?

With my comment on the other PR I was trying to avoid the changes here. 🤔

Copy link
Member Author

@graingert graingert Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's the same sort of issue, but it's wrapped in two TaskGroups before the user gets the exception. Collapsing is ok here because if wait_for_disconnect raises an exception it is always 'catastrophic' eg won't be caught


async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
Expand Down
20 changes: 16 additions & 4 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextvars
import sys
from collections.abc import AsyncGenerator, AsyncIterator, Generator
from contextlib import AsyncExitStack
from typing import Any
Expand All @@ -21,6 +22,9 @@
from starlette.websockets import WebSocket
from tests.types import TestClientFactory

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import ExceptionGroup


class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(
Expand All @@ -41,6 +45,10 @@ def exc(request: Request) -> None:
raise Exception("Exc")


def eg(request: Request) -> None:
raise ExceptionGroup("my exception group", [ValueError("TEST")])


def exc_stream(request: Request) -> StreamingResponse:
return StreamingResponse(_generate_faulty_stream())

Expand Down Expand Up @@ -76,6 +84,7 @@ async def websocket_endpoint(session: WebSocket) -> None:
routes=[
Route("/", endpoint=homepage),
Route("/exc", endpoint=exc),
Route("/eg", endpoint=eg),
Route("/exc-stream", endpoint=exc_stream),
Route("/no-response", endpoint=NoResponse),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
Expand All @@ -89,13 +98,16 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

with pytest.raises(Exception) as ctx:
with pytest.raises(Exception) as ctx1:
response = client.get("/exc")
assert str(ctx.value) == "Exc"
assert str(ctx1.value) == "Exc"

with pytest.raises(Exception) as ctx:
with pytest.raises(Exception) as ctx2:
response = client.get("/exc-stream")
assert str(ctx.value) == "Faulty Stream"
assert str(ctx2.value) == "Faulty Stream"

with pytest.raises(ExceptionGroup, match=r"my exception group \(1 sub-exception\)"):
client.get("/eg")

with pytest.raises(RuntimeError):
response = client.get("/no-response")
Expand Down
3 changes: 1 addition & 2 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import pytest

from starlette._utils import collapse_excgroups
from starlette.middleware.wsgi import WSGIMiddleware, build_environ
from tests.types import TestClientFactory

Expand Down Expand Up @@ -86,7 +85,7 @@ def test_wsgi_exception(test_client_factory: TestClientFactory) -> None:
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
with pytest.raises(RuntimeError), collapse_excgroups():
with pytest.raises(RuntimeError):
client.get("/")


Expand Down
34 changes: 33 additions & 1 deletion tests/test__utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import functools
import sys
from typing import Any

import pytest

from starlette._utils import get_route_path, is_async_callable
from starlette._utils import create_collapsing_task_group, get_route_path, is_async_callable
from starlette.types import Scope

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import ExceptionGroup


def test_async_func() -> None:
async def async_func() -> None: ... # pragma: no cover
Expand Down Expand Up @@ -94,3 +98,31 @@ async def async_func(
)
def test_get_route_path(scope: Scope, expected_result: str) -> None:
assert get_route_path(scope) == expected_result


@pytest.mark.anyio
async def test_collapsing_task_group_one_exc() -> None:
class MyException(Exception):
pass

with pytest.raises(MyException):
async with create_collapsing_task_group():
raise MyException


@pytest.mark.anyio
async def test_collapsing_task_group_two_exc() -> None:
class MyException(Exception):
pass

async def raise_exc() -> None:
raise MyException

with pytest.raises(ExceptionGroup) as exc:
async with create_collapsing_task_group() as task_group:
task_group.start_soon(raise_exc)
raise MyException

exc1, exc2 = exc.value.exceptions
assert isinstance(exc1, MyException)
assert isinstance(exc2, MyException)
Loading