Skip to content

Commit

Permalink
fix: fixed issue with multiple ping loops running
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeFoodPixels committed Feb 26, 2024
1 parent 79892aa commit 3e2b923
Showing 1 changed file with 43 additions and 26 deletions.
69 changes: 43 additions & 26 deletions custom_components/robovac/tuyalocalapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def __init__(
self.writer = None
self._response_task = None
self._recieve_task = None
self._ping_task = None
self._handlers: dict[int, Callable[[Message], Coroutine]] = {
Message.GRATUITOUS_UPDATE: self.async_gratuitous_update_state,
Message.PING_COMMAND: self._async_pong_received,
Expand Down Expand Up @@ -718,11 +719,15 @@ async def async_connect(self):
loop.create_connection
self.reader, self.writer = await asyncio.open_connection(sock=sock)
self._connected = True
asyncio.create_task(self.async_ping(self.ping_interval))

if self._ping_task is None:
self.ping_task = asyncio.create_task(self.async_ping(self.ping_interval))

asyncio.create_task(self._async_handle_message())

async def async_disable(self):
self._enabled = False

await self.async_disconnect()

async def async_disconnect(self):
Expand All @@ -732,11 +737,6 @@ async def async_disconnect(self):
_LOGGER.debug("Disconnected from {}".format(self))
self._connected = False
self.last_pong = 0
if self._response_task is not None:
self._response_task.cancel()

if self._recieve_task is not None:
self._recieve_task.cancel()

if self.writer is not None:
self.writer.close()
Expand Down Expand Up @@ -780,7 +780,7 @@ async def async_ping(self, ping_interval):
self._queue.append(message)

await asyncio.sleep(ping_interval)
asyncio.create_task(self.async_ping(self.ping_interval))
self.ping_task = asyncio.create_task(self.async_ping(self.ping_interval))
if self.last_pong < self.last_ping:
await self.async_disconnect()

Expand All @@ -792,7 +792,11 @@ async def async_gratuitous_update_state(self, state_message):
await self.update_entity_state_cb()

async def async_update_state(self, state_message, _=None):
if state_message.payload and state_message.payload["dps"]:
if (
state_message is not None
and state_message.payload
and state_message.payload["dps"]
):
self._dps.update(state_message.payload["dps"])
_LOGGER.debug("Received updated state {}: {}".format(self, self._dps))

Expand All @@ -805,7 +809,7 @@ def state_setter(self, new_values):
asyncio.create_task(self.async_set(new_values))

async def _async_handle_message(self):
if self._enabled is False:
if self._enabled is False or self._connected is False:
return

try:
Expand All @@ -815,17 +819,16 @@ async def _async_handle_message(self):
await self._response_task
response_data = self._response_task.result()
message = Message.from_bytes(response_data, self.cipher)
self._response_task = None
except InvalidMessage as e:
_LOGGER.debug("Invalid message from {}: {}".format(self, e))
except MessageDecodeFailed as e:
_LOGGER.debug("Failed to decrypt message from {}".format(self))
except asyncio.IncompleteReadError as e:
self._response_task = None
if self._connected:
_LOGGER.debug("Incomplete read")
except ConnectionResetError as e:
_LOGGER.debug("Connection reset")
except Exception as e:
if isinstance(e, InvalidMessage):
_LOGGER.debug("Invalid message from {}: {}".format(self, e))
elif isinstance(e, MessageDecodeFailed):
_LOGGER.debug("Failed to decrypt message from {}".format(self))
elif isinstance(e, asyncio.IncompleteReadError):
if self._connected:
_LOGGER.debug("Incomplete read")
elif isinstance(e, ConnectionResetError):
_LOGGER.debug("Connection reset")

else:
_LOGGER.debug("Received message from {}: {}".format(self, message))
Expand All @@ -839,6 +842,7 @@ async def _async_handle_message(self):
if handler is not None:
asyncio.create_task(handler(message))

self._response_task = None
asyncio.create_task(self._async_handle_message())

async def _async_send(self, message, retries=2):
Expand All @@ -850,7 +854,8 @@ async def _async_send(self, message, retries=2):
except Exception as e:
if retries == 0:
if isinstance(e, socket.error):
asyncio.create_task(self.async_disconnect())
await self.async_disconnect()

raise ConnectionException(
"Connection to {} failed: {}".format(self, e)
)
Expand Down Expand Up @@ -881,18 +886,30 @@ async def _async_send(self, message, retries=2):
await self._async_send(message, retries=retries - 1)

async def async_recieve(self, message):
if self._connected is False:
return

if message.expect_response is True:
try:
self._recieve_task = asyncio.create_task(
asyncio.wait_for(message.listener.acquire(), timeout=self.timeout)
)
await self._recieve_task
return self._listeners.pop(message.sequence)
response = self._listeners.pop(message.sequence)

if isinstance(response, Exception):
raise response

return response
except Exception as e:
del self._listeners[message.sequence]
await self.async_disconnect()

raise ResponseTimeoutException(
"Timed out waiting for response to sequence number {}".format(
message.sequence
if isinstance(e, TimeoutError):
raise ResponseTimeoutException(
"Timed out waiting for response to sequence number {}".format(
message.sequence
)
)
)

raise e

0 comments on commit 3e2b923

Please sign in to comment.