diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 56c1dfa344f..3af5711a7cf 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -67,6 +67,8 @@ def __call__( wait: bool | None = None, ) -> "LayoutContent": ... + InputFlowType = Generator[None, messages.ButtonRequest, None] + EXPECTED_RESPONSES_CONTEXT_LINES = 3 @@ -1108,7 +1110,7 @@ def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: return msg def set_input_flow( - self, input_flow: Generator[None, messages.ButtonRequest | None, None] + self, input_flow: InputFlowType | Callable[[], InputFlowType] ) -> None: """Configure a sequence of input events for the current with-block. @@ -1142,7 +1144,7 @@ def set_input_flow( if not hasattr(input_flow, "send"): raise RuntimeError("input_flow should be a generator function") self.ui.input_flow = input_flow - input_flow.send(None) # start the generator + next(input_flow) # start the generator def watch_layout(self, watch: bool = True) -> None: """Enable or disable watching layout changes. @@ -1190,7 +1192,8 @@ def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: input_flow.throw(exc_type, value, traceback) def set_expected_responses( - self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] + self, + expected: Sequence[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]], ) -> None: """Set a sequence of expected responses to client calls. diff --git a/tests/input_flows.py b/tests/input_flows.py index 2609cada306..00bd63ea083 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -12,7 +12,7 @@ from __future__ import annotations import time -from typing import Callable, Generator +from typing import Callable, Generator, Sequence from trezorlib import messages from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType @@ -2049,7 +2049,7 @@ def input_flow_common(self) -> BRGeneratorType: class InputFlowSlip39BasicRecovery(InputFlowBase): - def __init__(self, client: Client, shares: list[str], pin: str | None = None): + def __init__(self, client: Client, shares: Sequence[str], pin: str | None = None): super().__init__(client) self.shares = shares self.pin = pin diff --git a/tests/input_flows_helpers.py b/tests/input_flows_helpers.py index 3ad5f580ed3..d51d3a83153 100644 --- a/tests/input_flows_helpers.py +++ b/tests/input_flows_helpers.py @@ -301,7 +301,7 @@ def input_mnemonic(self, mnemonic: list[str]) -> BRGeneratorType: def input_all_slip39_shares( self, - shares: list[str], + shares: t.Sequence[str], has_groups: bool = False, click_info: bool = False, ) -> BRGeneratorType: