Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collect errors more reliably from websocket test client #2814

Merged
merged 17 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 36 additions & 43 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
import inspect
import io
import json
Expand All @@ -9,7 +10,6 @@
import sys
import typing
from concurrent.futures import Future
from functools import cached_property
from types import GeneratorType
from urllib.parse import unquote, urljoin

Expand Down Expand Up @@ -85,6 +85,14 @@ class WebSocketDenialResponse( # type: ignore[misc]
"""


class _Eof(enum.Enum):
EOF = enum.auto()


EOF: typing.Final = _Eof.EOF
Eof = typing.Literal[_Eof.EOF]


class WebSocketTestSession:
def __init__(
self,
Expand All @@ -97,63 +105,47 @@ def __init__(
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None

def __enter__(self) -> WebSocketTestSession:
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())

try:
_: Future[None] = self.portal.start_task_soon(self._run)
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())
fut, cs = portal.start_task(self._run)
stack.callback(fut.result)
stack.callback(portal.call, cs.cancel)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
except Exception:
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
return self

@cached_property
def should_close(self) -> anyio.Event:
return anyio.Event()

async def _notify_close(self) -> None:
self.should_close.set()
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self

def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
self.portal.start_task_soon(self._notify_close)
self.exit_stack.close()
while not self._send_queue.empty():
self.exit_stack.close()

while True:
message = self._send_queue.get()
if message is EOF:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an analogous to EOF from the standard library on 3.13?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It raises an exception

break
if isinstance(message, BaseException):
raise message
raise message # pragma: no cover (defensive, should be impossible)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it should be impossible?

The except BaseException as exc below doesn't have a pragma: no cover, so I assume it's being hit?

Copy link
Member Author

@graingert graingert Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is impossible because the exit stack will raise the exception out of fut.result() and so the queue won't be consumed.

This is only possible to be hit if ws.receive() is interrupted (eg with a KI) while waiting for an exception or message to be placed on the queue.

I'm currently sketching out another slight refactor that uses MemoryObjectStreams here instead that should clean this up a bit


async def _run(self) -> None:
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""

async def run_app(tg: anyio.abc.TaskGroup) -> None:
try:
try:
with anyio.CancelScope() as cs:
task_status.started(cs)
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except anyio.get_cancelled_exc_class():
...
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
tg.cancel_scope.cancel()

async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the if sys.version_info here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to stick to the EOF approach until someone puts up a Queue.shutdown backport, or we can use a MemoryObjectStream with portal


async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
Expand Down Expand Up @@ -202,6 +194,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None:

def receive(self) -> Message:
message = self._send_queue.get()
assert message is not EOF
if isinstance(message, BaseException):
raise message
return message
Expand Down
18 changes: 12 additions & 6 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,25 @@ async def asgi(receive: Receive, send: Send) -> None:


def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None:
cancelled = False

def app(scope: Scope) -> ASGIInstance:
async def asgi(receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
while True:
await anyio.sleep(0.1)
nonlocal cancelled
try:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await anyio.sleep_forever()
except anyio.get_cancelled_exc_class():
cancelled = True
raise
Comment on lines +267 to +269
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the cancelled exception here? Is this because this except was removed?

The addition of that except was on purpose, if I recall correctly.

Copy link
Member Author

@graingert graingert Dec 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is instead of websocket.should_close.is_set() as it no longer exists - so we need to find out that the async function ran and was cancelled

Copy link
Member Author

@graingert graingert Dec 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean to ask about this one: https://github.com/encode/starlette/pull/2814/files#diff-aca25e5f16c4fd49338ccdf3631f72455309335fee1e27f3d3b6016fa94ecedfL145 ?

this is because it's no longer possible to get a intentional cancelled_exc here because it will be caught by the CancelScope
if you do get an cancelled_exc here it's because someone incorrectly issued a native asyncio cancel or manually raised a trio cancel, both of which should be propagated out of the websocket_connect cmgr

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

So the point is: since the WebSocket doesn't call send it can't receive a disconnect exception, and since it doesn't call receive, it doesn't receive a websocket.disconnected message. Then, the TestClient here will propagate the cancellation exception here, but the user shouldn't worry about it because the TestClient will catch it anyway.

Right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we use a CancelScope passed out with task_status.started(cs), this is then used to cancel the task on completion and collect a result. The CancelScope catches the cancellation that it causes to raise in the coro, unless there's another reason for cancellation which is then propagated, this is a fatal case and so the user should be notified

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if there's a cancelled exception that is not from the TestClient, the user will see it, right?

I'm good with this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the only way is if the user was calling .cancel() directly or manually raising constructing a trio cancel (bypassing the private constructor protections)


return asgi

client = test_client_factory(app) # type: ignore
with client.websocket_connect("/") as websocket:
with client.websocket_connect("/"):
...
assert websocket.should_close.is_set()
assert cancelled


def test_client(test_client_factory: TestClientFactory) -> None:
Expand Down
Loading