From ab4657451cf032f5cea16c2b1dd0cb958c6fe3b2 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 26 Jun 2022 16:15:13 -0400 Subject: [PATCH] Implement the new zigpy radio API (#117) * Begin implementing the new radio API * Use the correct signature for `load_network_info` * Use command and response IDs * Parse ZiGate logging messages * Implement `Status` type * Erase PDM when writing new settings * Handle responses and status callbacks in any order * Implement `GET_DEVICES_LIST` * Rename `ADDRESS_MODE` to `AddressMode` * Initialize the ZiGate device on startup * Only permit joins via the coordinator in `permit_ncp` * Use `schedule_initialize` to prevent double initialization * Add unhandled `NODE_DESCRIPTOR_RSP` response * Add a stub for `add_endpoint` * Set the TCLK's partner IEEE * Set the network information `source` and `metadata` * Fix unit tests * Bump minimum required zigpy version to 0.47.0 --- setup.py | 2 +- tests/test_api.py | 2 +- tests/test_application.py | 41 ++- tests/test_types.py | 10 +- zigpy_zigate/api.py | 419 ++++++++++++++++++++--------- zigpy_zigate/types.py | 106 +++++++- zigpy_zigate/zigbee/application.py | 189 ++++++++----- 7 files changed, 554 insertions(+), 215 deletions(-) diff --git a/setup.py b/setup.py index 542fa12..db5ed21 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ def is_raspberry_pi(raise_on_errors=False): 'pyserial-asyncio>=0.5; platform_system!="Windows"', 'pyserial-asyncio!=0.5; platform_system=="Windows"', # 0.5 broke writesv 'pyusb>=1.1.0', - 'zigpy>=0.22.2', + 'zigpy>=0.47.0', ] if is_raspberry_pi(): diff --git a/tests/test_api.py b/tests/test_api.py index 41e20d1..8da6513 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -103,7 +103,7 @@ async def test_probe_fail(mock_connect, mock_raw_mode, exception): @pytest.mark.asyncio -@patch.object(asyncio, "wait_for", side_effect=asyncio.TimeoutError) +@patch.object(asyncio, "wait", return_value=([], [])) async def test_api_command(mock_command, api): """Test command method.""" try: diff --git a/tests/test_application.py b/tests/test_application.py index b5164a9..04d48f6 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -3,6 +3,7 @@ import pytest import zigpy.types as zigpy_types +import zigpy.exceptions import zigpy_zigate.config as config import zigpy_zigate.types as t @@ -31,7 +32,7 @@ def test_zigpy_ieee(app): data = b"\x01\x02\x03\x04\x05\x06\x07\x08" zigate_ieee, _ = t.EUI64.deserialize(data) - app._ieee = zigpy_types.EUI64(zigate_ieee) + app.state.node_info.ieee = zigpy_types.EUI64(zigate_ieee) dst_addr = app.get_dst_address(cluster) assert dst_addr.serialize() == b"\x03" + data[::-1] + b"\x01" @@ -44,20 +45,47 @@ def test_model_detection(app): @pytest.mark.asyncio async def test_form_network_success(app): + app._api.erase_persistent_data = AsyncMock() app._api.set_channel = AsyncMock() + app._api.set_extended_panid = AsyncMock() app._api.reset = AsyncMock() + async def mock_start_network(): return [[0x00, 0x1234, 0x0123456789abcdef], 0] app._api.start_network = mock_start_network + + async def mock_get_network_state(): + return [ + [ + 0x0000, + t.EUI64([0xef, 0xcd, 0xab, 0x89, 0x67, 0x45, 0x23, 0x01]), + 0x1234, + 0x1234abcdef012345, + 0x11, + ], + 0, + ] + + app._api.get_network_state = mock_get_network_state + await app.form_network() - assert app._nwk == 0x1234 - assert app._ieee == 0x0123456789abcdef + await app.load_network_info() + assert app.state.node_info.nwk == 0x0000 + assert app.state.node_info.ieee == zigpy.types.EUI64.convert( + "01:23:45:67:89:ab:cd:ef" + ) + assert app.state.network_info.pan_id == 0x1234 + assert app.state.network_info.extended_pan_id == zigpy.types.ExtendedPanId.convert( + "12:34:ab:cd:ef:01:23:45" + ) assert app._api.reset.call_count == 0 @pytest.mark.asyncio async def test_form_network_failed(app): + app._api.erase_persistent_data = AsyncMock() app._api.set_channel = AsyncMock() + app._api.set_extended_panid = AsyncMock() app._api.reset = AsyncMock() async def mock_start_network(): return [[0x06], 0] @@ -65,7 +93,6 @@ async def mock_start_network(): async def mock_get_network_state(): return [[0xffff, 0x0123456789abcdef, 0x1234, 0, 0x11], 0] app._api.get_network_state = mock_get_network_state - await app.form_network() - assert app._nwk == 0 - assert app._ieee == 0 - assert app._api.reset.call_count == 1 + + with pytest.raises(zigpy.exceptions.FormationFailure): + await app.form_network() diff --git a/tests/test_types.py b/tests/test_types.py index 74f6409..a88603c 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -23,8 +23,8 @@ def test_deserialize(): assert result[2] == 0x0001 assert result[3] == 0x01 assert result[4] == 0x01 - assert result[5] == t.Address(address_mode=t.ADDRESS_MODE.NWK, address=t.NWK(0x1234)) - assert result[6] == t.Address(address_mode=t.ADDRESS_MODE.NWK, address=t.NWK(0xabcd)) + assert result[5] == t.Address(address_mode=t.AddressMode.NWK, address=t.NWK(0x1234)) + assert result[6] == t.Address(address_mode=t.AddressMode.NWK, address=t.NWK(0xabcd)) assert result[7] == b'\x01\x00\xBE\xEF' assert rest == b'' @@ -34,7 +34,7 @@ def test_deserialize(): assert result[0] == 0x00 assert result[1] == 0x01 assert result[2] == 0x01 - assert result[3] == t.Address(address_mode=t.ADDRESS_MODE.NWK, address=t.NWK(0x1234)) + assert result[3] == t.Address(address_mode=t.AddressMode.NWK, address=t.NWK(0x1234)) assert result[4] == 0xff data = b'\x00\x01\x01\x03\x12\x34\x56\x78\x9a\xbc\xde\xf0\xff' @@ -43,7 +43,7 @@ def test_deserialize(): assert result[0] == 0x00 assert result[1] == 0x01 assert result[2] == 0x01 - assert result[3] == t.Address(address_mode=t.ADDRESS_MODE.IEEE, + assert result[3] == t.Address(address_mode=t.AddressMode.IEEE, address=t.EUI64.deserialize(b'\x12\x34\x56\x78\x9a\xbc\xde\xf0')[0]) assert result[4] == 0xff @@ -73,7 +73,7 @@ def test_deserialize(): assert result[0] == 0x00 assert result[1] == 0x01 assert result[2] == 0x01 - assert result[3] == t.Address(address_mode=t.ADDRESS_MODE.NWK, + assert result[3] == t.Address(address_mode=t.AddressMode.NWK, address=t.NWK.deserialize(b'\xbc\x8c')[0]) assert result[4] == 0x73 assert len(result) == 5 diff --git a/zigpy_zigate/api.py b/zigpy_zigate/api.py index 2cd172e..7a70b2a 100644 --- a/zigpy_zigate/api.py +++ b/zigpy_zigate/api.py @@ -19,40 +19,155 @@ COMMAND_TIMEOUT = 1.5 PROBE_TIMEOUT = 3.0 + +class CommandId(enum.IntEnum): + SET_RAWMODE = 0x0002 + NETWORK_STATE_REQ = 0x0009 + GET_VERSION = 0x0010 + RESET = 0x0011 + ERASE_PERSISTENT_DATA = 0x0012 + GET_DEVICES_LIST = 0x0015 + SET_TIMESERVER = 0x0016 + GET_TIMESERVER = 0x0017 + SET_LED = 0x0018 + SET_CE_FCC = 0x0019 + SET_EXT_PANID = 0x0020 + SET_CHANNELMASK = 0x0021 + START_NETWORK = 0x0024 + NETWORK_REMOVE_DEVICE = 0x0026 + PERMIT_JOINING_REQUEST = 0x0049 + MANAGEMENT_NETWORK_UPDATE_REQUEST = 0x004A + SEND_RAW_APS_DATA_PACKET = 0x0530 + AHI_SET_TX_POWER = 0x0806 + + +class ResponseId(enum.IntEnum): + DEVICE_ANNOUNCE = 0x004D + STATUS = 0x8000 + LOG = 0x8001 + DATA_INDICATION = 0x8002 + PDM_LOADED = 0x0302 + NODE_NON_FACTORY_NEW_RESTART = 0x8006 + NODE_FACTORY_NEW_RESTART = 0x8007 + HEART_BEAT = 0x8008 + NETWORK_STATE_RSP = 0x8009 + VERSION_LIST = 0x8010 + ACK_DATA = 0x8011 + APS_DATA_CONFIRM = 0x8012 + PERMIT_JOIN_RSP = 0x8014 + GET_DEVICES_LIST_RSP = 0x8015 + GET_TIMESERVER_LIST = 0x8017 + NETWORK_JOINED_FORMED = 0x8024 + PDM_EVENT = 0x8035 + NODE_DESCRIPTOR_RSP = 0x8042 + LEAVE_INDICATION = 0x8048 + ROUTE_DISCOVERY_CONFIRM = 0x8701 + APS_DATA_CONFIRM_FAILED = 0x8702 + AHI_SET_TX_POWER_RSP = 0x8806 + EXTENDED_ERROR = 0x9999 + + + + +class NonFactoryNewRestartStatus(t.uint8_t, enum.Enum): + Startup = 0 + Running = 1 + Start = 2 + +class FactoryNewRestartStatus(t.uint8_t, enum.Enum): + Startup = 0 + Start = 2 + Running = 6 + + RESPONSES = { - 0x004D: (t.NWK, t.EUI64, t.uint8_t, t.uint8_t), - 0x8000: (t.uint8_t, t.uint8_t, t.uint16_t, t.Bytes), - 0x8002: (t.uint8_t, t.uint16_t, t.uint16_t, t.uint8_t, t.uint8_t, - t.Address, t.Address, t.Bytes), - 0x0302: (t.uint8_t,), - 0x8006: (t.uint8_t,), - 0x8007: (t.uint8_t,), - 0x8009: (t.NWK, t.EUI64, t.uint16_t, t.uint64_t, t.uint8_t), - 0x8010: (t.uint16_t, t.uint16_t), - 0x8011: (t.uint8_t, t.NWK, t.uint8_t, t.uint16_t, t.uint8_t), - 0x8012: (t.uint8_t, t.uint8_t, t.uint8_t, t.Address, t.uint8_t), - 0x8017: (t.uint32_t,), - 0x8024: (t.uint8_t, t.NWK, t.EUI64, t.uint8_t), - 0x8035: (t.uint8_t, t.uint32_t), - 0x8048: (t.EUI64, t.uint8_t), - 0x8701: (t.uint8_t, t.uint8_t), - 0x8702: (t.uint8_t, t.uint8_t, t.uint8_t, t.Address, t.uint8_t), - 0x8806: (t.uint8_t,), - 0x9999: (t.uint8_t,), + ResponseId.DEVICE_ANNOUNCE: (t.NWK, t.EUI64, t.uint8_t, t.uint8_t), + ResponseId.STATUS: (t.Status, t.uint8_t, t.uint16_t, t.Bytes), + ResponseId.LOG: (t.LogLevel, t.Bytes), + ResponseId.DATA_INDICATION: ( + t.Status, + t.uint16_t, + t.uint16_t, + t.uint8_t, + t.uint8_t, + t.Address, + t.Address, + t.Bytes, + ), + ResponseId.PDM_LOADED: (t.uint8_t,), + ResponseId.NODE_NON_FACTORY_NEW_RESTART: (NonFactoryNewRestartStatus,), + ResponseId.NODE_FACTORY_NEW_RESTART: (FactoryNewRestartStatus,), + ResponseId.HEART_BEAT: (t.uint32_t,), + ResponseId.NETWORK_STATE_RSP: (t.NWK, t.EUI64, t.uint16_t, t.uint64_t, t.uint8_t), + ResponseId.VERSION_LIST: (t.uint16_t, t.uint16_t), + ResponseId.ACK_DATA: (t.Status, t.NWK, t.uint8_t, t.uint16_t, t.uint8_t), + ResponseId.APS_DATA_CONFIRM: ( + t.Status, + t.uint8_t, + t.uint8_t, + t.Address, + t.uint8_t, + ), + ResponseId.PERMIT_JOIN_RSP: (t.uint8_t,), + ResponseId.GET_DEVICES_LIST_RSP: (t.DeviceEntryArray,), + ResponseId.GET_TIMESERVER_LIST: (t.uint32_t,), + ResponseId.NETWORK_JOINED_FORMED: (t.uint8_t, t.NWK, t.EUI64, t.uint8_t), + ResponseId.PDM_EVENT: (t.Status, t.uint32_t), + ResponseId.NODE_DESCRIPTOR_RSP: ( + t.uint8_t, + t.Status, + t.NWK, + t.uint16_t, + t.uint16_t, + t.uint16_t, + t.uint16_t, + t.uint8_t, + t.uint8_t, + t.uint8_t, + t.uint16_t, + ), + ResponseId.LEAVE_INDICATION: (t.EUI64, t.uint8_t), + ResponseId.ROUTE_DISCOVERY_CONFIRM: (t.uint8_t, t.uint8_t), + ResponseId.APS_DATA_CONFIRM_FAILED: ( + t.Status, + t.uint8_t, + t.uint8_t, + t.Address, + t.uint8_t, + ), + ResponseId.AHI_SET_TX_POWER_RSP: (t.uint8_t,), + ResponseId.EXTENDED_ERROR: (t.Status,), } COMMANDS = { - 0x0002: (t.uint8_t,), - 0x0016: (t.uint32_t,), - 0x0018: (t.uint8_t,), - 0x0019: (t.uint8_t,), - 0x0020: (t.uint64_t,), - 0x0021: (t.uint32_t,), - 0x0026: (t.EUI64, t.EUI64), - 0x0049: (t.NWK, t.uint8_t, t.uint8_t), - 0x004a: (t.NWK, t.uint32_t, t.uint8_t, t.uint8_t, t.uint8_t, t.uint16_t), - 0x0530: (t.uint8_t, t.NWK, t.uint8_t, t.uint8_t, t.uint16_t, t.uint16_t, t.uint8_t, t.uint8_t, t.LBytes), - 0x0806: (t.uint8_t,), + CommandId.SET_RAWMODE: (t.uint8_t,), + CommandId.SET_TIMESERVER: (t.uint32_t,), + CommandId.SET_LED: (t.uint8_t,), + CommandId.SET_CE_FCC: (t.uint8_t,), + CommandId.SET_EXT_PANID: (t.uint64_t,), + CommandId.SET_CHANNELMASK: (t.uint32_t,), + CommandId.NETWORK_REMOVE_DEVICE: (t.EUI64, t.EUI64), + CommandId.PERMIT_JOINING_REQUEST: (t.NWK, t.uint8_t, t.uint8_t), + CommandId.MANAGEMENT_NETWORK_UPDATE_REQUEST: ( + t.NWK, + t.uint32_t, + t.uint8_t, + t.uint8_t, + t.uint8_t, + t.uint16_t, + ), + CommandId.SEND_RAW_APS_DATA_PACKET: ( + t.uint8_t, + t.NWK, + t.uint8_t, + t.uint8_t, + t.uint16_t, + t.uint16_t, + t.uint8_t, + t.uint8_t, + t.LBytes, + ), + CommandId.AHI_SET_TX_POWER: (t.uint8_t,), } @@ -61,23 +176,23 @@ def _generate_next_value_(name, start, count, last_values): return count -class PDM_EVENT(AutoEnum): - E_PDM_SYSTEM_EVENT_WEAR_COUNT_TRIGGER_VALUE_REACHED = enum.auto() - E_PDM_SYSTEM_EVENT_DESCRIPTOR_SAVE_FAILED = enum.auto() - E_PDM_SYSTEM_EVENT_PDM_NOT_ENOUGH_SPACE = enum.auto() - E_PDM_SYSTEM_EVENT_LARGEST_RECORD_FULL_SAVE_NO_LONGER_POSSIBLE = enum.auto() - E_PDM_SYSTEM_EVENT_SEGMENT_DATA_CHECKSUM_FAIL = enum.auto() - E_PDM_SYSTEM_EVENT_SEGMENT_SAVE_OK = enum.auto() - E_PDM_SYSTEM_EVENT_EEPROM_SEGMENT_HEADER_REPAIRED = enum.auto() - E_PDM_SYSTEM_EVENT_SYSTEM_INTERNAL_BUFFER_WEAR_COUNT_SWAP = enum.auto() - E_PDM_SYSTEM_EVENT_SYSTEM_DUPLICATE_FILE_SEGMENT_DETECTED = enum.auto() - E_PDM_SYSTEM_EVENT_SYSTEM_ERROR = enum.auto() - E_PDM_SYSTEM_EVENT_SEGMENT_PREWRITE = enum.auto() - E_PDM_SYSTEM_EVENT_SEGMENT_POSTWRITE = enum.auto() - E_PDM_SYSTEM_EVENT_SEQUENCE_DUPLICATE_DETECTED = enum.auto() - E_PDM_SYSTEM_EVENT_SEQUENCE_VERIFY_FAIL = enum.auto() - E_PDM_SYSTEM_EVENT_PDM_SMART_SAVE = enum.auto() - E_PDM_SYSTEM_EVENT_PDM_FULL_SAVE = enum.auto() +class PDM_EVENT(enum.IntEnum): + E_PDM_SYSTEM_EVENT_WEAR_COUNT_TRIGGER_VALUE_REACHED = 0 + E_PDM_SYSTEM_EVENT_DESCRIPTOR_SAVE_FAILED = 1 + E_PDM_SYSTEM_EVENT_PDM_NOT_ENOUGH_SPACE = 2 + E_PDM_SYSTEM_EVENT_LARGEST_RECORD_FULL_SAVE_NO_LONGER_POSSIBLE = 3 + E_PDM_SYSTEM_EVENT_SEGMENT_DATA_CHECKSUM_FAIL = 4 + E_PDM_SYSTEM_EVENT_SEGMENT_SAVE_OK = 5 + E_PDM_SYSTEM_EVENT_EEPROM_SEGMENT_HEADER_REPAIRED = 6 + E_PDM_SYSTEM_EVENT_SYSTEM_INTERNAL_BUFFER_WEAR_COUNT_SWAP = 7 + E_PDM_SYSTEM_EVENT_SYSTEM_DUPLICATE_FILE_SEGMENT_DETECTED = 8 + E_PDM_SYSTEM_EVENT_SYSTEM_ERROR = 9 + E_PDM_SYSTEM_EVENT_SEGMENT_PREWRITE = 10 + E_PDM_SYSTEM_EVENT_SEGMENT_POSTWRITE = 11 + E_PDM_SYSTEM_EVENT_SEQUENCE_DUPLICATE_DETECTED = 12 + E_PDM_SYSTEM_EVENT_SEQUENCE_VERIFY_FAIL = 13 + E_PDM_SYSTEM_EVENT_PDM_SMART_SAVE = 14 + E_PDM_SYSTEM_EVENT_PDM_FULL_SAVE = 15 class NoResponseError(zigpy.exceptions.APIException): @@ -176,8 +291,9 @@ def data_received(self, cmd, data, lqi): if cmd not in RESPONSES: LOGGER.warning('Received unhandled response 0x%04x', cmd) return + cmd = ResponseId(cmd) data, rest = t.deserialize(data, RESPONSES[cmd]) - if cmd == 0x8000: + if cmd == ResponseId.STATUS: if data[2] in self._status_awaiting: fut = self._status_awaiting.pop(data[2]) fut.set_result((data, lqi)) @@ -186,65 +302,98 @@ def data_received(self, cmd, data, lqi): fut.set_result((data, lqi)) self.handle_callback(cmd, data, lqi) + async def wait_for_status(self, cmd): + LOGGER.debug('Wait for status to command %s', cmd) + + if cmd in self._status_awaiting: + self._status_awaiting[cmd].cancel() + + status_fut = asyncio.Future() + self._status_awaiting[cmd] = status_fut + + try: + return await status_fut + finally: + if cmd in self._status_awaiting: + self._status_awaiting[cmd].cancel() + del self._status_awaiting[cmd] + + async def wait_for_response(self, wait_response): + LOGGER.debug('Wait for response %s', wait_response) + + if wait_response in self._awaiting: + self._awaiting[wait_response].cancel() + + response_fut = asyncio.Future() + self._awaiting[wait_response] = response_fut + + try: + return await response_fut + finally: + if wait_response in self._awaiting: + self._awaiting[wait_response].cancel() + del self._awaiting[wait_response] + async def command(self, cmd, data=b'', wait_response=None, wait_status=True, timeout=COMMAND_TIMEOUT): - - await self._lock.acquire() - tries = 3 - result = None - status_fut = None - response_fut = None - while tries > 0: - if self._uart is None: - # connection was lost - self._lock.release() - raise CommandError("API is not running") - if wait_status: - status_fut = asyncio.Future() - self._status_awaiting[cmd] = status_fut - if wait_response: - response_fut = asyncio.Future() - self._awaiting[wait_response] = response_fut - tries -= 1 - self._uart.send(cmd, data) + async with self._lock: + tries = 3 + + tasks = [] + status_task = None + response_task = None + + LOGGER.debug( + "Sending %s (%s), waiting for status: %s, waiting for response: %s", + cmd, + data, + wait_status, + wait_response, + ) + if wait_status: - LOGGER.debug('Wait for status to command 0x%04x', cmd) - try: - result = await asyncio.wait_for(status_fut, timeout=timeout) - LOGGER.debug('Got status for 0x%04x : %s', cmd, result) - except asyncio.TimeoutError: - if cmd in self._status_awaiting: - del self._status_awaiting[cmd] - if response_fut and wait_response in self._awaiting: - del self._awaiting[wait_response] - LOGGER.warning("No response to command 0x%04x", cmd) - LOGGER.debug('Tries count %s', tries) - if tries > 0: - LOGGER.warning("Retry command 0x%04x", cmd) - continue - else: - self._lock.release() - raise NoStatusError - if wait_response: - LOGGER.debug('Wait for response 0x%04x', wait_response) - try: - result = await asyncio.wait_for(response_fut, timeout=timeout) - LOGGER.debug('Got response 0x%04x : %s', wait_response, result) - except asyncio.TimeoutError: - if wait_response in self._awaiting: - del self._awaiting[wait_response] - LOGGER.warning("No response waiting for 0x%04x", wait_response) - LOGGER.debug('Tries count %s', tries) - if tries > 0: - LOGGER.warning("Retry command 0x%04x", cmd) - continue - else: - self._lock.release() - raise NoResponseError - self._lock.release() - return result + status_task = asyncio.create_task(self.wait_for_status(cmd)) + tasks.append(status_task) + + if wait_response is not None: + response_task = asyncio.create_task(self.wait_for_response(wait_response)) + tasks.append(response_task) + + try: + while tries > 0: + if self._uart is None: + # connection was lost + raise CommandError("API is not running") + + tries -= 1 + self._uart.send(cmd, data) + + done, pending = await asyncio.wait(tasks, timeout=timeout) + + if wait_status and tries == 0 and status_task in pending: + raise NoStatusError() + elif wait_response and tries == 0 and response_task in pending: + raise NoResponseError() + + if wait_response and response_task in done: + if wait_status and status_task in pending: + continue + elif wait_status: + await status_task + + return await response_task + elif wait_status and status_task in done: + return await status_task + elif not wait_response and not wait_status: + return + finally: + for t in tasks: + if not t.done(): + t.cancel() + + await asyncio.gather(*tasks, return_exceptions=True) async def version(self): - return await self.command(0x0010, wait_response=0x8010) + return await self.command(CommandId.GET_VERSION, wait_response=ResponseId.VERSION_LIST) async def version_str(self): version, lqi = await self.version() @@ -253,17 +402,20 @@ async def version_str(self): return version async def get_network_state(self): - return await self.command(0x0009, wait_response=0x8009) + return await self.command(CommandId.NETWORK_STATE_REQ, wait_response=ResponseId.NETWORK_STATE_RSP) async def set_raw_mode(self, enable=True): - data = t.serialize([enable], COMMANDS[0x0002]) - await self.command(0x0002, data) + data = t.serialize([enable], COMMANDS[CommandId.SET_RAWMODE]) + await self.command(CommandId.SET_RAWMODE, data) - async def reset(self): - await self.command(0x0011, wait_response=0x8006) + async def reset(self, *, wait=True): + wait_response = ResponseId.NODE_NON_FACTORY_NEW_RESTART if wait else None + await self.command(CommandId.RESET, wait_response=wait_response) async def erase_persistent_data(self): - await self.command(0x0012, wait_status=False) + await self.command(CommandId.ERASE_PERSISTENT_DATA, wait_status=False, wait_response=ResponseId.PDM_LOADED, timeout=10) + await asyncio.sleep(1) + await self.command(CommandId.RESET, wait_response=ResponseId.NODE_FACTORY_NEW_RESTART) async def set_time(self, dt=None): """ set internal time @@ -271,34 +423,34 @@ async def set_time(self, dt=None): """ dt = dt or datetime.datetime.now() timestamp = int((dt - datetime.datetime(2000, 1, 1)).total_seconds()) - data = t.serialize([timestamp], COMMANDS[0x0016]) - await self.command(0x0016, data) + data = t.serialize([timestamp], COMMANDS[CommandId.SET_TIMESERVER]) + await self.command(CommandId.SET_TIMESERVER, data) async def get_time_server(self): - timestamp, lqi = await self.command(0x0017, wait_response=0x8017) + timestamp, lqi = await self.command(CommandId.GET_TIMESERVER, wait_response=ResponseId.GET_TIMESERVER_LIST) dt = datetime.datetime(2000, 1, 1) + datetime.timedelta(seconds=timestamp[0]) return dt async def set_led(self, enable=True): - data = t.serialize([enable], COMMANDS[0x0018]) - await self.command(0x0018, data) + data = t.serialize([enable], COMMANDS[CommandId.SET_LED]) + await self.command(CommandId.SET_LED, data) async def set_certification(self, typ='CE'): cert = {'CE': 1, 'FCC': 2}[typ] - data = t.serialize([cert], COMMANDS[0x0019]) - await self.command(0x0019, data) + data = t.serialize([cert], COMMANDS[CommandId.SET_CE_FCC]) + await self.command(CommandId.SET_CE_FCC, data) async def management_network_request(self): - data = t.serialize([0x0000, 0x07fff800, 0xff, 5, 0xff, 0x0000], COMMANDS[0x004a]) - return await self.command(0x004a)#, wait_response=0x804a, timeout=10) + data = t.serialize([0x0000, 0x07fff800, 0xff, 5, 0xff, 0x0000], COMMANDS[CommandId.MANAGEMENT_NETWORK_UPDATE_REQUEST]) + return await self.command(CommandId.MANAGEMENT_NETWORK_UPDATE_REQUEST)#, wait_response=0x804a, timeout=10) async def set_tx_power(self, power=63): if power > 63: power = 63 if power < 0: power = 0 - data = t.serialize([power], COMMANDS[0x0806]) - power, lqi = await self.command(0x0806, data, wait_response=0x8806) + data = t.serialize([power], COMMANDS[CommandId.AHI_SET_TX_POWER]) + power, lqi = await self.command(CommandId.AHI_SET_TX_POWER, data, wait_response=CommandId.AHI_SET_TX_POWER_RSP) return power[0] async def set_channel(self, channels=None): @@ -306,23 +458,28 @@ async def set_channel(self, channels=None): if not isinstance(channels, list): channels = [channels] mask = functools.reduce(lambda acc, x: acc ^ 2 ** x, channels, 0) - data = t.serialize([mask], COMMANDS[0x0021]) - await self.command(0x0021, data), + data = t.serialize([mask], COMMANDS[CommandId.SET_CHANNELMASK]) + await self.command(CommandId.SET_CHANNELMASK, data) async def set_extended_panid(self, extended_pan_id): - data = t.serialize([extended_pan_id], COMMANDS[0x0020]) - await self.command(0x0020, data) + data = t.serialize([extended_pan_id], COMMANDS[CommandId.SET_EXT_PANID]) + await self.command(CommandId.SET_EXT_PANID, data) + + async def get_devices_list(self): + (entries,), lqi = await self.command(CommandId.GET_DEVICES_LIST, wait_response=ResponseId.GET_DEVICES_LIST_RSP) + + return list(entries or []) async def permit_join(self, duration=60): - data = t.serialize([0xfffc, duration, 0], COMMANDS[0x0049]) - return await self.command(0x0049, data) + data = t.serialize([0x0000, duration, 1], COMMANDS[CommandId.PERMIT_JOINING_REQUEST]) + return await self.command(CommandId.PERMIT_JOINING_REQUEST, data) async def start_network(self): - return await self.command(0x0024, wait_response=0x8024) + return await self.command(CommandId.START_NETWORK, wait_response=ResponseId.NETWORK_JOINED_FORMED) async def remove_device(self, zigate_ieee, ieee): - data = t.serialize([zigate_ieee, ieee], COMMANDS[0x0026]) - return await self.command(0x0026, data) + data = t.serialize([zigate_ieee, ieee], COMMANDS[CommandId.NETWORK_REMOVE_DEVICE]) + return await self.command(CommandId.NETWORK_REMOVE_DEVICE, data) async def raw_aps_data_request(self, addr, src_ep, dst_ep, profile, cluster, payload, addr_mode=2, security=0): @@ -332,8 +489,8 @@ async def raw_aps_data_request(self, addr, src_ep, dst_ep, profile, radius = 0 data = t.serialize([addr_mode, addr, src_ep, dst_ep, cluster, profile, - security, radius, payload], COMMANDS[0x0530]) - return await self.command(0x0530, data) + security, radius, payload], COMMANDS[CommandId.SEND_RAW_APS_DATA_PACKET]) + return await self.command(CommandId.SEND_RAW_APS_DATA_PACKET, data) def handle_callback(self, *args): """run application callback handler""" diff --git a/zigpy_zigate/types.py b/zigpy_zigate/types.py index 3e4ae8d..5f49502 100644 --- a/zigpy_zigate/types.py +++ b/zigpy_zigate/types.py @@ -141,7 +141,7 @@ def __str__(self): return "0x{:04x}".format(self) -class ADDRESS_MODE(uint8_t, enum.Enum): +class AddressMode(uint8_t, enum.Enum): # Address modes used in zigate protocol GROUP = 0x01 @@ -149,6 +149,76 @@ class ADDRESS_MODE(uint8_t, enum.Enum): IEEE = 0x03 +class Status(uint8_t, enum.Enum): + Success = 0x00 + IncorrectParams = 0x01 + UnhandledCommand = 0x02 + CommandFailed = 0x03 + Busy = 0x04 + StackAlreadyStarted = 0x05 + + # Errors below are due to resource shortage, retrying may succeed OR There are no + # free Network PDUs. The number of NPDUs is set in the “Number of NPDUs” property + # of the “PDU Manager” section of the config editor + ResourceShortage = 0x80 + # There are no free Application PDUs. The number of APDUs is set in the “Instances” + # property of the appropriate “APDU” child of the “PDU Manager” section of the + # config editor + NoFreeAppPDUs = 0x81 + # There are no free simultaneous data request handles. The number of handles is set + # in the “Maximum Number of Simultaneous Data Requests” field of the “APS layer + # configuration” section of the config editor + NoFreeDataReqHandles = 0x82 + # There are no free APS acknowledgement handles. The number of handles is set in + # the “Maximum Number of Simultaneous Data Requests with Acks” field of the “APS + # layer configuration” section of the config editor + NoFreeAPSAckHandles = 0x83 + # There are no free fragment record handles. The number of handles is set in + # the “Maximum Number of Transmitted Simultaneous Fragmented Messages” field of + # the “APS layer configuration” section of the config editor + NoFreeFragRecHandles = 0x84 + # There are no free MCPS request descriptors. There are 8 MCPS request descriptors. + # These are only ever likely to be exhausted under very heavy network load or when + # trying to transmit too many frames too close together. + NoFreeMCPSReqDesc = 0x85 + # The loop back send is currently busy. There can be only one loopback request at a + # time. + LoopbackSendBusy = 0x86 + # There are no free entries in the extended address table. The extended address + # table is configured in the config editor + NoFreeExtAddrTableEntries = 0x87 + # The simple descriptor does not exist for this endpoint / cluster. + SimpleDescDoesNotExist = 0x88 + # A bad parameter has been found while processing an APSDE request or response + BadAPSDEParam = 0x89 + # No free Routing table entries left + NoFreeRoutingTableEntries = 0x8A + # No free BTR entries left. + NoFreeBTREntries = 0x8B + + @classmethod + def _missing_(cls, value): + if not isinstance(value, int): + raise ValueError(f"{value} is not a valid {cls.__name__}") + + new_member = cls._member_type_.__new__(cls, value) + new_member._name_ = f"unknown_0x{value:02X}" + new_member._value_ = cls._member_type_(value) + + return new_member + + +class LogLevel(uint8_t, enum.Enum): + Emergency = 0 + Alert = 1 + Critical = 2 + Error = 3 + Warning = 4 + Notice = 5 + Information = 6 + Debug = 7 + + class Struct: _fields = [] @@ -191,7 +261,7 @@ def __repr__(self): class Address(Struct): _fields = [ - ('address_mode', ADDRESS_MODE), + ('address_mode', AddressMode), ('address', EUI64), ] @@ -205,9 +275,37 @@ def deserialize(cls, data): mode, data = field_type.deserialize(data) setattr(r, field_name, mode) v = None - if mode in [ADDRESS_MODE.GROUP, ADDRESS_MODE.NWK]: + if mode in [AddressMode.GROUP, AddressMode.NWK]: v, data = NWK.deserialize(data) - elif mode == ADDRESS_MODE.IEEE: + elif mode == AddressMode.IEEE: v, data = EUI64.deserialize(data) setattr(r, cls._fields[1][0], v) return r, data + + +class DeviceEntry(Struct): + _fields = [ + ("id", uint8_t), + ("short_addr", NWK), + ("ieee_addr", EUI64), + ("power_source", uint8_t), + ("link_quality", uint8_t), + ] + + +class DeviceEntryArray(tuple): + @classmethod + def deserialize(cls, data): + if len(data) % 13 != 0: + raise ValueError("Data is not an array of DeviceEntry") + + entries = [] + + while data: + entry, data = DeviceEntry.deserialize(data) + entries.append(entry) + + return cls(entries), data + + def serialize(self): + return b"".join([e.serialize() for e in self]) diff --git a/zigpy_zigate/zigbee/application.py b/zigpy_zigate/zigbee/application.py index 8edce92..3624249 100644 --- a/zigpy_zigate/zigbee/application.py +++ b/zigpy_zigate/zigbee/application.py @@ -7,10 +7,13 @@ import zigpy.device import zigpy.types import zigpy.util +import zigpy.zdo +import zigpy.exceptions +import zigpy_zigate from zigpy_zigate import types as t from zigpy_zigate import common as c -from zigpy_zigate.api import NoResponseError, ZiGate, PDM_EVENT +from zigpy_zigate.api import NoResponseError, ZiGate, CommandId, ResponseId, PDM_EVENT from zigpy_zigate.config import CONF_DEVICE, CONF_DEVICE_PATH, CONFIG_SCHEMA, SCHEMA_DEVICE LOGGER = logging.getLogger(__name__) @@ -29,80 +32,134 @@ def __init__(self, config: Dict[str, Any]): self._pending = {} self._pending_join = [] - self._nwk = 0 - self._ieee = 0 self.version = '' - async def startup(self, auto_form=False): - """Perform a complete application startup""" - self._api = await ZiGate.new(self._config[CONF_DEVICE], self) - await self._api.set_raw_mode() - await self._api.set_time() - version, lqi = await self._api.version() - version = '{:x}'.format(version[1]) - version = '{}.{}'.format(version[0], version[1:]) - self.version = version - if version < '3.21': - LOGGER.warning('Old ZiGate firmware detected, you should upgrade to 3.21 or newer') + async def connect(self): + api = await ZiGate.new(self._config[CONF_DEVICE], self) + await api.set_raw_mode() + await api.set_time() + version, lqi = await api.version() - network_state, lqi = await self._api.get_network_state() - should_form = not network_state or network_state[0] == 0xffff or network_state[3] == 0 + hex_version = f"{version[1]:x}" + self.version = f"{hex_version[0]}.{hex_version[1:]}" + self._api = api - if auto_form and should_form: - await self.form_network() - if should_form: - network_state, lqi = await self._api.get_network_state() - self._nwk = network_state[0] - self._ieee = zigpy.types.EUI64(network_state[1]) + if self.version < '3.21': + LOGGER.warning('Old ZiGate firmware detected, you should upgrade to 3.21 or newer') - dev = ZiGateDevice(self, self._ieee, self._nwk) - self.devices[dev.ieee] = dev + async def disconnect(self): + # TODO: how do you stop the network? Is it possible? + await self._api.reset(wait=False) - async def shutdown(self): - """Shutdown application.""" if self._api: self._api.close() + self._api = None + + async def start_network(self): + # TODO: how do you start the network? Is it always automatically started? + dev = ZiGateDevice(self, self.state.node_info.ieee, self.state.node_info.nwk) + self.devices[dev.ieee] = dev + await dev.schedule_initialize() + + async def load_network_info(self, *, load_devices: bool = False): + network_state, lqi = await self._api.get_network_state() + + if not network_state or network_state[3] == 0 or network_state[0] == 0xffff: + raise zigpy.exceptions.NetworkNotFormed() + + self.state.node_info = zigpy.state.NodeInfo( + nwk=zigpy.types.NWK(network_state[0]), + ieee=zigpy.types.EUI64(network_state[1]), + logical_type=zigpy.zdo.types.LogicalType.Coordinator, + ) + + epid, _ = zigpy.types.ExtendedPanId.deserialize(zigpy.types.uint64_t(network_state[3]).serialize()) + + self.state.network_info = zigpy.state.NetworkInfo( + source=f"zigpy-zigate@{zigpy_zigate.__version__}", + extended_pan_id=epid, + pan_id=zigpy.types.PanId(network_state[2]), + nwk_update_id=0, + nwk_manager_id=zigpy.types.NWK(0x0000), + channel=network_state[4], + channel_mask=zigpy.types.Channels.from_channel_list([network_state[4]]), + security_level=5, + # TODO: is it possible to read keys? + # network_key=zigpy.state.Key(), + # tc_link_key=zigpy.state.Key(), + children=[], + key_table=[], + nwk_addresses={}, + stack_specific={}, + metadata={ + "zigate": { + "version": self.version, + } + } + ) - async def form_network(self, channel=None, pan_id=None, extended_pan_id=None): - await self._api.set_channel(channel) - if pan_id: - LOGGER.warning('Setting pan_id is not supported by ZiGate') -# self._api.set_panid(pan_id) - if extended_pan_id: - await self._api.set_extended_panid(extended_pan_id) + self.state.network_info.tc_link_key.partner_ieee = self.state.node_info.ieee + + if not load_devices: + return + + for device in await self._api.get_devices_list(): + if device.power_source != 0: # only battery-powered devices + continue + + ieee = zigpy.types.EUI64(device.ieee_addr) + self.state.network_info.children.append(ieee) + self.state.network_info.nwk_addresses[ieee] = zigpy.types.NWK(device.short_addr) + + async def write_network_info(self, *, network_info, node_info): + LOGGER.warning('Setting the pan_id is not supported by ZiGate') + + await self._api.erase_persistent_data() + + await self._api.set_channel(network_info.channel) + + epid, _ = zigpy.types.uint64_t.deserialize(network_info.extended_pan_id.serialize()) + await self._api.set_extended_panid(epid) network_formed, lqi = await self._api.start_network() - if network_formed[0] in (0, 1, 4): - LOGGER.info('Network started %s %s', - network_formed[1], - network_formed[2]) - self._nwk = network_formed[1] - self._ieee = network_formed[2] - else: - LOGGER.warning('Starting network got status %s, wait...', network_formed[0]) - tries = 3 - while tries > 0: - await asyncio.sleep(1) - tries -= 1 - network_state, lqi = await self._api.get_network_state() - if network_state and network_state[3] != 0 and network_state[0] != 0xffff: - break - if tries <= 0: - LOGGER.error('Failed to start network error %s', network_formed[0]) - LOGGER.debug('Resetting ZiGate') - await self._api.reset() + + if network_formed[0] not in ( + t.Status.Success, + t.Status.IncorrectParams, + t.Status.Busy, + ): + raise zigpy.exceptions.FormationFailure( + f"Unexpected error starting network: {network_formed!r}" + ) + + LOGGER.warning('Starting network got status %s, wait...', network_formed[0]) + for attempt in range(3): + await asyncio.sleep(1) + + try: + await self.load_network_info() + except zigpy.exceptions.NetworkNotFormed as e: + if attempt == 2: + raise zigpy.exceptions.FormationFailure() from e + + async def permit_with_key(self, node, code, time_s = 60): + LOGGER.warning("ZiGate does not support joins with install codes") async def force_remove(self, dev): - await self._api.remove_device(self._ieee, dev.ieee) + await self._api.remove_device(self.state.node_info.ieee, dev.ieee) + + async def add_endpoint(self, descriptor): + # ZiGate does not support adding new endpoints + pass def zigate_callback_handler(self, msg, response, lqi): - LOGGER.debug('zigate_callback_handler {}'.format(response)) + LOGGER.debug('zigate_callback_handler %s %s', msg, response) - if msg == 0x8048: # leave + if msg == ResponseId.LEAVE_INDICATION: nwk = 0 ieee = zigpy.types.EUI64(response[0]) self.handle_leave(nwk, ieee) - elif msg == 0x004D: # join + elif msg == ResponseId.DEVICE_ANNOUNCE: nwk = response[0] ieee = zigpy.types.EUI64(response[1]) parent_nwk = 0 @@ -117,7 +174,7 @@ def zigate_callback_handler(self, msg, response, lqi): # else: # LOGGER.debug('Start pairing {} (1st device announce)'.format(nwk)) # self._pending_join.append(nwk) - elif msg == 0x8002: + elif msg == ResponseId.DATA_INDICATION: if response[1] == 0x0 and response[2] == 0x13: nwk = zigpy.types.NWK(response[5].address) ieee = zigpy.types.EUI64(response[7][3:11]) @@ -125,9 +182,9 @@ def zigate_callback_handler(self, msg, response, lqi): self.handle_join(nwk, ieee, parent_nwk) return try: - if response[5].address_mode == t.ADDRESS_MODE.NWK: + if response[5].address_mode == t.AddressMode.NWK: device = self.get_device(nwk = zigpy.types.NWK(response[5].address)) - elif response[5].address_mode == t.ADDRESS_MODE.IEEE: + elif response[5].address_mode == t.AddressMode.IEEE: device = self.get_device(ieee=zigpy.types.EUI64(response[5].address)) else: LOGGER.error("No such device %s", response[5].address) @@ -140,22 +197,22 @@ def zigate_callback_handler(self, msg, response, lqi): self.handle_message(device, response[1], response[2], response[3], response[4], response[-1]) - elif msg == 0x8011: # ACK Data + elif msg == ResponseId.ACK_DATA: LOGGER.debug('ACK Data received %s %s', response[4], response[0]) # disabled because of https://github.com/fairecasoimeme/ZiGate/issues/324 # self._handle_frame_failure(response[4], response[0]) - elif msg == 0x8012: # ZPS Event + elif msg == ResponseId.APS_DATA_CONFIRM: LOGGER.debug('ZPS Event APS data confirm, message routed to %s %s', response[3], response[0]) - elif msg == 0x8035: # PDM Event + elif msg == ResponseId.PDM_EVENT: try: event = PDM_EVENT(response[0]).name except ValueError: event = 'Unknown event' LOGGER.debug('PDM Event %s %s, record %s', response[0], event, response[1]) - elif msg == 0x8702: # APS Data confirm Fail + elif msg == ResponseId.APS_DATA_CONFIRM_FAILED: LOGGER.debug('APS Data confirm Fail %s %s', response[4], response[0]) self._handle_frame_failure(response[4], response[0]) - elif msg == 0x9999: # ZCL event + elif msg == ResponseId.EXTENDED_ERROR: LOGGER.warning('Extended error code %s', response[0]) def _handle_frame_failure(self, message_tag, status): @@ -190,7 +247,7 @@ async def _request(self, nwk, profile, cluster, src_ep, dst_ep, sequence, data, send_fut = asyncio.Future() self._pending[req_id] = send_fut - if v[0] != 0: + if v[0] != t.Status.Success: self._pending.pop(req_id) return v[0], "Message send failure {}".format(v[0]) @@ -207,7 +264,7 @@ async def _request(self, nwk, profile, cluster, src_ep, dst_ep, sequence, data, async def permit_ncp(self, time_s=60): assert 0 <= time_s <= 254 status, lqi = await self._api.permit_join(time_s) - if status[0] != 0: + if status[0] != t.Status.Success: await self._api.reset() async def broadcast(self, profile, cluster, src_ep, dst_ep, grpid, radius,