Skip to content

Commit

Permalink
Add option timeout to WebSocketTestSession::receive
Browse files Browse the repository at this point in the history
  • Loading branch information
mbergen authored and mbergen committed Nov 7, 2023
1 parent 2c066af commit dd08c5a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/testclient.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ May raise `starlette.websockets.WebSocketDisconnect` if the application does not
* `.receive_bytes()` - Wait for incoming bytestring sent by the application and return it.
* `.receive_json(mode="text")` - Wait for incoming json data sent by the application and return it. Use `mode="binary"` to receive JSON over binary data frames.

Pass `timeout=timeout_in_seconds` to raise `Empty` when no message is received within `timeout_in_seconds` to prevent blocking.
May raise `starlette.websockets.WebSocketDisconnect`.

#### Closing the connection
Expand Down
18 changes: 10 additions & 8 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,25 +156,27 @@ def send_json(self, data: typing.Any, mode: str = "text") -> None:
def close(self, code: int = 1000, reason: typing.Union[str, None] = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})

def receive(self) -> Message:
message = self._send_queue.get()
def receive(self, timeout: typing.Optional[float] = None) -> Message:
message = self._send_queue.get(timeout=timeout)
if isinstance(message, BaseException):
raise message
return message

def receive_text(self) -> str:
message = self.receive()
def receive_text(self, timeout: typing.Optional[float] = None) -> str:
message = self.receive(timeout=timeout)
self._raise_on_close(message)
return typing.cast(str, message["text"])

def receive_bytes(self) -> bytes:
message = self.receive()
def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes:
message = self.receive(timeout=timeout)
self._raise_on_close(message)
return typing.cast(bytes, message["bytes"])

def receive_json(self, mode: str = "text") -> typing.Any:
def receive_json(
self, mode: str = "text", timeout: typing.Optional[float] = None
) -> typing.Any:
assert mode in ["text", "binary"]
message = self.receive()
message = self.receive(timeout=timeout)
self._raise_on_close(message)
if mode == "text":
text = message["text"]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from asyncio import current_task as asyncio_current_task
from contextlib import asynccontextmanager
from queue import Empty
from typing import Callable

import anyio
Expand Down Expand Up @@ -254,6 +255,20 @@ async def asgi(receive, send):
assert data == {"message": "test"}


def test_websocket_timeout_receive(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(Empty):
websocket.receive_json(timeout=1.0)


def test_client(test_client_factory):
async def app(scope, receive, send):
client = scope.get("client")
Expand Down

0 comments on commit dd08c5a

Please sign in to comment.