Skip to content

Commit

Permalink
Use znp.disconnect instead of znp.close
Browse files Browse the repository at this point in the history
  • Loading branch information
puddly committed Oct 27, 2024
1 parent 377e029 commit 7a390ca
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 26 deletions.
24 changes: 12 additions & 12 deletions tests/api/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def test_connect_no_test(make_znp_server):
# Nothing will be sent
assert znp_server._uart.data_received.call_count == 0

znp.close()
await znp.disconnect()


@pytest.mark.parametrize("work_after_attempt", [1, 2, 3])
Expand All @@ -44,7 +44,7 @@ def ping_rsp(req):

await znp.connect(test_port=True)

znp.close()
await znp.disconnect()


async def test_connect_skip_bootloader_batched_rsp(make_znp_server, mocker):
Expand Down Expand Up @@ -82,7 +82,7 @@ def ping_rsp(req):

await znp.connect(test_port=True)

znp.close()
await znp.disconnect()


async def test_connect_skip_bootloader_failure(make_znp_server):
Expand All @@ -92,7 +92,7 @@ async def test_connect_skip_bootloader_failure(make_znp_server):
with pytest.raises(asyncio.TimeoutError):
await znp.connect(test_port=True)

znp.close()
await znp.disconnect()


async def test_connect_skip_bootloader_rts_dtr_pins(make_znp_server, mocker):
Expand All @@ -112,7 +112,7 @@ async def test_connect_skip_bootloader_rts_dtr_pins(make_znp_server, mocker):
assert serial._mock_dtr_prop.mock_calls == [call(False), call(False), call(False)]
assert serial._mock_rts_prop.mock_calls == [call(False), call(True), call(False)]

znp.close()
await znp.disconnect()


async def test_connect_skip_bootloader_config(make_znp_server, mocker):
Expand All @@ -133,24 +133,24 @@ async def test_connect_skip_bootloader_config(make_znp_server, mocker):
assert serial._mock_dtr_prop.called is False
assert serial._mock_rts_prop.called is False

znp.close()
await znp.disconnect()


async def test_api_close(connected_znp, mocker):
znp, znp_server = connected_znp
uart = znp._uart
mocker.spy(uart, "close")

znp.close()
await znp.disconnect()

# Make sure our UART was actually closed
assert znp._uart is None
assert znp._app is None
assert uart.close.call_count == 1

# ZNP.close should not throw any errors if called multiple times
znp.close()
znp.close()
# ZNP.disconnect should not throw any errors if called multiple times
await znp.disconnect()
await znp.disconnect()

def dict_minus(d, minus):
return {k: v for k, v in d.items() if k not in minus}
Expand All @@ -165,8 +165,8 @@ def dict_minus(d, minus):
znp2.__dict__, ignored_keys
)

znp2.close()
znp2.close()
await znp2.disconnect()
await znp2.disconnect()

assert dict_minus(znp.__dict__, ignored_keys) == dict_minus(
znp2.__dict__, ignored_keys
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def test_api_cancel_listeners(connected_znp, mocker):
)

assert not future.done()
znp.close()
await znp.disconnect()

with pytest.raises(asyncio.CancelledError):
await future
Expand Down
10 changes: 5 additions & 5 deletions tests/api/test_network_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def test_state_transfer(from_device, to_device, make_connected_znp):
formed_znp, _ = await make_connected_znp(server_cls=from_device)

await formed_znp.load_network_info()
formed_znp.close()
await formed_znp.disconnect()

empty_znp, _ = await make_connected_znp(server_cls=to_device)

Expand Down Expand Up @@ -72,15 +72,15 @@ async def test_broken_cc2531_load_state(device, make_connected_znp, caplog):
await znp.load_network_info()
assert "inconsistent" in caplog.text

znp.close()
await znp.disconnect()


@pytest.mark.parametrize("device", [FormedZStack3CC2531])
async def test_state_write_tclk_zstack3(device, make_connected_znp, caplog):
formed_znp, _ = await make_connected_znp(server_cls=device)

await formed_znp.load_network_info()
formed_znp.close()
await formed_znp.disconnect()

empty_znp, _ = await make_connected_znp(server_cls=device)

Expand All @@ -106,7 +106,7 @@ async def test_state_write_tclk_zstack3(device, make_connected_znp, caplog):
async def test_write_settings_fast(device, make_connected_znp):
formed_znp, _ = await make_connected_znp(server_cls=FormedLaunchpadCC26X2R1)
await formed_znp.load_network_info()
formed_znp.close()
await formed_znp.disconnect()

znp, _ = await make_connected_znp(server_cls=device)

Expand All @@ -126,7 +126,7 @@ async def test_write_settings_fast(device, make_connected_znp):
async def test_formation_failure_on_corrupted_nvram(device, make_connected_znp):
formed_znp, _ = await make_connected_znp(server_cls=FormedLaunchpadCC26X2R1)
await formed_znp.load_network_info()
formed_znp.close()
await formed_znp.disconnect()

znp, znp_server = await make_connected_znp(server_cls=device)

Expand Down
5 changes: 5 additions & 0 deletions zigpy_znp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,11 @@ async def connect(self, *, test_port=True) -> None:

LOGGER.debug("Connected to %s", self._uart.url)

def connection_made(self) -> None:
"""
Called by the UART object to indicate that the port was opened.
"""

def connection_lost(self, exc) -> None:
"""
Called by the UART object to indicate that the port was closed. Propagates up
Expand Down
2 changes: 1 addition & 1 deletion zigpy_znp/tools/flash_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def main(argv):
await znp.connect(test_port=False)

data = await read_firmware(znp)
znp.close()
await znp.disconnect()

f.write(data)

Expand Down
2 changes: 1 addition & 1 deletion zigpy_znp/tools/flash_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def main(argv):

await write_firmware(znp=znp, firmware=firmware, reset_nvram=args.reset)

znp.close()
await znp.disconnect()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion zigpy_znp/tools/network_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def main(argv: list[str]) -> None:
await znp.connect()

backup_obj = await backup_network(znp)
znp.close()
await znp.disconnect()

f.write(json.dumps(backup_obj, indent=4))

Expand Down
2 changes: 1 addition & 1 deletion zigpy_znp/tools/network_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def restore_network(
await znp.connect()
await znp.write_network_info(network_info=network_info, node_info=node_info)
await znp.reset()
znp.close()
await znp.disconnect()


async def main(argv: list[str]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions zigpy_znp/tools/network_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def network_scan(
await znp.nvram.osal_write(OsalNvIds.NIB, previous_nib, create=True)

await znp.nvram.osal_write(OsalNvIds.CHANLIST, previous_channels)
znp.close()
await znp.disconnect()


async def main(argv):
Expand Down Expand Up @@ -151,7 +151,7 @@ async def main(argv):
duplicates=args.allow_duplicates,
)

znp.close()
await znp.disconnect()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion zigpy_znp/tools/nvram_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def main(argv):
await znp.connect()

obj = await nvram_read(znp)
znp.close()
await znp.disconnect()

f.write(json.dumps(obj, indent=4) + "\n")

Expand Down
8 changes: 7 additions & 1 deletion zigpy_znp/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(self, api, *, url: str | None = None) -> None:

def close(self) -> None:
"""Closes the port."""
self._api = None
super().close()
self._api = None

def connection_lost(self, exc: Exception | None) -> None:
"""Connection lost."""
Expand All @@ -38,6 +38,12 @@ def connection_lost(self, exc: Exception | None) -> None:
if self._api is not None:
self._api.connection_lost(exc)

def connection_made(self, transport: asyncio.BaseTransport) -> None:
super().connection_made(transport)

if self._api is not None:
self._api.connection_made()

def data_received(self, data: bytes) -> None:
"""Callback when data is received."""
super().data_received(data)
Expand Down

0 comments on commit 7a390ca

Please sign in to comment.