Skip to content

Commit

Permalink
Alternative approach: wrap mem streams into cm on the top level
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitagashkov committed Dec 8, 2024
1 parent 4693336 commit 5ff5f84
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()

send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()

async def call_next(request: Request) -> Response:
app_exc: Exception | None = None
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
if response_sent.is_set():
Expand Down Expand Up @@ -144,7 +145,7 @@ async def send_no_error(message: Message) -> None:
async def coro() -> None:
nonlocal app_exc

async with send_stream, recv_stream:
with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
Expand All @@ -158,11 +159,7 @@ async def coro() -> None:
info = message.get("info", None)
if message["type"] == "http.response.debug" and info is not None:
message = await recv_stream.receive()
except (
# we have to catch both exceptions, order of closing is not guaranteed
anyio.EndOfStream, # for `send_stream`
anyio.ClosedResourceError, # for `recv_stream`
):
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
raise RuntimeError("No response returned.")
Expand All @@ -187,9 +184,10 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:

with collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
with recv_stream, send_stream:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
Expand Down

0 comments on commit 5ff5f84

Please sign in to comment.