Skip to content

Commit

Permalink
fix(python): revive trezorctl --script
Browse files Browse the repository at this point in the history
[no changelog]
  • Loading branch information
mmilata committed Mar 5, 2025
1 parent 634aec5 commit 6da8704
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 37 deletions.
18 changes: 15 additions & 3 deletions python/src/trezorlib/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import click

from .. import exceptions, transport, ui
from ..client import ProtocolVersion, TrezorClient
from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_passphrase(
passphrase_on_host: bool, available_on_device: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE
return PASSPHRASE_ON_DEVICE

env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
Expand Down Expand Up @@ -158,6 +158,8 @@ def get_session(

if empty_passphrase:
passphrase = ""
elif self.script:
passphrase = None
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
Expand Down Expand Up @@ -188,7 +190,17 @@ def get_transport(self) -> "Transport":
return _TRANSPORT

def get_client(self) -> TrezorClient:
return get_client(self.get_transport())
client = get_client(self.get_transport())
if self.script:
client.button_callback = ui.ScriptUI.button_request
client.passphrase_callback = ui.ScriptUI.get_passphrase
client.pin_callback = ui.ScriptUI.get_pin
else:
click_ui = ui.ClickUI()
client.button_callback = click_ui.button_request
client.passphrase_callback = click_ui.get_passphrase
client.pin_callback = click_ui.get_pin
return client

def get_seedless_session(self) -> Session:
client = self.get_client()
Expand Down
1 change: 0 additions & 1 deletion python/src/trezorlib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def get_default_client(
If path is specified, does a prefix-search for the specified device. Otherwise, uses
the value of TREZOR_PATH env variable, or finds first connected Trezor.
If no UI is supplied, instantiates the default CLI UI.
"""

if path is None:
Expand Down
100 changes: 67 additions & 33 deletions python/src/trezorlib/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

import os
import sys
from typing import Any, Callable, Optional, Union
import typing as t

import click
from mnemonic import Mnemonic
from typing_extensions import Protocol

from . import device, messages
from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE
from .exceptions import Cancelled
from .messages import PinMatrixRequestType, WordRequestType
from .client import MAX_PIN_LENGTH
from .exceptions import Cancelled, PinException
from .messages import Capability, PinMatrixRequestType, WordRequestType
from .transport.session import Session

PIN_MATRIX_DESCRIPTION = """
Use the numeric keypad or lowercase letters to describe number positions.
Expand Down Expand Up @@ -62,19 +62,11 @@
CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty()


class TrezorClientUI(Protocol):
def button_request(self, br: messages.ButtonRequest) -> None: ...

def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ...

def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ...


def echo(*args: Any, **kwargs: Any) -> None:
def echo(*args: t.Any, **kwargs: t.Any) -> None:
return click.echo(*args, err=True, **kwargs)


def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any:
def prompt(text: str, *, hide_input: bool = False, **kwargs: t.Any) -> t.Any:
# Disallowing hidden input and warning user when it would cause issues
if not CAN_HANDLE_HIDDEN_INPUT and hide_input:
hide_input = False
Expand All @@ -99,14 +91,16 @@ def _prompt_for_button(self, br: messages.ButtonRequest) -> str:

return "Please confirm action on your Trezor device."

def button_request(self, br: messages.ButtonRequest) -> None:
def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any:
prompt = self._prompt_for_button(br)
if prompt != self.last_prompt_shown:
echo(prompt)
if not self.always_prompt:
self.last_prompt_shown = prompt
return session.call_raw(messages.ButtonAck())

def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any:
code = request.type
if code == PIN_CURRENT:
desc = "current PIN"
elif code == PIN_NEW:
Expand All @@ -129,6 +123,7 @@ def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
try:
pin = prompt(f"Please enter {desc}", hide_input=True)
except click.Abort:
session.call_raw(messages.Cancel())
raise Cancelled from None

# translate letters to numbers if letters were used
Expand All @@ -142,16 +137,33 @@ def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
elif len(pin) > MAX_PIN_LENGTH:
echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
else:
return pin

def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp

def get_passphrase(
self, session: Session, request: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
if available_on_device and not self.passphrase_on_host:
return PASSPHRASE_ON_DEVICE
return session.call_raw(
messages.PassphraseAck(passphrase=None, on_device=True)
)

env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
return session.call_raw(
messages.PassphraseAck(passphrase=env_passphrase, on_device=False)
)

while True:
try:
Expand All @@ -163,20 +175,24 @@ def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
)
# In case user sees the input on the screen, we do not need confirmation
if not CAN_HANDLE_HIDDEN_INPUT:
return passphrase
break
second = prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
break
else:
echo("Passphrase did not match. Please try again.")
except click.Abort:
raise Cancelled from None

return session.call_raw(
messages.PassphraseAck(passphrase=passphrase, on_device=False)
)


class ScriptUI:
"""Interface to be used by scripts, not directly by user.
Expand All @@ -190,13 +206,14 @@ class ScriptUI:
"""

@staticmethod
def button_request(br: messages.ButtonRequest) -> None:
# TODO: send name={br.name} when it will be supported
def button_request(session: Session, br: messages.ButtonRequest) -> t.Any:
code = br.code.name if br.code else None
print(f"?BUTTON code={code} pages={br.pages}")
print(f"?BUTTON code={code} pages={br.pages} name={br.name}")
return session.call_raw(messages.ButtonAck())

@staticmethod
def get_pin(code: Optional[PinMatrixRequestType] = None) -> str:
def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any:
code = request.type
if code is None:
print("?PIN")
else:
Expand All @@ -208,10 +225,22 @@ def get_pin(code: Optional[PinMatrixRequestType] = None) -> str:
elif not pin.startswith(":"):
raise RuntimeError("Sent PIN must start with ':'")
else:
return pin[1:]
pin = pin[1:]
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp

@staticmethod
def get_passphrase(available_on_device: bool) -> Union[str, object]:
def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
if available_on_device:
print("?PASSPHRASE available_on_device")
else:
Expand All @@ -221,16 +250,21 @@ def get_passphrase(available_on_device: bool) -> Union[str, object]:
if passphrase == "CANCEL":
raise Cancelled from None
elif passphrase == "ON_DEVICE":
return PASSPHRASE_ON_DEVICE
return session.call_raw(
messages.PassphraseAck(passphrase=None, on_device=True)
)
elif not passphrase.startswith(":"):
raise RuntimeError("Sent passphrase must start with ':'")
else:
return passphrase[1:]
passphrase = passphrase[1:]
return session.call_raw(
messages.PassphraseAck(passphrase=passphrase, on_device=False)
)


def mnemonic_words(
expand: bool = False, language: str = "english"
) -> Callable[[WordRequestType], str]:
) -> t.Callable[[WordRequestType], str]:
if expand:
wordlist = Mnemonic(language).wordlist
else:
Expand Down

0 comments on commit 6da8704

Please sign in to comment.