diff --git a/httpx_ws/transport.py b/httpx_ws/transport.py index 616e5ea..445d927 100644 --- a/httpx_ws/transport.py +++ b/httpx_ws/transport.py @@ -51,7 +51,9 @@ async def __aenter__( ) self._aentered = True async with contextlib.AsyncExitStack() as stack: - self._task_group = await stack.enter_async_context(anyio.create_task_group()) + self._task_group = await stack.enter_async_context( + anyio.create_task_group() + ) self._task_group.start_soon(self._run) await self.send({"type": "websocket.connect"}) @@ -60,6 +62,7 @@ async def __aenter__( stack.push_async_callback(self.aclose) if message["type"] == "websocket.close": + await stack.aclose() raise WebSocketDisconnect(message["code"], message.get("reason")) assert message["type"] == "websocket.accept" @@ -67,7 +70,7 @@ async def __aenter__( self._exit_stack = stack.pop_all() return retval - async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool | None: + async def __aexit__(self, exc_type, exc_val, exc_tb) -> typing.Union[bool, None]: return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) async def read(