Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trezorlib API improvements #4490

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/developers/hello_world_feature_TT.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ if TYPE_CHECKING:
from .protobuf import MessageType


@expect(messages.HelloWorldResponse, field="text", ret_type=str)
def say_hello(
client: "TrezorClient",
name: str,
Expand All @@ -166,8 +165,9 @@ def say_hello(
name=name,
amount=amount,
show_display=show_display,
)
)
),
expect=messages.HelloWorldResponse,
).text
```

Code above is sending `HelloWorldRequest` into Trezor and is expecting to get `HelloWorldResponse` back (from which it extracts the `text` string as a response).
Expand Down
1 change: 1 addition & 0 deletions python/.changelog.d/4464.added.1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added an `expect` argument to `TrezorClient.call()`, to enforce the returned message type.
25 changes: 15 additions & 10 deletions python/src/trezorlib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from __future__ import annotations

import logging
import os
import warnings
Expand All @@ -24,14 +26,15 @@
from . import exceptions, mapping, messages, models
from .log import DUMP_BYTES
from .messages import Capability
from .protobuf import MessageType
from .tools import expect, parse_path, session

if TYPE_CHECKING:
from .protobuf import MessageType
from .transport import Transport
from .ui import TrezorClientUI

UI = TypeVar("UI", bound="TrezorClientUI")
romanz marked this conversation as resolved.
Show resolved Hide resolved
MT = TypeVar("MT", bound=MessageType)

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -149,12 +152,12 @@ def close(self) -> None:
def cancel(self) -> None:
self._raw_write(messages.Cancel())

def call_raw(self, msg: "MessageType") -> "MessageType":
def call_raw(self, msg: MessageType) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg)
return self._raw_read()

def _raw_write(self, msg: "MessageType") -> None:
def _raw_write(self, msg: MessageType) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
LOG.debug(
f"sending message: {msg.__class__.__name__}",
Expand All @@ -167,7 +170,7 @@ def _raw_write(self, msg: "MessageType") -> None:
)
self.transport.write(msg_type, msg_bytes)

def _raw_read(self) -> "MessageType":
def _raw_read(self) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
msg_type, msg_bytes = self.transport.read()
LOG.log(
Expand All @@ -181,7 +184,7 @@ def _raw_read(self) -> "MessageType":
)
return msg

def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType":
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
try:
pin = self.ui.get_pin(msg.type)
except exceptions.Cancelled:
Expand All @@ -204,12 +207,12 @@ def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType":
else:
return resp

def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType":
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
available_on_device = Capability.PassphraseEntry in self.features.capabilities

def send_passphrase(
passphrase: Optional[str] = None, on_device: Optional[bool] = None
) -> "MessageType":
) -> MessageType:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = self.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
Expand Down Expand Up @@ -244,15 +247,15 @@ def send_passphrase(

return send_passphrase(passphrase, on_device=False)

def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType":
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
self._raw_write(messages.ButtonAck())
self.ui.button_request(msg)
return self._raw_read()

@session
def call(self, msg: "MessageType") -> "MessageType":
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
self.check_firmware_version()
resp = self.call_raw(msg)
while True:
Expand All @@ -266,6 +269,8 @@ def call(self, msg: "MessageType") -> "MessageType":
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
elif not isinstance(resp, expect):
raise exceptions.UnexpectedMessageError(expect, resp)
else:
return resp

Expand Down Expand Up @@ -397,7 +402,7 @@ def ping(
self,
msg: str,
button_protection: bool = False,
) -> "MessageType":
) -> MessageType:
# We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old.
Expand Down
12 changes: 11 additions & 1 deletion python/src/trezorlib/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .messages import Failure
from .protobuf import MessageType


class TrezorException(Exception):
pass


class TrezorFailure(TrezorException):
def __init__(self, failure: "Failure") -> None:
def __init__(self, failure: Failure) -> None:
self.failure = failure
self.code = failure.code
self.message = failure.message
Expand Down Expand Up @@ -55,3 +58,10 @@ class Cancelled(TrezorException):

class OutdatedFirmwareError(TrezorException):
pass


class UnexpectedMessageError(TrezorException):
def __init__(self, expected: type[MessageType], actual: MessageType) -> None:
self.expected = expected
self.actual = actual
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")