Skip to content

Commit

Permalink
fixup! refactor(tests): move set_input_flow to SessionDebugWrapper co…
Browse files Browse the repository at this point in the history
…ntext manager
  • Loading branch information
mmilata committed Mar 6, 2025
1 parent e8bb3ea commit 251dcff
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 50 deletions.
29 changes: 11 additions & 18 deletions python/src/trezorlib/debuglink.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def reset_debug_features(self) -> None:
Clears all debugging state that might have been modified by a testcase.
"""
self.client.ui: DebugUI = DebugUI(self.client.debug) # is in main
self.client.ui.clear()
self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None
Expand Down Expand Up @@ -1090,25 +1090,25 @@ def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
self.client.ui, DebugUI
):
input_flow = self.client.ui.input_flow
input_flow_loops_forever = self.client.ui.input_flow_loops_forever
# input_flow_loops_forever = self.client.ui.input_flow_loops_forever
else:
input_flow = None
input_flow_loops_forever = False
# input_flow_loops_forever = False

self.reset_debug_features()

if exc_type is None:
# If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses)
if isinstance(input_flow, t.Generator) and not input_flow_loops_forever:
# Ensure that the input flow is exhausted
try:
input_flow.throw(
AssertionError("input flow continues past end of test")
)
except StopIteration:
pass
# if isinstance(input_flow, t.Generator) and not input_flow_loops_forever:
# # Ensure that the input flow is exhausted
# try:
# input_flow.throw(
# AssertionError("input flow continues past end of test")
# )
# except StopIteration:
# pass

elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
Expand Down Expand Up @@ -1251,7 +1251,6 @@ def __init__(
self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)

# self.reset_debug_features()
self._seedless_session = self.get_seedless_session(new_session=True)
self.sync_responses()

Expand Down Expand Up @@ -1463,12 +1462,6 @@ def mnemonic_callback(self, _) -> str:

raise RuntimeError("Unexpected call")

def __enter__(self) -> "TrezorClientDebugLink":
return self

def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
pass


def load_device(
session: "Session",
Expand Down
4 changes: 4 additions & 0 deletions tests/device_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@


class NullUI:
@staticmethod
def clear():
pass

@staticmethod
def button_request(code):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/device_tests/bitcoin/test_getpublickey.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def input_flow():
client.debug.press_yes() # finish the flow
yield

with client:
with session:
# test XPUB display flow (without showing QR code)
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub
Expand Down
2 changes: 1 addition & 1 deletion tests/device_tests/misc/test_msg_enablelabeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def input_flow():
client.debug.press_yes()

session = client.get_session()
with client, session:
with session:
session.set_input_flow(input_flow())
misc.encrypt_keyvalue(
session,
Expand Down
26 changes: 12 additions & 14 deletions tests/device_tests/test_msg_wipedevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def test_wipe_device(client: Client):
@pytest.mark.setup_client(pin=PIN4)
def test_autolock_not_retained(session: Session):
client = session.client
with client:
client.use_pin_sequence([PIN4])
device.apply_settings(session, auto_lock_delay_ms=10_000)
client.use_pin_sequence([PIN4])
device.apply_settings(session, auto_lock_delay_ms=10_000)

assert session.features.auto_lock_delay_ms == 10_000

Expand All @@ -57,21 +56,20 @@ def test_autolock_not_retained(session: Session):

assert client.features.auto_lock_delay_ms > 10_000

with client:
client.use_pin_sequence([PIN4, PIN4])
device.setup(
session,
skip_backup=True,
pin_protection=True,
passphrase_protection=False,
entropy_check_count=0,
backup_type=messages.BackupType.Bip39,
)
client.use_pin_sequence([PIN4, PIN4])
device.setup(
session,
skip_backup=True,
pin_protection=True,
passphrase_protection=False,
entropy_check_count=0,
backup_type=messages.BackupType.Bip39,
)

time.sleep(10.5)
session = client.get_session()

with session, client:
with session:
# after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked
session.set_expected_responses([messages.Address])
get_test_address(session)
10 changes: 5 additions & 5 deletions tests/device_tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_clear_session(client: Client):
cached_responses = [messages.PublicKey]
session = client.get_session()
session.lock()
with client, session:
with session:
client.use_pin_sequence([PIN4])
session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB
Expand All @@ -57,7 +57,7 @@ def test_clear_session(client: Client):
session = client.get_session()

# session cache is cleared
with client, session:
with session:
client.use_pin_sequence([PIN4])
session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB
Expand All @@ -76,7 +76,7 @@ def test_end_session(client: Client):
assert session.id is not None

# get_address will succeed
with session:
with session as session:
session.set_expected_responses([messages.Address])
get_test_address(session)

Expand Down Expand Up @@ -136,7 +136,7 @@ def test_end_session_only_current(client: Client):
@pytest.mark.setup_client(passphrase=True)
def test_session_recycling(client: Client):
session = client.get_session(passphrase="TREZOR")
with client, session:
with session:
session.set_expected_responses(
[
messages.PassphraseRequest,
Expand All @@ -155,7 +155,7 @@ def test_session_recycling(client: Client):

# it should still be possible to resume the original session
# TODO imo not True anymore
# with client, session:
# with session:
# # passphrase should still be cached
# session.set_expected_responses([messages.Features, messages.Address])
# client.use_passphrase("TREZOR")
Expand Down
6 changes: 3 additions & 3 deletions tests/device_tests/test_session_id_and_passphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def call(passphrase: str, expected_result: bool):
def test_hide_passphrase_from_host(client: Client):
# Without safety checks, turning it on fails
session = client.get_seedless_session()
with pytest.raises(TrezorFailure, match="Safety checks are strict"), client:
with pytest.raises(TrezorFailure, match="Safety checks are strict"):
device.apply_settings(session, hide_passphrase_from_host=True)

device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
Expand All @@ -405,7 +405,7 @@ def test_hide_passphrase_from_host(client: Client):

passphrase = "abc"
session = client.get_session(passphrase=passphrase)
with client, session:
with session:

def input_flow():
yield
Expand Down Expand Up @@ -439,7 +439,7 @@ def input_flow():
# Starting new session, otherwise the passphrase would be cached
session = client.get_session(passphrase=passphrase)

with client, session:
with session:

def input_flow():
yield
Expand Down
16 changes: 10 additions & 6 deletions tests/persistence_tests/test_wipe_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,37 @@


def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None:
device.wipe(client.get_seedless_session())
session = client.get_seedless_session()
device.wipe(session)
client = client.get_new_client()
session = client.get_seedless_session()
debuglink.load_device(
client.get_seedless_session(),
session,
MNEMONIC12,
pin,
passphrase_protection=False,
label="WIPECODE",
)

with client:
with session:
client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE])
device.change_wipe_code(client.get_seedless_session())


def setup_device_core(client: Client, pin: str, wipe_code: str) -> None:
device.wipe(client.get_seedless_session())
session = client.get_seedless_session()
device.wipe(session)
client = client.get_new_client()
session = client.get_seedless_session()
debuglink.load_device(
client.get_seedless_session(),
session,
MNEMONIC12,
pin,
passphrase_protection=False,
label="WIPECODE",
)

with client:
with session:
client.use_pin_sequence([pin, wipe_code, wipe_code])
device.change_wipe_code(client.get_seedless_session())

Expand Down
4 changes: 2 additions & 2 deletions tests/upgrade_tests/test_firmware_upgrades.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):

# Create a backup of the encrypted master secret.
assert emu.client.features.backup_availability == BackupAvailability.Required
with emu.client:
session = emu.client.get_session()
with session:
IF = InputFlowSlip39BasicBackup(emu.client, False)
session = emu.client.get_session()
session.set_input_flow(IF.get())
device.backup(session)
assert (
Expand Down

0 comments on commit 251dcff

Please sign in to comment.