Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debuglink timing out fix #4655

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions python/src/trezorlib/debuglink.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@ def _write(self, msg: protobuf.MessageType) -> None:
)
self.transport.write(msg_type, msg_bytes)

def _read(self) -> protobuf.MessageType:
ret_type, ret_bytes = self.transport.read()
def _read(self, timeout: float | None = None) -> protobuf.MessageType:
ret_type, ret_bytes = self.transport.read(timeout=timeout)
LOG.log(
DUMP_BYTES,
f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
Expand All @@ -513,9 +513,9 @@ def _read(self) -> protobuf.MessageType:
)
return msg

def _call(self, msg: protobuf.MessageType) -> Any:
def _call(self, msg: protobuf.MessageType, timeout: float | None = None) -> Any:
self._write(msg)
return self._read()
return self._read(timeout=timeout)

def state(self, wait_type: DebugWaitType | None = None) -> messages.DebugLinkState:
if wait_type is None:
Expand Down Expand Up @@ -640,7 +640,10 @@ def _decision(self, decision: messages.DebugLinkDecision) -> None:
if self.model is models.T1B1:
return
# When the call below returns, we know that `decision` has been processed in Core.
self._call(messages.DebugLinkGetState(return_empty_state=True))
# XXX Due to a bug, the reply may get lost at the end of a workflow.
# We assume that no single input event takes more than 5 seconds to process,
# and give up waiting after that.
self._call(messages.DebugLinkGetState(return_empty_state=True), timeout=5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just to be sure - we raise an exception here, right?
Asking, since IIRC there was a suggestion to ignore the timeout (i.e. to assume the request has been processed by the device), but then I would suggest also logging that there was an ignored timeout.

Copy link
Contributor

@romanz romanz Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use an environment variable to set the timeout in the general case?
IIUC, it may take the device more than 5 seconds to respond, e.g. when running WipeDevice.


press_yes = _make_input_func(button=messages.DebugButton.YES)
"""Confirm current layout. See `_decision` for more details."""
Expand Down
37 changes: 14 additions & 23 deletions python/src/trezorlib/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,10 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from __future__ import annotations

import logging
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar

from ..exceptions import TrezorException

Expand Down Expand Up @@ -84,23 +77,23 @@ def begin_session(self) -> None:
def end_session(self) -> None:
raise NotImplementedError

def read(self) -> MessagePayload:
def read(self, timeout: float | None = None) -> MessagePayload:
raise NotImplementedError

def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError

def find_debug(self: "T") -> "T":
def find_debug(self: T) -> T:
raise NotImplementedError

@classmethod
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
cls: type[T], models: Iterable[TrezorModel] | None = None
) -> Iterable[T]:
raise NotImplementedError

@classmethod
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T:
for device in cls.enumerate():
if (
path is None
Expand All @@ -112,13 +105,13 @@ def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")


def all_transports() -> Iterable[Type["Transport"]]:
def all_transports() -> Iterable[type["Transport"]]:
from .bridge import BridgeTransport
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport

transports: Tuple[Type["Transport"], ...] = (
transports: Tuple[type["Transport"], ...] = (
BridgeTransport,
HidTransport,
UdpTransport,
Expand All @@ -128,9 +121,9 @@ def all_transports() -> Iterable[Type["Transport"]]:


def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None,
) -> Sequence["Transport"]:
devices: List["Transport"] = []
models: Iterable[TrezorModel] | None = None,
) -> Sequence[Transport]:
devices: list[Transport] = []
for transport in all_transports():
name = transport.__name__
try:
Expand All @@ -145,9 +138,7 @@ def enumerate_devices(
return devices


def get_transport(
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
def get_transport(path: str | None = None, prefix_search: bool = False) -> Transport:
if path is None:
try:
return next(iter(enumerate_devices()))
Expand Down
48 changes: 29 additions & 19 deletions python/src/trezorlib/transport/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from __future__ import annotations

import logging
import struct
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Any, Iterable

import requests
from typing_extensions import Self

from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
Expand All @@ -45,9 +48,11 @@ def __init__(self, path: str, status: int, message: str) -> None:
super().__init__(f"trezord: {path} failed with code {status}: {message}")


def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
def call_bridge(
path: str, data: str | None = None, timeout: float | None = None
) -> requests.Response:
url = TREZORD_HOST + "/" + path
r = CONNECTION.post(url, data=data)
r = CONNECTION.post(url, data=data, timeout=timeout)
if r.status_code != 200:
raise BridgeException(path, r.status_code, r.json()["error"])
return r
Expand All @@ -63,7 +68,7 @@ class BridgeHandle:
def __init__(self, transport: "BridgeTransport") -> None:
self.transport = transport

def read_buf(self) -> bytes:
def read_buf(self, timeout: float | None = None) -> bytes:
raise NotImplementedError

def write_buf(self, buf: bytes) -> None:
Expand All @@ -75,28 +80,28 @@ def write_buf(self, buf: bytes) -> None:
LOG.log(DUMP_PACKETS, f"sending message: {buf.hex()}")
self.transport._call("post", data=buf.hex())

def read_buf(self) -> bytes:
data = self.transport._call("read")
def read_buf(self, timeout: float | None = None) -> bytes:
data = self.transport._call("read", timeout=timeout)
LOG.log(DUMP_PACKETS, f"received message: {data.text}")
return bytes.fromhex(data.text)


class BridgeHandleLegacy(BridgeHandle):
def __init__(self, transport: "BridgeTransport") -> None:
super().__init__(transport)
self.request: Optional[str] = None
self.request: str | None = None

def write_buf(self, buf: bytes) -> None:
if self.request is not None:
raise TransportException("Can't write twice on legacy Bridge")
self.request = buf.hex()

def read_buf(self) -> bytes:
def read_buf(self, timeout: float | None = None) -> bytes:
if self.request is None:
raise TransportException("Can't read without write on legacy Bridge")
try:
LOG.log(DUMP_PACKETS, f"calling with message: {self.request}")
data = self.transport._call("call", data=self.request)
data = self.transport._call("call", data=self.request, timeout=timeout)
LOG.log(DUMP_PACKETS, f"received response: {data.text}")
return bytes.fromhex(data.text)
finally:
Expand All @@ -112,13 +117,13 @@ class BridgeTransport(Transport):
ENABLED: bool = True

def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False
self, device: dict[str, Any], legacy: bool, debug: bool = False
) -> None:
if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge")

self.device = device
self.session: Optional[str] = None
self.session: str | None = None
self.debug = debug
self.legacy = legacy

Expand All @@ -130,21 +135,26 @@ def __init__(
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path']}"

def find_debug(self) -> "BridgeTransport":
def find_debug(self) -> Self:
if not self.device.get("debug"):
raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True)

def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
return self.__class__(self.device, self.legacy, debug=True)

def _call(
self,
action: str,
data: str | None = None,
timeout: float | None = None,
) -> requests.Response:
session = self.session or "null"
uri = action + "/" + str(session)
if self.debug:
uri = "debug/" + uri
return call_bridge(uri, data=data)
return call_bridge(uri, data=data, timeout=timeout)

@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
cls, _models: Iterable[TrezorModel] | None = None
) -> Iterable["BridgeTransport"]:
try:
legacy = is_legacy_bridge()
Expand Down Expand Up @@ -173,8 +183,8 @@ def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data)

def read(self) -> MessagePayload:
data = self.handle.read_buf()
def read(self, timeout: float | None = None) -> MessagePayload:
data = self.handle.read_buf(timeout=timeout)
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
return msg_type, data[headerlen : headerlen + datalen]
17 changes: 11 additions & 6 deletions python/src/trezorlib/transport/hid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from __future__ import annotations

import logging
import sys
import time
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable

from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel
Expand Down Expand Up @@ -91,13 +93,16 @@ def write_chunk(self, chunk: bytes) -> None:
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
self.handle.write(chunk)

def read_chunk(self) -> bytes:
def read_chunk(self, timeout: float | None = None) -> bytes:
start = time.time()
while True:
# hidapi seems to return lists of ints instead of bytes
chunk = bytes(self.handle.read(64))
if chunk:
break
else:
if timeout is not None and time.time() - start > timeout:
raise TransportException("Timeout reading HID packet")
time.sleep(0.001)

LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
Expand Down Expand Up @@ -134,13 +139,13 @@ def get_path(self) -> str:

@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
cls, models: Iterable[TrezorModel] | None = None, debug: bool = False
) -> Iterable[HidTransport]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]

devices: List["HidTransport"] = []
devices: list[HidTransport] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
Expand All @@ -154,7 +159,7 @@ def enumerate(
devices.append(HidTransport(dev))
return devices

def find_debug(self) -> "HidTransport":
def find_debug(self) -> HidTransport:
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
Expand Down
23 changes: 12 additions & 11 deletions python/src/trezorlib/transport/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from __future__ import annotations

import logging
import struct
from typing import Tuple

from typing_extensions import Protocol as StructuralType

Expand Down Expand Up @@ -48,7 +49,7 @@ def open(self) -> None: ...

def close(self) -> None: ...

def read_chunk(self) -> bytes: ...
def read_chunk(self, timeout: float | None = None) -> bytes: ...

def write_chunk(self, chunk: bytes) -> None: ...

Expand Down Expand Up @@ -86,7 +87,7 @@ def end_session(self) -> None:
if self.session_counter == 0:
self.handle.close()

def read(self) -> MessagePayload:
def read(self, timeout: float | None = None) -> MessagePayload:
raise NotImplementedError

def write(self, message_type: int, message_data: bytes) -> None:
Expand All @@ -106,8 +107,8 @@ def __init__(self, protocol: Protocol) -> None:
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)

def read(self) -> MessagePayload:
return self.protocol.read()
def read(self, timeout: float | None = None) -> MessagePayload:
return self.protocol.read(timeout=timeout)

def begin_session(self) -> None:
self.protocol.begin_session()
Expand All @@ -134,10 +135,10 @@ def write(self, message_type: int, message_data: bytes) -> None:
self.handle.write_chunk(chunk)
buffer = buffer[63:]

def read(self) -> MessagePayload:
def read(self, timeout: float | None = None) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
buffer.extend(first_chunk)

# Read the rest of the message
Expand All @@ -146,8 +147,8 @@ def read(self) -> MessagePayload:

return msg_type, buffer[:datalen]

def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]:
chunk = self.handle.read_chunk(timeout=timeout)
if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try:
Expand All @@ -158,8 +159,8 @@ def read_first(self) -> Tuple[int, int, bytes]:
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data

def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
def read_next(self, timeout: float | None = None) -> bytes:
chunk = self.handle.read_chunk(timeout=timeout)
if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:]
Loading
Loading