diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml index ced73797188..7ad31afa6d7 100644 --- a/.github/workflows/core.yml +++ b/.github/workflows/core.yml @@ -54,10 +54,19 @@ jobs: model: [T2T1, T3B1, T3T1, T3W1] coins: [universal, btconly] type: ${{ fromJSON(github.event_name == 'schedule' && '["normal", "debuglink", "production"]' || '["normal", "debuglink"]') }} + protocol: [v1] include: - model: D001 coins: universal type: normal + - model: T2T1 + coins: universal + type: debuglink + protocol: v2 + - model: T2T1 + coins: btconly + type: debuglink + protocol: v2 exclude: - model: T3W1 type: production @@ -67,6 +76,7 @@ jobs: PYOPT: ${{ matrix.type == 'debuglink' && '0' || '1' }} PRODUCTION: ${{ matrix.type == 'production' && '1' || '0' }} BOOTLOADER_DEVEL: ${{ matrix.model == 'T3W1' && '1' || '0' }} + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: @@ -87,7 +97,7 @@ jobs: if: matrix.coins == 'btconly' && matrix.type != 'debuglink' - uses: actions/upload-artifact@v4 with: - name: core-firmware-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.type }} + name: core-firmware-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.type }}-protocol_${{ matrix.protocol }} path: | core/build/boardloader/*.bin core/build/bootloader/*.bin @@ -109,15 +119,28 @@ jobs: # type: [normal, debuglink] type: [debuglink] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1] exclude: - type: normal asan: asan + include: + - model: T2T1 + coins: universal + type: debuglink + asan: noasan + protocol: v2 + - model: T2T1 + coins: btconly + type: debuglink + asan: noasan + protocol: v2 env: TREZOR_MODEL: ${{ matrix.model == 'T2T1' && 'T' || matrix.model }} BITCOIN_ONLY: ${{ matrix.coins == 'universal' && '0' || '1' }} PYOPT: ${{ matrix.type == 'debuglink' && '0' || '1' }} ADDRESS_SANITIZER: ${{ matrix.asan == 'asan' && '1' || '0' }} LSAN_OPTIONS: "suppressions=../../asan_suppressions.txt" + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: @@ -129,7 +152,7 @@ jobs: - run: cp core/build/unix/trezor-emu-core core/build/unix/trezor-emu-core-${{ matrix.model }}-${{ matrix.coins }} - uses: actions/upload-artifact@v4 with: - name: core-emu-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.type }}-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.type }}-${{ matrix.asan }}-protocol_${{ matrix.protocol }} path: | core/build/unix/trezor-emu-core* core/build/bootloader_emu/bootloader.elf @@ -174,7 +197,7 @@ jobs: retention-days: 2 core_unit_python_test: - name: Python unit tests + name: Python unit tests (${{ matrix.model }}, ${{ matrix.asan }}, protocol_${{ matrix.protocol}}) runs-on: ubuntu-latest needs: param strategy: @@ -182,10 +205,12 @@ jobs: matrix: model: [T2T1, T3B1, T3T1, T3W1] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1, v2] env: TREZOR_MODEL: ${{ matrix.model == 'T2T1' && 'T' || matrix.model }} ADDRESS_SANITIZER: ${{ matrix.asan == 'asan' && '1' || '0' }} LSAN_OPTIONS: "suppressions=../../asan_suppressions.txt" + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: @@ -195,7 +220,7 @@ jobs: - run: nix-shell --run "poetry run make -C core test" core_unit_rust_test: - name: Rust unit tests + name: Rust unit tests (${{ matrix.model }}, ${{ matrix.asan }}, protocol_${{ matrix.protocol}}) runs-on: ubuntu-latest needs: - param @@ -205,12 +230,14 @@ jobs: matrix: model: [T2T1, T3B1, T3T1, T3W1] asan: ${{ fromJSON(needs.param.outputs.asan) }} + protocol: [v1, v2] env: TREZOR_MODEL: ${{ matrix.model == 'T2T1' && 'T' || matrix.model }} ADDRESS_SANITIZER: ${{ matrix.asan == 'asan' && '1' || '0' }} RUSTC_BOOTSTRAP: ${{ matrix.asan == 'asan' && '1' || '0' }} RUSTFLAGS: ${{ matrix.asan == 'asan' && '-Z sanitizer=address' || '' }} LSAN_OPTIONS: "suppressions=../../asan_suppressions.txt" + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: @@ -234,7 +261,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-noasan + name: core-emu-${{ matrix.model }}-universal-debuglink-noasan-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -257,6 +284,13 @@ jobs: coins: [universal, btconly] asan: ${{ fromJSON(needs.param.outputs.asan) }} lang: ${{ fromJSON(needs.param.outputs.test_lang) }} + protocol: [v1] + include: + - model: T2T1 + coins: universal + asan: noasan + lang: en + protocol: v2 env: TREZOR_PROFILING: ${{ matrix.asan == 'noasan' && '1' || '0' }} TREZOR_MODEL: ${{ matrix.model == 'T2T1' && 'T' || matrix.model }} @@ -265,13 +299,14 @@ jobs: PYTEST_TIMEOUT: ${{ matrix.asan == 'asan' && 600 || 400 }} ACTIONS_DO_UI_TEST: ${{ matrix.coins == 'universal' && matrix.asan == 'noasan' }} TEST_LANG: ${{ matrix.lang }} + THP: ${{ matrix.protocol == 'v2' && '1' || '0'}} steps: - uses: actions/checkout@v4 with: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-${{ matrix.coins }}-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-${{ matrix.coins }}-debuglink-${{ matrix.asan }}-protocol_${{ matrix.protocol }} path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -280,7 +315,7 @@ jobs: if: failure() - uses: actions/upload-artifact@v4 with: - name: core-test-device-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.lang }}-${{ matrix.asan }} + name: core-test-device-${{ matrix.model }}-${{ matrix.coins }}-${{ matrix.lang }}-${{ matrix.asan }}-protocol_${{ matrix.protocol }} path: tests/trezor.log retention-days: 7 if: always() @@ -319,7 +354,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -329,7 +364,7 @@ jobs: if: ${{ matrix.asan == 'asan' }} - uses: actions/upload-artifact@v4 with: - name: core-test-click-${{ matrix.model }}-${{ matrix.lang }}-${{ matrix.asan }} + name: core-test-click-${{ matrix.model }}-${{ matrix.lang }}-${{ matrix.asan }}-protocol_v1 path: tests/trezor.log retention-days: 7 if: always() @@ -367,7 +402,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -397,7 +432,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -430,7 +465,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-noasan + name: core-emu-${{ matrix.model }}-universal-debuglink-noasan-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment # XXX poetry maybe not needed @@ -488,7 +523,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-firmware-${{ matrix.model }}-universal-normal # FIXME: s/normal/debuglink/ + name: core-firmware-${{ matrix.model }}-universal-normal-protocol_v1 # FIXME: s/normal/debuglink/ path: core/build - uses: ./.github/actions/environment - run: nix-shell --run "poetry run core/tools/size/checker.py core/build/firmware/firmware.elf" @@ -512,7 +547,7 @@ jobs: fetch-depth: 0 - uses: actions/download-artifact@v4 with: - name: core-firmware-${{ matrix.model }}-universal-normal + name: core-firmware-${{ matrix.model }}-universal-normal-protocol_v1 path: core/build - uses: ./.github/actions/environment - run: nix-shell --run "poetry run core/tools/size/compare_master.py core/build/firmware/firmware.elf -r firmware_elf_size_report.txt" @@ -543,7 +578,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: cachix/install-nix-action@v23 @@ -584,7 +619,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -619,7 +654,7 @@ jobs: submodules: recursive - uses: actions/download-artifact@v4 with: - name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }} + name: core-emu-${{ matrix.model }}-universal-debuglink-${{ matrix.asan }}-protocol_v1 path: core/build - run: chmod +x core/build/unix/trezor-emu-core* - uses: ./.github/actions/environment @@ -724,7 +759,7 @@ jobs: steps: - uses: actions/download-artifact@v4 with: - pattern: core-emu*debuglink-noasan + pattern: core-emu*debuglink-noasan-protocol_v* merge-multiple: true - name: Configure aws credentials uses: aws-actions/configure-aws-credentials@v4 @@ -747,7 +782,7 @@ jobs: steps: - uses: actions/download-artifact@v4 with: - pattern: core-emu*debuglink-noasan + pattern: core-emu*debuglink-noasan-protocol_v* merge-multiple: true - name: Configure aws credentials uses: aws-actions/configure-aws-credentials@v4 diff --git a/ci/build.yml b/ci/build.yml index 8faa98da557..3f2861edf04 100644 --- a/ci/build.yml +++ b/ci/build.yml @@ -307,6 +307,7 @@ core unix frozen debug build: needs: [] variables: PYOPT: "0" + THP: "1" script: - $NIX_SHELL --run "poetry run make -C core build_unix_frozen" artifacts: diff --git a/common/protob/messages-common.proto b/common/protob/messages-common.proto index 3e8cb9537ce..a11eaff2091 100644 --- a/common/protob/messages-common.proto +++ b/common/protob/messages-common.proto @@ -39,6 +39,8 @@ message Failure { Failure_PinMismatch = 12; Failure_WipeCodeMismatch = 13; Failure_InvalidSession = 14; + Failure_ThpUnallocatedSession=15; + Failure_InvalidProtocol=16; Failure_FirmwareError = 99; } } diff --git a/common/protob/messages-debug.proto b/common/protob/messages-debug.proto index 08f0f30e5f5..9e64d9944f6 100644 --- a/common/protob/messages-debug.proto +++ b/common/protob/messages-debug.proto @@ -110,6 +110,8 @@ message DebugLinkGetState { // trezor-core only - wait until current layout changes // changed in 2.6.4: multiple wait types instead of true/false. optional DebugWaitType wait_layout = 3 [default=IMMEDIATE]; + // THP only - it is used to get information from specified channel + optional bytes thp_channel_id=4; } /** @@ -130,6 +132,9 @@ message DebugLinkState { optional uint32 reset_word_pos = 11; // index of mnemonic word the device is expecting during ResetDevice workflow optional management.BackupType mnemonic_type = 12; // current mnemonic type (BIP-39/SLIP-39) repeated string tokens = 13; // current layout represented as a list of string tokens + optional uint32 thp_pairing_code_entry_code = 14; + optional bytes thp_pairing_code_qr_code = 15; + optional bytes thp_pairing_code_nfc_unidirectional = 16; } /** diff --git a/common/protob/messages-thp.proto b/common/protob/messages-thp.proto index c05d9f64d71..5fb24d055c3 100644 --- a/common/protob/messages-thp.proto +++ b/common/protob/messages-thp.proto @@ -9,6 +9,218 @@ import "options.proto"; option (include_in_bitcoin_only) = true; +/** + * Mapping between Trezor wire identifier (uint) and a Thp protobuf message + */ +enum ThpMessageType { + reserved 0 to 999; // Values reserved by other messages, see messages.proto + + ThpMessageType_ThpCreateNewSession = 1000[(bitcoin_only)=true, (channel_in) = true]; + ThpMessageType_ThpNewSession = 1001[(bitcoin_only)=true, (channel_out) = true]; + ThpMessageType_ThpStartPairingRequest = 1008 [(bitcoin_only) = true, (pairing_in) = true]; + ThpMessageType_ThpPairingPreparationsFinished = 1009 [(bitcoin_only) = true, (pairing_out) = true]; + ThpMessageType_ThpCredentialRequest = 1010 [(bitcoin_only) = true, (pairing_in) = true]; + ThpMessageType_ThpCredentialResponse = 1011 [(bitcoin_only) = true, (pairing_out) = true]; + ThpMessageType_ThpEndRequest = 1012 [(bitcoin_only) = true, (pairing_in) = true]; + ThpMessageType_ThpEndResponse = 1013[(bitcoin_only) = true, (pairing_out) = true]; + ThpMessageType_ThpCodeEntryCommitment = 1016[(bitcoin_only)=true, (pairing_out) = true]; + ThpMessageType_ThpCodeEntryChallenge = 1017[(bitcoin_only)=true, (pairing_in) = true]; + ThpMessageType_ThpCodeEntryCpaceHost = 1018[(bitcoin_only)=true, (pairing_in) = true]; + ThpMessageType_ThpCodeEntryCpaceTrezor = 1019[(bitcoin_only)=true, (pairing_out) = true]; + ThpMessageType_ThpCodeEntryTag = 1020[(bitcoin_only)=true, (pairing_in) = true]; + ThpMessageType_ThpCodeEntrySecret = 1021[(bitcoin_only)=true, (pairing_out) = true]; + ThpMessageType_ThpQrCodeTag = 1024[(bitcoin_only)=true, (pairing_in) = true]; + ThpMessageType_ThpQrCodeSecret = 1025[(bitcoin_only)=true, (pairing_out) = true]; + ThpMessageType_ThpNfcUnidirectionalTag = 1032[(bitcoin_only)=true, (pairing_in) = true]; + ThpMessageType_ThpNfcUnidirectionalSecret = 1033[(bitcoin_only)=true, (pairing_in) = true]; + + reserved 1100 to 2147483647; // Values reserved by other messages, see messages.proto +} + + +/** + * Numeric identifiers of pairing methods. + * @embed + */ +enum ThpPairingMethod { + NoMethod = 1; // Trust without MITM protection. + CodeEntry = 2; // User types code diplayed on Trezor into the host application. + QrCode = 3; // User scans code displayed on Trezor into host application. + NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC. +} + +/** + * @embed + */ +message ThpDeviceProperties { + optional string internal_model = 1; // Internal model name e.g. "T2B1". + optional uint32 model_variant = 2; // Encodes the device properties such as color. + optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode. + optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware. + repeated ThpPairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor. +} + +/** + * @embed + */ +message ThpHandshakeCompletionReqNoisePayload { + optional bytes host_pairing_credential = 1; // Host's pairing credential + repeated ThpPairingMethod pairing_methods = 2; // The pairing methods chosen by the host +} + +/** + * Request: Ask device for a new session with given passphrase. + * @start + * @next ThpNewSession + */ +message ThpCreateNewSession{ + optional string passphrase = 1; + optional bool on_device = 2; // User wants to enter passphrase on the device + optional bool derive_cardano = 3; // If True, Cardano keys will be derived. Ignored with BTC-only +} + +/** + * Response: Contains session_id of the newly created session. + * @end + */ +message ThpNewSession{ + optional uint32 new_session_id = 1; +} + +/** + * Request: Start pairing process. + * @start + * @next ThpCodeEntryCommitment + * @next ThpPairingPreparationsFinished + */ +message ThpStartPairingRequest{ + optional string host_name = 1; // Human-readable host name +} + +/** + * Response: Pairing is ready for user input / OOB communication. + * @next ThpCodeEntryCpace + * @next ThpQrCodeTag + * @next ThpNfcUnidirectionalTag + */ + message ThpPairingPreparationsFinished{ +} + +/** + * Response: If Code Entry is an allowed pairing option, Trezor responds with a commitment. + * @next ThpCodeEntryChallenge + */ +message ThpCodeEntryCommitment { + optional bytes commitment = 1; // SHA-256 of Trezor's random 32-byte secret +} + +/** + * Response: Host responds to Trezor's Code Entry commitment with a challenge. + * @next ThpPairingPreparationsFinished + */ +message ThpCodeEntryChallenge { + optional bytes challenge = 1; // host's random 32-byte challenge +} + +/** + * Request: User selected Code Entry option in Host. Host starts CPACE protocol with Trezor. + * @next ThpCodeEntryCpaceTrezor + */ +message ThpCodeEntryCpaceHost { + optional bytes cpace_host_public_key = 1; // Host's ephemeral CPace public key +} + +/** + * Response: Trezor continues with the CPACE protocol. + * @next ThpCodeEntryTag + */ +message ThpCodeEntryCpaceTrezor { + optional bytes cpace_trezor_public_key = 1; // Trezor's ephemeral CPace public key +} + +/** + * Response: Host continues with the CPACE protocol. + * @next ThpCodeEntrySecret + */ +message ThpCodeEntryTag { + optional bytes tag = 2; // SHA-256 of shared secret +} + +/** + * Response: Trezor finishes the CPACE protocol. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpCodeEntrySecret { + optional bytes secret = 1; // Trezor's secret +} + +/** + * Request: User selected QR Code pairing option. Host sends a QR Tag. + * @next ThpQrCodeSecret + */ +message ThpQrCodeTag { + optional bytes tag = 1; // SHA-256 of shared secret +} + +/** + * Response: Trezor sends the QR secret. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpQrCodeSecret { + optional bytes secret = 1; // Trezor's secret +} + +/** + * Request: User selected Unidirectional NFC pairing option. Host sends an Unidirectional NFC Tag. + * @next ThpNfcUnidirectionalSecret + */ +message ThpNfcUnidirectionalTag { + optional bytes tag = 1; // SHA-256 of shared secret +} + +/** + * Response: Trezor sends the Unidirectioal NFC secret. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpNfcUnidirectionalSecret { + optional bytes secret = 1; // Trezor's secret +} + +/** + * Request: Host requests issuance of a new pairing credential. + * @start + * @next ThpCredentialResponse + */ +message ThpCredentialRequest { + optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake. +} + +/** + * Response: Trezor issues a new pairing credential. + * @next ThpCredentialRequest + * @next ThpEndRequest + */ +message ThpCredentialResponse { + optional bytes trezor_static_pubkey = 1; // Trezor's static public key used in the handshake. + optional bytes credential = 2; // The pairing credential issued by the Trezor to the host. +} + +/** + * Request: Host requests transition to the encrypted traffic phase. + * @start + * @next ThpEndResponse + */ +message ThpEndRequest {} + +/** + * Response: Trezor approves transition to the encrypted traffic phase + * @end + */ +message ThpEndResponse {} + /** * Only for internal use. * @embed diff --git a/common/protob/options.proto b/common/protob/options.proto index 6919d93ab52..f4559a44993 100644 --- a/common/protob/options.proto +++ b/common/protob/options.proto @@ -37,6 +37,10 @@ The convention to achieve this is as follows: optional bool wire_tiny = 50006; // message is handled by Trezor when the USB stack is in tiny mode optional bool wire_bootloader = 50007; // message is only handled by Trezor Bootloader optional bool wire_no_fsm = 50008; // message is not handled by Trezor unless the USB stack is in tiny mode + optional bool channel_in = 50009; + optional bool channel_out = 50010; + optional bool pairing_in = 50011; + optional bool pairing_out = 50012; optional bool bitcoin_only = 60000; // enum value is available on BITCOIN_ONLY build // (messages not marked bitcoin_only will be EXCLUDED) diff --git a/common/protob/pb2py b/common/protob/pb2py index d6cbbde171e..ea12f8d9c1d 100755 --- a/common/protob/pb2py +++ b/common/protob/pb2py @@ -62,6 +62,7 @@ INT_TYPES = ( ) MESSAGE_TYPE_ENUM = "MessageType" +THP_MESSAGE_TYPE_ENUM = "ThpMessageType" LengthDelimited = c.Struct( "len" / c.VarInt, @@ -239,6 +240,9 @@ class ProtoMessage: @classmethod def from_message(cls, descriptor: "Descriptor", message): message_type = find_by_name(descriptor.message_type_enum.value, message.name) + thp_message_type = None + if not isinstance(descriptor.thp_message_type_enum,tuple): + thp_message_type = find_by_name(descriptor.thp_message_type_enum.value, message.name) # use extensions set on the message_type entry (if any) extensions = descriptor.get_extensions(message_type) # override with extensions set on the message itself @@ -248,6 +252,8 @@ class ProtoMessage: wire_type = extensions["wire_type"] elif message_type is not None: wire_type = message_type.number + elif thp_message_type is not None: + wire_type = thp_message_type.number else: wire_type = None @@ -351,10 +357,13 @@ class Descriptor: ] logging.debug(f"found {len(self.files)} bitcoin-only files") - # find message_type enum + # find message_type and thp_message_type enum top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files) self.message_type_enum = find_by_name(top_level_enums, MESSAGE_TYPE_ENUM, ()) + top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files) + self.thp_message_type_enum = find_by_name(top_level_enums, THP_MESSAGE_TYPE_ENUM, ()) self.convert_enum_value_names(self.message_type_enum) + self.convert_enum_value_names(self.thp_message_type_enum) # find messages and enums self.messages = [] @@ -423,6 +432,8 @@ class Descriptor: self._nested_types_from_message(nested.orig) def convert_enum_value_names(self, enum): + if isinstance(enum,tuple): + return for value in enum.value: value.name = strip_enum_prefix(enum.name, value.name) @@ -558,6 +569,8 @@ class RustBlobRenderer: enums = [] cursor = 0 for enum in sorted(self.descriptor.enums, key=lambda e: e.name): + if enum.name == "MessageType": + continue self.enum_map[enum.name] = cursor enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value)) enums.append(enum_blob) diff --git a/core/Makefile b/core/Makefile index a47adaa70b7..3ec21437ec1 100644 --- a/core/Makefile +++ b/core/Makefile @@ -311,6 +311,10 @@ build_unix: templates ## build unix port build_unix_frozen: templates build_cross ## build unix port with frozen modules $(SCONS) $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) TREZOR_EMULATOR_FROZEN=1 +build_unix_frozen_debug: templates build_cross ## build unix port with frozen modules and DEBUG (PYOPT="0") + $(SCONS) $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) TREZOR_EMULATOR_FROZEN=1 \ + PYOPT=0 + build_unix_debug: templates ## build unix port $(SCONS) --max-drift=1 $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \ TREZOR_EMULATOR_ASAN=1 TREZOR_EMULATOR_DEBUGGABLE=1 diff --git a/core/SConscript.firmware b/core/SConscript.firmware index 0140bab9ee4..bc1fcbd8530 100644 --- a/core/SConscript.firmware +++ b/core/SConscript.firmware @@ -564,14 +564,23 @@ if FROZEN: ] if not EVERYTHING else [] )) + if THP: + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py')) + if not THP or PYOPT == '0': + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) - SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) - SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', - exclude=[ - SOURCE_PY_DIR + 'storage/sd_salt.py', - ] if not SDCARD else [] - )) + + exclude_list = [] + if 'sd_card' not in FEATURES_AVAILABLE: + exclude_list.append(SOURCE_PY_DIR + 'storage/sd_salt.py') + if THP and PYOPT == '1': + exclude_list.append(SOURCE_PY_DIR + 'storage/cache_codec.py') + if not THP: + exclude_list.append(SOURCE_PY_DIR + 'storage/cache_thp.py') + + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', exclude=exclude_list)) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/messages/__init__.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/*.py', diff --git a/core/SConscript.unix b/core/SConscript.unix index a7d98ad47c7..8a04c2217ec 100644 --- a/core/SConscript.unix +++ b/core/SConscript.unix @@ -635,14 +635,23 @@ if FROZEN: ] if not EVERYTHING else [] )) + if THP: + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py')) + if not THP or PYOPT == '0': + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) - SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) - SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', - exclude=[ - SOURCE_PY_DIR + 'storage/sd_salt.py', - ] if 'sd_card' not in FEATURES_AVAILABLE else [] - )) + + exclude_list = [] + if 'sd_card' not in FEATURES_AVAILABLE: + exclude_list.append(SOURCE_PY_DIR + 'storage/sd_salt.py') + if THP and PYOPT == '1': + exclude_list.append(SOURCE_PY_DIR + 'storage/cache_codec.py') + if not THP: + exclude_list.append(SOURCE_PY_DIR + 'storage/cache_thp.py') + + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', exclude=exclude_list)) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/messages/__init__.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/*.py', diff --git a/core/emu.py b/core/emu.py index 0cf88a6ca91..ac568e62b04 100755 --- a/core/emu.py +++ b/core/emu.py @@ -282,9 +282,11 @@ def cli( label = "Emulator" assert emulator.client is not None - trezorlib.device.wipe(emulator.client) + trezorlib.device.wipe(emulator.client.get_management_session()) + emulator.client = emulator.client.get_new_client() + trezorlib.debuglink.load_device( - emulator.client, + emulator.client.get_management_session(), mnemonics, pin=None, passphrase_protection=False, diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 05f3a3330a5..183755390c3 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -51,6 +51,8 @@ import storage.cache_codec storage.cache_common import storage.cache_common +storage.cache_thp +import storage.cache_thp storage.common import storage.common storage.debug @@ -419,10 +421,52 @@ import apps.workflow_handlers if utils.USE_THP: + trezor.enums.ThpMessageType + import trezor.enums.ThpMessageType + trezor.enums.ThpPairingMethod + import trezor.enums.ThpPairingMethod + trezor.wire.thp + import trezor.wire.thp + trezor.wire.thp.alternating_bit_protocol + import trezor.wire.thp.alternating_bit_protocol + trezor.wire.thp.channel + import trezor.wire.thp.channel + trezor.wire.thp.channel_manager + import trezor.wire.thp.channel_manager + trezor.wire.thp.checksum + import trezor.wire.thp.checksum + trezor.wire.thp.control_byte + import trezor.wire.thp.control_byte + trezor.wire.thp.cpace + import trezor.wire.thp.cpace + trezor.wire.thp.crypto + import trezor.wire.thp.crypto + trezor.wire.thp.interface_manager + import trezor.wire.thp.interface_manager + trezor.wire.thp.memory_manager + import trezor.wire.thp.memory_manager + trezor.wire.thp.pairing_context + import trezor.wire.thp.pairing_context + trezor.wire.thp.received_message_handler + import trezor.wire.thp.received_message_handler + trezor.wire.thp.session_context + import trezor.wire.thp.session_context + trezor.wire.thp.session_manager + import trezor.wire.thp.session_manager + trezor.wire.thp.thp_main + import trezor.wire.thp.thp_main + trezor.wire.thp.transmission_loop + import trezor.wire.thp.transmission_loop + trezor.wire.thp.writer + import trezor.wire.thp.writer apps.thp import apps.thp + apps.thp.create_new_session + import apps.thp.create_new_session apps.thp.credential_manager import apps.thp.credential_manager + apps.thp.pairing + import apps.thp.pairing if not utils.BITCOIN_ONLY: trezor.enums.BinanceOrderSide diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 5552fc86ba9..ca923b7337b 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -204,33 +204,37 @@ def get_features() -> Features: return f -async def handle_Initialize(msg: Initialize) -> Features: - import storage.cache_codec as cache_codec +if not utils.USE_THP: - session_id = cache_codec.start_session(msg.session_id) + async def handle_Initialize(msg: Initialize) -> Features: + import storage.cache_codec as cache_codec - if not utils.BITCOIN_ONLY: - from storage.cache_common import APP_COMMON_DERIVE_CARDANO + session_id = cache_codec.start_session(msg.session_id) - derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) - have_seed = context.cache_is_set(APP_COMMON_SEED) - if ( - have_seed - and msg.derive_cardano is not None - and msg.derive_cardano != bool(derive_cardano) - ): - # seed is already derived, and host wants to change derive_cardano setting - # => create a new session - cache_codec.end_current_session() - session_id = cache_codec.start_session() - have_seed = False + if not utils.BITCOIN_ONLY: + from storage.cache_common import APP_COMMON_DERIVE_CARDANO + + derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) + have_seed = context.cache_is_set(APP_COMMON_SEED) + if ( + have_seed + and msg.derive_cardano is not None + and msg.derive_cardano != bool(derive_cardano) + ): + # seed is already derived, and host wants to change derive_cardano setting + # => create a new session + cache_codec.end_current_session() + session_id = cache_codec.start_session() + have_seed = False - if not have_seed: - context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) + if not have_seed: + context.cache_set_bool( + APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano) + ) - features = get_features() - features.session_id = session_id - return features + features = get_features() + features.session_id = session_id + return features async def handle_GetFeatures(msg: GetFeatures) -> Features: @@ -464,8 +468,9 @@ def boot() -> None: MT = MessageType # local_cache_global # Register workflow handlers + if not utils.USE_THP: + workflow_handlers.register(MT.Initialize, handle_Initialize) for msg_type, handler in [ - (MT.Initialize, handle_Initialize), (MT.GetFeatures, handle_GetFeatures), (MT.Cancel, handle_Cancel), (MT.LockDevice, handle_LockDevice), diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 35f6b3f60ce..e4e77825aa8 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -6,7 +6,7 @@ APP_CARDANO_ICARUS_TREZOR_SECRET, APP_COMMON_DERIVE_CARDANO, ) -from trezor import wire +from trezor import utils, wire from trezor.crypto import cardano from trezor.wire import context @@ -21,6 +21,7 @@ from trezor import messages from trezor.crypto import bip32 from trezor.enums import CardanoDerivationType + from trezor.wire.protocol_common import Context from apps.common.keychain import Handler, MsgOut from apps.common.paths import Bip32Path @@ -116,7 +117,7 @@ def is_minting_path(path: Bip32Path) -> bool: return path[: len(MINTING_ROOT)] == MINTING_ROOT -def derive_and_store_secrets(passphrase: str) -> None: +def derive_and_store_secrets(ctx: Context, passphrase: str) -> None: assert device.is_initialized() assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) @@ -144,8 +145,7 @@ def derive_and_store_secrets(passphrase: str) -> None: async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: from trezor.enums import CardanoDerivationType - - from apps.common.seed import derive_and_store_roots + from trezor.wire import context if not device.is_initialized(): raise wire.NotInitialized("Device is not initialized") @@ -164,10 +164,13 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai # _get_secret secret = context.cache_get(cache_entry) - if secret is None: - await derive_and_store_roots() - secret = context.cache_get(cache_entry) - assert secret is not None + if not utils.USE_THP: + if secret is None: + from apps.common.seed import derive_and_store_roots_legacy + + await derive_and_store_roots_legacy() + secret = context.cache_get(cache_entry) + assert secret is not None root = cardano.from_secret(secret) return Keychain(root) diff --git a/core/src/apps/common/backup.py b/core/src/apps/common/backup.py index fc56f42f9b7..8037aba6987 100644 --- a/core/src/apps/common/backup.py +++ b/core/src/apps/common/backup.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED -from trezor import wire +from trezor import utils, wire from trezor.enums import MessageType from trezor.wire import context from trezor.wire.message_handler import filters, remove_filter @@ -24,14 +24,23 @@ def deactivate_repeated_backup() -> None: remove_filter(_repeated_backup_filter) -_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( - MessageType.Initialize, - MessageType.GetFeatures, - MessageType.EndSession, - MessageType.BackupDevice, - MessageType.WipeDevice, - MessageType.Cancel, -) +if utils.USE_THP: + _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( + MessageType.GetFeatures, + MessageType.EndSession, + MessageType.BackupDevice, + MessageType.WipeDevice, + MessageType.Cancel, + ) +else: + _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( + MessageType.Initialize, + MessageType.GetFeatures, + MessageType.EndSession, + MessageType.BackupDevice, + MessageType.WipeDevice, + MessageType.Cancel, + ) def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: diff --git a/core/src/apps/common/keychain.py b/core/src/apps/common/keychain.py index 16913d1529e..7959789b251 100644 --- a/core/src/apps/common/keychain.py +++ b/core/src/apps/common/keychain.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from trezor import utils from trezor.crypto import bip32 from trezor.wire import DataError @@ -172,6 +173,9 @@ async def get_keychain( ) -> Keychain: from .seed import get_seed + if not utils.USE_THP: + pass + # try to ask for passphrase here seed = await get_seed() keychain = Keychain(seed, curve, schemas, slip21_namespaces) return keychain diff --git a/core/src/apps/common/passphrase.py b/core/src/apps/common/passphrase.py index ef8bb5b1850..d150dd47369 100644 --- a/core/src/apps/common/passphrase.py +++ b/core/src/apps/common/passphrase.py @@ -1,84 +1,122 @@ from micropython import const +from typing import TYPE_CHECKING import storage.device as storage_device +from trezor import utils from trezor.wire import DataError _MAX_PASSPHRASE_LEN = const(50) +if TYPE_CHECKING: + from trezor.messages import ThpCreateNewSession + def is_enabled() -> bool: return storage_device.is_passphrase_enabled() -async def get() -> str: - from trezor import workflow - +async def get_passphrase(msg: ThpCreateNewSession) -> str: if not is_enabled(): return "" + + if msg.on_device or storage_device.get_passphrase_always_on_device(): + passphrase = await _get_on_device() else: - workflow.close_others() # request exclusive UI access - if storage_device.get_passphrase_always_on_device(): - from trezor.ui.layouts import request_passphrase_on_device + passphrase = msg.passphrase or "" + if passphrase: + await _handle_displaying_passphrase_from_host(passphrase) - passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) - else: - passphrase = await _request_on_host() - if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: - raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes") + if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: + raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes") - return passphrase + return passphrase -async def _request_on_host() -> str: - from trezor import TR - from trezor.messages import PassphraseAck, PassphraseRequest - from trezor.ui.layouts import request_passphrase_on_host - from trezor.wire.context import call +async def _get_on_device() -> str: + from trezor import workflow + from trezor.ui.layouts import request_passphrase_on_device - request_passphrase_on_host() + workflow.close_others() # request exclusive UI access + passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) - request = PassphraseRequest() - ack = await call(request, PassphraseAck) - passphrase = ack.passphrase # local_cache_attribute + return passphrase - if ack.on_device: - from trezor.ui.layouts import request_passphrase_on_device - if passphrase is not None: - raise DataError("Passphrase provided when it should not be") - return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) +async def _handle_displaying_passphrase_from_host(passphrase: str) -> None: + from trezor import TR + from trezor.ui.layouts import confirm_action, confirm_blob + + # We want to hide the passphrase, or show it, according to settings. + if storage_device.get_hide_passphrase_from_host(): + await confirm_action( + "passphrase_host1_hidden", + TR.passphrase__wallet, + description=TR.passphrase__from_host_not_shown, + prompt_screen=True, + prompt_title=TR.passphrase__access_wallet, + ) + else: + await confirm_action( + "passphrase_host1", + TR.passphrase__wallet, + description=TR.passphrase__next_screen_will_show_passphrase, + verb=TR.buttons__continue, + ) - if passphrase is None: - raise DataError( - "Passphrase not provided and on_device is False. Use empty string to set an empty passphrase." + await confirm_blob( + "passphrase_host2", + TR.passphrase__title_confirm, + passphrase, ) - # non-empty passphrase - if passphrase: - from trezor.ui.layouts import confirm_action, confirm_blob - - # We want to hide the passphrase, or show it, according to settings. - if storage_device.get_hide_passphrase_from_host(): - await confirm_action( - "passphrase_host1_hidden", - TR.passphrase__wallet, - description=TR.passphrase__from_host_not_shown, - prompt_screen=True, - prompt_title=TR.passphrase__access_wallet, - ) + +if not utils.USE_THP: + + async def get() -> str: + from trezor import workflow + + if not is_enabled(): + return "" else: - await confirm_action( - "passphrase_host1", - TR.passphrase__wallet, - description=TR.passphrase__next_screen_will_show_passphrase, - verb=TR.buttons__continue, - ) + workflow.close_others() # request exclusive UI access + if storage_device.get_passphrase_always_on_device(): + from trezor.ui.layouts import request_passphrase_on_device + + passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) + else: + passphrase = await _request_on_host() + if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: + raise DataError( + f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes" + ) + + return passphrase + + async def _request_on_host() -> str: + from trezor.messages import PassphraseAck, PassphraseRequest + from trezor.ui.layouts import request_passphrase_on_host + from trezor.wire.context import call + + request_passphrase_on_host() - await confirm_blob( - "passphrase_host2", - TR.passphrase__title_confirm, - passphrase, - info=False, + request = PassphraseRequest() + ack = await call(request, PassphraseAck) + passphrase = ack.passphrase # local_cache_attribute + + if ack.on_device: + from trezor.ui.layouts import request_passphrase_on_device + + if passphrase is not None: + raise DataError("Passphrase provided when it should not be") + return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) + + if passphrase is None: + raise DataError( + "Passphrase not provided and on_device is False. Use empty string to set an empty passphrase." ) - return passphrase + # non-empty passphrase + if passphrase: + await _handle_displaying_passphrase_from_host(passphrase) + + return passphrase diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index b09004ae698..4bb15184f80 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -5,14 +5,18 @@ from trezor import utils from trezor.crypto import hmac from trezor.wire import context +from trezor.wire.context import get_context +from trezor.wire.errors import DataError from apps.common import cache from . import mnemonic -from .passphrase import get as get_passphrase +from .passphrase import get_passphrase as get_passphrase if TYPE_CHECKING: from trezor.crypto import bip32 + from trezor.messages import ThpCreateNewSession + from trezor.wire.protocol_common import Context from .paths import Bip32Path, Slip21Path @@ -22,6 +26,9 @@ APP_COMMON_DERIVE_CARDANO, ) +if not utils.USE_THP: + from .passphrase import get as get_passphrase_legacy + class Slip21Node: """ @@ -54,51 +61,111 @@ def clone(self) -> "Slip21Node": return Slip21Node(data=self.data) -if not utils.BITCOIN_ONLY: - # === Cardano variant === - # We want to derive both the normal seed and the Cardano seed together, AND - # expose a method for Cardano to do the same +if utils.USE_THP: + + async def get_seed() -> bytes: # type: ignore [Function declaration "get_seed" is obscured by a declaration of the same name] + common_seed = context.cache_get(APP_COMMON_SEED) + assert common_seed is not None + return common_seed + + if utils.BITCOIN_ONLY: + # === Bitcoin_only variant === + # We want to derive the normal seed ONLY - async def derive_and_store_roots() -> None: - from trezor import wire + async def derive_and_store_roots( + ctx: Context, msg: ThpCreateNewSession + ) -> None: - if not storage_device.is_initialized(): - raise wire.NotInitialized("Device is not initialized") + if msg.passphrase is not None and msg.on_device: + raise DataError("Passphrase provided when it shouldn't be!") - need_seed = not context.cache_is_set(APP_COMMON_SEED) - need_cardano_secret = context.cache_get_bool( - APP_COMMON_DERIVE_CARDANO - ) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET) + if ctx.cache.is_set(APP_COMMON_SEED): + raise Exception("Seed is already set!") - if not need_seed and not need_cardano_secret: - return + from trezor import wire - passphrase = await get_passphrase() + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") - if need_seed: + passphrase = await get_passphrase(msg) common_seed = mnemonic.get_seed(passphrase) - context.cache_set(APP_COMMON_SEED, common_seed) + ctx.cache.set(APP_COMMON_SEED, common_seed) - if need_cardano_secret: - from apps.cardano.seed import derive_and_store_secrets + else: + # === Cardano variant === + # We want to derive both the normal seed and the Cardano seed together + async def derive_and_store_roots( + ctx: Context, msg: ThpCreateNewSession + ) -> None: - derive_and_store_secrets(passphrase) + if msg.passphrase is not None and msg.on_device: + raise DataError("Passphrase provided when it shouldn't be!") - @cache.stored_async(APP_COMMON_SEED) - async def get_seed() -> bytes: - await derive_and_store_roots() - common_seed = context.cache_get(APP_COMMON_SEED) - assert common_seed is not None - return common_seed + from trezor import wire + + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") + + if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET): + raise Exception("Cardano icarus secret is already set!") + + passphrase = await get_passphrase(msg) + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) + + if msg.derive_cardano: + from apps.cardano.seed import derive_and_store_secrets + + ctx.cache.set_bool(APP_COMMON_DERIVE_CARDANO, True) + derive_and_store_secrets(ctx, passphrase) else: - # === Bitcoin-only variant === - # We use the simple version of `get_seed` that never needs to derive anything else. + if utils.BITCOIN_ONLY: + # === Bitcoin-only variant === + # We use the simple version of `get_seed` that never needs to derive anything else. + + @cache.stored_async(APP_COMMON_SEED) + async def get_seed() -> bytes: + passphrase = await get_passphrase_legacy() + return mnemonic.get_seed(passphrase=passphrase) + + else: + # === Cardano variant === + # We want to derive both the normal seed and the Cardano seed together, AND + # expose a method for Cardano to do the same + + @cache.stored_async(APP_COMMON_SEED) + async def get_seed() -> bytes: + await derive_and_store_roots_legacy() + common_seed = context.cache_get(APP_COMMON_SEED) + assert common_seed is not None + return common_seed + + async def derive_and_store_roots_legacy() -> None: + from trezor import wire + + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") + + ctx = get_context() + need_seed = not ctx.cache.is_set(APP_COMMON_SEED) + need_cardano_secret = ctx.cache.get_bool( + APP_COMMON_DERIVE_CARDANO + ) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET) + + if not need_seed and not need_cardano_secret: + return + + passphrase = await get_passphrase_legacy() + + if need_seed: + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) + + if need_cardano_secret: + from apps.cardano.seed import derive_and_store_secrets - @cache.stored_async(APP_COMMON_SEED) - async def get_seed() -> bytes: - passphrase = await get_passphrase() - return mnemonic.get_seed(passphrase) + derive_and_store_secrets(ctx, passphrase) @cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 3bfd4772e43..46e44a0a4d3 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -1,3 +1,5 @@ +from trezor.wire import message_handler + if not __debug__: from trezor.utils import halt @@ -29,13 +31,14 @@ DebugLinkState, ) from trezor.ui import Layout - from trezor.wire import WireInterface, context + from trezor.wire import WireInterface + from trezor.wire.protocol_common import Context Handler = Callable[[Any], Awaitable[Any]] layout_change_box = loop.mailbox() - DEBUG_CONTEXT: context.Context | None = None + DEBUG_CONTEXT: Context | None = None REFRESH_INDEX = 0 @@ -70,9 +73,7 @@ def wait_until_layout_is_running(timeout: int | None = _DEADLOCK_SLEEP_MS) -> Aw "layout deadlock detected (did you send a ButtonAck?)" ) - async def return_layout_change( - ctx: wire.protocol_common.Context, detect_deadlock: bool = False - ) -> None: + async def return_layout_change(ctx: Context, detect_deadlock: bool = False) -> None: # set up the wait storage.layout_watcher = True @@ -212,12 +213,12 @@ async def dispatch_DebugLinkDecision( x = msg.x # local_cache_attribute y = msg.y # local_cache_attribute - await wait_until_layout_is_running() assert isinstance(ui.CURRENT_LAYOUT, ui.Layout) layout_change_box.clear() try: + # click on specific coordinates, with possible hold if x is not None and y is not None: await _layout_click(x, y, msg.hold_ms or 0) @@ -229,7 +230,11 @@ async def dispatch_DebugLinkDecision( elif msg.button is not None: await _layout_event(msg.button) elif msg.input is not None: - ui.CURRENT_LAYOUT._emit_message(msg.input) + try: + ui.CURRENT_LAYOUT._emit_message(msg.input) + except Exception as e: + print(type(e)) + else: raise RuntimeError("Invalid DebugLinkDecision message") @@ -244,7 +249,11 @@ async def dispatch_DebugLinkDecision( # If no exception was raised, the layout did not shut down. That means that it # just updated itself. The update is already live for the caller to retrieve. - def _state() -> DebugLinkState: + def _state( + thp_pairing_code_entry_code: int | None = None, + thp_pairing_code_qr_code: bytes | None = None, + thp_pairing_code_nfc_unidirectional: bytes | None = None, + ) -> DebugLinkState: from trezor.messages import DebugLinkState from apps.common import mnemonic, passphrase @@ -263,13 +272,45 @@ def callback(*args: str) -> None: passphrase_protection=passphrase.is_enabled(), reset_entropy=storage.reset_internal_entropy, tokens=tokens, + thp_pairing_code_entry_code=thp_pairing_code_entry_code, + thp_pairing_code_qr_code=thp_pairing_code_qr_code, + thp_pairing_code_nfc_unidirectional=thp_pairing_code_nfc_unidirectional, ) async def dispatch_DebugLinkGetState( msg: DebugLinkGetState, ) -> DebugLinkState | None: + + thp_pairing_code_entry_code: int | None = None + thp_pairing_code_qr_code: bytes | None = None + thp_pairing_code_nfc_unidirectional: bytes | None = None + if utils.USE_THP and msg.thp_channel_id is not None: + channel_id = int.from_bytes(msg.thp_channel_id, "big") + + from trezor.wire.thp.channel import Channel + from trezor.wire.thp.pairing_context import PairingContext + from trezor.wire.thp.thp_main import _CHANNELS + + channel: Channel | None = None + ctx: PairingContext | None = None + try: + channel = _CHANNELS[channel_id] + ctx = channel.connection_context + except KeyError: + pass + if ctx is not None and isinstance(ctx, PairingContext): + thp_pairing_code_entry_code = ctx.display_data.code_code_entry + thp_pairing_code_qr_code = ctx.display_data.code_qr_code + thp_pairing_code_nfc_unidirectional = ( + ctx.display_data.code_nfc_unidirectional + ) + if msg.wait_layout == DebugWaitType.IMMEDIATE: - return _state() + return _state( + thp_pairing_code_entry_code, + thp_pairing_code_qr_code, + thp_pairing_code_nfc_unidirectional, + ) assert DEBUG_CONTEXT is not None if msg.wait_layout == DebugWaitType.NEXT_LAYOUT: @@ -280,7 +321,11 @@ async def dispatch_DebugLinkGetState( if not layout_is_ready(): return await return_layout_change(DEBUG_CONTEXT, detect_deadlock=True) else: - return _state() + return _state( + thp_pairing_code_entry_code, + thp_pairing_code_qr_code, + thp_pairing_code_nfc_unidirectional, + ) async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success: if msg.target_directory: @@ -390,7 +435,6 @@ async def handle_session(iface: WireInterface) -> None: ctx.iface.iface_num(), msg_type, ) - if msg.type not in WORKFLOW_HANDLERS: await ctx.write(wire.message_handler.unexpected_message()) continue @@ -403,7 +447,7 @@ async def handle_session(iface: WireInterface) -> None: await ctx.write(Success()) continue - req_msg = wire.message_handler.wrap_protobuf_load(msg.data, req_type) + req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) try: res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg) except Exception as exc: diff --git a/core/src/apps/management/reboot_to_bootloader.py b/core/src/apps/management/reboot_to_bootloader.py index 85596c0268d..2213d2c17a5 100644 --- a/core/src/apps/management/reboot_to_bootloader.py +++ b/core/src/apps/management/reboot_to_bootloader.py @@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn: boot_args = None ctx = get_context() - await ctx.write(Success(message="Rebooting")) + await ctx.write_force(Success(message="Rebooting")) # make sure the outgoing USB buffer is flushed await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE) # reboot to the bootloader, pass the firmware header hash if any diff --git a/core/src/apps/management/recovery_device/__init__.py b/core/src/apps/management/recovery_device/__init__.py index 08eb02b412b..9bc116eb06c 100644 --- a/core/src/apps/management/recovery_device/__init__.py +++ b/core/src/apps/management/recovery_device/__init__.py @@ -24,6 +24,7 @@ async def recovery_device(msg: RecoveryDevice) -> Success: from trezor import TR, config, wire, workflow from trezor.enums import BackupType, ButtonRequestType from trezor.ui.layouts import confirm_action, confirm_reset_device + from trezor.wire.context import try_get_ctx_ids from apps.common import mnemonic from apps.common.request_pin import ( @@ -69,8 +70,8 @@ async def recovery_device(msg: RecoveryDevice) -> Success: if recovery_type == RecoveryType.NormalRecovery: await confirm_reset_device(TR.recovery__title_recover, recovery=True) - # wipe storage to make sure the device is in a clear state - storage.reset() + # wipe storage to make sure the device is in a clear state (except protocol cache) + storage.reset(excluded=try_get_ctx_ids()) # set up pin if requested if msg.pin_protection: diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index 9899b3fe6de..3a3c1380d2b 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -3,8 +3,9 @@ import storage.device as storage_device import storage.recovery as storage_recovery import storage.recovery_shares as storage_recovery_shares -from trezor import TR, wire +from trezor import TR, utils, wire from trezor.messages import Success +from trezor.wire import message_handler from apps.common import backup_types @@ -38,18 +39,26 @@ async def recovery_process() -> Success: recovery_type = storage_recovery.get_type() - wire.message_handler.AVOID_RESTARTING_FOR = ( - MessageType.Initialize, - MessageType.GetFeatures, - MessageType.EndSession, - ) + if utils.USE_THP: + message_handler.AVOID_RESTARTING_FOR = ( + MessageType.GetFeatures, + MessageType.EndSession, + ) + else: + message_handler.AVOID_RESTARTING_FOR = ( + MessageType.Initialize, + MessageType.GetFeatures, + MessageType.EndSession, + ) try: return await _continue_recovery_process() except recover.RecoveryAborted: storage_recovery.end_progress() backup.deactivate_repeated_backup() if recovery_type == RecoveryType.NormalRecovery: - storage.wipe() + from trezor.wire.context import try_get_ctx_ids + + storage.wipe(excluded=try_get_ctx_ids()) raise wire.ActionCancelled @@ -59,11 +68,17 @@ async def _continue_repeated_backup() -> None: from apps.common import backup from apps.management.backup_device import perform_backup - wire.message_handler.AVOID_RESTARTING_FOR = ( - MessageType.Initialize, - MessageType.GetFeatures, - MessageType.EndSession, - ) + if utils.USE_THP: + message_handler.AVOID_RESTARTING_FOR = ( + MessageType.GetFeatures, + MessageType.EndSession, + ) + else: + message_handler.AVOID_RESTARTING_FOR = ( + MessageType.Initialize, + MessageType.GetFeatures, + MessageType.EndSession, + ) try: await perform_backup(is_repeated_backup=True) diff --git a/core/src/apps/management/reset_device/__init__.py b/core/src/apps/management/reset_device/__init__.py index 4e38bfcd8ca..f98edc54725 100644 --- a/core/src/apps/management/reset_device/__init__.py +++ b/core/src/apps/management/reset_device/__init__.py @@ -38,7 +38,7 @@ async def reset_device(msg: ResetDevice) -> Success: prompt_backup, show_wallet_created_success, ) - from trezor.wire.context import call + from trezor.wire.context import call, try_get_ctx_ids from apps.common.request_pin import request_pin_confirm @@ -60,8 +60,8 @@ async def reset_device(msg: ResetDevice) -> Success: # Rendering empty loader so users do not feel a freezing screen render_empty_loader(config.StorageMessage.PROCESSING_MSG) - # wipe storage to make sure the device is in a clear state - storage.reset() + # wipe storage to make sure the device is in a clear state (except protocol cache) + storage.reset(excluded=try_get_ctx_ids()) # request and set new PIN if msg.pin_protection: @@ -121,7 +121,7 @@ async def reset_device(msg: ResetDevice) -> Success: if perform_backup: await layout.show_backup_success() - return Success(message="Initialized") + return Success(message="Initialized") # TODO: Why "Initialized?" async def _backup_bip39(mnemonic: str) -> None: diff --git a/core/src/apps/management/wipe_device.py b/core/src/apps/management/wipe_device.py index b6e60057a6c..1abdc3f3e61 100644 --- a/core/src/apps/management/wipe_device.py +++ b/core/src/apps/management/wipe_device.py @@ -1,12 +1,19 @@ from typing import TYPE_CHECKING +from trezor.wire.context import get_context, try_get_ctx_ids + if TYPE_CHECKING: - from trezor.messages import Success, WipeDevice + from typing import NoReturn + + from trezor.messages import WipeDevice +if __debug__: + from trezor import log -async def wipe_device(msg: WipeDevice) -> Success: + +async def wipe_device(msg: WipeDevice) -> NoReturn: import storage - from trezor import TR, config, translations + from trezor import TR, config, loop, translations from trezor.enums import ButtonRequestType from trezor.messages import Success from trezor.pin import render_empty_loader @@ -26,16 +33,22 @@ async def wipe_device(msg: WipeDevice) -> Success: br_code=ButtonRequestType.WipeDevice, ) + if __debug__: + log.debug(__name__, "Device wipe - start") + # start an empty progress screen so that the screen is not blank while waiting render_empty_loader(config.StorageMessage.PROCESSING_MSG) - # wipe storage - storage.wipe() + storage.wipe(excluded=try_get_ctx_ids()) # erase translations translations.deinit() translations.erase() + await get_context().write_force(Success(message="Device wiped")) + storage.wipe_cache() + # reload settings reload_settings_from_storage() - - return Success(message="Device wiped") + loop.clear() + if __debug__: + log.debug(__name__, "Device wipe - finished") diff --git a/core/src/apps/thp/create_new_session.py b/core/src/apps/thp/create_new_session.py new file mode 100644 index 00000000000..156b852d46d --- /dev/null +++ b/core/src/apps/thp/create_new_session.py @@ -0,0 +1,59 @@ +from trezor import log, loop +from trezor.enums import FailureType +from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession +from trezor.wire.context import get_context +from trezor.wire.errors import ActionCancelled, DataError +from trezor.wire.thp import SessionState + + +async def create_new_session(message: ThpCreateNewSession) -> ThpNewSession | Failure: + """ + Creates a new `ThpSession` based on the provided parameters and returns a + `ThpNewSession` message containing the new session ID. + + Returns an appropriate `Failure` message if session creation fails. + """ + from trezor.wire import NotInitialized + from trezor.wire.thp.session_context import GenericSessionContext + from trezor.wire.thp.session_manager import create_new_session + + from apps.common.seed import derive_and_store_roots + + ctx = get_context() + + # Assert that context `ctx` is `GenericSessionContext` + assert isinstance(ctx, GenericSessionContext) + + channel = ctx.channel + + # Do not use `ctx` beyond this point, as it is techically + # allowed to change in between await statements + + new_session = create_new_session(channel) + try: + await derive_and_store_roots(new_session, message) + except DataError as e: + return Failure(code=FailureType.DataError, message=e.message) + except ActionCancelled as e: + return Failure(code=FailureType.ActionCancelled, message=e.message) + except NotInitialized as e: + return Failure(code=FailureType.NotInitialized, message=e.message) + # TODO handle other errors (`Exception`` when "Cardano icarus secret is already set!" + # and `RuntimeError` when accessing storage for mnemonic.get_secret - it actually + # happens for locked devices) + + new_session.set_session_state(SessionState.ALLOCATED) + channel.sessions[new_session.session_id] = new_session + loop.schedule(new_session.handle()) + new_session_id: int = new_session.session_id + + if __debug__: + log.debug( + __name__, + "create_new_session - new session created. Passphrase: %s, Session id: %d\n%s", + message.passphrase if message.passphrase is not None else "", + new_session.session_id, + str(channel.sessions), + ) + + return ThpNewSession(new_session_id=new_session_id) diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py new file mode 100644 index 00000000000..a14e4d032e6 --- /dev/null +++ b/core/src/apps/thp/pairing.py @@ -0,0 +1,403 @@ +from typing import TYPE_CHECKING +from ubinascii import hexlify + +from trezor import loop, protobuf +from trezor.crypto.hashlib import sha256 +from trezor.enums import ThpMessageType, ThpPairingMethod +from trezor.messages import ( + Cancel, + ThpCodeEntryChallenge, + ThpCodeEntryCommitment, + ThpCodeEntryCpaceHost, + ThpCodeEntryCpaceTrezor, + ThpCodeEntrySecret, + ThpCodeEntryTag, + ThpCredentialMetadata, + ThpCredentialRequest, + ThpCredentialResponse, + ThpEndRequest, + ThpEndResponse, + ThpNfcUnidirectionalSecret, + ThpNfcUnidirectionalTag, + ThpPairingPreparationsFinished, + ThpQrCodeSecret, + ThpQrCodeTag, + ThpStartPairingRequest, +) +from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage +from trezor.wire.thp import ChannelState, ThpError, crypto +from trezor.wire.thp.pairing_context import PairingContext + +from .credential_manager import issue_credential + +if __debug__: + from trezor import log + +if TYPE_CHECKING: + from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple + + P = ParamSpec("P") + FuncWithContext = Callable[Concatenate[PairingContext, P], Any] + +# +# Helpers - decorators + + +def check_state_and_log( + *allowed_states: ChannelState, +) -> Callable[[FuncWithContext], FuncWithContext]: + def decorator(f: FuncWithContext) -> FuncWithContext: + def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object: + _check_state(context, *allowed_states) + if __debug__: + try: + log.debug(__name__, "started %s", f.__name__) + except AttributeError: + log.debug( + __name__, + "started a function that cannot be named, because it raises AttributeError, eg. closure", + ) + return f(context, *args, **kwargs) + + return inner + + return decorator + + +def check_method_is_allowed( + pairing_method: ThpPairingMethod, +) -> Callable[[FuncWithContext], FuncWithContext]: + def decorator(f: FuncWithContext) -> FuncWithContext: + def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object: + _check_method_is_allowed(context, pairing_method) + return f(context, *args, **kwargs) + + return inner + + return decorator + + +# +# Pairing handlers + + +@check_state_and_log(ChannelState.TP1) +async def handle_pairing_request( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: + + if not ThpStartPairingRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + + ctx.host_name = message.host_name or "" + + skip_pairing = _is_method_included(ctx, ThpPairingMethod.NoMethod) + if skip_pairing: + return await _end_pairing(ctx) + + await _prepare_pairing(ctx) + await ctx.write(ThpPairingPreparationsFinished()) + ctx.channel_ctx.set_channel_state(ChannelState.TP3) + response = await show_display_data( + ctx, _get_possible_pairing_methods_and_cancel(ctx) + ) + + if Cancel.is_type_of(response): + ctx.channel_ctx.clear() + raise SilentError("Action was cancelled by the Host") + # TODO disable NFC (if enabled) + response = await _handle_different_pairing_methods(ctx, response) + + while ThpCredentialRequest.is_type_of(response): + response = await _handle_credential_request(ctx, response) + + return await _handle_end_request(ctx, response) + + +async def _prepare_pairing(ctx: PairingContext) -> None: + + if _is_method_included(ctx, ThpPairingMethod.CodeEntry): + await _handle_code_entry_is_included(ctx) + + if _is_method_included(ctx, ThpPairingMethod.QrCode): + _handle_qr_code_is_included(ctx) + + if _is_method_included(ctx, ThpPairingMethod.NFC_Unidirectional): + _handle_nfc_unidirectional_is_included(ctx) + + +async def show_display_data( + ctx: PairingContext, expected_types: Container[int] = () +) -> type[protobuf.MessageType]: + from trezorui2 import CANCELLED + + read_task = ctx.read(expected_types) + cancel_task = ctx.display_data.get_display_layout() + race = loop.race(read_task, cancel_task.get_result()) + result: type[protobuf.MessageType] = await race + + if result is CANCELLED: + raise ActionCancelled + + return result + + +@check_state_and_log(ChannelState.TP1) +async def _handle_code_entry_is_included(ctx: PairingContext) -> None: + commitment = sha256(ctx.secret).digest() + + challenge_message = await ctx.call( # noqa: F841 + ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge + ) + ctx.channel_ctx.set_channel_state(ChannelState.TP2) + + if not ThpCodeEntryChallenge.is_type_of(challenge_message): + raise UnexpectedMessage("Unexpected message") + + if challenge_message.challenge is None: + raise Exception("Invalid message") + sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.secret) + sha_ctx.update(challenge_message.challenge) + sha_ctx.update(bytes("PairingMethod_CodeEntry", "utf-8")) + code_code_entry_hash = sha_ctx.digest() + ctx.display_data.code_code_entry = ( + int.from_bytes(code_code_entry_hash, "big") % 1000000 + ) + + +@check_state_and_log(ChannelState.TP1, ChannelState.TP2) +def _handle_qr_code_is_included(ctx: PairingContext) -> None: + sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.secret) + sha_ctx.update(bytes("PairingMethod_QrCode", "utf-8")) + ctx.display_data.code_qr_code = sha_ctx.digest()[:16] + + +@check_state_and_log(ChannelState.TP1, ChannelState.TP2) +def _handle_nfc_unidirectional_is_included(ctx: PairingContext) -> None: + sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) + sha_ctx.update(ctx.secret) + sha_ctx.update(bytes("PairingMethod_NfcUnidirectional", "utf-8")) + ctx.display_data.code_nfc_unidirectional = sha_ctx.digest()[:16] + + +@check_state_and_log(ChannelState.TP3) +async def _handle_different_pairing_methods( + ctx: PairingContext, response: protobuf.MessageType +) -> protobuf.MessageType: + if ThpCodeEntryCpaceHost.is_type_of(response): + return await _handle_code_entry_cpace(ctx, response) + if ThpQrCodeTag.is_type_of(response): + return await _handle_qr_code_tag(ctx, response) + if ThpNfcUnidirectionalTag.is_type_of(response): + return await _handle_nfc_unidirectional_tag(ctx, response) + raise UnexpectedMessage("Unexpected message") + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed(ThpPairingMethod.CodeEntry) +async def _handle_code_entry_cpace( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + from trezor.wire.thp.cpace import Cpace + + # TODO check that ThpCodeEntryCpaceHost message is valid + + if TYPE_CHECKING: + assert isinstance(message, ThpCodeEntryCpaceHost) + if message.cpace_host_public_key is None: + raise ThpError("Message ThpCodeEntryCpaceHost has no public key") + + ctx.cpace = Cpace( + message.cpace_host_public_key, + ctx.channel_ctx.get_handshake_hash(), + ) + assert ctx.display_data.code_code_entry is not None + ctx.cpace.generate_keys_and_secret( + ctx.display_data.code_code_entry.to_bytes(6, "big") + ) + + ctx.channel_ctx.set_channel_state(ChannelState.TP4) + response = await ctx.call( + ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key), + ThpCodeEntryTag, + ) + return await _handle_code_entry_tag(ctx, response) + + +@check_state_and_log(ChannelState.TP4) +@check_method_is_allowed(ThpPairingMethod.CodeEntry) +async def _handle_code_entry_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + + if TYPE_CHECKING: + assert isinstance(message, ThpCodeEntryTag) + + expected_tag = sha256(ctx.cpace.shared_secret).digest() + if expected_tag != message.tag: + print( + "expected code entry tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + print( + "expected code entry shared secret:", + hexlify(ctx.cpace.shared_secret).decode(), + ) # TODO remove after testing + raise ThpError("Unexpected Code Entry Tag") + + return await _handle_secret_reveal( + ctx, + msg=ThpCodeEntrySecret(secret=ctx.secret), + ) + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed(ThpPairingMethod.QrCode) +async def _handle_qr_code_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + if TYPE_CHECKING: + assert isinstance(message, ThpQrCodeTag) + assert ctx.display_data.code_qr_code is not None + expected_tag = sha256(ctx.display_data.code_qr_code).digest() + if expected_tag != message.tag: + print( + "expected qr code tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + print( + "expected code qr code tag:", + hexlify(ctx.display_data.code_qr_code).decode(), + ) # TODO remove after testing + print( + "expected secret:", hexlify(ctx.secret).decode() + ) # TODO remove after testing + raise ThpError("Unexpected QR Code Tag") + + return await _handle_secret_reveal( + ctx, + msg=ThpQrCodeSecret(secret=ctx.secret), + ) + + +@check_state_and_log(ChannelState.TP3) +@check_method_is_allowed(ThpPairingMethod.NFC_Unidirectional) +async def _handle_nfc_unidirectional_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + if TYPE_CHECKING: + assert isinstance(message, ThpNfcUnidirectionalTag) + + expected_tag = sha256(ctx.display_data.code_nfc_unidirectional).digest() + if expected_tag != message.tag: + print( + "expected nfc tag:", hexlify(expected_tag).decode() + ) # TODO remove after testing + raise ThpError("Unexpected NFC Unidirectional Tag") + + return await _handle_secret_reveal( + ctx, + msg=ThpNfcUnidirectionalSecret(secret=ctx.secret), + ) + + +@check_state_and_log(ChannelState.TP3, ChannelState.TP4) +async def _handle_secret_reveal( + ctx: PairingContext, + msg: protobuf.MessageType, +) -> protobuf.MessageType: + ctx.channel_ctx.set_channel_state(ChannelState.TC1) + return await ctx.call_any( + msg, + ThpMessageType.ThpCredentialRequest, + ThpMessageType.ThpEndRequest, + ) + + +@check_state_and_log(ChannelState.TC1) +async def _handle_credential_request( + ctx: PairingContext, message: protobuf.MessageType +) -> protobuf.MessageType: + ctx.secret + + if not ThpCredentialRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + if message.host_static_pubkey is None: + raise Exception("Invalid message") # TODO change failure type + + trezor_static_pubkey = crypto.get_trezor_static_pubkey() + credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name) + credential = issue_credential(message.host_static_pubkey, credential_metadata) + + return await ctx.call_any( + ThpCredentialResponse( + trezor_static_pubkey=trezor_static_pubkey, credential=credential + ), + ThpMessageType.ThpCredentialRequest, + ThpMessageType.ThpEndRequest, + ) + + +@check_state_and_log(ChannelState.TC1) +async def _handle_end_request( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: + if not ThpEndRequest.is_type_of(message): + raise UnexpectedMessage("Unexpected message") + return await _end_pairing(ctx) + + +async def _end_pairing(ctx: PairingContext) -> ThpEndResponse: + ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + return ThpEndResponse() + + +# +# Helpers - checkers + + +def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None: + if ctx.channel_ctx.get_channel_state() not in allowed_states: + raise UnexpectedMessage("Unexpected message") + + +def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None: + if not _is_method_included(ctx, method): + raise ThpError("Unexpected pairing method") + + +def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: + return method in ctx.channel_ctx.selected_pairing_methods + + +# +# Helpers - getters + + +def _get_possible_pairing_methods_and_cancel(ctx: PairingContext) -> Tuple[int, ...]: + r = _get_possible_pairing_methods(ctx) + mtype = Cancel.MESSAGE_WIRE_TYPE + return r + ((mtype,) if mtype is not None else ()) + + +def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]: + r = tuple( + _get_message_type_for_method(method) + for method in ctx.channel_ctx.selected_pairing_methods + ) + if __debug__: + from trezor.messages import DebugLinkGetState + + mtype = DebugLinkGetState.MESSAGE_WIRE_TYPE + return r + ((mtype,) if mtype is not None else ()) + return r + + +def _get_message_type_for_method(method: int) -> int: + if method is ThpPairingMethod.CodeEntry: + return ThpMessageType.ThpCodeEntryCpaceHost + if method is ThpPairingMethod.NFC_Unidirectional: + return ThpMessageType.ThpNfcUnidirectionalTag + if method is ThpPairingMethod.QrCode: + return ThpMessageType.ThpQrCodeTag + raise ValueError("Unexpected pairing method - no message type available") diff --git a/core/src/apps/workflow_handlers.py b/core/src/apps/workflow_handlers.py index b65c853c93f..3013516382b 100644 --- a/core/src/apps/workflow_handlers.py +++ b/core/src/apps/workflow_handlers.py @@ -35,6 +35,13 @@ def _find_message_handler_module(msg_type: int) -> str: if __debug__ and msg_type == MessageType.BenchmarkRun: return "apps.benchmark.run" + if utils.USE_THP: + from trezor.enums import ThpMessageType + + # thp management + if msg_type == ThpMessageType.ThpCreateNewSession: + return "apps.thp.create_new_session" + # management if msg_type == MessageType.ResetDevice: return "apps.management.reset_device" diff --git a/core/src/storage/__init__.py b/core/src/storage/__init__.py index 3a012874f3d..2fe2c845d9d 100644 --- a/core/src/storage/__init__.py +++ b/core/src/storage/__init__.py @@ -1,11 +1,27 @@ # make sure to import cache unconditionally at top level so that it is imported (and retained) together with the storage module +from typing import TYPE_CHECKING + from storage import cache, common, device +if TYPE_CHECKING: + from typing import Tuple + + pass -def wipe() -> None: + +def wipe(excluded: Tuple[bytes, bytes] | None) -> None: + """ + TODO REPHRASE SO THAT IT IS TRUE! Wipes the storage. Using `exclude_protocol=False` destroys the THP communication channel. + If the device should communicate after wipe, use `exclude_protocol=True` and clear cache manually later using + `wipe_cache()`. + """ from trezor import config config.wipe() + cache.clear_all(excluded) + + +def wipe_cache() -> None: cache.clear_all() @@ -21,12 +37,12 @@ def init_unlocked() -> None: common.set_bool(common.APP_DEVICE, device.INITIALIZED, True, public=True) -def reset() -> None: +def reset(excluded: Tuple[bytes, bytes] | None) -> None: """ Wipes storage but keeps the device id unchanged. """ device_id = device.get_device_id() - wipe() + wipe(excluded) common.set(common.APP_DEVICE, device.DEVICE_ID, device_id.encode(), public=True) diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 72d8a1e4188..6db224a782d 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,26 +1,47 @@ import builtins import gc +from typing import TYPE_CHECKING -from storage import cache_codec from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache +from trezor import utils + +if TYPE_CHECKING: + from typing import Tuple + + pass # Cache initialization _SESSIONLESS_CACHE = SessionlessCache() -_PROTOCOL_CACHE = cache_codec + + +if utils.USE_THP: + from storage import cache_thp + + _PROTOCOL_CACHE = cache_thp +else: + from storage import cache_codec + + _PROTOCOL_CACHE = cache_codec + _PROTOCOL_CACHE.initialize() _SESSIONLESS_CACHE.clear() gc.collect() -def clear_all() -> None: +def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None: """ Clears all data from both the protocol cache and the sessionless cache. """ global autolock_last_touch autolock_last_touch = None _SESSIONLESS_CACHE.clear() - _PROTOCOL_CACHE.clear_all() + + if utils.USE_THP and excluded is not None: + # If we want to keep THP connection alive, we do not clear communication keys + cache_thp.clear_all_except_one_session_keys(excluded) + else: + _PROTOCOL_CACHE.clear_all() def get_int_all_sessions(key: int) -> builtins.set[int]: diff --git a/core/src/storage/cache_common.py b/core/src/storage/cache_common.py index 90cead81db5..40eee905ccd 100644 --- a/core/src/storage/cache_common.py +++ b/core/src/storage/cache_common.py @@ -14,6 +14,14 @@ APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) APP_MONERO_LIVE_REFRESH = const(7) +# Cache keys for THP channel +if utils.USE_THP: + CHANNEL_HANDSHAKE_HASH = const(0) + CHANNEL_KEY_RECEIVE = const(1) + CHANNEL_KEY_SEND = const(2) + CHANNEL_NONCE_RECEIVE = const(3) + CHANNEL_NONCE_SEND = const(4) + # Keys that are valid across sessions SESSIONLESS_FLAG = const(128) APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py new file mode 100644 index 00000000000..6ed41b8415d --- /dev/null +++ b/core/src/storage/cache_thp.py @@ -0,0 +1,363 @@ +import builtins +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import DataCache + +if TYPE_CHECKING: + from typing import Tuple + + pass + + +# THP specific constants +_MAX_CHANNELS_COUNT = const(10) +_MAX_SESSIONS_COUNT = const(20) + + +_CHANNEL_STATE_LENGTH = const(1) +_WIRE_INTERFACE_LENGTH = const(1) +_SESSION_STATE_LENGTH = const(1) +_CHANNEL_ID_LENGTH = const(2) +SESSION_ID_LENGTH = const(1) +BROADCAST_CHANNEL_ID = const(0xFFFF) +KEY_LENGTH = const(32) +TAG_LENGTH = const(16) +_UNALLOCATED_STATE = const(0) +_MANAGEMENT_STATE = const(2) +MANAGEMENT_SESSION_ID = const(0) + + +class ThpDataCache(DataCache): + def __init__(self) -> None: + self.channel_id = bytearray(_CHANNEL_ID_LENGTH) + self.last_usage = 0 + super().__init__() + + def clear(self) -> None: + self.channel_id[:] = b"" + self.last_usage = 0 + super().clear() + + +class ChannelCache(ThpDataCache): + def __init__(self) -> None: + self.host_ephemeral_pubkey = bytearray(KEY_LENGTH) + self.state = bytearray(_CHANNEL_STATE_LENGTH) + self.iface = bytearray(1) # TODO add decoding + self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5) + self.session_id_counter = 0x00 + self.fields = ( + 32, # CHANNEL_HANDSHAKE_HASH + 32, # CHANNEL_KEY_RECEIVE + 32, # CHANNEL_KEY_SEND + 8, # CHANNEL_NONCE_RECEIVE + 8, # CHANNEL_NONCE_SEND + ) + super().__init__() + + def clear(self) -> None: + self.state[:] = bytearray( + int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big") + ) # Set state to UNALLOCATED + self.host_ephemeral_pubkey[:] = bytearray(KEY_LENGTH) + self.state[:] = bytearray(_CHANNEL_STATE_LENGTH) + self.iface[:] = bytearray(1) + super().clear() + + +class SessionThpCache(ThpDataCache): + def __init__(self) -> None: + from trezor import utils + + self.session_id = bytearray(SESSION_ID_LENGTH) + self.state = bytearray(_SESSION_STATE_LENGTH) + if utils.BITCOIN_ONLY: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + ) + else: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + 0, # APP_COMMON_DERIVE_CARDANO + 96, # APP_CARDANO_ICARUS_SECRET + 96, # APP_CARDANO_ICARUS_TREZOR_SECRET + 0, # APP_MONERO_LIVE_REFRESH + ) + super().__init__() + + def clear(self) -> None: + self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED + self.session_id[:] = b"" + super().clear() + + +_CHANNELS: list[ChannelCache] = [] +_SESSIONS: list[SessionThpCache] = [] +cid_counter: int = 0 + +# Last-used counter +_usage_counter = 0 + + +def initialize() -> None: + global _CHANNELS + global _SESSIONS + global cid_counter + + for _ in range(_MAX_CHANNELS_COUNT): + _CHANNELS.append(ChannelCache()) + for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionThpCache()) + + for channel in _CHANNELS: + channel.clear() + for session in _SESSIONS: + session.clear() + + from trezorcrypto import random + + cid_counter = random.uniform(0xFFFE) + + +def get_new_channel(iface: bytes) -> ChannelCache: + if len(iface) != _WIRE_INTERFACE_LENGTH: + raise Exception("Invalid WireInterface (encoded) length") + + new_cid = get_next_channel_id() + index = _get_next_channel_index() + + # clear sessions from replaced channel + if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE: + old_cid = _CHANNELS[index].channel_id + clear_sessions_with_channel_id(old_cid) + + _CHANNELS[index] = ChannelCache() + _CHANNELS[index].channel_id[:] = new_cid + _CHANNELS[index].last_usage = _get_usage_counter_and_increment() + _CHANNELS[index].state[:] = bytearray( + _UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big") + ) + _CHANNELS[index].iface[:] = bytearray(iface) + return _CHANNELS[index] + + +def update_channel_last_used(channel_id: bytes) -> None: + for channel in _CHANNELS: + if channel.channel_id == channel_id: + channel.last_usage = _get_usage_counter_and_increment() + return + + +def update_session_last_used(channel_id: bytes, session_id: bytes) -> None: + for session in _SESSIONS: + if session.channel_id == channel_id and session.session_id == session_id: + session.last_usage = _get_usage_counter_and_increment() + update_channel_last_used(channel_id) + return + + +def get_all_allocated_channels() -> list[ChannelCache]: + _list: list[ChannelCache] = [] + for channel in _CHANNELS: + if _get_channel_state(channel) != _UNALLOCATED_STATE: + _list.append(channel) + return _list + + +def get_allocated_session( + channel_id: bytes, session_id: bytes +) -> SessionThpCache | None: + """ + Finds and returns the first allocated session matching the given `channel_id` and `session_id`, + or `None` if no match is found. + + Raises `Exception` if either channel_id or session_id has an invalid length. + """ + if len(channel_id) != _CHANNEL_ID_LENGTH or len(session_id) != SESSION_ID_LENGTH: + raise Exception("At least one of arguments has invalid length") + + for session in _SESSIONS: + if _get_session_state(session) == _UNALLOCATED_STATE: + continue + if session.channel_id != channel_id: + continue + if session.session_id != session_id: + continue + return session + return None + + +def is_management_session(session_cache: SessionThpCache) -> bool: + return _get_session_state(session_cache) == _MANAGEMENT_STATE + + +def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None: + if len(key) != KEY_LENGTH: + raise Exception("Invalid key length") + channel.host_ephemeral_pubkey = key + + +def get_new_session(channel: ChannelCache) -> SessionThpCache: + new_sid = get_next_session_id(channel) + index = _get_next_session_index() + + _SESSIONS[index] = SessionThpCache() + _SESSIONS[index].channel_id[:] = channel.channel_id + _SESSIONS[index].session_id[:] = new_sid + _SESSIONS[index].last_usage = _get_usage_counter_and_increment() + channel.last_usage = ( + _get_usage_counter_and_increment() + ) # increment also use of the channel so it does not get replaced + _SESSIONS[index].state[:] = bytearray( + _UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big") + ) + return _SESSIONS[index] + + +def _get_usage_counter_and_increment() -> int: + global _usage_counter + _usage_counter += 1 + return _usage_counter + + +def _get_next_channel_index() -> int: + idx = _get_unallocated_channel_index() + if idx is not None: + return idx + return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT) + + +def _get_next_session_index() -> int: + idx = _get_unallocated_session_index() + if idx is not None: + return idx + return get_least_recently_used_item(_SESSIONS, max_count=_MAX_SESSIONS_COUNT) + + +def _get_unallocated_channel_index() -> int | None: + for i in range(_MAX_CHANNELS_COUNT): + if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE: + return i + return None + + +def _get_unallocated_session_index() -> int | None: + for i in range(_MAX_SESSIONS_COUNT): + if (_SESSIONS[i]) is _UNALLOCATED_STATE: + return i + return None + + +def _get_channel_state(channel: ChannelCache) -> int: + return int.from_bytes(channel.state, "big") + + +def _get_session_state(session: SessionThpCache) -> int: + return int.from_bytes(session.state, "big") + + +def get_next_channel_id() -> bytes: + global cid_counter + while True: + cid_counter += 1 + if cid_counter >= BROADCAST_CHANNEL_ID: + cid_counter = 1 + if _is_cid_unique(): + break + return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big") + + +def get_next_session_id(channel: ChannelCache) -> bytes: + while True: + if channel.session_id_counter >= 255: + channel.session_id_counter = 1 + else: + channel.session_id_counter += 1 + if _is_session_id_unique(channel): + break + new_sid = channel.session_id_counter + return new_sid.to_bytes(SESSION_ID_LENGTH, "big") + + +def _is_session_id_unique(channel: ChannelCache) -> bool: + for session in _SESSIONS: + if session.channel_id == channel.channel_id: + if session.session_id == channel.session_id_counter: + return False + return True + + +def _is_cid_unique() -> bool: + global cid_counter + cid_counter_bytes = cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big") + for channel in _CHANNELS: + if channel.channel_id == cid_counter_bytes: + return False + return True + + +def get_least_recently_used_item( + list: list[ChannelCache] | list[SessionThpCache], max_count: int +) -> int: + global _usage_counter + lru_counter = _usage_counter + 1 + lru_item_index = 0 + for i in range(max_count): + if list[i].last_usage < lru_counter: + lru_counter = list[i].last_usage + lru_item_index = i + return lru_item_index + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + values = builtins.set() + for session in _SESSIONS: + encoded = session.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + + +def clear_sessions_with_channel_id(channel_id: bytes) -> None: + for session in _SESSIONS: + if session.channel_id == channel_id: + session.clear() + + +def clear_session(session: SessionThpCache) -> None: + for s in _SESSIONS: + if s.channel_id == session.channel_id and s.session_id == session.session_id: + session.clear() + + +def clear_all() -> None: + for session in _SESSIONS: + session.clear() + for channel in _CHANNELS: + channel.clear() + + +def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None: + cid, sid = excluded + + for channel in _CHANNELS: + if channel.channel_id != cid: + channel.clear() + + for session in _SESSIONS: + if session.channel_id != cid and session.session_id != sid: + session.clear() + else: + s_last_usage = session.last_usage + session.clear() + session.last_usage = s_last_usage + session.state = bytearray(_MANAGEMENT_STATE.to_bytes(1, "big")) + session.session_id[:] = bytearray(sid) + session.channel_id[:] = bytearray(cid) diff --git a/core/src/trezor/enums/FailureType.py b/core/src/trezor/enums/FailureType.py index fbb2001e54c..883844307a1 100644 --- a/core/src/trezor/enums/FailureType.py +++ b/core/src/trezor/enums/FailureType.py @@ -16,4 +16,6 @@ PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 +ThpUnallocatedSession = 15 +InvalidProtocol = 16 FirmwareError = 99 diff --git a/core/src/trezor/enums/ThpMessageType.py b/core/src/trezor/enums/ThpMessageType.py new file mode 100644 index 00000000000..45a34120e5a --- /dev/null +++ b/core/src/trezor/enums/ThpMessageType.py @@ -0,0 +1,22 @@ +# Automatically generated by pb2py +# fmt: off +# isort:skip_file + +ThpCreateNewSession = 1000 +ThpNewSession = 1001 +ThpStartPairingRequest = 1008 +ThpPairingPreparationsFinished = 1009 +ThpCredentialRequest = 1010 +ThpCredentialResponse = 1011 +ThpEndRequest = 1012 +ThpEndResponse = 1013 +ThpCodeEntryCommitment = 1016 +ThpCodeEntryChallenge = 1017 +ThpCodeEntryCpaceHost = 1018 +ThpCodeEntryCpaceTrezor = 1019 +ThpCodeEntryTag = 1020 +ThpCodeEntrySecret = 1021 +ThpQrCodeTag = 1024 +ThpQrCodeSecret = 1025 +ThpNfcUnidirectionalTag = 1032 +ThpNfcUnidirectionalSecret = 1033 diff --git a/core/src/trezor/enums/ThpPairingMethod.py b/core/src/trezor/enums/ThpPairingMethod.py new file mode 100644 index 00000000000..b356cdf470b --- /dev/null +++ b/core/src/trezor/enums/ThpPairingMethod.py @@ -0,0 +1,8 @@ +# Automatically generated by pb2py +# fmt: off +# isort:skip_file + +NoMethod = 1 +CodeEntry = 2 +QrCode = 3 +NFC_Unidirectional = 4 diff --git a/core/src/trezor/enums/__init__.py b/core/src/trezor/enums/__init__.py index 9d9fef32751..62ef86b8109 100644 --- a/core/src/trezor/enums/__init__.py +++ b/core/src/trezor/enums/__init__.py @@ -39,6 +39,8 @@ class FailureType(IntEnum): PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 + ThpUnallocatedSession = 15 + InvalidProtocol = 16 FirmwareError = 99 class ButtonRequestType(IntEnum): @@ -347,6 +349,32 @@ class TezosBallotType(IntEnum): Nay = 1 Pass = 2 + class ThpMessageType(IntEnum): + ThpCreateNewSession = 1000 + ThpNewSession = 1001 + ThpStartPairingRequest = 1008 + ThpPairingPreparationsFinished = 1009 + ThpCredentialRequest = 1010 + ThpCredentialResponse = 1011 + ThpEndRequest = 1012 + ThpEndResponse = 1013 + ThpCodeEntryCommitment = 1016 + ThpCodeEntryChallenge = 1017 + ThpCodeEntryCpaceHost = 1018 + ThpCodeEntryCpaceTrezor = 1019 + ThpCodeEntryTag = 1020 + ThpCodeEntrySecret = 1021 + ThpQrCodeTag = 1024 + ThpQrCodeSecret = 1025 + ThpNfcUnidirectionalTag = 1032 + ThpNfcUnidirectionalSecret = 1033 + + class ThpPairingMethod(IntEnum): + NoMethod = 1 + CodeEntry = 2 + QrCode = 3 + NFC_Unidirectional = 4 + class MessageType(IntEnum): Initialize = 0 Ping = 1 diff --git a/core/src/trezor/messages.py b/core/src/trezor/messages.py index f60dd153048..d230774c81a 100644 --- a/core/src/trezor/messages.py +++ b/core/src/trezor/messages.py @@ -68,6 +68,8 @@ def __getattr__(name: str) -> Any: from trezor.enums import StellarSignerType # noqa: F401 from trezor.enums import TezosBallotType # noqa: F401 from trezor.enums import TezosContractType # noqa: F401 + from trezor.enums import ThpMessageType # noqa: F401 + from trezor.enums import ThpPairingMethod # noqa: F401 from trezor.enums import WordRequestType # noqa: F401 class BenchmarkListNames(protobuf.MessageType): @@ -2866,11 +2868,13 @@ def is_type_of(cls, msg: Any) -> TypeGuard["DebugLinkRecordScreen"]: class DebugLinkGetState(protobuf.MessageType): wait_layout: "DebugWaitType" + thp_channel_id: "bytes | None" def __init__( self, *, wait_layout: "DebugWaitType | None" = None, + thp_channel_id: "bytes | None" = None, ) -> None: pass @@ -2892,6 +2896,9 @@ class DebugLinkState(protobuf.MessageType): reset_word_pos: "int | None" mnemonic_type: "BackupType | None" tokens: "list[str]" + thp_pairing_code_entry_code: "int | None" + thp_pairing_code_qr_code: "bytes | None" + thp_pairing_code_nfc_unidirectional: "bytes | None" def __init__( self, @@ -2909,6 +2916,9 @@ def __init__( recovery_word_pos: "int | None" = None, reset_word_pos: "int | None" = None, mnemonic_type: "BackupType | None" = None, + thp_pairing_code_entry_code: "int | None" = None, + thp_pairing_code_qr_code: "bytes | None" = None, + thp_pairing_code_nfc_unidirectional: "bytes | None" = None, ) -> None: pass @@ -6130,6 +6140,278 @@ def __init__( def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]: return isinstance(msg, cls) + class ThpDeviceProperties(protobuf.MessageType): + internal_model: "str | None" + model_variant: "int | None" + bootloader_mode: "bool | None" + protocol_version: "int | None" + pairing_methods: "list[ThpPairingMethod]" + + def __init__( + self, + *, + pairing_methods: "list[ThpPairingMethod] | None" = None, + internal_model: "str | None" = None, + model_variant: "int | None" = None, + bootloader_mode: "bool | None" = None, + protocol_version: "int | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]: + return isinstance(msg, cls) + + class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType): + host_pairing_credential: "bytes | None" + pairing_methods: "list[ThpPairingMethod]" + + def __init__( + self, + *, + pairing_methods: "list[ThpPairingMethod] | None" = None, + host_pairing_credential: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]: + return isinstance(msg, cls) + + class ThpCreateNewSession(protobuf.MessageType): + passphrase: "str | None" + on_device: "bool | None" + derive_cardano: "bool | None" + + def __init__( + self, + *, + passphrase: "str | None" = None, + on_device: "bool | None" = None, + derive_cardano: "bool | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]: + return isinstance(msg, cls) + + class ThpNewSession(protobuf.MessageType): + new_session_id: "int | None" + + def __init__( + self, + *, + new_session_id: "int | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpNewSession"]: + return isinstance(msg, cls) + + class ThpStartPairingRequest(protobuf.MessageType): + host_name: "str | None" + + def __init__( + self, + *, + host_name: "str | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpStartPairingRequest"]: + return isinstance(msg, cls) + + class ThpPairingPreparationsFinished(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]: + return isinstance(msg, cls) + + class ThpCodeEntryCommitment(protobuf.MessageType): + commitment: "bytes | None" + + def __init__( + self, + *, + commitment: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]: + return isinstance(msg, cls) + + class ThpCodeEntryChallenge(protobuf.MessageType): + challenge: "bytes | None" + + def __init__( + self, + *, + challenge: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]: + return isinstance(msg, cls) + + class ThpCodeEntryCpaceHost(protobuf.MessageType): + cpace_host_public_key: "bytes | None" + + def __init__( + self, + *, + cpace_host_public_key: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHost"]: + return isinstance(msg, cls) + + class ThpCodeEntryCpaceTrezor(protobuf.MessageType): + cpace_trezor_public_key: "bytes | None" + + def __init__( + self, + *, + cpace_trezor_public_key: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]: + return isinstance(msg, cls) + + class ThpCodeEntryTag(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryTag"]: + return isinstance(msg, cls) + + class ThpCodeEntrySecret(protobuf.MessageType): + secret: "bytes | None" + + def __init__( + self, + *, + secret: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]: + return isinstance(msg, cls) + + class ThpQrCodeTag(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]: + return isinstance(msg, cls) + + class ThpQrCodeSecret(protobuf.MessageType): + secret: "bytes | None" + + def __init__( + self, + *, + secret: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]: + return isinstance(msg, cls) + + class ThpNfcUnidirectionalTag(protobuf.MessageType): + tag: "bytes | None" + + def __init__( + self, + *, + tag: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalTag"]: + return isinstance(msg, cls) + + class ThpNfcUnidirectionalSecret(protobuf.MessageType): + secret: "bytes | None" + + def __init__( + self, + *, + secret: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalSecret"]: + return isinstance(msg, cls) + + class ThpCredentialRequest(protobuf.MessageType): + host_static_pubkey: "bytes | None" + + def __init__( + self, + *, + host_static_pubkey: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]: + return isinstance(msg, cls) + + class ThpCredentialResponse(protobuf.MessageType): + trezor_static_pubkey: "bytes | None" + credential: "bytes | None" + + def __init__( + self, + *, + trezor_static_pubkey: "bytes | None" = None, + credential: "bytes | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]: + return isinstance(msg, cls) + + class ThpEndRequest(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]: + return isinstance(msg, cls) + + class ThpEndResponse(protobuf.MessageType): + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]: + return isinstance(msg, cls) + class ThpCredentialMetadata(protobuf.MessageType): host_name: "str | None" diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 0162d4b8d58..6590cac4a45 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -35,6 +35,10 @@ DISABLE_ANIMATION = 0 +DISABLE_ENCRYPTION: bool = False + +ALLOW_DEBUG_MESSAGES: bool = True + if __debug__: if EMULATOR: import uos diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 2662a5610aa..1bc847f2733 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -5,7 +5,7 @@ - Request / response. - Protobuf-encoded, see `protobuf.py`. -- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`. +- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py` or `trezor/wire/thp/thp_main.py`. - Transferred over USB interface, or UDP in case of Unix emulation. This module: @@ -29,7 +29,12 @@ from trezor import log, loop, protobuf, utils from . import message_handler, protocol_common -from .codec.codec_context import CodecContext + +if utils.USE_THP: + from .thp import thp_main +else: + from .codec.codec_context import CodecContext + from .context import UnexpectedMessageException from .message_handler import failure @@ -40,6 +45,8 @@ _PROTOBUF_BUFFER_SIZE = const(8192) WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) +if utils.USE_THP: + WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE) if TYPE_CHECKING: from trezorio import WireInterface @@ -57,57 +64,89 @@ def setup(iface: WireInterface) -> None: loop.schedule(handle_session(iface)) -async def handle_session(iface: WireInterface) -> None: - ctx = CodecContext(iface, WIRE_BUFFER) - next_msg: protocol_common.Message | None = None +if utils.USE_THP: - # Take a mark of modules that are imported at this point, so we can - # roll back and un-import any others. - modules = utils.unimport_begin() - while True: - try: - if next_msg is None: - # If the previous run did not keep an unprocessed message for us, - # wait for a new one coming from the wire. - try: - msg = await ctx.read_from_wire() - except protocol_common.WireError as exc: - if __debug__: - log.exception(__name__, exc) - await ctx.write(failure(exc)) - continue + async def handle_session(iface: WireInterface) -> None: + + thp_main.set_read_buffer(WIRE_BUFFER) + thp_main.set_write_buffer(WIRE_BUFFER_2) - else: - # Process the message from previous run. - msg = next_msg - next_msg = None + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + modules = utils.unimport_begin() - do_not_restart = False + while True: try: - do_not_restart = await message_handler.handle_single_message(ctx, msg) - except UnexpectedMessageException as unexpected: - # The workflow was interrupted by an unexpected message. We need to - # process it as if it was a new message... - next_msg = unexpected.msg - # ...and we must not restart because that would lose the message. - do_not_restart = True - continue + await thp_main.thp_main_loop(iface) except Exception as exc: - # Log and ignore. The session handler can only exit explicitly in the - # following finally block. + # Log and try again. if __debug__: log.exception(__name__, exc) finally: # Unload modules imported by the workflow. Should not raise. + if __debug__: + log.debug(__name__, "utils.unimport_end(modules) and loop.clear()") utils.unimport_end(modules) + loop.clear() + return # pylint: disable=lost-exception - if not do_not_restart: - # Let the session be restarted from `main`. - loop.clear() - return # pylint: disable=lost-exception +else: - except Exception as exc: - # Log and try again. The session handler can only exit explicitly via - # loop.clear() above. - if __debug__: - log.exception(__name__, exc) + async def handle_session(iface: WireInterface) -> None: + ctx = CodecContext(iface, WIRE_BUFFER) + next_msg: protocol_common.Message | None = None + + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + modules = utils.unimport_begin() + while True: + try: + if next_msg is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one coming from the wire. + try: + msg = await ctx.read_from_wire() + except protocol_common.WireError as exc: + if __debug__: + log.exception(__name__, exc) + await ctx.write(failure(exc)) + continue + + else: + # Process the message from previous run. + msg = next_msg + next_msg = None + + do_not_restart = False + try: + do_not_restart = await message_handler.handle_single_message( + ctx, msg + ) + except UnexpectedMessageException as unexpected: + # The workflow was interrupted by an unexpected message. We need to + # process it as if it was a new message... + next_msg = unexpected.msg + # ...and we must not restart because that would lose the message. + do_not_restart = True + continue + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + # Unload modules imported by the workflow. Should not raise. + utils.unimport_end(modules) + + if not do_not_restart: + # Let the session be restarted from `main`. + if __debug__: + log.debug(__name__, "loop.clear()") + loop.clear() + return # pylint: disable=lost-exception + + except Exception as exc: + # Log and try again. The session handler can only exit explicitly via + # loop.clear() above. + if __debug__: + log.exception(__name__, exc) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 56df34fbc58..00bfeb77d4f 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -17,7 +17,7 @@ from storage import cache from storage.cache_common import SESSIONLESS_FLAG -from trezor import loop, protobuf +from trezor import loop, protobuf, utils from .protocol_common import Context, Message @@ -138,6 +138,17 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator: send_exc = None +def try_get_ctx_ids() -> tuple[bytes, bytes] | None: + ids = None + if utils.USE_THP: + from trezor.wire.thp.session_context import GenericSessionContext + + ctx = get_context() + if isinstance(ctx, GenericSessionContext): + ids = (ctx.channel_id, ctx.session_id.to_bytes(1, "big")) + return ids + + # ACCESS TO CACHE if TYPE_CHECKING: diff --git a/core/src/trezor/wire/errors.py b/core/src/trezor/wire/errors.py index 376820b5834..e8b2d3feb45 100644 --- a/core/src/trezor/wire/errors.py +++ b/core/src/trezor/wire/errors.py @@ -8,6 +8,12 @@ def __init__(self, code: FailureType, message: str) -> None: self.message = message +class SilentError(Exception): + def __init__(self, message: str) -> None: + super().__init__() + self.message = message + + class UnexpectedMessage(Error): def __init__(self, message: str) -> None: super().__init__(FailureType.UnexpectedMessage, message) diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py index 21c901dc90e..c0f201de22f 100644 --- a/core/src/trezor/wire/message_handler.py +++ b/core/src/trezor/wire/message_handler.py @@ -25,7 +25,12 @@ def wrap_protobuf_load( expected_type: type[LoadedMessageType], ) -> LoadedMessageType: try: - if __debug__ and utils.EMULATOR and utils.USE_THP: + if ( + __debug__ + and utils.EMULATOR + and utils.USE_THP + and utils.ALLOW_DEBUG_MESSAGES + ): log.debug( __name__, "Buffer to be parsed to a LoadedMessage: %s", @@ -38,7 +43,7 @@ def wrap_protobuf_load( ) return msg except Exception as e: - if __debug__: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.exception(__name__, e) if e.args: raise DataError("Failed to decode message: " + " ".join(e.args)) @@ -46,6 +51,25 @@ def wrap_protobuf_load( raise DataError("Failed to decode message") +if utils.USE_THP: + from trezor.enums import ThpMessageType + + def get_msg_name(msg_type: int) -> str | None: + for name in dir(ThpMessageType): + if not name.startswith("__"): # Skip built-in attributes + value = getattr(ThpMessageType, name) + if isinstance(value, int): + if value == msg_type: + return name + return None + + def get_msg_type(msg_name: str) -> int | None: + value = getattr(ThpMessageType, msg_name) + if isinstance(value, int): + return value + return None + + async def handle_single_message(ctx: Context, msg: Message) -> bool: """Handle a message that was loaded from a WireInterface by the caller. @@ -60,17 +84,27 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool: the type of message is supposed to be optimized and not disrupt the running state, this function will return `True`. """ - if __debug__: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: try: msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME except Exception: msg_type = f"{msg.type} - unknown message type" - log.debug( - __name__, - "%d receive: <%s>", - ctx.iface.iface_num(), - msg_type, - ) + if utils.USE_THP: + cid = int.from_bytes(ctx.channel_id, "big") + log.debug( + __name__, + "%d:%d receive: <%s>", + ctx.iface.iface_num(), + cid, + msg_type, + ) + else: + log.debug( + __name__, + "%d receive: <%s>", + ctx.iface.iface_num(), + msg_type, + ) res_msg: protobuf.MessageType | None = None @@ -91,7 +125,15 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool: try: # Find a protobuf.MessageType subclass that describes this # message. Raises if the type is not found. - req_type = protobuf.type_for_wire(msg.type) + + if utils.USE_THP: + name = get_msg_name(msg.type) + if name is None: + req_type = protobuf.type_for_wire(msg.type) + else: + req_type = protobuf.type_for_name(name) + else: + req_type = protobuf.type_for_wire(msg.type) # Try to decode the message according to schema from # `req_type`. Raises if the message is malformed. @@ -132,7 +174,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool: # - the message was not valid protobuf # - workflow raised some kind of an exception while running # - something canceled the workflow from the outside - if __debug__: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if isinstance(exc, ActionCancelled): log.debug(__name__, "cancelled: %s", exc.message) elif isinstance(exc, loop.TaskClosed): diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index ed4105517b1..0e54afe8c3c 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: from trezorio import WireInterface - from typing import Container, TypeVar, overload + from typing import Awaitable, Container, TypeVar, overload from storage.cache_common import DataCache @@ -72,6 +72,9 @@ async def write(self, msg: protobuf.MessageType) -> None: """Write a message to the wire.""" ... + def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.write(msg) + async def call( self, msg: protobuf.MessageType, diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py new file mode 100644 index 00000000000..ce61cf815a4 --- /dev/null +++ b/core/src/trezor/wire/thp/__init__.py @@ -0,0 +1,184 @@ +import ustruct +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import protobuf, utils +from trezor.enums import ThpPairingMethod +from trezor.messages import ThpDeviceProperties + +from ..protocol_common import WireError + +if TYPE_CHECKING: + from enum import IntEnum + + from trezor.wire import WireInterface + from typing_extensions import Self +else: + IntEnum = object + +CODEC_V1 = const(0x3F) + +HANDSHAKE_INIT_REQ = const(0x00) +HANDSHAKE_INIT_RES = const(0x01) +HANDSHAKE_COMP_REQ = const(0x02) +HANDSHAKE_COMP_RES = const(0x03) +ENCRYPTED = const(0x04) + +ACK_MESSAGE = const(0x20) +CHANNEL_ALLOCATION_REQ = const(0x40) +_CHANNEL_ALLOCATION_RES = const(0x41) +_ERROR = const(0x42) +CONTINUATION_PACKET = const(0x80) + + +class ThpError(WireError): + pass + + +class ThpDecryptionError(ThpError): + pass + + +class ThpInvalidDataError(ThpError): + pass + + +class ThpUnallocatedSessionError(ThpError): + + def __init__(self, session_id: int) -> None: + self.session_id = session_id + + +class ThpErrorType(IntEnum): + TRANSPORT_BUSY = 1 + UNALLOCATED_CHANNEL = 2 + DECRYPTION_FAILED = 3 + INVALID_DATA = 4 + + +class ChannelState(IntEnum): + UNALLOCATED = 0 + TH1 = 1 + TH2 = 2 + TP1 = 3 + TP2 = 4 + TP3 = 5 + TP4 = 6 + TC1 = 7 + ENCRYPTED_TRANSPORT = 8 + + +class SessionState(IntEnum): + UNALLOCATED = 0 + ALLOCATED = 1 + MANAGEMENT = 2 + + +class PacketHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.length = length + + def to_bytes(self) -> bytes: + return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + """ + Packs header information in the form of **intial** packet + into the provided buffer. + """ + ustruct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + """ + Packs header information in the form of **continuation** packet header + into the provided buffer. + """ + ustruct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + @classmethod + def get_error_header(cls, cid: int, length: int) -> Self: + """ + Returns header for protocol-level error messages. + """ + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_response_header(cls, length: int) -> Self: + """ + Returns header for allocation response handshake message. + """ + return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length) + + +_DEFAULT_ENABLED_PAIRING_METHODS = [ + ThpPairingMethod.CodeEntry, + ThpPairingMethod.QrCode, + ThpPairingMethod.NFC_Unidirectional, +] + + +def get_enabled_pairing_methods( + iface: WireInterface | None = None, +) -> list[ThpPairingMethod]: + """ + Returns pairing methods that are currently allowed by the device + with respect to the wire interface the host communicates on. + """ + import usb + + methods = _DEFAULT_ENABLED_PAIRING_METHODS.copy() + if iface is not None and iface is usb.iface_wire: + methods.append(ThpPairingMethod.NoMethod) + return methods + + +def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties: + # TODO define model variants + return ThpDeviceProperties( + pairing_methods=get_enabled_pairing_methods(iface), + internal_model=utils.INTERNAL_MODEL, + model_variant=0, + bootloader_mode=False, + protocol_version=2, + ) + + +def get_encoded_device_properties(iface: WireInterface) -> bytes: + props = _get_device_properties(iface) + length = protobuf.encoded_length(props) + encoded_properties = bytearray(length) + protobuf.encode(encoded_properties, props) + return encoded_properties + + +def get_channel_allocation_response( + nonce: bytes, new_cid: bytes, iface: WireInterface +) -> bytes: + props_msg = get_encoded_device_properties(iface) + return nonce + new_cid + props_msg + + +if __debug__: + + def state_to_str(state: int) -> str: + name = { + v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__") + }.get(state) + if name is not None: + return name + return "UNKNOWN_STATE" diff --git a/core/src/trezor/wire/thp/alternating_bit_protocol.py b/core/src/trezor/wire/thp/alternating_bit_protocol.py new file mode 100644 index 00000000000..d8ba60c5b23 --- /dev/null +++ b/core/src/trezor/wire/thp/alternating_bit_protocol.py @@ -0,0 +1,102 @@ +from storage.cache_thp import ChannelCache +from trezor import log, utils +from trezor.wire.thp import ThpError + + +def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool: + """ + Checks if: + - an ACK message is expected + - the received ACK message acknowledges correct sequence number (bit) + """ + if not _is_ack_expected(cache): + return False + + if not _has_ack_correct_sync_bit(cache, ack_bit): + return False + + return True + + +def _is_ack_expected(cache: ChannelCache) -> bool: + is_expected: bool = not is_sending_allowed(cache) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_expected: + log.debug(__name__, "Received unexpected ACK message") + return is_expected + + +def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool: + is_correct: bool = get_send_seq_bit(cache) == sync_bit + if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_correct: + log.debug(__name__, "Received ACK message with wrong ack bit") + return is_correct + + +def is_sending_allowed(cache: ChannelCache) -> bool: + """ + Checks whether sending a message in the provided channel is allowed. + + Note: Sending a message in a channel before receipt of ACK message for the previously + sent message (in the channel) is prohibited, as it can lead to desynchronization. + """ + return bool(cache.sync >> 7) + + +def get_send_seq_bit(cache: ChannelCache) -> int: + """ + Returns the sequential number (bit) of the next message to be sent + in the provided channel. + """ + return (cache.sync & 0x20) >> 5 + + +def get_expected_receive_seq_bit(cache: ChannelCache) -> int: + """ + Returns the (expected) sequential number (bit) of the next message + to be received in the provided channel. + """ + return (cache.sync & 0x40) >> 6 + + +def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None: + """ + Set the flag whether sending a message in this channel is allowed or not. + """ + cache.sync &= 0x7F + if sending_allowed: + cache.sync |= 0x80 + + +def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None: + """ + Set the expected sequential number (bit) of the next message to be received + in the provided channel + """ + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit) + if seq_bit not in (0, 1): + raise ThpError("Unexpected receive sync bit") + + # set second bit to "seq_bit" value + cache.sync &= 0xBF + if seq_bit: + cache.sync |= 0x40 + + +def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None: + if seq_bit not in (0, 1): + raise ThpError("Unexpected send seq bit") + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "setting sync send seq bit to %d", seq_bit) + # set third bit to "seq_bit" value + cache.sync &= 0xDF + if seq_bit: + cache.sync |= 0x20 + + +def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None: + """ + Set the sequential bit of the "next message to be send" to the opposite value, + i.e. 1 -> 0 and 0 -> 1 + """ + _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache)) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py new file mode 100644 index 00000000000..8e6e65647f0 --- /dev/null +++ b/core/src/trezor/wire/thp/channel.py @@ -0,0 +1,405 @@ +import ustruct +from typing import TYPE_CHECKING + +from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, +) +from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id +from trezor import log, loop, protobuf, utils, workflow + +from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError +from . import alternating_bit_protocol as ABP +from . import ( + control_byte, + crypto, + interface_manager, + memory_manager, + received_message_handler, +) +from .checksum import CHECKSUM_LENGTH +from .transmission_loop import TransmissionLoop +from .writer import ( + CONT_HEADER_LENGTH, + INIT_HEADER_LENGTH, + write_payload_to_wire_and_add_checksum, +) + +if __debug__: + from ubinascii import hexlify + + from . import state_to_str + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Awaitable + + from .pairing_context import PairingContext + from .session_context import GenericSessionContext + + +class Channel: + """ + THP protocol encrypted communication channel. + """ + + def __init__(self, channel_cache: ChannelCache) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "channel initialization") + self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) + self.channel_cache: ChannelCache = channel_cache + self.is_cont_packet_expected: bool = False + self.expected_payload_length: int = 0 + self.bytes_read: int = 0 + self.buffer: utils.BufferType + self.channel_id: bytes = channel_cache.channel_id + self.selected_pairing_methods = [] + self.sessions: dict[int, GenericSessionContext] = {} + self.write_task_spawn: loop.spawn | None = None + self.connection_context: PairingContext | None = None + self.transmission_loop: TransmissionLoop | None = None + self.handshake: crypto.Handshake | None = None + + def clear(self) -> None: + clear_sessions_with_channel_id(self.channel_id) + self.channel_cache.clear() + + # ACCESS TO CHANNEL_DATA + def get_channel_id_int(self) -> int: + return int.from_bytes(self.channel_id, "big") + + def get_channel_state(self) -> int: + state = int.from_bytes(self.channel_cache.state, "big") + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) get_channel_state: %s", + utils.get_bytes_as_str(self.channel_id), + state_to_str(state), + ) + return state + + def get_handshake_hash(self) -> bytes: + h = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH) + assert h is not None + return h + + def set_channel_state(self, state: ChannelState) -> None: + self.channel_cache.state = bytearray(state.to_bytes(1, "big")) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) set_channel_state: %s", + utils.get_bytes_as_str(self.channel_id), + state_to_str(state), + ) + + def set_buffer(self, buffer: utils.BufferType) -> None: + self.buffer = buffer + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) set_buffer: %s", + utils.get_bytes_as_str(self.channel_id), + type(self.buffer), + ) + + # CALLED BY THP_MAIN_LOOP + + def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) receive_packet", + utils.get_bytes_as_str(self.channel_id), + ) + + self._handle_received_packet(packet) + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) self.buffer: %s", + utils.get_bytes_as_str(self.channel_id), + utils.get_bytes_as_str(self.buffer), + ) + + if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: + self._finish_message() + return received_message_handler.handle_received_message(self, self.buffer) + elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read: + self.is_cont_packet_expected = True + else: + raise ThpError( + "Read more bytes than is the expected length of the message!" + ) + return None + + def _handle_received_packet(self, packet: utils.BufferType) -> None: + ctrl_byte = packet[0] + if control_byte.is_continuation(ctrl_byte): + return self._handle_cont_packet(packet) + return self._handle_init_packet(packet) + + def _handle_init_packet(self, packet: utils.BufferType) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) handle_init_packet", + utils.get_bytes_as_str(self.channel_id), + ) + # ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption + _, _, payload_length = ustruct.unpack(">BHH", packet) + self.expected_payload_length = payload_length + packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:] + + # If the channel does not "own" the buffer lock, decrypt first packet + # TODO do it only when needed! + # TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation. + # On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented + # in "memory_manager.select_buffer", because the session id is unknown (encrypted). + + # if control_byte.is_encrypted_transport(ctrl_byte): + # packet_payload = self._decrypt_single_packet_payload(packet_payload) + + self.buffer = memory_manager.select_buffer( + self.get_channel_state(), + self.buffer, + packet_payload, + payload_length, + ) + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) handle_init_packet - payload len: %d", + utils.get_bytes_as_str(self.channel_id), + payload_length, + ) + log.debug( + __name__, + "(cid: %s) handle_init_packet - buffer len: %d", + utils.get_bytes_as_str(self.channel_id), + len(self.buffer), + ) + return self._buffer_packet_data(self.buffer, packet, 0) + + def _handle_cont_packet(self, packet: utils.BufferType) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) handle_cont_packet", + utils.get_bytes_as_str(self.channel_id), + ) + if not self.is_cont_packet_expected: + raise ThpError("Continuation packet is not expected, ignoring") + return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH) + + def _decrypt_single_packet_payload( + self, payload: utils.BufferType + ) -> utils.BufferType: + # crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) + return payload + + def decrypt_buffer( + self, message_length: int, offset: int = INIT_HEADER_LENGTH + ) -> None: + noise_buffer = memoryview(self.buffer)[ + offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH + ] + tag = self.buffer[ + message_length + - CHECKSUM_LENGTH + - TAG_LENGTH : message_length + - CHECKSUM_LENGTH + ] + if utils.DISABLE_ENCRYPTION: + is_tag_valid = tag == crypto.DUMMY_TAG + else: + key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE) + nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE) + + assert key_receive is not None + assert nonce_receive is not None + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) Buffer before decryption: %s", + utils.get_bytes_as_str(self.channel_id), + hexlify(noise_buffer), + ) + is_tag_valid = crypto.dec( + noise_buffer, tag, key_receive, nonce_receive, b"" + ) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) Buffer after decryption: %s", + utils.get_bytes_as_str(self.channel_id), + hexlify(noise_buffer), + ) + + self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1) + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) Is decrypted tag valid? %s", + utils.get_bytes_as_str(self.channel_id), + str(is_tag_valid), + ) + log.debug( + __name__, + "(cid: %s) Received tag: %s", + utils.get_bytes_as_str(self.channel_id), + (hexlify(tag).decode()), + ) + log.debug( + __name__, + "(cid: %s) New nonce_receive: %i", + utils.get_bytes_as_str(self.channel_id), + nonce_receive + 1, + ) + + if not is_tag_valid: + raise ThpDecryptionError() + + def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id) + ) + assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH + + noise_buffer = memoryview(buffer)[0:noise_payload_len] + + if utils.DISABLE_ENCRYPTION: + tag = crypto.DUMMY_TAG + else: + key_send = self.channel_cache.get(CHANNEL_KEY_SEND) + nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND) + + assert key_send is not None + assert nonce_send is not None + + tag = crypto.enc(noise_buffer, key_send, nonce_send, b"") + + self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "New nonce_send: %i", nonce_send + 1) + + buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag + + def _buffer_packet_data( + self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int + ) -> None: + self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) + + def _finish_message(self) -> None: + self.bytes_read = 0 + self.expected_payload_length = 0 + self.is_cont_packet_expected = False + + # CALLED BY WORKFLOW / SESSION CONTEXT + + async def write( + self, + msg: protobuf.MessageType, + session_id: int = 0, + force: bool = False, + ) -> None: + if __debug__ and utils.EMULATOR: + log.debug( + __name__, + "(cid: %s) write message: %s\n%s", + utils.get_bytes_as_str(self.channel_id), + msg.MESSAGE_NAME, + utils.dump_protobuf(msg), + ) + + self.buffer = memory_manager.get_write_buffer(self.buffer, msg) + noise_payload_len = memory_manager.encode_into_buffer( + self.buffer, msg, session_id + ) + task = self.write_and_encrypt(self.buffer[:noise_payload_len], force) + if task is not None: + await task + + def write_error(self, err_type: int) -> Awaitable[None]: + msg_data = err_type.to_bytes(1, "big") + length = len(msg_data) + CHECKSUM_LENGTH + header = PacketHeader.get_error_header(self.get_channel_id_int(), length) + return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data) + + def write_and_encrypt( + self, payload: bytes, force: bool = False + ) -> Awaitable[None] | None: + payload_length = len(payload) + self._encrypt(self.buffer, payload_length) + payload_length = payload_length + TAG_LENGTH + + if self.write_task_spawn is not None: + self.write_task_spawn.close() # UPS TODO might break something + print("\nCLOSED\n") + self._prepare_write() + if force: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, "Writing FORCE message (without async or retransmission)." + ) + return self._write_encrypted_payload_loop( + ENCRYPTED, memoryview(self.buffer[:payload_length]) + ) + self.write_task_spawn = loop.spawn( + self._write_encrypted_payload_loop( + ENCRYPTED, memoryview(self.buffer[:payload_length]) + ) + ) + return None + + def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: + self._prepare_write() + self.write_task_spawn = loop.spawn( + self._write_encrypted_payload_loop(ctrl_byte, payload) + ) + + def _prepare_write(self) -> None: + # TODO add condition that disallows to write when can_send_message is false + ABP.set_sending_allowed(self.channel_cache, False) + + async def _write_encrypted_payload_loop( + self, ctrl_byte: int, payload: bytes + ) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid %s) write_encrypted_payload_loop", + utils.get_bytes_as_str(self.channel_id), + ) + payload_len = len(payload) + CHECKSUM_LENGTH + sync_bit = ABP.get_send_seq_bit(self.channel_cache) + ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit) + header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len) + self.transmission_loop = TransmissionLoop(self, header, payload) + await self.transmission_loop.start() + + ABP.set_send_seq_bit_to_opposite(self.channel_cache) + + # Let the main loop be restarted and clear loop, if there is no other + # workflow and the state is ENCRYPTED_TRANSPORT + if self._can_clear_loop(): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "(cid: %s) clearing loop from channel", + utils.get_bytes_as_str(self.channel_id), + ) + loop.clear() + + def _can_clear_loop(self) -> bool: + return ( + not workflow.tasks + ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT diff --git a/core/src/trezor/wire/thp/channel_manager.py b/core/src/trezor/wire/thp/channel_manager.py new file mode 100644 index 00000000000..a48f6d7fdb4 --- /dev/null +++ b/core/src/trezor/wire/thp/channel_manager.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp +from trezor import utils + +from . import ChannelState, interface_manager +from .channel import Channel + +if TYPE_CHECKING: + from trezorio import WireInterface + + +def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel: + """ + Creates a new channel for the interface `iface` with the buffer `buffer`. + """ + channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface)) + r = Channel(channel_cache) + r.set_buffer(buffer) + r.set_channel_state(ChannelState.TH1) + return r + + +def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: + """ + Returns all allocated channels from cache. + """ + channels: dict[int, Channel] = {} + cached_channels = cache_thp.get_all_allocated_channels() + for c in cached_channels: + channels[int.from_bytes(c.channel_id, "big")] = Channel(c) + for c in channels.values(): + c.set_buffer(buffer) + return channels diff --git a/core/src/trezor/wire/thp/checksum.py b/core/src/trezor/wire/thp/checksum.py new file mode 100644 index 00000000000..9c28f2e78d8 --- /dev/null +++ b/core/src/trezor/wire/thp/checksum.py @@ -0,0 +1,22 @@ +from micropython import const + +from trezor import utils +from trezor.crypto import crc + +CHECKSUM_LENGTH = const(4) + + +def compute(data: bytes | utils.BufferType) -> bytes: + """ + Returns a CRC-32 checksum of the provided `data`. + """ + return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + + +def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool: + """ + Checks whether the CRC-32 checksum of the `data` is the same + as the checksum provided in `checksum`. + """ + data_checksum = compute(data) + return checksum == data_checksum diff --git a/core/src/trezor/wire/thp/control_byte.py b/core/src/trezor/wire/thp/control_byte.py new file mode 100644 index 00000000000..5d4d69b0400 --- /dev/null +++ b/core/src/trezor/wire/thp/control_byte.py @@ -0,0 +1,50 @@ +from micropython import const + +from . import ( + ACK_MESSAGE, + CONTINUATION_PACKET, + ENCRYPTED, + HANDSHAKE_COMP_REQ, + HANDSHAKE_INIT_REQ, + ThpError, +) + +_CONTINUATION_PACKET_MASK = const(0x80) +_ACK_MASK = const(0xF7) +_DATA_MASK = const(0xE7) + + +def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int: + if seq_bit == 0: + return ctrl_byte & 0xEF + if seq_bit == 1: + return ctrl_byte | 0x10 + raise ThpError("Unexpected sequence bit") + + +def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int: + if ack_bit == 0: + return ctrl_byte & 0xF7 + if ack_bit == 1: + return ctrl_byte | 0x08 + raise ThpError("Unexpected acknowledgement bit") + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & _ACK_MASK == ACK_MESSAGE + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & _CONTINUATION_PACKET_MASK == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & _DATA_MASK == ENCRYPTED + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & _DATA_MASK == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & _DATA_MASK == HANDSHAKE_COMP_REQ diff --git a/core/src/trezor/wire/thp/cpace.py b/core/src/trezor/wire/thp/cpace.py new file mode 100644 index 00000000000..302dd3e5e37 --- /dev/null +++ b/core/src/trezor/wire/thp/cpace.py @@ -0,0 +1,36 @@ +from trezor.crypto import elligator2, random +from trezor.crypto.curve import curve25519 +from trezor.crypto.hashlib import sha512 + +_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06" +_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20" + + +class Cpace: + """ + CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/ + """ + + def __init__(self, cpace_host_public_key: bytes, handshake_hash: bytes) -> None: + self.handshake_hash: bytes = handshake_hash + self.host_public_key: bytes = cpace_host_public_key + self.shared_secret: bytes + self.trezor_private_key: bytes + self.trezor_public_key: bytes + + def generate_keys_and_secret(self, code_code_entry: bytes) -> None: + """ + Generate ephemeral key pair and a shared secret using Elligator2 with X25519. + """ + sha_ctx = sha512(_PREFIX) + sha_ctx.update(code_code_entry) + sha_ctx.update(_PADDING) + sha_ctx.update(self.handshake_hash) + sha_ctx.update(b"\x00") + pregenerator = sha_ctx.digest()[:32] + generator = elligator2.map_to_curve25519(pregenerator) + self.trezor_private_key = random.bytes(32) + self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator) + self.shared_secret = curve25519.multiply( + self.trezor_private_key, self.host_public_key + ) diff --git a/core/src/trezor/wire/thp/crypto.py b/core/src/trezor/wire/thp/crypto.py new file mode 100644 index 00000000000..aa7d9c146e9 --- /dev/null +++ b/core/src/trezor/wire/thp/crypto.py @@ -0,0 +1,211 @@ +from micropython import const +from trezorcrypto import aesgcm, bip32, curve25519, hmac + +from storage import device +from trezor import log, utils +from trezor.crypto.hashlib import sha256 +from trezor.wire.thp import ThpDecryptionError + +# The HARDENED flag is taken from apps.common.paths +# It is not imported to save on resources +HARDENED = const(0x8000_0000) +PUBKEY_LENGTH = const(32) +if utils.DISABLE_ENCRYPTION: + DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5" + +if __debug__: + from ubinascii import hexlify + + +def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes: + """ + Encrypts the provided `buffer` with AES-GCM (in place). + Returns a 16-byte long encryption tag. + """ + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "enc (key: %s, nonce: %d)", hexlify(key), nonce) + iv = _get_iv_from_nonce(nonce) + aes_ctx = aesgcm(key, iv) + aes_ctx.auth(auth_data) + aes_ctx.encrypt_in_place(buffer) + return aes_ctx.finish() + + +def dec( + buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes +) -> bool: + """ + Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as + the tag computed in decryption, otherwise it returns `False`. + """ + iv = _get_iv_from_nonce(nonce) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "dec (key: %s, nonce: %d)", hexlify(key), nonce) + aes_ctx = aesgcm(key, iv) + aes_ctx.auth(auth_data) + aes_ctx.decrypt_in_place(buffer) + computed_tag = aes_ctx.finish() + return computed_tag == tag + + +class BusyDecoder: + def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None: + iv = _get_iv_from_nonce(nonce) + self.aes_ctx = aesgcm(key, iv) + self.aes_ctx.auth(auth_data) + + def decrypt_part(self, part: utils.BufferType) -> None: + self.aes_ctx.decrypt_in_place(part) + + def finish_and_check_tag(self, tag: bytes) -> bool: + computed_tag = self.aes_ctx.finish() + return computed_tag == tag + + +PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" +IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + + +class Handshake: + """ + `Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel. + The following values should be saved for future use before disposing of this object: + - `h` - handshake hash, can be used to bind other values to the channel + - `key_receive` - key for decrypting incoming communication + - `key_send` - key for encrypting outgoing communication + """ + + def __init__(self) -> None: + self.trezor_ephemeral_privkey: bytes + self.ck: bytes + self.k: bytes + self.h: bytes + self.key_receive: bytes + self.key_send: bytes + + def handle_th1_crypto( + self, + device_properties: bytes, + host_ephemeral_pubkey: bytes, + ) -> tuple[bytes, bytes, bytes]: + + trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair() + self.trezor_ephemeral_privkey = curve25519.generate_secret() + trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey) + self.h = _hash_of_two(PROTOCOL_NAME, device_properties) + self.h = _hash_of_two(self.h, host_ephemeral_pubkey) + self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey) + point = curve25519.multiply( + self.trezor_ephemeral_privkey, host_ephemeral_pubkey + ) + self.ck, self.k = _hkdf(PROTOCOL_NAME, point) + mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey) + trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey) + aes_ctx = aesgcm(self.k, IV_1) + encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey) + if __debug__: + log.debug(__name__, "th1 - enc (key: %s, nonce: %d)", hexlify(self.k), 0) + aes_ctx.auth(self.h) + tag_to_encrypted_key = aes_ctx.finish() + encrypted_trezor_static_pubkey = ( + encrypted_trezor_static_pubkey + tag_to_encrypted_key + ) + self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey) + point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey) + self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point)) + aes_ctx = aesgcm(self.k, IV_1) + aes_ctx.auth(self.h) + tag = aes_ctx.finish() + self.h = _hash_of_two(self.h, tag) + return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag) + + def handle_th2_crypto( + self, + encrypted_host_static_pubkey: utils.BufferType, + encrypted_payload: utils.BufferType, + ) -> None: + + aes_ctx = aesgcm(self.k, IV_2) + + # The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted. + # However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for + # authentication of the gcm tag. + aes_ctx.auth(self.h) # Authenticate with the previous value of `h` + self.h = _hash_of_two(self.h, encrypted_host_static_pubkey) # Compute new value + aes_ctx.decrypt_in_place( + memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH] + ) + if __debug__: + log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 1) + host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH] + tag = aes_ctx.finish() + if tag != encrypted_host_static_pubkey[-16:]: + raise ThpDecryptionError() + + self.ck, self.k = _hkdf( + self.ck, + curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey), + ) + aes_ctx = aesgcm(self.k, IV_1) + aes_ctx.auth(self.h) + aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16]) + if __debug__: + log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 0) + tag = aes_ctx.finish() + if tag != encrypted_payload[-16:]: + raise ThpDecryptionError() + + self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16]) + self.key_receive, self.key_send = _hkdf(self.ck, b"") + if __debug__: + log.debug( + __name__, + "(key_receive: %s, key_send: %s)", + hexlify(self.key_receive), + hexlify(self.key_send), + ) + + def get_handshake_completion_response(self, trezor_state: bytes) -> bytes: + aes_ctx = aesgcm(self.key_send, IV_1) + encrypted_trezor_state = aes_ctx.encrypt(trezor_state) + tag = aes_ctx.finish() + return encrypted_trezor_state + tag + + +def _derive_static_key_pair() -> tuple[bytes, bytes]: + node_int = HARDENED | int.from_bytes(b"\x00THP", "big") + node = bip32.from_seed(device.get_device_secret(), "curve25519") + node.derive(node_int) + + trezor_static_privkey = node.private_key() + trezor_static_pubkey = node.public_key()[1:33] + # Note: the first byte (\x01) of the public key is removed, as it + # only indicates the type of the elliptic curve used + + return trezor_static_privkey, trezor_static_pubkey + + +def get_trezor_static_pubkey() -> bytes: + _, pubkey = _derive_static_key_pair() + return pubkey + + +def _hkdf(chaining_key: bytes, input: bytes) -> tuple[bytes, bytes]: + temp_key = hmac(hmac.SHA256, chaining_key, input).digest() + output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest() + ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes: + ctx = sha256(part_1) + ctx.update(part_2) + return ctx.digest() + + +def _get_iv_from_nonce(nonce: int) -> bytes: + utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") diff --git a/core/src/trezor/wire/thp/interface_manager.py b/core/src/trezor/wire/thp/interface_manager.py new file mode 100644 index 00000000000..a1fecfe7d64 --- /dev/null +++ b/core/src/trezor/wire/thp/interface_manager.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +import usb + +_WIRE_INTERFACE_USB = b"\x01" +# TODO _WIRE_INTERFACE_BLE = b"\x02" + +if TYPE_CHECKING: + from trezorio import WireInterface + + +def decode_iface(cached_iface: bytes) -> WireInterface: + """Decode the cached wire interface.""" + if cached_iface == _WIRE_INTERFACE_USB: + iface = usb.iface_wire + if iface is None: + raise RuntimeError("There is no valid USB WireInterface") + return iface + # TODO implement bluetooth interface + raise Exception("Unknown WireInterface") + + +def encode_iface(iface: WireInterface) -> bytes: + """Encode wire interface into bytes.""" + if iface is usb.iface_wire: + return _WIRE_INTERFACE_USB + # TODO implement bluetooth interface + raise Exception("Unknown WireInterface") diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py new file mode 100644 index 00000000000..0a117c16f73 --- /dev/null +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -0,0 +1,179 @@ +from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH +from trezor import log, protobuf, utils +from trezor.wire.message_handler import get_msg_type + +from . import ChannelState, ThpError +from .checksum import CHECKSUM_LENGTH +from .writer import ( + INIT_HEADER_LENGTH, + MAX_PAYLOAD_LEN, + MESSAGE_TYPE_LENGTH, + PACKET_LENGTH, +) + + +def select_buffer( + channel_state: int, + channel_buffer: utils.BufferType, + packet_payload: utils.BufferType, + payload_length: int, +) -> utils.BufferType: + + if channel_state is ChannelState.ENCRYPTED_TRANSPORT: + session_id = packet_payload[0] + if session_id == 0: + pass + # TODO use small buffer + else: + pass + # TODO use big buffer but only if the channel owns the buffer lock. + # Otherwise send BUSY message and return + else: + pass + # TODO use small buffer + try: + # TODO for now, we create a new big buffer every time. It should be changed + buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer) + return buffer + except Exception as e: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.exception(__name__, e) + raise Exception("Failed to create a buffer for channel") # TODO handle better + + +def get_write_buffer( + buffer: utils.BufferType, msg: protobuf.MessageType +) -> utils.BufferType: + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + + if required_min_size > len(buffer): + return _get_buffer_for_write(required_min_size, buffer) + return buffer + + +def encode_into_buffer( + buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int +) -> int: + # cannot write message without wire type + msg_type = msg.MESSAGE_WIRE_TYPE + if msg_type is None: + msg_type = get_msg_type(msg.MESSAGE_NAME) + assert msg_type is not None + + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + + _encode_session_into_buffer(memoryview(buffer), session_id) + _encode_message_type_into_buffer(memoryview(buffer), msg_type, SESSION_ID_LENGTH) + _encode_message_into_buffer( + memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + ) + + return payload_size + + +def _encode_session_into_buffer( + buffer: memoryview, session_id: int, buffer_offset: int = 0 +) -> None: + session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") + utils.memcpy(buffer, buffer_offset, session_id_bytes, 0) + + +def _encode_message_type_into_buffer( + buffer: memoryview, message_type: int, offset: int = 0 +) -> None: + msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big") + utils.memcpy(buffer, offset, msg_type_bytes, 0) + + +def _encode_message_into_buffer( + buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0 +) -> None: + protobuf.encode(memoryview(buffer[buffer_offset:]), message) + + +def _get_buffer_for_read( + payload_length: int, + existing_buffer: utils.BufferType, + max_length: int = MAX_PAYLOAD_LEN, +) -> utils.BufferType: + length = payload_length + INIT_HEADER_LENGTH + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "get_buffer_for_read - length: %d, %s %s", + length, + "existing buffer type:", + type(existing_buffer), + ) + if length > max_length: + raise ThpError("Message too large") + + if length > len(existing_buffer): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Allocating a new buffer") + + from .thp_main import get_raw_read_buffer + + if length > len(get_raw_read_buffer()): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "Required length is %d, where raw buffer has capacity only %d", + length, + len(get_raw_read_buffer()), + ) + raise ThpError("Message is too large") + + try: + payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length] + except MemoryError: + payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH] + raise ThpError("Message is too large") + return payload + + # reuse a part of the supplied buffer + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Reusing already allocated buffer") + return memoryview(existing_buffer)[:length] + + +def _get_buffer_for_write( + payload_length: int, + existing_buffer: utils.BufferType, + max_length: int = MAX_PAYLOAD_LEN, +) -> utils.BufferType: + length = payload_length + INIT_HEADER_LENGTH + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "get_buffer_for_write - length: %d, %s %s", + length, + "existing buffer type:", + type(existing_buffer), + ) + if length > max_length: + raise ThpError("Message too large") + + if length > len(existing_buffer): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Creating a new write buffer from raw write buffer") + + from .thp_main import get_raw_write_buffer + + if length > len(get_raw_write_buffer()): + raise ThpError("Message is too large") + + try: + payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length] + except MemoryError: + payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH] + raise ThpError("Message is too large") + return payload + + # reuse a part of the supplied buffer + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Reusing already allocated buffer") + return memoryview(existing_buffer)[:length] diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py new file mode 100644 index 00000000000..34065b122f2 --- /dev/null +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -0,0 +1,262 @@ +from typing import TYPE_CHECKING +from ubinascii import hexlify + +import trezorui2 +from trezor import loop, protobuf, workflow +from trezor.crypto import random +from trezor.wire import context, message_handler, protocol_common +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.errors import ActionCancelled, SilentError +from trezor.wire.protocol_common import Context, Message + +if TYPE_CHECKING: + from typing import Container + + from trezor import ui + + from .channel import Channel + from .cpace import Cpace + + pass + +if __debug__: + from trezor import log + + +class PairingDisplayData: + + def __init__(self) -> None: + self.code_code_entry: int | None = None + self.code_qr_code: bytes | None = None + self.code_nfc_unidirectional: bytes | None = None + + def get_display_layout(self) -> ui.Layout: + from trezor import ui + + # TODO have different layouts when there is only QR code or only Code Entry + qr_str = "" + code_str = "" + if self.code_qr_code is not None: + qr_str = self._get_code_qr_code_str() + if self.code_code_entry is not None: + code_str = self._get_code_code_entry_str() + + return ui.Layout( + trezorui2.show_address_details( # noqa + qr_title="Scan QR code to pair", + address=qr_str, + case_sensitive=True, + details_title="", + account="Code to rewrite:\n" + code_str, + path="", + xpubs=[], + ) + ) + + def _get_code_code_entry_str(self) -> str: + if self.code_code_entry is not None: + code_str = f"{self.code_code_entry:06}" + if __debug__: + log.debug(__name__, "code_code_entry: %s", code_str) + + return code_str[:3] + " " + code_str[3:] + raise Exception("Code entry string is not available") + + def _get_code_qr_code_str(self) -> str: + if self.code_qr_code is not None: + code_str = (hexlify(self.code_qr_code)).decode("utf-8") + if __debug__: + log.debug(__name__, "code_qr_code_hexlified: %s", code_str) + return code_str + raise Exception("QR code string is not available") + + +class PairingContext(Context): + + def __init__(self, channel_ctx: Channel) -> None: + super().__init__(channel_ctx.iface, channel_ctx.channel_id) + self.channel_ctx: Channel = channel_ctx + self.incoming_message = loop.mailbox() + self.secret: bytes = random.bytes(16) + + self.display_data: PairingDisplayData = PairingDisplayData() + self.cpace: Cpace + self.host_name: str + + async def handle(self, is_debug_session: bool = False) -> None: + # if __debug__: + # log.debug(__name__, "handle - start") + # if is_debug_session: + # import apps.debug + + # apps.debug.DEBUG_CONTEXT = self + + next_message: Message | None = None + + while True: + try: + if next_message is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one. + try: + message: Message = await self.incoming_message + except protocol_common.WireError as e: + if __debug__: + log.exception(__name__, e) + await self.write(message_handler.failure(e)) + continue + else: + # Process the message from previous run. + message = next_message + next_message = None + + try: + next_message = await handle_pairing_request_message(self, message) + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + # Unload modules imported by the workflow. Should not raise. + # This is not done for the debug session because the snapshot taken + # in a debug session would clear modules which are in use by the + # workflow running on wire. + # TODO utils.unimport_end(modules) + + if next_message is None: + + # Shut down the loop if there is no next message waiting. + return # pylint: disable=lost-exception + + except Exception as exc: + # Log and try again. The session handler can only exit explicitly via + # loop.clear() above. # TODO not updated comments + if __debug__: + log.exception(__name__, exc) + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + if __debug__: + exp_type: str = str(expected_type) + if expected_type is not None: + exp_type = expected_type.MESSAGE_NAME + log.debug( + __name__, + "Read - with expected types %s and expected type %s", + str(expected_types), + exp_type, + ) + + message: Message = await self.incoming_message + + if message.type not in expected_types: + raise UnexpectedMessageException(message) + + if expected_type is None: + name = message_handler.get_msg_name(message.type) + if name is None: + expected_type = protobuf.type_for_wire(message.type) + else: + expected_type = protobuf.type_for_name(name) + + return message_handler.wrap_protobuf_load(message.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + return await self.channel_ctx.write(msg) + + async def call( + self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] + ) -> protobuf.MessageType: + expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME) + if expected_wire_type is None: + expected_wire_type = expected_type.MESSAGE_WIRE_TYPE + + assert expected_wire_type is not None + + await self.write(msg) + del msg + + return await self.read((expected_wire_type,), expected_type) + + async def call_any( + self, msg: protobuf.MessageType, *expected_types: int + ) -> protobuf.MessageType: + await self.write(msg) + del msg + return await self.read(expected_types) + + +async def handle_pairing_request_message( + pairing_ctx: PairingContext, + msg: protocol_common.Message, +) -> protocol_common.Message | None: + + res_msg: protobuf.MessageType | None = None + + from apps.thp.pairing import handle_pairing_request + + if msg.type in workflow.ALLOW_WHILE_LOCKED: + workflow.autolock_interrupts_workflow = False + + # Here we make sure we always respond with a Failure response + # in case of any errors. + try: + # Find a protobuf.MessageType subclass that describes this + # message. Raises if the type is not found. + name = message_handler.get_msg_name(msg.type) + if name is None: + req_type = protobuf.type_for_wire(msg.type) + else: + req_type = protobuf.type_for_name(name) + + # Try to decode the message according to schema from + # `req_type`. Raises if the message is malformed. + req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) + + # Create the handler task. + task = handle_pairing_request(pairing_ctx, req_msg) + + # Run the workflow task. Workflow can do more on-the-wire + # communication inside, but it should eventually return a + # response message, or raise an exception (a rather common + # thing to do). Exceptions are handled in the code below. + res_msg = await workflow.spawn(context.with_context(pairing_ctx, task)) + + except UnexpectedMessageException as exc: + # Workflow was trying to read a message from the wire, and + # something unexpected came in. See Context.read() for + # example, which expects some particular message and raises + # UnexpectedMessage if another one comes in. + # In order not to lose the message, we return it to the caller. + # TODO: + # We might handle only the few common cases here, like + # Initialize and Cancel. + return exc.msg + except SilentError as exc: + if __debug__: + log.error(__name__, "SilentError: %s", exc.message) + except BaseException as exc: + # Either: + # - the message had a type that has a registered handler, but does not have + # a protobuf class + # - the message was not valid protobuf + # - workflow raised some kind of an exception while running + # - something canceled the workflow from the outside + if __debug__: + if isinstance(exc, ActionCancelled): + log.debug(__name__, "cancelled: %s", exc.message) + elif isinstance(exc, loop.TaskClosed): + log.debug(__name__, "cancelled: loop task was closed") + else: + log.exception(__name__, exc) + res_msg = message_handler.failure(exc) + + if res_msg is not None: + # perform the write outside the big try-except block, so that usb write + # problem bubbles up + await pairing_ctx.write(res_msg) + return None diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py new file mode 100644 index 00000000000..3f9cd8f693c --- /dev/null +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -0,0 +1,446 @@ +import ustruct +from typing import TYPE_CHECKING + +from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, +) +from storage.cache_thp import ( + KEY_LENGTH, + MANAGEMENT_SESSION_ID, + SESSION_ID_LENGTH, + TAG_LENGTH, + update_channel_last_used, + update_session_last_used, +) +from trezor import log, loop, protobuf, utils +from trezor.enums import FailureType +from trezor.messages import Failure + +from .. import message_handler +from ..errors import DataError +from ..protocol_common import Message +from . import ( + ACK_MESSAGE, + HANDSHAKE_COMP_RES, + HANDSHAKE_INIT_RES, + ChannelState, + PacketHeader, + SessionState, + ThpDecryptionError, + ThpError, + ThpErrorType, + ThpInvalidDataError, + ThpUnallocatedSessionError, +) +from . import alternating_bit_protocol as ABP +from . import ( + checksum, + control_byte, + get_enabled_pairing_methods, + get_encoded_device_properties, + session_manager, +) +from .checksum import CHECKSUM_LENGTH +from .crypto import PUBKEY_LENGTH, Handshake +from .writer import ( + INIT_HEADER_LENGTH, + MESSAGE_TYPE_LENGTH, + write_payload_to_wire_and_add_checksum, +) + +if TYPE_CHECKING: + from typing import Awaitable + + from trezor.messages import ThpHandshakeCompletionReqNoisePayload + + from .channel import Channel + +if __debug__: + from ubinascii import hexlify + + from . import state_to_str + + +_TREZOR_STATE_UNPAIRED = b"\x00" +_TREZOR_STATE_PAIRED = b"\x01" + + +async def handle_received_message( + ctx: Channel, message_buffer: utils.BufferType +) -> None: + """Handle a message received from the channel.""" + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_received_message") + if utils.ALLOW_DEBUG_MESSAGES: # TODO remove after performance tests are done + try: + import micropython + + print("micropython.mem_info() from received_message_handler.py") + micropython.mem_info() + print("Allocation count:", micropython.alloc_count()) # type: ignore ["alloc_count" is not a known attribute of module "micropython"] + except AttributeError: + print( + "To show allocation count, create the build with TREZOR_MEMPERF=1" + ) + ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer) + message_length = payload_length + INIT_HEADER_LENGTH + + _check_checksum(message_length, message_buffer) + + # Synchronization process + seq_bit = (ctrl_byte & 0x10) >> 4 + ack_bit = (ctrl_byte & 0x08) >> 3 + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "handle_completed_message - seq bit of message: %d, ack bit of message: %d", + seq_bit, + ack_bit, + ) + # 0: Update "last-time used" + update_channel_last_used(ctx.channel_id) + + # 1: Handle ACKs + if control_byte.is_ack(ctrl_byte): + await _handle_ack(ctx, ack_bit) + return + + if _should_have_ctrl_byte_encrypted_transport( + ctx + ) and not control_byte.is_encrypted_transport(ctrl_byte): + raise ThpError("Message is not encrypted. Ignoring") + + # 2: Handle message with unexpected sequential bit + if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Received message with an unexpected sequential bit") + await _send_ack(ctx, ack_bit=seq_bit) + raise ThpError("Received message with an unexpected sequential bit") + + # 3: Send ACK in response + await _send_ack(ctx, ack_bit=seq_bit) + + ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit) + + try: + await _handle_message_to_app_or_channel( + ctx, payload_length, message_length, ctrl_byte + ) + except ThpUnallocatedSessionError as e: + error_message = Failure(code=FailureType.ThpUnallocatedSession) + await ctx.write(error_message, e.session_id) + except ThpDecryptionError: + await ctx.write_error(ThpErrorType.DECRYPTION_FAILED) + ctx.clear() + except ThpInvalidDataError: + await ctx.write_error(ThpErrorType.INVALID_DATA) + ctx.clear() + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_received_message - end") + + +def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]: + ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) + header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "Writing ACK message to a channel with id: %d, ack_bit: %d", + ctx.get_channel_id_int(), + ack_bit, + ) + return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"") + + +def _check_checksum(message_length: int, message_buffer: utils.BufferType) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "check_checksum") + if not checksum.is_valid( + checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length], + data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH], + ): + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Invalid checksum, ignoring message.") + raise ThpError("Invalid checksum, ignoring message.") + + +async def _handle_ack(ctx: Channel, ack_bit: int) -> None: + if not ABP.is_ack_valid(ctx.channel_cache, ack_bit): + return + # ACK is expected and it has correct sync bit + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Received ACK message with correct ack bit") + if ctx.transmission_loop is not None: + ctx.transmission_loop.stop_immediately() + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Stopped transmission loop") + + ABP.set_sending_allowed(ctx.channel_cache, True) + + if ctx.write_task_spawn is not None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, 'Control to "write_encrypted_payload_loop" task') + await ctx.write_task_spawn + # Note that no the write_task_spawn could result in loop.clear(), + # which will result in termination of this function - any code after + # this await might not be executed + + +def _handle_message_to_app_or_channel( + ctx: Channel, + payload_length: int, + message_length: int, + ctrl_byte: int, +) -> Awaitable[None]: + state = ctx.get_channel_state() + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "state: %s", state_to_str(state)) + + if state is ChannelState.ENCRYPTED_TRANSPORT: + return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) + + if state is ChannelState.TH1: + return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte) + + if state is ChannelState.TH2: + return _handle_state_TH2(ctx, message_length, ctrl_byte) + + if _is_channel_state_pairing(state): + return _handle_pairing(ctx, message_length) + + raise ThpError("Unimplemented channel state") + + +async def _handle_state_TH1( + ctx: Channel, + payload_length: int, + message_length: int, + ctrl_byte: int, +) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_state_TH1") + if not control_byte.is_handshake_init_req(ctrl_byte): + raise ThpError("Message received is not a handshake init request!") + if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH: + raise ThpError("Message received is not a valid handshake init request!") + + ctx.handshake = Handshake() + + host_ephemeral_pubkey = bytearray( + ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH] + ) + trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = ( + ctx.handshake.handle_th1_crypto( + get_encoded_device_properties(ctx.iface), host_ephemeral_pubkey + ) + ) + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "trezor ephemeral pubkey: %s", + hexlify(trezor_ephemeral_pubkey).decode(), + ) + log.debug( + __name__, + "encrypted trezor masked static pubkey: %s", + hexlify(encrypted_trezor_static_pubkey).decode(), + ) + log.debug(__name__, "tag: %s", hexlify(tag)) + + payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag + + # send handshake init response message + ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload) + ctx.set_channel_state(ChannelState.TH2) + return + + +async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None: + from apps.thp.credential_manager import validate_credential + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_state_TH2") + if not control_byte.is_handshake_comp_req(ctrl_byte): + raise ThpError("Message received is not a handshake completion request!") + if ctx.handshake is None: + raise Exception("Handshake object is not prepared. Retry handshake.") + + host_encrypted_static_pubkey = memoryview(ctx.buffer)[ + INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH + ] + handshake_completion_request_noise_payload = memoryview(ctx.buffer)[ + INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH + ] + + ctx.handshake.handle_th2_crypto( + host_encrypted_static_pubkey, handshake_completion_request_noise_payload + ) + + ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive) + ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send) + ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h) + ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1) + + noise_payload = _decode_message( + ctx.buffer[ + INIT_HEADER_LENGTH + + KEY_LENGTH + + TAG_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + 0, + "ThpHandshakeCompletionReqNoisePayload", + ) + if TYPE_CHECKING: + assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload) + enabled_methods = get_enabled_pairing_methods(ctx.iface) + for method in noise_payload.pairing_methods: + if method not in enabled_methods: + raise ThpInvalidDataError() + if method not in ctx.selected_pairing_methods: + ctx.selected_pairing_methods.append(method) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, + "host static pubkey: %s, noise payload: %s", + utils.get_bytes_as_str(host_encrypted_static_pubkey), + utils.get_bytes_as_str(handshake_completion_request_noise_payload), + ) + + # key is decoded in handshake._handle_th2_crypto + host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH] + + paired: bool = False + + if noise_payload.host_pairing_credential is not None: + try: # TODO change try-except for something better + paired = validate_credential( + noise_payload.host_pairing_credential, + host_static_pubkey, + ) + except DataError as e: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.exception(__name__, e) + pass + + trezor_state = _TREZOR_STATE_UNPAIRED + if paired: + trezor_state = _TREZOR_STATE_PAIRED + # send hanshake completion response + ctx.write_handshake_message( + HANDSHAKE_COMP_RES, + ctx.handshake.get_handshake_completion_response(trezor_state), + ) + + ctx.handshake = None + + if paired: + ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + else: + ctx.set_channel_state(ChannelState.TP1) + + +async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") + + ctx.decrypt_buffer(message_length) + session_id, message_type = ustruct.unpack( + ">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:] + ) + if session_id not in ctx.sessions: + if session_id == MANAGEMENT_SESSION_ID: + s = session_manager.create_new_management_session(ctx) + else: + s = session_manager.get_session_from_cache(ctx, session_id) + if s is None: + raise ThpUnallocatedSessionError(session_id) + ctx.sessions[session_id] = s + loop.schedule(s.handle()) + + elif ctx.sessions[session_id].get_session_state() is SessionState.UNALLOCATED: + raise ThpUnallocatedSessionError(session_id) + + s = ctx.sessions[session_id] + update_session_last_used(s.channel_id, (s.session_id).to_bytes(1, "big")) + + s.incoming_message.put( + Message( + message_type, + ctx.buffer[ + INIT_HEADER_LENGTH + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + + +async def _handle_pairing(ctx: Channel, message_length: int) -> None: + from .pairing_context import PairingContext + + if ctx.connection_context is None: + ctx.connection_context = PairingContext(ctx) + loop.schedule(ctx.connection_context.handle()) + + ctx.decrypt_buffer(message_length) + message_type = ustruct.unpack( + ">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :] + )[0] + + ctx.connection_context.incoming_message.put( + Message( + message_type, + ctx.buffer[ + INIT_HEADER_LENGTH + + MESSAGE_TYPE_LENGTH + + SESSION_ID_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + ) + ) + + +def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool: + if ctx.get_channel_state() in [ + ChannelState.UNALLOCATED, + ChannelState.TH1, + ChannelState.TH2, + ]: + return False + return True + + +def _decode_message( + buffer: bytes, msg_type: int, message_name: str | None = None +) -> protobuf.MessageType: + if __debug__: + log.debug(__name__, "decode message") + if message_name is not None: + expected_type = protobuf.type_for_name(message_name) + else: + expected_type = protobuf.type_for_wire(msg_type) + return message_handler.wrap_protobuf_load(buffer, expected_type) + + +def _is_channel_state_pairing(state: int) -> bool: + if state in ( + ChannelState.TP1, + ChannelState.TP2, + ChannelState.TP3, + ChannelState.TP4, + ChannelState.TC1, + ): + return True + return False diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py new file mode 100644 index 00000000000..688fa46b37c --- /dev/null +++ b/core/src/trezor/wire/thp/session_context.py @@ -0,0 +1,169 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp +from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache +from trezor import log, loop, protobuf, utils +from trezor.wire import message_handler, protocol_common +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.message_handler import failure + +from ..protocol_common import Context, Message +from . import SessionState + +if TYPE_CHECKING: + from typing import Awaitable, Container + + from storage.cache_common import DataCache + + from .channel import Channel + + pass + +_EXIT_LOOP = True +_REPEAT_LOOP = False + +if __debug__: + from trezor.utils import get_bytes_as_str + + +class GenericSessionContext(Context): + + def __init__(self, channel: Channel, session_id: int) -> None: + super().__init__(channel.iface, channel.channel_id) + self.channel: Channel = channel + self.session_id: int = session_id + self.incoming_message = loop.mailbox() + + async def handle(self) -> None: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._handle_debug() + + next_message: Message | None = None + + while True: + message = next_message + next_message = None + try: + if await self._handle_message(message): + loop.schedule(self.handle()) + return + except UnexpectedMessageException as unexpected: + # The workflow was interrupted by an unexpected message. We need to + # process it as if it was a new message... + next_message = unexpected.msg + continue + except Exception as exc: + # Log and try again. + if __debug__: + log.exception(__name__, exc) + + def _handle_debug(self) -> None: + log.debug( + __name__, + "handle - start (channel_id (bytes): %s, session_id: %d)", + get_bytes_as_str(self.channel_id), + self.session_id, + ) + + async def _handle_message( + self, + next_message: Message | None, + ) -> bool: + + try: + if next_message is not None: + # Process the message from previous run. + message = next_message + next_message = None + else: + # Wait for a new message from wire + message = await self.incoming_message + + except protocol_common.WireError as e: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.exception(__name__, e) + await self.write(failure(e)) + return _REPEAT_LOOP + + await message_handler.handle_single_message(self, message) + return _EXIT_LOOP + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + exp_type: str = str(expected_type) + if expected_type is not None: + exp_type = expected_type.MESSAGE_NAME + log.debug( + __name__, + "Read - with expected types %s and expected type %s", + str(expected_types), + exp_type, + ) + message: Message = await self.incoming_message + if message.type not in expected_types: + if __debug__: + log.debug( + __name__, + "EXPECTED TYPES: %s\nRECEIVED TYPE: %s", + str(expected_types), + str(message.type), + ) + raise UnexpectedMessageException(message) + + if expected_type is None: + expected_type = protobuf.type_for_wire(message.type) + + return message_handler.wrap_protobuf_load(message.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + return await self.channel.write(msg, self.session_id) + + def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.channel.write(msg, self.session_id, force=True) + + def get_session_state(self) -> SessionState: ... + + +class ManagementSessionContext(GenericSessionContext): + + def __init__( + self, channel_ctx: Channel, session_id: int = MANAGEMENT_SESSION_ID + ) -> None: + super().__init__(channel_ctx, session_id) + + def get_session_state(self) -> SessionState: + return SessionState.MANAGEMENT + + +class SessionContext(GenericSessionContext): + + def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None: + if channel_ctx.channel_id != session_cache.channel_id: + raise Exception( + "The session has different channel id than the provided channel context!" + ) + session_id = int.from_bytes(session_cache.session_id, "big") + super().__init__(channel_ctx, session_id) + self.session_cache = session_cache + + # ACCESS TO SESSION DATA + + def get_session_state(self) -> SessionState: + state = int.from_bytes(self.session_cache.state, "big") + return SessionState(state) + + def set_session_state(self, state: SessionState) -> None: + self.session_cache.state = bytearray(state.to_bytes(1, "big")) + + def release(self) -> None: + if self.session_cache is not None: + cache_thp.clear_session(self.session_cache) + + # ACCESS TO CACHE + @property + def cache(self) -> DataCache: + return self.session_cache diff --git a/core/src/trezor/wire/thp/session_manager.py b/core/src/trezor/wire/thp/session_manager.py new file mode 100644 index 00000000000..3377ce437fb --- /dev/null +++ b/core/src/trezor/wire/thp/session_manager.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from storage import cache_thp + +from .session_context import ( + GenericSessionContext, + ManagementSessionContext, + SessionContext, +) + +if TYPE_CHECKING: + from .channel import Channel + + +def create_new_session(channel_ctx: Channel) -> SessionContext: + """ + Creates new `SessionContext` backed by cache. + """ + session_cache = cache_thp.get_new_session(channel_ctx.channel_cache) + return SessionContext(channel_ctx, session_cache) + + +def create_new_management_session( + channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID +) -> ManagementSessionContext: + """ + Creates new `ManagementSessionContext` that is not backed by cache entry. + + Seed cannot be derived with this type of session. + """ + return ManagementSessionContext(channel_ctx, session_id) + + +def get_session_from_cache( + channel_ctx: Channel, session_id: int +) -> GenericSessionContext | None: + """ + Returns a `SessionContext` (or `ManagementSessionContext`) reconstructed from a cache or `None` if backing cache is not found. + """ + session_id_bytes = session_id.to_bytes(1, "big") + session_cache = cache_thp.get_allocated_session( + channel_ctx.channel_id, session_id_bytes + ) + if session_cache is None: + return None + elif cache_thp.is_management_session(session_cache): + return ManagementSessionContext(channel_ctx, session_id) + return SessionContext(channel_ctx, session_cache) diff --git a/core/src/trezor/wire/thp/thp_main.py b/core/src/trezor/wire/thp/thp_main.py new file mode 100644 index 00000000000..2381ca06389 --- /dev/null +++ b/core/src/trezor/wire/thp/thp_main.py @@ -0,0 +1,187 @@ +import ustruct +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import io, log, loop, utils + +from . import ( + CHANNEL_ALLOCATION_REQ, + CODEC_V1, + ChannelState, + PacketHeader, + ThpError, + ThpErrorType, + channel_manager, + checksum, + get_channel_allocation_response, + writer, +) +from .channel import Channel +from .checksum import CHECKSUM_LENGTH +from .writer import ( + INIT_HEADER_LENGTH, + MAX_PAYLOAD_LEN, + PACKET_LENGTH, + write_payload_to_wire_and_add_checksum, +) + +if TYPE_CHECKING: + from trezorio import WireInterface + +_CID_REQ_PAYLOAD_LENGTH = const(12) +_READ_BUFFER: bytearray +_WRITE_BUFFER: bytearray +_CHANNELS: dict[int, Channel] = {} + + +def set_read_buffer(buffer: bytearray) -> None: + global _READ_BUFFER + _READ_BUFFER = buffer + + +def set_write_buffer(buffer: bytearray) -> None: + global _WRITE_BUFFER + _WRITE_BUFFER = buffer + + +def get_raw_read_buffer() -> bytearray: + global _READ_BUFFER + return _READ_BUFFER + + +def get_raw_write_buffer() -> bytearray: + global _WRITE_BUFFER + return _WRITE_BUFFER + + +async def thp_main_loop(iface: WireInterface) -> None: + global _CHANNELS + global _READ_BUFFER + _CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER) + + read = loop.wait(iface.iface_num() | io.POLL_READ) + + while True: + try: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "thp_main_loop") + packet = await read + ctrl_byte, cid = ustruct.unpack(">BH", packet) + + if ctrl_byte == CODEC_V1: + await _handle_codec_v1(iface, packet) + continue + + if cid == BROADCAST_CHANNEL_ID: + await _handle_broadcast(iface, ctrl_byte, packet) + continue + + if cid in _CHANNELS: + await _handle_allocated(iface, cid, packet) + else: + await _handle_unallocated(iface, cid) + + except ThpError as e: + if __debug__: + log.exception(__name__, e) + + +async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None: + # If the received packet is not an initial codec_v1 packet, do not send error message + if not packet[1:3] == b"##": + return + if __debug__: + log.debug(__name__, "Received codec_v1 message, returning error") + error_message = _get_codec_v1_error_message() + await writer.write_packet_to_wire(iface, error_message) + + +async def _handle_broadcast( + iface: WireInterface, ctrl_byte: int, packet: utils.BufferType +) -> None: + global _READ_BUFFER + if ctrl_byte != CHANNEL_ALLOCATION_REQ: + raise ThpError("Unexpected ctrl_byte in a broadcast channel packet") + if __debug__: + log.debug(__name__, "Received valid message on the broadcast channel") + + length, nonce = ustruct.unpack(">H8s", packet[3:]) + payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH) + if not checksum.is_valid( + payload[-4:], + packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH], + ): + raise ThpError("Checksum is not valid") + + new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER) + cid = int.from_bytes(new_channel.channel_id, "big") + _CHANNELS[cid] = new_channel + + response_data = get_channel_allocation_response( + nonce, new_channel.channel_id, iface + ) + response_header = PacketHeader.get_channel_allocation_response_header( + len(response_data) + CHECKSUM_LENGTH, + ) + if __debug__: + log.debug(__name__, "New channel allocated with id %d", cid) + + await write_payload_to_wire_and_add_checksum(iface, response_header, response_data) + + +async def _handle_allocated( + iface: WireInterface, cid: int, packet: utils.BufferType +) -> None: + channel = _CHANNELS[cid] + if channel is None: + await _handle_unallocated(iface, cid) + raise ThpError("Invalid state of a channel") + if channel.iface is not iface: + # TODO send error message to wire + raise ThpError("Channel has different WireInterface") + + if channel.get_channel_state() != ChannelState.UNALLOCATED: + x = channel.receive_packet(packet) + if x is not None: + await x + + +async def _handle_unallocated(iface: WireInterface, cid: int) -> None: + data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big") + header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) + await write_payload_to_wire_and_add_checksum(iface, header, data) + + +def _get_buffer_for_payload( + payload_length: int, + existing_buffer: utils.BufferType, + max_length: int = MAX_PAYLOAD_LEN, +) -> utils.BufferType: + if payload_length > max_length: + raise ThpError("Message too large") + if payload_length > len(existing_buffer): + return _try_allocate_new_buffer(payload_length) + return _reuse_existing_buffer(payload_length, existing_buffer) + + +def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType: + try: + payload: utils.BufferType = bytearray(payload_length) + except MemoryError: + payload = bytearray(PACKET_LENGTH) + raise ThpError("Message too large") + return payload + + +def _reuse_existing_buffer( + payload_length: int, existing_buffer: utils.BufferType +) -> utils.BufferType: + return memoryview(existing_buffer)[:payload_length] + + +def _get_codec_v1_error_message() -> bytes: + # Codec_v1 magic constant "?##" + Failure message type + msg_size + # + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B + ERROR_MSG = b"\x3f\x23\x23\x00\x03\x00\x00\x00\x14\x08\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + return ERROR_MSG diff --git a/core/src/trezor/wire/thp/transmission_loop.py b/core/src/trezor/wire/thp/transmission_loop.py new file mode 100644 index 00000000000..cd3e3ba2f8d --- /dev/null +++ b/core/src/trezor/wire/thp/transmission_loop.py @@ -0,0 +1,54 @@ +from micropython import const +from typing import TYPE_CHECKING + +from trezor import loop + +from .writer import write_payload_to_wire_and_add_checksum + +if TYPE_CHECKING: + from . import PacketHeader + from .channel import Channel + +MAX_RETRANSMISSION_COUNT = const(50) +MIN_RETRANSMISSION_COUNT = const(2) + + +class TransmissionLoop: + + def __init__( + self, channel: Channel, header: PacketHeader, transport_payload: bytes + ) -> None: + self.channel: Channel = channel + self.header: PacketHeader = header + self.transport_payload: bytes = transport_payload + self.wait_task: loop.spawn | None = None + self.min_retransmisson_count_achieved: bool = False + + async def start( + self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT + ) -> None: + self.min_retransmisson_count_achieved = False + for i in range(max_retransmission_count): + if i >= MIN_RETRANSMISSION_COUNT: + self.min_retransmisson_count_achieved = True + await write_payload_to_wire_and_add_checksum( + self.channel.iface, self.header, self.transport_payload + ) + self.wait_task = loop.spawn(self._wait(i)) + try: + await self.wait_task + except loop.TaskClosed: + self.wait_task = None + break + + def stop_immediately(self) -> None: + if self.wait_task is not None: + self.wait_task.close() + self.wait_task = None + + async def _wait(self, counter: int = 0) -> None: + timeout_ms = round(10200 - 1010000 / (counter + 100)) + await loop.sleep(timeout_ms) + + def __del__(self) -> None: + self.stop_immediately() diff --git a/core/src/trezor/wire/thp/writer.py b/core/src/trezor/wire/thp/writer.py new file mode 100644 index 00000000000..9fe69c5cd25 --- /dev/null +++ b/core/src/trezor/wire/thp/writer.py @@ -0,0 +1,92 @@ +from micropython import const +from trezorcrypto import crc +from typing import TYPE_CHECKING + +from trezor import io, log, loop, utils + +from . import PacketHeader + +INIT_HEADER_LENGTH = const(5) +CONT_HEADER_LENGTH = const(3) +PACKET_LENGTH = const(64) +CHECKSUM_LENGTH = const(4) +MAX_PAYLOAD_LEN = const(60000) +MESSAGE_TYPE_LENGTH = const(2) + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Awaitable, Sequence + + +def write_payload_to_wire_and_add_checksum( + iface: WireInterface, header: PacketHeader, transport_payload: bytes +) -> Awaitable[None]: + header_checksum: int = crc.crc32(header.to_bytes()) + checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes( + CHECKSUM_LENGTH, "big" + ) + data = (transport_payload, checksum) + return write_payloads_to_wire(iface, header, data) + + +async def write_payloads_to_wire( + iface: WireInterface, header: PacketHeader, data: Sequence[bytes] +) -> None: + n_of_data = len(data) + total_length = sum(len(item) for item in data) + + current_data_idx = 0 + current_data_offset = 0 + + packet = bytearray(PACKET_LENGTH) + header.pack_to_init_buffer(packet) + packet_offset: int = INIT_HEADER_LENGTH + packet_number = 0 + nwritten = 0 + while nwritten < total_length: + if packet_number == 1: + header.pack_to_cont_buffer(packet) + if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH: + packet[:] = bytearray(PACKET_LENGTH) + header.pack_to_cont_buffer(packet) + while True: + n = utils.memcpy( + packet, packet_offset, data[current_data_idx], current_data_offset + ) + packet_offset += n + current_data_offset += n + nwritten += n + + if packet_offset < PACKET_LENGTH: + current_data_idx += 1 + current_data_offset = 0 + if current_data_idx >= n_of_data: + break + elif packet_offset == PACKET_LENGTH: + break + else: + raise Exception("Should not happen!!!") + packet_number += 1 + packet_offset = CONT_HEADER_LENGTH + + # write packet to wire (in-lined) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet) + ) + written_by_iface: int = 0 + while written_by_iface < len(packet): + await loop.wait(iface.iface_num() | io.POLL_WRITE) + written_by_iface = iface.write(packet) + + +async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None: + while True: + await loop.wait(iface.iface_num() | io.POLL_WRITE) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet) + ) + n_written = iface.write(packet) + if n_written == len(packet): + return diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py index 67b88f8e684..9fc72c3e987 100644 --- a/core/src/trezor/workflow.py +++ b/core/src/trezor/workflow.py @@ -3,7 +3,7 @@ import storage.cache as storage_cache from trezor import log, loop -from trezor.enums import MessageType +from trezor.enums import MessageType, ThpMessageType if TYPE_CHECKING: from typing import Callable @@ -17,9 +17,14 @@ from trezor import utils +if utils.USE_THP: + protocol_specific = ThpMessageType.ThpCreateNewSession +else: + protocol_specific = MessageType.Initialize + ALLOW_WHILE_LOCKED = ( - MessageType.Initialize, + protocol_specific, MessageType.EndSession, MessageType.GetFeatures, MessageType.Cancel, diff --git a/core/tests/mock_wire_interface.py b/core/tests/mock_wire_interface.py new file mode 100644 index 00000000000..b74b2150643 --- /dev/null +++ b/core/tests/mock_wire_interface.py @@ -0,0 +1,17 @@ +from trezor.loop import wait + + +class MockHID: + def __init__(self, num): + self.num = num + self.data = [] + + def iface_num(self): + return self.num + + def write(self, msg): + self.data.append(bytearray(msg)) + return len(msg) + + def wait_object(self, mode): + return wait(mode | self.num) diff --git a/core/tests/myTests.sh b/core/tests/myTests.sh new file mode 100755 index 00000000000..1c29c1fd01b --- /dev/null +++ b/core/tests/myTests.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +declare -a results +declare -i passed=0 failed=0 exit_code=0 +declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m' +MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}" +print_summary() { + echo + echo 'Summary:' + echo '-------------------' + printf '%b\n' "${results[@]}" + if [ $exit_code == 0 ]; then + echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!" + else + echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!" + fi +} + +trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT + +cd $(dirname $0) + +[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*) + +declare -i num_of_tests=${#tests[@]} + +for test_case in ${tests[@]}; do + echo ${MICROPYTHON} + echo ${test_case} + echo + if $MICROPYTHON $test_case; then + results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case") + ((passed++)) + else + results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case") + ((failed++)) + exit_code=1 + fi +done + +print_summary +exit $exit_code diff --git a/core/tests/test_apps.bitcoin.approver.py b/core/tests/test_apps.bitcoin.approver.py index 22888546f03..17a870df7f9 100644 --- a/core/tests/test_apps.bitcoin.approver.py +++ b/core/tests/test_apps.bitcoin.approver.py @@ -1,4 +1,4 @@ -from common import H_, await_result, unittest # isort:skip +from common import * # isort:skip import storage.cache_codec from trezor import wire @@ -12,19 +12,33 @@ TxOutput, ) from trezor.wire import context -from trezor.wire.codec.codec_context import CodecContext from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization from apps.bitcoin.sign_tx.approvers import CoinJoinApprover from apps.bitcoin.sign_tx.bitcoin import Bitcoin from apps.bitcoin.sign_tx.tx_info import TxInfo from apps.common import coins +from trezor.wire.codec.codec_context import CodecContext + +if utils.USE_THP: + import thp_common +else: + import storage.cache_codec + from trezor.wire.codec.codec_context import CodecContext class TestApprover(unittest.TestCase): + if utils.USE_THP: - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) def tearDownClass(self): context.CURRENT_CONTEXT = None @@ -48,15 +62,14 @@ def setUp(self): self.msg_auth = AuthorizeCoinJoin( coordinator=self.coordinator_name, max_rounds=10, - max_coordinator_fee_rate=int( - self.fee_rate_percent * 10**FEE_RATE_DECIMALS - ), + max_coordinator_fee_rate=int(self.fee_rate_percent * 10**FEE_RATE_DECIMALS), max_fee_per_kvbyte=7000, address_n=[H_(10025), H_(0), H_(0), H_(1)], coin_name=self.coin.coin_name, script_type=InputScriptType.SPENDTAPROOT, ) - storage.cache_codec.start_session() + if not utils.USE_THP: + storage.cache_codec.start_session() def make_coinjoin_request(self, inputs): return CoinJoinRequest( @@ -155,7 +168,11 @@ def test_coinjoin_lots_of_inputs(self): if txo.address_n: await_result(approver.add_change_output(txo, script_pubkey=bytes(22))) else: - await_result(approver.add_external_output(txo, script_pubkey=bytes(22), tx_info=tx_info)) + await_result( + approver.add_external_output( + txo, script_pubkey=bytes(22), tx_info=tx_info + ) + ) await_result(approver.approve_tx(tx_info, [], None)) diff --git a/core/tests/test_apps.bitcoin.authorization.py b/core/tests/test_apps.bitcoin.authorization.py index 03d32651c70..4faf2029892 100644 --- a/core/tests/test_apps.bitcoin.authorization.py +++ b/core/tests/test_apps.bitcoin.authorization.py @@ -1,23 +1,37 @@ -from common import H_, unittest # isort:skip +from common import * # isort:skip import storage.cache_codec from trezor.enums import InputScriptType from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx from trezor.wire import context -from trezor.wire.codec.codec_context import CodecContext from apps.bitcoin.authorization import CoinJoinAuthorization from apps.common import coins _ROUND_ID_LEN = 32 +if utils.USE_THP: + import thp_common +else: + import storage.cache_codec + from trezor.wire.codec.codec_context import CodecContext + class TestAuthorization(unittest.TestCase): coin = coins.by_name("Bitcoin") - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) def tearDownClass(self): context.CURRENT_CONTEXT = None @@ -34,7 +48,8 @@ def setUp(self): ) self.authorization = CoinJoinAuthorization(self.msg_auth) - storage.cache_codec.start_session() + if not utils.USE_THP: + storage.cache_codec.start_session() def test_ownership_proof_account_depth_mismatch(self): # Account depth mismatch. diff --git a/core/tests/test_apps.bitcoin.keychain.py b/core/tests/test_apps.bitcoin.keychain.py index a232a000ae4..9fe07e6220e 100644 --- a/core/tests/test_apps.bitcoin.keychain.py +++ b/core/tests/test_apps.bitcoin.keychain.py @@ -7,22 +7,39 @@ from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin from trezor.wire.codec.codec_context import CodecContext -from storage import cache_codec + +if utils.USE_THP: + import thp_common +else: + from storage import cache_codec class TestBitcoinKeychain(unittest.TestCase): - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def tearDownClass(self): context.CURRENT_CONTEXT = None - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) - def test_bitcoin(self): coin = _get_coin_by_name("Bitcoin") keychain = await_result(_get_keychain_for_coin(coin)) @@ -98,18 +115,30 @@ def test_unknown(self): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestAltcoinKeychains(unittest.TestCase): + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def tearDownClass(self): context.CURRENT_CONTEXT = None - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) - def test_bcash(self): coin = _get_coin_by_name("Bcash") keychain = await_result(_get_keychain_for_coin(coin)) diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py index fa2e3ff0414..8c3482b7b63 100644 --- a/core/tests/test_apps.common.keychain.py +++ b/core/tests/test_apps.common.keychain.py @@ -11,20 +11,33 @@ from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.paths import PATTERN_SEP5, PathSchema from trezor.wire.codec.codec_context import CodecContext -from storage import cache_codec + +if utils.USE_THP: + import thp_common +if not utils.USE_THP: + from storage import cache_codec class TestKeychain(unittest.TestCase): - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def setUp(self): + cache_codec.start_session() def tearDownClass(self): context.CURRENT_CONTEXT = None - def setUp(self): - cache_codec.start_session() - def tearDown(self): cache.clear_all() diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py index a00b412a5f2..09e88510ce3 100644 --- a/core/tests/test_apps.ethereum.keychain.py +++ b/core/tests/test_apps.ethereum.keychain.py @@ -10,7 +10,12 @@ from apps.common.keychain import get_keychain from apps.common.paths import HARDENED from trezor.wire.codec.codec_context import CodecContext -from storage import cache_codec + +if utils.USE_THP: + import thp_common +else: + from storage import cache_codec + if not utils.BITCOIN_ONLY: from ethereum_common import encode_network, make_network @@ -74,17 +79,30 @@ def _check_keychain(self, keychain, slip44_id): addr, ) - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) + + else: + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def tearDownClass(self): context.CURRENT_CONTEXT = None - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) - def from_address_n(self, address_n): slip44 = _slip44_from_address_n(address_n) network = make_network(slip44=slip44) diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 25eb119bd3c..c419cc9ec32 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -1,240 +1,520 @@ -from common import * # isort:skip +from common import * # isort:skip # noqa: F403 -from mock_storage import mock_storage -from storage import cache, cache_codec, cache_common -from trezor.messages import EndSession, Initialize -from trezor.wire import context -from trezor.wire.codec.codec_context import CodecContext - -from apps.base import handle_EndSession, handle_Initialize -from apps.common.cache import stored, stored_async KEY = 0 +if utils.USE_THP: + import thp_common + from mock_wire_interface import MockHID + from storage import cache, cache_thp + from trezor.wire.thp import ChannelState + from trezor.wire.thp.session_context import SessionContext + + _PROTOCOL_CACHE = cache_thp + +else: + from storage import cache, cache_codec + from trezor.messages import EndSession, Initialize + from apps.base import handle_EndSession + from mock_storage import mock_storage + + _PROTOCOL_CACHE = cache_codec -# Function moved from cache.py, as it was not used there -def is_session_started() -> bool: - return cache_codec._active_session_idx is not None + def is_session_started() -> bool: + return cache_codec.get_active_session() is not None + + def get_active_session(): + return cache_codec.get_active_session() class TestStorageCache(unittest.TestCase): - def setUpClass(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) - - def tearDownClass(self): - context.CURRENT_CONTEXT = None - - def setUp(self): - cache.clear_all() - - def test_start_session(self): - session_id_a = cache_codec.start_session() - self.assertIsNotNone(session_id_a) - session_id_b = cache_codec.start_session() - self.assertNotEqual(session_id_a, session_id_b) - - cache.clear_all() - with self.assertRaises(cache_common.InvalidSessionError): - context.cache_set(KEY, "something") - with self.assertRaises(cache_common.InvalidSessionError): - context.cache_get(KEY) - - def test_end_session(self): - session_id = cache_codec.start_session() - self.assertTrue(is_session_started()) - context.cache_set(KEY, b"A") - cache_codec.end_current_session() - self.assertFalse(is_session_started()) - self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) - - # ending an ended session should be a no-op - cache_codec.end_current_session() - self.assertFalse(is_session_started()) - - session_id_a = cache_codec.start_session(session_id) - # original session no longer exists - self.assertNotEqual(session_id_a, session_id) - # original session data no longer exists - self.assertIsNone(context.cache_get(KEY)) - - # create a new session - session_id_b = cache_codec.start_session() - # switch back to original session - session_id = cache_codec.start_session(session_id_a) - self.assertEqual(session_id, session_id_a) - # end original session - cache_codec.end_current_session() - # switch back to B - session_id = cache_codec.start_session(session_id_b) - self.assertEqual(session_id, session_id_b) - - def test_session_queue(self): - session_id = cache_codec.start_session() - self.assertEqual(cache_codec.start_session(session_id), session_id) - context.cache_set(KEY, b"A") - for i in range(cache_codec._MAX_SESSIONS_COUNT): + if utils.USE_THP: + + def setUpClass(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + cache.clear_all() + + def test_new_channel_and_session(self): + channel = thp_common.get_new_channel(self.interface) + + # Assert that channel is created without any sessions + self.assertEqual(len(channel.sessions), 0) + + cid_1 = channel.channel_id + session_cache_1 = cache_thp.get_new_session(channel.channel_cache) + session_1 = SessionContext(channel, session_cache_1) + self.assertEqual(session_1.channel_id, cid_1) + + session_cache_2 = cache_thp.get_new_session(channel.channel_cache) + session_2 = SessionContext(channel, session_cache_2) + self.assertEqual(session_2.channel_id, cid_1) + self.assertEqual(session_1.channel_id, session_2.channel_id) + self.assertNotEqual(session_1.session_id, session_2.session_id) + + channel_2 = thp_common.get_new_channel(self.interface) + cid_2 = channel_2.channel_id + self.assertNotEqual(cid_1, cid_2) + + session_cache_3 = cache_thp.get_new_session(channel_2.channel_cache) + session_3 = SessionContext(channel_2, session_cache_3) + self.assertEqual(session_3.channel_id, cid_2) + + # Sessions 1 and 3 should have different channel_id, but the same session_id + self.assertNotEqual(session_1.channel_id, session_3.channel_id) + self.assertEqual(session_1.session_id, session_3.session_id) + + self.assertEqual(cache_thp._SESSIONS[0], session_cache_1) + self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2) + self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id) + + # Check that session data IS in cache for created sessions ONLY + for i in range(3): + self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"") + self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"") + self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0) + for i in range(3, cache_thp._MAX_SESSIONS_COUNT): + self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"") + self.assertEqual(cache_thp._SESSIONS[i].session_id, b"") + self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0) + + # Check that session data IS NOT in cache after cache.clear_all() + cache.clear_all() + for session in cache_thp._SESSIONS: + self.assertEqual(session.channel_id, b"") + self.assertEqual(session.session_id, b"") + self.assertEqual(session.last_usage, 0) + self.assertEqual(session.state, b"\x00") + + def test_channel_capacity_in_cache(self): + self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3) + channels = [] + for i in range(cache_thp._MAX_CHANNELS_COUNT): + channels.append(thp_common.get_new_channel(self.interface)) + channel_ids = [channel.channel_cache.channel_id for channel in channels] + + # Assert that each channel_id is unique and that cache and list of channels + # have the same "channels" on the same indexes + for i in range(len(channel_ids)): + self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i]) + for j in range(i + 1, len(channel_ids)): + self.assertNotEqual(channel_ids[i], channel_ids[j]) + + # Create a new channel that is over the capacity + new_channel = thp_common.get_new_channel(self.interface) + for c in channels: + self.assertNotEqual(c.channel_id, new_channel.channel_id) + + # Test that the oldest (least used) channel was replaced (_CHANNELS[0]) + self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0]) + self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id) + + # Update the "last used" value of the second channel in cache (_CHANNELS[1]) and + # assert that it is not replaced when creating a new channel + cache_thp.update_channel_last_used(channel_ids[1]) + new_new_channel = thp_common.get_new_channel(self.interface) + self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1]) + + # Assert that it was in fact the _CHANNEL[2] that was replaced + self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2]) + self.assertEqual( + cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id + ) + + def test_session_capacity_in_cache(self): + self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4) + channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache + channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache + + sesions_A = [] + cid = [] + sid = [] + for i in range(3): + sesions_A.append(cache_thp.get_new_session(channel_cache_A)) + cid.append(sesions_A[i].channel_id) + sid.append(sesions_A[i].session_id) + + sessions_B = [] + for i in range(cache_thp._MAX_SESSIONS_COUNT - 3): + sessions_B.append(cache_thp.get_new_session(channel_cache_B)) + + for i in range(3): + self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i]) + self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id) + self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id) + for i in range(3, cache_thp._MAX_SESSIONS_COUNT): + self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i]) + + # Assert that new session replaces the oldest (least used) one (_SESSOIONS[0]) + new_session = cache_thp.get_new_session(channel_cache_B) + self.assertEqual(new_session, cache_thp._SESSIONS[0]) + self.assertNotEqual(new_session.channel_id, cid[0]) + self.assertNotEqual(new_session.session_id, sid[0]) + + # Assert that updating "last used" for session on channel A increases also + # the "last usage" of channel A. + self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage) + cache_thp.update_session_last_used( + channel_cache_A.channel_id, sesions_A[1].session_id + ) + self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage) + + new_new_session = cache_thp.get_new_session(channel_cache_B) + + # Assert that creating a new session on channel B shifts the "last usage" again + # and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced + self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage) + self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1]) + self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2]) + self.assertEqual(new_new_session, cache_thp._SESSIONS[2]) + + def test_clear(self): + channel_A = thp_common.get_new_channel(self.interface) + channel_B = thp_common.get_new_channel(self.interface) + cid_A = channel_A.channel_id + cid_B = channel_B.channel_id + sessions = [] + + for i in range(3): + sessions.append(cache_thp.get_new_session(channel_A.channel_cache)) + sessions.append(cache_thp.get_new_session(channel_B.channel_cache)) + + self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A) + self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0) + + self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B) + self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0) + + # Assert that clearing of channel A works + self.assertNotEqual(channel_A.channel_cache.channel_id, b"") + self.assertNotEqual(channel_A.channel_cache.last_usage, 0) + self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1) + + channel_A.clear() + + self.assertEqual(channel_A.channel_cache.channel_id, b"") + self.assertEqual(channel_A.channel_cache.last_usage, 0) + self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED) + + # Assert that clearing channel A also cleared all its sessions + for i in range(3): + self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0) + self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"") + + self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0) + self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B) + + cache.clear_all() + for session in cache_thp._SESSIONS: + self.assertEqual(session.last_usage, 0) + self.assertEqual(session.channel_id, b"") + for channel in cache_thp._CHANNELS: + self.assertEqual(channel.channel_id, b"") + self.assertEqual(channel.last_usage, 0) + self.assertEqual( + cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED + ) + + def test_get_set(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.get_new_session(channel.channel_cache) + session_1.set(KEY, b"hello") + self.assertEqual(session_1.get(KEY), b"hello") + + session_2 = cache_thp.get_new_session(channel.channel_cache) + session_2.set(KEY, b"world") + self.assertEqual(session_2.get(KEY), b"world") + + self.assertEqual(session_1.get(KEY), b"hello") + + cache.clear_all() + self.assertIsNone(session_1.get(KEY)) + self.assertIsNone(session_2.get(KEY)) + + def test_get_set_int(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.get_new_session(channel.channel_cache) + session_1.set_int(KEY, 1234) + + self.assertEqual(session_1.get_int(KEY), 1234) + + session_2 = cache_thp.get_new_session(channel.channel_cache) + session_2.set_int(KEY, 5678) + self.assertEqual(session_2.get_int(KEY), 5678) + + self.assertEqual(session_1.get_int(KEY), 1234) + + cache.clear_all() + self.assertIsNone(session_1.get_int(KEY)) + self.assertIsNone(session_2.get_int(KEY)) + + def test_get_set_bool(self): + channel = thp_common.get_new_channel(self.interface) + + session_1 = cache_thp.get_new_session(channel.channel_cache) + with self.assertRaises(AssertionError): + session_1.set_bool(KEY, True) + + # Change length of first session field to 0 so that the length check passes + session_1.fields = (0,) + session_1.fields[1:] + + # with self.assertRaises(AssertionError) as e: + session_1.set_bool(KEY, True) + self.assertEqual(session_1.get_bool(KEY), True) + + session_2 = cache_thp.get_new_session(channel.channel_cache) + session_2.fields = session_2.fields = (0,) + session_2.fields[1:] + session_2.set_bool(KEY, False) + self.assertEqual(session_2.get_bool(KEY), False) + + self.assertEqual(session_1.get_bool(KEY), True) + + cache.clear_all() + + # Default value is False + self.assertFalse(session_1.get_bool(KEY)) + self.assertFalse(session_2.get_bool(KEY)) + + def test_delete(self): + channel = thp_common.get_new_channel(self.interface) + session_1 = cache_thp.get_new_session(channel.channel_cache) + + self.assertIsNone(session_1.get(KEY)) + session_1.set(KEY, b"hello") + self.assertEqual(session_1.get(KEY), b"hello") + session_1.delete(KEY) + self.assertIsNone(session_1.get(KEY)) + + session_1.set(KEY, b"hello") + session_2 = cache_thp.get_new_session(channel.channel_cache) + + self.assertIsNone(session_2.get(KEY)) + session_2.set(KEY, b"hello") + self.assertEqual(session_2.get(KEY), b"hello") + session_2.delete(KEY) + self.assertIsNone(session_2.get(KEY)) + + self.assertEqual(session_1.get(KEY), b"hello") + + else: + + def setUpClass(self): + from trezor.wire.codec.codec_context import CodecContext + from trezor.wire import context + + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + from trezor.wire import context + + context.CURRENT_CONTEXT = None + + def setUp(self): + cache.clear_all() + + def test_start_session(self): + session_id_a = cache_codec.start_session() + self.assertIsNotNone(session_id_a) + session_id_b = cache_codec.start_session() + self.assertNotEqual(session_id_a, session_id_b) + + cache.clear_all() + self.assertIsNone(get_active_session()) + for session in cache_codec._SESSIONS: + self.assertEqual(session.session_id, b"") + self.assertEqual(session.last_usage, 0) + + def test_end_session(self): + session_id = cache_codec.start_session() + self.assertTrue(is_session_started()) + get_active_session().set(KEY, b"A") + cache_codec.end_current_session() + self.assertFalse(is_session_started()) + self.assertIsNone(get_active_session()) + + # ending an ended session should be a no-op + cache_codec.end_current_session() + self.assertFalse(is_session_started()) + + session_id_a = cache_codec.start_session(session_id) + # original session no longer exists + self.assertNotEqual(session_id_a, session_id) + # original session data no longer exists + self.assertIsNone(get_active_session().get(KEY)) + + # create a new session + session_id_b = cache_codec.start_session() + # switch back to original session + session_id = cache_codec.start_session(session_id_a) + self.assertEqual(session_id, session_id_a) + # end original session + cache_codec.end_current_session() + # switch back to B + session_id = cache_codec.start_session(session_id_b) + self.assertEqual(session_id, session_id_b) + + def test_session_queue(self): + session_id = cache_codec.start_session() + self.assertEqual(cache_codec.start_session(session_id), session_id) + get_active_session().set(KEY, b"A") + for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT): + cache_codec.start_session() + self.assertNotEqual(cache_codec.start_session(session_id), session_id) + self.assertIsNone(get_active_session().get(KEY)) + + def test_get_set(self): + session_id1 = cache_codec.start_session() + cache_codec.get_active_session().set(KEY, b"hello") + self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello") + + session_id2 = cache_codec.start_session() + cache_codec.get_active_session().set(KEY, b"world") + self.assertEqual(cache_codec.get_active_session().get(KEY), b"world") + + cache_codec.start_session(session_id2) + self.assertEqual(cache_codec.get_active_session().get(KEY), b"world") + cache_codec.start_session(session_id1) + self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello") + + cache.clear_all() + self.assertIsNone(cache_codec.get_active_session()) + + def test_get_set_int(self): + session_id1 = cache_codec.start_session() + get_active_session().set_int(KEY, 1234) + self.assertEqual(get_active_session().get_int(KEY), 1234) + + session_id2 = cache_codec.start_session() + get_active_session().set_int(KEY, 5678) + self.assertEqual(get_active_session().get_int(KEY), 5678) + + cache_codec.start_session(session_id2) + self.assertEqual(get_active_session().get_int(KEY), 5678) + cache_codec.start_session(session_id1) + self.assertEqual(get_active_session().get_int(KEY), 1234) + + cache.clear_all() + self.assertIsNone(get_active_session()) + + def test_delete(self): + session_id1 = cache_codec.start_session() + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + get_active_session().delete(KEY) + self.assertIsNone(get_active_session().get(KEY)) + + get_active_session().set(KEY, b"hello") + cache_codec.start_session() + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + get_active_session().delete(KEY) + self.assertIsNone(get_active_session().get(KEY)) + + cache_codec.start_session(session_id1) + self.assertEqual(get_active_session().get(KEY), b"hello") + + def test_decorators(self): + run_count = 0 + cache_codec.start_session() + from apps.common.cache import stored + + @stored(KEY) + def func(): + nonlocal run_count + run_count += 1 + return b"foo" + + # cache is empty + self.assertIsNone(get_active_session().get(KEY)) + self.assertEqual(run_count, 0) + self.assertEqual(func(), b"foo") + # function was run + self.assertEqual(run_count, 1) + self.assertEqual(get_active_session().get(KEY), b"foo") + # function does not run again but returns cached value + self.assertEqual(func(), b"foo") + self.assertEqual(run_count, 1) + + def test_empty_value(self): + cache_codec.start_session() + + self.assertIsNone(get_active_session().get(KEY)) + get_active_session().set(KEY, b"") + self.assertEqual(get_active_session().get(KEY), b"") + + get_active_session().delete(KEY) + run_count = 0 + + from apps.common.cache import stored + + @stored(KEY) + def func(): + nonlocal run_count + run_count += 1 + return b"" + + self.assertEqual(func(), b"") + # function gets called once + self.assertEqual(run_count, 1) + self.assertEqual(func(), b"") + # function is not called for a second time + self.assertEqual(run_count, 1) + + if not utils.USE_THP: + + @mock_storage + def test_Initialize(self): + from apps.base import handle_Initialize + + def call_Initialize(**kwargs): + msg = Initialize(**kwargs) + return await_result(handle_Initialize(msg)) + + # calling Initialize without an ID allocates a new one + session_id = cache_codec.start_session() + features = call_Initialize() + self.assertNotEqual(session_id, features.session_id) + + # calling Initialize with the current ID does not allocate a new one + features = call_Initialize(session_id=session_id) + self.assertEqual(session_id, features.session_id) + + # store "hello" + get_active_session().set(KEY, b"hello") + # check that it is cleared + features = call_Initialize() + session_id = features.session_id + self.assertIsNone(get_active_session().get(KEY)) + # store "hello" again + get_active_session().set(KEY, b"hello") + self.assertEqual(get_active_session().get(KEY), b"hello") + + # supplying a different session ID starts a new session + call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH) + self.assertIsNone(get_active_session().get(KEY)) + + # but resuming a session loads the previous one + call_Initialize(session_id=session_id) + self.assertEqual(get_active_session().get(KEY), b"hello") + + def test_EndSession(self): + + self.assertIsNone(get_active_session()) cache_codec.start_session() - self.assertNotEqual(cache_codec.start_session(session_id), session_id) - self.assertIsNone(context.cache_get(KEY)) - - def test_get_set(self): - session_id1 = cache_codec.start_session() - context.cache_set(KEY, b"hello") - self.assertEqual(context.cache_get(KEY), b"hello") - - session_id2 = cache_codec.start_session() - context.cache_set(KEY, b"world") - self.assertEqual(context.cache_get(KEY), b"world") - - cache_codec.start_session(session_id2) - self.assertEqual(context.cache_get(KEY), b"world") - cache_codec.start_session(session_id1) - self.assertEqual(context.cache_get(KEY), b"hello") - - cache.clear_all() - with self.assertRaises(cache_common.InvalidSessionError): - context.cache_get(KEY) - - def test_get_set_int(self): - session_id1 = cache_codec.start_session() - context.cache_set_int(KEY, 1234) - self.assertEqual(context.cache_get_int(KEY), 1234) - - session_id2 = cache_codec.start_session() - context.cache_set_int(KEY, 5678) - self.assertEqual(context.cache_get_int(KEY), 5678) - - cache_codec.start_session(session_id2) - self.assertEqual(context.cache_get_int(KEY), 5678) - cache_codec.start_session(session_id1) - self.assertEqual(context.cache_get_int(KEY), 1234) - - cache.clear_all() - with self.assertRaises(cache_common.InvalidSessionError): - context.cache_get_int(KEY) - - def test_delete(self): - session_id1 = cache_codec.start_session() - self.assertIsNone(context.cache_get(KEY)) - context.cache_set(KEY, b"hello") - self.assertEqual(context.cache_get(KEY), b"hello") - context.cache_delete(KEY) - self.assertIsNone(context.cache_get(KEY)) - - context.cache_set(KEY, b"hello") - cache_codec.start_session() - self.assertIsNone(context.cache_get(KEY)) - context.cache_set(KEY, b"hello") - self.assertEqual(context.cache_get(KEY), b"hello") - context.cache_delete(KEY) - self.assertIsNone(context.cache_get(KEY)) - - cache_codec.start_session(session_id1) - self.assertEqual(context.cache_get(KEY), b"hello") - - def test_decorators(self): - run_count = 0 - cache_codec.start_session() - - @stored(KEY) - def func(): - nonlocal run_count - run_count += 1 - return b"foo" - - # cache is empty - self.assertIsNone(context.cache_get(KEY)) - self.assertEqual(run_count, 0) - self.assertEqual(func(), b"foo") - # function was run - self.assertEqual(run_count, 1) - self.assertEqual(context.cache_get(KEY), b"foo") - # function does not run again but returns cached value - self.assertEqual(func(), b"foo") - self.assertEqual(run_count, 1) - - @stored_async(KEY) - async def async_func(): - nonlocal run_count - run_count += 1 - return b"bar" - - # cache is still full - self.assertEqual(await_result(async_func()), b"foo") - self.assertEqual(run_count, 1) - - cache_codec.start_session() - self.assertEqual(await_result(async_func()), b"bar") - self.assertEqual(run_count, 2) - # awaitable is also run only once - self.assertEqual(await_result(async_func()), b"bar") - self.assertEqual(run_count, 2) - - def test_empty_value(self): - cache_codec.start_session() - - self.assertIsNone(context.cache_get(KEY)) - context.cache_set(KEY, b"") - self.assertEqual(context.cache_get(KEY), b"") - - context.cache_delete(KEY) - run_count = 0 - - @stored(KEY) - def func(): - nonlocal run_count - run_count += 1 - return b"" - - self.assertEqual(func(), b"") - # function gets called once - self.assertEqual(run_count, 1) - self.assertEqual(func(), b"") - # function is not called for a second time - self.assertEqual(run_count, 1) - - @mock_storage - def test_Initialize(self): - def call_Initialize(**kwargs): - msg = Initialize(**kwargs) - return await_result(handle_Initialize(msg)) - - # calling Initialize without an ID allocates a new one - session_id = cache_codec.start_session() - features = call_Initialize() - self.assertNotEqual(session_id, features.session_id) - - # calling Initialize with the current ID does not allocate a new one - features = call_Initialize(session_id=session_id) - self.assertEqual(session_id, features.session_id) - - # store "hello" - context.cache_set(KEY, b"hello") - # check that it is cleared - features = call_Initialize() - session_id = features.session_id - self.assertIsNone(context.cache_get(KEY)) - # store "hello" again - context.cache_set(KEY, b"hello") - self.assertEqual(context.cache_get(KEY), b"hello") - - # supplying a different session ID starts a new cache - call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH) - self.assertIsNone(context.cache_get(KEY)) - - # but resuming a session loads the previous one - call_Initialize(session_id=session_id) - self.assertEqual(context.cache_get(KEY), b"hello") - - def test_EndSession(self): - self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) - cache_codec.start_session() - self.assertTrue(is_session_started()) - self.assertIsNone(context.cache_get(KEY)) - await_result(handle_EndSession(EndSession())) - self.assertFalse(is_session_started()) - self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) + self.assertTrue(is_session_started()) + self.assertIsNone(get_active_session().get(KEY)) + await_result(handle_EndSession(EndSession())) + self.assertFalse(is_session_started()) + self.assertIsNone(cache_codec.get_active_session()) if __name__ == "__main__": diff --git a/core/tests/test_trezor.wire.codec.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py index 78675859e2c..bd8ddacea7b 100644 --- a/core/tests/test_trezor.wire.codec.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -2,28 +2,11 @@ import ustruct +from mock_wire_interface import MockHID from trezor import io -from trezor.loop import wait from trezor.utils import chunks from trezor.wire.codec import codec_v1 - -class MockHID: - def __init__(self, num): - self.num = num - self.data = [] - - def iface_num(self): - return self.num - - def write(self, msg): - self.data.append(bytearray(msg)) - return len(msg) - - def wait_object(self, mode): - return wait(mode | self.num) - - MESSAGE_TYPE = 0x4242 HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL") diff --git a/core/tests/test_trezor.wire.thp.checksum.py b/core/tests/test_trezor.wire.thp.checksum.py new file mode 100644 index 00000000000..41c93250012 --- /dev/null +++ b/core/tests/test_trezor.wire.thp.checksum.py @@ -0,0 +1,94 @@ +from common import * # isort:skip + +if utils.USE_THP: + from trezor.wire.thp import checksum + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolChecksum(unittest.TestCase): + vectors_correct = [ + ( + b"", + b"\x00\x00\x00\x00", + ), + ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + b"\x19\x0A\x55\xAD", + ), + ( + bytes("a", "ascii"), + b"\xE8\xB7\xBE\x43", + ), + ( + bytes("abc", "ascii"), + b"\x35\x24\x41\xC2", + ), + ( + bytes("123456789", "ascii"), + b"\xCB\xF4\x39\x26", + ), + ( + bytes( + "12345678901234567890123456789012345678901234567890123456789012345678901234567890", + "ascii", + ), + b"\x7C\xA9\x4A\x72", + ), + ( + b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61", + b"\x9B\xD3\x66\xAE", + ), + ( + b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4", + b"\x6B\xA4\xEC\x92", + ), + ] + vectors_incorrect = [ + ( + b"", + b"\x00\x00\x00\x00\x00", + ), + ( + b"", + b"", + ), + ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + b"\x19\x0A\x55\xAE", + ), + ( + bytes("A", "ascii"), + b"\xE8\xB7\xBE\x43", + ), + ( + bytes("abc ", "ascii"), + b"\x35\x24\x41\xC2", + ), + ( + bytes("1234567890", "ascii"), + b"\xCB\xF4\x39\x26", + ), + ( + bytes( + "1234567890123456789012345678901234567890123456789012345678901234567890123456789", + "ascii", + ), + b"\x7C\xA9\x4A\x72", + ), + ] + + def test_computation(self): + for data, chksum in self.vectors_correct: + self.assertEqual(checksum.compute(data), chksum) + + def test_validation_correct(self): + for data, chksum in self.vectors_correct: + self.assertTrue(checksum.is_valid(chksum, data)) + + def test_validation_incorrect(self): + for data, chksum in self.vectors_incorrect: + self.assertFalse(checksum.is_valid(chksum, data)) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.credential_manager.py b/core/tests/test_trezor.wire.thp.credential_manager.py new file mode 100644 index 00000000000..59631979d6a --- /dev/null +++ b/core/tests/test_trezor.wire.thp.credential_manager.py @@ -0,0 +1,66 @@ +from common import * # isort:skip + + +if utils.USE_THP: + import thp_common + from trezor import config + from trezor.messages import ThpCredentialMetadata + + from apps.thp import credential_manager + + def _issue_credential(host_name: str, host_static_pubkey: bytes) -> bytes: + metadata = ThpCredentialMetadata(host_name=host_name) + return credential_manager.issue_credential(host_static_pubkey, metadata) + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolCredentialManager(unittest.TestCase): + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + config.init() + config.wipe() + + def test_derive_cred_auth_key(self): + key1 = credential_manager.derive_cred_auth_key() + key2 = credential_manager.derive_cred_auth_key() + self.assertEqual(len(key1), 32) + self.assertEqual(key1, key2) + + def test_invalidate_cred_auth_key(self): + key1 = credential_manager.derive_cred_auth_key() + credential_manager.invalidate_cred_auth_key() + key2 = credential_manager.derive_cred_auth_key() + self.assertNotEqual(key1, key2) + + def test_credentials(self): + DUMMY_KEY_1 = b"\x00\x00" + DUMMY_KEY_2 = b"\xff\xff" + HOST_NAME_1 = "host_name" + HOST_NAME_2 = "different host_name" + + cred_1 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1) + cred_2 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1) + self.assertEqual(cred_1, cred_2) + + cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1) + self.assertNotEqual(cred_1, cred_3) + + self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1)) + self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1)) + self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2)) + + credential_manager.invalidate_cred_auth_key() + cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1) + self.assertNotEqual(cred_1, cred_4) + self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1)) + self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1)) + self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.crypto.py b/core/tests/test_trezor.wire.thp.crypto.py new file mode 100644 index 00000000000..d26785ce65e --- /dev/null +++ b/core/tests/test_trezor.wire.thp.crypto.py @@ -0,0 +1,156 @@ +from common import * # isort:skip +from trezorcrypto import aesgcm, curve25519 + +import storage + +if utils.USE_THP: + import thp_common + from trezor.wire.thp import crypto + from trezor.wire.thp.crypto import IV_1, IV_2, Handshake + + def get_dummy_device_secret(): + return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08" + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolCrypto(unittest.TestCase): + if utils.USE_THP: + handshake = Handshake() + key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07" + # 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag + vectors_enc = [ + ( + key_1, + 0, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09", + b"e2c9dd152fbee5821ea7", + b"10625812de81b14a46b9f1e5100a6d0c", + ), + ( + key_1, + 1, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09", + b"79811619ddb07c2b99f8", + b"71c6b872cdc499a7e9a3c7441f053214", + ), + ( + key_1, + 369, + b"\x55\x64", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + b"03bd030390f2dfe815a61c2b157a064f", + b"c1200f8a7ae9a6d32cef0fff878d55c2", + ), + ( + key_1, + 369, + b"\x55\x64\x73\x82\x91", + b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + b"03bd030390f2dfe815a61c2b157a064f", + b"693ac160cd93a20f7fc255f049d808d0", + ), + ] + # 0:chaining key, 1:input, 2:output_1, 3:output:2 + vectors_hkdf = [ + ( + crypto.PROTOCOL_NAME, + b"\x01\x02", + b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514", + b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba", + ), + ( + b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14", + b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa", + b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48", + b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76", + ), + ] + vectors_iv = [ + (0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), + (1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"), + (7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"), + (1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"), + (4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"), + (0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"), + ] + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + utils.DISABLE_ENCRYPTION = False + + def test_encryption(self): + for v in self.vectors_enc: + buffer = bytearray(v[3]) + tag = crypto.enc(buffer, v[0], v[1], v[2]) + self.assertEqual(hexlify(buffer), v[4]) + self.assertEqual(hexlify(tag), v[5]) + self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2])) + self.assertEqual(buffer, v[3]) + + def test_hkdf(self): + for v in self.vectors_hkdf: + ck, k = crypto._hkdf(v[0], v[1]) + self.assertEqual(hexlify(ck), v[2]) + self.assertEqual(hexlify(k), v[3]) + + def test_iv_from_nonce(self): + for v in self.vectors_iv: + x = v[0] + y = x.to_bytes(8, "big") + iv = crypto._get_iv_from_nonce(v[0]) + self.assertEqual(iv, v[1]) + with self.assertRaises(AssertionError) as e: + iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1) + self.assertEqual(e.value.value, "Nonce overflow, terminate the channel") + + def test_incorrect_vectors(self): + pass + + def test_th1_crypto(self): + storage.device.get_device_secret = get_dummy_device_secret + handshake = self.handshake + + host_ephemeral_privkey = curve25519.generate_secret() + host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey) + handshake.handle_th1_crypto(b"", host_ephemeral_pubkey) + + def test_th2_crypto(self): + handshake = self.handshake + + host_static_privkey = curve25519.generate_secret() + host_static_pubkey = curve25519.publickey(host_static_privkey) + aes_ctx = aesgcm(handshake.k, IV_2) + aes_ctx.auth(handshake.h) + encrypted_host_static_pubkey = bytearray( + aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish() + ) + + # Code to encrypt Host's noise encrypted payload correctly: + protomsg = bytearray(b"\x10\x02\x10\x03") + temp_k = handshake.k + temp_h = handshake.h + + temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey) + _, temp_k = crypto._hkdf( + handshake.ck, + curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey), + ) + aes_ctx = aesgcm(temp_k, IV_1) + aes_ctx.encrypt_in_place(protomsg) + aes_ctx.auth(temp_h) + tag = aes_ctx.finish() + encrypted_payload = bytearray(protomsg + tag) + # end of encrypted payload generation + + handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload) + self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03") + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.py b/core/tests/test_trezor.wire.thp.py new file mode 100644 index 00000000000..576ddab4db5 --- /dev/null +++ b/core/tests/test_trezor.wire.thp.py @@ -0,0 +1,378 @@ +from common import * # isort:skip +from mock_wire_interface import MockHID +from trezor import config, io, protobuf +from trezor.crypto.curve import curve25519 +from trezor.enums import ThpMessageType +from trezor.wire.errors import UnexpectedMessage +from trezor.wire.protocol_common import Message + +if utils.USE_THP: + from typing import TYPE_CHECKING + + import thp_common + from storage import cache_thp + from storage.cache_common import ( + CHANNEL_HANDSHAKE_HASH, + CHANNEL_KEY_RECEIVE, + CHANNEL_KEY_SEND, + CHANNEL_NONCE_RECEIVE, + CHANNEL_NONCE_SEND, + ) + from trezor.crypto import elligator2 + from trezor.enums import ThpPairingMethod + from trezor.messages import ( + ThpCodeEntryChallenge, + ThpCodeEntryCpaceHost, + ThpCodeEntryTag, + ThpCredentialRequest, + ThpEndRequest, + ThpStartPairingRequest, + ) + from trezor.wire.thp import thp_main + from trezor.wire.thp import ChannelState, checksum, interface_manager + from trezor.wire.thp.crypto import Handshake + from trezor.wire.thp.pairing_context import PairingContext + + from apps.thp import pairing + + if TYPE_CHECKING: + from trezor.wire import WireInterface + + def get_dummy_key() -> bytes: + return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31" + + def dummy_encode_iface(iface: WireInterface): + return thp_common._MOCK_INTERFACE_HID + + def send_channel_allocation_request( + interface: WireInterface, nonce: bytes | None = None + ) -> bytes: + if nonce is None or len(nonce) != 8: + nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77" + header = b"\x40\xff\xff\x00\x0c" + chksum = checksum.compute(header + nonce) + cid_req = header + nonce + chksum + gen = thp_main.thp_main_loop(interface) + expected_channel_index = cache_thp._get_next_channel_index() + gen.send(None) + gen.send(cid_req) + gen.send(None) + model = bytes(utils.INTERNAL_MODEL, "big") + response_data = ( + b"\x0a\x04" + model + "\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04" + ) + response_without_crc = ( + b"\x41\xff\xff\x00\x20" + + nonce + + cache_thp._CHANNELS[expected_channel_index].channel_id + + response_data + ) + chkcsum = checksum.compute(response_without_crc) + expected_response = response_without_crc + chkcsum + b"\x00" * 27 + return expected_response + + def get_channel_id_from_response(channel_allocation_response: bytes) -> int: + return int.from_bytes(channel_allocation_response[13:15], "big") + + def get_ack(channel_id: bytes) -> bytes: + if len(channel_id) != 2: + raise Exception("Channel id should by two bytes long") + return ( + b"\x20" + + channel_id + + b"\x00\x04" + + checksum.compute(b"\x20" + channel_id + b"\x00\x04") + + b"\x00" * 55 + ) + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocol(unittest.TestCase): + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + interface_manager.encode_iface = dummy_encode_iface + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + buffer = bytearray(64) + buffer2 = bytearray(256) + thp_main.set_read_buffer(buffer) + thp_main.set_write_buffer(buffer2) + interface_manager.decode_iface = thp_common.dummy_decode_iface + + def test_codec_message(self): + self.assertEqual(len(self.interface.data), 0) + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + + # There should be a failiure response to received init packet (starts with "?##") + test_codec_message = b"?## Some data" + gen.send(test_codec_message) + gen.send(None) + self.assertEqual(len(self.interface.data), 1) + + expected_response = b"?##\x00\x03\x00\x00\x00\x14\x08\x10" + self.assertEqual( + self.interface.data[-1][: len(expected_response)], expected_response + ) + + # There should be no response for continuation packet (starts with "?" only) + test_codec_message_2 = b"? Cont packet" + gen.send(test_codec_message_2) + with self.assertRaises(TypeError) as e: + gen.send(None) + self.assertEqual(e.value.value, "object with buffer protocol required") + self.assertEqual(len(self.interface.data), 1) + + def test_message_on_unallocated_channel(self): + gen = thp_main.thp_main_loop(self.interface) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + message_to_channel_789a = ( + b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c" + ) + gen.send(message_to_channel_789a) + gen.send(None) + unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + self.assertEqual( + utils.get_bytes_as_str(self.interface.data[-1]), + unallocated_chanel_error_on_channel_789a, + ) + + def test_channel_allocation(self): + self.assertEqual(len(thp_main._CHANNELS), 0) + for c in cache_thp._CHANNELS: + self.assertEqual(int.from_bytes(c.state, "big"), ChannelState.UNALLOCATED) + + expected_channel_index = cache_thp._get_next_channel_index() + expected_response = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-1], expected_response) + + cid = cache_thp._CHANNELS[expected_channel_index].channel_id + self.assertTrue(int.from_bytes(cid, "big") in thp_main._CHANNELS) + self.assertEqual(len(thp_main._CHANNELS), 1) + + # test channel's default state is TH1: + cid = get_channel_id_from_response(self.interface.data[-1]) + self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1) + + def test_invalid_encrypted_tag(self): + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + # prepare 2 new channels + expected_response_1 = send_channel_allocation_request(self.interface) + expected_response_2 = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-2], expected_response_1) + self.assertEqual(self.interface.data[-1], expected_response_2) + + # test invalid encryption tag + config.init() + config.wipe() + cid_1 = get_channel_id_from_response(expected_response_1) + channel = thp_main._CHANNELS[cid_1] + channel.iface = self.interface + channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + header = b"\x04" + channel.channel_id + b"\x00\x14" + + tag = b"\x00" * 16 + chksum = checksum.compute(header + tag) + message_with_invalid_tag = header + tag + chksum + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + cid_1_bytes = int.to_bytes(cid_1, 2, "big") + expected_ack_on_received_message = get_ack(cid_1_bytes) + + gen.send(message_with_invalid_tag) + gen.send(None) + + self.assertEqual( + self.interface.data[-1], + expected_ack_on_received_message, + ) + error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03" + chksum_err = checksum.compute(error_without_crc) + gen.send(None) + + decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54 + + self.assertEqual( + self.interface.data[-1], + decryption_failed_error, + ) + + def test_channel_errors(self): + gen = thp_main.thp_main_loop(self.interface) + gen.send(None) + # prepare 2 new channels + expected_response_1 = send_channel_allocation_request(self.interface) + expected_response_2 = send_channel_allocation_request(self.interface) + self.assertEqual(self.interface.data[-2], expected_response_1) + self.assertEqual(self.interface.data[-1], expected_response_2) + + # test invalid encryption tag + config.init() + config.wipe() + cid_1 = get_channel_id_from_response(expected_response_1) + channel = thp_main._CHANNELS[cid_1] + channel.iface = self.interface + channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + header = b"\x04" + channel.channel_id + b"\x00\x14" + + tag = b"\x00" * 16 + chksum = checksum.compute(header + tag) + message_with_invalid_tag = header + tag + chksum + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + cid_1_bytes = int.to_bytes(cid_1, 2, "big") + expected_ack_on_received_message = get_ack(cid_1_bytes) + + gen.send(message_with_invalid_tag) + gen.send(None) + + self.assertEqual( + self.interface.data[-1], + expected_ack_on_received_message, + ) + error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03" + chksum_err = checksum.compute(error_without_crc) + gen.send(None) + + decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54 + + self.assertEqual( + self.interface.data[-1], + decryption_failed_error, + ) + + # test invalid tag in handshake phase + cid_2 = get_channel_id_from_response(expected_response_1) + cid_2_bytes = cid_2.to_bytes(2, "big") + channel = thp_main._CHANNELS[cid_2] + channel.iface = self.interface + + channel.set_channel_state(ChannelState.TH2) + + message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9" + + channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) + + # gen.send(message_with_invalid_tag) + # gen.send(None) + # gen.send(None) + # for i in self.interface.data: + # print(utils.get_bytes_as_str(i)) + + def test_skip_pairing(self): + config.init() + config.wipe() + channel = next(iter(thp_main._CHANNELS.values())) + channel.selected_pairing_methods = [ + ThpPairingMethod.NoMethod, + ThpPairingMethod.CodeEntry, + ThpPairingMethod.NFC_Unidirectional, + ThpPairingMethod.QrCode, + ] + pairing_ctx = PairingContext(channel) + request_message = ThpStartPairingRequest() + channel.set_channel_state(ChannelState.TP1) + gen = pairing.handle_pairing_request(pairing_ctx, request_message) + + with self.assertRaises(StopIteration): + gen.send(None) + self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT) + + # Teardown: set back initial channel state value + channel.set_channel_state(ChannelState.TH1) + + def TODO_test_pairing(self): + config.init() + config.wipe() + cid = get_channel_id_from_response( + send_channel_allocation_request(self.interface) + ) + channel = thp_main._CHANNELS[cid] + channel.selected_pairing_methods = [ + ThpPairingMethod.CodeEntry, + ThpPairingMethod.NFC_Unidirectional, + ThpPairingMethod.QrCode, + ] + pairing_ctx = PairingContext(channel) + request_message = ThpStartPairingRequest() + with self.assertRaises(UnexpectedMessage) as e: + pairing.handle_pairing_request(pairing_ctx, request_message) + print(e.value.message) + channel.set_channel_state(ChannelState.TP1) + gen = pairing.handle_pairing_request(pairing_ctx, request_message) + + channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key()) + channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0) + channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"") + + gen.send(None) + + async def _dummy(ctx: PairingContext, expected_types): + return await ctx.read([1018, 1024]) + + pairing.show_display_data = _dummy + + msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34") + buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry)) + protobuf.encode(buffer, msg_code_entry) + code_entry_challenge = Message(ThpMessageType.ThpCodeEntryChallenge, buffer) + gen.send(code_entry_challenge) + + # tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3" + # tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0" + + pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe" + generator_host = elligator2.map_to_curve25519(pregenerator_host) + cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9" + cpace_host_public_key: bytes = curve25519.multiply( + cpace_host_private_key, generator_host + ) + msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key) + + # msg = ThpQrCodeTag(tag=tag_qrc) + # msg = ThpNfcUnidirectionalTag(tag=tag_nfc) + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + + protobuf.encode(buffer, msg) + user_message = Message(ThpMessageType.ThpCodeEntryCpaceHost, buffer) + gen.send(user_message) + + tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81" + msg = ThpCodeEntryTag(tag=tag_ent) + + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + + protobuf.encode(buffer, msg) + user_message = Message(ThpMessageType.ThpCodeEntryTag, buffer) + gen.send(user_message) + + host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77" + msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey) + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + protobuf.encode(buffer, msg) + credential_request = Message(ThpMessageType.ThpCredentialRequest, buffer) + gen.send(credential_request) + + msg = ThpEndRequest() + + buffer: bytearray = bytearray(protobuf.encoded_length(msg)) + protobuf.encode(buffer, msg) + end_request = Message(1012, buffer) + with self.assertRaises(StopIteration) as e: + gen.send(end_request) + print("response message:", e.value.value.MESSAGE_NAME) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp.writer.py b/core/tests/test_trezor.wire.thp.writer.py new file mode 100644 index 00000000000..0e50f5c4b5b --- /dev/null +++ b/core/tests/test_trezor.wire.thp.writer.py @@ -0,0 +1,151 @@ +from common import * # isort:skip + +from typing import Any, Awaitable + + +if utils.USE_THP: + import thp_common + from mock_wire_interface import MockHID + from trezor.wire.thp import writer + from trezor.wire.thp import ENCRYPTED, PacketHeader + + +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestTrezorHostProtocolWriter(unittest.TestCase): + short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + longer_payload_expected = [ + b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ] + eight_longer_payloads_expected = [ + b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e", + b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b", + b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8", + b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5", + b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122", + b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f", + b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c", + b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9", + b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516", + b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253", + b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90", + b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd", + b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a", + b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647", + b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384", + b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1", + b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe", + b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b", + b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778", + b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5", + b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2", + b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f", + b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c", + b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9", + b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6", + b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223", + b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60", + b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d", + b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da", + b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000", + ] + empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + longer_payload_with_checksum_expected = [ + b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a", + b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677", + b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4", + b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", + b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ] + + def await_until_result(self, task: Awaitable) -> Any: + with self.assertRaises(StopIteration): + while True: + task.send(None) + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + + def test_write_empty_packet(self): + self.await_until_result(writer.write_packet_to_wire(self.interface, b"")) + + print(self.interface.data[0]) + self.assertEqual(len(self.interface.data), 1) + self.assertEqual(self.interface.data[0], b"") + + def test_write_empty_payload(self): + header = PacketHeader(ENCRYPTED, 4660, 4) + await_result(writer.write_payloads_to_wire(self.interface, header, (b"",))) + self.assertEqual(len(self.interface.data), 0) + + def test_write_short_payload(self): + header = PacketHeader(ENCRYPTED, 4660, 5) + data = b"\x07" + self.await_until_result( + writer.write_payloads_to_wire(self.interface, header, (data,)) + ) + self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected) + + def test_write_longer_payload(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED, 4660, 256) + self.await_until_result( + writer.write_payloads_to_wire(self.interface, header, (data,)) + ) + + for i in range(len(self.longer_payload_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), self.longer_payload_expected[i] + ) + + def test_write_eight_longer_payloads(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED, 4660, 2048) + self.await_until_result( + writer.write_payloads_to_wire( + self.interface, header, (data, data, data, data, data, data, data, data) + ) + ) + for i in range(len(self.eight_longer_payloads_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i] + ) + + def test_write_empty_payload_with_checksum(self): + header = PacketHeader(ENCRYPTED, 4660, 4) + self.await_until_result( + writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"") + ) + + self.assertEqual( + hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected + ) + + def test_write_longer_payload_with_checksum(self): + data = bytearray(range(256)) + header = PacketHeader(ENCRYPTED, 4660, 256) + self.await_until_result( + writer.write_payload_to_wire_and_add_checksum(self.interface, header, data) + ) + + for i in range(len(self.longer_payload_with_checksum_expected)): + self.assertEqual( + hexlify(self.interface.data[i]), + self.longer_payload_with_checksum_expected[i], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp_deprecated.py b/core/tests/test_trezor.wire.thp_deprecated.py new file mode 100644 index 00000000000..12ef40bb7e2 --- /dev/null +++ b/core/tests/test_trezor.wire.thp_deprecated.py @@ -0,0 +1,338 @@ +from common import * # isort:skip +import ustruct +from typing import TYPE_CHECKING + +from mock_wire_interface import MockHID +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import io +from trezor.utils import chunks +from trezor.wire.protocol_common import Message + +if utils.USE_THP: + import thp_common + import trezor.wire.thp + from trezor.wire.thp import thp_main + from trezor.wire.thp import alternating_bit_protocol as ABP + from trezor.wire.thp import checksum + from trezor.wire.thp.checksum import CHECKSUM_LENGTH + from trezor.wire.thp.writer import PACKET_LENGTH + +if TYPE_CHECKING: + from trezorio import WireInterface + + +MESSAGE_TYPE = 0x4242 +MESSAGE_TYPE_BYTES = b"\x42\x42" +_MESSAGE_TYPE_LEN = 2 +PLAINTEXT_0 = 0x01 +PLAINTEXT_1 = 0x11 +COMMON_CID = 4660 +CONT = 0x80 + +HEADER_INIT_LENGTH = 5 +HEADER_CONT_LENGTH = 3 +if utils.USE_THP: + INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN + + +def make_header(ctrl_byte, cid, length): + return ustruct.pack(">BHH", ctrl_byte, cid, length) + + +def make_cont_header(): + return ustruct.pack(">BH", CONT, COMMON_CID) + + +def makeSimpleMessage(header, message_type, message_data): + return header + ustruct.pack(">H", message_type) + message_data + + +def makeCidRequest(header, message_data): + return header + message_data + + +def getPlaintext() -> bytes: + if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1: + return PLAINTEXT_1 + return PLAINTEXT_0 + + +async def deprecated_read_message( + iface: WireInterface, buffer: utils.BufferType +) -> Message: + return Message(-1, b"\x00") + + +async def deprecated_write_message( + iface: WireInterface, message: Message, is_retransmission: bool = False +) -> None: + pass + + +# This test suite is an adaptation of test_trezor.wire.codec_v1 +@unittest.skipUnless(utils.USE_THP, "only needed for THP") +class TestWireTrezorHostProtocolV1(unittest.TestCase): + + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + super().__init__() + + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + + def _simple(self): + cid_req_header = make_header( + ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12 + ) + cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c" + cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data) + + message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18) + cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0" + message = makeSimpleMessage( + message_header, + MESSAGE_TYPE, + cid_request_dummy_data + cid_request_dummy_data_checksum, + ) + + buffer = bytearray(64) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(cid_req_message) + gen.send(None) + gen.send(message) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, cid_request_dummy_data) + + buffer_without_zeroes = buffer[: len(message) - 5] + message_without_header = message[5:] + # message should have been read into the buffer + self.assertEqual(buffer_without_zeroes, message_without_header) + + def _read_one_packet(self): + # zero length message - just a header + PLAINTEXT = getPlaintext() + header = make_header( + PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH + ) + chksum = checksum.compute(header + MESSAGE_TYPE_BYTES) + message = header + MESSAGE_TYPE_BYTES + chksum + + buffer = bytearray(64) + gen = deprecated_read_message(self.interface, buffer) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(message) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, b"") + + # message should have been read into the buffer + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58) + + def _read_many_packets(self): + message = bytes(range(256)) + header = make_header( + getPlaintext(), + COMMON_CID, + len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, + ) + chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message) + # message = MESSAGE_TYPE_BYTES + message + checksum + + # first packet is init header + 59 bytes of data + # other packets are cont header + 61 bytes of data + cont_header = make_cont_header() + packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [ + cont_header + chunk + for chunk in chunks( + message[INIT_MESSAGE_DATA_LENGTH:] + chksum, + 64 - HEADER_CONT_LENGTH, + ) + ] + buffer = bytearray(262) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + for packet in packets: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + query = gen.send(packet) + + # last packet will stop + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message) + + # message should have been read into the buffer ) + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum) + + def _read_large_message(self): + message = b"hello world" + header = make_header( + getPlaintext(), + COMMON_CID, + _MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH, + ) + + packet = ( + header + + MESSAGE_TYPE_BYTES + + message + + checksum.compute(header + MESSAGE_TYPE_BYTES + message) + ) + + # make sure we fit into one packet, to make this easier + self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH) + + buffer = bytearray(1) + self.assertTrue(len(buffer) <= len(packet)) + + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(packet) + with self.assertRaises(StopIteration) as e: + gen.send(None) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message) + + # read should have allocated its own buffer and not touch ours + self.assertEqual(buffer, b"\x00") + + def _roundtrip(self): + message_payload = bytes(range(256)) + message = Message( + MESSAGE_TYPE, message_payload, 1 + ) # TODO use different session id + gen = deprecated_write_message(self.interface, message) + # exhaust the iterator: + # (XXX we can only do this because the iterator is only accepting None and returns None) + for query in gen: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + buffer = bytearray(1024) + gen = deprecated_read_message(self.interface, buffer) + query = gen.send(None) + for packet in self.interface.data: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + print(utils.get_bytes_as_str(packet)) + query = gen.send(packet) + + with self.assertRaises(StopIteration) as e: + gen.send(None) + + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message.data) + + def _write_one_packet(self): + message = Message(MESSAGE_TYPE, b"") + gen = deprecated_write_message(self.interface, message) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + with self.assertRaises(StopIteration): + gen.send(None) + + header = make_header( + getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH + ) + expected_message = ( + header + + MESSAGE_TYPE_BYTES + + checksum.compute(header + MESSAGE_TYPE_BYTES) + + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH) + ) + self.assertTrue(self.interface.data == [expected_message]) + + def _write_multiple_packets(self): + message_payload = bytes(range(256)) + message = Message(MESSAGE_TYPE, message_payload) + gen = deprecated_write_message(self.interface, message) + + header = make_header( + PLAINTEXT_1, + COMMON_CID, + len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, + ) + cont_header = make_cont_header() + chksum = checksum.compute( + header + message.type.to_bytes(2, "big") + message.data + ) + packets = [ + header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH] + ] + [ + cont_header + chunk + for chunk in chunks( + message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum, + thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH, + ) + ] + + for _ in packets: + # we receive as many queries as there are packets + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + # the first sent None only started the generator. the len(packets)-th None + # will finish writing and raise StopIteration + with self.assertRaises(StopIteration): + gen.send(None) + + # packets must be identical up to the last one + self.assertListEqual(packets[:-1], self.interface.data[:-1]) + # last packet must be identical up to message length. remaining bytes in + # the 64-byte packets are garbage -- in particular, it's the bytes of the + # previous packet + last_packet = packets[-1] + packets[-2][len(packets[-1]) :] + self.assertEqual(last_packet, self.interface.data[-1]) + + def _read_huge_packet(self): + PACKET_COUNT = 1180 + # message that takes up 1 180 USB packets + message_size = (PACKET_COUNT - 1) * ( + PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN + ) + INIT_MESSAGE_DATA_LENGTH + + # ensure that a message this big won't fit into memory + # Note: this control is changed, because THP has only 2 byte length field + self.assertTrue(message_size > thp_main.MAX_PAYLOAD_LEN) + # self.assertRaises(MemoryError, bytearray, message_size) + header = make_header(PLAINTEXT_1, COMMON_CID, message_size) + packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH) + buffer = bytearray(65536) + gen = deprecated_read_message(self.interface, buffer) + + query = gen.send(None) + + # THP returns "Message too large" error after reading the message size, + # it is different from codec_v1 as it does not allow big enough messages + # to raise MemoryError in this test + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + with self.assertRaises(trezor.wire.thp.ThpError) as e: + query = gen.send(packet) + + self.assertEqual(e.value.args[0], "Message too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/thp_common.py b/core/tests/thp_common.py new file mode 100644 index 00000000000..298a5130020 --- /dev/null +++ b/core/tests/thp_common.py @@ -0,0 +1,44 @@ +from trezor import utils +from trezor.wire.thp import ChannelState + +if utils.USE_THP: + import unittest + from typing import TYPE_CHECKING + + from mock_wire_interface import MockHID + from storage import cache_thp + from trezor.wire import context + from trezor.wire.thp import interface_manager + from trezor.wire.thp.channel import Channel + from trezor.wire.thp.session_context import SessionContext + + _MOCK_INTERFACE_HID = b"\x00" + + if TYPE_CHECKING: + from trezor.wire import WireInterface + + def dummy_decode_iface(cached_iface: bytes): + return MockHID(0xDEADBEEF) + + def get_new_channel(channel_iface: WireInterface | None = None) -> Channel: + interface_manager.decode_iface = dummy_decode_iface + channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID) + channel = Channel(channel_cache) + channel.set_channel_state(ChannelState.TH1) + if channel_iface is not None: + channel.iface = channel_iface + return channel + + def prepare_context() -> None: + channel = get_new_channel() + session_cache = cache_thp.get_new_session(channel.channel_cache) + session_ctx = SessionContext(channel, session_cache) + context.CURRENT_CONTEXT = session_ctx + + +if __debug__: + # Disable log.debug + def suppres_debug_log() -> None: + from trezor import log + + log.debug = lambda name, msg, *args: None diff --git a/core/tools/codegen/get_trezor_keys.py b/core/tools/codegen/get_trezor_keys.py index 31c40fef1fe..b511abd807d 100755 --- a/core/tools/codegen/get_trezor_keys.py +++ b/core/tools/codegen/get_trezor_keys.py @@ -2,7 +2,7 @@ import binascii from trezorlib.client import TrezorClient -from trezorlib.transport_hid import HidTransport +from trezorlib.transport.hid import HidTransport devices = HidTransport.enumerate() if len(devices) > 0: diff --git a/docs/ci/jobs.md b/docs/ci/jobs.md index 7a57340f24d..2325549c2a9 100644 --- a/docs/ci/jobs.md +++ b/docs/ci/jobs.md @@ -106,44 +106,44 @@ Frozen version. That means you do not need any other files to run it, it is just a single binary file that you can execute directly. **Are you looking for a Trezor T emulator? This is most likely it.** -### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L317) +### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L318) -### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L332) +### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L333) -### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L346) +### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L347) -### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L369) +### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L370) -### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L392) +### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L393) -### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L408) +### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L409) -### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L430) +### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L431) -### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L455) +### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L456) Build of our cryptographic library, which is then incorporated into the other builds. -### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L485) +### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L486) -### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L501) +### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L502) -### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L518) +### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L519) -### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L537) +### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L538) -### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L558) +### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L559) Regular version (not only Bitcoin) of above. **Are you looking for a Trezor One emulator? This is most likely it.** -### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L573) +### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L574) -### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L591) +### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L592) -### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L617) +### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L618) Build of Legacy into UNIX emulator. Use keyboard arrows to emulate button presses. Bitcoin-only version. -### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L634) +### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L635) --- ## TEST stage - [test.yml](https://github.com/trezor/trezor-firmware/blob/master/ci/test.yml) diff --git a/legacy/firmware/fsm.c b/legacy/firmware/fsm.c index 07c4c24b1cd..99f9c356a58 100644 --- a/legacy/firmware/fsm.c +++ b/legacy/firmware/fsm.c @@ -191,6 +191,12 @@ void fsm_sendFailure(FailureType code, const char *text) case FailureType_Failure_InvalidSession: text = _("Invalid session"); break; + case FailureType_Failure_ThpUnallocatedSession: + text = _("Unallocated session"); + break; + case FailureType_Failure_InvalidProtocol: + text = _("Invalid protocol"); + break; case FailureType_Failure_FirmwareError: text = _("Firmware error"); break; diff --git a/legacy/firmware/protob/Makefile b/legacy/firmware/protob/Makefile index b6782666b03..f4aa763e239 100644 --- a/legacy/firmware/protob/Makefile +++ b/legacy/firmware/protob/Makefile @@ -10,7 +10,7 @@ SKIPPED_MESSAGES := Binance Cardano DebugMonero Eos Monero Ontology Ripple SdPro EthereumTypedDataValueRequest EthereumTypedDataValueAck ShowDeviceTutorial \ UnlockBootloader AuthenticateDevice AuthenticityProof \ Solana StellarClaimClaimableBalanceOp \ - ChangeLanguage TranslationDataRequest TranslationDataAck \ + ChangeLanguage TranslationDataRequest TranslationDataAck Thp \ SetBrightness DebugLinkOptigaSetSecMax \ BenchmarkListNames BenchmarkRun BenchmarkNames BenchmarkResult diff --git a/legacy/firmware/protob/messages-debug.options b/legacy/firmware/protob/messages-debug.options index b36b617261e..c099b38b0b9 100644 --- a/legacy/firmware/protob/messages-debug.options +++ b/legacy/firmware/protob/messages-debug.options @@ -2,14 +2,19 @@ DebugLinkDecision.input max_size:33 DebugLinkDecision.x type:FT_IGNORE DebugLinkDecision.y type:FT_IGNORE -DebugLinkState.layout max_size:1024 -DebugLinkState.pin max_size:51 -DebugLinkState.matrix max_size:10 -DebugLinkState.mnemonic_secret max_size:240 -DebugLinkState.reset_word max_size:12 -DebugLinkState.reset_entropy max_size:128 -DebugLinkState.recovery_fake_word max_size:12 -DebugLinkState.tokens type:FT_IGNORE +DebugLinkGetState.thp_channel_id type:FT_IGNORE + +DebugLinkState.layout max_size:1024 +DebugLinkState.pin max_size:51 +DebugLinkState.matrix max_size:10 +DebugLinkState.mnemonic_secret max_size:240 +DebugLinkState.reset_word max_size:12 +DebugLinkState.reset_entropy max_size:128 +DebugLinkState.recovery_fake_word max_size:12 +DebugLinkState.tokens type:FT_IGNORE +DebugLinkState.thp_pairing_code_entry_code type:FT_IGNORE +DebugLinkState.thp_pairing_code_qr_code type:FT_IGNORE +DebugLinkState.thp_pairing_code_nfc_unidirectional type:FT_IGNORE DebugLinkLog.bucket max_size:33 DebugLinkLog.text max_size:256 diff --git a/legacy/firmware/protob/messages-thp.proto b/legacy/firmware/protob/messages-thp.proto new file mode 120000 index 00000000000..4799efe83ae --- /dev/null +++ b/legacy/firmware/protob/messages-thp.proto @@ -0,0 +1 @@ +../../vendor/trezor-common/protob/messages-thp.proto \ No newline at end of file diff --git a/python/requirements.txt b/python/requirements.txt index 440bc2a2bea..e3cf51812de 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -6,3 +6,5 @@ libusb1>=1.6.4 construct>=2.9,!=2.10.55 typing_extensions>=4.7.1 construct-classes>=0.1.2 +cryptography >=43.0.3 +platformdirs >=2 diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 4f6d56f8ed1..d171f582005 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -95,6 +95,15 @@ def client(self) -> TrezorClientDebugLink: raise RuntimeError return self._client + @client.setter + def client(self, new_client: TrezorClientDebugLink) -> None: + """Setter for the client property to update _client.""" + if not isinstance(new_client, TrezorClientDebugLink): + raise TypeError( + f"Expected a TrezorClientDebugLink, got {type(new_client).__name__}." + ) + self._client = new_client + def make_args(self) -> List[str]: return [] @@ -112,7 +121,7 @@ def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None: start = time.monotonic() try: while True: - if transport._ping(): + if transport.ping(): break if self.process.poll() is not None: raise RuntimeError("Emulator process died") diff --git a/python/src/trezorlib/authentication.py b/python/src/trezorlib/authentication.py index 39e26f569fc..2e4a530af5b 100644 --- a/python/src/trezorlib/authentication.py +++ b/python/src/trezorlib/authentication.py @@ -7,7 +7,7 @@ from importlib import metadata from . import device -from .client import TrezorClient +from .transport.session import Session try: cryptography_version = metadata.version("cryptography") @@ -361,7 +361,7 @@ def verify_authentication_response( def authenticate_device( - client: TrezorClient, + session: Session, challenge: bytes | None = None, *, whitelist: t.Collection[bytes] | None = None, @@ -371,7 +371,7 @@ def authenticate_device( if challenge is None: challenge = secrets.token_bytes(16) - resp = device.authenticate(client, challenge) + resp = device.authenticate(session, challenge) return verify_authentication_response( challenge, diff --git a/python/src/trezorlib/benchmark.py b/python/src/trezorlib/benchmark.py index f96ef7970ea..b961dda4262 100644 --- a/python/src/trezorlib/benchmark.py +++ b/python/src/trezorlib/benchmark.py @@ -20,17 +20,17 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session @expect(messages.BenchmarkNames) def list_names( - client: "TrezorClient", + session: "Session", ) -> "MessageType": - return client.call(messages.BenchmarkListNames()) + return session.call(messages.BenchmarkListNames()) @expect(messages.BenchmarkResult) -def run(client: "TrezorClient", name: str) -> "MessageType": - return client.call(messages.BenchmarkRun(name=name)) +def run(session: "Session", name: str) -> "MessageType": + return session.call(messages.BenchmarkRun(name=name)) diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index d2e4b97912c..afe251a06c3 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -18,22 +18,22 @@ from . import messages from .protobuf import dict_to_proto -from .tools import expect, session +from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.BinanceAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.BinanceGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -42,16 +42,15 @@ def get_address( @expect(messages.BinancePublicKey, field="public_key", ret_type=bytes) def get_public_key( - client: "TrezorClient", address_n: "Address", show_display: bool = False + session: "Session", address_n: "Address", show_display: bool = False ) -> "MessageType": - return client.call( + return session.call( messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) ) -@session def sign_tx( - client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False + session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False ) -> messages.BinanceSignedTx: msg = tx_json["msgs"][0] tx_msg = tx_json.copy() @@ -60,7 +59,7 @@ def sign_tx( tx_msg["chunkify"] = chunkify envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) - response = client.call(envelope) + response = session.call(envelope) if not isinstance(response, messages.BinanceTxRequest): raise RuntimeError( @@ -77,7 +76,7 @@ def sign_tx( else: raise ValueError("can not determine msg type") - response = client.call(msg) + response = session.call(msg) if not isinstance(response, messages.BinanceSignedTx): raise RuntimeError( diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index a71ead2adc2..3ccb1a95959 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -13,7 +13,6 @@ # # You should have received a copy of the License along with this library. # If not, see . - import warnings from copy import copy from decimal import Decimal @@ -23,12 +22,12 @@ from typing_extensions import Protocol, TypedDict from . import exceptions, messages -from .tools import expect, prepare_message_bytes, session +from .tools import expect, prepare_message_bytes if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session class ScriptSig(TypedDict): asm: str @@ -105,7 +104,7 @@ def make_bin_output(vout: "Vout") -> messages.TxOutputBinType: @expect(messages.PublicKey) def get_public_node( - client: "TrezorClient", + session: "Session", n: "Address", ecdsa_curve_name: Optional[str] = None, show_display: bool = False, @@ -116,13 +115,13 @@ def get_public_node( unlock_path_mac: Optional[bytes] = None, ) -> "MessageType": if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") - return client.call( + return session.call( messages.GetPublicKey( address_n=n, ecdsa_curve_name=ecdsa_curve_name, @@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any): @expect(messages.Address) def get_authenticated_address( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", show_display: bool = False, @@ -153,13 +152,13 @@ def get_authenticated_address( chunkify: bool = False, ) -> "MessageType": if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") - return client.call( + return session.call( messages.GetAddress( address_n=n, coin_name=coin_name, @@ -172,15 +171,16 @@ def get_authenticated_address( ) +# TODO this is used by tests only @expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) def get_ownership_id( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> "MessageType": - return client.call( + return session.call( messages.GetOwnershipId( address_n=n, coin_name=coin_name, @@ -190,8 +190,9 @@ def get_ownership_id( ) +# TODO this is used by tests only def get_ownership_proof( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, @@ -202,11 +203,11 @@ def get_ownership_proof( preauthorized: bool = False, ) -> Tuple[bytes, bytes]: if preauthorized: - res = client.call(messages.DoPreauthorized()) + res = session.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): raise exceptions.TrezorException("Unexpected message") - res = client.call( + res = session.call( messages.GetOwnershipProof( address_n=n, coin_name=coin_name, @@ -226,7 +227,7 @@ def get_ownership_proof( @expect(messages.MessageSignature) def sign_message( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", message: AnyStr, @@ -234,7 +235,7 @@ def sign_message( no_script_type: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.SignMessage( coin_name=coin_name, address_n=n, @@ -247,7 +248,7 @@ def sign_message( def verify_message( - client: "TrezorClient", + session: "Session", coin_name: str, address: str, signature: bytes, @@ -255,7 +256,7 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - resp = client.call( + resp = session.call( messages.VerifyMessage( address=address, signature=signature, @@ -269,9 +270,9 @@ def verify_message( return isinstance(resp, messages.Success) -@session +# @session def sign_tx( - client: "TrezorClient", + session: "Session", coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], @@ -319,17 +320,17 @@ def sign_tx( setattr(signtx, name, value) if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") elif preauthorized: - res = client.call(messages.DoPreauthorized()) + res = session.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): raise exceptions.TrezorException("Unexpected message") - res = client.call(signtx) + res = session.call(signtx) # Prepare structure for signatures signatures: List[Optional[bytes]] = [None] * len(inputs) @@ -388,7 +389,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: if res.request_type == R.TXPAYMENTREQ: assert res.details.request_index is not None msg = payment_reqs[res.details.request_index] - res = client.call(msg) + res = session.call(msg) else: msg = messages.TransactionType() if res.request_type == R.TXMETA: @@ -418,7 +419,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: f"Unknown request type - {res.request_type}." ) - res = client.call(messages.TxAck(tx=msg)) + res = session.call(messages.TxAck(tx=msg)) if not isinstance(res, messages.TxRequest): raise exceptions.TrezorException("Unexpected message") @@ -432,7 +433,7 @@ def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: @expect(messages.Success, field="message", ret_type=str) def authorize_coinjoin( - client: "TrezorClient", + session: "Session", coordinator: str, max_rounds: int, max_coordinator_fee_rate: int, @@ -441,7 +442,7 @@ def authorize_coinjoin( coin_name: str, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> "MessageType": - return client.call( + return session.call( messages.AuthorizeCoinJoin( coordinator=coordinator, max_rounds=max_rounds, diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 49d2c6463f8..f39cfb42221 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -35,8 +35,8 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session PROTOCOL_MAGICS = { "mainnet": 764824073, @@ -825,7 +825,7 @@ def _get_collateral_inputs_items( @expect(messages.CardanoAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_parameters: messages.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], @@ -833,7 +833,7 @@ def get_address( derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetAddress( address_parameters=address_parameters, protocol_magic=protocol_magic, @@ -847,12 +847,12 @@ def get_address( @expect(messages.CardanoPublicKey) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, show_display: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type, @@ -863,12 +863,12 @@ def get_public_key( @expect(messages.CardanoNativeScriptHash) def get_native_script_hash( - client: "TrezorClient", + session: "Session", native_script: messages.CardanoNativeScript, display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetNativeScriptHash( script=native_script, display_format=display_format, @@ -878,7 +878,7 @@ def get_native_script_hash( def sign_tx( - client: "TrezorClient", + session: "Session", signing_mode: messages.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithData], @@ -915,7 +915,7 @@ def sign_tx( signing_mode, ) - response = client.call( + response = session.call( messages.CardanoSignTxInit( signing_mode=signing_mode, inputs_count=len(inputs), @@ -951,14 +951,14 @@ def sign_tx( _get_certificates_items(certificates), withdrawals, ): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: - auxiliary_data_supplement = client.call(auxiliary_data) + auxiliary_data_supplement = session.call(auxiliary_data) if not isinstance( auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement ): @@ -971,7 +971,7 @@ def sign_tx( auxiliary_data_supplement.__dict__ ) - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR @@ -980,24 +980,24 @@ def sign_tx( _get_collateral_inputs_items(collateral_inputs), required_signers, ): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR if collateral_return is not None: for tx_item in _get_output_items(collateral_return): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR for reference_input in reference_inputs: - response = client.call(reference_input) + response = session.call(reference_input) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["witnesses"] = [] for witness_request in witness_requests: - response = client.call(witness_request) + response = session.call(witness_request) if not isinstance(response, messages.CardanoTxWitnessResponse): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["witnesses"].append( @@ -1009,12 +1009,12 @@ def sign_tx( } ) - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoTxBodyHash): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["tx_hash"] = response.tx_hash - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoSignTxFinished): raise UNEXPECTED_RESPONSE_ERROR diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 6db335a7adc..0b14778ed7b 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -14,33 +14,42 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import functools +import logging +import os import sys +import typing as t from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -from .. import exceptions, transport -from ..client import TrezorClient -from ..ui import ClickUI, ScriptUI +from .. import exceptions, transport, ui +from ..client import ProtocolVersion, TrezorClient +from ..messages import Capability +from ..transport import Transport +from ..transport.session import Session, SessionV1, SessionV2 +from ..transport.thp.channel_database import get_channel_db + +LOG = logging.getLogger(__name__) -if TYPE_CHECKING: +if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ - from typing import TypeVar from typing_extensions import Concatenate, ParamSpec - from ..transport import Transport - from ..ui import TrezorClientUI - P = ParamSpec("P") - R = TypeVar("R") + R = t.TypeVar("R") + FuncWithSession = t.Callable[Concatenate[Session, P], R] class ChoiceType(click.Choice): - def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None: + + def __init__( + self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True + ) -> None: super().__init__(list(typemap.keys())) self.case_sensitive = case_sensitive if case_sensitive: @@ -48,7 +57,7 @@ def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None else: self.typemap = {k.lower(): v for k, v in typemap.items()} - def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: + def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -57,11 +66,69 @@ def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: return self.typemap[value] +def get_passphrase( + passphrase_on_host: bool, available_on_device: bool +) -> t.Union[str, object]: + if available_on_device and not passphrase_on_host: + return ui.PASSPHRASE_ON_DEVICE + + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: + ui.echo("Passphrase required. Using PASSPHRASE environment variable.") + return env_passphrase + + while True: + try: + passphrase = ui.prompt( + "Passphrase required", + hide_input=True, + default="", + show_default=False, + ) + # In case user sees the input on the screen, we do not need confirmation + if not ui.CAN_HANDLE_HIDDEN_INPUT: + return passphrase + second = ui.prompt( + "Confirm your passphrase", + hide_input=True, + default="", + show_default=False, + ) + if passphrase == second: + return passphrase + else: + ui.echo("Passphrase did not match. Please try again.") + except click.Abort: + raise exceptions.Cancelled from None + + +def get_client(transport: Transport) -> TrezorClient: + stored_channels = get_channel_db().load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + try: + client = TrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + except Exception: + LOG.debug("Failed to resume a channel. Replacing by a new one.") + get_channel_db().remove_channel(path) + client = TrezorClient(transport) + else: + client = TrezorClient(transport) + return client + + class TrezorConnection: + def __init__( self, path: str, - session_id: Optional[bytes], + session_id: bytes | None, passphrase_on_host: bool, script: bool, ) -> None: @@ -70,6 +137,54 @@ def __init__( self.passphrase_on_host = passphrase_on_host self.script = script + def get_session( + self, + derive_cardano: bool = False, + empty_passphrase: bool = False, + must_resume: bool = False, + ) -> Session: + client = self.get_client() + if must_resume and self.session_id is None: + click.echo("Failed to resume session - no session id provided") + raise RuntimeError("Failed to resume session - no session id provided") + + # Try resume session from id + if self.session_id is not None: + if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + session = SessionV1.resume_from_id( + client=client, session_id=self.session_id + ) + elif client.protocol_version is ProtocolVersion.PROTOCOL_V2: + session = SessionV2(client, self.session_id) + # TODO fix resumption on THP + else: + raise Exception("Unsupported client protocol", client.protocol_version) + if must_resume: + if session.id != self.session_id or session.id is None: + click.echo("Failed to resume session") + RuntimeError("Failed to resume session - no session id provided") + return session + + features = client.protocol.get_features() + + passphrase_enabled = True # TODO what to do here? + + if not passphrase_enabled: + return client.get_session(derive_cardano=derive_cardano) + + if empty_passphrase: + passphrase = "" + else: + available_on_device = Capability.PassphraseEntry in features.capabilities + passphrase = get_passphrase(available_on_device, self.passphrase_on_host) + # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + return session + def get_transport(self) -> "Transport": try: # look for transport without prefix search @@ -82,19 +197,13 @@ def get_transport(self) -> "Transport": # if this fails, we want the exception to bubble up to the caller return transport.get_transport(self.path, prefix_search=True) - def get_ui(self) -> "TrezorClientUI": - if self.script: - # It is alright to return just the class object instead of instance, - # as the ScriptUI class object itself is the implementation of TrezorClientUI - # (ScriptUI is just a set of staticmethods) - return ScriptUI - else: - return ClickUI(passphrase_on_host=self.passphrase_on_host) - def get_client(self) -> TrezorClient: - transport = self.get_transport() - ui = self.get_ui() - return TrezorClient(transport, ui=ui, session_id=self.session_id) + return get_client(self.get_transport()) + + def get_management_session(self) -> Session: + client = self.get_client() + management_session = client.get_management_session() + return management_session @contextmanager def client_context(self): @@ -128,7 +237,57 @@ def client_context(self): # other exceptions may cause a traceback -def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": +def with_session( + func: "t.Callable[Concatenate[Session, P], R]|None" = None, + *, + empty_passphrase: bool = False, + derive_cardano: bool = False, + management: bool = False, + must_resume: bool = False, +) -> t.Callable[[FuncWithSession], t.Callable[P, R]]: + """Provides a Click command with parameter `session=obj.get_session(...)` or + `session=obj.get_management_session()` based on the parameters provided. + + If default parameters are ok, this decorator can be used without parentheses. + + TODO: handle resumption of sessions and their (potential) closure. + """ + + def decorator( + func: FuncWithSession, + ) -> "t.Callable[P, R]": + + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + if management: + session = obj.get_management_session() + else: + session = obj.get_session( + derive_cardano=derive_cardano, + empty_passphrase=empty_passphrase, + must_resume=must_resume, + ) + try: + return func(session, *args, **kwargs) + finally: + pass + # TODO try end session if not resumed + + return function_with_session + + # If the decorator @get_session is used without parentheses + if func and callable(func): + return decorator(func) # type: ignore [Function return type] + + return decorator + + +def with_client( + func: "t.Callable[Concatenate[TrezorClient, P], R]", +) -> "t.Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -142,23 +301,62 @@ def trezorctl_command_with_client( obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" ) -> "R": with obj.client_context() as client: - session_was_resumed = obj.session_id == client.session_id - if not session_was_resumed and obj.session_id is not None: - # tried to resume but failed - click.echo("Warning: failed to resume session.", err=True) - + # session_was_resumed = obj.session_id == client.session_id + # if not session_was_resumed and obj.session_id is not None: + # # tried to resume but failed + # click.echo("Warning: failed to resume session.", err=True) + click.echo( + "Warning: resume session detection is not implemented yet!", err=True + ) try: return func(client, *args, **kwargs) finally: - if not session_was_resumed: - try: - client.end_session() - except Exception: - pass + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) + # if not session_was_resumed: + # try: + # client.end_session() + # except Exception: + # pass return trezorctl_command_with_client +# def with_client( +# func: "t.Callable[Concatenate[TrezorClient, P], R]", +# ) -> "t.Callable[P, R]": +# """Wrap a Click command in `with obj.client_context() as client`. + +# Sessions are handled transparently. The user is warned when session did not resume +# cleanly. The session is closed after the command completes - unless the session +# was resumed, in which case it should remain open. +# """ + +# @click.pass_obj +# @functools.wraps(func) +# def trezorctl_command_with_client( +# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" +# ) -> "R": +# with obj.client_context() as client: +# session_was_resumed = obj.session_id == client.session_id +# if not session_was_resumed and obj.session_id is not None: +# # tried to resume but failed +# click.echo("Warning: failed to resume session.", err=True) + +# try: +# return func(client, *args, **kwargs) +# finally: +# if not session_was_resumed: +# try: +# client.end_session() +# except Exception: +# pass + +# # the return type of @click.pass_obj is improperly specified and pyright doesn't +# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) +# return trezorctl_command_with_client + + class AliasedGroup(click.Group): """Command group that handles aliases and Click 6.x compatibility. @@ -188,14 +386,14 @@ class AliasedGroup(click.Group): def __init__( self, - aliases: Optional[Dict[str, click.Command]] = None, - *args: Any, - **kwargs: Any, + aliases: t.Dict[str, click.Command] | None = None, + *args: t.Any, + **kwargs: t.Any, ) -> None: super().__init__(*args, **kwargs) self.aliases = aliases or {} - def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) diff --git a/python/src/trezorlib/cli/benchmark.py b/python/src/trezorlib/cli/benchmark.py index e445089815c..7908223881f 100644 --- a/python/src/trezorlib/cli/benchmark.py +++ b/python/src/trezorlib/cli/benchmark.py @@ -20,17 +20,15 @@ import click from .. import benchmark -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session -def list_names_patern( - client: "TrezorClient", pattern: Optional[str] = None -) -> List[str]: - names = list(benchmark.list_names(client).names) +def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]: + names = list(benchmark.list_names(session).names) if pattern is None: return names return [name for name in names if fnmatch(name, pattern)] @@ -43,10 +41,10 @@ def cli() -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: +@with_session(empty_passphrase=True) +def list_names(session: "Session", pattern: Optional[str] = None) -> None: """List names of all supported benchmarks""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: @@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def run(client: "TrezorClient", pattern: Optional[str]) -> None: +@with_session(empty_passphrase=True) +def run(session: "Session", pattern: Optional[str]) -> None: """Run benchmark""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: for name in names: - result = benchmark.run(client, name) + result = benchmark.run(session, name) click.echo(f"{name}: {result.value} {result.unit}") diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index a3139fb2711..d8097b3e900 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -20,11 +20,11 @@ import click from .. import binance, tools -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" @@ -39,23 +39,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Binance address for specified path.""" address_n = tools.parse_path(address) - return binance.get_address(client, address_n, show_display, chunkify) + return binance.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Binance public key.""" address_n = tools.parse_path(address) - return binance.get_public_key(client, address_n, show_display).hex() + return binance.get_public_key(session, address_n, show_display).hex() @cli.command() @@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.BinanceSignedTx": """Sign Binance transaction. Transaction must be provided as a JSON file. """ address_n = tools.parse_path(address) - return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify) diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index d6a9867215c..77bbe83f811 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -13,6 +13,7 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import base64 import json @@ -22,10 +23,10 @@ import construct as c from .. import btc, messages, protobuf, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PURPOSE_BIP44 = 44 PURPOSE_BIP48 = 48 @@ -174,15 +175,15 @@ def cli() -> None: help="Sort pubkeys lexicographically using BIP-67", ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", coin: str, address: str, - script_type: Optional[messages.InputScriptType], + script_type: messages.InputScriptType | None, show_display: bool, multisig_xpub: List[str], - multisig_threshold: Optional[int], + multisig_threshold: int | None, multisig_suffix_length: int, multisig_sort_pubkeys: bool, chunkify: bool, @@ -235,7 +236,7 @@ def get_address( multisig = None return btc.get_address( - client, + session, coin, address_n, show_display, @@ -252,9 +253,9 @@ def get_address( @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_node( - client: "TrezorClient", + session: "Session", coin: str, address: str, curve: Optional[str], @@ -266,7 +267,7 @@ def get_public_node( if script_type is None: script_type = guess_script_type_from_path(address_n) result = btc.get_public_node( - client, + session, address_n, ecdsa_curve_name=curve, show_display=show_display, @@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str: def _get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, purpose: Optional[int], @@ -326,7 +327,7 @@ def _get_descriptor( n = tools.parse_path(path) pub = btc.get_public_node( - client, + session, n, show_display=show_display, coin_name=coin, @@ -363,9 +364,9 @@ def _get_descriptor( @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, account_type: Optional[int], @@ -375,7 +376,7 @@ def get_descriptor( """Get descriptor of given account.""" try: return _get_descriptor( - client, coin, account, account_type, script_type, show_display + session, coin, account, account_type, script_type, show_display ) except ValueError as e: raise click.ClickException(str(e)) @@ -390,8 +391,8 @@ def get_descriptor( @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) @click.argument("json_file", type=click.File()) -@with_client -def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: } _, serialized_tx = btc.sign_tx( - client, + session, coin, inputs, outputs, @@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: ) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, message: str, @@ -462,7 +463,7 @@ def sign_message( if script_type is None: script_type = guess_script_type_from_path(address_n) res = btc.sign_message( - client, + session, coin, address_n, message, @@ -483,9 +484,9 @@ def sign_message( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, signature: str, @@ -495,7 +496,7 @@ def verify_message( """Verify message.""" signature_bytes = base64.b64decode(signature) return btc.verify_message( - client, coin, address, signature_bytes, message, chunkify=chunkify + session, coin, address, signature_bytes, message, chunkify=chunkify ) diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index 26d4eab5b99..1e6935d6d9a 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -20,10 +20,10 @@ import click from .. import cardano, messages, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" @@ -62,9 +62,9 @@ def cli() -> None: @click.option("-i", "--include-network-id", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True) -@with_client +@with_session(derive_cardano=True) def sign_tx( - client: "TrezorClient", + session: "Session", file: TextIO, signing_mode: messages.CardanoTxSigningMode, protocol_magic: int, @@ -123,9 +123,8 @@ def sign_tx( for p in transaction["additional_witness_requests"] ] - client.init_device(derive_cardano=True) sign_tx_response = cardano.sign_tx( - client, + session, signing_mode, inputs, outputs, @@ -209,9 +208,9 @@ def sign_tx( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_address( - client: "TrezorClient", + session: "Session", address: str, address_type: messages.CardanoAddressType, staking_address: str, @@ -262,9 +261,8 @@ def get_address( script_staking_hash_bytes, ) - client.init_device(derive_cardano=True) return cardano.get_address( - client, + session, address_parameters, protocol_magic, network_id, @@ -283,18 +281,17 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_public_key( - client: "TrezorClient", + session: "Session", address: str, derivation_type: messages.CardanoDerivationType, show_display: bool, ) -> messages.CardanoPublicKey: """Get Cardano public key.""" address_n = tools.parse_path(address) - client.init_device(derive_cardano=True) return cardano.get_public_key( - client, address_n, derivation_type=derivation_type, show_display=show_display + session, address_n, derivation_type=derivation_type, show_display=show_display ) @@ -312,9 +309,9 @@ def get_public_key( type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), default=messages.CardanoDerivationType.ICARUS, ) -@with_client +@with_session(derive_cardano=True) def get_native_script_hash( - client: "TrezorClient", + session: "Session", file: TextIO, display_format: messages.CardanoNativeScriptHashDisplayFormat, derivation_type: messages.CardanoDerivationType, @@ -323,7 +320,6 @@ def get_native_script_hash( native_script_json = json.load(file) native_script = cardano.parse_native_script(native_script_json) - client.init_device(derive_cardano=True) return cardano.get_native_script_hash( - client, native_script, display_format, derivation_type=derivation_type + session, native_script, display_format, derivation_type=derivation_type ) diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index a58b80d4b69..469bc719a48 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -19,10 +19,10 @@ import click from .. import misc, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PROMPT_TYPE = ChoiceType( @@ -42,10 +42,10 @@ def cli() -> None: @cli.command() @click.argument("size", type=int) -@with_client -def get_entropy(client: "TrezorClient", size: int) -> str: +@with_session(empty_passphrase=True) +def get_entropy(session: "Session", size: int) -> str: """Get random bytes from device.""" - return misc.get_entropy(client, size).hex() + return misc.get_entropy(session, size).hex() @cli.command() @@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str: ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -75,7 +75,7 @@ def encrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.encrypt_keyvalue( - client, + session, address_n, key, value.encode(), @@ -91,9 +91,9 @@ def encrypt_keyvalue( ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -112,7 +112,7 @@ def decrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.decrypt_keyvalue( - client, + session, address_n, key, bytes.fromhex(value), diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index 50613a04eee..1670117eb8d 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -18,13 +18,12 @@ import click -from .. import mapping, messages, protobuf -from ..client import TrezorClient from ..debuglink import TrezorClientDebugLink from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import record_screen -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from . import TrezorConnection @@ -35,51 +34,51 @@ def cli() -> None: """Miscellaneous debug features.""" -@cli.command() -@click.argument("message_name_or_type") -@click.argument("hex_data") -@click.pass_obj -def send_bytes( - obj: "TrezorConnection", message_name_or_type: str, hex_data: str -) -> None: - """Send raw bytes to Trezor. +# @cli.command() +# @click.argument("message_name_or_type") +# @click.argument("hex_data") +# @click.pass_obj +# def send_bytes( +# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str +# ) -> None: +# """Send raw bytes to Trezor. - Message type and message data must be specified separately, due to how message - chunking works on the transport level. Message length is calculated and sent - automatically, and it is currently impossible to explicitly specify invalid length. +# Message type and message data must be specified separately, due to how message +# chunking works on the transport level. Message length is calculated and sent +# automatically, and it is currently impossible to explicitly specify invalid length. - MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, - in which case the value of that enum is used. - """ - if message_name_or_type.isdigit(): - message_type = int(message_name_or_type) - else: - message_type = getattr(messages.MessageType, message_name_or_type) +# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, +# in which case the value of that enum is used. +# """ +# if message_name_or_type.isdigit(): +# message_type = int(message_name_or_type) +# else: +# message_type = getattr(messages.MessageType, message_name_or_type) - if not isinstance(message_type, int): - raise click.ClickException("Invalid message type.") +# if not isinstance(message_type, int): +# raise click.ClickException("Invalid message type.") - try: - message_data = bytes.fromhex(hex_data) - except Exception as e: - raise click.ClickException("Invalid hex data.") from e +# try: +# message_data = bytes.fromhex(hex_data) +# except Exception as e: +# raise click.ClickException("Invalid hex data.") from e - transport = obj.get_transport() - transport.begin_session() - transport.write(message_type, message_data) +# transport = obj.get_transport() +# transport.deprecated_begin_session() +# transport.write(message_type, message_data) - response_type, response_data = transport.read() - transport.end_session() +# response_type, response_data = transport.read() +# transport.deprecated_end_session() - click.echo(f"Response type: {response_type}") - click.echo(f"Response data: {response_data.hex()}") +# click.echo(f"Response type: {response_type}") +# click.echo(f"Response data: {response_data.hex()}") - try: - msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) - click.echo("Parsed message:") - click.echo(protobuf.format_message(msg)) - except Exception as e: - click.echo(f"Could not parse response: {e}") +# try: +# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) +# click.echo("Parsed message:") +# click.echo(protobuf.format_message(msg)) +# except Exception as e: +# click.echo(f"Could not parse response: {e}") @cli.command() @@ -106,17 +105,17 @@ def record_screen_from_connection( @cli.command() -@with_client -def prodtest_t1(client: "TrezorClient") -> str: +@with_session(management=True) +def prodtest_t1(session: "Session") -> str: """Perform a prodtest on Model One. Only available on PRODTEST firmware and on T1B1. Formerly named self-test. """ - return debuglink_prodtest_t1(client) + return debuglink_prodtest_t1(session) @cli.command() -@with_client -def optiga_set_sec_max(client: "TrezorClient") -> str: +@with_session(management=True) +def optiga_set_sec_max(session: "Session") -> str: """Set Optiga's security event counter to maximum.""" - return debuglink_optiga_set_sec_max(client) + return debuglink_optiga_set_sec_max(session) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 52c0bd3961c..d53aad19934 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -24,11 +24,11 @@ import requests from .. import debuglink, device, exceptions, messages, ui -from . import ChoiceType, with_client +from . import ChoiceType, with_session if t.TYPE_CHECKING: - from ..client import TrezorClient from ..protobuf import MessageType + from ..transport.session import Session from . import TrezorConnection RECOVERY_DEVICE_INPUT_METHOD = { @@ -64,17 +64,18 @@ def cli() -> None: help="Wipe device in bootloader mode. This also erases the firmware.", is_flag=True, ) -@with_client -def wipe(client: "TrezorClient", bootloader: bool) -> str: +@with_session(management=True) +def wipe(session: "Session", bootloader: bool) -> str: """Reset device to factory defaults and remove all private data.""" + features = session.features if bootloader: - if not client.features.bootloader_mode: + if not features.bootloader_mode: click.echo("Please switch your device to bootloader mode.") sys.exit(1) else: click.echo("Wiping user data and firmware!") else: - if client.features.bootloader_mode: + if features.bootloader_mode: click.echo( "Your device is in bootloader mode. This operation would also erase firmware." ) @@ -87,7 +88,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str: click.echo("Wiping user data!") try: - return device.wipe(client) + return device.wipe( + session + ) # TODO decide where the wipe should happen - management or regular session except exceptions.TrezorFailure as e: click.echo("Action failed: {} {}".format(*e.args)) sys.exit(3) @@ -103,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str: @click.option("-a", "--academic", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) -@with_client +@with_session(management=True) def load( - client: "TrezorClient", + session: "Session", mnemonic: t.Sequence[str], pin: str, passphrase_protection: bool, @@ -136,7 +139,7 @@ def load( try: return debuglink.load_device( - client, + session, mnemonic=list(mnemonic), pin=pin, passphrase_protection=passphrase_protection, @@ -171,9 +174,9 @@ def load( ) @click.option("-d", "--dry-run", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True) -@with_client +@with_session(management=True) def recover( - client: "TrezorClient", + session: "Session", words: str, expand: bool, pin_protection: bool, @@ -201,7 +204,7 @@ def recover( type = messages.RecoveryType.UnlockRepeatedBackup return device.recover( - client, + session, word_count=int(words), passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -222,9 +225,9 @@ def recover( @click.option("-s", "--skip-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) -@with_client +@with_session(management=True) def setup( - client: "TrezorClient", + session: "Session", strength: int | None, passphrase_protection: bool, pin_protection: bool, @@ -241,7 +244,7 @@ def setup( BT = messages.BackupType if backup_type is None: - if client.version >= (2, 7, 1): + if session.version >= (2, 7, 1): # SLIP39 extendable was introduced in 2.7.1 backup_type = BT.Slip39_Single_Extendable else: @@ -251,10 +254,10 @@ def setup( if ( backup_type in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) - and messages.Capability.Shamir not in client.features.capabilities + and messages.Capability.Shamir not in session.features.capabilities ) or ( backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable) - and messages.Capability.ShamirGroups not in client.features.capabilities + and messages.Capability.ShamirGroups not in session.features.capabilities ): click.echo( "WARNING: Your Trezor device does not indicate support for the requested\n" @@ -262,7 +265,7 @@ def setup( ) return device.reset( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -277,23 +280,21 @@ def setup( @cli.command() @click.option("-t", "--group-threshold", type=int) @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") -@with_client +@with_session(management=True) def backup( - client: "TrezorClient", + session: "Session", group_threshold: int | None = None, groups: t.Sequence[tuple[int, int]] = (), ) -> str: """Perform device seed backup.""" - return device.backup(client, group_threshold, groups) + return device.backup(session, group_threshold, groups) @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) -@with_client -def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType -) -> str: +@with_session(management=True) +def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -307,9 +308,9 @@ def sd_protect( off - Remove SD card secret protection. refresh - Replace the current SD card secret with a new one. """ - if client.features.model == "1": + if session.features.model == "1": raise click.ClickException("Trezor One does not support SD card protection.") - return device.sd_protect(client, operation) + return device.sd_protect(session, operation) @cli.command() @@ -319,24 +320,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str: Currently only supported on Trezor Model One. """ - # avoid using @with_client because it closes the session afterwards, + # avoid using @with_management_session because it closes the session afterwards, # which triggers double prompt on device with obj.client_context() as client: - return device.reboot_to_bootloader(client) + return device.reboot_to_bootloader(client.get_management_session()) @cli.command() -@with_client -def tutorial(client: "TrezorClient") -> str: +@with_session(management=True) +def tutorial(session: "Session") -> str: """Show on-device tutorial.""" - return device.show_device_tutorial(client) + return device.show_device_tutorial(session) @cli.command() -@with_client -def unlock_bootloader(client: "TrezorClient") -> str: +@with_session(management=True) +def unlock_bootloader(session: "Session") -> str: """Unlocks bootloader. Irreversible.""" - return device.unlock_bootloader(client) + return device.unlock_bootloader(session) @cli.command() @@ -347,11 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> str: type=int, help="Dialog expiry in seconds.", ) -@with_client -def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str: +@with_session(management=True) +def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str: """Show a "Do not disconnect" dialog.""" if enable is False: - return device.set_busy(client, None) + return device.set_busy(session, None) if expiry is None: raise click.ClickException("Missing option '-e' / '--expiry'.") @@ -361,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." ) - return device.set_busy(client, expiry * 1000) + return device.set_busy(session, expiry * 1000) PUBKEY_WHITELIST_URL_TEMPLATE = ( @@ -381,9 +382,9 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> is_flag=True, help="Do not check intermediate certificates against the whitelist.", ) -@with_client +@with_session(management=True) def authenticate( - client: "TrezorClient", + session: "Session", hex_challenge: str | None, root: t.BinaryIO | None, raw: bool | None, @@ -408,7 +409,7 @@ def authenticate( challenge = bytes.fromhex(hex_challenge) if raw: - msg = device.authenticate(client, challenge) + msg = device.authenticate(session, challenge) click.echo(f"Challenge: {hex_challenge}") click.echo(f"Signature of challenge: {msg.signature.hex()}") @@ -456,14 +457,14 @@ def format(self, record: logging.LogRecord) -> str: else: whitelist_json = requests.get( PUBKEY_WHITELIST_URL_TEMPLATE.format( - model=client.model.internal_name.lower() + model=session.model.internal_name.lower() ) ).json() whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]] try: authentication.authenticate_device( - client, challenge, root_pubkey=root_bytes, whitelist=whitelist + session, challenge, root_pubkey=root_bytes, whitelist=whitelist ) except authentication.DeviceNotAuthentic: click.echo("Device is not authentic.") diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index 84c248c4a4c..27d461d8b0b 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -20,11 +20,11 @@ import click from .. import eos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" @@ -37,11 +37,11 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Eos public key in base58 encoding.""" address_n = tools.parse_path(address) - res = eos.get_public_key(client, address_n, show_display) + res = eos.get_public_key(session, address_n, show_display) return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" @@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_transaction( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.EosSignedTx": """Sign EOS transaction.""" tx_json = json.load(file) address_n = tools.parse_path(address) return eos.sign_tx( - client, + session, address_n, tx_json["transaction"], tx_json["chain_id"], diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 6bbfc0d356d..d810d2bf2d1 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -26,14 +26,14 @@ from .. import _rlp, definitions, ethereum, tools from ..messages import EthereumDefinitions -from . import with_client +from . import with_session if TYPE_CHECKING: import web3 from eth_typing import ChecksumAddress # noqa: I900 from web3.types import Wei - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" @@ -268,24 +268,24 @@ def cli( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - return ethereum.get_address(client, address_n, show_display, network, chunkify) + return ethereum.get_address(session, address_n, show_display, network, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: +@with_session +def get_public_node(session: "Session", address: str, show_display: bool) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) - result = ethereum.get_public_node(client, address_n, show_display=show_display) + result = ethereum.get_public_node(session, address_n, show_display=show_display) return { "node": { "depth": result.node.depth, @@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-C", "--chunkify", is_flag=True) @click.argument("to_address") @click.argument("amount", callback=_amount_to_int) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", chain_id: int, address: str, amount: int, @@ -400,7 +400,7 @@ def sign_tx( encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) address_n = tools.parse_path(address) from_address = ethereum.get_address( - client, address_n, encoded_network=encoded_network + session, address_n, encoded_network=encoded_network ) if token: @@ -446,7 +446,7 @@ def sign_tx( assert max_gas_fee is not None assert max_priority_fee is not None sig = ethereum.sign_tx_eip1559( - client, + session, n=address_n, nonce=nonce, gas_limit=gas_limit, @@ -465,7 +465,7 @@ def sign_tx( gas_price = _get_web3().eth.gas_price assert gas_price is not None sig = ethereum.sign_tx( - client, + session, n=address_n, tx_type=tx_type, nonce=nonce, @@ -526,14 +526,14 @@ def sign_tx( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", address: str, message: str, chunkify: bool + session: "Session", address: str, message: str, chunkify: bool ) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) + ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify) output = { "message": message, "address": ret.address, @@ -550,9 +550,9 @@ def sign_message( help="Be compatible with Metamask's signTypedData_v4 implementation", ) @click.argument("file", type=click.File("r")) -@with_client +@with_session def sign_typed_data( - client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO + session: "Session", address: str, metamask_v4_compat: bool, file: TextIO ) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. @@ -565,7 +565,7 @@ def sign_typed_data( defs = EthereumDefinitions(encoded_network=network) data = json.loads(file.read()) ret = ethereum.sign_typed_data( - client, + session, address_n, data, metamask_v4_compat=metamask_v4_compat, @@ -583,9 +583,9 @@ def sign_typed_data( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: str, message: str, @@ -594,7 +594,7 @@ def verify_message( """Verify message signed with Ethereum address.""" signature_bytes = ethereum.decode_hex(signature) return ethereum.verify_message( - client, address, signature_bytes, message, chunkify=chunkify + session, address, signature_bytes, message, chunkify=chunkify ) @@ -602,9 +602,9 @@ def verify_message( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("domain_hash_hex") @click.argument("message_hash_hex") -@with_client +@with_session def sign_typed_data_hash( - client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str + session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str ) -> Dict[str, str]: """ Sign hash of typed data (EIP-712) with Ethereum address. @@ -618,7 +618,7 @@ def sign_typed_data_hash( message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) ret = ethereum.sign_typed_data_hash( - client, address_n, domain_hash, message_hash, network + session, address_n, domain_hash, message_hash, network ) output = { "domain_hash": domain_hash_hex, diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 5983c572493..024a0bf63fb 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -19,10 +19,10 @@ import click from .. import fido -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} @@ -40,10 +40,10 @@ def credentials() -> None: @credentials.command(name="list") -@with_client -def credentials_list(client: "TrezorClient") -> None: +@with_session(empty_passphrase=True) +def credentials_list(session: "Session") -> None: """List all resident credentials on the device.""" - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) for cred in creds: click.echo("") click.echo(f"WebAuthn credential at index {cred.index}:") @@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") -@with_client -def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: +@with_session(empty_passphrase=True) +def credentials_add(session: "Session", hex_credential_id: str) -> str: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. """ - return fido.add_credential(client, bytes.fromhex(hex_credential_id)) + return fido.add_credential(session, bytes.fromhex(hex_credential_id)) @credentials.command(name="remove") @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@with_client -def credentials_remove(client: "TrezorClient", index: int) -> str: +@with_session(empty_passphrase=True) +def credentials_remove(session: "Session", index: int) -> str: """Remove the resident credential at the given index.""" - return fido.remove_credential(client, index) + return fido.remove_credential(session, index) # @@ -110,19 +110,19 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) -@with_client -def counter_set(client: "TrezorClient", counter: int) -> str: +@with_session(empty_passphrase=True) +def counter_set(session: "Session", counter: int) -> str: """Set FIDO/U2F counter value.""" - return fido.set_counter(client, counter) + return fido.set_counter(session, counter) @counter.command(name="get-next") -@with_client -def counter_get_next(client: "TrezorClient") -> int: +@with_session(empty_passphrase=True) +def counter_get_next(session: "Session") -> int: """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value is returned and atomically increased. This command performs the same operation and returns the counter value. """ - return fido.get_next_counter(client) + return fido.get_next_counter(session) diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index 4376a4f2839..37a393cb4c5 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -37,10 +37,11 @@ from .. import device, exceptions, firmware, messages, models from ..firmware import models as fw_models from ..models import TrezorModel -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: from ..client import TrezorClient + from ..transport.session import Session from . import TrezorConnection MODEL_CHOICE = ChoiceType( @@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool: This is the case from bootloader version 1.8.0, and also holds for firmware version 1.8.0 because that installs the appropriate bootloader. """ - f = client.features - version = (f.major_version, f.minor_version, f.patch_version) - bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) + features = client.features + version = client.version + bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0) return bootloader_onev2 @@ -306,25 +307,26 @@ def find_best_firmware_version( If the specified version is not found, prints the closest available version (higher than the specified one, if existing). """ + features = client.features + model = client.model + if bitcoin_only is None: - bitcoin_only = _should_use_bitcoin_only(client.features) + bitcoin_only = _should_use_bitcoin_only(features) def version_str(version: Iterable[int]) -> str: return ".".join(map(str, version)) - f = client.features - - releases = get_all_firmware_releases(client.model, bitcoin_only, beta) + releases = get_all_firmware_releases(model, bitcoin_only, beta) highest_version = releases[0]["version"] if version: want_version = [int(x) for x in version.split(".")] if len(want_version) != 3: click.echo("Please use the 'X.Y.Z' version format.") - if want_version[0] != f.major_version: + if want_version[0] != features.major_version: click.echo( - f"Warning: Trezor {client.model.name} firmware version should be " - f"{f.major_version}.X.Y (requested: {version})" + f"Warning: Trezor {model.name} firmware version should be " + f"{features.major_version}.X.Y (requested: {version})" ) else: want_version = highest_version @@ -359,8 +361,8 @@ def version_str(version: Iterable[int]) -> str: # to the newer one, in that case update to the minimal # compatible version first # Choosing the version key to compare based on (not) being in BL mode - client_version = [f.major_version, f.minor_version, f.patch_version] - if f.bootloader_mode: + client_version = client.version + if features.bootloader_mode: key_to_compare = "min_bootloader_version" else: key_to_compare = "min_firmware_version" @@ -447,11 +449,11 @@ def extract_embedded_fw( def upload_firmware_into_device( - client: "TrezorClient", + session: "Session", firmware_data: bytes, ) -> None: """Perform the final act of loading the firmware into Trezor.""" - f = client.features + f = session.features try: if f.major_version == 1 and f.firmware_present is not False: # Trezor One does not send ButtonRequest @@ -461,7 +463,7 @@ def upload_firmware_into_device( with click.progressbar( label="Uploading", length=len(firmware_data), show_eta=False ) as bar: - firmware.update(client, firmware_data, bar.update) + firmware.update(session, firmware_data, bar.update) except exceptions.Cancelled: click.echo("Update aborted on device.") except exceptions.TrezorException as e: @@ -654,6 +656,7 @@ def update( against data.trezor.io information, if available. """ with obj.client_context() as client: + management_session = client.get_management_session() if sum(bool(x) for x in (filename, url, version)) > 1: click.echo("You can use only one of: filename, url, version.") sys.exit(1) @@ -709,7 +712,7 @@ def update( if _is_strict_update(client, firmware_data): header_size = _get_firmware_header_size(firmware_data) device.reboot_to_bootloader( - client, + management_session, boot_command=messages.BootCommand.INSTALL_UPGRADE, firmware_header=firmware_data[:header_size], language_data=language_data, @@ -719,7 +722,7 @@ def update( click.echo( "WARNING: Seamless installation not possible, language data will not be uploaded." ) - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(management_session) click.echo("Waiting for bootloader...") while True: @@ -735,13 +738,15 @@ def update( click.echo("Please switch your device to bootloader mode.") sys.exit(1) - upload_firmware_into_device(client=client, firmware_data=firmware_data) + upload_firmware_into_device( + session=client.get_management_session(), firmware_data=firmware_data + ) @cli.command() @click.argument("hex_challenge", required=False) -@with_client -def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str: +@with_session(management=True) +def get_hash(session: "Session", hex_challenge: Optional[str]) -> str: """Get a hash of the installed firmware combined with the optional challenge.""" challenge = bytes.fromhex(hex_challenge) if hex_challenge else None - return firmware.get_hash(client, challenge).hex() + return firmware.get_hash(session, challenge).hex() diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index 355c562ae39..0441ebc09b4 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -19,10 +19,10 @@ import click from .. import messages, monero, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" @@ -42,9 +42,9 @@ def cli() -> None: default=messages.MoneroNetworkType.MAINNET, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, network_type: messages.MoneroNetworkType, @@ -52,7 +52,7 @@ def get_address( ) -> bytes: """Get Monero address for specified path.""" address_n = tools.parse_path(address) - return monero.get_address(client, address_n, show_display, network_type, chunkify) + return monero.get_address(session, address_n, show_display, network_type, chunkify) @cli.command() @@ -63,13 +63,13 @@ def get_address( type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), default=messages.MoneroNetworkType.MAINNET, ) -@with_client +@with_session def get_watch_key( - client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType + session: "Session", address: str, network_type: messages.MoneroNetworkType ) -> Dict[str, str]: """Get Monero watch key for specified path.""" address_n = tools.parse_path(address) - res = monero.get_watch_key(client, address_n, network_type) + res = monero.get_watch_key(session, address_n, network_type) # TODO: could be made required in MoneroWatchKey assert res.address is not None assert res.watch_key is not None diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 746ad187236..eac16c2d8c2 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -21,10 +21,10 @@ import requests from .. import nem, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" @@ -39,9 +39,9 @@ def cli() -> None: @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, network: int, show_display: bool, @@ -49,7 +49,7 @@ def get_address( ) -> str: """Get NEM address for specified path.""" address_n = tools.parse_path(address) - return nem.get_address(client, address_n, network, show_display, chunkify) + return nem.get_address(session, address_n, network, show_display, chunkify) @cli.command() @@ -58,9 +58,9 @@ def get_address( @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, file: TextIO, broadcast: Optional[str], @@ -71,7 +71,7 @@ def sign_tx( Transaction file is expected in the NIS (RequestPrepareAnnounce) format. """ address_n = tools.parse_path(address) - transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify) payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index e4bcc0b3503..634a92028e6 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -20,10 +20,10 @@ import click from .. import ripple, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" @@ -37,13 +37,13 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ripple address""" address_n = tools.parse_path(address) - return ripple.get_address(client, address_n, show_display, chunkify) + return ripple.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -51,13 +51,13 @@ def get_address( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client -def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None: """Sign Ripple transaction""" address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) - result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) + result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify) click.echo("Signature:") click.echo(result.signature.hex()) click.echo() diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index eac93eb7965..d5e615750dc 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -24,10 +24,11 @@ import requests from .. import device, messages, toif -from . import AliasedGroup, ChoiceType, with_client +from ..transport.session import Session +from . import AliasedGroup, ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + pass try: from PIL import Image @@ -180,18 +181,18 @@ def cli() -> None: @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: +@with_session(management=True) +def pin(session: "Session", enable: Optional[bool], remove: bool) -> str: """Set, change or remove PIN.""" # Remove argument is there for backwards compatibility - return device.change_pin(client, remove=_should_remove(enable, remove)) + return device.change_pin(session, remove=_should_remove(enable, remove)) @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: +@with_session(management=True) +def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -199,32 +200,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s removed and the device will be reset to factory defaults. """ # Remove argument is there for backwards compatibility - return device.change_wipe_code(client, remove=_should_remove(enable, remove)) + return device.change_wipe_code(session, remove=_should_remove(enable, remove)) @cli.command() # keep the deprecated -l/--label option, make it do nothing @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") -@with_client -def label(client: "TrezorClient", label: str) -> str: +@with_session(management=True) +def label(session: "Session", label: str) -> str: """Set new device label.""" - return device.apply_settings(client, label=label) + return device.apply_settings(session, label=label) @cli.command() -@with_client -def brightness(client: "TrezorClient") -> str: +@with_session(management=True) +def brightness(session: "Session") -> str: """Set display brightness.""" - return device.set_brightness(client) + return device.set_brightness(session) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def haptic_feedback(client: "TrezorClient", enable: bool) -> str: +@with_session(management=True) +def haptic_feedback(session: "Session", enable: bool) -> str: """Enable or disable haptic feedback.""" - return device.apply_settings(client, haptic_feedback=enable) + return device.apply_settings(session, haptic_feedback=enable) @cli.command() @@ -233,9 +234,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str: "-r", "--remove", is_flag=True, default=False, help="Switch back to english." ) @click.option("-d/-D", "--display/--no-display", default=None) -@with_client +@with_session(management=True) def language( - client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None + session: "Session", path_or_url: str | None, remove: bool, display: bool | None ) -> str: """Set new language with translations.""" if remove != (path_or_url is None): @@ -260,29 +261,29 @@ def language( f"Failed to load translations from {path_or_url}" ) from None return device.change_language( - client, language_data=language_data, show_display=display + session, language_data=language_data, show_display=display ) @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) -@with_client -def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str: +@with_session(management=True) +def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str: """Set display rotation. Configure display rotation for Trezor Model T. The options are north, east, south or west. """ - return device.apply_settings(client, display_rotation=rotation) + return device.apply_settings(session, display_rotation=rotation) @cli.command() @click.argument("delay", type=str) -@with_client -def auto_lock_delay(client: "TrezorClient", delay: str) -> str: +@with_session(management=True) +def auto_lock_delay(session: "Session", delay: str) -> str: """Set auto-lock delay (in seconds).""" - if not client.features.pin_protection: + if not session.features.pin_protection: raise click.ClickException("Set up a PIN first") value, unit = delay[:-1], delay[-1:] @@ -291,13 +292,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str: seconds = float(value) * units[unit] else: seconds = float(delay) # assume seconds if no unit is specified - return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) + return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000)) @cli.command() @click.argument("flags") -@with_client -def flags(client: "TrezorClient", flags: str) -> str: +@with_session(management=True) +def flags(session: "Session", flags: str) -> str: """Set device flags.""" if flags.lower().startswith("0b"): flags_int = int(flags, 2) @@ -305,7 +306,7 @@ def flags(client: "TrezorClient", flags: str) -> str: flags_int = int(flags, 16) else: flags_int = int(flags) - return device.apply_flags(client, flags=flags_int) + return device.apply_flags(session, flags=flags_int) @cli.command() @@ -314,8 +315,8 @@ def flags(client: "TrezorClient", flags: str) -> str: "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") -@with_client -def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: +@with_session(management=True) +def homescreen(session: "Session", filename: str, quality: int) -> str: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -327,39 +328,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: if not path.exists() or not path.is_file(): raise click.ClickException("Cannot open file") - if client.features.model == "1": + if session.features.model == "1": img = image_to_t1(path) else: - if client.features.homescreen_format == messages.HomescreenFormat.Jpeg: + if session.features.homescreen_format == messages.HomescreenFormat.Jpeg: width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 240 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 240 ) img = image_to_jpeg(path, width, height, quality) - elif client.features.homescreen_format == messages.HomescreenFormat.ToiG: - width = client.features.homescreen_width - height = client.features.homescreen_height + elif session.features.homescreen_format == messages.HomescreenFormat.ToiG: + width = session.features.homescreen_width + height = session.features.homescreen_height if width is None or height is None: raise click.ClickException("Device did not report homescreen size.") img = image_to_toif(path, width, height, True) elif ( - client.features.homescreen_format == messages.HomescreenFormat.Toif - or client.features.homescreen_format is None + session.features.homescreen_format == messages.HomescreenFormat.Toif + or session.features.homescreen_format is None ): width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 144 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 144 ) img = image_to_toif(path, width, height, False) @@ -369,7 +370,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: "Unknown image format requested by the device." ) - return device.apply_settings(client, homescreen=img) + return device.apply_settings(session, homescreen=img) @cli.command() @@ -377,9 +378,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) -@with_client +@with_session(management=True) def safety_checks( - client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel + session: "Session", always: bool, level: messages.SafetyCheckLevel ) -> str: """Set safety check level. @@ -392,18 +393,18 @@ def safety_checks( """ if always and level == messages.SafetyCheckLevel.PromptTemporarily: level = messages.SafetyCheckLevel.PromptAlways - return device.apply_settings(client, safety_checks=level) + return device.apply_settings(session, safety_checks=level) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def experimental_features(client: "TrezorClient", enable: bool) -> str: +@with_session(management=True) +def experimental_features(session: "Session", enable: bool) -> str: """Enable or disable experimental message types. This is a developer feature. Use with caution. """ - return device.apply_settings(client, experimental_features=enable) + return device.apply_settings(session, experimental_features=enable) # @@ -426,25 +427,25 @@ def passphrase_main() -> None: @passphrase.command(name="on") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) -@with_client -def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str: +@with_session(management=True) +def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str: """Enable passphrase.""" - if client.features.passphrase_protection is not True: + if session.features.passphrase_protection is not True: use_passphrase = True else: use_passphrase = None return device.apply_settings( - client, + session, use_passphrase=use_passphrase, passphrase_always_on_device=force_on_device, ) @passphrase.command(name="off") -@with_client -def passphrase_off(client: "TrezorClient") -> str: +@with_session(management=True) +def passphrase_off(session: "Session") -> str: """Disable passphrase.""" - return device.apply_settings(client, use_passphrase=False) + return device.apply_settings(session, use_passphrase=False) # Registering the aliases for backwards compatibility @@ -457,10 +458,10 @@ def passphrase_off(client: "TrezorClient") -> str: @passphrase.command(name="hide") @click.argument("hide", type=ChoiceType({"on": True, "off": False})) -@with_client -def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str: +@with_session(management=True) +def hide_passphrase_from_host(session: "Session", hide: bool) -> str: """Enable or disable hiding passphrase coming from host. This is a developer feature. Use with caution. """ - return device.apply_settings(client, hide_passphrase_from_host=hide) + return device.apply_settings(session, hide_passphrase_from_host=hide) diff --git a/python/src/trezorlib/cli/solana.py b/python/src/trezorlib/cli/solana.py index 3fe80a51646..8152116b550 100644 --- a/python/src/trezorlib/cli/solana.py +++ b/python/src/trezorlib/cli/solana.py @@ -4,10 +4,10 @@ import click from .. import messages, solana, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h" @@ -21,40 +21,40 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_key( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, ) -> messages.SolanaPublicKey: """Get Solana public key.""" address_n = tools.parse_path(address) - return solana.get_public_key(client, address_n, show_display) + return solana.get_public_key(session, address_n, show_display) @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, chunkify: bool, ) -> messages.SolanaAddress: """Get Solana address.""" address_n = tools.parse_path(address) - return solana.get_address(client, address_n, show_display, chunkify) + return solana.get_address(session, address_n, show_display, chunkify) @cli.command() @click.argument("serialized_tx", type=str) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-a", "--additional-information-file", type=click.File("r")) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, serialized_tx: str, additional_information_file: Optional[TextIO], @@ -78,7 +78,7 @@ def sign_tx( ) return solana.sign_tx( - client, + session, address_n, bytes.fromhex(serialized_tx), additional_information, diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 77ce700ee5b..9acb6a57ed7 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -21,10 +21,10 @@ import click from .. import stellar, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session try: from stellar_sdk import ( @@ -52,13 +52,13 @@ def cli() -> None: ) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Stellar public address.""" address_n = tools.parse_path(address) - return stellar.get_address(client, address_n, show_display, chunkify) + return stellar.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -77,9 +77,9 @@ def get_address( help="Network passphrase (blank for public network).", ) @click.argument("b64envelope") -@with_client +@with_session def sign_transaction( - client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str + session: "Session", b64envelope: str, address: str, network_passphrase: str ) -> bytes: """Sign a base64-encoded transaction envelope. @@ -109,6 +109,6 @@ def sign_transaction( address_n = tools.parse_path(address) tx, operations = stellar.from_envelope(envelope) - resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) + resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase) return base64.b64encode(resp.signature) diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 7dcd1ab9db1..e4f0c1a877d 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -20,10 +20,10 @@ import click from .. import messages, protobuf, tezos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" @@ -37,23 +37,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Tezos address for specified path.""" address_n = tools.parse_path(address) - return tezos.get_address(client, address_n, show_display, chunkify) + return tezos.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Tezos public key.""" address_n = tools.parse_path(address) - return tezos.get_public_key(client, address_n, show_display) + return tezos.get_public_key(session, address_n, show_display) @cli.command() @@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> messages.TezosSignedTx: """Sign Tezos transaction.""" address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) - return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) + return tezos.sign_tx(session, address_n, msg, chunkify=chunkify) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 60f8e8d3092..b3a885e4c8d 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -24,9 +24,12 @@ import click -from .. import __version__, log, messages, protobuf, ui -from ..client import TrezorClient +from .. import __version__, log, messages, protobuf +from ..client import ProtocolVersion, TrezorClient from ..transport import DeviceIsBusy, enumerate_devices +from ..transport.session import Session +from ..transport.thp import channel_database +from ..transport.thp.channel_database import get_channel_db from ..transport.udp import UdpTransport from . import ( AliasedGroup, @@ -50,6 +53,7 @@ stellar, tezos, with_client, + with_session, ) F = TypeVar("F", bound=Callable) @@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None: "--record", help="Record screen changes into a specified directory.", ) +@click.option( + "-n", + "--no-store", + is_flag=True, + help="Do not store channels data between commands.", + default=False, +) @click.version_option(version=__version__) @click.pass_context def cli_main( @@ -204,9 +215,10 @@ def cli_main( script: bool, session_id: Optional[str], record: Optional[str], + no_store: bool, ) -> None: configure_logging(verbose) - + channel_database.set_channel_database(should_not_store=no_store) bytes_session_id: Optional[bytes] = None if session_id is not None: try: @@ -214,6 +226,7 @@ def cli_main( except ValueError: raise click.ClickException(f"Not a valid session id: {session_id}") + # ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) # Optionally record the screen into a specified directory. @@ -285,18 +298,23 @@ def format_device_name(features: messages.Features) -> str: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: - return enumerate_devices() + for d in enumerate_devices(): + print(d.get_path()) + return + + from . import get_client for transport in enumerate_devices(): try: - client = TrezorClient(transport, ui=ui.ClickUI()) + client = get_client(transport) description = format_device_name(client.features) - client.end_session() + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) except DeviceIsBusy: description = "Device is in use by another process" - except Exception: - description = "Failed to read details" - click.echo(f"{transport} - {description}") + except Exception as e: + description = "Failed to read details " + str(type(e)) + click.echo(f"{transport.get_path()} - {description}") return None @@ -314,15 +332,19 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_client -def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: +@with_session(empty_passphrase=True) +def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" - return client.ping(message, button_protection=button_protection) + + # TODO return short-circuit from old client for old Trezors + return session.ping(message, button_protection) @cli.command() @click.pass_obj -def get_session(obj: TrezorConnection) -> str: +def get_session( + obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False +) -> str: """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with @@ -336,23 +358,44 @@ def get_session(obj: TrezorConnection) -> str: obj.session_id = None with obj.client_context() as client: + if client.features.model == "1" and client.version < (1, 9, 0): raise click.ClickException( "Upgrade your firmware to enable session support." ) - client.ensure_unlocked() - if client.session_id is None: + # client.ensure_unlocked() + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + if session.id is None: raise click.ClickException("Passphrase not enabled or firmware too old.") else: - return client.session_id.hex() + return session.id.hex() @cli.command() -@with_client -def clear_session(client: "TrezorClient") -> None: +@with_session(must_resume=True, empty_passphrase=True) +def clear_session(session: "Session") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" - return client.clear_session() + if session is None: + click.echo("Cannot clear session as it was not properly resumed.") + return + session.call(messages.LockDevice()) + session.end() + # TODO different behaviour than main, not sure if ok + + +@cli.command() +def delete_channels() -> None: + """ + Delete cached channels. + + Do not use together with the `-n` (`--no-store`) flag, + as the JSON database will not be deleted in that case. + """ + get_channel_db().clear_stored_channels() + click.echo("Deleted stored channels") @cli.command() diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 2ec853dfd3b..d82554dd93c 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -13,25 +13,24 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import logging import os -import warnings -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +import typing as t +from enum import IntEnum -from mnemonic import Mnemonic +from . import mapping, messages, models +from .mapping import ProtobufMapping +from .tools import parse_path +from .transport import Transport, get_transport +from .transport.thp.channel_data import ChannelData +from .transport.thp.protocol_and_channel import ProtocolAndChannel +from .transport.thp.protocol_v1 import ProtocolV1 +from .transport.thp.protocol_v2 import ProtocolV2 -from . import exceptions, mapping, messages, models -from .log import DUMP_BYTES -from .messages import Capability -from .tools import expect, parse_path, session - -if TYPE_CHECKING: - from .protobuf import MessageType - from .transport import Transport - from .ui import TrezorClientUI - -UI = TypeVar("UI", bound="TrezorClientUI") +if t.TYPE_CHECKING: + from .transport.session import Session LOG = logging.getLogger(__name__) @@ -48,445 +47,653 @@ """.strip() -def get_default_client( - path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any -) -> "TrezorClient": - """Get a client for a connected Trezor device. - - Returns a TrezorClient instance with minimum fuss. - - If path is specified, does a prefix-search for the specified device. Otherwise, uses - the value of TREZOR_PATH env variable, or finds first connected Trezor. - If no UI is supplied, instantiates the default CLI UI. - """ - from .transport import get_transport - from .ui import ClickUI - - if path is None: - path = os.getenv("TREZOR_PATH") - - transport = get_transport(path, prefix_search=True) - if ui is None: - ui = ClickUI() +LOG = logging.getLogger(__name__) - return TrezorClient(transport, ui, **kwargs) +class ProtocolVersion(IntEnum): + UNKNOWN = 0x00 + PROTOCOL_V1 = 0x01 # Codec + PROTOCOL_V2 = 0x02 # THP -class TrezorClient(Generic[UI]): - """Trezor client, a connection to a Trezor device. - This class allows you to manage connection state, send and receive protobuf - messages, handle user interactions, and perform some generic tasks - (send a cancel message, initialize or clear a session, ping the device). - """ +class TrezorClient: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None - model: models.TrezorModel - transport: "Transport" - session_id: Optional[bytes] - ui: UI - features: messages.Features + _management_session: Session | None = None + _features: messages.Features | None = None + _protocol_version: int + _has_setup_pin: bool = False # Should by used only by conftest def __init__( self, - transport: "Transport", - ui: UI, - session_id: Optional[bytes] = None, - derive_cardano: Optional[bool] = None, - model: Optional[models.TrezorModel] = None, - _init_device: bool = True, + transport: Transport, + protobuf_mapping: ProtobufMapping | None = None, + protocol: ProtocolAndChannel | None = None, ) -> None: - """Create a TrezorClient instance. - - You have to provide a `transport`, i.e., a raw connection to the device. You can - use `trezorlib.transport.get_transport` to find one. - - You have to provide a UI implementation for the three kinds of interaction: - - button request (notify the user that their interaction is needed) - - PIN request (on T1, ask the user to input numbers for a PIN matrix) - - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for - details. - - You can supply a `session_id` you might have saved in the previous session. If - you do, the user might not need to enter their passphrase again. - - You can provide Trezor model information. If not provided, it is detected from - the model name reported at initialization time. - - By default, the instance will open a connection to the Trezor device, send an - `Initialize` message, set up the `features` field from the response, and connect - to a session. By specifying `_init_device=False`, this step is skipped. Notably, - this means that `client.features` is unset. Use `client.init_device()` or - `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. - Only use this if you are _sure_ that you know what you are doing. This feature - might be removed at any time. - """ - LOG.info(f"creating client instance for device: {transport.get_path()}") - # Here, self.model could be set to None. Unless _init_device is False, it will - # get correctly reconfigured as part of the init_device flow. - self.model = model # type: ignore ["None" is incompatible with "TrezorModel"] - if self.model: - self.mapping = self.model.default_mapping - else: - self.mapping = mapping.DEFAULT_MAPPING self.transport = transport - self.ui = ui - self.session_counter = 0 - self.session_id = session_id - if _init_device: - self.init_device(session_id=session_id, derive_cardano=derive_cardano) - - def open(self) -> None: - if self.session_counter == 0: - self.transport.begin_session() - self.session_counter += 1 - - def close(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - # TODO call EndSession here? - self.transport.end_session() - - def cancel(self) -> None: - self._raw_write(messages.Cancel()) - - def call_raw(self, msg: "MessageType") -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self._raw_write(msg) - return self._raw_read() - - def _raw_write(self, msg: "MessageType") -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - LOG.debug( - f"sending message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - msg_type, msg_bytes = self.mapping.encode(msg) - LOG.log( - DUMP_BYTES, - f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - self.transport.write(msg_type, msg_bytes) - - def _raw_read(self) -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - msg_type, msg_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - msg = self.mapping.decode(msg_type, msg_bytes) - LOG.debug( - f"received message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - return msg - - def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": - try: - pin = self.ui.get_pin(msg.type) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - self.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - - resp = self.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise exceptions.PinException(resp.code, resp.message) + + if protobuf_mapping is None: + self.mapping = mapping.DEFAULT_MAPPING else: - return resp - - def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": - available_on_device = Capability.PassphraseEntry in self.features.capabilities - - def send_passphrase( - passphrase: Optional[str] = None, on_device: Optional[bool] = None - ) -> "MessageType": - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = self.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - self.session_id = resp.state - resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - passphrase = self.ui.get_passphrase(available_on_device=available_on_device) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - self.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) - - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - self.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - self._raw_write(messages.ButtonAck()) - self.ui.button_request(msg) - return self._raw_read() - - @session - def call(self, msg: "MessageType") -> "MessageType": - self.check_firmware_version() - resp = self.call_raw(msg) - while True: - if isinstance(resp, messages.PinMatrixRequest): - resp = self._callback_pin(resp) - elif isinstance(resp, messages.PassphraseRequest): - resp = self._callback_passphrase(resp) - elif isinstance(resp, messages.ButtonRequest): - resp = self._callback_button(resp) - elif isinstance(resp, messages.Failure): - if resp.code == messages.FailureType.ActionCancelled: - raise exceptions.Cancelled - raise exceptions.TrezorFailure(resp) + self.mapping = protobuf_mapping + if protocol is None: + self.protocol = self._get_protocol() + else: + self.protocol = protocol + self.protocol.mapping = self.mapping + + if isinstance(self.protocol, ProtocolV1): + self._protocol_version = ProtocolVersion.PROTOCOL_V1 + elif isinstance(self.protocol, ProtocolV2): + self._protocol_version = ProtocolVersion.PROTOCOL_V2 + else: + self._protocol_version = ProtocolVersion.UNKNOWN + + @classmethod + def resume( + cls, + transport: Transport, + channel_data: ChannelData, + protobuf_mapping: ProtobufMapping | None = None, + ) -> TrezorClient: + if protobuf_mapping is None: + protobuf_mapping = mapping.DEFAULT_MAPPING + protocol_v1 = ProtocolV1(transport, protobuf_mapping) + if channel_data.protocol_version == 2: + try: + protocol_v1.write(messages.Ping(message="Sanity check - to resume")) + except Exception as e: + print(type(e)) + response = protocol_v1.read() + if ( + isinstance(response, messages.Failure) + and response.code == messages.FailureType.InvalidProtocol + ): + protocol = ProtocolV2(transport, protobuf_mapping, channel_data) + protocol.write(0, messages.Ping()) + response = protocol.read(0) + if not isinstance(response, messages.Success): + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + LOG.debug("Protocol V2 detected - can be resumed") else: - return resp - - def _refresh_features(self, features: messages.Features) -> None: - """Update internal fields based on passed-in Features message.""" - - if not self.model: - # Trezor Model One bootloader 1.8.0 or older does not send model name - model = models.by_internal_name(features.internal_model) - if model is None: - model = models.by_name(features.model or "1") - if model is None: - raise RuntimeError( - "Unsupported Trezor model" - f" (internal_model: {features.internal_model}, model: {features.model})" - ) - self.model = model - - if features.vendor not in self.model.vendors: - raise RuntimeError("Unsupported device") - - self.features = features - self.version = ( - self.features.major_version, - self.features.minor_version, - self.features.patch_version, - ) - self.check_firmware_version(warn_only=True) - if self.features.session_id is not None: - self.session_id = self.features.session_id - self.features.session_id = None + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + else: + protocol = ProtocolV1(transport, protobuf_mapping, channel_data) + return TrezorClient(transport, protobuf_mapping, protocol) - @session - def refresh_features(self) -> messages.Features: - """Reload features from the device. + def get_session( + self, + passphrase: str | object | None = None, + derive_cardano: bool = False, + ) -> Session: + """ + Returns initialized session (with derived seed). - Should be called after changing settings or performing operations that affect - device state. + Will fail if the device is not initialized """ - resp = self.call_raw(messages.GetFeatures()) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to GetFeatures") - self._refresh_features(resp) - return resp - - @session - def init_device( - self, - *, - session_id: Optional[bytes] = None, - new_session: bool = False, - derive_cardano: Optional[bool] = None, - ) -> Optional[bytes]: - """Initialize the device and return a session ID. - - You can optionally specify a session ID. If the session still exists on the - device, the same session ID will be returned and the session is resumed. - Otherwise a different session ID is returned. - - Specify `new_session=True` to open a fresh session. Since firmware version - 1.9.0/2.3.0, the previous session will remain cached on the device, and can be - resumed by calling `init_device` again with the appropriate session ID. - - If neither `new_session` nor `session_id` is specified, the current session ID - will be reused. If no session ID was cached, a new session ID will be allocated - and returned. - - # Version notes: - - Trezor One older than 1.9.0 does not have session management. Optional arguments - have no effect and the function returns None - - Trezor T older than 2.3.0 does not have session cache. Requesting a new session - will overwrite the old one. In addition, this function will always return None. - A valid session_id can be obtained from the `session_id` attribute, but only - after a passphrase-protected call is performed. You can use the following code: - - >>> client.init_device() - >>> client.ensure_unlocked() - >>> valid_session_id = client.session_id + from .transport.session import SessionV1, SessionV2 + + if isinstance(self.protocol, ProtocolV1): + if passphrase is None: + passphrase = "" + return SessionV1.new(self, passphrase, derive_cardano) + if isinstance(self.protocol, ProtocolV2): + assert isinstance(passphrase, str) or passphrase is None + return SessionV2.new(self, passphrase, derive_cardano) + raise NotImplementedError # TODO + + def resume_session(self, session: Session): + """ + Note: this function potentially modifies the input session. """ - if new_session: - self.session_id = None - elif session_id is not None: - self.session_id = session_id - - resp = self.call_raw( - messages.Initialize( - session_id=self.session_id, - derive_cardano=derive_cardano, + from .debuglink import SessionDebugWrapper + from .transport.session import SessionV1, SessionV2 + + if isinstance(session, SessionDebugWrapper): + session = session._session + + if isinstance(session, SessionV2): + return session + elif isinstance(session, SessionV1): + session.init_session() + return session + + else: + raise NotImplementedError + + def get_management_session(self, new_session: bool = False) -> Session: + from .transport.session import SessionV1, SessionV2 + + if not new_session and self._management_session is not None: + return self._management_session + if isinstance(self.protocol, ProtocolV1): + self._management_session = SessionV1.new( + client=self, + passphrase="", + derive_cardano=False, ) + elif isinstance(self.protocol, ProtocolV2): + self._management_session = SessionV2(client=self, id=b"\x00") + assert self._management_session is not None + return self._management_session + + @property + def features(self) -> messages.Features: + if self._features is None: + self._features = self.protocol.get_features() + assert self._features is not None + return self._features + + @property + def protocol_version(self) -> int: + return self._protocol_version + + @property + def model(self) -> models.TrezorModel: + f = self.features + model = models.by_name(f.model or "1") + + if model is None: + raise RuntimeError( + "Unsupported Trezor model" + f" (internal_model: {f.internal_model}, model: {f.model})" + ) + return model + + @property + def version(self) -> tuple[int, int, int]: + f = self.features + ver = ( + f.major_version, + f.minor_version, + f.patch_version, ) - if isinstance(resp, messages.Failure): - # can happen if `derive_cardano` does not match the current session - raise exceptions.TrezorFailure(resp) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to Initialize") - - if self.session_id is not None and resp.session_id == self.session_id: - LOG.info("Successfully resumed session") - elif session_id is not None: - LOG.info("Failed to resume session") - - # TT < 2.3.0 compatibility: - # _refresh_features will clear out the session_id field. We want this function - # to return its value, so that callers can rely on it being either a valid - # session_id, or None if we can't do that. - # Older TT FW does not report session_id in Features and self.session_id might - # be invalid because TT will not allocate a session_id until a passphrase - # exchange happens. - reported_session_id = resp.session_id - self._refresh_features(resp) - return reported_session_id - - def is_outdated(self) -> bool: - if self.features.bootloader_mode: - return False - return self.version < self.model.minimum_version - - def check_firmware_version(self, warn_only: bool = False) -> None: - if self.is_outdated(): - if warn_only: - warnings.warn("Firmware is out of date", stacklevel=2) - else: - raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) + return ver - @expect(messages.Success, field="message", ret_type=str) - def ping( - self, - msg: str, - button_protection: bool = False, - ) -> "MessageType": - # We would like ping to work on any valid TrezorClient instance, but - # due to the protection modes, we need to go through self.call, and that will - # raise an exception if the firmware is too old. - # So we short-circuit the simplest variant of ping with call_raw. - if not button_protection: - # XXX this should be: `with self:` - try: - self.open() - resp = self.call_raw(messages.Ping(message=msg)) - if isinstance(resp, messages.ButtonRequest): - # device is PIN-locked. - # respond and hope for the best - resp = self._callback_button(resp) - return resp - finally: - self.close() - - return self.call( - messages.Ping(message=msg, button_protection=button_protection) - ) + def refresh_features(self) -> None: + self.protocol.update_features() + self._features = self.protocol.get_features() - def get_device_id(self) -> Optional[str]: - return self.features.device_id + def _get_protocol(self) -> ProtocolAndChannel: + self.transport.open() - @session - def lock(self, *, _refresh_features: bool = True) -> None: - """Lock the device. + protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING) - If the device does not have a PIN configured, this will do nothing. - Otherwise, a lock screen will be shown and the device will prompt for PIN - before further actions. + protocol.write(messages.Initialize()) - This call does _not_ invalidate passphrase cache. If passphrase is in use, - the device will not prompt for it after unlocking. + response = protocol.read() + self.transport.close() + if isinstance(response, messages.Failure): + if response.code == messages.FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol = ProtocolV2(self.transport, self.mapping) + return protocol - To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate - passphrase cache, use `clear_session()`. - """ - # Private argument _refresh_features can be used internally to avoid - # refreshing in cases where we will refresh soon anyway. This is used - # in TrezorClient.clear_session() - self.call(messages.LockDevice()) - if _refresh_features: - self.refresh_features() - - @session - def ensure_unlocked(self) -> None: - """Ensure the device is unlocked and a passphrase is cached. - - If the device is locked, this will prompt for PIN. If passphrase is enabled - and no passphrase is cached for the current session, the device will also - prompt for passphrase. - - After calling this method, further actions on the device will not prompt for - PIN or passphrase until the device is locked or the session becomes invalid. - """ - from .btc import get_address - get_address(self, "Testnet", PASSPHRASE_TEST_PATH) - self.refresh_features() +def get_default_client( + path: t.Optional[str] = None, + **kwargs: t.Any, +) -> "TrezorClient": + """Get a client for a connected Trezor device. + + Returns a TrezorClient instance with minimum fuss. + + If path is specified, does a prefix-search for the specified device. Otherwise, uses + the value of TREZOR_PATH env variable, or finds first connected Trezor. + If no UI is supplied, instantiates the default CLI UI. + """ - def end_session(self) -> None: - """Close the current session and clear cached passphrase. + if path is None: + path = os.getenv("TREZOR_PATH") - The session will become invalid until `init_device()` is called again. - If passphrase is enabled, further actions will prompt for it again. + transport = get_transport(path, prefix_search=True) - This is a no-op in bootloader mode, as it does not support session management. - """ - # since: 2.3.4, 1.9.4 - try: - if not self.features.bootloader_mode: - self.call(messages.EndSession()) - except exceptions.TrezorFailure: - # A failure most likely means that the FW version does not support - # the EndSession call. We ignore the failure and clear the local session_id. - # The client-side end result is identical. - pass - self.session_id = None - - @session - def clear_session(self) -> None: - """Lock the device and present a fresh session. - - The current session will be invalidated and a new one will be started. If the - device has PIN enabled, it will become locked. - - Equivalent to calling `lock()`, `end_session()` and `init_device()`. - """ - self.lock(_refresh_features=False) - self.end_session() - self.init_device(new_session=True) + return TrezorClient(transport, **kwargs) + + +# class TrezorClient(t.Generic[UI]): +# """Trezor client, a connection to a Trezor device. + +# This class allows you to manage connection state, send and receive protobuf +# messages, handle user interactions, and perform some generic tasks +# (send a cancel message, initialize or clear a session, ping the device). +# """ + +# model: models.TrezorModel +# transport: "Transport" +# session_id: t.Optional[bytes] +# ui: UI +# features: messages.Features + +# def __init__( +# self, +# transport: "Transport", +# ui: UI, +# session_id: t.Optional[bytes] = None, +# derive_cardano: t.Optional[bool] = None, +# model: t.Optional[models.TrezorModel] = None, +# _init_device: bool = True, +# ) -> None: +# """Create a TrezorClient instance. + +# You have to provide a `transport`, i.e., a raw connection to the device. You can +# use `trezorlib.transport.get_transport` to find one. + +# You have to provide a UI implementation for the three kinds of interaction: +# - button request (notify the user that their interaction is needed) +# - PIN request (on T1, ask the user to input numbers for a PIN matrix) +# - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for +# details. + +# You can supply a `session_id` you might have saved in the previous session. If +# you do, the user might not need to enter their passphrase again. + +# You can provide Trezor model information. If not provided, it is detected from +# the model name reported at initialization time. + +# By default, the instance will open a connection to the Trezor device, send an +# `Initialize` message, set up the `features` field from the response, and connect +# to a session. By specifying `_init_device=False`, this step is skipped. Notably, +# this means that `client.features` is unset. Use `client.init_device()` or +# `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. +# Only use this if you are _sure_ that you know what you are doing. This feature +# might be removed at any time. +# """ +# LOG.info(f"creating client instance for device: {transport.get_path()}") +# # Here, self.model could be set to None. Unless _init_device is False, it will +# # get correctly reconfigured as part of the init_device flow. +# self.model = model # type: ignre ["None" is incompatible with "TrezorModel"] +# if self.model: +# self.mapping = self.model.default_mapping +# else: +# self.mapping = mapping.DEFAULT_MAPPING +# self.transport = transport +# self.ui = ui +# self.session_counter = 0 +# self.session_id = session_id +# if _init_device: +# self.init_device(session_id=session_id, derive_cardano=derive_cardano) +# self.resume_session() + +# def open(self) -> None: +# if self.session_counter == 0: +# session_id = self.transport.resume_session(b"") +# if self.session_id != session_id: +# print("Failed to resume session, allocated a new session") +# self.session_id = session_id +# self.transport.deprecated_begin_session() +# self.session_counter += 1 + +# def resume_session(self) -> None: +# new_id = self.transport.resume_session(self.session_id or b"") +# if self.session_id != new_id: +# print("Failed to resume session, allocated a new session") +# self.session_id = new_id + +# def close(self) -> None: +# self.session_counter = max(self.session_counter - 1, 0) +# if self.session_counter == 0: +# # TODO call EndSession here? +# self.transport.deprecated_end_session() + +# def cancel(self) -> None: +# self._raw_write(messages.Cancel()) + +# def call_raw(self, msg: "MessageType") -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 + +# self._raw_write(msg) +# x = self._raw_read() +# return x + +# def _raw_write(self, msg: "MessageType") -> None: +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# LOG.debug( +# f"sending message: {msg.__class__.__name__}", +# extra={"protobuf": msg}, +# ) +# msg_type, msg_bytes = self.mapping.encode(msg) +# LOG.log( +# DUMP_BYTES, +# f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", +# ) +# self.transport.write(msg_type, msg_bytes) + +# def _raw_read(self) -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# msg_type, msg_bytes = self.transport.read() +# print("type/data", msg_type, msg_bytes) +# LOG.log( +# DUMP_BYTES, +# f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", +# ) +# msg = self.mapping.decode(msg_type, msg_bytes) +# LOG.debug( +# f"received message: {msg.__class__.__name__}", +# extra={"protobuf": msg}, +# ) +# return msg + +# def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": +# try: +# pin = self.ui.get_pin(msg.type) +# except exceptions.Cancelled: +# self.call_raw(messages.Cancel()) +# raise + +# if any(d not in "123456789" for d in pin) or not ( +# 1 <= len(pin) <= MAX_PIN_LENGTH +# ): +# self.call_raw(messages.Cancel()) +# raise ValueError("Invalid PIN provided") + +# resp = self.call_raw(messages.PinMatrixAck(pin=pin)) +# if isinstance(resp, messages.Failure) and resp.code in ( +# messages.FailureType.PinInvalid, +# messages.FailureType.PinCancelled, +# messages.FailureType.PinExpected, +# ): +# raise exceptions.PinException(resp.code, resp.message) +# else: +# return resp + +# def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": +# available_on_device = Capability.PassphraseEntry in self.features.capabilities + +# def send_passphrase( +# passphrase: t.Optional[str] = None, on_device: t.Optional[bool] = None +# ) -> "MessageType": +# msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) +# resp = self.call_raw(msg) +# if isinstance(resp, messages.Deprecated_PassphraseStateRequest): +# self.session_id = resp.state +# resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) +# return resp + +# # short-circuit old style entry +# if msg._on_device is True: +# return send_passphrase(None, None) + +# try: +# passphrase = self.ui.get_passphrase(available_on_device=available_on_device) +# except exceptions.Cancelled: +# self.call_raw(messages.Cancel()) +# raise + +# if passphrase is PASSPHRASE_ON_DEVICE: +# if not available_on_device: +# self.call_raw(messages.Cancel()) +# raise RuntimeError("Device is not capable of entering passphrase") +# else: +# return send_passphrase(on_device=True) + +# # else process host-entered passphrase +# if not isinstance(passphrase, str): +# raise RuntimeError("Passphrase must be a str") +# passphrase = Mnemonic.normalize_string(passphrase) +# if len(passphrase) > MAX_PASSPHRASE_LENGTH: +# self.call_raw(messages.Cancel()) +# raise ValueError("Passphrase too long") + +# return send_passphrase(passphrase, on_device=False) + +# def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# # do this raw - send ButtonAck first, notify UI later +# self._raw_write(messages.ButtonAck()) +# self.ui.button_request(msg) +# return self._raw_read() + +# @session +# def call(self, msg: "MessageType") -> "MessageType": +# self.check_firmware_version() +# resp = self.call_raw(msg) +# while True: +# if isinstance(resp, messages.PinMatrixRequest): +# resp = self._callback_pin(resp) +# elif isinstance(resp, messages.PassphraseRequest): +# resp = self._callback_passphrase(resp) +# elif isinstance(resp, messages.ButtonRequest): +# resp = self._callback_button(resp) +# elif isinstance(resp, messages.Failure): +# print("self.call-failure") + +# if resp.code == messages.FailureType.ActionCancelled: +# raise exceptions.Cancelled +# raise exceptions.TrezorFailure(resp) +# else: +# print("self.call-end") +# return resp + +# def _refresh_features(self, features: messages.Features) -> None: +# """Update internal fields based on passed-in Features message.""" + +# if not self.model: +# # Trezor Model One bootloader 1.8.0 or older does not send model name +# model = models.by_internal_name(features.internal_model) +# if model is None: +# model = models.by_name(features.model or "1") +# if model is None: +# raise RuntimeError( +# "Unsupported Trezor model" +# f" (internal_model: {features.internal_model}, model: {features.model})" +# ) +# self.model = model + +# if features.vendor not in self.model.vendors: +# raise RuntimeError("Unsupported device") + +# self.features = features +# self.version = ( +# self.features.major_version, +# self.features.minor_version, +# self.features.patch_version, +# ) +# self.check_firmware_version(warn_only=True) +# if self.features.session_id is not None: +# self.session_id = self.features.session_id +# self.features.session_id = None + +# @session +# def refresh_features(self) -> messages.Features: +# """Reload features from the device. + +# Should be called after changing settings or performing operations that affect +# device state. +# """ +# resp = self.call_raw(messages.GetFeatures()) +# if not isinstance(resp, messages.Features): +# raise exceptions.TrezorException("Unexpected response to GetFeatures") +# self._refresh_features(resp) +# return resp + +# def init_device( +# self, +# *, +# session_id: t.Optional[bytes] = None, +# new_session: bool = False, +# derive_cardano: t.Optional[bool] = None, +# ) -> t.Optional[bytes]: +# """Initialize the device and return a session ID. + +# You can optionally specify a session ID. If the session still exists on the +# device, the same session ID will be returned and the session is resumed. +# Otherwise a different session ID is returned. + +# Specify `new_session=True` to open a fresh session. Since firmware version +# 1.9.0/2.3.0, the previous session will remain cached on the device, and can be +# resumed by calling `init_device` again with the appropriate session ID. + +# If neither `new_session` nor `session_id` is specified, the current session ID +# will be reused. If no session ID was cached, a new session ID will be allocated +# and returned. + +# # Version notes: + +# Trezor One older than 1.9.0 does not have session management. Optional arguments +# have no effect and the function returns None + +# Trezor T older than 2.3.0 does not have session cache. Requesting a new session +# will overwrite the old one. In addition, this function will always return None. +# A valid session_id can be obtained from the `session_id` attribute, but only +# after a passphrase-protected call is performed. You can use the following code: + +# >>> client.init_device() +# >>> client.ensure_unlocked() +# >>> valid_session_id = client.session_id +# """ +# if new_session: +# self.session_id = None +# elif session_id is not None: +# self.session_id = session_id + +# print("before init conn") + +# resp = self.transport.initialize_connection( +# mapping=self.mapping, +# session_id=session_id, +# derive_cardano=derive_cardano, +# ) +# print("here") +# if isinstance(resp, messages.Failure): +# # can happen if `derive_cardano` does not match the current session +# raise exceptions.TrezorFailure(resp) +# if not isinstance(resp, messages.Features): +# raise exceptions.TrezorException("Unexpected response to Initialize") + +# if self.session_id is not None and resp.session_id == self.session_id: +# LOG.info("Successfully resumed session") +# elif session_id is not None: +# LOG.info("Failed to resume session") + +# # TT < 2.3.0 compatibility: +# # _refresh_features will clear out the session_id field. We want this function +# # to return its value, so that callers can rely on it being either a valid +# # session_id, or None if we can't do that. +# # Older TT FW does not report session_id in Features and self.session_id might +# # be invalid because TT will not allocate a session_id until a passphrase +# # exchange happens. +# reported_session_id = resp.session_id +# self._refresh_features(resp) +# print("there:", reported_session_id) +# return reported_session_id + +# def is_outdated(self) -> bool: +# if self.features.bootloader_mode: +# return False +# return self.version < self.model.minimum_version + +# def check_firmware_version(self, warn_only: bool = False) -> None: +# if self.is_outdated(): +# if warn_only: +# warnings.warn("Firmware is out of date", stacklevel=2) +# else: +# raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) + +# @expect(messages.Success, field="message", ret_type=str) +# def ping( +# self, +# msg: str, +# button_protection: bool = False, +# ) -> "MessageType": +# # We would like ping to work on any valid TrezorClient instance, but +# # due to the protection modes, we need to go through self.call, and that will +# # raise an exception if the firmware is too old. +# # So we short-circuit the simplest variant of ping with call_raw. +# if not button_protection: +# # XXX this should be: `with self:` +# try: +# self.open() +# resp = self.call_raw(messages.Ping(message=msg)) +# if isinstance(resp, messages.ButtonRequest): +# # device is PIN-locked. +# # respond and hope for the best +# resp = self._callback_button(resp) +# return resp +# finally: +# self.close() + +# return self.call( +# messages.Ping(message=msg, button_protection=button_protection) +# ) + +# def get_device_id(self) -> t.Optional[str]: +# return self.features.device_id + +# @session +# def lock(self, *, _refresh_features: bool = True) -> None: +# """Lock the device. + +# If the device does not have a PIN configured, this will do nothing. +# Otherwise, a lock screen will be shown and the device will prompt for PIN +# before further actions. + +# This call does _not_ invalidate passphrase cache. If passphrase is in use, +# the device will not prompt for it after unlocking. + +# To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate +# passphrase cache, use `clear_session()`. +# """ +# # Private argument _refresh_features can be used internally to avoid +# # refreshing in cases where we will refresh soon anyway. This is used +# # in TrezorClient.clear_session() +# self.call(messages.LockDevice()) +# if _refresh_features: +# self.refresh_features() + +# @session +# def ensure_unlocked(self) -> None: +# """Ensure the device is unlocked and a passphrase is cached. + +# If the device is locked, this will prompt for PIN. If passphrase is enabled +# and no passphrase is cached for the current session, the device will also +# prompt for passphrase. + +# After calling this method, further actions on the device will not prompt for +# PIN or passphrase until the device is locked or the session becomes invalid. +# """ +# from .btc import get_address + +# get_address(self, "Testnet", PASSPHRASE_TEST_PATH) +# self.refresh_features() + +# def end_session(self) -> None: +# """Close the current session and clear cached passphrase. + +# The session will become invalid until `init_device()` is called again. +# If passphrase is enabled, further actions will prompt for it again. + +# This is a no-op in bootloader mode, as it does not support session management. +# """ +# # since: 2.3.4, 1.9.4 +# print("end session") +# try: +# if not self.features.bootloader_mode: +# self.transport.end_session(self.session_id or b"") +# # self.call(messages.EndSession()) +# except exceptions.TrezorFailure: +# # A failure most likely means that the FW version does not support +# # the EndSession call. We ignore the failure and clear the local session_id. +# # The client-side end result is identical. +# pass +# except ValueError as e: +# print(e) +# print(e.args) +# self.session_id = None + +# @session +# def clear_session(self) -> None: +# """Lock the device and present a fresh session. + +# The current session will be invalidated and a new one will be started. If the +# device has PIN enabled, it will become locked. + +# Equivalent to calling `lock()`, `end_session()` and `init_device()`. +# """ +# self.lock(_refresh_features=False) +# self.end_session() +# self.init_device(new_session=True) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 11fac1bc22a..707401cf1bf 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -21,47 +21,44 @@ import re import textwrap import time +import typing as t from contextlib import contextmanager from copy import deepcopy from datetime import datetime from enum import Enum, IntEnum, auto from itertools import zip_longest from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Iterable, - Iterator, - Sequence, - Tuple, - Union, -) from mnemonic import Mnemonic -from . import mapping, messages, models, protobuf -from .client import TrezorClient -from .exceptions import TrezorFailure +from . import btc, mapping, messages, models, protobuf +from .client import ( + MAX_PASSPHRASE_LENGTH, + MAX_PIN_LENGTH, + PASSPHRASE_ON_DEVICE, + TrezorClient, +) +from .exceptions import Cancelled, PinException, TrezorFailure from .log import DUMP_BYTES -from .messages import DebugWaitType -from .tools import expect +from .messages import Capability, DebugWaitType +from .tools import expect, parse_path +from .transport.session import Session, SessionV1 +from .transport.thp.protocol_v1 import ProtocolV1 -if TYPE_CHECKING: +if t.TYPE_CHECKING: from typing_extensions import Protocol from .messages import PinMatrixRequestType from .transport import Transport - ExpectedMessage = Union[ - protobuf.MessageType, type[protobuf.MessageType], "MessageFilter" + ExpectedMessage = t.Union[ + protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter" ] - AnyDict = Dict[str, Any] + AnyDict = t.Dict[str, t.Any] class InputFunc(Protocol): + def __call__( self, hold_ms: int | None = None, @@ -70,6 +67,7 @@ def __call__( EXPECTED_RESPONSES_CONTEXT_LINES = 3 +PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") LOG = logging.getLogger(__name__) @@ -104,11 +102,13 @@ def __init__(self, json_str: str) -> None: except json.JSONDecodeError: self.dict = {} - def top_level_value(self, key: str) -> Any: + def top_level_value(self, key: str) -> t.Any: return self.dict.get(key) - def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_objects_with_key_and_value( + self, key: str, value: t.Any + ) -> list["AnyDict"]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if data.get(key) == value: yield data @@ -121,7 +121,7 @@ def recursively_find(data: Any) -> Iterator[Any]: return list(recursively_find(self.dict)) def find_unique_object_with_key_and_value( - self, key: str, value: Any + self, key: str, value: t.Any ) -> AnyDict | None: objects = self.find_objects_with_key_and_value(key, value) if not objects: @@ -129,8 +129,10 @@ def find_unique_object_with_key_and_value( assert len(objects) == 1 return objects[0] - def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_values_by_key( + self, key: str, only_type: type | None = None + ) -> list[t.Any]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if key in data: yield data[key] @@ -148,8 +150,8 @@ def recursively_find(data: Any) -> Iterator[Any]: return values def find_unique_value_by_key( - self, key: str, default: Any, only_type: type | None = None - ) -> Any: + self, key: str, default: t.Any, only_type: type | None = None + ) -> t.Any: values = self.find_values_by_key(key, only_type=only_type) if not values: return default @@ -160,7 +162,7 @@ def find_unique_value_by_key( class LayoutContent(UnstructuredJSONReader): """Contains helper functions to extract specific parts of the layout.""" - def __init__(self, json_tokens: Sequence[str]) -> None: + def __init__(self, json_tokens: t.Sequence[str]) -> None: json_str = "".join(json_tokens) super().__init__(json_str) @@ -422,11 +424,13 @@ def input_func( class DebugLink: + def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: self.transport = transport self.allow_interactions = auto_interact self.mapping = mapping.DEFAULT_MAPPING + self.protocol = ProtocolV1(self.transport, self.mapping) # To be set by TrezorClientDebugLink (is not known during creation time) self.model: models.TrezorModel | None = None self.version: tuple[int, int, int] = (0, 0, 0) @@ -479,10 +483,16 @@ def set_screen_text_file(self, file_path: Path | None) -> None: self.screen_text_file = file_path def open(self) -> None: - self.transport.begin_session() + self.transport.open() + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_begin_session() def close(self) -> None: - self.transport.end_session() + pass + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_end_session() def _write(self, msg: protobuf.MessageType) -> None: if self.waiting_for_layout_change: @@ -499,15 +509,10 @@ def _write(self, msg: protobuf.MessageType) -> None: DUMP_BYTES, f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", ) - self.transport.write(msg_type, msg_bytes) + self.protocol.write(msg) def _read(self) -> protobuf.MessageType: - ret_type, ret_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}", - ) - msg = self.mapping.decode(ret_type, ret_bytes) + msg = self.protocol.read() # Collapse tokens to make log use less lines. msg_for_log = msg @@ -521,18 +526,27 @@ def _read(self) -> protobuf.MessageType: ) return msg - def _call(self, msg: protobuf.MessageType) -> Any: + def _call(self, msg: protobuf.MessageType) -> t.Any: self._write(msg) return self._read() - def state(self, wait_type: DebugWaitType | None = None) -> messages.DebugLinkState: + def state( + self, + wait_type: DebugWaitType | None = None, + thp_channel_id: bytes | None = None, + ) -> messages.DebugLinkState: if wait_type is None: wait_type = ( DebugWaitType.CURRENT_LAYOUT if self.has_global_layout else DebugWaitType.IMMEDIATE ) - result = self._call(messages.DebugLinkGetState(wait_layout=wait_type)) + result = self._call( + messages.DebugLinkGetState( + wait_layout=wait_type, + thp_channel_id=thp_channel_id, + ) + ) while not isinstance(result, (messages.Failure, messages.DebugLinkState)): result = self._read() if isinstance(result, messages.Failure): @@ -544,7 +558,7 @@ def read_layout(self) -> LayoutContent: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: # Next layout change will be caused by external event - # (e.g. device being auto-locked or as a result of device_handler.run(xxx)) + # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx)) # and not by our debug actions/decisions. # Resetting the debug state so we wait for the next layout change # (and do not return the current state). @@ -560,7 +574,7 @@ def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: return LayoutContent(obj.tokens) @contextmanager - def wait_for_layout_change(self) -> Iterator[LayoutContent]: + def wait_for_layout_change(self) -> t.Iterator[LayoutContent]: # set up a dummy layout content object to be yielded layout_content = LayoutContent( ["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("] @@ -622,7 +636,7 @@ def encode_pin(self, pin: str, matrix: str | None = None) -> str: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self) -> Tuple[str | None, int | None]: + def read_recovery_word(self) -> t.Tuple[str | None, int | None]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) @@ -700,7 +714,7 @@ def input(self, word: str, wait: bool | None = None) -> LayoutContent: def click( self, - click: Tuple[int, int], + click: t.Tuple[int, int], hold_ms: int | None = None, wait: bool | None = None, ) -> LayoutContent: @@ -862,10 +876,10 @@ def __init__(self, debuglink: DebugLink) -> None: self.clear() def clear(self) -> None: - self.pins: Iterator[str] | None = None + self.pins: t.Iterator[str] | None = None self.passphrase = "" - self.input_flow: Union[ - Generator[None, messages.ButtonRequest, None], object, None + self.input_flow: t.Union[ + t.Generator[None, messages.ButtonRequest, None], object, None ] = None def _default_input_flow(self, br: messages.ButtonRequest) -> None: @@ -896,7 +910,7 @@ def button_request(self, br: messages.ButtonRequest) -> None: raise AssertionError("input flow ended prematurely") else: try: - assert isinstance(self.input_flow, Generator) + assert isinstance(self.input_flow, t.Generator) self.input_flow.send(br) except StopIteration: self.input_flow = self.INPUT_FLOW_DONE @@ -918,12 +932,15 @@ def get_passphrase(self, available_on_device: bool) -> str: class MessageFilter: - def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None: + + def __init__( + self, message_type: t.Type[protobuf.MessageType], **fields: t.Any + ) -> None: self.message_type = message_type - self.fields: Dict[str, Any] = {} + self.fields: t.Dict[str, t.Any] = {} self.update_fields(**fields) - def update_fields(self, **fields: Any) -> "MessageFilter": + def update_fields(self, **fields: t.Any) -> "MessageFilter": for name, value in fields.items(): try: self.fields[name] = self.from_message_or_type(value) @@ -971,7 +988,7 @@ def match(self, message: protobuf.MessageType) -> bool: return True def to_string(self, maxwidth: int = 80) -> str: - fields: list[Tuple[str, str]] = [] + fields: list[t.Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -1001,7 +1018,8 @@ def to_string(self, maxwidth: int = 80) -> str: class MessageFilterGenerator: - def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: + + def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]: message_type = getattr(messages, key) return MessageFilter(message_type).update_fields @@ -1009,6 +1027,245 @@ def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: message_filters = MessageFilterGenerator() +class SessionDebugWrapper(Session): + def __init__(self, session: Session) -> None: + self._session = session + self.reset_debug_features() + if isinstance(session, SessionDebugWrapper): + raise Exception("Cannot wrap already wrapped session!") + + @property + def protocol_version(self) -> int: + return self.client.protocol_version + + @property + def client(self) -> TrezorClientDebugLink: + assert isinstance(self._session.client, TrezorClientDebugLink) + return self._session.client + + @property + def id(self) -> bytes: + return self._session.id + + def _write(self, msg: t.Any) -> None: + print("writing message:", msg.__class__.__name__) + self._session._write(self._filter_message(msg)) + + def _read(self) -> t.Any: + resp = self._filter_message(self._session._read()) + print("reading message:", resp.__class__.__name__) + if self.actual_responses is not None: + self.actual_responses.append(resp) + return resp + + def set_expected_responses( + self, + expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], + ) -> None: + """Set a sequence of expected responses to session calls. + + Within a given with-block, the list of received responses from device must + match the list of expected responses, otherwise an ``AssertionError`` is raised. + + If an expected response is given a field value other than ``None``, that field value + must exactly match the received field value. If a given field is ``None`` + (or unspecified) in the expected response, the received field value is not + checked. + + Each expected response can also be a tuple ``(bool, message)``. In that case, the + expected response is only evaluated if the first field is ``True``. + This is useful for differentiating sequences between Trezor models: + + >>> trezor_one = session.features.model == "1" + >>> session.set_expected_responses([ + >>> messages.ButtonRequest(code=ConfirmOutput), + >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), + >>> messages.Success(), + >>> ]) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + # make sure all items are (bool, message) tuples + expected_with_validity = ( + e if isinstance(e, tuple) else (True, e) for e in expected + ) + + # only apply those items that are (True, message) + self.expected_responses = [ + MessageFilter.from_message_or_type(expected) + for valid, expected in expected_with_validity + if valid + ] + self.actual_responses = [] + + def lock(self, *, _refresh_features: bool = True) -> None: + """Lock the device. + + If the device does not have a PIN configured, this will do nothing. + Otherwise, a lock screen will be shown and the device will prompt for PIN + before further actions. + + This call does _not_ invalidate passphrase cache. If passphrase is in use, + the device will not prompt for it after unlocking. + + To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate + passphrase cache, use `clear_session()`. + """ + # TODO update the documentation above + # Private argument _refresh_features can be used internally to avoid + # refreshing in cases where we will refresh soon anyway. This is used + # in TrezorClient.clear_session() + self.call(messages.LockDevice()) + if _refresh_features: + self.refresh_features() + + def cancel(self) -> None: + self._write(messages.Cancel()) + + def ensure_unlocked(self) -> None: + btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH) + self.refresh_features() + + def set_filter( + self, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ) -> None: + """Configure a filter function for a specified message type. + + The `callback` must be a function that accepts a protobuf message, and returns + a (possibly modified) protobuf message of the same type. Whenever a message + is sent or received that matches `message_type`, `callback` is invoked on the + message and its result is substituted for the original. + + Useful for test scenarios with an active malicious actor on the wire. + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + self.filters[message_type] = callback + + def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: + message_type = msg.__class__ + callback = self.filters.get(message_type) + if callable(callback): + return callback(deepcopy(msg)) + else: + return msg + + def reset_debug_features(self) -> None: + """Prepare the debugging session for a new testcase. + + Clears all debugging state that might have been modified by a testcase. + """ + self.in_with_statement = False + self.expected_responses: list[MessageFilter] | None = None + self.actual_responses: list[protobuf.MessageType] | None = None + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ] = {} + self.button_callback = self.client.button_callback + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self._session.passphrase_callback + self.passphrase = self._session.passphrase + + def __enter__(self) -> "SessionDebugWrapper": + # For usage in with/expected_responses + if self.in_with_statement: + raise RuntimeError("Do not nest!") + self.in_with_statement = True + return self + + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + # copy expected/actual responses before clearing them + expected_responses = self.expected_responses + actual_responses = self.actual_responses + + # grab a copy of the inputflow generator to raise an exception through it + if isinstance(self.client.ui, DebugUI): + input_flow = self.client.ui.input_flow + else: + input_flow = None + + self.reset_debug_features() + + if exc_type is None: + # If no other exception was raised, evaluate missed responses + # (raises AssertionError on mismatch) + self._verify_responses(expected_responses, actual_responses) + if isinstance(input_flow, t.Generator): + # Ensure that the input flow is exhausted + try: + input_flow.throw( + AssertionError("input flow continues past end of test") + ) + except StopIteration: + pass + + elif isinstance(input_flow, t.Generator): + # Propagate the exception through the input flow, so that we see in + # traceback where it is stuck. + input_flow.throw(exc_type, value, traceback) + + @classmethod + def _verify_responses( + cls, + expected: list[MessageFilter] | None, + actual: list[protobuf.MessageType] | None, + ) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + if expected is None and actual is None: + return + + assert expected is not None + assert actual is not None + + for i, (exp, act) in enumerate(zip_longest(expected, actual)): + if exp is None: + output = cls._expectation_lines(expected, i) + output.append("No more messages were expected, but we got:") + for resp in actual[i:]: + output.append( + textwrap.indent(protobuf.format_message(resp), " ") + ) + raise AssertionError("\n".join(output)) + + if act is None: + output = cls._expectation_lines(expected, i) + output.append("This and the following message was not received.") + raise AssertionError("\n".join(output)) + + if not exp.match(act): + output = cls._expectation_lines(expected, i) + output.append("Actually received:") + output.append(textwrap.indent(protobuf.format_message(act), " ")) + raise AssertionError("\n".join(output)) + + @staticmethod + def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: + start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) + stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) + output: list[str] = [] + output.append("Expected responses:") + if start_at > 0: + output.append(f" (...{start_at} previous responses omitted)") + for i in range(start_at, stop_at): + exp = expected[i] + prefix = " " if i != current else ">>> " + output.append(textwrap.indent(exp.to_string(), prefix)) + if stop_at < len(expected): + omitted = len(expected) - stop_at + output.append(f" (...{omitted} following responses omitted)") + + output.append("") + return output + + class TrezorClientDebugLink(TrezorClient): # This class implements automatic responses # and other functionality for unit tests @@ -1034,54 +1291,165 @@ def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: raise # set transport explicitly so that sync_responses can work + super().__init__(transport) + self.transport = transport + self.ui: DebugUI = DebugUI(self.debug) - self.reset_debug_features() + self.reset_debug_features(new_management_session=True) self.sync_responses() - super().__init__(transport, ui=self.ui) - # So that we can choose right screenshotting logic (T1 vs TT) # and know the supported debug capabilities self.debug.model = self.model self.debug.version = self.version + self.passphrase: str | None = None @property def layout_type(self) -> LayoutType: return self.debug.layout_type - def reset_debug_features(self) -> None: - """Prepare the debugging client for a new testcase. + def get_new_client(self) -> TrezorClientDebugLink: + return TrezorClientDebugLink(self.transport, self.debug.allow_interactions) + + def reset_debug_features(self, new_management_session: bool = False) -> None: + """ + Prepare the debugging client for a new testcase. Clears all debugging state that might have been modified by a testcase. """ self.ui: DebugUI = DebugUI(self.debug) + # self.pin_callback = self.ui.debug_callback_button self.in_with_statement = False self.expected_responses: list[MessageFilter] | None = None self.actual_responses: list[protobuf.MessageType] | None = None - self.filters: dict[ - type[protobuf.MessageType], - Callable[[protobuf.MessageType], protobuf.MessageType] | None, + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} + if new_management_session: + self._management_session = self.get_management_session(new_session=True) + + @property + def button_callback(self): + + def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + session._write(messages.ButtonAck()) + self.ui.button_request(msg) + return session._read() + + return _callback_button + + @property + def pin_callback(self): + + def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any: + try: + pin = self.ui.get_pin(msg.type) + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if any(d not in "123456789" for d in pin) or not ( + 1 <= len(pin) <= MAX_PIN_LENGTH + ): + session.call_raw(messages.Cancel()) + raise ValueError("Invalid PIN provided") + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp + + return _callback_pin + + @property + def passphrase_callback(self): + def _callback_passphrase( + session: Session, msg: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) + + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> t.Any: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + # session.session_id = resp.state + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp + + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + if session.passphrase is None and isinstance(session, SessionV1): + passphrase = self.ui.get_passphrase( + available_on_device=available_on_device + ) + else: + passphrase = session.passphrase + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: + session.call_raw(messages.Cancel()) + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) + + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") + + return send_passphrase(passphrase, on_device=False) + + return _callback_passphrase def ensure_open(self) -> None: """Only open session if there isn't already an open one.""" - if self.session_counter == 0: - self.open() + # if self.session_counter == 0: + # self.open() + # TODO check if is this needed def open(self) -> None: - super().open() - if self.session_counter == 1: - self.debug.open() + pass + # TODO is this needed? + # self.debug.open() def close(self) -> None: - if self.session_counter == 1: - self.debug.close() - super().close() + pass + # TODO is this needed? + # self.debug.close() + + def get_session( + self, + passphrase: str | object | None = "", + derive_cardano: bool = False, + ) -> Session: + if isinstance(passphrase, str): + passphrase = Mnemonic.normalize_string(passphrase) + return super().get_session(passphrase, derive_cardano) def set_filter( self, - message_type: type[protobuf.MessageType], - callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ) -> None: """Configure a filter function for a specified message type. @@ -1106,7 +1474,8 @@ def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: return msg def set_input_flow( - self, input_flow: Generator[None, messages.ButtonRequest | None, None] + self, + input_flow: t.Generator[None, messages.ButtonRequest | None, None], ) -> None: """Configure a sequence of input events for the current with-block. @@ -1140,6 +1509,7 @@ def set_input_flow( if not hasattr(input_flow, "send"): raise RuntimeError("input_flow should be a generator function") self.ui.input_flow = input_flow + assert input_flow is not None input_flow.send(None) # start the generator def watch_layout(self, watch: bool = True) -> None: @@ -1162,7 +1532,7 @@ def __enter__(self) -> "TrezorClientDebugLink": self.in_with_statement = True return self - def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # copy expected/actual responses before clearing them @@ -1175,20 +1545,21 @@ def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: else: input_flow = None - self.reset_debug_features() + self.reset_debug_features(new_management_session=False) if exc_type is None: # If no other exception was raised, evaluate missed responses # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, Generator): + elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. input_flow.throw(exc_type, value, traceback) def set_expected_responses( - self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] + self, + expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], ) -> None: """Set a sequence of expected responses to client calls. @@ -1227,7 +1598,7 @@ def set_expected_responses( ] self.actual_responses = [] - def use_pin_sequence(self, pins: Iterable[str]) -> None: + def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. """ @@ -1235,6 +1606,7 @@ def use_pin_sequence(self, pins: Iterable[str]) -> None: def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" + self.passphrase = passphrase self.ui.passphrase = Mnemonic.normalize_string(passphrase) def use_mnemonic(self, mnemonic: str) -> None: @@ -1244,15 +1616,14 @@ def use_mnemonic(self, mnemonic: str) -> None: def _raw_read(self) -> protobuf.MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - resp = super()._raw_read() + resp = self.get_management_session()._read() resp = self._filter_message(resp) if self.actual_responses is not None: self.actual_responses.append(resp) return resp def _raw_write(self, msg: protobuf.MessageType) -> None: - return super()._raw_write(self._filter_message(msg)) + return self.get_management_session()._write(self._filter_message(msg)) @staticmethod def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: @@ -1322,23 +1693,25 @@ def sync_responses(self) -> None: # Start by canceling whatever is on screen. This will work to cancel T1 PIN # prompt, which is in TINY mode and does not respond to `Ping`. - cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) - self.transport.begin_session() + # TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) + self.transport.open() try: - self.transport.write(*cancel_msg) - + # self.protocol.write(messages.Cancel()) message = "SYNC" + secrets.token_hex(8) - ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) - self.transport.write(*ping_msg) + self.get_management_session()._write(messages.Ping(message=message)) resp = None while resp != messages.Success(message=message): - msg_id, msg_bytes = self.transport.read() try: - resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) + resp = self.get_management_session()._read() + + raise Exception + except Exception: pass + finally: - self.transport.end_session() + pass # TODO fix + # self.transport.end_session(self.session_id or b"") def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() @@ -1352,8 +1725,8 @@ def mnemonic_callback(self, _) -> str: @expect(messages.Success, field="message", ret_type=str) def load_device( - client: "TrezorClient", - mnemonic: Union[str, Iterable[str]], + session: "Session", + mnemonic: str | t.Iterable[str], pin: str | None, passphrase_protection: bool, label: str | None, @@ -1366,12 +1739,12 @@ def load_device( mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call device.wipe() and try again." ) - resp = client.call( + resp = session.call( messages.LoadDevice( mnemonics=mnemonics, pin=pin, @@ -1382,7 +1755,7 @@ def load_device( no_backup=no_backup, ) ) - client.init_device() + session.refresh_features() return resp @@ -1391,11 +1764,11 @@ def load_device( @expect(messages.Success, field="message", ret_type=str) -def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: - if client.features.bootloader_mode is not True: +def prodtest_t1(session: "Session") -> protobuf.MessageType: + if session.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") - return client.call( + return session.call( messages.ProdTestT1( payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" ) @@ -1404,8 +1777,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: def record_screen( debug_client: "TrezorClientDebugLink", - directory: Union[str, None], - report_func: Union[Callable[[str], None], None] = None, + directory: str | None, + report_func: t.Callable[[str], None] | None = None, ) -> None: """Record screen changes into a specified directory. @@ -1451,5 +1824,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool: @expect(messages.Success, field="message", ret_type=str) -def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType: - return client.call(messages.DebugLinkOptigaSetSecMax()) +def optiga_set_sec_max(session: "Session") -> protobuf.MessageType: + return session.call(messages.DebugLinkOptigaSetSecMax()) diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index ebd7ca85f51..2542f00ddee 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -23,20 +23,19 @@ from . import messages from .exceptions import Cancelled, TrezorException -from .tools import Address, expect, session +from .tools import Address, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session RECOVERY_BACK = "\x08" # backspace character, sent literally @expect(messages.Success, field="message", ret_type=str) -@session def apply_settings( - client: "TrezorClient", + session: "Session", label: Optional[str] = None, language: Optional[str] = None, use_passphrase: Optional[bool] = None, @@ -67,13 +66,13 @@ def apply_settings( haptic_feedback=haptic_feedback, ) - out = client.call(settings) - client.refresh_features() + out = session.call(settings) + session.refresh_features() return out def _send_language_data( - client: "TrezorClient", + session: "Session", request: "messages.TranslationDataRequest", language_data: bytes, ) -> "MessageType": @@ -83,76 +82,70 @@ def _send_language_data( data_length = response.data_length data_offset = response.data_offset chunk = language_data[data_offset : data_offset + data_length] - response = client.call(messages.TranslationDataAck(data_chunk=chunk)) + response = session.call(messages.TranslationDataAck(data_chunk=chunk)) return response @expect(messages.Success, field="message", ret_type=str) -@session def change_language( - client: "TrezorClient", + session: "Session", language_data: bytes, show_display: bool | None = None, ) -> "MessageType": data_length = len(language_data) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) - response = client.call(msg) + response = session.call(msg) if data_length > 0: assert isinstance(response, messages.TranslationDataRequest) - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) assert isinstance(response, messages.Success) - client.refresh_features() # changing the language in features + session.refresh_features() # changing the language in features return response @expect(messages.Success, field="message", ret_type=str) -@session -def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": - out = client.call(messages.ApplyFlags(flags=flags)) - client.refresh_features() +def apply_flags(session: "Session", flags: int) -> "MessageType": + out = session.call(messages.ApplyFlags(flags=flags)) + session.refresh_features() return out @expect(messages.Success, field="message", ret_type=str) -@session -def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": - ret = client.call(messages.ChangePin(remove=remove)) - client.refresh_features() +def change_pin(session: "Session", remove: bool = False) -> "MessageType": + ret = session.call(messages.ChangePin(remove=remove)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session -def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": - ret = client.call(messages.ChangeWipeCode(remove=remove)) - client.refresh_features() +def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType": + ret = session.call(messages.ChangeWipeCode(remove=remove)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType + session: "Session", operation: messages.SdProtectOperationType ) -> "MessageType": - ret = client.call(messages.SdProtect(operation=operation)) - client.refresh_features() + ret = session.call(messages.SdProtect(operation=operation)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session -def wipe(client: "TrezorClient") -> "MessageType": - ret = client.call(messages.WipeDevice()) - if not client.features.bootloader_mode: - client.init_device() +def wipe(session: "Session") -> "MessageType": + + ret = session.call(messages.WipeDevice()) + # if not session.features.bootloader_mode: + # session.refresh_features() return ret -@session def recover( - client: "TrezorClient", + session: "Session", word_count: int = 24, passphrase_protection: bool = False, pin_protection: bool = True, @@ -188,13 +181,13 @@ def recover( if type is None: type = messages.RecoveryType.NormalRecovery - if client.features.model == "1" and input_callback is None: + if session.features.model == "1" and input_callback is None: raise RuntimeError("Input callback required for Trezor One") if word_count not in (12, 18, 24): raise ValueError("Invalid word count. Use 12/18/24") - if client.features.initialized and type == messages.RecoveryType.NormalRecovery: + if session.features.initialized and type == messages.RecoveryType.NormalRecovery: raise RuntimeError( "Device already initialized. Call device.wipe() and try again." ) @@ -216,24 +209,23 @@ def recover( msg.label = label msg.u2f_counter = u2f_counter - res = client.call(msg) + res = session.call(msg) while isinstance(res, messages.WordRequest): try: assert input_callback is not None inp = input_callback(res.type) - res = client.call(messages.WordAck(word=inp)) + res = session.call(messages.WordAck(word=inp)) except Cancelled: - res = client.call(messages.Cancel()) + res = session.call(messages.Cancel()) - client.init_device() + session.refresh_features() return res @expect(messages.Success, field="message", ret_type=str) -@session def reset( - client: "TrezorClient", + session: "Session", display_random: bool = False, strength: Optional[int] = None, passphrase_protection: bool = False, @@ -257,13 +249,13 @@ def reset( DeprecationWarning, ) - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call wipe_device() and try again." ) if strength is None: - if client.features.model == "1": + if session.features.model == "1": strength = 256 else: strength = 128 @@ -280,25 +272,24 @@ def reset( backup_type=backup_type, ) - resp = client.call(msg) + resp = session.call(msg) if not isinstance(resp, messages.EntropyRequest): raise RuntimeError("Invalid response, expected EntropyRequest") external_entropy = os.urandom(32) # LOG.debug("Computer generated entropy: " + external_entropy.hex()) - ret = client.call(messages.EntropyAck(entropy=external_entropy)) - client.init_device() + ret = session.call(messages.EntropyAck(entropy=external_entropy)) + session.refresh_features() # TODO is necessary? return ret @expect(messages.Success, field="message", ret_type=str) -@session def backup( - client: "TrezorClient", + session: "Session", group_threshold: Optional[int] = None, groups: Iterable[tuple[int, int]] = (), ) -> "MessageType": - ret = client.call( + ret = session.call( messages.BackupDevice( group_threshold=group_threshold, groups=[ @@ -307,37 +298,36 @@ def backup( ], ) ) - client.refresh_features() + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -def cancel_authorization(client: "TrezorClient") -> "MessageType": - return client.call(messages.CancelAuthorization()) +def cancel_authorization(session: "Session") -> "MessageType": + return session.call(messages.CancelAuthorization()) @expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes) -def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType": - resp = client.call(messages.UnlockPath(address_n=n)) +def unlock_path(session: "Session", n: "Address") -> "MessageType": + resp = session.call(messages.UnlockPath(address_n=n)) # Cancel the UnlockPath workflow now that we have the authentication code. try: - client.call(messages.Cancel()) + session.call(messages.Cancel()) except Cancelled: return resp else: raise TrezorException("Unexpected response in UnlockPath flow") -@session @expect(messages.Success, field="message", ret_type=str) def reboot_to_bootloader( - client: "TrezorClient", + session: "Session", boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, firmware_header: Optional[bytes] = None, language_data: bytes = b"", ) -> "MessageType": - response = client.call( + response = session.call( messages.RebootToBootloader( boot_command=boot_command, firmware_header=firmware_header, @@ -345,42 +335,37 @@ def reboot_to_bootloader( ) ) if isinstance(response, messages.TranslationDataRequest): - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) return response -@session @expect(messages.Success, field="message", ret_type=str) -def show_device_tutorial(client: "TrezorClient") -> "MessageType": - return client.call(messages.ShowDeviceTutorial()) +def show_device_tutorial(session: "Session") -> "MessageType": + return session.call(messages.ShowDeviceTutorial()) -@session @expect(messages.Success, field="message", ret_type=str) -def unlock_bootloader(client: "TrezorClient") -> "MessageType": - return client.call(messages.UnlockBootloader()) +def unlock_bootloader(session: "Session") -> "MessageType": + return session.call(messages.UnlockBootloader()) @expect(messages.Success, field="message", ret_type=str) -@session -def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType": +def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType": """Sets or clears the busy state of the device. In the busy state the device shows a "Do not disconnect" message instead of the homescreen. Setting `expiry_ms=None` clears the busy state. """ - ret = client.call(messages.SetBusy(expiry_ms=expiry_ms)) - client.refresh_features() + ret = session.call(messages.SetBusy(expiry_ms=expiry_ms)) + session.refresh_features() return ret @expect(messages.AuthenticityProof) -def authenticate(client: "TrezorClient", challenge: bytes): - return client.call(messages.AuthenticateDevice(challenge=challenge)) +def authenticate(session: "Session", challenge: bytes): + return session.call(messages.AuthenticateDevice(challenge=challenge)) @expect(messages.Success, field="message", ret_type=str) -def set_brightness( - client: "TrezorClient", value: Optional[int] = None -) -> "MessageType": - return client.call(messages.SetBrightness(value=value)) +def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType": + return session.call(messages.SetBrightness(value=value)) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index 1ffaafb4ab7..fffe6f0adc2 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -18,12 +18,12 @@ from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages -from .tools import b58decode, expect, session +from .tools import b58decode, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session def name_to_number(name: str) -> int: @@ -321,17 +321,16 @@ def parse_transaction_json( @expect(messages.EosPublicKey) def get_public_key( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> "MessageType": - response = client.call( + response = session.call( messages.EosGetPublicKey(address_n=n, show_display=show_display) ) return response -@session def sign_tx( - client: "TrezorClient", + session: "Session", address: "Address", transaction: dict, chain_id: str, @@ -347,11 +346,11 @@ def sign_tx( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) try: while isinstance(response, messages.EosTxActionRequest): - response = client.call(actions.pop(0)) + response = session.call(actions.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 1cf2eeeaed1..60eaa3366ba 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -18,12 +18,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import definitions, exceptions, messages -from .tools import expect, prepare_message_bytes, session, unharden +from .tools import expect, prepare_message_bytes, unharden if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session def int_to_big_endian(value: int) -> bytes: @@ -163,13 +163,13 @@ def network_from_address_n( @expect(messages.EthereumAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumGetAddress( address_n=n, show_display=show_display, @@ -181,16 +181,15 @@ def get_address( @expect(messages.EthereumPublicKey) def get_public_node( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> "MessageType": - return client.call( + return session.call( messages.EthereumGetPublicKey(address_n=n, show_display=show_display) ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", n: "Address", nonce: int, gas_price: int, @@ -226,13 +225,13 @@ def sign_tx( data, chunk = data[1024:], data[:1024] msg.data_initial_chunk = chunk - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -247,9 +246,8 @@ def sign_tx( return response.signature_v, response.signature_r, response.signature_s -@session def sign_tx_eip1559( - client: "TrezorClient", + session: "Session", n: "Address", *, nonce: int, @@ -282,13 +280,13 @@ def sign_tx_eip1559( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -299,13 +297,13 @@ def sign_tx_eip1559( @expect(messages.EthereumMessageSignature) def sign_message( - client: "TrezorClient", + session: "Session", n: "Address", message: AnyStr, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumSignMessage( address_n=n, message=prepare_message_bytes(message), @@ -317,7 +315,7 @@ def sign_message( @expect(messages.EthereumTypedDataSignature) def sign_typed_data( - client: "TrezorClient", + session: "Session", n: "Address", data: Dict[str, Any], *, @@ -333,7 +331,7 @@ def sign_typed_data( metamask_v4_compat=metamask_v4_compat, definitions=definitions, ) - response = client.call(request) + response = session.call(request) # Sending all the types while isinstance(response, messages.EthereumTypedDataStructRequest): @@ -349,7 +347,7 @@ def sign_typed_data( members.append(struct_member) request = messages.EthereumTypedDataStructAck(members=members) - response = client.call(request) + response = session.call(request) # Sending the whole message that should be signed while isinstance(response, messages.EthereumTypedDataValueRequest): @@ -362,7 +360,7 @@ def sign_typed_data( member_typename = data["primaryType"] member_data = data["message"] else: - client.cancel() + # TODO session.cancel() raise exceptions.TrezorException("Root index can only be 0 or 1") # It can be asking for a nested structure (the member path being [X, Y, Z, ...]) @@ -385,20 +383,20 @@ def sign_typed_data( encoded_data = encode_data(member_data, member_typename) request = messages.EthereumTypedDataValueAck(value=encoded_data) - response = client.call(request) + response = session.call(request) return response def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: bytes, message: AnyStr, chunkify: bool = False, ) -> bool: try: - resp = client.call( + resp = session.call( messages.EthereumVerifyMessage( address=address, signature=signature, @@ -413,13 +411,13 @@ def verify_message( @expect(messages.EthereumTypedDataSignature) def sign_typed_data_hash( - client: "TrezorClient", + session: "Session", n: "Address", domain_hash: bytes, message_hash: Optional[bytes], encoded_network: Optional[bytes] = None, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumSignTypedHash( address_n=n, domain_separator_hash=domain_hash, diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index 4ed6f22951f..90064bb238c 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -20,8 +20,8 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session @expect( @@ -29,27 +29,27 @@ field="credentials", ret_type=List[messages.WebAuthnCredential], ) -def list_credentials(client: "TrezorClient") -> "MessageType": - return client.call(messages.WebAuthnListResidentCredentials()) +def list_credentials(session: "Session") -> "MessageType": + return session.call(messages.WebAuthnListResidentCredentials()) @expect(messages.Success, field="message", ret_type=str) -def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType": - return client.call( +def add_credential(session: "Session", credential_id: bytes) -> "MessageType": + return session.call( messages.WebAuthnAddResidentCredential(credential_id=credential_id) ) @expect(messages.Success, field="message", ret_type=str) -def remove_credential(client: "TrezorClient", index: int) -> "MessageType": - return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) +def remove_credential(session: "Session", index: int) -> "MessageType": + return session.call(messages.WebAuthnRemoveResidentCredential(index=index)) @expect(messages.Success, field="message", ret_type=str) -def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": - return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) +def set_counter(session: "Session", u2f_counter: int) -> "MessageType": + return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) @expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) -def get_next_counter(client: "TrezorClient") -> "MessageType": - return client.call(messages.GetNextU2FCounter()) +def get_next_counter(session: "Session") -> "MessageType": + return session.call(messages.GetNextU2FCounter()) diff --git a/python/src/trezorlib/firmware/__init__.py b/python/src/trezorlib/firmware/__init__.py index 5cc5d8830cb..a588b160e1e 100644 --- a/python/src/trezorlib/firmware/__init__.py +++ b/python/src/trezorlib/firmware/__init__.py @@ -20,7 +20,7 @@ from typing_extensions import Protocol, TypeGuard from .. import messages -from ..tools import expect, session +from ..tools import expect from .core import VendorFirmware from .legacy import LegacyFirmware, LegacyV2Firmware @@ -38,7 +38,7 @@ from .vendor import * # noqa: F401, F403 if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session T = t.TypeVar("T", bound="FirmwareType") @@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]: # ====== Client functions ====== # -@session def update( - client: "TrezorClient", + session: "Session", data: bytes, progress_update: t.Callable[[int], t.Any] = lambda _: None, ): - if client.features.bootloader_mode is False: + if session.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") - resp = client.call(messages.FirmwareErase(length=len(data))) + resp = session.call(messages.FirmwareErase(length=len(data))) # TREZORv1 method if isinstance(resp, messages.Success): - resp = client.call(messages.FirmwareUpload(payload=data)) + resp = session.call(messages.FirmwareUpload(payload=data)) progress_update(len(data)) if isinstance(resp, messages.Success): return @@ -97,7 +96,7 @@ def update( length = resp.length payload = data[resp.offset : resp.offset + length] digest = blake2s(payload).digest() - resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) + resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest)) progress_update(length) if isinstance(resp, messages.Success): @@ -107,5 +106,5 @@ def update( @expect(messages.FirmwareHash, field="hash", ret_type=bytes) -def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]): - return client.call(messages.GetFirmwareHash(challenge=challenge)) +def get_hash(session: "Session", challenge: t.Optional[bytes]): + return session.call(messages.GetFirmwareHash(challenge=challenge)) diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index d50324d5868..04b75f0aa56 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -17,6 +17,7 @@ from __future__ import annotations import io +import logging from types import ModuleType from typing import Dict, Optional, Tuple, Type, TypeVar @@ -25,6 +26,7 @@ from . import messages, protobuf T = TypeVar("T") +LOG = logging.getLogger(__name__) class ProtobufMapping: @@ -63,11 +65,21 @@ def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]: wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) if wire_type is None: raise ValueError("Cannot encode class without wire type") - + LOG.debug("encoding wire type %d", wire_type) buf = io.BytesIO() protobuf.dump_message(buf, msg) return wire_type, buf.getvalue() + def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes: + """Serialize a Python protobuf class. + + Returns the byte representation of the protobuf message. + """ + + buf = io.BytesIO() + protobuf.dump_message(buf, msg) + return buf.getvalue() + def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: """Deserialize a protobuf message into a Python class.""" cls = self.type_to_class[msg_wire_type] @@ -83,7 +95,9 @@ def from_module(cls, module: ModuleType) -> Self: mapping = cls() message_types = getattr(module, "MessageType") - for entry in message_types: + thp_message_types = getattr(module, "ThpMessageType") + + for entry in (*message_types, *thp_message_types): msg_class = getattr(module, entry.name, None) if msg_class is None: raise ValueError( diff --git a/python/src/trezorlib/messages.py b/python/src/trezorlib/messages.py index b52119311f2..86fd70dfd80 100644 --- a/python/src/trezorlib/messages.py +++ b/python/src/trezorlib/messages.py @@ -43,6 +43,8 @@ class FailureType(IntEnum): PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 + ThpUnallocatedSession = 15 + InvalidProtocol = 16 FirmwareError = 99 @@ -400,6 +402,34 @@ class TezosBallotType(IntEnum): Pass = 2 +class ThpMessageType(IntEnum): + ThpCreateNewSession = 1000 + ThpNewSession = 1001 + ThpStartPairingRequest = 1008 + ThpPairingPreparationsFinished = 1009 + ThpCredentialRequest = 1010 + ThpCredentialResponse = 1011 + ThpEndRequest = 1012 + ThpEndResponse = 1013 + ThpCodeEntryCommitment = 1016 + ThpCodeEntryChallenge = 1017 + ThpCodeEntryCpaceHost = 1018 + ThpCodeEntryCpaceTrezor = 1019 + ThpCodeEntryTag = 1020 + ThpCodeEntrySecret = 1021 + ThpQrCodeTag = 1024 + ThpQrCodeSecret = 1025 + ThpNfcUnidirectionalTag = 1032 + ThpNfcUnidirectionalSecret = 1033 + + +class ThpPairingMethod(IntEnum): + NoMethod = 1 + CodeEntry = 2 + QrCode = 3 + NFC_Unidirectional = 4 + + class MessageType(IntEnum): Initialize = 0 Ping = 1 @@ -4100,6 +4130,7 @@ class DebugLinkGetState(protobuf.MessageType): 1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None), 2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None), 3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE), + 4: protobuf.Field("thp_channel_id", "bytes", repeated=False, required=False, default=None), } def __init__( @@ -4108,10 +4139,12 @@ def __init__( wait_word_list: Optional["bool"] = None, wait_word_pos: Optional["bool"] = None, wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE, + thp_channel_id: Optional["bytes"] = None, ) -> None: self.wait_word_list = wait_word_list self.wait_word_pos = wait_word_pos self.wait_layout = wait_layout + self.thp_channel_id = thp_channel_id class DebugLinkState(protobuf.MessageType): @@ -4130,6 +4163,9 @@ class DebugLinkState(protobuf.MessageType): 11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None), 12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None), 13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None), + 14: protobuf.Field("thp_pairing_code_entry_code", "uint32", repeated=False, required=False, default=None), + 15: protobuf.Field("thp_pairing_code_qr_code", "bytes", repeated=False, required=False, default=None), + 16: protobuf.Field("thp_pairing_code_nfc_unidirectional", "bytes", repeated=False, required=False, default=None), } def __init__( @@ -4148,6 +4184,9 @@ def __init__( recovery_word_pos: Optional["int"] = None, reset_word_pos: Optional["int"] = None, mnemonic_type: Optional["BackupType"] = None, + thp_pairing_code_entry_code: Optional["int"] = None, + thp_pairing_code_qr_code: Optional["bytes"] = None, + thp_pairing_code_nfc_unidirectional: Optional["bytes"] = None, ) -> None: self.tokens: Sequence["str"] = tokens if tokens is not None else [] self.layout = layout @@ -4162,6 +4201,9 @@ def __init__( self.recovery_word_pos = recovery_word_pos self.reset_word_pos = reset_word_pos self.mnemonic_type = mnemonic_type + self.thp_pairing_code_entry_code = thp_pairing_code_entry_code + self.thp_pairing_code_qr_code = thp_pairing_code_qr_code + self.thp_pairing_code_nfc_unidirectional = thp_pairing_code_nfc_unidirectional class DebugLinkStop(protobuf.MessageType): @@ -7824,6 +7866,280 @@ def __init__( self.amount = amount +class ThpDeviceProperties(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None), + 3: protobuf.Field("bootloader_mode", "bool", repeated=False, required=False, default=None), + 4: protobuf.Field("protocol_version", "uint32", repeated=False, required=False, default=None), + 5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None), + } + + def __init__( + self, + *, + pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None, + internal_model: Optional["str"] = None, + model_variant: Optional["int"] = None, + bootloader_mode: Optional["bool"] = None, + protocol_version: Optional["int"] = None, + ) -> None: + self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else [] + self.internal_model = internal_model + self.model_variant = model_variant + self.bootloader_mode = bootloader_mode + self.protocol_version = protocol_version + + +class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None), + } + + def __init__( + self, + *, + pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None, + host_pairing_credential: Optional["bytes"] = None, + ) -> None: + self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else [] + self.host_pairing_credential = host_pairing_credential + + +class ThpCreateNewSession(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1000 + FIELDS = { + 1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None), + 3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + passphrase: Optional["str"] = None, + on_device: Optional["bool"] = None, + derive_cardano: Optional["bool"] = None, + ) -> None: + self.passphrase = passphrase + self.on_device = on_device + self.derive_cardano = derive_cardano + + +class ThpNewSession(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1001 + FIELDS = { + 1: protobuf.Field("new_session_id", "uint32", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + new_session_id: Optional["int"] = None, + ) -> None: + self.new_session_id = new_session_id + + +class ThpStartPairingRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1008 + FIELDS = { + 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_name: Optional["str"] = None, + ) -> None: + self.host_name = host_name + + +class ThpPairingPreparationsFinished(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1009 + + +class ThpCodeEntryCommitment(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1016 + FIELDS = { + 1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + commitment: Optional["bytes"] = None, + ) -> None: + self.commitment = commitment + + +class ThpCodeEntryChallenge(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1017 + FIELDS = { + 1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + challenge: Optional["bytes"] = None, + ) -> None: + self.challenge = challenge + + +class ThpCodeEntryCpaceHost(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1018 + FIELDS = { + 1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_host_public_key: Optional["bytes"] = None, + ) -> None: + self.cpace_host_public_key = cpace_host_public_key + + +class ThpCodeEntryCpaceTrezor(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1019 + FIELDS = { + 1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_trezor_public_key: Optional["bytes"] = None, + ) -> None: + self.cpace_trezor_public_key = cpace_trezor_public_key + + +class ThpCodeEntryTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1020 + FIELDS = { + 2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpCodeEntrySecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1021 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpQrCodeTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1024 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpQrCodeSecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1025 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpNfcUnidirectionalTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1032 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpNfcUnidirectionalSecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1033 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpCredentialRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1010 + FIELDS = { + 1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_static_pubkey: Optional["bytes"] = None, + ) -> None: + self.host_static_pubkey = host_static_pubkey + + +class ThpCredentialResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1011 + FIELDS = { + 1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + trezor_static_pubkey: Optional["bytes"] = None, + credential: Optional["bytes"] = None, + ) -> None: + self.trezor_static_pubkey = trezor_static_pubkey + self.credential = credential + + +class ThpEndRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1012 + + +class ThpEndResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1013 + + class ThpCredentialMetadata(protobuf.MessageType): MESSAGE_WIRE_TYPE = None FIELDS = { diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index 4ed6f5aa81c..d951c52d7cd 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -20,25 +20,25 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.Entropy, field="entropy", ret_type=bytes) -def get_entropy(client: "TrezorClient", size: int) -> "MessageType": - return client.call(messages.GetEntropy(size=size)) +def get_entropy(session: "Session", size: int) -> "MessageType": + return session.call(messages.GetEntropy(size=size)) @expect(messages.SignedIdentity) def sign_identity( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, ecdsa_curve_name: Optional[str] = None, ) -> "MessageType": - return client.call( + return session.call( messages.SignIdentity( identity=identity, challenge_hidden=challenge_hidden, @@ -50,12 +50,12 @@ def sign_identity( @expect(messages.ECDHSessionKey) def get_ecdh_session_key( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, peer_public_key: bytes, ecdsa_curve_name: Optional[str] = None, ) -> "MessageType": - return client.call( + return session.call( messages.GetECDHSessionKey( identity=identity, peer_public_key=peer_public_key, @@ -66,7 +66,7 @@ def get_ecdh_session_key( @expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -74,7 +74,7 @@ def encrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> "MessageType": - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -89,7 +89,7 @@ def encrypt_keyvalue( @expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -97,7 +97,7 @@ def decrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> "MessageType": - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -111,5 +111,5 @@ def decrypt_keyvalue( @expect(messages.Nonce, field="nonce", ret_type=bytes) -def get_nonce(client: "TrezorClient"): - return client.call(messages.GetNonce()) +def get_nonce(session: "Session"): + return session.call(messages.GetNonce()) diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index 5bce7574e82..5b071626b48 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -20,9 +20,9 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session # MAINNET = 0 @@ -33,13 +33,13 @@ @expect(messages.MoneroAddress, field="address", ret_type=bytes) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.MoneroGetAddress( address_n=n, show_display=show_display, @@ -51,10 +51,10 @@ def get_address( @expect(messages.MoneroWatchKey) def get_watch_key( - client: "TrezorClient", + session: "Session", n: "Address", network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, ) -> "MessageType": - return client.call( + return session.call( messages.MoneroGetWatchKey(address_n=n, network_type=network_type) ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 3a67aec72c2..6aa087757a6 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -21,9 +21,9 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 @@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig @expect(messages.NEMAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", network: int, show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.NEMGetAddress( address_n=n, network=network, show_display=show_display, chunkify=chunkify ) @@ -213,7 +213,7 @@ def get_address( @expect(messages.NEMSignedTx) def sign_tx( - client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False + session: "Session", n: "Address", transaction: dict, chunkify: bool = False ) -> "MessageType": try: msg = create_sign_tx(transaction, chunkify=chunkify) @@ -222,4 +222,4 @@ def sign_tx( assert msg.transaction is not None msg.transaction.address_n = n - return client.call(msg) + return session.call(msg) diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 7a953b8fac5..f026236c071 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -21,9 +21,9 @@ from .tools import dict_from_camelcase, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") @@ -31,12 +31,12 @@ @expect(messages.RippleAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.RippleGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -45,14 +45,14 @@ def get_address( @expect(messages.RippleSignedTx) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", msg: messages.RippleSignTx, chunkify: bool = False, ) -> "MessageType": msg.address_n = address_n msg.chunkify = chunkify - return client.call(msg) + return session.call(msg) def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: diff --git a/python/src/trezorlib/solana.py b/python/src/trezorlib/solana.py index be7f2e5fcb5..1a228b2f957 100644 --- a/python/src/trezorlib/solana.py +++ b/python/src/trezorlib/solana.py @@ -4,29 +4,29 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session @expect(messages.SolanaPublicKey) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, ) -> "MessageType": - return client.call( + return session.call( messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display) ) @expect(messages.SolanaAddress) def get_address( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.SolanaGetAddress( address_n=address_n, show_display=show_display, @@ -37,12 +37,12 @@ def get_address( @expect(messages.SolanaTxSignature) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: List[int], serialized_tx: bytes, additional_info: Optional[messages.SolanaTxAdditionalInfo], ) -> "MessageType": - return client.call( + return session.call( messages.SolanaSignTx( address_n=address_n, serialized_tx=serialized_tx, diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index ebf81e4fd04..12a75ca5d8a 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -21,9 +21,9 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session StellarMessageType = Union[ messages.StellarAccountMergeOp, @@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: @expect(messages.StellarAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.StellarGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -338,7 +338,7 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", tx: messages.StellarSignTx, operations: List["StellarMessageType"], address_n: "Address", @@ -354,10 +354,10 @@ def sign_tx( # 3. Receive a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 5. The final message received will be StellarSignedTx which is returned from this method - resp = client.call(tx) + resp = session.call(tx) try: while isinstance(resp, messages.StellarTxOpRequest): - resp = client.call(operations.pop(0)) + resp = session.call(operations.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index cff06ed6c83..b74dc562599 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -20,19 +20,19 @@ from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.TezosAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.TezosGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -41,12 +41,12 @@ def get_address( @expect(messages.TezosPublicKey, field="public_key", ret_type=str) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.TezosGetPublicKey( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -55,11 +55,11 @@ def get_public_key( @expect(messages.TezosSignedTx) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", sign_tx_msg: messages.TezosSignTx, chunkify: bool = False, ) -> "MessageType": sign_tx_msg.address_n = address_n sign_tx_msg.chunkify = chunkify - return client.call(sign_tx_msg) + return session.call(sign_tx_msg) diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 4fd1558ec29..3e9bd1c5608 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -40,7 +40,7 @@ # More details: https://www.python.org/dev/peps/pep-0612/ from typing import TypeVar - from typing_extensions import Concatenate, ParamSpec + from typing_extensions import ParamSpec from . import client from .protobuf import MessageType @@ -284,23 +284,6 @@ def wrapped_f(*args: "P.args", **kwargs: "P.kwargs") -> "Union[MT, R]": return decorator -def session( - f: "Callable[Concatenate[TrezorClient, P], R]", -) -> "Callable[Concatenate[TrezorClient, P], R]": - # Decorator wraps a BaseClient method - # with session activation / deactivation - @functools.wraps(f) - def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - client.open() - try: - return f(client, *args, **kwargs) - finally: - client.close() - - return wrapped_f - - # de-camelcasifier # https://stackoverflow.com/a/1176023/222189 diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index b04876b6b77..45d05150c2b 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,24 +14,18 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -from typing import ( - TYPE_CHECKING, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +import typing as t from ..exceptions import TrezorException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel - T = TypeVar("T", bound="Transport") + T = t.TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) @@ -41,7 +35,7 @@ """.strip() -MessagePayload = Tuple[int, bytes] +MessagePayload = t.Tuple[int, bytes] class TransportException(TrezorException): @@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException): class Transport: - """Raw connection to a Trezor device. + PATH_PREFIX: str - Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB - or USB-HID connection, or UDP socket of listening emulator(s). - It can also enumerate devices available over this communication link, and return - them as instances. + @classmethod + def enumerate( + cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["T"]: + raise NotImplementedError - Transport instance is a thing that: - - can be identified and requested by a string URI-like path - - can open and close sessions, which enclose related operations - - can read and write protobuf messages + @classmethod + def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T": + for device in cls.enumerate(): - You need to implement a new Transport subclass if you invent a new way to connect - a Trezor device to a computer. - """ + if device.get_path() == path: + return device - PATH_PREFIX: str - ENABLED = False + if prefix_search and device.get_path().startswith(path): + return device - def __str__(self) -> str: - return self.get_path() + raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def get_path(self) -> str: raise NotImplementedError - def begin_session(self) -> None: - raise NotImplementedError - - def end_session(self) -> None: + def find_debug(self: "T") -> "T": raise NotImplementedError - def read(self) -> MessagePayload: + def open(self) -> None: raise NotImplementedError - def write(self, message_type: int, message_data: bytes) -> None: + def close(self) -> None: raise NotImplementedError - def find_debug(self: "T") -> "T": + def write_chunk(self, chunk: bytes) -> None: raise NotImplementedError - @classmethod - def enumerate( - cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["T"]: + def read_chunk(self) -> bytes: raise NotImplementedError - @classmethod - def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": - for device in cls.enumerate(): - if ( - path is None - or device.get_path() == path - or (prefix_search and device.get_path().startswith(path)) - ): - return device - - raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") + CHUNK_SIZE: t.ClassVar[int] -def all_transports() -> Iterable[Type["Transport"]]: +def all_transports() -> t.Iterable[t.Type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - transports: Tuple[Type["Transport"], ...] = ( + transports: t.Tuple[t.Type["Transport"], ...] = ( BridgeTransport, HidTransport, UdpTransport, @@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]: def enumerate_devices( - models: Optional[Iterable["TrezorModel"]] = None, -) -> Sequence["Transport"]: - devices: List["Transport"] = [] + models: t.Iterable["TrezorModel"] | None = None, +) -> t.Sequence["Transport"]: + devices: t.List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: @@ -145,9 +121,7 @@ def enumerate_devices( return devices -def get_transport( - path: Optional[str] = None, prefix_search: bool = False -) -> "Transport": +def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e0c34a8f701..8d69e5b253f 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,24 +14,30 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import struct -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +import typing as t import requests from ..log import DUMP_PACKETS from . import DeviceIsBusy, MessagePayload, Transport, TransportException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel LOG = logging.getLogger(__name__) +PROTOCOL_VERSION_1 = 1 +PROTOCOL_VERSION_2 = 2 + TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_VERSION_MODERN = (2, 0, 25) +TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) @@ -45,7 +51,7 @@ def __init__(self, path: str, status: int, message: str) -> None: super().__init__(f"trezord: {path} failed with code {status}: {message}") -def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: +def call_bridge(path: str, data: str | None = None) -> requests.Response: url = TREZORD_HOST + "/" + path r = CONNECTION.post(url, data=data) if r.status_code != 200: @@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: return r -def is_legacy_bridge() -> bool: +def get_bridge_version() -> t.Tuple[int, ...]: config = call_bridge("configure").json() - version_tuple = tuple(map(int, config["version"].split("."))) - return version_tuple < TREZORD_VERSION_MODERN + return tuple(map(int, config["version"].split("."))) + + +def is_legacy_bridge() -> bool: + return get_bridge_version() < TREZORD_VERSION_MODERN + + +def supports_protocolV2() -> bool: + return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT + + +def detect_protocol_version(transport: "BridgeTransport") -> int: + from .. import mapping, messages + from ..messages import FailureType + + protocol_version = PROTOCOL_VERSION_1 + request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize()) + transport.deprecated_begin_session() + transport.deprecated_write(request_type, request_data) + + response_type, response_data = transport.deprecated_read() + response = mapping.DEFAULT_MAPPING.decode(response_type, response_data) + transport.deprecated_begin_session() + if isinstance(response, messages.Failure): + if response.code == FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol_version = PROTOCOL_VERSION_2 + + return protocol_version + + +def _is_transport_valid(transport: "BridgeTransport") -> bool: + is_valid = ( + supports_protocolV2() + or detect_protocol_version(transport) == PROTOCOL_VERSION_1 + ) + if not is_valid: + LOG.warning("Detected unsupported Bridge transport!") + return is_valid + + +def filter_invalid_bridge_transports( + transports: t.Iterable["BridgeTransport"], +) -> t.Sequence["BridgeTransport"]: + """Filters out invalid bridge transports. Keeps only valid ones.""" + return [t for t in transports if _is_transport_valid(t)] class BridgeHandle: @@ -84,7 +134,7 @@ def read_buf(self) -> bytes: class BridgeHandleLegacy(BridgeHandle): def __init__(self, transport: "BridgeTransport") -> None: super().__init__(transport) - self.request: Optional[str] = None + self.request: str | None = None def write_buf(self, buf: bytes) -> None: if self.request is not None: @@ -112,13 +162,12 @@ class BridgeTransport(Transport): ENABLED: bool = True def __init__( - self, device: Dict[str, Any], legacy: bool, debug: bool = False + self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False ) -> None: if legacy and debug: raise TransportException("Debugging not supported on legacy Bridge") - self.device = device - self.session: Optional[str] = None + self.session: str | None = device["session"] self.debug = debug self.legacy = legacy @@ -135,7 +184,7 @@ def find_debug(self) -> "BridgeTransport": raise TransportException("Debug device not available") return BridgeTransport(self.device, self.legacy, debug=True) - def _call(self, action: str, data: Optional[str] = None) -> requests.Response: + def _call(self, action: str, data: str | None = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: @@ -144,17 +193,20 @@ def _call(self, action: str, data: Optional[str] = None) -> requests.Response: @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["BridgeTransport"]: + cls, _models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() - return [ - BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() - ] + return filter_invalid_bridge_transports( + [ + BridgeTransport(dev, legacy) + for dev in call_bridge("enumerate").json() + ] + ) except Exception: return [] - def begin_session(self) -> None: + def deprecated_begin_session(self) -> None: try: data = self._call("acquire/" + self.device["path"]) except BridgeException as e: @@ -163,18 +215,32 @@ def begin_session(self) -> None: raise self.session = data.json()["session"] - def end_session(self) -> None: + def deprecated_end_session(self) -> None: if not self.session: return self._call("release") self.session = None - def write(self, message_type: int, message_data: bytes) -> None: + def deprecated_write(self, message_type: int, message_data: bytes) -> None: header = struct.pack(">HL", message_type, len(message_data)) self.handle.write_buf(header + message_data) - def read(self) -> MessagePayload: + def deprecated_read(self) -> MessagePayload: data = self.handle.read_buf() headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", data[:headerlen]) return msg_type, data[headerlen : headerlen + datalen] + + def open(self) -> None: + pass + # TODO self.handle.open() + + def close(self) -> None: + pass + # TODO self.handle.close() + + def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :) + self.handle.write_buf(chunk) + + def read_chunk(self) -> bytes: # TODO check if it works :) + return self.handle.read_buf() diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65fa08ccd70..995fd6960ca 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,15 +14,16 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import sys import time -from typing import Any, Dict, Iterable, List, Optional +import typing as t from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel -from . import UDEV_RULES_STR, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, Transport, TransportException LOG = logging.getLogger(__name__) @@ -35,23 +36,61 @@ HID_IMPORTED = False -HidDevice = Dict[str, Any] -HidDeviceHandle = Any +HidDevice = t.Dict[str, t.Any] +HidDeviceHandle = t.Any + + +class HidTransport(Transport): + """ + HidTransport implements transport over USB HID interface. + """ + PATH_PREFIX = "hid" + ENABLED = HID_IMPORTED -class HidHandle: - def __init__( - self, path: bytes, serial: str, probe_hid_version: bool = False - ) -> None: - self.path = path - self.serial = serial + def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None: + self.device = device + self.device_path = device["path"] + self.device_serial_number = device["serial_number"] self.handle: HidDeviceHandle = None self.hid_version = None if probe_hid_version else 2 + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" + + @classmethod + def enumerate( + cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False + ) -> t.Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + + devices: t.List["HidTransport"] = [] + for dev in hid.enumerate(0, 0): + usb_id = (dev["vendor_id"], dev["product_id"]) + if usb_id not in usb_ids: + continue + if debug: + if not is_debuglink(dev): + continue + else: + if not is_wirelink(dev): + continue + devices.append(HidTransport(dev)) + return devices + + def find_debug(self) -> "HidTransport": + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") + def open(self) -> None: self.handle = hid.device() try: - self.handle.open_path(self.path) + self.handle.open_path(self.device_path) except (IOError, OSError) as e: if sys.platform.startswith("linux"): e.args = e.args + (UDEV_RULES_STR,) @@ -62,11 +101,11 @@ def open(self) -> None: # and we wouldn't even know. # So we check that the serial matches what we expect. serial = self.handle.get_serial_number_string() - if serial != self.serial: + if serial != self.device_serial_number: self.handle.close() self.handle = None raise TransportException( - f"Unexpected device {serial} on path {self.path.decode()}" + f"Unexpected device {serial} on path {self.device_path.decode()}" ) self.handle.set_nonblocking(True) @@ -77,7 +116,7 @@ def open(self) -> None: def close(self) -> None: if self.handle is not None: # reload serial, because device.wipe() can reset it - self.serial = self.handle.get_serial_number_string() + self.device_serial_number = self.handle.get_serial_number_string() self.handle.close() self.handle = None @@ -115,53 +154,6 @@ def probe_hid_version(self) -> int: raise TransportException("Unknown HID version") -class HidTransport(ProtocolBasedTransport): - """ - HidTransport implements transport over USB HID interface. - """ - - PATH_PREFIX = "hid" - ENABLED = HID_IMPORTED - - def __init__(self, device: HidDevice) -> None: - self.device = device - self.handle = HidHandle(device["path"], device["serial_number"]) - - super().__init__(protocol=ProtocolV1(self.handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False - ) -> Iterable["HidTransport"]: - if models is None: - models = {TREZOR_ONE} - usb_ids = [id for model in models for id in model.usb_ids] - - devices: List["HidTransport"] = [] - for dev in hid.enumerate(0, 0): - usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id not in usb_ids: - continue - if debug: - if not is_debuglink(dev): - continue - else: - if not is_wirelink(dev): - continue - devices.append(HidTransport(dev)) - return devices - - def find_debug(self) -> "HidTransport": - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") - - def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py deleted file mode 100644 index a5a0ee6be4d..00000000000 --- a/python/src/trezorlib/transport/protocol.py +++ /dev/null @@ -1,165 +0,0 @@ -# This file is part of the Trezor project. -# -# Copyright (C) 2012-2022 SatoshiLabs and contributors -# -# This library is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the License along with this library. -# If not, see . - -import logging -import struct -from typing import Tuple - -from typing_extensions import Protocol as StructuralType - -from . import MessagePayload, Transport - -REPLEN = 64 - -V2_FIRST_CHUNK = 0x01 -V2_NEXT_CHUNK = 0x02 -V2_BEGIN_SESSION = 0x03 -V2_END_SESSION = 0x04 - -LOG = logging.getLogger(__name__) - - -class Handle(StructuralType): - """PEP 544 structural type for Handle functionality. - (called a "Protocol" in the proposed PEP, name which is impractical here) - - Handle is a "physical" layer for a protocol. - It can open/close a connection and read/write bare data in 64-byte chunks. - - Functionally we gain nothing from making this an (abstract) base class for handle - implementations, so this definition is for type hinting purposes only. You can, - but don't have to, inherit from it. - """ - - def open(self) -> None: ... - - def close(self) -> None: ... - - def read_chunk(self) -> bytes: ... - - def write_chunk(self, chunk: bytes) -> None: ... - - -class Protocol: - """Wire protocol that can communicate with a Trezor device, given a Handle. - - A Protocol implements the part of the Transport API that relates to communicating - logical messages over a physical layer. It is a thing that can: - - open and close sessions, - - send and receive protobuf messages, - given the ability to: - - open and close physical connections, - - and send and receive binary chunks. - - For now, the class also handles session counting and opening the underlying Handle. - This will probably be removed in the future. - - We will need a new Protocol class if we change the way a Trezor device encapsulates - its messages. - """ - - def __init__(self, handle: Handle) -> None: - self.handle = handle - self.session_counter = 0 - - # XXX we might be able to remove this now that TrezorClient does session handling - def begin_session(self) -> None: - if self.session_counter == 0: - self.handle.open() - self.session_counter += 1 - - def end_session(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - self.handle.close() - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - - -class ProtocolBasedTransport(Transport): - """Transport that implements its communications through a Protocol. - - Intended as a base class for implementations that proxy their communication - operations to a Protocol. - """ - - def __init__(self, protocol: Protocol) -> None: - self.protocol = protocol - - def write(self, message_type: int, message_data: bytes) -> None: - self.protocol.write(message_type, message_data) - - def read(self) -> MessagePayload: - return self.protocol.read() - - def begin_session(self) -> None: - self.protocol.begin_session() - - def end_session(self) -> None: - self.protocol.end_session() - - -class ProtocolV1(Protocol): - """Protocol version 1. Currently (11/2018) in use on all Trezors. - Does not understand sessions. - """ - - HEADER_LEN = struct.calcsize(">HL") - - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - buffer = bytearray(b"##" + header + message_data) - - while buffer: - # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: REPLEN - 1] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - buffer = buffer[63:] - - def read(self) -> MessagePayload: - buffer = bytearray() - # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() - buffer.extend(first_chunk) - - # Read the rest of the message - while len(buffer) < datalen: - buffer.extend(self.read_next()) - - return msg_type, buffer[:datalen] - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - if chunk[:3] != b"?##": - raise RuntimeError("Unexpected magic characters") - try: - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) - except Exception: - raise RuntimeError("Cannot parse header") - - data = chunk[3 + self.HEADER_LEN :] - return msg_type, datalen, data - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - if chunk[:1] != b"?": - raise RuntimeError("Unexpected magic characters") - return chunk[1:] diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py new file mode 100644 index 00000000000..6b6f4cce2c3 --- /dev/null +++ b/python/src/trezorlib/transport/session.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import logging +import typing as t + +from .. import exceptions, messages, models +from .thp.protocol_v1 import ProtocolV1 +from .thp.protocol_v2 import ProtocolV2 + +if t.TYPE_CHECKING: + from ..client import TrezorClient + +LOG = logging.getLogger(__name__) + + +class Session: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + + def __init__( + self, client: TrezorClient, id: bytes, passphrase: str | object | None = None + ) -> None: + self.client = client + self._id = id + self.passphrase = passphrase + + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool + ) -> Session: + raise NotImplementedError + + def call(self, msg: t.Any) -> t.Any: + # TODO self.check_firmware_version() + resp = self.call_raw(msg) + + while True: + if isinstance(resp, messages.PinMatrixRequest): + if self.pin_callback is None: + raise Exception # TODO + resp = self.pin_callback(self, resp) + elif isinstance(resp, messages.PassphraseRequest): + if self.passphrase_callback is None: + raise Exception # TODO + resp = self.passphrase_callback(self, resp) + elif isinstance(resp, messages.ButtonRequest): + if self.button_callback is None: + raise Exception # TODO + resp = self.button_callback(self, resp) + elif isinstance(resp, messages.Failure): + if resp.code == messages.FailureType.ActionCancelled: + raise exceptions.Cancelled + raise exceptions.TrezorFailure(resp) + else: + return resp + + def call_raw(self, msg: t.Any) -> t.Any: + self._write(msg) + return self._read() + + def _write(self, msg: t.Any) -> None: + raise NotImplementedError + + def _read(self) -> t.Any: + raise NotImplementedError + + def refresh_features(self) -> None: + self.client.refresh_features() + + def end(self) -> t.Any: + return self.call(messages.EndSession()) + + def ping(self, message: str, button_protection: bool | None = None) -> str: + resp: messages.Success = self.call( + messages.Ping(message=message, button_protection=button_protection) + ) + return resp.message or "" + + @property + def features(self) -> messages.Features: + return self.client.features + + @property + def model(self) -> models.TrezorModel: + return self.client.model + + @property + def version(self) -> t.Tuple[int, int, int]: + return self.client.version + + @property + def id(self) -> bytes: + return self._id + + @id.setter + def id(self, value: bytes) -> None: + if not isinstance(value, bytes): + raise ValueError("id must be of type bytes") + self._id = value + + +class SessionV1(Session): + derive_cardano: bool | None = False + + @classmethod + def new( + cls, + client: TrezorClient, + passphrase: str | object = "", + derive_cardano: bool = False, + session_id: bytes | None = None, + ) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, id=session_id or b"") + + session._init_callbacks() + session.passphrase = passphrase + session.derive_cardano = derive_cardano + session.init_session(session.derive_cardano) + return session + + @classmethod + def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, session_id) + session.init_session() + return session + + def _init_callbacks(self) -> None: + self.button_callback = self.client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self.client.passphrase_callback + + def _write(self, msg: t.Any) -> None: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + self.client.protocol.write(msg) + + def _read(self) -> t.Any: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + return self.client.protocol.read() + + def init_session(self, derive_cardano: bool | None = None): + if self.id == b"": + session_id = None + else: + session_id = self.id + resp: messages.Features = self.call_raw( + messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) + ) + if isinstance(self.passphrase, str): + self.passphrase_callback = _send_passphrase + self._id = resp.session_id + + +def _send_passphrase(session: Session, resp: t.Any) -> None: + assert isinstance(session.passphrase, str) + return session.call(messages.PassphraseAck(passphrase=session.passphrase)) + + +def _callback_button(session: Session, msg: t.Any) -> t.Any: + print("Please confirm action on your Trezor device.") # TODO how to handle UI? + return session.call(messages.ButtonAck()) + + +class SessionV2(Session): + + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool + ) -> SessionV2: + assert isinstance(client.protocol, ProtocolV2) + session = cls(client, b"\x00") + new_session: messages.ThpNewSession = session.call( + messages.ThpCreateNewSession( + passphrase=passphrase, derive_cardano=derive_cardano + ) + ) + assert new_session.new_session_id is not None + session_id = new_session.new_session_id + session.update_id_and_sid(session_id.to_bytes(1, "big")) + return session + + def __init__(self, client: TrezorClient, id: bytes) -> None: + super().__init__(client, id) + assert isinstance(client.protocol, ProtocolV2) + + self.pin_callback = client.pin_callback + self.button_callback = client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + self.channel: ProtocolV2 = client.protocol.get_channel() + self.update_id_and_sid(id) + + def _write(self, msg: t.Any) -> None: + LOG.debug("writing message %s", type(msg)) + self.channel.write(self.sid, msg) + + def _read(self) -> t.Any: + msg = self.channel.read(self.sid) + LOG.debug("reading message %s", type(msg)) + return msg + + def update_id_and_sid(self, id: bytes) -> None: + self._id = id + self.sid = int.from_bytes(id, "big") # TODO update to extract only sid diff --git a/python/src/trezorlib/transport/thp/alternating_bit_protocol.py b/python/src/trezorlib/transport/thp/alternating_bit_protocol.py new file mode 100644 index 00000000000..62fb650fab0 --- /dev/null +++ b/python/src/trezorlib/transport/thp/alternating_bit_protocol.py @@ -0,0 +1,102 @@ +# from storage.cache_thp import ChannelCache +# from trezor import log +# from trezor.wire.thp import ThpError + + +# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool: +# """ +# Checks if: +# - an ACK message is expected +# - the received ACK message acknowledges correct sequence number (bit) +# """ +# if not _is_ack_expected(cache): +# return False + +# if not _has_ack_correct_sync_bit(cache, ack_bit): +# return False + +# return True + + +# def _is_ack_expected(cache: ChannelCache) -> bool: +# is_expected: bool = not is_sending_allowed(cache) +# if __debug__ and not is_expected: +# log.debug(__name__, "Received unexpected ACK message") +# return is_expected + + +# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool: +# is_correct: bool = get_send_seq_bit(cache) == sync_bit +# if __debug__ and not is_correct: +# log.debug(__name__, "Received ACK message with wrong ack bit") +# return is_correct + + +# def is_sending_allowed(cache: ChannelCache) -> bool: +# """ +# Checks whether sending a message in the provided channel is allowed. + +# Note: Sending a message in a channel before receipt of ACK message for the previously +# sent message (in the channel) is prohibited, as it can lead to desynchronization. +# """ +# return bool(cache.sync >> 7) + + +# def get_send_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the sequential number (bit) of the next message to be sent +# in the provided channel. +# """ +# return (cache.sync & 0x20) >> 5 + + +# def get_expected_receive_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the (expected) sequential number (bit) of the next message +# to be received in the provided channel. +# """ +# return (cache.sync & 0x40) >> 6 + + +# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None: +# """ +# Set the flag whether sending a message in this channel is allowed or not. +# """ +# cache.sync &= 0x7F +# if sending_allowed: +# cache.sync |= 0x80 + + +# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# """ +# Set the expected sequential number (bit) of the next message to be received +# in the provided channel +# """ +# if __debug__: +# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit) +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected receive sync bit") + +# # set second bit to "seq_bit" value +# cache.sync &= 0xBF +# if seq_bit: +# cache.sync |= 0x40 + + +# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected send seq bit") +# if __debug__: +# log.debug(__name__, "setting sync send seq bit to %d", seq_bit) +# # set third bit to "seq_bit" value +# cache.sync &= 0xDF +# if seq_bit: +# cache.sync |= 0x20 + + +# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None: +# """ +# Set the sequential bit of the "next message to be send" to the opposite value, +# i.e. 1 -> 0 and 0 -> 1 +# """ +# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache)) diff --git a/python/src/trezorlib/transport/thp/channel_data.py b/python/src/trezorlib/transport/thp/channel_data.py new file mode 100644 index 00000000000..3d70deecafd --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_data.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from binascii import hexlify + + +class ChannelData: + def __init__( + self, + protocol_version: int, + transport_path: str, + channel_id: int, + key_request: bytes, + key_response: bytes, + nonce_request: int, + nonce_response: int, + sync_bit_send: int, + sync_bit_receive: int, + ) -> None: + self.protocol_version: int = protocol_version + self.transport_path: str = transport_path + self.channel_id: int = channel_id + self.key_request: str = hexlify(key_request).decode() + self.key_response: str = hexlify(key_response).decode() + self.nonce_request: int = nonce_request + self.nonce_response: int = nonce_response + self.sync_bit_receive: int = sync_bit_receive + self.sync_bit_send: int = sync_bit_send + + def to_dict(self): + return { + "protocol_version": self.protocol_version, + "transport_path": self.transport_path, + "channel_id": self.channel_id, + "key_request": self.key_request, + "key_response": self.key_response, + "nonce_request": self.nonce_request, + "nonce_response": self.nonce_response, + "sync_bit_send": self.sync_bit_send, + "sync_bit_receive": self.sync_bit_receive, + } diff --git a/python/src/trezorlib/transport/thp/channel_database.py b/python/src/trezorlib/transport/thp/channel_database.py new file mode 100644 index 00000000000..143430069fb --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_database.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import json +import logging +import os +import typing as t + +from ..thp.channel_data import ChannelData +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +db: "ChannelDatabase | None" = None + + +def get_channel_db() -> ChannelDatabase: + if db is None: + set_channel_database(should_not_store=True) + assert db is not None + return db + + +class ChannelDatabase: + + def load_stored_channels(self) -> t.List[ChannelData]: ... + def clear_stored_channels(self) -> None: ... + def read_all_channels(self) -> t.List: ... + def save_all_channels(self, channels: t.List[t.Dict]) -> None: ... + def save_channel(self, new_channel: ProtocolAndChannel): ... + def remove_channel(self, transport_path: str) -> None: ... + + +class DummyChannelDatabase(ChannelDatabase): + + def load_stored_channels(self) -> t.List[ChannelData]: + return [] + + def clear_stored_channels(self) -> None: + pass + + def read_all_channels(self) -> t.List: + return [] + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + return + + def save_channel(self, new_channel: ProtocolAndChannel): + pass + + def remove_channel(self, transport_path: str) -> None: + pass + + +class JsonChannelDatabase(ChannelDatabase): + def __init__(self, data_path: str) -> None: + self.data_path = data_path + super().__init__() + + def load_stored_channels(self) -> t.List[ChannelData]: + dicts = self.read_all_channels() + return [dict_to_channel_data(d) for d in dicts] + + def clear_stored_channels(self) -> None: + LOG.debug("Clearing contents of %s", self.data_path) + with open(self.data_path, "w") as f: + json.dump([], f) + try: + os.remove(self.data_path) + except Exception as e: + LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e))) + + def read_all_channels(self) -> t.List: + ensure_file_exists(self.data_path) + with open(self.data_path, "r") as f: + return json.load(f) + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + LOG.debug("saving all channels") + with open(self.data_path, "w") as f: + json.dump(channels, f, indent=4) + + def save_channel(self, new_channel: ProtocolAndChannel): + + LOG.debug("save channel") + channels = self.read_all_channels() + transport_path = new_channel.transport.get_path() + + # If the channel is found in database: replace the old entry by the new + for i, channel in enumerate(channels): + if channel["transport_path"] == transport_path: + LOG.debug("Modified channel entry for %s", transport_path) + channels[i] = new_channel.get_channel_data().to_dict() + self.save_all_channels(channels) + return + + # Channel was not found: add a new channel entry + LOG.debug("Created a new channel entry on path %s", transport_path) + channels.append(new_channel.get_channel_data().to_dict()) + self.save_all_channels(channels) + + def remove_channel(self, transport_path: str) -> None: + LOG.debug( + "Removing channel with path %s from the channel database.", + transport_path, + ) + channels = self.read_all_channels() + remaining_channels = [ + ch for ch in channels if ch["transport_path"] != transport_path + ] + self.save_all_channels(remaining_channels) + + +def dict_to_channel_data(dict: t.Dict) -> ChannelData: + return ChannelData( + protocol_version=dict["protocol_version"], + transport_path=dict["transport_path"], + channel_id=dict["channel_id"], + key_request=bytes.fromhex(dict["key_request"]), + key_response=bytes.fromhex(dict["key_response"]), + nonce_request=dict["nonce_request"], + nonce_response=dict["nonce_response"], + sync_bit_send=dict["sync_bit_send"], + sync_bit_receive=dict["sync_bit_receive"], + ) + + +def ensure_file_exists(file_path: str) -> None: + LOG.debug("checking if file %s exists", file_path) + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + LOG.debug("File %s does not exist. Creating a new one.", file_path) + with open(file_path, "w") as f: + json.dump([], f) + + +def set_channel_database(should_not_store: bool): + global db + if should_not_store: + db = DummyChannelDatabase() + else: + from platformdirs import user_cache_dir + + APP_NAME = "@trezor" # TODO + DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json") + + db = JsonChannelDatabase(DATA_PATH) diff --git a/python/src/trezorlib/transport/thp/checksum.py b/python/src/trezorlib/transport/thp/checksum.py new file mode 100644 index 00000000000..8e0f32f0132 --- /dev/null +++ b/python/src/trezorlib/transport/thp/checksum.py @@ -0,0 +1,19 @@ +import zlib + +CHECKSUM_LENGTH = 4 + + +def compute(data: bytes) -> bytes: + """ + Returns a CRC-32 checksum of the provided `data`. + """ + return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + + +def is_valid(checksum: bytes, data: bytes) -> bool: + """ + Checks whether the CRC-32 checksum of the `data` is the same + as the checksum provided in `checksum`. + """ + data_checksum = compute(data) + return checksum == data_checksum diff --git a/python/src/trezorlib/transport/thp/control_byte.py b/python/src/trezorlib/transport/thp/control_byte.py new file mode 100644 index 00000000000..ce7f6066f98 --- /dev/null +++ b/python/src/trezorlib/transport/thp/control_byte.py @@ -0,0 +1,59 @@ +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + + +def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int: + if seq_bit == 0: + return ctrl_byte & 0xEF + if seq_bit == 1: + return ctrl_byte | 0x10 + raise Exception("Unexpected sequence bit") + + +def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int: + if ack_bit == 0: + return ctrl_byte & 0xF7 + if ack_bit == 1: + return ctrl_byte | 0x08 + raise Exception("Unexpected acknowledgement bit") + + +def get_seq_bit(ctrl_byte: int) -> int: + return (ctrl_byte & 0x10) >> 4 + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & ACK_MASK == ACK_MESSAGE + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ diff --git a/python/src/trezorlib/transport/thp/curve25519.py b/python/src/trezorlib/transport/thp/curve25519.py new file mode 100644 index 00000000000..43127c49e57 --- /dev/null +++ b/python/src/trezorlib/transport/thp/curve25519.py @@ -0,0 +1,116 @@ +from typing import Tuple + +p = 2**255 - 19 +J = 486662 + +c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1) +c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8 +a24 = 121666 # (J + 2) // 4 + + +def decode_scalar(scalar: bytes) -> int: + # decodeScalar25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + if len(scalar) != 32: + raise ValueError("Invalid length of scalar") + + array = bytearray(scalar) + array[0] &= 248 + array[31] &= 127 + array[31] |= 64 + + return int.from_bytes(array, "little") + + +def decode_coordinate(coordinate: bytes) -> int: + # decodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + if len(coordinate) != 32: + raise ValueError("Invalid length of coordinate") + + array = bytearray(coordinate) + array[-1] &= 0x7F + return int.from_bytes(array, "little") % p + + +def encode_coordinate(coordinate: int) -> bytes: + # encodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + return coordinate.to_bytes(32, "little") + + +def get_private_key(secret: bytes) -> bytes: + return decode_scalar(secret).to_bytes(32, "little") + + +def get_public_key(private_key: bytes) -> bytes: + base_point = int.to_bytes(9, 32, "little") + return multiply(private_key, base_point) + + +def multiply(private_scalar: bytes, public_point: bytes): + # X25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + def ladder_operation( + x1: int, x2: int, z2: int, x3: int, z3: int + ) -> Tuple[int, int, int, int]: + # https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3 + # (x4, z4) = 2 * (x2, z2) + # (x5, z5) = (x2, z2) + (x3, z3) + # where (x1, 1) = (x3, z3) - (x2, z2) + + a = (x2 + z2) % p + aa = (a * a) % p + b = (x2 - z2) % p + bb = (b * b) % p + e = (aa - bb) % p + c = (x3 + z3) % p + d = (x3 - z3) % p + da = (d * a) % p + cb = (c * b) % p + t0 = (da + cb) % p + x5 = (t0 * t0) % p + t1 = (da - cb) % p + t2 = (t1 * t1) % p + z5 = (x1 * t2) % p + x4 = (aa * bb) % p + t3 = (a24 * e) % p + t4 = (bb + t3) % p + z4 = (e * t4) % p + + return x4, z4, x5, z5 + + def conditional_swap(first: int, second: int, condition: int): + # Returns (second, first) if condition is true and (first, second) otherwise + # Must be implemented in a way that it is constant time + true_mask = -condition + false_mask = ~true_mask + return (first & false_mask) | (second & true_mask), (second & false_mask) | ( + first & true_mask + ) + + k = decode_scalar(private_scalar) + u = decode_coordinate(public_point) + + x_1 = u + x_2 = 1 + z_2 = 0 + x_3 = u + z_3 = 1 + swap = 0 + + for i in reversed(range(256)): + bit = (k >> i) & 1 + swap = bit ^ swap + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + swap = bit + x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3) + + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + + x = pow(z_2, p - 2, p) * x_2 % p + return encode_coordinate(x) diff --git a/python/src/trezorlib/transport/thp/message_header.py b/python/src/trezorlib/transport/thp/message_header.py new file mode 100644 index 00000000000..d2ff002d636 --- /dev/null +++ b/python/src/trezorlib/transport/thp/message_header.py @@ -0,0 +1,82 @@ +import struct + +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + +BROADCAST_CHANNEL_ID = 0xFFFF + + +class MessageHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.data_length = length + + def to_bytes_init(self) -> bytes: + return struct.pack( + self.format_str_init, self.ctrl_byte, self.cid, self.data_length + ) + + def to_bytes_cont(self) -> bytes: + return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.data_length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + def is_ack(self) -> bool: + return self.ctrl_byte & ACK_MASK == ACK_MESSAGE + + def is_channel_allocation_response(self): + return ( + self.cid == BROADCAST_CHANNEL_ID + and self.ctrl_byte == _CHANNEL_ALLOCATION_RES + ) + + def is_handshake_init_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES + + def is_handshake_comp_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES + + def is_encrypted_transport(self) -> bool: + return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + @classmethod + def get_error_header(cls, cid: int, length: int): + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_request_header(cls, length: int): + return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length) diff --git a/python/src/trezorlib/transport/thp/protocol_and_channel.py b/python/src/trezorlib/transport/thp/protocol_and_channel.py new file mode 100644 index 00000000000..fa420ac0af2 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_and_channel.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import logging + +from ... import messages +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp.channel_data import ChannelData + +LOG = logging.getLogger(__name__) + + +class ProtocolAndChannel: + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.transport = transport + self.mapping = mapping + self.channel_keys = channel_data + + def get_features(self) -> messages.Features: + raise NotImplementedError() + + def get_channel_data(self) -> ChannelData: + raise NotImplementedError + + def update_features(self) -> None: + raise NotImplementedError diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py new file mode 100644 index 00000000000..baea7e74010 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +import struct +import typing as t + +from ... import exceptions, messages +from ...log import DUMP_BYTES +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + + +class ProtocolV1(ProtocolAndChannel): + HEADER_LEN = struct.calcsize(">HL") + _features: messages.Features | None = None + + def get_features(self) -> messages.Features: + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + self.write(messages.GetFeatures()) + resp = self.read() + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = resp + + def read(self) -> t.Any: + msg_type, msg_bytes = self._read() + LOG.log( + DUMP_BYTES, + f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + msg = self.mapping.decode(msg_type, msg_bytes) + LOG.debug( + f"received message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + self.transport.close() + return msg + + def write(self, msg: t.Any) -> None: + LOG.debug( + f"sending message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = self.mapping.encode(msg) + LOG.log( + DUMP_BYTES, + f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + self._write(msg_type, msg_bytes) + + def _write(self, message_type: int, message_data: bytes) -> None: + chunk_size = self.transport.CHUNK_SIZE + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) + + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: chunk_size - 1] + chunk = chunk.ljust(chunk_size, b"\x00") + self.transport.write_chunk(chunk) + buffer = buffer[63:] + + def _read(self) -> t.Tuple[int, bytes]: + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first() + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next()) + + return msg_type, buffer[:datalen] + + def read_first(self) -> t.Tuple[int, int, bytes]: + chunk = self.transport.read_chunk() + if chunk[:3] != b"?##": + raise RuntimeError("Unexpected magic characters") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self) -> bytes: + chunk = self.transport.read_chunk() + if chunk[:1] != b"?": + raise RuntimeError("Unexpected magic characters") + return chunk[1:] diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py new file mode 100644 index 00000000000..07ff2cadd43 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +import typing as t +from binascii import hexlify +from enum import IntEnum + +import click +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from ... import exceptions, messages +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp import checksum, curve25519, thp_io +from ..thp.channel_data import ChannelData +from ..thp.checksum import CHECKSUM_LENGTH +from ..thp.message_header import MessageHeader +from . import control_byte +from .channel_database import ChannelDatabase, get_channel_db +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +MANAGEMENT_SESSION_ID: int = 0 + + +def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: + hash = hashlib.sha256(val_1) + hash.update(val_2) + return hash.digest() + + +def _hkdf(chaining_key: bytes, input: bytes): + temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest() + output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest() + ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _get_iv_from_nonce(nonce: int) -> bytes: + if not nonce <= 0xFFFFFFFFFFFFFFFF: + raise ValueError("Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") + + +class ProtocolV2(ProtocolAndChannel): + channel_id: int + channel_database: ChannelDatabase + key_request: bytes + key_response: bytes + nonce_request: int + nonce_response: int + sync_bit_send: int + sync_bit_receive: int + + _has_valid_channel: bool = False + _features: messages.Features | None = None + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.channel_database: ChannelDatabase = get_channel_db() + super().__init__(transport, mapping, channel_data) + if channel_data is not None: + self.channel_id = channel_data.channel_id + self.key_request = bytes.fromhex(channel_data.key_request) + self.key_response = bytes.fromhex(channel_data.key_response) + self.nonce_request = channel_data.nonce_request + self.nonce_response = channel_data.nonce_response + self.sync_bit_receive = channel_data.sync_bit_receive + self.sync_bit_send = channel_data.sync_bit_send + self._has_valid_channel = True + + def get_channel(self) -> ProtocolV2: + if not self._has_valid_channel: + self._establish_new_channel() + return self + + def get_channel_data(self) -> ChannelData: + return ChannelData( + protocol_version=2, + transport_path=self.transport.get_path(), + channel_id=self.channel_id, + key_request=self.key_request, + key_response=self.key_response, + nonce_request=self.nonce_request, + nonce_response=self.nonce_response, + sync_bit_receive=self.sync_bit_receive, + sync_bit_send=self.sync_bit_send, + ) + + def read(self, session_id: int) -> t.Any: + sid, msg_type, msg_data = self.read_and_decrypt() + if sid != session_id: + raise Exception("Received messsage on a different session.") + self.channel_database.save_channel(self) + return self.mapping.decode(msg_type, msg_data) + + def write(self, session_id: int, msg: t.Any) -> None: + msg_type, msg_data = self.mapping.encode(msg) + self._encrypt_and_write(session_id, msg_type, msg_data) + self.channel_database.save_channel(self) + + def get_features(self) -> messages.Features: + if not self._has_valid_channel: + self._establish_new_channel() + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + message = messages.GetFeatures() + message_type, message_data = self.mapping.encode(message) + self.session_id: int = 0 + self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) + _ = self._read_until_valid_crc_check() # TODO check ACK + _, msg_type, msg_data = self.read_and_decrypt() + features = self.mapping.decode(msg_type, msg_data) + if not isinstance(features, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = features + + def _establish_new_channel(self) -> None: + self.sync_bit_send = 0 + self.sync_bit_receive = 0 + # Send channel allocation request + # Note that [:8] on the following line is required when tests use + # WITH_MOCK_URANDOM. Without [:8] such tests will (almost always) fail. + channel_id_request_nonce = os.urandom(8)[:8] + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + MessageHeader.get_channel_allocation_request_header(12), + channel_id_request_nonce, + ) + + # Read channel allocation response + header, payload = self._read_until_valid_crc_check() + if not self._is_valid_channel_allocation_response( + header, payload, channel_id_request_nonce + ): + # TODO raise exception here, I guess + raise Exception("Invalid channel allocation response.") + + self.channel_id = int.from_bytes(payload[8:10], "big") + self.device_properties = payload[10:] + + # Send handshake init request + ha_init_req_header = MessageHeader(0, self.channel_id, 36) + # Note that [:32] on the following line is required when tests use + # WITH_MOCK_URANDOM. Without [:32] such tests will (almost always) fail. + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)[:32]) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, ha_init_req_header, host_ephemeral_pubkey + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + + # Read handshake init response + header, payload = self._read_until_valid_crc_check() + self._send_ack_0() + + if not header.is_handshake_init_response(): + click.echo( + "Received message is not a valid handshake init response message", + err=True, + ) + + trezor_ephemeral_pubkey = payload[:32] + encrypted_trezor_static_pubkey = payload[32:80] + noise_tag = payload[80:96] + + # TODO check noise tag + LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) + + # Prepare and send handshake completion request + PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) + h = _sha256_of_two(h, host_ephemeral_pubkey) + h = _sha256_of_two(h, trezor_ephemeral_pubkey) + ck, k = _hkdf( + PROTOCOL_NAME, + curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), + ) + + aes_ctx = AESGCM(k) + try: + trezor_masked_static_pubkey = aes_ctx.decrypt( + IV_1, encrypted_trezor_static_pubkey, h + ) + except Exception as e: + click.echo( + f"Exception of type{type(e)}", err=True + ) # TODO how to handle potential exceptions? Q for Matejcik + h = _sha256_of_two(h, encrypted_trezor_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) + ) + aes_ctx = AESGCM(k) + + tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) + h = _sha256_of_two(h, tag_of_empty_string) + # TODO: search for saved credentials (or possibly not, as we skip pairing phase) + + zeroes_32 = int.to_bytes(0, 32, "little") + temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey) + aes_ctx = AESGCM(k) + encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h) + h = _sha256_of_two(h, encrypted_host_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey) + ) + msg_data = self.mapping.encode_without_wire_type( + messages.ThpHandshakeCompletionReqNoisePayload( + pairing_methods=[ + messages.ThpPairingMethod.NoMethod, + ] + ) + ) + + aes_ctx = AESGCM(k) + + encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) + h = _sha256_of_two(h, encrypted_payload) + ha_completion_req_header = MessageHeader( + 0x12, + self.channel_id, + len(encrypted_host_static_pubkey) + + len(encrypted_payload) + + CHECKSUM_LENGTH, + ) + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + ha_completion_req_header, + encrypted_host_static_pubkey + encrypted_payload, + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + + # Read handshake completion response, ignore payload as we do not care about the state + header, _ = self._read_until_valid_crc_check() + if not header.is_handshake_comp_response(): + click.echo( + "Received message is not a valid handshake completion response", + err=True, + ) + self._send_ack_1() + + self.key_request, self.key_response = _hkdf(ck, b"") + self.nonce_request = 0 + self.nonce_response = 1 + + # Send StartPairingReqest message + message = messages.ThpStartPairingRequest() + message_type, message_data = self.mapping.encode(message) + + self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + + # Read + _, msg_type, msg_data = self.read_and_decrypt() + maaa = self.mapping.decode(msg_type, msg_data) + + assert isinstance(maaa, messages.ThpEndResponse) + self._has_valid_channel = True + + def _send_ack_0(self): + LOG.debug("sending ack 0") + header = MessageHeader(0x20, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _send_ack_1(self): + LOG.debug("sending ack 1") + header = MessageHeader(0x28, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _encrypt_and_write( + self, + session_id: int, + message_type: int, + message_data: bytes, + ctrl_byte: int | None = None, + ) -> None: + assert self.key_request is not None + aes_ctx = AESGCM(self.key_request) + + if ctrl_byte is None: + ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send) + self.sync_bit_send = 1 - self.sync_bit_send + + sid = session_id.to_bytes(1, "big") + msg_type = message_type.to_bytes(2, "big") + data = sid + msg_type + message_data + nonce = _get_iv_from_nonce(self.nonce_request) + self.nonce_request += 1 + encrypted_message = aes_ctx.encrypt(nonce, data, b"") + header = MessageHeader( + ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH + ) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, header, encrypted_message + ) + + def read_and_decrypt(self) -> t.Tuple[int, int, bytes]: + header, raw_payload = self._read_until_valid_crc_check() + if control_byte.is_ack(header.ctrl_byte): + return self.read_and_decrypt() + if not header.is_encrypted_transport(): + click.echo( + "Trying to decrypt not encrypted message!" + + hexlify(header.to_bytes_init() + raw_payload).decode(), + err=True, + ) + + if not control_byte.is_ack(header.ctrl_byte): + LOG.debug( + "--> Get sequence bit %d %s %s", + control_byte.get_seq_bit(header.ctrl_byte), + "from control byte", + hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(), + ) + if control_byte.get_seq_bit(header.ctrl_byte): + self._send_ack_1() + else: + self._send_ack_0() + aes_ctx = AESGCM(self.key_response) + nonce = _get_iv_from_nonce(self.nonce_response) + self.nonce_response += 1 + + message = aes_ctx.decrypt(nonce, raw_payload, b"") + session_id = message[0] + message_type = message[1:3] + message_data = message[3:] + return ( + session_id, + int.from_bytes(message_type, "big"), + message_data, + ) + + def _read_until_valid_crc_check( + self, + ) -> t.Tuple[MessageHeader, bytes]: + is_valid = False + header, payload, chksum = thp_io.read(self.transport) + while not is_valid: + is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) + if not is_valid: + click.echo( + "Received a message with an invalid checksum:" + + hexlify(header.to_bytes_init() + payload + chksum).decode(), + err=True, + ) + header, payload, chksum = thp_io.read(self.transport) + + return header, payload + + def _is_valid_channel_allocation_response( + self, header: MessageHeader, payload: bytes, original_nonce: bytes + ) -> bool: + if not header.is_channel_allocation_response(): + click.echo( + "Received message is not a channel allocation response", err=True + ) + return False + if len(payload) < 10: + click.echo("Invalid channel allocation response payload", err=True) + return False + if payload[:8] != original_nonce: + click.echo( + "Invalid channel allocation response payload (nonce mismatch)", err=True + ) + return False + return True + + class ControlByteType(IntEnum): + CHANNEL_ALLOCATION_RES = 1 + HANDSHAKE_INIT_RES = 2 + HANDSHAKE_COMP_RES = 3 + ACK = 4 + ENCRYPTED_TRANSPORT = 5 diff --git a/python/src/trezorlib/transport/thp/thp_io.py b/python/src/trezorlib/transport/thp/thp_io.py new file mode 100644 index 00000000000..d0237f9e36d --- /dev/null +++ b/python/src/trezorlib/transport/thp/thp_io.py @@ -0,0 +1,93 @@ +import struct +from typing import Tuple + +from .. import Transport +from ..thp import checksum +from .message_header import MessageHeader + +INIT_HEADER_LENGTH = 5 +CONT_HEADER_LENGTH = 3 +MAX_PAYLOAD_LEN = 60000 +MESSAGE_TYPE_LENGTH = 2 + +CONTINUATION_PACKET = 0x80 + + +def write_payload_to_wire_and_add_checksum( + transport: Transport, header: MessageHeader, transport_payload: bytes +): + chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload) + data = transport_payload + chksum + write_payload_to_wire(transport, header, data) + + +def write_payload_to_wire( + transport: Transport, header: MessageHeader, transport_payload: bytes +): + transport.open() + buffer = bytearray(transport_payload) + chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH] + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + + buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :] + while buffer: + chunk = ( + header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH] + ) + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :] + + +def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]: + """ + Reads from the given wire transport. + + Returns `Tuple[MessageHeader, bytes, bytes]`: + 1. `header` (`MessageHeader`): Header of the message. + 2. `data` (`bytes`): Contents of the message (if any). + 3. `checksum` (`bytes`): crc32 checksum of the header + data. + + """ + buffer = bytearray() + + # Read header with first part of message data + header, first_chunk = read_first(transport) + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < header.data_length: + buffer.extend(read_next(transport, header.cid)) + + data_len = header.data_length - checksum.CHECKSUM_LENGTH + msg_data = buffer[:data_len] + chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH] + + return (header, msg_data, chksum) + + +def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]: + chunk = transport.read_chunk() + try: + ctrl_byte, cid, data_length = struct.unpack( + MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH] + ) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[INIT_HEADER_LENGTH:] + return MessageHeader(ctrl_byte, cid, data_length), data + + +def read_next(transport: Transport, cid: int) -> bytes: + chunk = transport.read_chunk() + ctrl_byte, read_cid = struct.unpack( + MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH] + ) + if ctrl_byte != CONTINUATION_PACKET: + raise RuntimeError("Continuation packet with incorrect control byte") + if read_cid != cid: + raise RuntimeError("Continuation packet for different channel") + + return chunk[CONT_HEADER_LENGTH:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7e4c4614c63..2960df89945 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,14 +14,15 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import socket import time -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Tuple from ..log import DUMP_PACKETS -from . import TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import Transport, TransportException if TYPE_CHECKING: from ..models import TrezorModel @@ -31,14 +32,18 @@ LOG = logging.getLogger(__name__) -class UdpTransport(ProtocolBasedTransport): +class UdpTransport(Transport): DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 21324 PATH_PREFIX = "udp" ENABLED: bool = True + CHUNK_SIZE = 64 - def __init__(self, device: Optional[str] = None) -> None: + def __init__( + self, + device: str | None = None, + ) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -46,24 +51,17 @@ def __init__(self, device: Optional[str] = None) -> None: devparts = device.split(":") host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT - self.device = (host, port) - self.socket: Optional[socket.socket] = None - - super().__init__(protocol=ProtocolV1(self)) - - def get_path(self) -> str: - return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) + self.device: Tuple[str, int] = (host, port) - def find_debug(self) -> "UdpTransport": - host, port = self.device - return UdpTransport(f"{host}:{port + 1}") + self.socket: socket.socket | None = None + super().__init__() @classmethod def _try_path(cls, path: str) -> "UdpTransport": d = cls(path) try: d.open() - if d._ping(): + if d.ping(): return d else: raise TransportException( @@ -77,7 +75,7 @@ def _try_path(cls, path: str) -> "UdpTransport": @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None + cls, _models: Iterable["TrezorModel"] | None = None ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: @@ -99,20 +97,8 @@ def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport": else: raise TransportException(f"No UDP device at {path}") - def wait_until_ready(self, timeout: float = 10) -> None: - try: - self.open() - start = time.monotonic() - while True: - if self._ping(): - break - elapsed = time.monotonic() - start - if elapsed >= timeout: - raise TransportException("Timed out waiting for connection.") - - time.sleep(0.05) - finally: - self.close() + def get_path(self) -> str: + return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) def open(self) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -124,18 +110,9 @@ def close(self) -> None: self.socket.close() self.socket = None - def _ping(self) -> bool: - """Test if the device is listening.""" - assert self.socket is not None - resp = None - try: - self.socket.sendall(b"PINGPING") - resp = self.socket.recv(8) - except Exception: - pass - return resp == b"PONGPONG" - def write_chunk(self, chunk: bytes) -> None: + if self.socket is None: + self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -143,6 +120,8 @@ def write_chunk(self, chunk: bytes) -> None: self.socket.sendall(chunk) def read_chunk(self) -> bytes: + if self.socket is None: + self.open() assert self.socket is not None while True: try: @@ -154,3 +133,33 @@ def read_chunk(self) -> bytes: if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return bytearray(chunk) + + def find_debug(self) -> "UdpTransport": + host, port = self.device + return UdpTransport(f"{host}:{port + 1}") + + def wait_until_ready(self, timeout: float = 10) -> None: + try: + self.open() + start = time.monotonic() + while True: + if self.ping(): + break + elapsed = time.monotonic() - start + if elapsed >= timeout: + raise TransportException("Timed out waiting for connection.") + + time.sleep(0.05) + finally: + self.close() + + def ping(self) -> bool: + """Test if the device is listening.""" + assert self.socket is not None + resp = None + try: + self.socket.sendall(b"PINGPING") + resp = self.socket.recv(8) + except Exception: + pass + return resp == b"PONGPONG" diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 8e2d08147a6..023ed5f2455 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,16 +14,17 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import atexit import logging import sys import time -from typing import Iterable, List, Optional +from typing import Iterable, List from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel -from . import UDEV_RULES_STR, DeviceIsBusy, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException LOG = logging.getLogger(__name__) @@ -44,13 +45,69 @@ WEBUSB_CHUNK_SIZE = 64 -class WebUsbHandle: - def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: +class WebUsbTransport(Transport): + """ + WebUsbTransport implements transport over WebUSB interface. + """ + + PATH_PREFIX = "webusb" + ENABLED = USB_IMPORTED + context = None + CHUNK_SIZE = 64 + + def __init__( + self, + device: "usb1.USBDevice", + debug: bool = False, + ) -> None: + self.device = device + self.debug = debug + self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT - self.count = 0 - self.handle: Optional["usb1.USBDeviceHandle"] = None + self.handle: usb1.USBDeviceHandle | None = None + + super().__init__() + + @classmethod + def enumerate( + cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["WebUsbTransport"] = [] + for dev in cls.context.getDeviceIterator(skip_on_error=True): + usb_id = (dev.getVendorID(), dev.getProductID()) + if usb_id not in usb_ids: + continue + if not is_vendor_class(dev): + continue + if usb_reset: + handle = dev.open() + handle.resetDevice() + handle.close() + continue + try: + # workaround for issue #223: + # on certain combinations of Windows USB drivers and libusb versions, + # Trezor is returned twice (possibly because Windows know it as both + # a HID and a WebUSB device), and one of the returned devices is + # non-functional. + dev.getProduct() + devices.append(WebUsbTransport(dev)) + except usb1.USBErrorNotSupported: + pass + return devices + + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" def open(self) -> None: self.handle = self.device.open() @@ -64,6 +121,8 @@ def open(self) -> None: self.handle.claimInterface(self.interface) except usb1.USBErrorAccess as e: raise DeviceIsBusy(self.device) from e + except usb1.USBErrorBusy as e: + raise DeviceIsBusy(self.device) from e def close(self) -> None: if self.handle is not None: @@ -75,6 +134,8 @@ def close(self) -> None: self.handle = None def write_chunk(self, chunk: bytes) -> None: + if self.handle is None: + self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -97,6 +158,8 @@ def write_chunk(self, chunk: bytes) -> None: return def read_chunk(self) -> bytes: + if self.handle is None: + self.open() assert self.handle is not None endpoint = 0x80 | self.endpoint while True: @@ -117,70 +180,6 @@ def read_chunk(self) -> bytes: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return chunk - -class WebUsbTransport(ProtocolBasedTransport): - """ - WebUsbTransport implements transport over WebUSB interface. - """ - - PATH_PREFIX = "webusb" - ENABLED = USB_IMPORTED - context = None - - def __init__( - self, - device: "usb1.USBDevice", - handle: Optional[WebUsbHandle] = None, - debug: bool = False, - ) -> None: - if handle is None: - handle = WebUsbHandle(device, debug) - - self.device = device - self.handle = handle - self.debug = debug - - super().__init__(protocol=ProtocolV1(handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False - ) -> Iterable["WebUsbTransport"]: - if cls.context is None: - cls.context = usb1.USBContext() - cls.context.open() - atexit.register(cls.context.close) - - if models is None: - models = TREZORS - usb_ids = [id for model in models for id in model.usb_ids] - devices: List["WebUsbTransport"] = [] - for dev in cls.context.getDeviceIterator(skip_on_error=True): - usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in usb_ids: - continue - if not is_vendor_class(dev): - continue - try: - # workaround for issue #223: - # on certain combinations of Windows USB drivers and libusb versions, - # Trezor is returned twice (possibly because Windows know it as both - # a HID and a WebUSB device), and one of the returned devices is - # non-functional. - dev.getProduct() - devices.append(WebUsbTransport(dev)) - except usb1.USBErrorNotSupported: - pass - except usb1.USBErrorPipe: - if usb_reset: - handle = dev.open() - handle.resetDevice() - handle.close() - return devices - def find_debug(self) -> "WebUsbTransport": # For v1 protocol, find debug USB interface for the same serial number return WebUsbTransport(self.device, debug=True) diff --git a/python/tools/encfs_aes_getpass.py b/python/tools/encfs_aes_getpass.py index 82773e50fa7..37a221154cc 100755 --- a/python/tools/encfs_aes_getpass.py +++ b/python/tools/encfs_aes_getpass.py @@ -35,7 +35,6 @@ from trezorlib.client import TrezorClient from trezorlib.tools import Address from trezorlib.transport import enumerate_devices -from trezorlib.ui import ClickUI version_tuple = tuple(map(int, trezorlib.__version__.split("."))) if not (0, 11) <= version_tuple < (0, 14): @@ -71,7 +70,7 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport": sys.stderr.write("Available devices:\n") for d in devices: try: - client = TrezorClient(d, ui=ClickUI()) + client = TrezorClient(d) except IOError: sys.stderr.write("[-] \n") continue @@ -80,7 +79,7 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport": sys.stderr.write(f"[{i}] {client.features.label}\n") else: sys.stderr.write(f"[{i}] \n") - client.close() + # TODO client.close() i += 1 sys.stderr.write("----------------------------\n") @@ -106,7 +105,8 @@ def main() -> None: devices = wait_for_devices() transport = choose_device(devices) - client = TrezorClient(transport, ui=ClickUI()) + client = TrezorClient(transport) + session = client.get_management_session() rootdir = os.environ["encfs_root"] # Read "man encfs" for more passw_file = os.path.join(rootdir, "password.dat") @@ -120,7 +120,7 @@ def main() -> None: sys.stderr.write("Computer asked Trezor for new strong password.\n") # 32 bytes, good for AES - trezor_entropy = trezorlib.misc.get_entropy(client, 32) + trezor_entropy = trezorlib.misc.get_entropy(session, 32) urandom_entropy = os.urandom(32) passw = hashlib.sha256(trezor_entropy + urandom_entropy).digest() @@ -129,7 +129,7 @@ def main() -> None: bip32_path = Address([10, 0]) passw_encrypted = trezorlib.misc.encrypt_keyvalue( - client, bip32_path, label, passw, False, True + session, bip32_path, label, passw, False, True ) data = { @@ -144,7 +144,7 @@ def main() -> None: data = json.load(open(passw_file, "r")) passw = trezorlib.misc.decrypt_keyvalue( - client, + session, data["bip32_path"], data["label"], bytes.fromhex(data["password_encrypted_hex"]), diff --git a/python/tools/helloworld.py b/python/tools/helloworld.py index 76b4502da2d..b8711dbb00a 100755 --- a/python/tools/helloworld.py +++ b/python/tools/helloworld.py @@ -24,13 +24,14 @@ def main() -> None: # Use first connected device client = get_default_client() + session = client.get_session(derive_cardano=True) # Print out Trezor's features and settings - print(client.features) + print(session.features) # Get the first address of first BIP44 account bip32_path = parse_path("44h/0h/0h/0/0") - address = btc.get_address(client, "Bitcoin", bip32_path, True) + address = btc.get_address(session, "Bitcoin", bip32_path, False) print("Bitcoin address:", address) diff --git a/python/tools/pwd_reader.py b/python/tools/pwd_reader.py index afd405e1642..1c012c7abf4 100755 --- a/python/tools/pwd_reader.py +++ b/python/tools/pwd_reader.py @@ -26,23 +26,24 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from trezorlib import misc, ui +from trezorlib import misc from trezorlib.client import TrezorClient from trezorlib.tools import parse_path from trezorlib.transport import get_transport +from trezorlib.transport.session import Session # Return path by BIP-32 BIP32_PATH = parse_path("10016h/0") # Deriving master key -def getMasterKey(client: TrezorClient) -> str: +def getMasterKey(session: Session) -> str: bip32_path = BIP32_PATH ENC_KEY = "Activate TREZOR Password Manager?" ENC_VALUE = bytes.fromhex( "2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee" ) - key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True) + key = misc.encrypt_keyvalue(session, bip32_path, ENC_KEY, ENC_VALUE, True, True) return key.hex() @@ -101,7 +102,7 @@ def decryptEntryValue(nonce: str, val: bytes) -> dict: # Decrypt give entry nonce -def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: +def getDecryptedNonce(session: Session, entry: dict) -> str: print() print("Waiting for Trezor input ...") print() @@ -117,7 +118,7 @@ def getDecryptedNonce(client: TrezorClient, entry: dict) -> str: ENC_KEY = f"Unlock {item} for user {entry['username']}?" ENC_VALUE = entry["nonce"] decrypted_nonce = misc.decrypt_keyvalue( - client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True + session, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True ) return decrypted_nonce.hex() @@ -144,13 +145,14 @@ def main() -> None: print(e) return - client = TrezorClient(transport=transport, ui=ui.ClickUI()) + client = TrezorClient(transport=transport) + session = client.get_management_session() print() print("Confirm operation on Trezor") print() - masterKey = getMasterKey(client) + masterKey = getMasterKey(session) # print('master key:', masterKey) fileName = getFileEncKey(masterKey)[0] @@ -173,7 +175,7 @@ def main() -> None: entry_id = input("Select entry number to decrypt: ") entry_id = str(entry_id) - plain_nonce = getDecryptedNonce(client, entries[entry_id]) + plain_nonce = getDecryptedNonce(session, entries[entry_id]) pwdArr = entries[entry_id]["password"]["data"] pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr]) diff --git a/python/tools/pybridge.py b/python/tools/pybridge.py index 30d69bbc9b1..d94ec121d03 100644 --- a/python/tools/pybridge.py +++ b/python/tools/pybridge.py @@ -24,6 +24,8 @@ from gevent import monkey +import trezorlib.transport + monkey.patch_all() import json @@ -103,11 +105,11 @@ def __init__(self, transport: trezorlib.transport.Transport) -> None: self.session: Session | None = None self.transport = transport - client = TrezorClient(transport, ui=SilentUI()) + client = TrezorClient(transport) # TODO add silent UI? self.model = ( trezorlib.models.by_name(client.features.model) or trezorlib.models.TREZOR_T ) - client.end_session() + # TODO client.end_session() def acquire(self, sid: str) -> str: if self.session_id() != sid: @@ -116,11 +118,11 @@ def acquire(self, sid: str) -> str: self.session.release() self.session = Session(self) - self.transport.begin_session() + # TODO self.transport.deprecated_begin_session() return self.session.id def release(self) -> None: - self.transport.end_session() + # TODO self.transport.deprecated_end_session() self.session = None def session_id(self) -> str | None: @@ -141,10 +143,14 @@ def to_json(self) -> dict: } def write(self, msg_id: int, data: bytes) -> None: - self.transport.write(msg_id, data) + raise NotImplementedError + # TODO + # self.transport.write(msg_id, data) def read(self) -> tuple[int, bytes]: - return self.transport.read() + raise NotImplementedError + # TODO + # return self.transport.read() @classmethod def find(cls, path: str) -> Transport | None: diff --git a/python/tools/rng_entropy_collector.py b/python/tools/rng_entropy_collector.py index 2b0a5b80d79..437561b1549 100755 --- a/python/tools/rng_entropy_collector.py +++ b/python/tools/rng_entropy_collector.py @@ -7,14 +7,15 @@ import io import sys -from trezorlib import misc, ui +from trezorlib import misc from trezorlib.client import TrezorClient from trezorlib.transport import get_transport def main() -> None: try: - client = TrezorClient(get_transport(), ui=ui.ClickUI()) + client = TrezorClient(get_transport()) + session = client.get_management_session() except Exception as e: print(e) return @@ -25,11 +26,9 @@ def main() -> None: with io.open(arg1, "wb") as f: for _ in range(0, arg2, step): - entropy = misc.get_entropy(client, step) + entropy = misc.get_entropy(session, step) f.write(entropy) - client.close() - if __name__ == "__main__": main() diff --git a/python/tools/trezor-otp.py b/python/tools/trezor-otp.py index bc0b66daa97..a88f745b412 100755 --- a/python/tools/trezor-otp.py +++ b/python/tools/trezor-otp.py @@ -27,26 +27,25 @@ from trezorlib.misc import decrypt_keyvalue, encrypt_keyvalue from trezorlib.tools import parse_path from trezorlib.transport import get_transport -from trezorlib.ui import ClickUI BIP32_PATH = parse_path("10016h/0") def encrypt(type: str, domain: str, secret: str) -> str: transport = get_transport() - client = TrezorClient(transport, ClickUI()) + client = TrezorClient(transport) + session = client.get_management_session() dom = type.upper() + ": " + domain - enc = encrypt_keyvalue(client, BIP32_PATH, dom, secret.encode(), False, True) - client.close() + enc = encrypt_keyvalue(session, BIP32_PATH, dom, secret.encode(), False, True) return enc.hex() def decrypt(type: str, domain: str, secret: bytes) -> bytes: transport = get_transport() - client = TrezorClient(transport, ClickUI()) + client = TrezorClient(transport) + session = client.get_management_session() dom = type.upper() + ": " + domain - dec = decrypt_keyvalue(client, BIP32_PATH, dom, secret, False, True) - client.close() + dec = decrypt_keyvalue(session, BIP32_PATH, dom, secret, False, True) return dec diff --git a/rust/trezor-client/src/protos/generated/messages_common.rs b/rust/trezor-client/src/protos/generated/messages_common.rs index 4fd72b22f0b..5323fe3c303 100644 --- a/rust/trezor-client/src/protos/generated/messages_common.rs +++ b/rust/trezor-client/src/protos/generated/messages_common.rs @@ -414,6 +414,10 @@ pub mod failure { Failure_WipeCodeMismatch = 13, // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_InvalidSession) Failure_InvalidSession = 14, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_ThpUnallocatedSession) + Failure_ThpUnallocatedSession = 15, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_InvalidProtocol) + Failure_InvalidProtocol = 16, // @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_FirmwareError) Failure_FirmwareError = 99, } @@ -441,6 +445,8 @@ pub mod failure { 12 => ::std::option::Option::Some(FailureType::Failure_PinMismatch), 13 => ::std::option::Option::Some(FailureType::Failure_WipeCodeMismatch), 14 => ::std::option::Option::Some(FailureType::Failure_InvalidSession), + 15 => ::std::option::Option::Some(FailureType::Failure_ThpUnallocatedSession), + 16 => ::std::option::Option::Some(FailureType::Failure_InvalidProtocol), 99 => ::std::option::Option::Some(FailureType::Failure_FirmwareError), _ => ::std::option::Option::None } @@ -462,6 +468,8 @@ pub mod failure { "Failure_PinMismatch" => ::std::option::Option::Some(FailureType::Failure_PinMismatch), "Failure_WipeCodeMismatch" => ::std::option::Option::Some(FailureType::Failure_WipeCodeMismatch), "Failure_InvalidSession" => ::std::option::Option::Some(FailureType::Failure_InvalidSession), + "Failure_ThpUnallocatedSession" => ::std::option::Option::Some(FailureType::Failure_ThpUnallocatedSession), + "Failure_InvalidProtocol" => ::std::option::Option::Some(FailureType::Failure_InvalidProtocol), "Failure_FirmwareError" => ::std::option::Option::Some(FailureType::Failure_FirmwareError), _ => ::std::option::Option::None } @@ -482,6 +490,8 @@ pub mod failure { FailureType::Failure_PinMismatch, FailureType::Failure_WipeCodeMismatch, FailureType::Failure_InvalidSession, + FailureType::Failure_ThpUnallocatedSession, + FailureType::Failure_InvalidProtocol, FailureType::Failure_FirmwareError, ]; } @@ -508,7 +518,9 @@ pub mod failure { FailureType::Failure_PinMismatch => 11, FailureType::Failure_WipeCodeMismatch => 12, FailureType::Failure_InvalidSession => 13, - FailureType::Failure_FirmwareError => 14, + FailureType::Failure_ThpUnallocatedSession => 14, + FailureType::Failure_InvalidProtocol => 15, + FailureType::Failure_FirmwareError => 16, }; Self::enum_descriptor().value_by_index(index) } @@ -2481,9 +2493,9 @@ impl ::protobuf::reflect::ProtobufValue for HDNodeType { static file_descriptor_proto_data: &'static [u8] = b"\ \n\x15messages-common.proto\x12\x19hw.trezor.messages.common\x1a\roption\ s.proto\"%\n\x07Success\x12\x1a\n\x07message\x18\x01\x20\x01(\t:\0R\x07m\ - essage\"\x8f\x04\n\x07Failure\x12B\n\x04code\x18\x01\x20\x01(\x0e2..hw.t\ + essage\"\xcf\x04\n\x07Failure\x12B\n\x04code\x18\x01\x20\x01(\x0e2..hw.t\ rezor.messages.common.Failure.FailureTypeR\x04code\x12\x18\n\x07message\ - \x18\x02\x20\x01(\tR\x07message\"\xa5\x03\n\x0bFailureType\x12\x1d\n\x19\ + \x18\x02\x20\x01(\tR\x07message\"\xe5\x03\n\x0bFailureType\x12\x1d\n\x19\ Failure_UnexpectedMessage\x10\x01\x12\x1a\n\x16Failure_ButtonExpected\ \x10\x02\x12\x15\n\x11Failure_DataError\x10\x03\x12\x1b\n\x17Failure_Act\ ionCancelled\x10\x04\x12\x17\n\x13Failure_PinExpected\x10\x05\x12\x18\n\ @@ -2492,44 +2504,45 @@ static file_descriptor_proto_data: &'static [u8] = b"\ essError\x10\t\x12\x1a\n\x16Failure_NotEnoughFunds\x10\n\x12\x1a\n\x16Fa\ ilure_NotInitialized\x10\x0b\x12\x17\n\x13Failure_PinMismatch\x10\x0c\ \x12\x1c\n\x18Failure_WipeCodeMismatch\x10\r\x12\x1a\n\x16Failure_Invali\ - dSession\x10\x0e\x12\x19\n\x15Failure_FirmwareError\x10c\"\xab\x06\n\rBu\ - ttonRequest\x12N\n\x04code\x18\x01\x20\x01(\x0e2:.hw.trezor.messages.com\ - mon.ButtonRequest.ButtonRequestTypeR\x04code\x12\x14\n\x05pages\x18\x02\ - \x20\x01(\rR\x05pages\x12\x12\n\x04name\x18\x04\x20\x01(\tR\x04name\"\ - \x99\x05\n\x11ButtonRequestType\x12\x17\n\x13ButtonRequest_Other\x10\x01\ - \x12\"\n\x1eButtonRequest_FeeOverThreshold\x10\x02\x12\x1f\n\x1bButtonRe\ - quest_ConfirmOutput\x10\x03\x12\x1d\n\x19ButtonRequest_ResetDevice\x10\ - \x04\x12\x1d\n\x19ButtonRequest_ConfirmWord\x10\x05\x12\x1c\n\x18ButtonR\ - equest_WipeDevice\x10\x06\x12\x1d\n\x19ButtonRequest_ProtectCall\x10\x07\ - \x12\x18\n\x14ButtonRequest_SignTx\x10\x08\x12\x1f\n\x1bButtonRequest_Fi\ - rmwareCheck\x10\t\x12\x19\n\x15ButtonRequest_Address\x10\n\x12\x1b\n\x17\ - ButtonRequest_PublicKey\x10\x0b\x12#\n\x1fButtonRequest_MnemonicWordCoun\ - t\x10\x0c\x12\x1f\n\x1bButtonRequest_MnemonicInput\x10\r\x120\n(_Depreca\ - ted_ButtonRequest_PassphraseType\x10\x0e\x1a\x02\x08\x01\x12'\n#ButtonRe\ - quest_UnknownDerivationPath\x10\x0f\x12\"\n\x1eButtonRequest_RecoveryHom\ - epage\x10\x10\x12\x19\n\x15ButtonRequest_Success\x10\x11\x12\x19\n\x15Bu\ - ttonRequest_Warning\x10\x12\x12!\n\x1dButtonRequest_PassphraseEntry\x10\ - \x13\x12\x1a\n\x16ButtonRequest_PinEntry\x10\x14J\x04\x08\x03\x10\x04\"\ - \x0b\n\tButtonAck\"\xbb\x02\n\x10PinMatrixRequest\x12T\n\x04type\x18\x01\ - \x20\x01(\x0e2@.hw.trezor.messages.common.PinMatrixRequest.PinMatrixRequ\ - estTypeR\x04type\"\xd0\x01\n\x14PinMatrixRequestType\x12\x20\n\x1cPinMat\ - rixRequestType_Current\x10\x01\x12!\n\x1dPinMatrixRequestType_NewFirst\ - \x10\x02\x12\"\n\x1ePinMatrixRequestType_NewSecond\x10\x03\x12&\n\"PinMa\ - trixRequestType_WipeCodeFirst\x10\x04\x12'\n#PinMatrixRequestType_WipeCo\ - deSecond\x10\x05\"\x20\n\x0cPinMatrixAck\x12\x10\n\x03pin\x18\x01\x20\ - \x02(\tR\x03pin\"5\n\x11PassphraseRequest\x12\x20\n\n_on_device\x18\x01\ - \x20\x01(\x08R\x08OnDeviceB\x02\x18\x01\"g\n\rPassphraseAck\x12\x1e\n\np\ - assphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x19\n\x06_state\x18\x02\ - \x20\x01(\x0cR\x05StateB\x02\x18\x01\x12\x1b\n\ton_device\x18\x03\x20\ - \x01(\x08R\x08onDevice\"=\n!Deprecated_PassphraseStateRequest\x12\x14\n\ - \x05state\x18\x01\x20\x01(\x0cR\x05state:\x02\x18\x01\"#\n\x1dDeprecated\ - _PassphraseStateAck:\x02\x18\x01\"\xc0\x01\n\nHDNodeType\x12\x14\n\x05de\ - pth\x18\x01\x20\x02(\rR\x05depth\x12\x20\n\x0bfingerprint\x18\x02\x20\ - \x02(\rR\x0bfingerprint\x12\x1b\n\tchild_num\x18\x03\x20\x02(\rR\x08chil\ - dNum\x12\x1d\n\nchain_code\x18\x04\x20\x02(\x0cR\tchainCode\x12\x1f\n\ - \x0bprivate_key\x18\x05\x20\x01(\x0cR\nprivateKey\x12\x1d\n\npublic_key\ - \x18\x06\x20\x02(\x0cR\tpublicKeyB>\n#com.satoshilabs.trezor.lib.protobu\ - fB\x13TrezorMessageCommon\x80\xa6\x1d\x01\ + dSession\x10\x0e\x12!\n\x1dFailure_ThpUnallocatedSession\x10\x0f\x12\x1b\ + \n\x17Failure_InvalidProtocol\x10\x10\x12\x19\n\x15Failure_FirmwareError\ + \x10c\"\xab\x06\n\rButtonRequest\x12N\n\x04code\x18\x01\x20\x01(\x0e2:.h\ + w.trezor.messages.common.ButtonRequest.ButtonRequestTypeR\x04code\x12\ + \x14\n\x05pages\x18\x02\x20\x01(\rR\x05pages\x12\x12\n\x04name\x18\x04\ + \x20\x01(\tR\x04name\"\x99\x05\n\x11ButtonRequestType\x12\x17\n\x13Butto\ + nRequest_Other\x10\x01\x12\"\n\x1eButtonRequest_FeeOverThreshold\x10\x02\ + \x12\x1f\n\x1bButtonRequest_ConfirmOutput\x10\x03\x12\x1d\n\x19ButtonReq\ + uest_ResetDevice\x10\x04\x12\x1d\n\x19ButtonRequest_ConfirmWord\x10\x05\ + \x12\x1c\n\x18ButtonRequest_WipeDevice\x10\x06\x12\x1d\n\x19ButtonReques\ + t_ProtectCall\x10\x07\x12\x18\n\x14ButtonRequest_SignTx\x10\x08\x12\x1f\ + \n\x1bButtonRequest_FirmwareCheck\x10\t\x12\x19\n\x15ButtonRequest_Addre\ + ss\x10\n\x12\x1b\n\x17ButtonRequest_PublicKey\x10\x0b\x12#\n\x1fButtonRe\ + quest_MnemonicWordCount\x10\x0c\x12\x1f\n\x1bButtonRequest_MnemonicInput\ + \x10\r\x120\n(_Deprecated_ButtonRequest_PassphraseType\x10\x0e\x1a\x02\ + \x08\x01\x12'\n#ButtonRequest_UnknownDerivationPath\x10\x0f\x12\"\n\x1eB\ + uttonRequest_RecoveryHomepage\x10\x10\x12\x19\n\x15ButtonRequest_Success\ + \x10\x11\x12\x19\n\x15ButtonRequest_Warning\x10\x12\x12!\n\x1dButtonRequ\ + est_PassphraseEntry\x10\x13\x12\x1a\n\x16ButtonRequest_PinEntry\x10\x14J\ + \x04\x08\x03\x10\x04\"\x0b\n\tButtonAck\"\xbb\x02\n\x10PinMatrixRequest\ + \x12T\n\x04type\x18\x01\x20\x01(\x0e2@.hw.trezor.messages.common.PinMatr\ + ixRequest.PinMatrixRequestTypeR\x04type\"\xd0\x01\n\x14PinMatrixRequestT\ + ype\x12\x20\n\x1cPinMatrixRequestType_Current\x10\x01\x12!\n\x1dPinMatri\ + xRequestType_NewFirst\x10\x02\x12\"\n\x1ePinMatrixRequestType_NewSecond\ + \x10\x03\x12&\n\"PinMatrixRequestType_WipeCodeFirst\x10\x04\x12'\n#PinMa\ + trixRequestType_WipeCodeSecond\x10\x05\"\x20\n\x0cPinMatrixAck\x12\x10\n\ + \x03pin\x18\x01\x20\x02(\tR\x03pin\"5\n\x11PassphraseRequest\x12\x20\n\n\ + _on_device\x18\x01\x20\x01(\x08R\x08OnDeviceB\x02\x18\x01\"g\n\rPassphra\ + seAck\x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x19\n\ + \x06_state\x18\x02\x20\x01(\x0cR\x05StateB\x02\x18\x01\x12\x1b\n\ton_dev\ + ice\x18\x03\x20\x01(\x08R\x08onDevice\"=\n!Deprecated_PassphraseStateReq\ + uest\x12\x14\n\x05state\x18\x01\x20\x01(\x0cR\x05state:\x02\x18\x01\"#\n\ + \x1dDeprecated_PassphraseStateAck:\x02\x18\x01\"\xc0\x01\n\nHDNodeType\ + \x12\x14\n\x05depth\x18\x01\x20\x02(\rR\x05depth\x12\x20\n\x0bfingerprin\ + t\x18\x02\x20\x02(\rR\x0bfingerprint\x12\x1b\n\tchild_num\x18\x03\x20\ + \x02(\rR\x08childNum\x12\x1d\n\nchain_code\x18\x04\x20\x02(\x0cR\tchainC\ + ode\x12\x1f\n\x0bprivate_key\x18\x05\x20\x01(\x0cR\nprivateKey\x12\x1d\n\ + \npublic_key\x18\x06\x20\x02(\x0cR\tpublicKeyB>\n#com.satoshilabs.trezor\ + .lib.protobufB\x13TrezorMessageCommon\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file diff --git a/rust/trezor-client/src/protos/generated/messages_debug.rs b/rust/trezor-client/src/protos/generated/messages_debug.rs index d384b11545c..3197a4fab4e 100644 --- a/rust/trezor-client/src/protos/generated/messages_debug.rs +++ b/rust/trezor-client/src/protos/generated/messages_debug.rs @@ -1128,6 +1128,8 @@ pub struct DebugLinkGetState { pub wait_word_pos: ::std::option::Option, // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkGetState.wait_layout) pub wait_layout: ::std::option::Option<::protobuf::EnumOrUnknown>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkGetState.thp_channel_id) + pub thp_channel_id: ::std::option::Option<::std::vec::Vec>, // special fields // @@protoc_insertion_point(special_field:hw.trezor.messages.debug.DebugLinkGetState.special_fields) pub special_fields: ::protobuf::SpecialFields, @@ -1204,8 +1206,44 @@ impl DebugLinkGetState { self.wait_layout = ::std::option::Option::Some(::protobuf::EnumOrUnknown::new(v)); } + // optional bytes thp_channel_id = 4; + + pub fn thp_channel_id(&self) -> &[u8] { + match self.thp_channel_id.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_thp_channel_id(&mut self) { + self.thp_channel_id = ::std::option::Option::None; + } + + pub fn has_thp_channel_id(&self) -> bool { + self.thp_channel_id.is_some() + } + + // Param is passed by value, moved + pub fn set_thp_channel_id(&mut self, v: ::std::vec::Vec) { + self.thp_channel_id = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_thp_channel_id(&mut self) -> &mut ::std::vec::Vec { + if self.thp_channel_id.is_none() { + self.thp_channel_id = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.thp_channel_id.as_mut().unwrap() + } + + // Take field + pub fn take_thp_channel_id(&mut self) -> ::std::vec::Vec { + self.thp_channel_id.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { - let mut fields = ::std::vec::Vec::with_capacity(3); + let mut fields = ::std::vec::Vec::with_capacity(4); let mut oneofs = ::std::vec::Vec::with_capacity(0); fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( "wait_word_list", @@ -1222,6 +1260,11 @@ impl DebugLinkGetState { |m: &DebugLinkGetState| { &m.wait_layout }, |m: &mut DebugLinkGetState| { &mut m.wait_layout }, )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "thp_channel_id", + |m: &DebugLinkGetState| { &m.thp_channel_id }, + |m: &mut DebugLinkGetState| { &mut m.thp_channel_id }, + )); ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( "DebugLinkGetState", fields, @@ -1249,6 +1292,9 @@ impl ::protobuf::Message for DebugLinkGetState { 24 => { self.wait_layout = ::std::option::Option::Some(is.read_enum_or_unknown()?); }, + 34 => { + self.thp_channel_id = ::std::option::Option::Some(is.read_bytes()?); + }, tag => { ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; }, @@ -1270,6 +1316,9 @@ impl ::protobuf::Message for DebugLinkGetState { if let Some(v) = self.wait_layout { my_size += ::protobuf::rt::int32_size(3, v.value()); } + if let Some(v) = self.thp_channel_id.as_ref() { + my_size += ::protobuf::rt::bytes_size(4, &v); + } my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); self.special_fields.cached_size().set(my_size as u32); my_size @@ -1285,6 +1334,9 @@ impl ::protobuf::Message for DebugLinkGetState { if let Some(v) = self.wait_layout { os.write_enum(3, ::protobuf::EnumOrUnknown::value(&v))?; } + if let Some(v) = self.thp_channel_id.as_ref() { + os.write_bytes(4, v)?; + } os.write_unknown_fields(self.special_fields.unknown_fields())?; ::std::result::Result::Ok(()) } @@ -1305,6 +1357,7 @@ impl ::protobuf::Message for DebugLinkGetState { self.wait_word_list = ::std::option::Option::None; self.wait_word_pos = ::std::option::Option::None; self.wait_layout = ::std::option::Option::None; + self.thp_channel_id = ::std::option::Option::None; self.special_fields.clear(); } @@ -1313,6 +1366,7 @@ impl ::protobuf::Message for DebugLinkGetState { wait_word_list: ::std::option::Option::None, wait_word_pos: ::std::option::Option::None, wait_layout: ::std::option::Option::None, + thp_channel_id: ::std::option::Option::None, special_fields: ::protobuf::SpecialFields::new(), }; &instance @@ -1436,6 +1490,12 @@ pub struct DebugLinkState { pub mnemonic_type: ::std::option::Option<::protobuf::EnumOrUnknown>, // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkState.tokens) pub tokens: ::std::vec::Vec<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkState.thp_pairing_code_entry_code) + pub thp_pairing_code_entry_code: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkState.thp_pairing_code_qr_code) + pub thp_pairing_code_qr_code: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.debug.DebugLinkState.thp_pairing_code_nfc_unidirectional) + pub thp_pairing_code_nfc_unidirectional: ::std::option::Option<::std::vec::Vec>, // special fields // @@protoc_insertion_point(special_field:hw.trezor.messages.debug.DebugLinkState.special_fields) pub special_fields: ::protobuf::SpecialFields, @@ -1783,8 +1843,99 @@ impl DebugLinkState { self.mnemonic_type = ::std::option::Option::Some(::protobuf::EnumOrUnknown::new(v)); } + // optional uint32 thp_pairing_code_entry_code = 14; + + pub fn thp_pairing_code_entry_code(&self) -> u32 { + self.thp_pairing_code_entry_code.unwrap_or(0) + } + + pub fn clear_thp_pairing_code_entry_code(&mut self) { + self.thp_pairing_code_entry_code = ::std::option::Option::None; + } + + pub fn has_thp_pairing_code_entry_code(&self) -> bool { + self.thp_pairing_code_entry_code.is_some() + } + + // Param is passed by value, moved + pub fn set_thp_pairing_code_entry_code(&mut self, v: u32) { + self.thp_pairing_code_entry_code = ::std::option::Option::Some(v); + } + + // optional bytes thp_pairing_code_qr_code = 15; + + pub fn thp_pairing_code_qr_code(&self) -> &[u8] { + match self.thp_pairing_code_qr_code.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_thp_pairing_code_qr_code(&mut self) { + self.thp_pairing_code_qr_code = ::std::option::Option::None; + } + + pub fn has_thp_pairing_code_qr_code(&self) -> bool { + self.thp_pairing_code_qr_code.is_some() + } + + // Param is passed by value, moved + pub fn set_thp_pairing_code_qr_code(&mut self, v: ::std::vec::Vec) { + self.thp_pairing_code_qr_code = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_thp_pairing_code_qr_code(&mut self) -> &mut ::std::vec::Vec { + if self.thp_pairing_code_qr_code.is_none() { + self.thp_pairing_code_qr_code = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.thp_pairing_code_qr_code.as_mut().unwrap() + } + + // Take field + pub fn take_thp_pairing_code_qr_code(&mut self) -> ::std::vec::Vec { + self.thp_pairing_code_qr_code.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes thp_pairing_code_nfc_unidirectional = 16; + + pub fn thp_pairing_code_nfc_unidirectional(&self) -> &[u8] { + match self.thp_pairing_code_nfc_unidirectional.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_thp_pairing_code_nfc_unidirectional(&mut self) { + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::None; + } + + pub fn has_thp_pairing_code_nfc_unidirectional(&self) -> bool { + self.thp_pairing_code_nfc_unidirectional.is_some() + } + + // Param is passed by value, moved + pub fn set_thp_pairing_code_nfc_unidirectional(&mut self, v: ::std::vec::Vec) { + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_thp_pairing_code_nfc_unidirectional(&mut self) -> &mut ::std::vec::Vec { + if self.thp_pairing_code_nfc_unidirectional.is_none() { + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.thp_pairing_code_nfc_unidirectional.as_mut().unwrap() + } + + // Take field + pub fn take_thp_pairing_code_nfc_unidirectional(&mut self) -> ::std::vec::Vec { + self.thp_pairing_code_nfc_unidirectional.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { - let mut fields = ::std::vec::Vec::with_capacity(13); + let mut fields = ::std::vec::Vec::with_capacity(16); let mut oneofs = ::std::vec::Vec::with_capacity(0); fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( "layout", @@ -1851,6 +2002,21 @@ impl DebugLinkState { |m: &DebugLinkState| { &m.tokens }, |m: &mut DebugLinkState| { &mut m.tokens }, )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "thp_pairing_code_entry_code", + |m: &DebugLinkState| { &m.thp_pairing_code_entry_code }, + |m: &mut DebugLinkState| { &mut m.thp_pairing_code_entry_code }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "thp_pairing_code_qr_code", + |m: &DebugLinkState| { &m.thp_pairing_code_qr_code }, + |m: &mut DebugLinkState| { &mut m.thp_pairing_code_qr_code }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "thp_pairing_code_nfc_unidirectional", + |m: &DebugLinkState| { &m.thp_pairing_code_nfc_unidirectional }, + |m: &mut DebugLinkState| { &mut m.thp_pairing_code_nfc_unidirectional }, + )); ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( "DebugLinkState", fields, @@ -1913,6 +2079,15 @@ impl ::protobuf::Message for DebugLinkState { 106 => { self.tokens.push(is.read_string()?); }, + 112 => { + self.thp_pairing_code_entry_code = ::std::option::Option::Some(is.read_uint32()?); + }, + 122 => { + self.thp_pairing_code_qr_code = ::std::option::Option::Some(is.read_bytes()?); + }, + 130 => { + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::Some(is.read_bytes()?); + }, tag => { ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; }, @@ -1965,6 +2140,15 @@ impl ::protobuf::Message for DebugLinkState { for value in &self.tokens { my_size += ::protobuf::rt::string_size(13, &value); }; + if let Some(v) = self.thp_pairing_code_entry_code { + my_size += ::protobuf::rt::uint32_size(14, v); + } + if let Some(v) = self.thp_pairing_code_qr_code.as_ref() { + my_size += ::protobuf::rt::bytes_size(15, &v); + } + if let Some(v) = self.thp_pairing_code_nfc_unidirectional.as_ref() { + my_size += ::protobuf::rt::bytes_size(16, &v); + } my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); self.special_fields.cached_size().set(my_size as u32); my_size @@ -2010,6 +2194,15 @@ impl ::protobuf::Message for DebugLinkState { for v in &self.tokens { os.write_string(13, &v)?; }; + if let Some(v) = self.thp_pairing_code_entry_code { + os.write_uint32(14, v)?; + } + if let Some(v) = self.thp_pairing_code_qr_code.as_ref() { + os.write_bytes(15, v)?; + } + if let Some(v) = self.thp_pairing_code_nfc_unidirectional.as_ref() { + os.write_bytes(16, v)?; + } os.write_unknown_fields(self.special_fields.unknown_fields())?; ::std::result::Result::Ok(()) } @@ -2040,6 +2233,9 @@ impl ::protobuf::Message for DebugLinkState { self.reset_word_pos = ::std::option::Option::None; self.mnemonic_type = ::std::option::Option::None; self.tokens.clear(); + self.thp_pairing_code_entry_code = ::std::option::Option::None; + self.thp_pairing_code_qr_code = ::std::option::Option::None; + self.thp_pairing_code_nfc_unidirectional = ::std::option::Option::None; self.special_fields.clear(); } @@ -2058,6 +2254,9 @@ impl ::protobuf::Message for DebugLinkState { reset_word_pos: ::std::option::Option::None, mnemonic_type: ::std::option::Option::None, tokens: ::std::vec::Vec::new(), + thp_pairing_code_entry_code: ::std::option::Option::None, + thp_pairing_code_qr_code: ::std::option::Option::None, + thp_pairing_code_nfc_unidirectional: ::std::option::Option::None, special_fields: ::protobuf::SpecialFields::new(), }; &instance @@ -3650,39 +3849,44 @@ static file_descriptor_proto_data: &'static [u8] = b"\ \x01\x20\x03(\tR\x06tokens:\x02\x18\x01\"-\n\x15DebugLinkReseedRandom\ \x12\x14\n\x05value\x18\x01\x20\x01(\rR\x05value\"j\n\x15DebugLinkRecord\ Screen\x12)\n\x10target_directory\x18\x01\x20\x01(\tR\x0ftargetDirectory\ - \x12&\n\rrefresh_index\x18\x02\x20\x01(\r:\x010R\x0crefreshIndex\"\x91\ + \x12&\n\rrefresh_index\x18\x02\x20\x01(\r:\x010R\x0crefreshIndex\"\xb7\ \x02\n\x11DebugLinkGetState\x12(\n\x0ewait_word_list\x18\x01\x20\x01(\ \x08R\x0cwaitWordListB\x02\x18\x01\x12&\n\rwait_word_pos\x18\x02\x20\x01\ (\x08R\x0bwaitWordPosB\x02\x18\x01\x12e\n\x0bwait_layout\x18\x03\x20\x01\ (\x0e29.hw.trezor.messages.debug.DebugLinkGetState.DebugWaitType:\tIMMED\ - IATER\nwaitLayout\"C\n\rDebugWaitType\x12\r\n\tIMMEDIATE\x10\0\x12\x0f\n\ - \x0bNEXT_LAYOUT\x10\x01\x12\x12\n\x0eCURRENT_LAYOUT\x10\x02\"\x97\x04\n\ - \x0eDebugLinkState\x12\x16\n\x06layout\x18\x01\x20\x01(\x0cR\x06layout\ - \x12\x10\n\x03pin\x18\x02\x20\x01(\tR\x03pin\x12\x16\n\x06matrix\x18\x03\ - \x20\x01(\tR\x06matrix\x12'\n\x0fmnemonic_secret\x18\x04\x20\x01(\x0cR\ - \x0emnemonicSecret\x129\n\x04node\x18\x05\x20\x01(\x0b2%.hw.trezor.messa\ - ges.common.HDNodeTypeR\x04node\x123\n\x15passphrase_protection\x18\x06\ - \x20\x01(\x08R\x14passphraseProtection\x12\x1d\n\nreset_word\x18\x07\x20\ - \x01(\tR\tresetWord\x12#\n\rreset_entropy\x18\x08\x20\x01(\x0cR\x0creset\ - Entropy\x12,\n\x12recovery_fake_word\x18\t\x20\x01(\tR\x10recoveryFakeWo\ - rd\x12*\n\x11recovery_word_pos\x18\n\x20\x01(\rR\x0frecoveryWordPos\x12$\ - \n\x0ereset_word_pos\x18\x0b\x20\x01(\rR\x0cresetWordPos\x12N\n\rmnemoni\ - c_type\x18\x0c\x20\x01(\x0e2).hw.trezor.messages.management.BackupTypeR\ - \x0cmnemonicType\x12\x16\n\x06tokens\x18\r\x20\x03(\tR\x06tokens\"\x0f\n\ - \rDebugLinkStop\"P\n\x0cDebugLinkLog\x12\x14\n\x05level\x18\x01\x20\x01(\ - \rR\x05level\x12\x16\n\x06bucket\x18\x02\x20\x01(\tR\x06bucket\x12\x12\n\ - \x04text\x18\x03\x20\x01(\tR\x04text\"G\n\x13DebugLinkMemoryRead\x12\x18\ - \n\x07address\x18\x01\x20\x01(\rR\x07address\x12\x16\n\x06length\x18\x02\ - \x20\x01(\rR\x06length\")\n\x0fDebugLinkMemory\x12\x16\n\x06memory\x18\ - \x01\x20\x01(\x0cR\x06memory\"^\n\x14DebugLinkMemoryWrite\x12\x18\n\x07a\ - ddress\x18\x01\x20\x01(\rR\x07address\x12\x16\n\x06memory\x18\x02\x20\ - \x01(\x0cR\x06memory\x12\x14\n\x05flash\x18\x03\x20\x01(\x08R\x05flash\"\ - -\n\x13DebugLinkFlashErase\x12\x16\n\x06sector\x18\x01\x20\x01(\rR\x06se\ - ctor\".\n\x14DebugLinkEraseSdCard\x12\x16\n\x06format\x18\x01\x20\x01(\ - \x08R\x06format\"0\n\x14DebugLinkWatchLayout\x12\x14\n\x05watch\x18\x01\ - \x20\x01(\x08R\x05watch:\x02\x18\x01\"\x1f\n\x19DebugLinkResetDebugEvent\ - s:\x02\x18\x01\"\x1a\n\x18DebugLinkOptigaSetSecMaxB=\n#com.satoshilabs.t\ - rezor.lib.protobufB\x12TrezorMessageDebug\x80\xa6\x1d\x01\ + IATER\nwaitLayout\x12$\n\x0ethp_channel_id\x18\x04\x20\x01(\x0cR\x0cthpC\ + hannelId\"C\n\rDebugWaitType\x12\r\n\tIMMEDIATE\x10\0\x12\x0f\n\x0bNEXT_\ + LAYOUT\x10\x01\x12\x12\n\x0eCURRENT_LAYOUT\x10\x02\"\xdb\x05\n\x0eDebugL\ + inkState\x12\x16\n\x06layout\x18\x01\x20\x01(\x0cR\x06layout\x12\x10\n\ + \x03pin\x18\x02\x20\x01(\tR\x03pin\x12\x16\n\x06matrix\x18\x03\x20\x01(\ + \tR\x06matrix\x12'\n\x0fmnemonic_secret\x18\x04\x20\x01(\x0cR\x0emnemoni\ + cSecret\x129\n\x04node\x18\x05\x20\x01(\x0b2%.hw.trezor.messages.common.\ + HDNodeTypeR\x04node\x123\n\x15passphrase_protection\x18\x06\x20\x01(\x08\ + R\x14passphraseProtection\x12\x1d\n\nreset_word\x18\x07\x20\x01(\tR\tres\ + etWord\x12#\n\rreset_entropy\x18\x08\x20\x01(\x0cR\x0cresetEntropy\x12,\ + \n\x12recovery_fake_word\x18\t\x20\x01(\tR\x10recoveryFakeWord\x12*\n\ + \x11recovery_word_pos\x18\n\x20\x01(\rR\x0frecoveryWordPos\x12$\n\x0eres\ + et_word_pos\x18\x0b\x20\x01(\rR\x0cresetWordPos\x12N\n\rmnemonic_type\ + \x18\x0c\x20\x01(\x0e2).hw.trezor.messages.management.BackupTypeR\x0cmne\ + monicType\x12\x16\n\x06tokens\x18\r\x20\x03(\tR\x06tokens\x12<\n\x1bthp_\ + pairing_code_entry_code\x18\x0e\x20\x01(\rR\x17thpPairingCodeEntryCode\ + \x126\n\x18thp_pairing_code_qr_code\x18\x0f\x20\x01(\x0cR\x14thpPairingC\ + odeQrCode\x12L\n#thp_pairing_code_nfc_unidirectional\x18\x10\x20\x01(\ + \x0cR\x1fthpPairingCodeNfcUnidirectional\"\x0f\n\rDebugLinkStop\"P\n\x0c\ + DebugLinkLog\x12\x14\n\x05level\x18\x01\x20\x01(\rR\x05level\x12\x16\n\ + \x06bucket\x18\x02\x20\x01(\tR\x06bucket\x12\x12\n\x04text\x18\x03\x20\ + \x01(\tR\x04text\"G\n\x13DebugLinkMemoryRead\x12\x18\n\x07address\x18\ + \x01\x20\x01(\rR\x07address\x12\x16\n\x06length\x18\x02\x20\x01(\rR\x06l\ + ength\")\n\x0fDebugLinkMemory\x12\x16\n\x06memory\x18\x01\x20\x01(\x0cR\ + \x06memory\"^\n\x14DebugLinkMemoryWrite\x12\x18\n\x07address\x18\x01\x20\ + \x01(\rR\x07address\x12\x16\n\x06memory\x18\x02\x20\x01(\x0cR\x06memory\ + \x12\x14\n\x05flash\x18\x03\x20\x01(\x08R\x05flash\"-\n\x13DebugLinkFlas\ + hErase\x12\x16\n\x06sector\x18\x01\x20\x01(\rR\x06sector\".\n\x14DebugLi\ + nkEraseSdCard\x12\x16\n\x06format\x18\x01\x20\x01(\x08R\x06format\"0\n\ + \x14DebugLinkWatchLayout\x12\x14\n\x05watch\x18\x01\x20\x01(\x08R\x05wat\ + ch:\x02\x18\x01\"\x1f\n\x19DebugLinkResetDebugEvents:\x02\x18\x01\"\x1a\ + \n\x18DebugLinkOptigaSetSecMaxB=\n#com.satoshilabs.trezor.lib.protobufB\ + \x12TrezorMessageDebug\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file diff --git a/rust/trezor-client/src/protos/generated/messages_thp.rs b/rust/trezor-client/src/protos/generated/messages_thp.rs index 9e0d8e8aea8..b449bb30daa 100644 --- a/rust/trezor-client/src/protos/generated/messages_thp.rs +++ b/rust/trezor-client/src/protos/generated/messages_thp.rs @@ -25,6 +25,3265 @@ /// of protobuf runtime. const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_3_3_0; +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpDeviceProperties) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpDeviceProperties { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.internal_model) + pub internal_model: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.model_variant) + pub model_variant: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.bootloader_mode) + pub bootloader_mode: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.protocol_version) + pub protocol_version: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpDeviceProperties.pairing_methods) + pub pairing_methods: ::std::vec::Vec<::protobuf::EnumOrUnknown>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpDeviceProperties.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpDeviceProperties { + fn default() -> &'a ThpDeviceProperties { + ::default_instance() + } +} + +impl ThpDeviceProperties { + pub fn new() -> ThpDeviceProperties { + ::std::default::Default::default() + } + + // optional string internal_model = 1; + + pub fn internal_model(&self) -> &str { + match self.internal_model.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_internal_model(&mut self) { + self.internal_model = ::std::option::Option::None; + } + + pub fn has_internal_model(&self) -> bool { + self.internal_model.is_some() + } + + // Param is passed by value, moved + pub fn set_internal_model(&mut self, v: ::std::string::String) { + self.internal_model = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_internal_model(&mut self) -> &mut ::std::string::String { + if self.internal_model.is_none() { + self.internal_model = ::std::option::Option::Some(::std::string::String::new()); + } + self.internal_model.as_mut().unwrap() + } + + // Take field + pub fn take_internal_model(&mut self) -> ::std::string::String { + self.internal_model.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional uint32 model_variant = 2; + + pub fn model_variant(&self) -> u32 { + self.model_variant.unwrap_or(0) + } + + pub fn clear_model_variant(&mut self) { + self.model_variant = ::std::option::Option::None; + } + + pub fn has_model_variant(&self) -> bool { + self.model_variant.is_some() + } + + // Param is passed by value, moved + pub fn set_model_variant(&mut self, v: u32) { + self.model_variant = ::std::option::Option::Some(v); + } + + // optional bool bootloader_mode = 3; + + pub fn bootloader_mode(&self) -> bool { + self.bootloader_mode.unwrap_or(false) + } + + pub fn clear_bootloader_mode(&mut self) { + self.bootloader_mode = ::std::option::Option::None; + } + + pub fn has_bootloader_mode(&self) -> bool { + self.bootloader_mode.is_some() + } + + // Param is passed by value, moved + pub fn set_bootloader_mode(&mut self, v: bool) { + self.bootloader_mode = ::std::option::Option::Some(v); + } + + // optional uint32 protocol_version = 4; + + pub fn protocol_version(&self) -> u32 { + self.protocol_version.unwrap_or(0) + } + + pub fn clear_protocol_version(&mut self) { + self.protocol_version = ::std::option::Option::None; + } + + pub fn has_protocol_version(&self) -> bool { + self.protocol_version.is_some() + } + + // Param is passed by value, moved + pub fn set_protocol_version(&mut self, v: u32) { + self.protocol_version = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(5); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "internal_model", + |m: &ThpDeviceProperties| { &m.internal_model }, + |m: &mut ThpDeviceProperties| { &mut m.internal_model }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "model_variant", + |m: &ThpDeviceProperties| { &m.model_variant }, + |m: &mut ThpDeviceProperties| { &mut m.model_variant }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "bootloader_mode", + |m: &ThpDeviceProperties| { &m.bootloader_mode }, + |m: &mut ThpDeviceProperties| { &mut m.bootloader_mode }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "protocol_version", + |m: &ThpDeviceProperties| { &m.protocol_version }, + |m: &mut ThpDeviceProperties| { &mut m.protocol_version }, + )); + fields.push(::protobuf::reflect::rt::v2::make_vec_simpler_accessor::<_, _>( + "pairing_methods", + |m: &ThpDeviceProperties| { &m.pairing_methods }, + |m: &mut ThpDeviceProperties| { &mut m.pairing_methods }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpDeviceProperties", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpDeviceProperties { + const NAME: &'static str = "ThpDeviceProperties"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.internal_model = ::std::option::Option::Some(is.read_string()?); + }, + 16 => { + self.model_variant = ::std::option::Option::Some(is.read_uint32()?); + }, + 24 => { + self.bootloader_mode = ::std::option::Option::Some(is.read_bool()?); + }, + 32 => { + self.protocol_version = ::std::option::Option::Some(is.read_uint32()?); + }, + 40 => { + self.pairing_methods.push(is.read_enum_or_unknown()?); + }, + 42 => { + ::protobuf::rt::read_repeated_packed_enum_or_unknown_into(is, &mut self.pairing_methods)? + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.internal_model.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.model_variant { + my_size += ::protobuf::rt::uint32_size(2, v); + } + if let Some(v) = self.bootloader_mode { + my_size += 1 + 1; + } + if let Some(v) = self.protocol_version { + my_size += ::protobuf::rt::uint32_size(4, v); + } + for value in &self.pairing_methods { + my_size += ::protobuf::rt::int32_size(5, value.value()); + }; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.internal_model.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.model_variant { + os.write_uint32(2, v)?; + } + if let Some(v) = self.bootloader_mode { + os.write_bool(3, v)?; + } + if let Some(v) = self.protocol_version { + os.write_uint32(4, v)?; + } + for v in &self.pairing_methods { + os.write_enum(5, ::protobuf::EnumOrUnknown::value(v))?; + }; + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpDeviceProperties { + ThpDeviceProperties::new() + } + + fn clear(&mut self) { + self.internal_model = ::std::option::Option::None; + self.model_variant = ::std::option::Option::None; + self.bootloader_mode = ::std::option::Option::None; + self.protocol_version = ::std::option::Option::None; + self.pairing_methods.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpDeviceProperties { + static instance: ThpDeviceProperties = ThpDeviceProperties { + internal_model: ::std::option::Option::None, + model_variant: ::std::option::Option::None, + bootloader_mode: ::std::option::Option::None, + protocol_version: ::std::option::Option::None, + pairing_methods: ::std::vec::Vec::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpDeviceProperties { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpDeviceProperties").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpDeviceProperties { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpDeviceProperties { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpHandshakeCompletionReqNoisePayload { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload.host_pairing_credential) + pub host_pairing_credential: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload.pairing_methods) + pub pairing_methods: ::std::vec::Vec<::protobuf::EnumOrUnknown>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpHandshakeCompletionReqNoisePayload.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpHandshakeCompletionReqNoisePayload { + fn default() -> &'a ThpHandshakeCompletionReqNoisePayload { + ::default_instance() + } +} + +impl ThpHandshakeCompletionReqNoisePayload { + pub fn new() -> ThpHandshakeCompletionReqNoisePayload { + ::std::default::Default::default() + } + + // optional bytes host_pairing_credential = 1; + + pub fn host_pairing_credential(&self) -> &[u8] { + match self.host_pairing_credential.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_host_pairing_credential(&mut self) { + self.host_pairing_credential = ::std::option::Option::None; + } + + pub fn has_host_pairing_credential(&self) -> bool { + self.host_pairing_credential.is_some() + } + + // Param is passed by value, moved + pub fn set_host_pairing_credential(&mut self, v: ::std::vec::Vec) { + self.host_pairing_credential = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_pairing_credential(&mut self) -> &mut ::std::vec::Vec { + if self.host_pairing_credential.is_none() { + self.host_pairing_credential = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.host_pairing_credential.as_mut().unwrap() + } + + // Take field + pub fn take_host_pairing_credential(&mut self) -> ::std::vec::Vec { + self.host_pairing_credential.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_pairing_credential", + |m: &ThpHandshakeCompletionReqNoisePayload| { &m.host_pairing_credential }, + |m: &mut ThpHandshakeCompletionReqNoisePayload| { &mut m.host_pairing_credential }, + )); + fields.push(::protobuf::reflect::rt::v2::make_vec_simpler_accessor::<_, _>( + "pairing_methods", + |m: &ThpHandshakeCompletionReqNoisePayload| { &m.pairing_methods }, + |m: &mut ThpHandshakeCompletionReqNoisePayload| { &mut m.pairing_methods }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpHandshakeCompletionReqNoisePayload", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpHandshakeCompletionReqNoisePayload { + const NAME: &'static str = "ThpHandshakeCompletionReqNoisePayload"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_pairing_credential = ::std::option::Option::Some(is.read_bytes()?); + }, + 16 => { + self.pairing_methods.push(is.read_enum_or_unknown()?); + }, + 18 => { + ::protobuf::rt::read_repeated_packed_enum_or_unknown_into(is, &mut self.pairing_methods)? + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_pairing_credential.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + for value in &self.pairing_methods { + my_size += ::protobuf::rt::int32_size(2, value.value()); + }; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_pairing_credential.as_ref() { + os.write_bytes(1, v)?; + } + for v in &self.pairing_methods { + os.write_enum(2, ::protobuf::EnumOrUnknown::value(v))?; + }; + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpHandshakeCompletionReqNoisePayload { + ThpHandshakeCompletionReqNoisePayload::new() + } + + fn clear(&mut self) { + self.host_pairing_credential = ::std::option::Option::None; + self.pairing_methods.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpHandshakeCompletionReqNoisePayload { + static instance: ThpHandshakeCompletionReqNoisePayload = ThpHandshakeCompletionReqNoisePayload { + host_pairing_credential: ::std::option::Option::None, + pairing_methods: ::std::vec::Vec::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpHandshakeCompletionReqNoisePayload { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpHandshakeCompletionReqNoisePayload").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpHandshakeCompletionReqNoisePayload { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpHandshakeCompletionReqNoisePayload { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCreateNewSession) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCreateNewSession { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.passphrase) + pub passphrase: ::std::option::Option<::std::string::String>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.on_device) + pub on_device: ::std::option::Option, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCreateNewSession.derive_cardano) + pub derive_cardano: ::std::option::Option, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCreateNewSession.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCreateNewSession { + fn default() -> &'a ThpCreateNewSession { + ::default_instance() + } +} + +impl ThpCreateNewSession { + pub fn new() -> ThpCreateNewSession { + ::std::default::Default::default() + } + + // optional string passphrase = 1; + + pub fn passphrase(&self) -> &str { + match self.passphrase.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_passphrase(&mut self) { + self.passphrase = ::std::option::Option::None; + } + + pub fn has_passphrase(&self) -> bool { + self.passphrase.is_some() + } + + // Param is passed by value, moved + pub fn set_passphrase(&mut self, v: ::std::string::String) { + self.passphrase = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_passphrase(&mut self) -> &mut ::std::string::String { + if self.passphrase.is_none() { + self.passphrase = ::std::option::Option::Some(::std::string::String::new()); + } + self.passphrase.as_mut().unwrap() + } + + // Take field + pub fn take_passphrase(&mut self) -> ::std::string::String { + self.passphrase.take().unwrap_or_else(|| ::std::string::String::new()) + } + + // optional bool on_device = 2; + + pub fn on_device(&self) -> bool { + self.on_device.unwrap_or(false) + } + + pub fn clear_on_device(&mut self) { + self.on_device = ::std::option::Option::None; + } + + pub fn has_on_device(&self) -> bool { + self.on_device.is_some() + } + + // Param is passed by value, moved + pub fn set_on_device(&mut self, v: bool) { + self.on_device = ::std::option::Option::Some(v); + } + + // optional bool derive_cardano = 3; + + pub fn derive_cardano(&self) -> bool { + self.derive_cardano.unwrap_or(false) + } + + pub fn clear_derive_cardano(&mut self) { + self.derive_cardano = ::std::option::Option::None; + } + + pub fn has_derive_cardano(&self) -> bool { + self.derive_cardano.is_some() + } + + // Param is passed by value, moved + pub fn set_derive_cardano(&mut self, v: bool) { + self.derive_cardano = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(3); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "passphrase", + |m: &ThpCreateNewSession| { &m.passphrase }, + |m: &mut ThpCreateNewSession| { &mut m.passphrase }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "on_device", + |m: &ThpCreateNewSession| { &m.on_device }, + |m: &mut ThpCreateNewSession| { &mut m.on_device }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "derive_cardano", + |m: &ThpCreateNewSession| { &m.derive_cardano }, + |m: &mut ThpCreateNewSession| { &mut m.derive_cardano }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCreateNewSession", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCreateNewSession { + const NAME: &'static str = "ThpCreateNewSession"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.passphrase = ::std::option::Option::Some(is.read_string()?); + }, + 16 => { + self.on_device = ::std::option::Option::Some(is.read_bool()?); + }, + 24 => { + self.derive_cardano = ::std::option::Option::Some(is.read_bool()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.passphrase.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + if let Some(v) = self.on_device { + my_size += 1 + 1; + } + if let Some(v) = self.derive_cardano { + my_size += 1 + 1; + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.passphrase.as_ref() { + os.write_string(1, v)?; + } + if let Some(v) = self.on_device { + os.write_bool(2, v)?; + } + if let Some(v) = self.derive_cardano { + os.write_bool(3, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCreateNewSession { + ThpCreateNewSession::new() + } + + fn clear(&mut self) { + self.passphrase = ::std::option::Option::None; + self.on_device = ::std::option::Option::None; + self.derive_cardano = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCreateNewSession { + static instance: ThpCreateNewSession = ThpCreateNewSession { + passphrase: ::std::option::Option::None, + on_device: ::std::option::Option::None, + derive_cardano: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCreateNewSession { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCreateNewSession").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCreateNewSession { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCreateNewSession { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpNewSession) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpNewSession { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpNewSession.new_session_id) + pub new_session_id: ::std::option::Option, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpNewSession.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpNewSession { + fn default() -> &'a ThpNewSession { + ::default_instance() + } +} + +impl ThpNewSession { + pub fn new() -> ThpNewSession { + ::std::default::Default::default() + } + + // optional uint32 new_session_id = 1; + + pub fn new_session_id(&self) -> u32 { + self.new_session_id.unwrap_or(0) + } + + pub fn clear_new_session_id(&mut self) { + self.new_session_id = ::std::option::Option::None; + } + + pub fn has_new_session_id(&self) -> bool { + self.new_session_id.is_some() + } + + // Param is passed by value, moved + pub fn set_new_session_id(&mut self, v: u32) { + self.new_session_id = ::std::option::Option::Some(v); + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "new_session_id", + |m: &ThpNewSession| { &m.new_session_id }, + |m: &mut ThpNewSession| { &mut m.new_session_id }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpNewSession", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpNewSession { + const NAME: &'static str = "ThpNewSession"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 8 => { + self.new_session_id = ::std::option::Option::Some(is.read_uint32()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.new_session_id { + my_size += ::protobuf::rt::uint32_size(1, v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.new_session_id { + os.write_uint32(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpNewSession { + ThpNewSession::new() + } + + fn clear(&mut self) { + self.new_session_id = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpNewSession { + static instance: ThpNewSession = ThpNewSession { + new_session_id: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpNewSession { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpNewSession").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpNewSession { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpNewSession { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpStartPairingRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpStartPairingRequest { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpStartPairingRequest.host_name) + pub host_name: ::std::option::Option<::std::string::String>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpStartPairingRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpStartPairingRequest { + fn default() -> &'a ThpStartPairingRequest { + ::default_instance() + } +} + +impl ThpStartPairingRequest { + pub fn new() -> ThpStartPairingRequest { + ::std::default::Default::default() + } + + // optional string host_name = 1; + + pub fn host_name(&self) -> &str { + match self.host_name.as_ref() { + Some(v) => v, + None => "", + } + } + + pub fn clear_host_name(&mut self) { + self.host_name = ::std::option::Option::None; + } + + pub fn has_host_name(&self) -> bool { + self.host_name.is_some() + } + + // Param is passed by value, moved + pub fn set_host_name(&mut self, v: ::std::string::String) { + self.host_name = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_name(&mut self) -> &mut ::std::string::String { + if self.host_name.is_none() { + self.host_name = ::std::option::Option::Some(::std::string::String::new()); + } + self.host_name.as_mut().unwrap() + } + + // Take field + pub fn take_host_name(&mut self) -> ::std::string::String { + self.host_name.take().unwrap_or_else(|| ::std::string::String::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_name", + |m: &ThpStartPairingRequest| { &m.host_name }, + |m: &mut ThpStartPairingRequest| { &mut m.host_name }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpStartPairingRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpStartPairingRequest { + const NAME: &'static str = "ThpStartPairingRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_name = ::std::option::Option::Some(is.read_string()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_name.as_ref() { + my_size += ::protobuf::rt::string_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_name.as_ref() { + os.write_string(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpStartPairingRequest { + ThpStartPairingRequest::new() + } + + fn clear(&mut self) { + self.host_name = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpStartPairingRequest { + static instance: ThpStartPairingRequest = ThpStartPairingRequest { + host_name: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpStartPairingRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpStartPairingRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpStartPairingRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpStartPairingRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpPairingPreparationsFinished) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpPairingPreparationsFinished { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpPairingPreparationsFinished.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpPairingPreparationsFinished { + fn default() -> &'a ThpPairingPreparationsFinished { + ::default_instance() + } +} + +impl ThpPairingPreparationsFinished { + pub fn new() -> ThpPairingPreparationsFinished { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpPairingPreparationsFinished", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpPairingPreparationsFinished { + const NAME: &'static str = "ThpPairingPreparationsFinished"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpPairingPreparationsFinished { + ThpPairingPreparationsFinished::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpPairingPreparationsFinished { + static instance: ThpPairingPreparationsFinished = ThpPairingPreparationsFinished { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpPairingPreparationsFinished { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpPairingPreparationsFinished").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpPairingPreparationsFinished { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpPairingPreparationsFinished { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCommitment) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCommitment { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCommitment.commitment) + pub commitment: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCommitment.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCommitment { + fn default() -> &'a ThpCodeEntryCommitment { + ::default_instance() + } +} + +impl ThpCodeEntryCommitment { + pub fn new() -> ThpCodeEntryCommitment { + ::std::default::Default::default() + } + + // optional bytes commitment = 1; + + pub fn commitment(&self) -> &[u8] { + match self.commitment.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_commitment(&mut self) { + self.commitment = ::std::option::Option::None; + } + + pub fn has_commitment(&self) -> bool { + self.commitment.is_some() + } + + // Param is passed by value, moved + pub fn set_commitment(&mut self, v: ::std::vec::Vec) { + self.commitment = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_commitment(&mut self) -> &mut ::std::vec::Vec { + if self.commitment.is_none() { + self.commitment = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.commitment.as_mut().unwrap() + } + + // Take field + pub fn take_commitment(&mut self) -> ::std::vec::Vec { + self.commitment.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "commitment", + |m: &ThpCodeEntryCommitment| { &m.commitment }, + |m: &mut ThpCodeEntryCommitment| { &mut m.commitment }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCommitment", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCommitment { + const NAME: &'static str = "ThpCodeEntryCommitment"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.commitment = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.commitment.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.commitment.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCommitment { + ThpCodeEntryCommitment::new() + } + + fn clear(&mut self) { + self.commitment = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCommitment { + static instance: ThpCodeEntryCommitment = ThpCodeEntryCommitment { + commitment: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCommitment { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCommitment").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCommitment { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCommitment { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryChallenge) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryChallenge { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryChallenge.challenge) + pub challenge: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryChallenge.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryChallenge { + fn default() -> &'a ThpCodeEntryChallenge { + ::default_instance() + } +} + +impl ThpCodeEntryChallenge { + pub fn new() -> ThpCodeEntryChallenge { + ::std::default::Default::default() + } + + // optional bytes challenge = 1; + + pub fn challenge(&self) -> &[u8] { + match self.challenge.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_challenge(&mut self) { + self.challenge = ::std::option::Option::None; + } + + pub fn has_challenge(&self) -> bool { + self.challenge.is_some() + } + + // Param is passed by value, moved + pub fn set_challenge(&mut self, v: ::std::vec::Vec) { + self.challenge = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_challenge(&mut self) -> &mut ::std::vec::Vec { + if self.challenge.is_none() { + self.challenge = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.challenge.as_mut().unwrap() + } + + // Take field + pub fn take_challenge(&mut self) -> ::std::vec::Vec { + self.challenge.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "challenge", + |m: &ThpCodeEntryChallenge| { &m.challenge }, + |m: &mut ThpCodeEntryChallenge| { &mut m.challenge }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryChallenge", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryChallenge { + const NAME: &'static str = "ThpCodeEntryChallenge"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.challenge = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.challenge.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.challenge.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryChallenge { + ThpCodeEntryChallenge::new() + } + + fn clear(&mut self) { + self.challenge = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryChallenge { + static instance: ThpCodeEntryChallenge = ThpCodeEntryChallenge { + challenge: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryChallenge { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryChallenge").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryChallenge { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryChallenge { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCpaceHost) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCpaceHost { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCpaceHost.cpace_host_public_key) + pub cpace_host_public_key: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCpaceHost.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCpaceHost { + fn default() -> &'a ThpCodeEntryCpaceHost { + ::default_instance() + } +} + +impl ThpCodeEntryCpaceHost { + pub fn new() -> ThpCodeEntryCpaceHost { + ::std::default::Default::default() + } + + // optional bytes cpace_host_public_key = 1; + + pub fn cpace_host_public_key(&self) -> &[u8] { + match self.cpace_host_public_key.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_cpace_host_public_key(&mut self) { + self.cpace_host_public_key = ::std::option::Option::None; + } + + pub fn has_cpace_host_public_key(&self) -> bool { + self.cpace_host_public_key.is_some() + } + + // Param is passed by value, moved + pub fn set_cpace_host_public_key(&mut self, v: ::std::vec::Vec) { + self.cpace_host_public_key = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_cpace_host_public_key(&mut self) -> &mut ::std::vec::Vec { + if self.cpace_host_public_key.is_none() { + self.cpace_host_public_key = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.cpace_host_public_key.as_mut().unwrap() + } + + // Take field + pub fn take_cpace_host_public_key(&mut self) -> ::std::vec::Vec { + self.cpace_host_public_key.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "cpace_host_public_key", + |m: &ThpCodeEntryCpaceHost| { &m.cpace_host_public_key }, + |m: &mut ThpCodeEntryCpaceHost| { &mut m.cpace_host_public_key }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCpaceHost", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCpaceHost { + const NAME: &'static str = "ThpCodeEntryCpaceHost"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.cpace_host_public_key = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.cpace_host_public_key.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.cpace_host_public_key.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCpaceHost { + ThpCodeEntryCpaceHost::new() + } + + fn clear(&mut self) { + self.cpace_host_public_key = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCpaceHost { + static instance: ThpCodeEntryCpaceHost = ThpCodeEntryCpaceHost { + cpace_host_public_key: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCpaceHost { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCpaceHost").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCpaceHost { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCpaceHost { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryCpaceTrezor { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor.cpace_trezor_public_key) + pub cpace_trezor_public_key: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryCpaceTrezor.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryCpaceTrezor { + fn default() -> &'a ThpCodeEntryCpaceTrezor { + ::default_instance() + } +} + +impl ThpCodeEntryCpaceTrezor { + pub fn new() -> ThpCodeEntryCpaceTrezor { + ::std::default::Default::default() + } + + // optional bytes cpace_trezor_public_key = 1; + + pub fn cpace_trezor_public_key(&self) -> &[u8] { + match self.cpace_trezor_public_key.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_cpace_trezor_public_key(&mut self) { + self.cpace_trezor_public_key = ::std::option::Option::None; + } + + pub fn has_cpace_trezor_public_key(&self) -> bool { + self.cpace_trezor_public_key.is_some() + } + + // Param is passed by value, moved + pub fn set_cpace_trezor_public_key(&mut self, v: ::std::vec::Vec) { + self.cpace_trezor_public_key = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_cpace_trezor_public_key(&mut self) -> &mut ::std::vec::Vec { + if self.cpace_trezor_public_key.is_none() { + self.cpace_trezor_public_key = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.cpace_trezor_public_key.as_mut().unwrap() + } + + // Take field + pub fn take_cpace_trezor_public_key(&mut self) -> ::std::vec::Vec { + self.cpace_trezor_public_key.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "cpace_trezor_public_key", + |m: &ThpCodeEntryCpaceTrezor| { &m.cpace_trezor_public_key }, + |m: &mut ThpCodeEntryCpaceTrezor| { &mut m.cpace_trezor_public_key }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryCpaceTrezor", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryCpaceTrezor { + const NAME: &'static str = "ThpCodeEntryCpaceTrezor"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.cpace_trezor_public_key = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.cpace_trezor_public_key.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.cpace_trezor_public_key.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryCpaceTrezor { + ThpCodeEntryCpaceTrezor::new() + } + + fn clear(&mut self) { + self.cpace_trezor_public_key = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryCpaceTrezor { + static instance: ThpCodeEntryCpaceTrezor = ThpCodeEntryCpaceTrezor { + cpace_trezor_public_key: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryCpaceTrezor { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryCpaceTrezor").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryCpaceTrezor { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryCpaceTrezor { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntryTag) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntryTag { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntryTag.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntryTag.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntryTag { + fn default() -> &'a ThpCodeEntryTag { + ::default_instance() + } +} + +impl ThpCodeEntryTag { + pub fn new() -> ThpCodeEntryTag { + ::std::default::Default::default() + } + + // optional bytes tag = 2; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpCodeEntryTag| { &m.tag }, + |m: &mut ThpCodeEntryTag| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntryTag", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntryTag { + const NAME: &'static str = "ThpCodeEntryTag"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 18 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntryTag { + ThpCodeEntryTag::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntryTag { + static instance: ThpCodeEntryTag = ThpCodeEntryTag { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntryTag { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntryTag").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntryTag { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntryTag { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCodeEntrySecret) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCodeEntrySecret { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCodeEntrySecret.secret) + pub secret: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCodeEntrySecret.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCodeEntrySecret { + fn default() -> &'a ThpCodeEntrySecret { + ::default_instance() + } +} + +impl ThpCodeEntrySecret { + pub fn new() -> ThpCodeEntrySecret { + ::std::default::Default::default() + } + + // optional bytes secret = 1; + + pub fn secret(&self) -> &[u8] { + match self.secret.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_secret(&mut self) { + self.secret = ::std::option::Option::None; + } + + pub fn has_secret(&self) -> bool { + self.secret.is_some() + } + + // Param is passed by value, moved + pub fn set_secret(&mut self, v: ::std::vec::Vec) { + self.secret = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_secret(&mut self) -> &mut ::std::vec::Vec { + if self.secret.is_none() { + self.secret = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.secret.as_mut().unwrap() + } + + // Take field + pub fn take_secret(&mut self) -> ::std::vec::Vec { + self.secret.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "secret", + |m: &ThpCodeEntrySecret| { &m.secret }, + |m: &mut ThpCodeEntrySecret| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCodeEntrySecret", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCodeEntrySecret { + const NAME: &'static str = "ThpCodeEntrySecret"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.secret = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.secret.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.secret.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCodeEntrySecret { + ThpCodeEntrySecret::new() + } + + fn clear(&mut self) { + self.secret = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCodeEntrySecret { + static instance: ThpCodeEntrySecret = ThpCodeEntrySecret { + secret: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCodeEntrySecret { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCodeEntrySecret").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCodeEntrySecret { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCodeEntrySecret { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpQrCodeTag) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpQrCodeTag { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpQrCodeTag.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpQrCodeTag.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpQrCodeTag { + fn default() -> &'a ThpQrCodeTag { + ::default_instance() + } +} + +impl ThpQrCodeTag { + pub fn new() -> ThpQrCodeTag { + ::std::default::Default::default() + } + + // optional bytes tag = 1; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpQrCodeTag| { &m.tag }, + |m: &mut ThpQrCodeTag| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpQrCodeTag", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpQrCodeTag { + const NAME: &'static str = "ThpQrCodeTag"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpQrCodeTag { + ThpQrCodeTag::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpQrCodeTag { + static instance: ThpQrCodeTag = ThpQrCodeTag { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpQrCodeTag { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpQrCodeTag").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpQrCodeTag { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpQrCodeTag { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpQrCodeSecret) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpQrCodeSecret { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpQrCodeSecret.secret) + pub secret: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpQrCodeSecret.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpQrCodeSecret { + fn default() -> &'a ThpQrCodeSecret { + ::default_instance() + } +} + +impl ThpQrCodeSecret { + pub fn new() -> ThpQrCodeSecret { + ::std::default::Default::default() + } + + // optional bytes secret = 1; + + pub fn secret(&self) -> &[u8] { + match self.secret.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_secret(&mut self) { + self.secret = ::std::option::Option::None; + } + + pub fn has_secret(&self) -> bool { + self.secret.is_some() + } + + // Param is passed by value, moved + pub fn set_secret(&mut self, v: ::std::vec::Vec) { + self.secret = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_secret(&mut self) -> &mut ::std::vec::Vec { + if self.secret.is_none() { + self.secret = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.secret.as_mut().unwrap() + } + + // Take field + pub fn take_secret(&mut self) -> ::std::vec::Vec { + self.secret.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "secret", + |m: &ThpQrCodeSecret| { &m.secret }, + |m: &mut ThpQrCodeSecret| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpQrCodeSecret", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpQrCodeSecret { + const NAME: &'static str = "ThpQrCodeSecret"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.secret = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.secret.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.secret.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpQrCodeSecret { + ThpQrCodeSecret::new() + } + + fn clear(&mut self) { + self.secret = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpQrCodeSecret { + static instance: ThpQrCodeSecret = ThpQrCodeSecret { + secret: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpQrCodeSecret { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpQrCodeSecret").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpQrCodeSecret { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpQrCodeSecret { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpNfcUnidirectionalTag) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpNfcUnidirectionalTag { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpNfcUnidirectionalTag.tag) + pub tag: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpNfcUnidirectionalTag.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpNfcUnidirectionalTag { + fn default() -> &'a ThpNfcUnidirectionalTag { + ::default_instance() + } +} + +impl ThpNfcUnidirectionalTag { + pub fn new() -> ThpNfcUnidirectionalTag { + ::std::default::Default::default() + } + + // optional bytes tag = 1; + + pub fn tag(&self) -> &[u8] { + match self.tag.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_tag(&mut self) { + self.tag = ::std::option::Option::None; + } + + pub fn has_tag(&self) -> bool { + self.tag.is_some() + } + + // Param is passed by value, moved + pub fn set_tag(&mut self, v: ::std::vec::Vec) { + self.tag = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_tag(&mut self) -> &mut ::std::vec::Vec { + if self.tag.is_none() { + self.tag = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.tag.as_mut().unwrap() + } + + // Take field + pub fn take_tag(&mut self) -> ::std::vec::Vec { + self.tag.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "tag", + |m: &ThpNfcUnidirectionalTag| { &m.tag }, + |m: &mut ThpNfcUnidirectionalTag| { &mut m.tag }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpNfcUnidirectionalTag", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpNfcUnidirectionalTag { + const NAME: &'static str = "ThpNfcUnidirectionalTag"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.tag = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.tag.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.tag.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpNfcUnidirectionalTag { + ThpNfcUnidirectionalTag::new() + } + + fn clear(&mut self) { + self.tag = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpNfcUnidirectionalTag { + static instance: ThpNfcUnidirectionalTag = ThpNfcUnidirectionalTag { + tag: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpNfcUnidirectionalTag { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpNfcUnidirectionalTag").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpNfcUnidirectionalTag { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpNfcUnidirectionalTag { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpNfcUnidirectionalSecret) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpNfcUnidirectionalSecret { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpNfcUnidirectionalSecret.secret) + pub secret: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpNfcUnidirectionalSecret.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpNfcUnidirectionalSecret { + fn default() -> &'a ThpNfcUnidirectionalSecret { + ::default_instance() + } +} + +impl ThpNfcUnidirectionalSecret { + pub fn new() -> ThpNfcUnidirectionalSecret { + ::std::default::Default::default() + } + + // optional bytes secret = 1; + + pub fn secret(&self) -> &[u8] { + match self.secret.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_secret(&mut self) { + self.secret = ::std::option::Option::None; + } + + pub fn has_secret(&self) -> bool { + self.secret.is_some() + } + + // Param is passed by value, moved + pub fn set_secret(&mut self, v: ::std::vec::Vec) { + self.secret = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_secret(&mut self) -> &mut ::std::vec::Vec { + if self.secret.is_none() { + self.secret = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.secret.as_mut().unwrap() + } + + // Take field + pub fn take_secret(&mut self) -> ::std::vec::Vec { + self.secret.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "secret", + |m: &ThpNfcUnidirectionalSecret| { &m.secret }, + |m: &mut ThpNfcUnidirectionalSecret| { &mut m.secret }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpNfcUnidirectionalSecret", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpNfcUnidirectionalSecret { + const NAME: &'static str = "ThpNfcUnidirectionalSecret"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.secret = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.secret.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.secret.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpNfcUnidirectionalSecret { + ThpNfcUnidirectionalSecret::new() + } + + fn clear(&mut self) { + self.secret = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpNfcUnidirectionalSecret { + static instance: ThpNfcUnidirectionalSecret = ThpNfcUnidirectionalSecret { + secret: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpNfcUnidirectionalSecret { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpNfcUnidirectionalSecret").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpNfcUnidirectionalSecret { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpNfcUnidirectionalSecret { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCredentialRequest { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialRequest.host_static_pubkey) + pub host_static_pubkey: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCredentialRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCredentialRequest { + fn default() -> &'a ThpCredentialRequest { + ::default_instance() + } +} + +impl ThpCredentialRequest { + pub fn new() -> ThpCredentialRequest { + ::std::default::Default::default() + } + + // optional bytes host_static_pubkey = 1; + + pub fn host_static_pubkey(&self) -> &[u8] { + match self.host_static_pubkey.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_host_static_pubkey(&mut self) { + self.host_static_pubkey = ::std::option::Option::None; + } + + pub fn has_host_static_pubkey(&self) -> bool { + self.host_static_pubkey.is_some() + } + + // Param is passed by value, moved + pub fn set_host_static_pubkey(&mut self, v: ::std::vec::Vec) { + self.host_static_pubkey = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_host_static_pubkey(&mut self) -> &mut ::std::vec::Vec { + if self.host_static_pubkey.is_none() { + self.host_static_pubkey = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.host_static_pubkey.as_mut().unwrap() + } + + // Take field + pub fn take_host_static_pubkey(&mut self) -> ::std::vec::Vec { + self.host_static_pubkey.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "host_static_pubkey", + |m: &ThpCredentialRequest| { &m.host_static_pubkey }, + |m: &mut ThpCredentialRequest| { &mut m.host_static_pubkey }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCredentialRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCredentialRequest { + const NAME: &'static str = "ThpCredentialRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.host_static_pubkey = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.host_static_pubkey.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.host_static_pubkey.as_ref() { + os.write_bytes(1, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCredentialRequest { + ThpCredentialRequest::new() + } + + fn clear(&mut self) { + self.host_static_pubkey = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCredentialRequest { + static instance: ThpCredentialRequest = ThpCredentialRequest { + host_static_pubkey: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCredentialRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCredentialRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCredentialRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCredentialRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialResponse) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpCredentialResponse { + // message fields + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialResponse.trezor_static_pubkey) + pub trezor_static_pubkey: ::std::option::Option<::std::vec::Vec>, + // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpCredentialResponse.credential) + pub credential: ::std::option::Option<::std::vec::Vec>, + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpCredentialResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpCredentialResponse { + fn default() -> &'a ThpCredentialResponse { + ::default_instance() + } +} + +impl ThpCredentialResponse { + pub fn new() -> ThpCredentialResponse { + ::std::default::Default::default() + } + + // optional bytes trezor_static_pubkey = 1; + + pub fn trezor_static_pubkey(&self) -> &[u8] { + match self.trezor_static_pubkey.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_trezor_static_pubkey(&mut self) { + self.trezor_static_pubkey = ::std::option::Option::None; + } + + pub fn has_trezor_static_pubkey(&self) -> bool { + self.trezor_static_pubkey.is_some() + } + + // Param is passed by value, moved + pub fn set_trezor_static_pubkey(&mut self, v: ::std::vec::Vec) { + self.trezor_static_pubkey = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_trezor_static_pubkey(&mut self) -> &mut ::std::vec::Vec { + if self.trezor_static_pubkey.is_none() { + self.trezor_static_pubkey = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.trezor_static_pubkey.as_mut().unwrap() + } + + // Take field + pub fn take_trezor_static_pubkey(&mut self) -> ::std::vec::Vec { + self.trezor_static_pubkey.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + // optional bytes credential = 2; + + pub fn credential(&self) -> &[u8] { + match self.credential.as_ref() { + Some(v) => v, + None => &[], + } + } + + pub fn clear_credential(&mut self) { + self.credential = ::std::option::Option::None; + } + + pub fn has_credential(&self) -> bool { + self.credential.is_some() + } + + // Param is passed by value, moved + pub fn set_credential(&mut self, v: ::std::vec::Vec) { + self.credential = ::std::option::Option::Some(v); + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_credential(&mut self) -> &mut ::std::vec::Vec { + if self.credential.is_none() { + self.credential = ::std::option::Option::Some(::std::vec::Vec::new()); + } + self.credential.as_mut().unwrap() + } + + // Take field + pub fn take_credential(&mut self) -> ::std::vec::Vec { + self.credential.take().unwrap_or_else(|| ::std::vec::Vec::new()) + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(2); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "trezor_static_pubkey", + |m: &ThpCredentialResponse| { &m.trezor_static_pubkey }, + |m: &mut ThpCredentialResponse| { &mut m.trezor_static_pubkey }, + )); + fields.push(::protobuf::reflect::rt::v2::make_option_accessor::<_, _>( + "credential", + |m: &ThpCredentialResponse| { &m.credential }, + |m: &mut ThpCredentialResponse| { &mut m.credential }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpCredentialResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpCredentialResponse { + const NAME: &'static str = "ThpCredentialResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.trezor_static_pubkey = ::std::option::Option::Some(is.read_bytes()?); + }, + 18 => { + self.credential = ::std::option::Option::Some(is.read_bytes()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + if let Some(v) = self.trezor_static_pubkey.as_ref() { + my_size += ::protobuf::rt::bytes_size(1, &v); + } + if let Some(v) = self.credential.as_ref() { + my_size += ::protobuf::rt::bytes_size(2, &v); + } + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + if let Some(v) = self.trezor_static_pubkey.as_ref() { + os.write_bytes(1, v)?; + } + if let Some(v) = self.credential.as_ref() { + os.write_bytes(2, v)?; + } + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpCredentialResponse { + ThpCredentialResponse::new() + } + + fn clear(&mut self) { + self.trezor_static_pubkey = ::std::option::Option::None; + self.credential = ::std::option::Option::None; + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpCredentialResponse { + static instance: ThpCredentialResponse = ThpCredentialResponse { + trezor_static_pubkey: ::std::option::Option::None, + credential: ::std::option::Option::None, + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpCredentialResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpCredentialResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpCredentialResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpCredentialResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpEndRequest) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpEndRequest { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpEndRequest.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpEndRequest { + fn default() -> &'a ThpEndRequest { + ::default_instance() + } +} + +impl ThpEndRequest { + pub fn new() -> ThpEndRequest { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpEndRequest", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpEndRequest { + const NAME: &'static str = "ThpEndRequest"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpEndRequest { + ThpEndRequest::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpEndRequest { + static instance: ThpEndRequest = ThpEndRequest { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpEndRequest { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpEndRequest").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpEndRequest { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpEndRequest { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + +// @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpEndResponse) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct ThpEndResponse { + // special fields + // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpEndResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a ThpEndResponse { + fn default() -> &'a ThpEndResponse { + ::default_instance() + } +} + +impl ThpEndResponse { + pub fn new() -> ThpEndResponse { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(0); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "ThpEndResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for ThpEndResponse { + const NAME: &'static str = "ThpEndResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> ThpEndResponse { + ThpEndResponse::new() + } + + fn clear(&mut self) { + self.special_fields.clear(); + } + + fn default_instance() -> &'static ThpEndResponse { + static instance: ThpEndResponse = ThpEndResponse { + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for ThpEndResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("ThpEndResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for ThpEndResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for ThpEndResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + // @@protoc_insertion_point(message:hw.trezor.messages.thp.ThpCredentialMetadata) #[derive(PartialEq,Clone,Default,Debug)] pub struct ThpCredentialMetadata { @@ -537,17 +3796,316 @@ impl ::protobuf::reflect::ProtobufValue for ThpAuthenticatedCredentialData { type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; } +#[derive(Clone,Copy,PartialEq,Eq,Debug,Hash)] +// @@protoc_insertion_point(enum:hw.trezor.messages.thp.ThpMessageType) +pub enum ThpMessageType { + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCreateNewSession) + ThpMessageType_ThpCreateNewSession = 1000, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpNewSession) + ThpMessageType_ThpNewSession = 1001, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpStartPairingRequest) + ThpMessageType_ThpStartPairingRequest = 1008, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpPairingPreparationsFinished) + ThpMessageType_ThpPairingPreparationsFinished = 1009, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCredentialRequest) + ThpMessageType_ThpCredentialRequest = 1010, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCredentialResponse) + ThpMessageType_ThpCredentialResponse = 1011, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpEndRequest) + ThpMessageType_ThpEndRequest = 1012, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpEndResponse) + ThpMessageType_ThpEndResponse = 1013, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryCommitment) + ThpMessageType_ThpCodeEntryCommitment = 1016, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryChallenge) + ThpMessageType_ThpCodeEntryChallenge = 1017, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryCpaceHost) + ThpMessageType_ThpCodeEntryCpaceHost = 1018, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryCpaceTrezor) + ThpMessageType_ThpCodeEntryCpaceTrezor = 1019, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntryTag) + ThpMessageType_ThpCodeEntryTag = 1020, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpCodeEntrySecret) + ThpMessageType_ThpCodeEntrySecret = 1021, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpQrCodeTag) + ThpMessageType_ThpQrCodeTag = 1024, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpQrCodeSecret) + ThpMessageType_ThpQrCodeSecret = 1025, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpNfcUnidirectionalTag) + ThpMessageType_ThpNfcUnidirectionalTag = 1032, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpMessageType.ThpMessageType_ThpNfcUnidirectionalSecret) + ThpMessageType_ThpNfcUnidirectionalSecret = 1033, +} + +impl ::protobuf::Enum for ThpMessageType { + const NAME: &'static str = "ThpMessageType"; + + fn value(&self) -> i32 { + *self as i32 + } + + fn from_i32(value: i32) -> ::std::option::Option { + match value { + 1000 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCreateNewSession), + 1001 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNewSession), + 1008 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpStartPairingRequest), + 1009 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpPairingPreparationsFinished), + 1010 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCredentialRequest), + 1011 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCredentialResponse), + 1012 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpEndRequest), + 1013 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpEndResponse), + 1016 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCommitment), + 1017 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryChallenge), + 1018 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCpaceHost), + 1019 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCpaceTrezor), + 1020 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryTag), + 1021 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntrySecret), + 1024 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpQrCodeTag), + 1025 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpQrCodeSecret), + 1032 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNfcUnidirectionalTag), + 1033 => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNfcUnidirectionalSecret), + _ => ::std::option::Option::None + } + } + + fn from_str(str: &str) -> ::std::option::Option { + match str { + "ThpMessageType_ThpCreateNewSession" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCreateNewSession), + "ThpMessageType_ThpNewSession" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNewSession), + "ThpMessageType_ThpStartPairingRequest" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpStartPairingRequest), + "ThpMessageType_ThpPairingPreparationsFinished" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpPairingPreparationsFinished), + "ThpMessageType_ThpCredentialRequest" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCredentialRequest), + "ThpMessageType_ThpCredentialResponse" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCredentialResponse), + "ThpMessageType_ThpEndRequest" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpEndRequest), + "ThpMessageType_ThpEndResponse" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpEndResponse), + "ThpMessageType_ThpCodeEntryCommitment" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCommitment), + "ThpMessageType_ThpCodeEntryChallenge" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryChallenge), + "ThpMessageType_ThpCodeEntryCpaceHost" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCpaceHost), + "ThpMessageType_ThpCodeEntryCpaceTrezor" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryCpaceTrezor), + "ThpMessageType_ThpCodeEntryTag" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntryTag), + "ThpMessageType_ThpCodeEntrySecret" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpCodeEntrySecret), + "ThpMessageType_ThpQrCodeTag" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpQrCodeTag), + "ThpMessageType_ThpQrCodeSecret" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpQrCodeSecret), + "ThpMessageType_ThpNfcUnidirectionalTag" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNfcUnidirectionalTag), + "ThpMessageType_ThpNfcUnidirectionalSecret" => ::std::option::Option::Some(ThpMessageType::ThpMessageType_ThpNfcUnidirectionalSecret), + _ => ::std::option::Option::None + } + } + + const VALUES: &'static [ThpMessageType] = &[ + ThpMessageType::ThpMessageType_ThpCreateNewSession, + ThpMessageType::ThpMessageType_ThpNewSession, + ThpMessageType::ThpMessageType_ThpStartPairingRequest, + ThpMessageType::ThpMessageType_ThpPairingPreparationsFinished, + ThpMessageType::ThpMessageType_ThpCredentialRequest, + ThpMessageType::ThpMessageType_ThpCredentialResponse, + ThpMessageType::ThpMessageType_ThpEndRequest, + ThpMessageType::ThpMessageType_ThpEndResponse, + ThpMessageType::ThpMessageType_ThpCodeEntryCommitment, + ThpMessageType::ThpMessageType_ThpCodeEntryChallenge, + ThpMessageType::ThpMessageType_ThpCodeEntryCpaceHost, + ThpMessageType::ThpMessageType_ThpCodeEntryCpaceTrezor, + ThpMessageType::ThpMessageType_ThpCodeEntryTag, + ThpMessageType::ThpMessageType_ThpCodeEntrySecret, + ThpMessageType::ThpMessageType_ThpQrCodeTag, + ThpMessageType::ThpMessageType_ThpQrCodeSecret, + ThpMessageType::ThpMessageType_ThpNfcUnidirectionalTag, + ThpMessageType::ThpMessageType_ThpNfcUnidirectionalSecret, + ]; +} + +impl ::protobuf::EnumFull for ThpMessageType { + fn enum_descriptor() -> ::protobuf::reflect::EnumDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::EnumDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().enum_by_package_relative_name("ThpMessageType").unwrap()).clone() + } + + fn descriptor(&self) -> ::protobuf::reflect::EnumValueDescriptor { + let index = match self { + ThpMessageType::ThpMessageType_ThpCreateNewSession => 0, + ThpMessageType::ThpMessageType_ThpNewSession => 1, + ThpMessageType::ThpMessageType_ThpStartPairingRequest => 2, + ThpMessageType::ThpMessageType_ThpPairingPreparationsFinished => 3, + ThpMessageType::ThpMessageType_ThpCredentialRequest => 4, + ThpMessageType::ThpMessageType_ThpCredentialResponse => 5, + ThpMessageType::ThpMessageType_ThpEndRequest => 6, + ThpMessageType::ThpMessageType_ThpEndResponse => 7, + ThpMessageType::ThpMessageType_ThpCodeEntryCommitment => 8, + ThpMessageType::ThpMessageType_ThpCodeEntryChallenge => 9, + ThpMessageType::ThpMessageType_ThpCodeEntryCpaceHost => 10, + ThpMessageType::ThpMessageType_ThpCodeEntryCpaceTrezor => 11, + ThpMessageType::ThpMessageType_ThpCodeEntryTag => 12, + ThpMessageType::ThpMessageType_ThpCodeEntrySecret => 13, + ThpMessageType::ThpMessageType_ThpQrCodeTag => 14, + ThpMessageType::ThpMessageType_ThpQrCodeSecret => 15, + ThpMessageType::ThpMessageType_ThpNfcUnidirectionalTag => 16, + ThpMessageType::ThpMessageType_ThpNfcUnidirectionalSecret => 17, + }; + Self::enum_descriptor().value_by_index(index) + } +} + +// Note, `Default` is implemented although default value is not 0 +impl ::std::default::Default for ThpMessageType { + fn default() -> Self { + ThpMessageType::ThpMessageType_ThpCreateNewSession + } +} + +impl ThpMessageType { + fn generated_enum_descriptor_data() -> ::protobuf::reflect::GeneratedEnumDescriptorData { + ::protobuf::reflect::GeneratedEnumDescriptorData::new::("ThpMessageType") + } +} + +#[derive(Clone,Copy,PartialEq,Eq,Debug,Hash)] +// @@protoc_insertion_point(enum:hw.trezor.messages.thp.ThpPairingMethod) +pub enum ThpPairingMethod { + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.NoMethod) + NoMethod = 1, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.CodeEntry) + CodeEntry = 2, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.QrCode) + QrCode = 3, + // @@protoc_insertion_point(enum_value:hw.trezor.messages.thp.ThpPairingMethod.NFC_Unidirectional) + NFC_Unidirectional = 4, +} + +impl ::protobuf::Enum for ThpPairingMethod { + const NAME: &'static str = "ThpPairingMethod"; + + fn value(&self) -> i32 { + *self as i32 + } + + fn from_i32(value: i32) -> ::std::option::Option { + match value { + 1 => ::std::option::Option::Some(ThpPairingMethod::NoMethod), + 2 => ::std::option::Option::Some(ThpPairingMethod::CodeEntry), + 3 => ::std::option::Option::Some(ThpPairingMethod::QrCode), + 4 => ::std::option::Option::Some(ThpPairingMethod::NFC_Unidirectional), + _ => ::std::option::Option::None + } + } + + fn from_str(str: &str) -> ::std::option::Option { + match str { + "NoMethod" => ::std::option::Option::Some(ThpPairingMethod::NoMethod), + "CodeEntry" => ::std::option::Option::Some(ThpPairingMethod::CodeEntry), + "QrCode" => ::std::option::Option::Some(ThpPairingMethod::QrCode), + "NFC_Unidirectional" => ::std::option::Option::Some(ThpPairingMethod::NFC_Unidirectional), + _ => ::std::option::Option::None + } + } + + const VALUES: &'static [ThpPairingMethod] = &[ + ThpPairingMethod::NoMethod, + ThpPairingMethod::CodeEntry, + ThpPairingMethod::QrCode, + ThpPairingMethod::NFC_Unidirectional, + ]; +} + +impl ::protobuf::EnumFull for ThpPairingMethod { + fn enum_descriptor() -> ::protobuf::reflect::EnumDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::EnumDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().enum_by_package_relative_name("ThpPairingMethod").unwrap()).clone() + } + + fn descriptor(&self) -> ::protobuf::reflect::EnumValueDescriptor { + let index = match self { + ThpPairingMethod::NoMethod => 0, + ThpPairingMethod::CodeEntry => 1, + ThpPairingMethod::QrCode => 2, + ThpPairingMethod::NFC_Unidirectional => 3, + }; + Self::enum_descriptor().value_by_index(index) + } +} + +// Note, `Default` is implemented although default value is not 0 +impl ::std::default::Default for ThpPairingMethod { + fn default() -> Self { + ThpPairingMethod::NoMethod + } +} + +impl ThpPairingMethod { + fn generated_enum_descriptor_data() -> ::protobuf::reflect::GeneratedEnumDescriptorData { + ::protobuf::reflect::GeneratedEnumDescriptorData::new::("ThpPairingMethod") + } +} + static file_descriptor_proto_data: &'static [u8] = b"\ \n\x12messages-thp.proto\x12\x16hw.trezor.messages.thp\x1a\roptions.prot\ - o\":\n\x15ThpCredentialMetadata\x12\x1b\n\thost_name\x18\x01\x20\x01(\tR\ - \x08hostName:\x04\x98\xb2\x19\x01\"\x82\x01\n\x14ThpPairingCredential\ - \x12R\n\rcred_metadata\x18\x01\x20\x01(\x0b2-.hw.trezor.messages.thp.Thp\ - CredentialMetadataR\x0ccredMetadata\x12\x10\n\x03mac\x18\x02\x20\x01(\ - \x0cR\x03mac:\x04\x98\xb2\x19\x01\"\xa8\x01\n\x1eThpAuthenticatedCredent\ - ialData\x12,\n\x12host_static_pubkey\x18\x01\x20\x01(\x0cR\x10hostStatic\ - Pubkey\x12R\n\rcred_metadata\x18\x02\x20\x01(\x0b2-.hw.trezor.messages.t\ - hp.ThpCredentialMetadataR\x0ccredMetadata:\x04\x98\xb2\x19\x01B;\n#com.s\ - atoshilabs.trezor.lib.protobufB\x10TrezorMessageThp\x80\xa6\x1d\x01\ + o\"\x88\x02\n\x13ThpDeviceProperties\x12%\n\x0einternal_model\x18\x01\ + \x20\x01(\tR\rinternalModel\x12#\n\rmodel_variant\x18\x02\x20\x01(\rR\ + \x0cmodelVariant\x12'\n\x0fbootloader_mode\x18\x03\x20\x01(\x08R\x0eboot\ + loaderMode\x12)\n\x10protocol_version\x18\x04\x20\x01(\rR\x0fprotocolVer\ + sion\x12Q\n\x0fpairing_methods\x18\x05\x20\x03(\x0e2(.hw.trezor.messages\ + .thp.ThpPairingMethodR\x0epairingMethods\"\xb2\x01\n%ThpHandshakeComplet\ + ionReqNoisePayload\x126\n\x17host_pairing_credential\x18\x01\x20\x01(\ + \x0cR\x15hostPairingCredential\x12Q\n\x0fpairing_methods\x18\x02\x20\x03\ + (\x0e2(.hw.trezor.messages.thp.ThpPairingMethodR\x0epairingMethods\"y\n\ + \x13ThpCreateNewSession\x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\npassp\ + hrase\x12\x1b\n\ton_device\x18\x02\x20\x01(\x08R\x08onDevice\x12%\n\x0ed\ + erive_cardano\x18\x03\x20\x01(\x08R\rderiveCardano\"5\n\rThpNewSession\ + \x12$\n\x0enew_session_id\x18\x01\x20\x01(\rR\x0cnewSessionId\"5\n\x16Th\ + pStartPairingRequest\x12\x1b\n\thost_name\x18\x01\x20\x01(\tR\x08hostNam\ + e\"\x20\n\x1eThpPairingPreparationsFinished\"8\n\x16ThpCodeEntryCommitme\ + nt\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitment\"5\n\x15ThpCo\ + deEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\x0cR\tchallenge\"\ + J\n\x15ThpCodeEntryCpaceHost\x121\n\x15cpace_host_public_key\x18\x01\x20\ + \x01(\x0cR\x12cpaceHostPublicKey\"P\n\x17ThpCodeEntryCpaceTrezor\x125\n\ + \x17cpace_trezor_public_key\x18\x01\x20\x01(\x0cR\x14cpaceTrezorPublicKe\ + y\"#\n\x0fThpCodeEntryTag\x12\x10\n\x03tag\x18\x02\x20\x01(\x0cR\x03tag\ + \",\n\x12ThpCodeEntrySecret\x12\x16\n\x06secret\x18\x01\x20\x01(\x0cR\ + \x06secret\"\x20\n\x0cThpQrCodeTag\x12\x10\n\x03tag\x18\x01\x20\x01(\x0c\ + R\x03tag\")\n\x0fThpQrCodeSecret\x12\x16\n\x06secret\x18\x01\x20\x01(\ + \x0cR\x06secret\"+\n\x17ThpNfcUnidirectionalTag\x12\x10\n\x03tag\x18\x01\ + \x20\x01(\x0cR\x03tag\"4\n\x1aThpNfcUnidirectionalSecret\x12\x16\n\x06se\ + cret\x18\x01\x20\x01(\x0cR\x06secret\"D\n\x14ThpCredentialRequest\x12,\n\ + \x12host_static_pubkey\x18\x01\x20\x01(\x0cR\x10hostStaticPubkey\"i\n\ + \x15ThpCredentialResponse\x120\n\x14trezor_static_pubkey\x18\x01\x20\x01\ + (\x0cR\x12trezorStaticPubkey\x12\x1e\n\ncredential\x18\x02\x20\x01(\x0cR\ + \ncredential\"\x0f\n\rThpEndRequest\"\x10\n\x0eThpEndResponse\":\n\x15Th\ + pCredentialMetadata\x12\x1b\n\thost_name\x18\x01\x20\x01(\tR\x08hostName\ + :\x04\x98\xb2\x19\x01\"\x82\x01\n\x14ThpPairingCredential\x12R\n\rcred_m\ + etadata\x18\x01\x20\x01(\x0b2-.hw.trezor.messages.thp.ThpCredentialMetad\ + ataR\x0ccredMetadata\x12\x10\n\x03mac\x18\x02\x20\x01(\x0cR\x03mac:\x04\ + \x98\xb2\x19\x01\"\xa8\x01\n\x1eThpAuthenticatedCredentialData\x12,\n\ + \x12host_static_pubkey\x18\x01\x20\x01(\x0cR\x10hostStaticPubkey\x12R\n\ + \rcred_metadata\x18\x02\x20\x01(\x0b2-.hw.trezor.messages.thp.ThpCredent\ + ialMetadataR\x0ccredMetadata:\x04\x98\xb2\x19\x01*\xbe\x07\n\x0eThpMessa\ + geType\x121\n\"ThpMessageType_ThpCreateNewSession\x10\xe8\x07\x1a\x08\ + \x80\xa6\x1d\x01\xc8\xb5\x18\x01\x12+\n\x1cThpMessageType_ThpNewSession\ + \x10\xe9\x07\x1a\x08\x80\xa6\x1d\x01\xd0\xb5\x18\x01\x124\n%ThpMessageTy\ + pe_ThpStartPairingRequest\x10\xf0\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\ + \x18\x01\x12<\n-ThpMessageType_ThpPairingPreparationsFinished\x10\xf1\ + \x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x122\n#ThpMessageType_ThpCr\ + edentialRequest\x10\xf2\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x123\ + \n$ThpMessageType_ThpCredentialResponse\x10\xf3\x07\x1a\x08\x80\xa6\x1d\ + \x01\xe0\xb5\x18\x01\x12+\n\x1cThpMessageType_ThpEndRequest\x10\xf4\x07\ + \x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x12,\n\x1dThpMessageType_ThpEnd\ + Response\x10\xf5\x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x124\n%ThpM\ + essageType_ThpCodeEntryCommitment\x10\xf8\x07\x1a\x08\x80\xa6\x1d\x01\ + \xe0\xb5\x18\x01\x123\n$ThpMessageType_ThpCodeEntryChallenge\x10\xf9\x07\ + \x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x123\n$ThpMessageType_ThpCodeEn\ + tryCpaceHost\x10\xfa\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x125\n&\ + ThpMessageType_ThpCodeEntryCpaceTrezor\x10\xfb\x07\x1a\x08\x80\xa6\x1d\ + \x01\xe0\xb5\x18\x01\x12-\n\x1eThpMessageType_ThpCodeEntryTag\x10\xfc\ + \x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x120\n!ThpMessageType_ThpCo\ + deEntrySecret\x10\xfd\x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x12*\n\ + \x1bThpMessageType_ThpQrCodeTag\x10\x80\x08\x1a\x08\x80\xa6\x1d\x01\xd8\ + \xb5\x18\x01\x12-\n\x1eThpMessageType_ThpQrCodeSecret\x10\x81\x08\x1a\ + \x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x125\n&ThpMessageType_ThpNfcUnidire\ + ctionalTag\x10\x88\x08\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x128\n)Th\ + pMessageType_ThpNfcUnidirectionalSecret\x10\x89\x08\x1a\x08\x80\xa6\x1d\ + \x01\xd8\xb5\x18\x01\"\x05\x08\0\x10\xe7\x07\"\t\x08\xcc\x08\x10\xff\xff\ + \xff\xff\x07*S\n\x10ThpPairingMethod\x12\x0c\n\x08NoMethod\x10\x01\x12\r\ + \n\tCodeEntry\x10\x02\x12\n\n\x06QrCode\x10\x03\x12\x16\n\x12NFC_Unidire\ + ctional\x10\x04B;\n#com.satoshilabs.trezor.lib.protobufB\x10TrezorMessag\ + eThp\x80\xa6\x1d\x01\ "; /// `FileDescriptorProto` object which was a source for this generated file @@ -566,11 +4124,33 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor { let generated_file_descriptor = generated_file_descriptor_lazy.get(|| { let mut deps = ::std::vec::Vec::with_capacity(1); deps.push(super::options::file_descriptor().clone()); - let mut messages = ::std::vec::Vec::with_capacity(3); + let mut messages = ::std::vec::Vec::with_capacity(23); + messages.push(ThpDeviceProperties::generated_message_descriptor_data()); + messages.push(ThpHandshakeCompletionReqNoisePayload::generated_message_descriptor_data()); + messages.push(ThpCreateNewSession::generated_message_descriptor_data()); + messages.push(ThpNewSession::generated_message_descriptor_data()); + messages.push(ThpStartPairingRequest::generated_message_descriptor_data()); + messages.push(ThpPairingPreparationsFinished::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCommitment::generated_message_descriptor_data()); + messages.push(ThpCodeEntryChallenge::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCpaceHost::generated_message_descriptor_data()); + messages.push(ThpCodeEntryCpaceTrezor::generated_message_descriptor_data()); + messages.push(ThpCodeEntryTag::generated_message_descriptor_data()); + messages.push(ThpCodeEntrySecret::generated_message_descriptor_data()); + messages.push(ThpQrCodeTag::generated_message_descriptor_data()); + messages.push(ThpQrCodeSecret::generated_message_descriptor_data()); + messages.push(ThpNfcUnidirectionalTag::generated_message_descriptor_data()); + messages.push(ThpNfcUnidirectionalSecret::generated_message_descriptor_data()); + messages.push(ThpCredentialRequest::generated_message_descriptor_data()); + messages.push(ThpCredentialResponse::generated_message_descriptor_data()); + messages.push(ThpEndRequest::generated_message_descriptor_data()); + messages.push(ThpEndResponse::generated_message_descriptor_data()); messages.push(ThpCredentialMetadata::generated_message_descriptor_data()); messages.push(ThpPairingCredential::generated_message_descriptor_data()); messages.push(ThpAuthenticatedCredentialData::generated_message_descriptor_data()); - let mut enums = ::std::vec::Vec::with_capacity(0); + let mut enums = ::std::vec::Vec::with_capacity(2); + enums.push(ThpMessageType::generated_enum_descriptor_data()); + enums.push(ThpPairingMethod::generated_enum_descriptor_data()); ::protobuf::reflect::GeneratedFileDescriptor::new_generated( file_descriptor_proto(), deps, diff --git a/rust/trezor-client/src/protos/generated/options.rs b/rust/trezor-client/src/protos/generated/options.rs index 79bcf5e3472..5d96704fd93 100644 --- a/rust/trezor-client/src/protos/generated/options.rs +++ b/rust/trezor-client/src/protos/generated/options.rs @@ -42,6 +42,14 @@ pub mod exts { pub const wire_no_fsm: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50008, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + pub const channel_in: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50009, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + + pub const channel_out: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50010, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + + pub const pairing_in: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50011, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + + pub const pairing_out: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50012, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); + pub const bitcoin_only: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(60000, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); pub const has_bitcoin_only_values: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(51001, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL); @@ -68,19 +76,25 @@ static file_descriptor_proto_data: &'static [u8] = b"\ \x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x08wireTiny\ :L\n\x0fwire_bootloader\x18\xd7\x86\x03\x20\x01(\x08\x12!.google.protobu\ f.EnumValueOptionsR\x0ewireBootloader:C\n\x0bwire_no_fsm\x18\xd8\x86\x03\ - \x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\twireNoFsm:F\n\x0cb\ - itcoin_only\x18\xe0\xd4\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueO\ - ptionsR\x0bbitcoinOnly:U\n\x17has_bitcoin_only_values\x18\xb9\x8e\x03\ - \x20\x01(\x08\x12\x1c.google.protobuf.EnumOptionsR\x14hasBitcoinOnlyValu\ - es:T\n\x14experimental_message\x18\xa1\x96\x03\x20\x01(\x08\x12\x1f.goog\ - le.protobuf.MessageOptionsR\x13experimentalMessage:>\n\twire_type\x18\ - \xa2\x96\x03\x20\x01(\r\x12\x1f.google.protobuf.MessageOptionsR\x08wireT\ - ype:F\n\rinternal_only\x18\xa3\x96\x03\x20\x01(\x08\x12\x1f.google.proto\ - buf.MessageOptionsR\x0cinternalOnly:N\n\x12experimental_field\x18\x89\ - \x9e\x03\x20\x01(\x08\x12\x1d.google.protobuf.FieldOptionsR\x11experimen\ - talField:U\n\x17include_in_bitcoin_only\x18\xe0\xd4\x03\x20\x01(\x08\x12\ - \x1c.google.protobuf.FileOptionsR\x14includeInBitcoinOnlyB4\n#com.satosh\ - ilabs.trezor.lib.protobufB\rTrezorOptions\ + \x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\twireNoFsm:B\n\ncha\ + nnel_in\x18\xd9\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptio\ + nsR\tchannelIn:D\n\x0bchannel_out\x18\xda\x86\x03\x20\x01(\x08\x12!.goog\ + le.protobuf.EnumValueOptionsR\nchannelOut:B\n\npairing_in\x18\xdb\x86\ + \x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\tpairingIn:D\n\ + \x0bpairing_out\x18\xdc\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumVa\ + lueOptionsR\npairingOut:F\n\x0cbitcoin_only\x18\xe0\xd4\x03\x20\x01(\x08\ + \x12!.google.protobuf.EnumValueOptionsR\x0bbitcoinOnly:U\n\x17has_bitcoi\ + n_only_values\x18\xb9\x8e\x03\x20\x01(\x08\x12\x1c.google.protobuf.EnumO\ + ptionsR\x14hasBitcoinOnlyValues:T\n\x14experimental_message\x18\xa1\x96\ + \x03\x20\x01(\x08\x12\x1f.google.protobuf.MessageOptionsR\x13experimenta\ + lMessage:>\n\twire_type\x18\xa2\x96\x03\x20\x01(\r\x12\x1f.google.protob\ + uf.MessageOptionsR\x08wireType:F\n\rinternal_only\x18\xa3\x96\x03\x20\ + \x01(\x08\x12\x1f.google.protobuf.MessageOptionsR\x0cinternalOnly:N\n\ + \x12experimental_field\x18\x89\x9e\x03\x20\x01(\x08\x12\x1d.google.proto\ + buf.FieldOptionsR\x11experimentalField:U\n\x17include_in_bitcoin_only\ + \x18\xe0\xd4\x03\x20\x01(\x08\x12\x1c.google.protobuf.FileOptionsR\x14in\ + cludeInBitcoinOnlyB4\n#com.satoshilabs.trezor.lib.protobufB\rTrezorOptio\ + ns\ "; /// `FileDescriptorProto` object which was a source for this generated file diff --git a/tests/REGISTERED_MARKERS b/tests/REGISTERED_MARKERS index fab4ec8b3a3..bec85ca898b 100644 --- a/tests/REGISTERED_MARKERS +++ b/tests/REGISTERED_MARKERS @@ -11,6 +11,7 @@ multisig nem ontology peercoin +protocol ripple sd_card solana diff --git a/tests/click_tests/test_autolock.py b/tests/click_tests/test_autolock.py index e841d60f0c2..80d7fc1765a 100644 --- a/tests/click_tests/test_autolock.py +++ b/tests/click_tests/test_autolock.py @@ -21,7 +21,9 @@ import pytest from trezorlib import btc, device, exceptions, messages +from trezorlib.client import PASSPHRASE_ON_DEVICE from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.protobuf import MessageType from trezorlib.tools import parse_path @@ -58,8 +60,8 @@ def set_autolock_delay(device_handler: "BackgroundDeviceHandler", delay_ms: int): debug = device_handler.debuglink() - - device_handler.run(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore + Session(device_handler.client.get_management_session()).lock() + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore assert "PinKeyboard" in debug.read_layout().all_components() @@ -97,7 +99,7 @@ def test_autolock_interrupts_signing(device_handler: "BackgroundDeviceHandler"): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - device_handler.run(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore + device_handler.run_with_session(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore assert ( "1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1" @@ -132,6 +134,10 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() + + # Prepare session to use later + session = Session(device_handler.client.get_session()) + # try to sign a transaction inp1 = messages.TxInputType( address_n=parse_path("86h/0h/0h/0/0"), @@ -147,8 +153,8 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa script_type=messages.OutputScriptType.PAYTOADDRESS, ) - device_handler.run( - btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + device_handler.run_with_provided_session( + session, btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert ( @@ -175,11 +181,11 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - device_handler.client.set_filter(messages.TxAck, None) + session.set_filter(messages.TxAck, None) return msg - with device_handler.client: - device_handler.client.set_filter(messages.TxAck, sleepy_filter) + with session, device_handler.client: + session.set_filter(messages.TxAck, sleepy_filter) # confirm transaction # In all cases we set wait=False to avoid waiting for the screen and triggering # the layout deadlock detection. In reality there is no deadlock but the @@ -187,7 +193,7 @@ def sleepy_filter(msg: MessageType) -> MessageType: # timeout is 3. In this test we don't need the result of the input event so # waiting for it is not necessary. if debug.layout_type is LayoutType.TT: - debug.click(buttons.OK, wait=False) + debug.click(buttons.OK, hold_ms=1000, wait=False) elif debug.layout_type is LayoutType.Mercury: debug.click(buttons.TAP_TO_CONFIRM, wait=False) elif debug.layout_type is LayoutType.TR: @@ -196,7 +202,6 @@ def sleepy_filter(msg: MessageType) -> MessageType: signatures, tx = device_handler.result() assert len(signatures) == 1 assert tx - assert device_handler.features().unlocked is False @@ -206,8 +211,10 @@ def test_autolock_passphrase_keyboard(device_handler: "BackgroundDeviceHandler") debug = device_handler.debuglink() # get address - device_handler.run(common.get_test_address) # type: ignore - + session = Session( + device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE) + ) + device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore assert "PassphraseKeyboard" in debug.read_layout().all_components() if debug.layout_type is LayoutType.TR: @@ -248,8 +255,10 @@ def test_autolock_interrupts_passphrase(device_handler: "BackgroundDeviceHandler debug = device_handler.debuglink() # get address - device_handler.run(common.get_test_address) # type: ignore - + session = Session( + device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE) + ) + device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore assert "PassphraseKeyboard" in debug.read_layout().all_components() if debug.layout_type is LayoutType.TR: @@ -287,7 +296,7 @@ def test_dryrun_locks_at_number_of_words(device_handler: "BackgroundDeviceHandle set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) layout = unlock_dry_run(debug) assert TR.recovery__num_of_words in debug.read_layout().text_content() @@ -319,7 +328,7 @@ def test_dryrun_locks_at_word_entry(device_handler: "BackgroundDeviceHandler"): set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) unlock_dry_run(debug) @@ -345,7 +354,7 @@ def test_dryrun_enter_word_slowly(device_handler: "BackgroundDeviceHandler"): set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) unlock_dry_run(debug) @@ -405,7 +414,11 @@ def test_autolock_does_not_interrupt_preauthorized( debug = device_handler.debuglink() - device_handler.run( + # Prepare session to use later + session = Session(device_handler.client.get_session()) + + device_handler.run_with_provided_session( + session, btc.authorize_coinjoin, coordinator="www.example.com", max_rounds=2, @@ -519,14 +532,15 @@ def test_autolock_does_not_interrupt_preauthorized( def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - device_handler.client.set_filter(messages.SignTx, None) + session.set_filter(messages.SignTx, None) return msg - with device_handler.client: + with session: # Start DoPreauthorized flow when device is unlocked. Wait 10s before # delivering SignTx, by that time autolock timer should have fired. - device_handler.client.set_filter(messages.SignTx, sleepy_filter) - device_handler.run( + session.set_filter(messages.SignTx, sleepy_filter) + device_handler.run_with_provided_session( + session, btc.sign_tx, "Testnet", inputs, diff --git a/tests/click_tests/test_backup_slip39_custom.py b/tests/click_tests/test_backup_slip39_custom.py index be01683d075..0976a08ad32 100644 --- a/tests/click_tests/test_backup_slip39_custom.py +++ b/tests/click_tests/test_backup_slip39_custom.py @@ -53,7 +53,9 @@ def test_backup_slip39_custom( assert features.initialized is False - device_handler.run( + session = device_handler.client.get_management_session() + device_handler.run_with_provided_session( + session, device.reset, strength=128, backup_type=messages.BackupType.Slip39_Basic, @@ -68,7 +70,7 @@ def test_backup_slip39_custom( assert device_handler.result() == "Initialized" - device_handler.run( + device_handler.run_with_session( device.backup, group_threshold=group_threshold, groups=[(share_threshold, share_count)], diff --git a/tests/click_tests/test_lock.py b/tests/click_tests/test_lock.py index 4b719885e66..afaacb078ca 100644 --- a/tests/click_tests/test_lock.py +++ b/tests/click_tests/test_lock.py @@ -19,7 +19,7 @@ import pytest -from trezorlib import models +from trezorlib import messages, models from trezorlib.debuglink import LayoutType from .. import buttons, common @@ -34,6 +34,9 @@ @pytest.mark.setup_client(pin=PIN4) def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() + session = device_handler.client.get_management_session() + session.call(messages.LockDevice()) + session.refresh_features() short_duration = { models.T1B1: 500, @@ -59,22 +62,25 @@ def hold(duration: int) -> None: assert device_handler.features().unlocked is False # unlock with message - device_handler.run(common.get_test_address) + device_handler.run_with_session(common.get_test_address) assert "PinKeyboard" in debug.read_layout().all_components() debug.input("1234") assert device_handler.result() + session.refresh_features() assert device_handler.features().unlocked is True # short touch hold(short_duration) time.sleep(0.5) # so that the homescreen appears again (hacky) + session.refresh_features() assert device_handler.features().unlocked is True # lock hold(lock_duration) + session.refresh_features() assert device_handler.features().unlocked is False # unlock by touching @@ -85,8 +91,10 @@ def hold(duration: int) -> None: assert "PinKeyboard" in layout.all_components() debug.input("1234") + session.refresh_features() assert device_handler.features().unlocked is True # lock hold(lock_duration) + session.refresh_features() assert device_handler.features().unlocked is False diff --git a/tests/click_tests/test_passphrase_mercury.py b/tests/click_tests/test_passphrase_mercury.py index 9bed04da84a..d0783e0dcd8 100644 --- a/tests/click_tests/test_passphrase_mercury.py +++ b/tests/click_tests/test_passphrase_mercury.py @@ -97,7 +97,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore # TODO assert debug.read_layout().main_component() == "PassphraseKeyboard" diff --git a/tests/click_tests/test_passphrase_tr.py b/tests/click_tests/test_passphrase_tr.py index 57685451ba0..0affa4fbb6b 100644 --- a/tests/click_tests/test_passphrase_tr.py +++ b/tests/click_tests/test_passphrase_tr.py @@ -91,7 +91,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore layout = debug.read_layout() assert "PassphraseKeyboard" in layout.all_components() assert layout.passphrase() == "" diff --git a/tests/click_tests/test_passphrase_tt.py b/tests/click_tests/test_passphrase_tt.py index 8f490c03098..79993b954fb 100644 --- a/tests/click_tests/test_passphrase_tt.py +++ b/tests/click_tests/test_passphrase_tt.py @@ -69,7 +69,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore assert debug.read_layout().main_component() == "PassphraseKeyboard" # Resetting the category as it could have been changed by previous tests diff --git a/tests/click_tests/test_pin.py b/tests/click_tests/test_pin.py index 48f54c5573f..4d8afedb313 100644 --- a/tests/click_tests/test_pin.py +++ b/tests/click_tests/test_pin.py @@ -23,6 +23,7 @@ from trezorlib import device, exceptions from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from .. import buttons from .. import translations as TR @@ -91,17 +92,19 @@ def prepare( tap = False + Session(device_handler.client.get_management_session()).lock() + # Setup according to the wanted situation if situation == Situation.PIN_INPUT: # Any action triggering the PIN dialogue - device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore tap = True if situation == Situation.PIN_INPUT_CANCEL: # Any action triggering the PIN dialogue - device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore elif situation == Situation.PIN_SETUP: # Set new PIN - device_handler.run(device.change_pin) # type: ignore + device_handler.run_with_session(device.change_pin) # type: ignore assert ( TR.pin__turn_on in debug.read_layout().text_content() or TR.pin__info in debug.read_layout().text_content() @@ -115,14 +118,14 @@ def prepare( go_next(debug) elif situation == Situation.PIN_CHANGE: # Change PIN - device_handler.run(device.change_pin) # type: ignore + device_handler.run_with_session(device.change_pin) # type: ignore _input_see_confirm(debug, old_pin) assert TR.pin__change in debug.read_layout().text_content() go_next(debug) _input_see_confirm(debug, old_pin) elif situation == Situation.WIPE_CODE_SETUP: # Set wipe code - device_handler.run(device.change_wipe_code) # type: ignore + device_handler.run_with_session(device.change_wipe_code) # type: ignore if old_pin: _input_see_confirm(debug, old_pin) assert TR.wipe_code__turn_on in debug.read_layout().text_content() diff --git a/tests/click_tests/test_recovery.py b/tests/click_tests/test_recovery.py index 8770649296a..769b2b507c3 100644 --- a/tests/click_tests/test_recovery.py +++ b/tests/click_tests/test_recovery.py @@ -40,7 +40,7 @@ def prepare_recovery_and_evaluate( features = device_handler.features() debug = device_handler.debuglink() assert features.initialized is False - device_handler.run(device.recover, pin_protection=False) # type: ignore + device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore yield debug diff --git a/tests/click_tests/test_repeated_backup.py b/tests/click_tests/test_repeated_backup.py index d61d97962df..55fd4157efe 100644 --- a/tests/click_tests/test_repeated_backup.py +++ b/tests/click_tests/test_repeated_backup.py @@ -42,7 +42,7 @@ def test_repeated_backup( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.reset, strength=128, backup_type=messages.BackupType.Slip39_Basic, @@ -94,7 +94,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # run recovery to unlock backup - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) @@ -161,7 +161,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # try to unlock backup again... - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) diff --git a/tests/click_tests/test_reset_bip39.py b/tests/click_tests/test_reset_bip39.py index 907246fb51d..18692b12797 100644 --- a/tests/click_tests/test_reset_bip39.py +++ b/tests/click_tests/test_reset_bip39.py @@ -40,7 +40,7 @@ def test_reset_bip39(device_handler: "BackgroundDeviceHandler"): assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.reset, strength=128, backup_type=messages.BackupType.Bip39, diff --git a/tests/click_tests/test_reset_slip39_advanced.py b/tests/click_tests/test_reset_slip39_advanced.py index 874ad7a6211..d26a55fb00c 100644 --- a/tests/click_tests/test_reset_slip39_advanced.py +++ b/tests/click_tests/test_reset_slip39_advanced.py @@ -52,7 +52,7 @@ def test_reset_slip39_advanced( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.reset, backup_type=messages.BackupType.Slip39_Advanced, pin_protection=False, diff --git a/tests/click_tests/test_reset_slip39_basic.py b/tests/click_tests/test_reset_slip39_basic.py index f8c6592f6d4..fbdd8f63f76 100644 --- a/tests/click_tests/test_reset_slip39_basic.py +++ b/tests/click_tests/test_reset_slip39_basic.py @@ -48,7 +48,7 @@ def test_reset_slip39_basic( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.reset, strength=128, backup_type=messages.BackupType.Slip39_Basic, diff --git a/tests/click_tests/test_tutorial_mercury.py b/tests/click_tests/test_tutorial_mercury.py index 987b32b48c7..7129cf91312 100644 --- a/tests/click_tests/test_tutorial_mercury.py +++ b/tests/click_tests/test_tutorial_mercury.py @@ -36,7 +36,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 @@ -57,7 +57,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 @@ -84,7 +84,7 @@ def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 @@ -108,7 +108,7 @@ def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 @@ -139,7 +139,7 @@ def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_funfact(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 diff --git a/tests/click_tests/test_tutorial_tr.py b/tests/click_tests/test_tutorial_tr.py index 81d2645ace5..88dc895a64e 100644 --- a/tests/click_tests/test_tutorial_tr.py +++ b/tests/click_tests/test_tutorial_tr.py @@ -39,7 +39,7 @@ def prepare_tutorial_and_cancel_after_it( device_handler: "BackgroundDeviceHandler", cancelled: bool = False ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) yield debug diff --git a/tests/common.py b/tests/common.py index b2a20bb39dd..41fd00f4d54 100644 --- a/tests/common.py +++ b/tests/common.py @@ -34,8 +34,8 @@ from _pytest.mark.structures import MarkDecorator from trezorlib.debuglink import DebugLink - from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import ButtonRequest + from trezorlib.transport.session import Session PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")] @@ -338,10 +338,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None: assert got >= expected -def get_test_address(client: "Client") -> str: +def get_test_address(session: "Session") -> str: """Fetch a testnet address on a fixed path. Useful to make a pin/passphrase protected call, or to identify the root secret (seed+passphrase)""" - return btc.get_address(client, "Testnet", TEST_ADDRESS_N) + return btc.get_address(session, "Testnet", TEST_ADDRESS_N) def compact_size(n: int) -> bytes: @@ -380,5 +380,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None: debug.swipe_up() -def is_core(client: "Client") -> bool: - return client.model is not models.T1B1 +def is_core(session: "Session") -> bool: + return session.model is not models.T1B1 diff --git a/tests/conftest.py b/tests/conftest.py index c78c9a766f6..704cd8114d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,17 +20,22 @@ import typing as t from enum import IntEnum from pathlib import Path +from time import sleep +import cryptography import pytest import xdist from _pytest.python import IdMaker from _pytest.reports import TestReport from trezorlib import debuglink, log, models +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.device import apply_settings from trezorlib.device import wipe as wipe_device from trezorlib.transport import enumerate_devices, get_transport +from trezorlib.transport.thp.protocol_v1 import ProtocolV1 # register rewrites before importing from local package # so that we see details of failed asserts from this module @@ -135,6 +140,10 @@ def _get_port() -> int: @pytest.fixture(scope="session") def _raw_client(request: pytest.FixtureRequest) -> Client: + return _get_raw_client(request) + + +def _get_raw_client(request: pytest.FixtureRequest) -> Client: # In case tests run in parallel, each process has its own emulator/client. # Requesting the emulator fixture only if relevant. if request.session.config.getoption("control_emulators"): @@ -273,6 +282,29 @@ def client( if _raw_client.model not in models_filter: pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}") + protocol_marker: Mark | None = request.node.get_closest_marker("protocol") + if protocol_marker: + args = protocol_marker.args + protocol_version = _raw_client.protocol_version + + if ( + protocol_version == ProtocolVersion.PROTOCOL_V1 + and "protocol_v1" not in args + ): + pytest.xfail( + f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." + ) + + if ( + protocol_version == ProtocolVersion.PROTOCOL_V2 + and "protocol_v2" not in args + ): + pytest.xfail( + f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." + ) + + if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2: + pass sd_marker = request.node.get_closest_marker("sd_card") if sd_marker and not _raw_client.features.sd_card_present: raise RuntimeError( @@ -283,14 +315,15 @@ def client( test_ui = request.config.getoption("ui") - _raw_client.reset_debug_features() + _raw_client.reset_debug_features(new_management_session=True) _raw_client.open() - try: - _raw_client.sync_responses() - _raw_client.init_device() - except Exception: - request.session.shouldstop = "Failed to communicate with Trezor" - pytest.fail("Failed to communicate with Trezor") + if isinstance(_raw_client.protocol, ProtocolV1): + try: + _raw_client.sync_responses() + # TODO _raw_client.init_device() + except Exception: + request.session.shouldstop = "Failed to communicate with Trezor" + pytest.fail("Failed to communicate with Trezor") # Resetting all the debug events to not be influenced by previous test _raw_client.debug.reset_debug_events() @@ -303,13 +336,34 @@ def client( should_format = sd_marker.kwargs.get("formatted", True) _raw_client.debug.erase_sd_card(format=should_format) - wipe_device(_raw_client) + while True: + try: + session = _raw_client.get_management_session() + wipe_device(session) + sleep(1.5) # Makes tests more stable (wait for wipe to finish) + break + except cryptography.exceptions.InvalidTag: + # Get a new client + _raw_client = _get_raw_client(request) + + from trezorlib.transport.thp.channel_database import get_channel_db + + get_channel_db().clear_stored_channels() + _raw_client.protocol = None + _raw_client.__init__( + transport=_raw_client.transport, + auto_interact=_raw_client.debug.allow_interactions, + ) + if not _raw_client.features.bootloader_mode: + _raw_client.refresh_features() # Load language again, as it got erased in wipe if _raw_client.model is not models.T1B1: lang = request.session.config.getoption("lang") or "en" assert isinstance(lang, str) - translations.set_language(_raw_client, lang) + translations.set_language( + SessionDebugWrapper(_raw_client.get_management_session()), lang + ) setup_params = dict( uninitialized=False, @@ -327,10 +381,10 @@ def client( use_passphrase = setup_params["passphrase"] is True or isinstance( setup_params["passphrase"], str ) - if not setup_params["uninitialized"]: + session = _raw_client.get_management_session(new_session=True) debuglink.load_device( - _raw_client, + session, mnemonic=setup_params["mnemonic"], # type: ignore pin=setup_params["pin"], # type: ignore passphrase_protection=use_passphrase, @@ -338,14 +392,16 @@ def client( needs_backup=setup_params["needs_backup"], # type: ignore no_backup=setup_params["no_backup"], # type: ignore ) + if setup_params["pin"] is not None: + _raw_client._has_setup_pin = True if request.node.get_closest_marker("experimental"): - apply_settings(_raw_client, experimental_features=True) + apply_settings(session, experimental_features=True) if use_passphrase and isinstance(setup_params["passphrase"], str): _raw_client.use_passphrase(setup_params["passphrase"]) - _raw_client.clear_session() + # TODO _raw_client.clear_session() with ui_tests.screen_recording(_raw_client, request): yield _raw_client @@ -353,6 +409,29 @@ def client( _raw_client.close() +@pytest.fixture(scope="function") +def session( + request: pytest.FixtureRequest, client: Client +) -> t.Generator[SessionDebugWrapper, None, None]: + if bool(request.node.get_closest_marker("uninitialized_session")): + session = client.get_management_session() + else: + derive_cardano = bool(request.node.get_closest_marker("cardano")) + passphrase = client.passphrase or "" + session = client.get_session( + derive_cardano=derive_cardano, passphrase=passphrase + ) + try: + wrapped_session = SessionDebugWrapper(session) + if client._has_setup_pin: + wrapped_session.lock() + yield wrapped_session + finally: + pass + # TODO + # session.end() + + def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool: """Return True if the current process is the main test runner. @@ -463,6 +542,10 @@ def pytest_configure(config: "Config") -> None: "markers", 'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance', ) + config.addinivalue_line( + "markers", + "uninitialized_session: use uninitialized session instance", + ) with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f: for line in f: config.addinivalue_line("markers", line.strip()) diff --git a/tests/device_handler.py b/tests/device_handler.py index 45ec1df9f78..b2c61acbfc8 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -48,7 +48,9 @@ def _configure_client(self, client: "Client") -> None: self.client.watch_layout(True) self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT - def run(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def run_with_session( + self, function: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: """Runs some function that interacts with a device. Makes sure the UI is updated before returning. @@ -58,15 +60,30 @@ def run(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: # wait for the first UI change triggered by the task running in the background with self.debuglink().wait_for_layout_change(): - self.task = self._pool.submit(function, self.client, *args, **kwargs) + session = self.client.get_session() + self.task = self._pool.submit(function, session, *args, **kwargs) + + def run_with_provided_session( + self, session, function: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: + """Runs some function that interacts with a device. + + Makes sure the UI is updated before returning. + """ + if self.task is not None: + raise RuntimeError("Wait for previous task first") + + # wait for the first UI change triggered by the task running in the background + with self.debuglink().wait_for_layout_change(): + self.task = self._pool.submit(function, session, *args, **kwargs) def kill_task(self) -> None: if self.task is not None: # Force close the client, which should raise an exception in a client # waiting on IO. Does not work over Bridge, because bridge doesn't have # a close() method. - while self.client.session_counter > 0: - self.client.close() + # while self.client.session_counter > 0: + # self.client.close() try: self.task.result(timeout=1) except Exception: @@ -90,7 +107,7 @@ def result(self, timeout: float | None = None) -> Any: def features(self) -> "Features": if self.task is not None: raise RuntimeError("Cannot query features while task is running") - self.client.init_device() + self.client.refresh_features() return self.client.features def debuglink(self) -> "DebugLink": diff --git a/tests/device_tests/binance/test_get_address.py b/tests/device_tests/binance/test_get_address.py index cdb6e722713..6b5a0247676 100644 --- a/tests/device_tests/binance/test_get_address.py +++ b/tests/device_tests/binance/test_get_address.py @@ -17,7 +17,7 @@ import pytest from trezorlib.binance import get_address -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowAddressQRCode @@ -38,23 +38,23 @@ @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) -def test_binance_get_address(client: Client, path: str, expected_address: str): +def test_binance_get_address(session: Session, path: str, expected_address: str): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - address = get_address(client, parse_path(path), show_display=True) + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS) def test_binance_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/binance/test_get_public_key.py b/tests/device_tests/binance/test_get_public_key.py index ea04fdbd88f..f65baa5dd83 100644 --- a/tests/device_tests/binance/test_get_public_key.py +++ b/tests/device_tests/binance/test_get_public_key.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...input_flows import InputFlowShowXpubQRCode @@ -31,11 +31,11 @@ @pytest.mark.setup_client( mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin" ) -def test_binance_get_public_key(client: Client): - with client: +def test_binance_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - sig = binance.get_public_key(client, BINANCE_PATH, show_display=True) + sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) assert ( sig.hex() == "029729a52e4e3c2b4a4e52aa74033eedaf8ba1df5ab6d1f518fd69e67bbd309b0e" diff --git a/tests/device_tests/binance/test_sign_tx.py b/tests/device_tests/binance/test_sign_tx.py index ceb06924650..1665e005a46 100644 --- a/tests/device_tests/binance/test_sign_tx.py +++ b/tests/device_tests/binance/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import binance -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path BINANCE_TEST_VECTORS = [ @@ -110,10 +110,10 @@ @pytest.mark.parametrize("message, expected_response", BINANCE_TEST_VECTORS) @pytest.mark.parametrize("chunkify", (True, False)) def test_binance_sign_message( - client: Client, chunkify: bool, message: dict, expected_response: dict + session: Session, chunkify: bool, message: dict, expected_response: dict ): response = binance.sign_tx( - client, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify + session, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify ) assert response.public_key.hex() == expected_response["public_key"] diff --git a/tests/device_tests/bitcoin/payment_req.py b/tests/device_tests/bitcoin/payment_req.py index 73d98859ba1..f928a5fa8e8 100644 --- a/tests/device_tests/bitcoin/payment_req.py +++ b/tests/device_tests/bitcoin/payment_req.py @@ -4,6 +4,7 @@ from ecdsa import SECP256k1, SigningKey from trezorlib import btc, messages +from trezorlib.transport.session import Session from ...common import compact_size @@ -27,7 +28,12 @@ def hash_bytes_prefixed(hasher, data): def make_payment_request( - client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None + session: Session, + recipient_name, + outputs, + change_addresses=None, + memos=None, + nonce=None, ): h_pr = sha256(b"SL\x00\x24") @@ -52,7 +58,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, memo.text.encode()) elif isinstance(memo, RefundMemo): address_resp = btc.get_authenticated_address( - client, "Testnet", memo.address_n + session, "Testnet", memo.address_n ) msg_memo = messages.RefundMemo( address=address_resp.address, mac=address_resp.mac @@ -63,7 +69,7 @@ def make_payment_request( hash_bytes_prefixed(h_pr, address_resp.address.encode()) elif isinstance(memo, CoinPurchaseMemo): address_resp = btc.get_authenticated_address( - client, memo.coin_name, memo.address_n + session, memo.coin_name, memo.address_n ) msg_memo = messages.CoinPurchaseMemo( coin_type=memo.slip44, diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index b149ff53d16..2ee16a074c5 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -19,6 +19,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -59,15 +60,15 @@ @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.setup_client(pin=PIN) -def test_sign_tx(client: Client, chunkify: bool): +def test_sign_tx(session: Session, chunkify: bool): # NOTE: FAKE input tx - + assert session.features.unlocked is False commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=2, max_coordinator_fee_rate=500_000, # 0.5 % @@ -77,14 +78,14 @@ def test_sign_tx(client: Client, chunkify: bool): script_type=messages.InputScriptType.SPENDTAPROOT, ) - client.call(messages.LockDevice()) + session.call(messages.LockDevice()) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -93,12 +94,12 @@ def test_sign_tx(client: Client, chunkify: bool): preauthorized=True, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/5"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -206,8 +207,8 @@ def test_sign_tx(client: Client, chunkify: bool): no_fee_indices=[], ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.PreauthorizedRequest(), request_input(0), @@ -222,7 +223,7 @@ def test_sign_tx(client: Client, chunkify: bool): ] ) signatures, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -243,7 +244,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a second time. btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -256,7 +257,7 @@ def test_sign_tx(client: Client, chunkify: bool): # Test for a third time, number of rounds should be exceeded. with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -267,7 +268,7 @@ def test_sign_tx(client: Client, chunkify: bool): ) -def test_sign_tx_large(client: Client): +def test_sign_tx_large(session: Session): # NOTE: FAKE input tx commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") @@ -278,17 +279,16 @@ def test_sign_tx_large(client: Client): output_denom = 10_000 # sats max_expected_delay = 60 # seconds - with client: - btc.authorize_coinjoin( - client, - coordinator="www.example.com", - max_rounds=2, - max_coordinator_fee_rate=500_000, # 0.5 % - max_fee_per_kvbyte=3500, - n=parse_path("m/10025h/1h/0h/1h"), - coin_name="Testnet", - script_type=messages.InputScriptType.SPENDTAPROOT, - ) + btc.authorize_coinjoin( + session, + coordinator="www.example.com", + max_rounds=2, + max_coordinator_fee_rate=500_000, # 0.5 % + max_fee_per_kvbyte=3500, + n=parse_path("m/10025h/1h/0h/1h"), + coin_name="Testnet", + script_type=messages.InputScriptType.SPENDTAPROOT, + ) # INPUTS. @@ -399,22 +399,21 @@ def test_sign_tx_large(client: Client): ) start = time.time() - with client: - btc.sign_tx( - client, - "Testnet", - inputs, - outputs, - prev_txes=TX_CACHE_TESTNET, - coinjoin_request=coinjoin_req, - preauthorized=True, - serialize=False, - ) + btc.sign_tx( + session, + "Testnet", + inputs, + outputs, + prev_txes=TX_CACHE_TESTNET, + coinjoin_request=coinjoin_req, + preauthorized=True, + serialize=False, + ) delay = time.time() - start assert delay <= max_expected_delay -def test_sign_tx_spend(client: Client): +def test_sign_tx_spend(session: Session): # NOTE: FAKE input tx inputs = [ @@ -446,15 +445,15 @@ def test_sign_tx_spend(client: Client): # Ensure that Trezor refuses to spend from CoinJoin without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest(), @@ -462,7 +461,7 @@ def test_sign_tx_spend(client: Client): request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -472,7 +471,7 @@ def test_sign_tx_spend(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -487,7 +486,7 @@ def test_sign_tx_spend(client: Client): ) -def test_sign_tx_migration(client: Client): +def test_sign_tx_migration(session: Session): inputs = [ messages.TxInputType( address_n=parse_path("m/84h/1h/3h/0/12"), @@ -520,15 +519,15 @@ def test_sign_tx_migration(client: Client): # Ensure that Trezor refuses to receive to CoinJoin path without the user first authorizing access to CoinJoin paths. with pytest.raises(TrezorFailure, match="Forbidden key path"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, prev_txes=TX_CACHE_TESTNET, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest(), @@ -536,7 +535,7 @@ def test_sign_tx_migration(client: Client): request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_2cc3c1), @@ -558,7 +557,7 @@ def test_sign_tx_migration(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -573,11 +572,11 @@ def test_sign_tx_migration(client: Client): ) -def test_wrong_coordinator(client: Client): +def test_wrong_coordinator(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -589,7 +588,7 @@ def test_wrong_coordinator(client: Client): with pytest.raises(TrezorFailure, match="Unauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -599,9 +598,9 @@ def test_wrong_coordinator(client: Client): ) -def test_wrong_account_type(client: Client): +def test_wrong_account_type(session: Session): params = { - "client": client, + "session": session, "coordinator": "www.example.com", "max_rounds": 10, "max_coordinator_fee_rate": 500_000, # 0.5 % @@ -625,11 +624,11 @@ def test_wrong_account_type(client: Client): ) -def test_cancel_authorization(client: Client): +def test_cancel_authorization(session: Session): # Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator. btc.authorize_coinjoin( - client, + session, coordinator="www.example.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -639,11 +638,11 @@ def test_cancel_authorization(client: Client): script_type=messages.InputScriptType.SPENDTAPROOT, ) - device.cancel_authorization(client) + device.cancel_authorization(session) with pytest.raises(TrezorFailure, match="No preauthorized operation"): btc.get_ownership_proof( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -653,35 +652,35 @@ def test_cancel_authorization(client: Client): ) -def test_get_public_key(client: Client): +def test_get_public_key(session: Session): ACCOUNT_PATH = parse_path("m/10025h/1h/0h/1h") EXPECTED_XPUB = "tpubDEMKm4M3S2Grx5DHTfbX9et5HQb9KhdjDCkUYdH9gvVofvPTE6yb2MH52P9uc4mx6eFohUmfN1f4hhHNK28GaZnWRXr3b8KkfFcySo1SmXU" # Ensure that user cannot access SLIP-25 path without UnlockPath. with pytest.raises(TrezorFailure, match="Forbidden key path"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) # Get unlock path MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, n=SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, n=SLIP25_PATH) # Ensure that UnlockPath fails with invalid MAC. invalid_unlock_path_mac = bytes([unlock_path_mac[0] ^ 1]) + unlock_path_mac[1:] with pytest.raises(TrezorFailure, match="Invalid MAC"): resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -690,15 +689,15 @@ def test_get_public_key(client: Client): ) # Ensure that user does not need to confirm access when path unlock is requested with MAC. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.UnlockedPathRequest, messages.PublicKey, ] ) resp = btc.get_public_node( - client, + session, ACCOUNT_PATH, coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, @@ -708,11 +707,12 @@ def test_get_public_key(client: Client): assert resp.xpub == EXPECTED_XPUB -def test_get_address(client: Client): +def test_get_address(session: Session): + # Ensure that the SLIP-0025 external chain is inaccessible without user confirmation. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -720,20 +720,20 @@ def test_get_address(client: Client): ) # Unlock CoinJoin path. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, messages.Failure(code=messages.FailureType.ActionCancelled), ] ) - unlock_path_mac = device.unlock_path(client, SLIP25_PATH) + unlock_path_mac = device.unlock_path(session, SLIP25_PATH) # Ensure that the SLIP-0025 external chain is accessible after user confirmation. for chunkify in (True, False): resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -745,7 +745,7 @@ def test_get_address(client: Client): assert resp == "tb1pl3y9gf7xk2ryvmav5ar66ra0d2hk7lhh9mmusx3qvn0n09kmaghqh32ru7" resp = btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -758,7 +758,7 @@ def test_get_address(client: Client): # Ensure that the SLIP-0025 internal chain is inaccessible even with user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -769,7 +769,7 @@ def test_get_address(client: Client): with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/0h/1h/1/1"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -781,7 +781,7 @@ def test_get_address(client: Client): # Ensure that another SLIP-0025 account is inaccessible with the same MAC. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_address( - client, + session, "Testnet", parse_path("m/10025h/1h/1h/1h/0/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -793,8 +793,10 @@ def test_get_address(client: Client): def test_multisession_authorization(client: Client): # Authorize CoinJoin with www.example1.com in session 1. + session1 = client.get_session() + btc.authorize_coinjoin( - client, + session1, coordinator="www.example1.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -803,14 +805,14 @@ def test_multisession_authorization(client: Client): coin_name="Testnet", script_type=messages.InputScriptType.SPENDTAPROOT, ) - + session2 = client.get_session() # Open a second session. - session_id1 = client.session_id - client.init_device(new_session=True) + # session_id1 = session.session_id + # TODO client.init_device(new_session=True) # Authorize CoinJoin with www.example2.com in session 2. btc.authorize_coinjoin( - client, + session2, coordinator="www.example2.com", max_rounds=10, max_coordinator_fee_rate=500_000, # 0.5 % @@ -823,7 +825,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example1.com should fail in session 2. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -834,7 +836,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -849,12 +851,12 @@ def test_multisession_authorization(client: Client): ) # Switch back to the first session. - session_id2 = client.session_id - client.init_device(session_id=session_id1) - + # session_id2 = session.session_id + # TODO client.init_device(session_id=session_id1) + client.resume_session(session1) # Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1. ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -871,7 +873,7 @@ def test_multisession_authorization(client: Client): # Requesting a preauthorized ownership proof for www.example2.com should fail in session 1. with pytest.raises(TrezorFailure, match="Unauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -881,12 +883,12 @@ def test_multisession_authorization(client: Client): ) # Cancel the authorization in session 1. - device.cancel_authorization(client) + device.cancel_authorization(session1) # Requesting a preauthorized ownership proof should fail now. with pytest.raises(TrezorFailure, match="No preauthorized operation"): ownership_proof, _ = btc.get_ownership_proof( - client, + session1, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -896,11 +898,11 @@ def test_multisession_authorization(client: Client): ) # Switch to the second session. - client.init_device(session_id=session_id2) - + # TODO client.init_device(session_id=session_id2) + client.resume_session(session2) # Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2. ownership_proof, _ = btc.get_ownership_proof( - client, + session2, "Testnet", parse_path("m/10025h/1h/0h/1h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, diff --git a/tests/device_tests/bitcoin/test_bcash.py b/tests/device_tests/bitcoin/test_bcash.py index 76538828632..d1f0129741c 100644 --- a/tests/device_tests/bitcoin/test_bcash.py +++ b/tests/device_tests/bitcoin/test_bcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -53,7 +53,7 @@ pytestmark = pytest.mark.altcoin -def test_send_bch_change(client: Client): +def test_send_bch_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/0/0"), # bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv @@ -72,14 +72,14 @@ def test_send_bch_change(client: Client): amount=73_452, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_bc37c2), @@ -92,9 +92,9 @@ def test_send_bch_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) - + # raise Exception(hexlify(serialized_tx)) assert_tx_matches( serialized_tx, hash_link="https://bch1.trezor.io/api/tx/502e8577b237b0152843a416f8f1ab0c63321b1be7a8cad7bf5c5c216fcf062c", @@ -102,7 +102,7 @@ def test_send_bch_change(client: Client): ) -def test_send_bch_nochange(client: Client): +def test_send_bch_nochange(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -124,14 +124,14 @@ def test_send_bch_nochange(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -150,7 +150,7 @@ def test_send_bch_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -160,7 +160,7 @@ def test_send_bch_nochange(client: Client): ) -def test_send_bch_oldaddr(client: Client): +def test_send_bch_oldaddr(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/145h/0h/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -182,14 +182,14 @@ def test_send_bch_oldaddr(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -208,7 +208,7 @@ def test_send_bch_oldaddr(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( @@ -218,7 +218,7 @@ def test_send_bch_oldaddr(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -252,15 +252,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_bd32ff), @@ -271,16 +271,16 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_bch_multisig_wrongchange(client: Client): +def test_send_bch_multisig_wrongchange(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -327,13 +327,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=23_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_062fbd), @@ -346,7 +346,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1], prev_txes=TX_API + session, "Bcash", [inp1], [out1], prev_txes=TX_API ) assert ( signatures1[0].hex() @@ -359,12 +359,12 @@ def getmultisig(chain, nr, signatures): @pytest.mark.multisig -def test_send_bch_multisig_change(client: Client): +def test_send_bch_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" + session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash" ).node for i in range(1, 4) ] @@ -395,13 +395,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=24_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -415,7 +415,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -434,13 +434,13 @@ def getmultisig(chain, nr, signatures): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -454,7 +454,7 @@ def getmultisig(chain, nr, signatures): ] ) (signatures1, serialized_tx) = btc.sign_tx( - client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API + session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -468,7 +468,7 @@ def getmultisig(chain, nr, signatures): @pytest.mark.models("core") -def test_send_bch_external_presigned(client: Client): +def test_send_bch_external_presigned(session: Session): inp1 = messages.TxInputType( # address_n=parse_path("44'/145'/0'/1/0"), # bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw @@ -496,14 +496,14 @@ def test_send_bch_external_presigned(client: Client): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_502e85), @@ -522,7 +522,7 @@ def test_send_bch_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_bgold.py b/tests/device_tests/bitcoin/test_bgold.py index 71c1a6c3ad4..831ea216cbd 100644 --- a/tests/device_tests/bitcoin/test_bgold.py +++ b/tests/device_tests/bitcoin/test_bgold.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path, tx_hash @@ -51,7 +51,7 @@ # All data taken from T1 -def test_send_bitcoin_gold_change(client: Client): +def test_send_bitcoin_gold_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -71,14 +71,14 @@ def test_send_bitcoin_gold_change(client: Client): amount=1_252_382_934 - 1_896_050 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -92,7 +92,7 @@ def test_send_bitcoin_gold_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -101,7 +101,7 @@ def test_send_bitcoin_gold_change(client: Client): ) -def test_send_bitcoin_gold_nochange(client: Client): +def test_send_bitcoin_gold_nochange(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -124,14 +124,14 @@ def test_send_bitcoin_gold_nochange(client: Client): amount=1_252_382_934 + 38_448_607 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -150,7 +150,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -159,7 +159,7 @@ def test_send_bitcoin_gold_nochange(client: Client): ) -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -193,15 +193,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -213,16 +213,16 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure): - btc.sign_tx(client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) + btc.sign_tx(session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API) @pytest.mark.multisig -def test_send_btg_multisig_change(client: Client): +def test_send_btg_multisig_change(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" + session, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -254,13 +254,13 @@ def getmultisig(chain, nr, signatures): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=1_252_382_934 - 24_000 - 1_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -275,7 +275,7 @@ def getmultisig(chain, nr, signatures): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -293,13 +293,13 @@ def getmultisig(chain, nr, signatures): ) out2.address_n[2] = H_(1) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -314,7 +314,7 @@ def getmultisig(chain, nr, signatures): ] ) signatures, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -327,7 +327,7 @@ def getmultisig(chain, nr, signatures): ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -347,16 +347,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_db7239), @@ -371,7 +371,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -380,7 +380,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_witness_change(client: Client): +def test_send_p2sh_witness_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -400,13 +400,13 @@ def test_send_p2sh_witness_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -422,7 +422,7 @@ def test_send_p2sh_witness_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API + session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API ) assert ( @@ -432,12 +432,12 @@ def test_send_p2sh_witness_change(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # NOTE: fake input tx used nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" + session, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold" ).node for i in range(1, 4) ] @@ -460,13 +460,13 @@ def test_send_multisig_1(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -479,17 +479,17 @@ def test_send_multisig_1(client: Client): request_finished(), ] ) - signatures, _ = btc.sign_tx(client, "Bgold", [inp1], [out1], prev_txes=TX_API) + signatures, _ = btc.sign_tx(session, "Bgold", [inp1], [out1], prev_txes=TX_API) # store signature inp1.multisig.signatures[0] = signatures[0] # sign with third key inp1.address_n[2] = H_(3) - client.set_expected_responses( + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_7f1f6b), @@ -503,7 +503,7 @@ def test_send_multisig_1(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1], [out1], prev_txes=TX_API + session, "Bgold", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -512,7 +512,7 @@ def test_send_multisig_1(client: Client): ) -def test_send_mixed_inputs(client: Client): +def test_send_mixed_inputs(session: Session): # NOTE: fake input tx used # First is non-segwit, second is segwit. @@ -537,9 +537,9 @@ def test_send_mixed_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -549,7 +549,7 @@ def test_send_mixed_inputs(client: Client): @pytest.mark.models("core") -def test_send_btg_external_presigned(client: Client): +def test_send_btg_external_presigned(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -577,14 +577,14 @@ def test_send_btg_external_presigned(client: Client): amount=1_252_382_934 + 58_456 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_6f0398), @@ -603,7 +603,7 @@ def test_send_btg_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API + session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_dash.py b/tests/device_tests/bitcoin/test_dash.py index 4dde98bfbfd..06b335c1487 100644 --- a/tests/device_tests/bitcoin/test_dash.py +++ b/tests/device_tests/bitcoin/test_dash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")] -def test_send_dash(client: Client): +def test_send_dash(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -57,13 +57,13 @@ def test_send_dash(client: Client): amount=999_999_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -77,7 +77,9 @@ def test_send_dash(client: Client): request_finished(), ] ) - _, serialized_tx = btc.sign_tx(client, "Dash", [inp1], [out1], prev_txes=TX_API) + _, serialized_tx = btc.sign_tx( + session, "Dash", [inp1], [out1], prev_txes=TX_API + ) assert ( serialized_tx.hex() @@ -85,7 +87,7 @@ def test_send_dash(client: Client): ) -def test_send_dash_dip2_input(client: Client): +def test_send_dash_dip2_input(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/5h/0h/0/0"), # dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH @@ -104,14 +106,14 @@ def test_send_dash_dip2_input(client: Client): amount=95_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(inp1.prev_hash), @@ -128,7 +130,7 @@ def test_send_dash_dip2_input(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Dash", [inp1], [out1, out2], prev_txes=TX_API + session, "Dash", [inp1], [out1, out2], prev_txes=TX_API ) assert ( diff --git a/tests/device_tests/bitcoin/test_decred.py b/tests/device_tests/bitcoin/test_decred.py index 78bb1b0c3af..204d0559280 100644 --- a/tests/device_tests/bitcoin/test_decred.py +++ b/tests/device_tests/bitcoin/test_decred.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -57,7 +57,7 @@ ] -def test_send_decred(client: Client): +def test_send_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -76,13 +76,13 @@ def test_send_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -95,7 +95,7 @@ def test_send_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1], [out1], prev_txes=TX_API ) assert ( @@ -105,7 +105,7 @@ def test_send_decred(client: Client): @pytest.mark.models("core") -def test_purchase_ticket_decred(client: Client): +def test_purchase_ticket_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -133,8 +133,8 @@ def test_purchase_ticket_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), @@ -153,7 +153,7 @@ def test_purchase_ticket_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1], [out1, out2, out3], @@ -168,7 +168,7 @@ def test_purchase_ticket_decred(client: Client): @pytest.mark.models("core") -def test_spend_from_stake_generation_and_revocation_decred(client: Client): +def test_spend_from_stake_generation_and_revocation_decred(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -197,14 +197,14 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_8b6890), @@ -223,7 +223,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert ( @@ -232,7 +232,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client): ) -def test_send_decred_change(client: Client): +def test_send_decred_change(session: Session): # NOTE: fake input tx used inp1 = messages.TxInputType( @@ -278,15 +278,15 @@ def test_send_decred_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_input(2), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -311,7 +311,7 @@ def test_send_decred_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2, inp3], [out1, out2], @@ -325,12 +325,12 @@ def test_send_decred_change(client: Client): @pytest.mark.multisig -def test_decred_multisig_change(client: Client): +def test_decred_multisig_change(session: Session): # NOTE: fake input tx used paths = [parse_path(f"m/48h/1h/{index}'/0'") for index in range(3)] nodes = [ - btc.get_public_node(client, address_n, coin_name="Decred Testnet").node + btc.get_public_node(session, address_n, coin_name="Decred Testnet").node for address_n in paths ] @@ -384,15 +384,15 @@ def test_multisig(index): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_9ac7d2), @@ -410,7 +410,7 @@ def test_multisig(index): ] ) signature, serialized_tx = btc.sign_tx( - client, + session, "Decred Testnet", [inp1, inp2], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_descriptors.py b/tests/device_tests/bitcoin/test_descriptors.py index 6efdd99ed82..7a077b20527 100644 --- a/tests/device_tests/bitcoin/test_descriptors.py +++ b/tests/device_tests/bitcoin/test_descriptors.py @@ -18,7 +18,7 @@ from trezorlib import btc, messages, models from trezorlib.cli import btc as btc_cli -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_ from ...input_flows import InputFlowShowXpubQRCode @@ -165,14 +165,16 @@ def _address_n(purpose, coin, account, script_type): @pytest.mark.parametrize( "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) -def test_descriptors(client: Client, coin, account, purpose, script_type, descriptors): - with client: +def test_descriptors( + session: Session, coin, account, purpose, script_type, descriptors +): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) address_n = _address_n(purpose, coin, account, script_type) res = btc.get_public_node( - client, + session, _address_n(purpose, coin, account, script_type), show_display=True, coin_name=coin, @@ -187,13 +189,13 @@ def test_descriptors(client: Client, coin, account, purpose, script_type, descri "coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS ) def test_descriptors_trezorlib( - client: Client, coin, account, purpose, script_type, descriptors + session: Session, coin, account, purpose, script_type, descriptors ): - with client: + with session.client as client: if client.model != models.T1B1: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) res = btc_cli._get_descriptor( - client, coin, account, purpose, script_type, show_display=True + session, coin, account, purpose, script_type, show_display=True ) assert res == descriptors diff --git a/tests/device_tests/bitcoin/test_firo.py b/tests/device_tests/bitcoin/test_firo.py index 52db787957d..2ceeb2c2d77 100644 --- a/tests/device_tests/bitcoin/test_firo.py +++ b/tests/device_tests/bitcoin/test_firo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -30,7 +30,7 @@ @pytest.mark.altcoin -def test_spend_lelantus(client: Client): +def test_spend_lelantus(session: Session): inp1 = messages.TxInputType( # THgGLVqfzJcaxRVPWE5fd8YJ1GpVePq2Uk address_n=parse_path("m/44h/1h/0h/0/4"), @@ -45,7 +45,7 @@ def test_spend_lelantus(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Firo Testnet", [inp1], [out1], prev_txes=TX_API + session, "Firo Testnet", [inp1], [out1], prev_txes=TX_API ) assert_tx_matches( serialized_tx, diff --git a/tests/device_tests/bitcoin/test_fujicoin.py b/tests/device_tests/bitcoin/test_fujicoin.py index f28747c7173..45886e8603b 100644 --- a/tests/device_tests/bitcoin/test_fujicoin.py +++ b/tests/device_tests/bitcoin/test_fujicoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path TXHASH_33043a = bytes.fromhex( @@ -27,7 +27,7 @@ pytestmark = pytest.mark.altcoin -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # fc1prr07akly3xjtmggue0p04vghr8pdcgxrye2s00sahptwjeawxrkq2rxzr7 address_n=parse_path("m/86h/75h/0h/0/1"), @@ -42,7 +42,7 @@ def test_send_p2tr(client: Client): amount=99_996_670_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _, serialized_tx = btc.sign_tx(client, "Fujicoin", [inp1], [out1]) + _, serialized_tx = btc.sign_tx(session, "Fujicoin", [inp1], [out1]) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://explorer.fujicoin.org/tx/a1c6a81f5e8023b17e6e3e51e2596d5b5e1d4914ea13c0c31cef90b3c3edee86 assert ( diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index 73d984a4cec..f92e6f3e67c 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import MultisigPubkeysOrder, SafetyCheckLevel from trezorlib.tools import parse_path @@ -36,112 +36,112 @@ def getmultisig(chain, nr, xpubs): ) -def test_btc(client: Client): +def test_btc(session: Session): assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) == "1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/1")) == "1GWFxtwWmNVqotUPXLcKVL2mUKpshuJYo" ) assert ( - btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" ) @pytest.mark.altcoin -def test_ltc(client: Client): +def test_ltc(session: Session): assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/0")) == "LcubERmHD31PWup1fbozpKuiqjHZ4anxcL" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/1")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/1")) == "LVWBmHBkCGNjSPHucvL2PmnuRAJnucmRE6" ) assert ( - btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/1/0")) + btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/1/0")) == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" ) -def test_tbtc(client: Client): +def test_tbtc(session: Session): assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/1")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/1")) == "mopZWqZZyQc3F2Sy33cvDtJchSAMsnLi7b" ) assert ( - btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" ) @pytest.mark.altcoin -def test_bch(client: Client): +def test_bch(session: Session): assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/0")) == "bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/1")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/1")) == "bitcoincash:qr23ajjfd9wd73l87j642puf8cad20lfmqdgwvpat4" ) assert ( - btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/1/0")) + btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/1/0")) == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" ) @pytest.mark.altcoin -def test_grs(client: Client): +def test_grs(session: Session): assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/0/0")) == "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/0")) == "FmRaqvVBRrAp2Umfqx9V1ectZy8gw54QDN" ) assert ( - btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) + btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")) == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" ) @pytest.mark.altcoin -def test_tgrs(client: Client): +def test_tgrs(session: Session): assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0")) == "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0")) == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1LMq8cN" ) assert ( - btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) + btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1")) == "mjXZwmEi1z1MzveZrKUAo4DBgbdq6ZhGD6" ) @pytest.mark.altcoin -def test_elements(client: Client): +def test_elements(session: Session): assert ( - btc.get_address(client, "Elements", parse_path("m/44h/1h/0h/0/0")) + btc.get_address(session, "Elements", parse_path("m/44h/1h/0h/0/0")) == "2dpWh6jbhAowNsQ5agtFzi7j6nKscj6UnEr" ) @pytest.mark.models("core") -def test_address_mac(client: Client): +def test_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/1/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/1/0") ) assert resp.address == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert ( @@ -150,7 +150,7 @@ def test_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Testnet", parse_path("m/44h/1h/0h/1/0") + session, "Testnet", parse_path("m/44h/1h/0h/1/0") ) assert resp.address == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ" assert ( @@ -160,16 +160,16 @@ def test_address_mac(client: Client): # Script type mismatch. resp = btc.get_authenticated_address( - client, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False ) assert resp.mac is None @pytest.mark.models("core") @pytest.mark.altcoin -def test_altcoin_address_mac(client: Client): +def test_altcoin_address_mac(session: Session): resp = btc.get_authenticated_address( - client, "Litecoin", parse_path("m/44h/2h/0h/1/0") + session, "Litecoin", parse_path("m/44h/2h/0h/1/0") ) assert resp.address == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h" assert ( @@ -178,7 +178,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Bcash", parse_path("m/44h/145h/0h/1/0") + session, "Bcash", parse_path("m/44h/145h/0h/1/0") ) assert resp.address == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw" assert ( @@ -187,7 +187,7 @@ def test_altcoin_address_mac(client: Client): ) resp = btc.get_authenticated_address( - client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") + session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1") ) assert resp.address == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di" assert ( @@ -198,9 +198,9 @@ def test_altcoin_address_mac(client: Client): @pytest.mark.multisig @pytest.mark.models(skip="legacy", reason="Sortedmulti is not supported") -def test_multisig_pubkeys_order(client: Client): - xpub_internal = btc.get_public_node(client, parse_path("m/45h/0")).xpub - xpub_external = btc.get_public_node(client, parse_path("m/44h/1")).xpub +def test_multisig_pubkeys_order(session: Session): + xpub_internal = btc.get_public_node(session, parse_path("m/45h/0")).xpub + xpub_external = btc.get_public_node(session, parse_path("m/44h/1")).xpub multisig_unsorted_1 = messages.MultisigRedeemScriptType( nodes=[bip32.deserialize(xpub) for xpub in [xpub_internal, xpub_internal]], @@ -239,45 +239,45 @@ def test_multisig_pubkeys_order(client: Client): assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 ) == address_unsorted_1 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 ) == address_unsorted_2 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1 ) == address_unsorted_2 ) assert ( btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2 ) == address_unsorted_2 ) @pytest.mark.multisig -def test_multisig(client: Client): +def test_multisig(session: Session): xpubs = [] for n in range(1, 4): - node = btc.get_public_node(client, parse_path(f"m/44h/0h/{n}h")) + node = btc.get_public_node(session, parse_path(f"m/44h/0h/{n}h")) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/0/0"), show_display=(nr == 1), @@ -287,7 +287,7 @@ def test_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(f"m/44h/0h/{nr}h/1/0"), show_display=(nr == 1), @@ -299,11 +299,11 @@ def test_multisig(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/44h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/44h/0h/{i}h")).node for i in range(1, 4) ] @@ -322,12 +322,12 @@ def test_multisig_missing(client: Client, show_display): ) for multisig in (multisig1, multisig2): - with client, pytest.raises(TrezorFailure): - if is_core(client): + with session.client as client, pytest.raises(TrezorFailure): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=show_display, @@ -337,22 +337,22 @@ def test_multisig_missing(client: Client, show_display): @pytest.mark.altcoin @pytest.mark.multisig -def test_bch_multisig(client: Client): +def test_bch_multisig(session: Session): xpubs = [] for n in range(1, 4): node = btc.get_public_node( - client, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" + session, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash" ) xpubs.append(node.xpub) for nr in range(1, 4): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/0/0"), show_display=(nr == 1), @@ -362,7 +362,7 @@ def test_bch_multisig(client: Client): ) assert ( btc.get_address( - client, + session, "Bcash", parse_path(f"m/44h/145h/{nr}h/1/0"), show_display=(nr == 1), @@ -372,43 +372,43 @@ def test_bch_multisig(client: Client): ) -def test_public_ckd(client: Client): - node = btc.get_public_node(client, parse_path("m/44h/0h/0h")).node - node_sub1 = btc.get_public_node(client, parse_path("m/44h/0h/0h/1/0")).node +def test_public_ckd(session: Session): + node = btc.get_public_node(session, parse_path("m/44h/0h/0h")).node + node_sub1 = btc.get_public_node(session, parse_path("m/44h/0h/0h/1/0")).node node_sub2 = bip32.public_ckd(node, [1, 0]) assert node_sub1.chain_code == node_sub2.chain_code assert node_sub1.public_key == node_sub2.public_key - address1 = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) + address1 = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")) address2 = bip32.get_address(node_sub2, 0) assert address2 == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE" assert address1 == address2 -def test_invalid_path(client: Client): +def test_invalid_path(session: Session): with pytest.raises(TrezorFailure, match="Forbidden key path"): # slip44 id mismatch btc.get_address( - client, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True ) -def test_unknown_path(client: Client): +def test_unknown_path(session: Session): UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0") - with client: - client.set_expected_responses([messages.Failure]) + with session: + session.set_expected_responses([messages.Failure]) with pytest.raises(TrezorFailure, match="Forbidden key path"): # account number is too high - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) # disable safety checks - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ messages.ButtonRequest( code=messages.ButtonRequestType.UnknownDerivationPath @@ -417,31 +417,31 @@ def test_unknown_path(client: Client): messages.Address, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) # try again with a warning - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) - with client: + with session: # no warning is displayed when the call is silent - client.set_expected_responses([messages.Address]) - btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=False) + session.set_expected_responses([messages.Address]) + btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False) @pytest.mark.altcoin -def test_crw(client: Client): +def test_crw(session: Session): assert ( - btc.get_address(client, "Crown", parse_path("m/44h/72h/0h/0/0")) + btc.get_address(session, "Crown", parse_path("m/44h/72h/0h/0/0")) == "CRWYdvZM1yXMKQxeN3hRsAbwa7drfvTwys48" ) @pytest.mark.multisig @pytest.mark.models(skip="legacy", reason="Not fixed") -def test_multisig_different_paths(client: Client): +def test_multisig_different_paths(session: Session): nodes = [ - btc.get_public_node(client, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node + btc.get_public_node(session, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node for i in range(2) ] @@ -457,12 +457,12 @@ def test_multisig_different_paths(client: Client): with pytest.raises( Exception, match="Using different paths for different xpubs is not allowed" ): - with client: - if is_core(client): + with session.client as client, session: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/45h/0/0/0"), show_display=True, @@ -470,13 +470,13 @@ def test_multisig_different_paths(client: Client): script_type=messages.InputScriptType.SPENDMULTISIG, ) - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - with client: - if is_core(client): + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + with session.client as client, session: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", parse_path("m/45h/0/0/0"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index 848097a8cbb..b1e3affac72 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -25,10 +25,10 @@ from ...input_flows import InputFlowConfirmAllWarnings -def test_show_segwit(client: Client): +def test_show_segwit(session: Session): assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -39,7 +39,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/49h/1h/0h/0/0"), False, @@ -50,7 +50,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -61,7 +61,7 @@ def test_show_segwit(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path("m/44h/1h/0h/0/0"), False, @@ -73,14 +73,14 @@ def test_show_segwit(client: Client): @pytest.mark.altcoin -def test_show_segwit_altcoin(client: Client): - with client: - if is_core(client): +def test_show_segwit_altcoin(session: Session): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/1/0"), True, @@ -91,7 +91,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/49h/1h/0h/0/0"), True, @@ -102,7 +102,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -113,7 +113,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"), True, @@ -124,7 +124,7 @@ def test_show_segwit_altcoin(client: Client): ) assert ( btc.get_address( - client, + session, "Elements", parse_path("m/49h/1h/0h/0/0"), True, @@ -136,10 +136,10 @@ def test_show_segwit_altcoin(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -155,7 +155,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/49h/1h/{i}h/0/7"), False, @@ -168,11 +168,11 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display): +def test_multisig_missing(session: Session, show_display): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/49h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/49h/0h/{i}h")).node for i in range(1, 4) ] @@ -193,7 +193,7 @@ def test_multisig_missing(client: Client, show_display): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/49h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py index 55b0fbfdb5e..7c220adf65a 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit_native.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -141,7 +141,7 @@ @pytest.mark.parametrize("show_display", (True, False)) @pytest.mark.parametrize("coin, path, script_type, address", VECTORS) def test_show_segwit( - client: Client, + session: Session, show_display: bool, coin: str, path: str, @@ -150,7 +150,7 @@ def test_show_segwit( ): assert ( btc.get_address( - client, + session, coin, parse_path(path), show_display, @@ -166,10 +166,10 @@ def test_show_segwit( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) @pytest.mark.parametrize("path, address", BIP86_VECTORS) -def test_bip86(client: Client, path: str, address: str): +def test_bip86(session: Session, path: str, address: str): assert ( btc.get_address( - client, + session, "Bitcoin", parse_path(path), False, @@ -181,10 +181,10 @@ def test_bip86(client: Client, path: str, address: str): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -197,7 +197,7 @@ def test_show_multisig_3(client: Client): for i in [1, 2, 3]: assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/1"), False, @@ -208,7 +208,7 @@ def test_show_multisig_3(client: Client): ) assert ( btc.get_address( - client, + session, "Testnet", parse_path(f"m/84h/1h/{i}h/0/0"), False, @@ -221,11 +221,11 @@ def test_show_multisig_3(client: Client): @pytest.mark.multisig @pytest.mark.parametrize("show_display", (True, False)) -def test_multisig_missing(client: Client, show_display: bool): +def test_multisig_missing(session: Session, show_display: bool): # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] @@ -246,7 +246,7 @@ def test_multisig_missing(client: Client, show_display: bool): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=show_display, diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index 8770176d427..208f9c98ae1 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import is_core @@ -55,20 +55,20 @@ @pytest.mark.models("legacy") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_t1( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): def input_flow_t1(): yield - client.debug.press_no() + session.debug.press_no() yield - client.debug.press_yes() + session.debug.press_yes() - with client: + with session: # This is the only place where even T1 is using input flow - client.set_input_flow(input_flow_t1) + session.set_input_flow(input_flow_t1) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -82,18 +82,18 @@ def input_flow_t1(): @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_tt( - client: Client, + session: Session, chunkify: bool, path: str, script_type: messages.InputScriptType, address: str, ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -107,13 +107,13 @@ def test_show_tt( @pytest.mark.models("core") @pytest.mark.parametrize("path, script_type, address", VECTORS) def test_show_cancel( - client: Client, path: str, script_type: messages.InputScriptType, address: str + session: Session, path: str, script_type: messages.InputScriptType, address: str ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowShowAddressQRCodeCancel(client) client.set_input_flow(IF.get()) btc.get_address( - client, + session, "Bitcoin", tools.parse_path(path), script_type=script_type, @@ -121,10 +121,10 @@ def test_show_cancel( ) -def test_show_unrecognized_path(client: Client): +def test_show_unrecognized_path(session: Session): with pytest.raises(TrezorFailure): btc.get_address( - client, + session, "Bitcoin", tools.parse_path("m/24684621h/516582h/5156h/21/856"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -133,10 +133,10 @@ def test_show_unrecognized_path(client: Client): @pytest.mark.multisig -def test_show_multisig_3(client: Client): +def test_show_multisig_3(session: Session): nodes = [ btc.get_public_node( - client, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" + session, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" ).node for i in [1, 2, 3] ] @@ -157,13 +157,13 @@ def test_show_multisig_3(client: Client): for multisig in (multisig1, multisig2): for i in [1, 2, 3]: - with client: - if is_core(client): + with session.client as client, session: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/{i}/0/0"), show_display=True, @@ -250,7 +250,7 @@ def test_show_multisig_3(client: Client): "script_type, bip48_type, address, xpubs, ignore_xpub_magic", VECTORS_MULTISIG ) def test_show_multisig_xpubs( - client: Client, + session: Session, script_type: messages.InputScriptType, bip48_type: int, address: str, @@ -259,7 +259,7 @@ def test_show_multisig_xpubs( ): nodes = [ btc.get_public_node( - client, + session, tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h"), coin_name="Bitcoin", ) @@ -273,13 +273,13 @@ def test_show_multisig_xpubs( ) for i in range(3): - with client: + with session, session.client as client: IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) client.set_input_flow(IF.get()) client.debug.synchronize_at("Homescreen") client.watch_layout() btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/48h/0h/{i}h/{bip48_type}h/0/0"), show_display=True, @@ -290,10 +290,10 @@ def test_show_multisig_xpubs( @pytest.mark.multisig -def test_show_multisig_15(client: Client): +def test_show_multisig_15(session: Session): nodes = [ btc.get_public_node( - client, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" + session, tools.parse_path(f"m/45h/{i}"), coin_name="Bitcoin" ).node for i in range(15) ] @@ -314,13 +314,13 @@ def test_show_multisig_15(client: Client): for multisig in [multisig1, multisig2]: for i in range(15): - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) assert ( btc.get_address( - client, + session, "Bitcoin", tools.parse_path(f"m/45h/{i}/0/0"), show_display=True, diff --git a/tests/device_tests/bitcoin/test_getownershipproof.py b/tests/device_tests/bitcoin/test_getownershipproof.py index b21fe944b0e..51309eb625d 100644 --- a/tests/device_tests/bitcoin/test_getownershipproof.py +++ b/tests/device_tests/bitcoin/test_getownershipproof.py @@ -17,14 +17,14 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path -def test_p2wpkh_ownership_id(client: Client): +def test_p2wpkh_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -35,9 +35,9 @@ def test_p2wpkh_ownership_id(client: Client): ) -def test_p2tr_ownership_id(client: Client): +def test_p2tr_ownership_id(session: Session): ownership_id = btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -48,12 +48,12 @@ def test_p2tr_ownership_id(client: Client): ) -def test_attack_ownership_id(client: Client): +def test_attack_ownership_id(session: Session): # Multisig with global suffix specification. # Use account numbers 1, 2 and 3 to create a valid multisig, # but not containing the keys from account 0 used below. nodes = [ - btc.get_public_node(client, parse_path(f"m/84h/0h/{i}h")).node + btc.get_public_node(session, parse_path(f"m/84h/0h/{i}h")).node for i in range(1, 4) ] multisig1 = messages.MultisigRedeemScriptType( @@ -62,7 +62,7 @@ def test_attack_ownership_id(client: Client): # Multisig with per-node suffix specification. node = btc.get_public_node( - client, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/84h/0h/0h/0"), coin_name="Bitcoin" ).node multisig2 = messages.MultisigRedeemScriptType( pubkeys=[ @@ -77,7 +77,7 @@ def test_attack_ownership_id(client: Client): for multisig in (multisig1, multisig2): with pytest.raises(TrezorFailure): btc.get_ownership_id( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), multisig=multisig, @@ -85,9 +85,9 @@ def test_attack_ownership_id(client: Client): ) -def test_p2wpkh_ownership_proof(client: Client): +def test_p2wpkh_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -98,9 +98,9 @@ def test_p2wpkh_ownership_proof(client: Client): ) -def test_p2tr_ownership_proof(client: Client): +def test_p2tr_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/86h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDTAPROOT, @@ -111,10 +111,10 @@ def test_p2tr_ownership_proof(client: Client): ) -def test_fake_ownership_id(client: Client): +def test_fake_ownership_id(session: Session): with pytest.raises(TrezorFailure, match="Invalid ownership identifier"): btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -124,9 +124,9 @@ def test_fake_ownership_id(client: Client): ) -def test_confirm_ownership_proof(client: Client): +def test_confirm_ownership_proof(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -139,9 +139,9 @@ def test_confirm_ownership_proof(client: Client): ) -def test_confirm_ownership_proof_with_data(client: Client): +def test_confirm_ownership_proof_with_data(session: Session): ownership_proof, _ = btc.get_ownership_proof( - client, + session, "Bitcoin", parse_path("m/84h/0h/0h/1/0"), script_type=messages.InputScriptType.SPENDWITNESS, diff --git a/tests/device_tests/bitcoin/test_getpublickey.py b/tests/device_tests/bitcoin/test_getpublickey.py index e8b90cbb487..81dadf8a60a 100644 --- a/tests/device_tests/bitcoin/test_getpublickey.py +++ b/tests/device_tests/bitcoin/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -110,35 +110,35 @@ @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node(client: Client, coin_name, xpub_magic, path, xpub): - res = btc.get_public_node(client, path, coin_name=coin_name) +def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub): + res = btc.get_public_node(session, path, coin_name=coin_name) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.models("core") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) -def test_get_public_node_show(client: Client, coin_name, xpub_magic, path, xpub): - with client: +def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) - res = btc.get_public_node(client, path, coin_name=coin_name, show_display=True) + res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") @pytest.mark.parametrize("coin_name, path", VECTORS_INVALID) -def test_invalid_path(client: Client, coin_name, path): +def test_invalid_path(session: Session, coin_name, path): with pytest.raises(TrezorFailure, match="Forbidden key path"): - btc.get_public_node(client, path, coin_name=coin_name) + btc.get_public_node(session, path, coin_name=coin_name) -def test_slip25_path(client: Client): +def test_slip25_path(session: Session): # Ensure that CoinJoin XPUBs are inaccessible without user authorization. with pytest.raises(TrezorFailure, match="Forbidden key path"): btc.get_public_node( - client, + session, parse_path("m/10025h/0h/0h/1h"), script_type=messages.InputScriptType.SPENDTAPROOT, ) @@ -169,14 +169,14 @@ def test_slip25_path(client: Client): @pytest.mark.parametrize("script_type, xpub, xpub_ignored_magic", VECTORS_SCRIPT_TYPES) -def test_script_type(client: Client, script_type, xpub, xpub_ignored_magic): +def test_script_type(session: Session, script_type, xpub, xpub_ignored_magic): path = parse_path("m/44h/0h/0") res = btc.get_public_node( - client, path, coin_name="Bitcoin", script_type=script_type + session, path, coin_name="Bitcoin", script_type=script_type ) assert res.xpub == xpub res = btc.get_public_node( - client, + session, path, coin_name="Bitcoin", script_type=script_type, diff --git a/tests/device_tests/bitcoin/test_getpublickey_curve.py b/tests/device_tests/bitcoin/test_getpublickey_curve.py index 8b8cba68871..393afca61c8 100644 --- a/tests/device_tests/bitcoin/test_getpublickey_curve.py +++ b/tests/device_tests/bitcoin/test_getpublickey_curve.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -54,21 +54,21 @@ @pytest.mark.parametrize("curve, path, pubkey", VECTORS) -def test_publickey_curve(client: Client, curve, path, pubkey): - resp = btc.get_public_node(client, path, ecdsa_curve_name=curve) +def test_publickey_curve(session: Session, curve, path, pubkey): + resp = btc.get_public_node(session, path, ecdsa_curve_name=curve) assert resp.node.public_key.hex() == pubkey -def test_ed25519_public(client: Client): +def test_ed25519_public(session: Session): with pytest.raises(TrezorFailure): - btc.get_public_node(client, PATH_PUBLIC, ecdsa_curve_name="ed25519") + btc.get_public_node(session, PATH_PUBLIC, ecdsa_curve_name="ed25519") @pytest.mark.xfail(reason="Currently path validation on get_public_node is disabled.") -def test_coin_and_curve(client: Client): +def test_coin_and_curve(session: Session): with pytest.raises( TrezorFailure, match="Cannot use coin_name or script_type with ecdsa_curve_name" ): btc.get_public_node( - client, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" + session, PATH_PRIVATE, coin_name="Bitcoin", ecdsa_curve_name="ed25519" ) diff --git a/tests/device_tests/bitcoin/test_grs.py b/tests/device_tests/bitcoin/test_grs.py index d25ffd20f00..ff2b5c4cdfc 100644 --- a/tests/device_tests/bitcoin/test_grs.py +++ b/tests/device_tests/bitcoin/test_grs.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ pytestmark = pytest.mark.altcoin -def test_legacy(client: Client): +def test_legacy(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -56,7 +56,7 @@ def test_legacy(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -64,7 +64,7 @@ def test_legacy(client: Client): ) -def test_legacy_change(client: Client): +def test_legacy_change(session: Session): inp1 = messages.TxInputType( # FXHDsC5ZqWQHkDmShzgRVZ1MatpWhwxTAA address_n=parse_path("m/44h/17h/0h/0/2"), @@ -78,7 +78,7 @@ def test_legacy_change(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin", [inp1], [out1], prev_txes=TX_API + session, "Groestlcoin", [inp1], [out1], prev_txes=TX_API ) assert ( serialized_tx.hex() @@ -86,7 +86,7 @@ def test_legacy_change(client: Client): ) -def test_send_segwit_p2sh(client: Client): +def test_send_segwit_p2sh(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -107,7 +107,7 @@ def test_send_segwit_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -120,7 +120,7 @@ def test_send_segwit_p2sh(client: Client): ) -def test_send_segwit_p2sh_change(client: Client): +def test_send_segwit_p2sh_change(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 address_n=parse_path("m/49h/1h/0h/1/0"), @@ -141,7 +141,7 @@ def test_send_segwit_p2sh_change(client: Client): amount=123_456_789 - 11_000 - 12_300_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -154,7 +154,7 @@ def test_send_segwit_p2sh_change(client: Client): ) -def test_send_segwit_native(client: Client): +def test_send_segwit_native(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -174,7 +174,7 @@ def test_send_segwit_native(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -187,7 +187,7 @@ def test_send_segwit_native(client: Client): ) -def test_send_segwit_native_change(client: Client): +def test_send_segwit_native_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=12_300_000, @@ -207,7 +207,7 @@ def test_send_segwit_native_change(client: Client): amount=12_300_000 - 11_000 - 5_000_000, ) _, serialized_tx = btc.sign_tx( - client, + session, "Groestlcoin Testnet", [inp1], [out1, out2], @@ -220,7 +220,7 @@ def test_send_segwit_native_change(client: Client): ) -def test_send_p2tr(client: Client): +def test_send_p2tr(session: Session): inp1 = messages.TxInputType( # tgrs1paxhjl357yzctuf3fe58fcdx6nul026hhh6kyldpfsf3tckj9a3wsvuqrgn address_n=parse_path("m/86h/1h/1h/0/0"), @@ -236,7 +236,7 @@ def test_send_p2tr(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Groestlcoin Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # Transaction hex changed with fix #2085, all other details are the same as this tx: # https://blockbook-test.groestlcoin.org/tx/c66a79075044aaab3dba17daffb23f48addee87d7c87c7bc88e2997ce38a74ee diff --git a/tests/device_tests/bitcoin/test_komodo.py b/tests/device_tests/bitcoin/test_komodo.py index f883afc7bcd..111acefc6f3 100644 --- a/tests/device_tests/bitcoin/test_komodo.py +++ b/tests/device_tests/bitcoin/test_komodo.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -43,7 +43,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.komodo] -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: 2807c5b126ec8e2b078cab0f12e4c8b4ce1d7724905f8ebef8dca26b0c8e0f1d:0 # input 1: 10.9998 KMD @@ -61,13 +61,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -82,7 +82,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1], @@ -100,7 +100,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_one_one_rewards_claim(client: Client): +def test_one_one_rewards_claim(session: Session): # prevout: 7b28bd91119e9776f0d4ebd80e570165818a829bbf4477cd1afe5149dbcd34b1:0 # input 1: 10.9997 KMD @@ -125,16 +125,16 @@ def test_one_one_rewards_claim(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -150,7 +150,7 @@ def test_one_one_rewards_claim(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Komodo", [inp1], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 7a3da905d68..15f81a155dd 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -55,12 +55,12 @@ @pytest.mark.multisig @pytest.mark.parametrize("chunkify", (True, False)) -def test_2_of_3(client: Client, chunkify: bool): +def test_2_of_3(session: Session, chunkify: bool): # input tx: 6b07c1321b52d9c85743f9695e13eb431b41708cdf4e1585258d51208e5b93fc nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Testnet" ).node for index in range(1, 4) ] @@ -89,7 +89,7 @@ def test_2_of_3(client: Client, chunkify: bool): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_6b07c1), @@ -101,12 +101,12 @@ def test_2_of_3(client: Client, chunkify: bool): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) # Now we have first signature signatures1, _ = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1], @@ -143,10 +143,10 @@ def test_2_of_3(client: Client, chunkify: bool): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET ) assert ( @@ -163,12 +163,12 @@ def test_2_of_3(client: Client, chunkify: bool): @pytest.mark.multisig @pytest.mark.models(skip="legacy", reason="Sortedmulti is not supported") -def test_pubkeys_order(client: Client): +def test_pubkeys_order(session: Session): node_internal = btc.get_public_node( - client, parse_path("m/45h/0"), coin_name="Bitcoin" + session, parse_path("m/45h/0"), coin_name="Bitcoin" ).node node_external = btc.get_public_node( - client, parse_path("m/45h/1"), coin_name="Bitcoin" + session, parse_path("m/45h/1"), coin_name="Bitcoin" ).node multisig_unsorted_1 = messages.MultisigRedeemScriptType( @@ -204,10 +204,10 @@ def test_pubkeys_order(client: Client): ) address_unsorted_1 = btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1 ) address_unsorted_2 = btc.get_address( - client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 + session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2 ) prev_hash, prev_tx = forge_prevtx( @@ -283,7 +283,7 @@ def test_pubkeys_order(client: Client): tx_unsorted_2 = "0100000001637ffac0d4fbd8a6c02b114e36b079615ec3e4bdf09b769c7bf8b5fd6f8e781701000000910047304402204914036468434698e2d87985007a66691f170195e4a16507bbb86b4c00da5fde02200a788312d447b3796ee5288ce9e9c0247896debfa473339302bc928da6dd78cb014751210369b79f2094a6eb89e7aff0e012a5699f7272968a341e48e99e64a54312f2932b210262e9ac5bea4c84c7dea650424ed768cf123af9e447eef3c63d37c41d1f825e4952aeffffffff01301b0f000000000017a914320ad0ff0f1b605ab1fa8e29b70d22827cf45a9f8700000000" _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_unsorted_1], [output_unsorted_1], @@ -292,7 +292,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_1 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_unsorted_2], [output_unsorted_2], @@ -301,7 +301,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_2 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_sorted_1], [output_sorted_1], @@ -310,7 +310,7 @@ def test_pubkeys_order(client: Client): assert tx.hex() == tx_unsorted_1 _, tx = btc.sign_tx( - client, + session, "Bitcoin", [input_sorted_2], [output_sorted_2], @@ -320,11 +320,11 @@ def test_pubkeys_order(client: Client): @pytest.mark.multisig -def test_15_of_15(client: Client): +def test_15_of_15(session: Session): # input tx: 0d5b5648d47b5650edea1af3d47bbe5624213abb577cf1b1c96f98321f75cdbc node = btc.get_public_node( - client, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" + session, parse_path("m/48h/1h/1h/0h"), coin_name="Testnet" ).node pubs = [messages.HDNodePathType(node=node, address_n=[0, x]) for x in range(15)] @@ -350,9 +350,9 @@ def test_15_of_15(client: Client): multisig=multisig, ) - with client: + with session: sig, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) signatures[x] = sig[0] @@ -364,9 +364,9 @@ def test_15_of_15(client: Client): @pytest.mark.multisig @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_missing_pubkey(client: Client): +def test_missing_pubkey(session: Session): node = btc.get_public_node( - client, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" + session, parse_path("m/48h/0h/1h/0h/0"), coin_name="Bitcoin" ).node multisig = messages.MultisigRedeemScriptType( @@ -396,16 +396,16 @@ def test_missing_pubkey(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) - if client.model is models.T1B1: + if session.model is models.T1B1: assert exc.value.message.endswith("Failed to derive scriptPubKey") else: assert exc.value.message.endswith("Pubkey not found in multisig script") @pytest.mark.multisig -def test_attack_change_input(client: Client): +def test_attack_change_input(session: Session): """ In Phases 1 and 2 the attacker replaces a non-multisig input `input_real` with a multisig input `input_fake`, which allows the @@ -428,7 +428,7 @@ def test_attack_change_input(client: Client): multisig_fake = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -463,12 +463,12 @@ def test_attack_change_input(client: Client): ) # Transaction can be signed without the attack processor - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], @@ -485,11 +485,11 @@ def attack_processor(msg): attack_count -= 1 return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, + session, "Testnet", [input_real], [output_payee, output_change], diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index 9703a9b6727..6e7a2a8ed78 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ... import bip32 @@ -123,7 +123,7 @@ def _responses( - client: Client, + session: Session, INP1: messages.TxInputType, INP2: messages.TxInputType, change_indices: Optional[list[int]] = None, @@ -144,7 +144,7 @@ def _responses( resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) if 1 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp.append(request_output(1)) @@ -153,7 +153,7 @@ def _responses( resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) if 2 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - if is_core(client): + if is_core(session): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp += [ @@ -182,7 +182,7 @@ def _responses( # both outputs are external -def test_external_external(client: Client): +def test_external_external(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -195,10 +195,10 @@ def test_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -207,7 +207,7 @@ def test_external_external(client: Client): # first external, second internal -def test_external_internal(client: Client): +def test_external_internal(session: Session): out1 = messages.TxOutputType( address="1F8yBZB2NZhPZvJekhjTwjhQRRvQeTjjXr", amount=40_000_000, @@ -220,21 +220,21 @@ def test_external_internal(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( _responses( - client, + session, INP1, INP2, change_indices=[] if is_core(client) else [2], foreign_indices=[2], ) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -243,7 +243,7 @@ def test_external_internal(client: Client): # first internal, second external -def test_internal_external(client: Client): +def test_internal_external(session: Session): out1 = messages.TxOutputType( address_n=parse_path("m/45h/0/1/0"), amount=40_000_000, @@ -256,21 +256,21 @@ def test_internal_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( _responses( - client, + session, INP1, INP2, change_indices=[] if is_core(client) else [1], foreign_indices=[1], ) ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -279,7 +279,7 @@ def test_internal_external(client: Client): # both outputs are external -def test_multisig_external_external(client: Client): +def test_multisig_external_external(session: Session): out1 = messages.TxOutputType( address="3B23k4kFBRtu49zvpG3Z9xuFzfpHvxBcwt", amount=40_000_000, @@ -292,10 +292,10 @@ def test_multisig_external_external(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -304,7 +304,7 @@ def test_multisig_external_external(client: Client): # inputs match, change matches (first is change) -def test_multisig_change_match_first(client: Client): +def test_multisig_change_match_first(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 0], @@ -325,12 +325,10 @@ def test_multisig_change_match_first(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change_indices=[1]) - ) + with session: + session.set_expected_responses(_responses(session, INP1, INP2, change=1)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -339,7 +337,7 @@ def test_multisig_change_match_first(client: Client): # inputs match, change matches (second is change) -def test_multisig_change_match_second(client: Client): +def test_multisig_change_match_second(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_EXT2, NODE_INT], address_n=[1, 1], @@ -360,12 +358,12 @@ def test_multisig_change_match_second(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses( - _responses(client, INP1, INP2, change_indices=[2]) + with session: + session.set_expected_responses( + _responses(session, INP1, INP2, change_indices=[2]) ) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -374,7 +372,7 @@ def test_multisig_change_match_second(client: Client): # inputs match, change mismatches (second tries to be change but isn't) -def test_multisig_mismatch_multisig_change(client: Client): +def test_multisig_mismatch_multisig_change(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT3], address_n=[1, 0], @@ -395,10 +393,10 @@ def test_multisig_mismatch_multisig_change(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -408,7 +406,7 @@ def test_multisig_mismatch_multisig_change(client: Client): # inputs match, change mismatches (second tries to be change but isn't) @pytest.mark.models(skip="legacy", reason="Not fixed") -def test_multisig_mismatch_multisig_change_different_paths(client: Client): +def test_multisig_mismatch_multisig_change_different_paths(session: Session): multisig_out2 = messages.MultisigRedeemScriptType( pubkeys=[ messages.HDNodePathType(node=NODE_EXT1, address_n=[1, 0]), @@ -432,10 +430,10 @@ def test_multisig_mismatch_multisig_change_different_paths(client: Client): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP2)) + with session: + session.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP2], [out1, out2], @@ -444,7 +442,7 @@ def test_multisig_mismatch_multisig_change_different_paths(client: Client): # inputs mismatch, change matches with first input -def test_multisig_mismatch_inputs(client: Client): +def test_multisig_mismatch_inputs(session: Session): multisig_out1 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT2, NODE_EXT1, NODE_INT], address_n=[1, 0], @@ -465,10 +463,10 @@ def test_multisig_mismatch_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses(_responses(client, INP1, INP3)) + with session: + session.set_expected_responses(_responses(session, INP1, INP3)) btc.sign_tx( - client, + session, "Bitcoin", [INP1, INP3], [out1, out2], diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index ac33ee8b40e..77d57aa9517 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import is_core @@ -94,11 +94,11 @@ # accepted in case we make this more restrictive in the future. @pytest.mark.parametrize("path, script_types", VECTORS) def test_getpublicnode( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: res = btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin", script_type=script_type + session, parse_path(path), coin_name="Bitcoin", script_type=script_type ) assert res.xpub @@ -107,18 +107,18 @@ def test_getpublicnode( @pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("path, script_types", VECTORS) def test_getaddress( - client: Client, + session: Session, chunkify: bool, path: str, script_types: list[messages.InputScriptType], ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) res = btc.get_address( - client, + session, "Bitcoin", parse_path(path), show_display=True, @@ -131,16 +131,16 @@ def test_getaddress( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signmessage( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path(path), script_type=script_type, @@ -152,12 +152,14 @@ def test_signmessage( @pytest.mark.parametrize("path, script_types", VECTORS) def test_signtx( - client: Client, path: str, script_types: list[messages.InputScriptType] + session: Session, path: str, script_types: list[messages.InputScriptType] ): address_n = parse_path(path) for script_type in script_types: - address = btc.get_address(client, "Bitcoin", address_n, script_type=script_type) + address = btc.get_address( + session, "Bitcoin", address_n, script_type=script_type + ) prevhash, prevtx = forge_prevtx([(address, 390_000)]) inp1 = messages.TxInputType( address_n=address_n, @@ -173,12 +175,12 @@ def test_signtx( script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert serialized_tx.hex() @@ -187,12 +189,12 @@ def test_signtx( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) def test_getaddress_multisig( - client: Client, paths: list[str], address_index: list[int] + session: Session, paths: list[str], address_index: list[int] ): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -200,12 +202,12 @@ def test_getaddress_multisig( ] multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) address = btc.get_address( - client, + session, "Bitcoin", parse_path(paths[0]) + address_index, show_display=True, @@ -218,11 +220,11 @@ def test_getaddress_multisig( @pytest.mark.multisig @pytest.mark.parametrize("paths, address_index", VECTORS_MULTISIG) -def test_signtx_multisig(client: Client, paths: list[str], address_index: list[int]): +def test_signtx_multisig(session: Session, paths: list[str], address_index: list[int]): pubs = [ messages.HDNodePathType( node=btc.get_public_node( - client, parse_path(path), coin_name="Bitcoin" + session, parse_path(path), coin_name="Bitcoin" ).node, address_n=address_index, ) @@ -235,7 +237,7 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i address_n = parse_path(paths[0]) + address_index address = btc.get_address( - client, + session, "Bitcoin", address_n, multisig=multisig, @@ -259,12 +261,12 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) sig, _ = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) assert sig[0] diff --git a/tests/device_tests/bitcoin/test_op_return.py b/tests/device_tests/bitcoin/test_op_return.py index b5063891993..0aa8acb0802 100644 --- a/tests/device_tests/bitcoin/test_op_return.py +++ b/tests/device_tests/bitcoin/test_op_return.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,7 +43,7 @@ ) -def test_opreturn(client: Client): +def test_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/1h/0/21"), # myGMXcCxmuDooMdzZFPMmvHviijzqYKhza amount=89_581, @@ -63,13 +63,13 @@ def test_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), messages.ButtonRequest(code=B.SignTx), @@ -86,7 +86,7 @@ def test_opreturn(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -96,7 +96,7 @@ def test_opreturn(client: Client): ) -def test_nonzero_opreturn(client: Client): +def test_nonzero_opreturn(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/10h/0/5"), amount=390_000, @@ -110,18 +110,18 @@ def test_nonzero_opreturn(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="OP_RETURN output with non-zero amount" ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) -def test_opreturn_address(client: Client): +def test_opreturn_address(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/2"), amount=390_000, @@ -136,11 +136,11 @@ def test_opreturn_address(client: Client): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( TrezorFailure, match="Output's address_n provided but not expected." ): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_peercoin.py b/tests/device_tests/bitcoin/test_peercoin.py index b1b62e49e55..b3de714e26e 100644 --- a/tests/device_tests/bitcoin/test_peercoin.py +++ b/tests/device_tests/bitcoin/test_peercoin.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -32,7 +32,7 @@ @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_included(client: Client): +def test_timestamp_included(session: Session): # tx: 41b29ad615d8eea40a4654a052d18bb10cd08f203c351f4d241f88b031357d3d # input 0: 0.1 PPC @@ -50,7 +50,7 @@ def test_timestamp_included(client: Client): ) _, timestamp_tx = btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -66,7 +66,7 @@ def test_timestamp_included(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing(client: Client): +def test_timestamp_missing(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -81,7 +81,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -92,7 +92,7 @@ def test_timestamp_missing(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -104,7 +104,7 @@ def test_timestamp_missing(client: Client): @pytest.mark.altcoin @pytest.mark.peercoin -def test_timestamp_missing_prevtx(client: Client): +def test_timestamp_missing_prevtx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/6h/0h/0/0"), amount=100_000, @@ -122,7 +122,7 @@ def test_timestamp_missing_prevtx(client: Client): with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], @@ -134,7 +134,7 @@ def test_timestamp_missing_prevtx(client: Client): prevtx.timestamp = None with pytest.raises(TrezorFailure, match="Timestamp must be set."): btc.sign_tx( - client, + session, "Peercoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index 2e1cab3eda3..8b9dca8574b 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -20,7 +20,7 @@ from trezorlib import btc, messages from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import message_filters from trezorlib.exceptions import Cancelled from trezorlib.tools import parse_path @@ -286,7 +286,7 @@ def case( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -296,7 +296,7 @@ def test_signmessage( signature: str, ): sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -312,7 +312,7 @@ def test_signmessage( "coin_name, path, script_type, no_script_type, address, message, signature", VECTORS ) def test_signmessage_info( - client: Client, + session: Session, coin_name: str, path: str, script_type: messages.InputScriptType, @@ -321,11 +321,11 @@ def test_signmessage_info( message: str, signature: str, ): - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignMessageInfo(client) client.set_input_flow(IF.get()) sig = btc.sign_message( - client, + session, coin_name=coin_name, n=parse_path(path), script_type=script_type, @@ -352,12 +352,12 @@ def test_signmessage_info( @pytest.mark.models("core") @pytest.mark.parametrize("message", MESSAGE_LENGTHS) -def test_signmessage_pagination(client: Client, message: str): - with client: +def test_signmessage_pagination(session: Session, message: str): + with session.client as client: IF = InputFlowSignMessagePagination(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, @@ -365,19 +365,19 @@ def test_signmessage_pagination(client: Client, message: str): # We cannot differentiate between a newline and space in the message read from Trezor. # TODO: do the check also for T2B1 - if client.layout_type in (LayoutType.TT, LayoutType.Mercury): + if session.client.layout_type in (LayoutType.TT, LayoutType.Mercury): message_read = IF.message_read.replace(" ", "").replace("...", "") signed_message = message.replace("\n", "").replace(" ", "") assert signed_message in message_read @pytest.mark.models("t2t1", reason="Tailored to TT fonts and screen size") -def test_signmessage_pagination_trailing_newline(client: Client): +def test_signmessage_pagination_trailing_newline(session: Session): message = "THIS\nMUST\nNOT\nBE\nPAGINATED\n" # The trailing newline must not cause a new paginated screen to appear. # The UI must be a single dialog without pagination. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # expect address confirmation message_filters.ButtonRequest(code=messages.ButtonRequestType.Other), @@ -387,18 +387,18 @@ def test_signmessage_pagination_trailing_newline(client: Client): ] ) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/44h/0h/0h/0/0"), message=message, ) -def test_signmessage_path_warning(client: Client): +def test_signmessage_path_warning(session: Session): message = "This is an example of a signed message." - with client: - client.set_expected_responses( + with session, session.client as client: + session.set_expected_responses( [ # expect a path warning message_filters.ButtonRequest( @@ -409,11 +409,11 @@ def test_signmessage_path_warning(client: Client): messages.MessageSignature, ] ) - if is_core(client): + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) btc.sign_message( - client, + session, coin_name="Bitcoin", n=parse_path("m/86h/0h/0h/0/0"), message=message, diff --git a/tests/device_tests/bitcoin/test_signtx.py b/tests/device_tests/bitcoin/test_signtx.py index 96fc4edc691..135992224e6 100644 --- a/tests/device_tests/bitcoin/test_signtx.py +++ b/tests/device_tests/bitcoin/test_signtx.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.tools import H_, parse_path @@ -111,7 +111,7 @@ CORNER_BUTTON = (215, 25) -def test_one_one_fee(client: Client): +def test_one_one_fee(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -127,13 +127,13 @@ def test_one_one_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_0dac36), @@ -148,7 +148,7 @@ def test_one_one_fee(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -158,7 +158,7 @@ def test_one_one_fee(client: Client): ) -def test_testnet_one_two_fee(client: Client): +def test_testnet_one_two_fee(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd inp1 = messages.TxInputType( @@ -180,13 +180,13 @@ def test_testnet_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -203,7 +203,7 @@ def test_testnet_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -213,7 +213,7 @@ def test_testnet_one_two_fee(client: Client): ) -def test_testnet_fee_high_warning(client: Client): +def test_testnet_fee_high_warning(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -230,13 +230,13 @@ def test_testnet_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -250,7 +250,7 @@ def test_testnet_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -260,7 +260,7 @@ def test_testnet_fee_high_warning(client: Client): ) -def test_one_two_fee(client: Client): +def test_one_two_fee(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -282,14 +282,14 @@ def test_one_two_fee(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -305,7 +305,7 @@ def test_one_two_fee(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -316,7 +316,7 @@ def test_one_two_fee(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_one_three_fee(client: Client, chunkify: bool): +def test_one_three_fee(session: Session, chunkify: bool): # input tx: bb5169091f09e833e155b291b662019df56870effe388c626221c5ea84274bc4 inp1 = messages.TxInputType( @@ -344,16 +344,16 @@ def test_one_three_fee(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -371,7 +371,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2, out3], @@ -386,7 +386,7 @@ def test_one_three_fee(client: Client, chunkify: bool): ) -def test_two_two(client: Client): +def test_two_two(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -415,15 +415,15 @@ def test_two_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -449,7 +449,7 @@ def test_two_two(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -464,7 +464,7 @@ def test_two_two(client: Client): @pytest.mark.slow -def test_lots_of_inputs(client: Client): +def test_lots_of_inputs(session: Session): # Tests if device implements serialization of len(inputs) correctly # input tx: 3019487f064329247daad245aed7a75349d09c14b1d24f170947690e030f5b20 @@ -485,7 +485,7 @@ def test_lots_of_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET + session, "Testnet", inputs, [out], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -495,7 +495,7 @@ def test_lots_of_inputs(client: Client): @pytest.mark.slow -def test_lots_of_outputs(client: Client): +def test_lots_of_outputs(session: Session): # Tests if device implements serialization of len(outputs) correctly # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e @@ -518,7 +518,7 @@ def test_lots_of_outputs(client: Client): outputs.append(out) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -528,7 +528,7 @@ def test_lots_of_outputs(client: Client): @pytest.mark.slow -def test_lots_of_change(client: Client): +def test_lots_of_change(session: Session): # Tests if device implements prompting for multiple change addresses correctly # input tx: 892d06cb3394b8e6006eec9a2aa90692b718a29be6844b6c6a9e89ec3aa6aac4 @@ -559,13 +559,13 @@ def test_lots_of_change(client: Client): request_change_outputs = [request_output(i + 1) for i in range(cnt)] - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), ] + request_change_outputs + [ @@ -585,7 +585,7 @@ def test_lots_of_change(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], outputs, prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -594,7 +594,7 @@ def test_lots_of_change(client: Client): ) -def test_fee_high_warning(client: Client): +def test_fee_high_warning(session: Session): # input tx: 1f326f65768d55ef146efbb345bd87abe84ac7185726d0457a026fc347a26ef3 inp1 = messages.TxInputType( @@ -610,13 +610,13 @@ def test_fee_high_warning(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.FeeOverThreshold), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -631,7 +631,7 @@ def test_fee_high_warning(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -642,7 +642,7 @@ def test_fee_high_warning(client: Client): @pytest.mark.models("core") -def test_fee_high_hardfail(client: Client): +def test_fee_high_hardfail(session: Session): # input tx: 25fee583181847cbe9d9fd9a483a8b8626c99854a72d01de848ef40508d0f3bc # (The "25fee" tx hash is very suitable for testing high fees) @@ -660,18 +660,18 @@ def test_fee_high_hardfail(client: Client): ) with pytest.raises(TrezorFailure, match="fee is unexpectedly large"): - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET) # set SafetyCheckLevel to PromptTemporarily and try again device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: + with session.client as client: IF = InputFlowSignTxHighFee(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert IF.finished @@ -682,7 +682,7 @@ def test_fee_high_hardfail(client: Client): ) -def test_not_enough_funds(client: Client): +def test_not_enough_funds(session: Session): # input tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882 inp1 = messages.TxInputType( @@ -698,21 +698,21 @@ def test_not_enough_funds(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.Failure(code=messages.FailureType.NotEnoughFunds), ] ) with pytest.raises(TrezorFailure, match="NotEnoughFunds"): - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) -def test_p2sh(client: Client): +def test_p2sh(session: Session): # input tx: 58d56a5d1325cf83543ee4c87fd73a784e4ba1499ced574be359fa2bdcb9ac8e inp1 = messages.TxInputType( @@ -728,13 +728,13 @@ def test_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_58d56a), @@ -748,7 +748,7 @@ def test_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -758,7 +758,7 @@ def test_p2sh(client: Client): ) -def test_testnet_big_amount(client: Client): +def test_testnet_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 # input tx: 074b0070939db4c2635c1bef0c8e68412ccc8d3c8782137547c7a2bbde073fc0 @@ -775,7 +775,7 @@ def test_testnet_big_amount(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -785,7 +785,7 @@ def test_testnet_big_amount(client: Client): ) -def test_attack_change_outputs(client: Client): +def test_attack_change_outputs(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a inp1 = messages.TxInputType( @@ -815,15 +815,15 @@ def test_attack_change_outputs(client: Client): ) # Test if the transaction can be signed normally - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ac4ca0), @@ -849,7 +849,7 @@ def test_attack_change_outputs(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET + session, "Bitcoin", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_MAINNET ) assert_tx_matches( @@ -871,14 +871,14 @@ def attack_processor(msg): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) btc.sign_tx( - client, + session, "Bitcoin", [inp1, inp2], [out1, out2], @@ -886,7 +886,7 @@ def attack_processor(msg): ) -def test_attack_modify_change_address(client: Client): +def test_attack_modify_change_address(session: Session): # Ensure that if the change output is modified after the user confirms the # transaction, then signing fails. @@ -926,16 +926,18 @@ def attack_processor(msg): return msg - with client, pytest.raises( + with session, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - client.set_filter(messages.TxAck, attack_processor) + session.set_filter(messages.TxAck, attack_processor) - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # input tx: d2dcdaf547ea7f57a713c607f15e883ddc4a98167ee2c43ed953c53cb5153e24 inp1 = messages.TxInputType( @@ -960,7 +962,7 @@ def test_attack_change_input_address(client: Client): # Test if the transaction can be signed normally _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -982,14 +984,14 @@ def attack_processor(msg): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1004,7 +1006,7 @@ def attack_processor(msg): # Now run the attack, must trigger the exception with pytest.raises(TrezorFailure) as exc: btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1015,7 +1017,7 @@ def attack_processor(msg): assert exc.value.message.endswith("Transaction has changed during signing") -def test_spend_coinbase(client: Client): +def test_spend_coinbase(session: Session): # NOTE: the input transaction is not real # We did not have any coinbase transaction at connected with `all all` seed, # so it was artificially created for the test purpose @@ -1033,13 +1035,13 @@ def test_spend_coinbase(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(FAKE_TXHASH_005f6f), @@ -1052,7 +1054,7 @@ def test_spend_coinbase(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -1062,7 +1064,7 @@ def test_spend_coinbase(client: Client): ) -def test_two_changes(client: Client): +def test_two_changes(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1091,13 +1093,13 @@ def test_two_changes(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), request_output(2), messages.ButtonRequest(code=B.SignTx), @@ -1118,7 +1120,7 @@ def test_two_changes(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change1, out_change2], @@ -1126,7 +1128,7 @@ def test_two_changes(client: Client): ) -def test_change_on_main_chain_allowed(client: Client): +def test_change_on_main_chain_allowed(session: Session): # input tx: e5040e1bc1ae7667ffb9e5248e90b2fb93cd9150234151ce90e14ab2f5933bcd # see 87be0736f202f7c2bff0781b42bad3e0cdcb54761939da69ea793a3735552c56 @@ -1150,13 +1152,13 @@ def test_change_on_main_chain_allowed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1174,7 +1176,7 @@ def test_change_on_main_chain_allowed(client: Client): ) btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out_change], @@ -1182,7 +1184,7 @@ def test_change_on_main_chain_allowed(client: Client): ) -def test_not_enough_vouts(client: Client): +def test_not_enough_vouts(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a prev_tx = TX_CACHE_MAINNET[TXHASH_ac4ca0] @@ -1222,7 +1224,7 @@ def test_not_enough_vouts(client: Client): TrezorFailure, match="Not enough outputs in previous transaction." ): btc.sign_tx( - client, + session, "Bitcoin", [inp0, inp1, inp2], [out1], @@ -1240,7 +1242,7 @@ def test_not_enough_vouts(client: Client): ("branch_id", 13), ), ) -def test_prevtx_forbidden_fields(client: Client, field, value): +def test_prevtx_forbidden_fields(session: Session, field, value): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1258,7 +1260,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} + session, "Bitcoin", [inp0], [out1], prev_txes={TXHASH_157041: prev_tx} ) @@ -1266,7 +1268,7 @@ def test_prevtx_forbidden_fields(client: Client, field, value): "field, value", (("expiry", 9), ("timestamp", 42), ("version_group_id", 69), ("branch_id", 13)), ) -def test_signtx_forbidden_fields(client: Client, field: str, value: int): +def test_signtx_forbidden_fields(session: Session, field: str, value: int): inp0 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), # 1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL prev_hash=TXHASH_157041, @@ -1283,7 +1285,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): name = field.replace("_", " ") with pytest.raises(TrezorFailure, match=rf"(?i){name} not enabled on this coin"): btc.sign_tx( - client, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs + session, "Bitcoin", [inp0], [out1], prev_txes=TX_CACHE_MAINNET, **kwargs ) @@ -1291,7 +1293,7 @@ def test_signtx_forbidden_fields(client: Client, field: str, value: int): "script_type", (messages.InputScriptType.SPENDADDRESS, messages.InputScriptType.EXTERNAL), ) -def test_incorrect_input_script_type(client: Client, script_type): +def test_incorrect_input_script_type(session: Session, script_type): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( "030e669acac1f280d1ddf441cd2ba5e97417bf2689e4bbec86df4f831bf9f7ffd0" @@ -1300,7 +1302,7 @@ def test_incorrect_input_script_type(client: Client, script_type): multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1335,7 +1337,9 @@ def test_incorrect_input_script_type(client: Client, script_type): with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( @@ -1346,7 +1350,7 @@ def test_incorrect_input_script_type(client: Client, script_type): ), ) def test_incorrect_output_script_type( - client: Client, script_type: messages.OutputScriptType + session: Session, script_type: messages.OutputScriptType ): address_n = parse_path("m/44h/1h/0h/0/0") # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q attacker_multisig_public_key = bytes.fromhex( @@ -1356,7 +1360,7 @@ def test_incorrect_output_script_type( multisig = messages.MultisigRedeemScriptType( m=1, nodes=[ - btc.get_public_node(client, address_n, coin_name="Testnet").node, + btc.get_public_node(session, address_n, coin_name="Testnet").node, messages.HDNodeType( depth=0, fingerprint=0, @@ -1390,14 +1394,16 @@ def test_incorrect_output_script_type( with pytest.raises( TrezorFailure, match="Multisig field provided but not expected." ): - btc.sign_tx(client, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET + ) @pytest.mark.parametrize( "lock_time, sequence", ((499_999_999, 0xFFFFFFFE), (500_000_000, 0xFFFFFFFE), (1, 0xFFFFFFFF)), ) -def test_lock_time(client: Client, lock_time: int, sequence: int): +def test_lock_time(session: Session, lock_time: int, sequence: int): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1414,13 +1420,13 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -1436,7 +1442,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): ) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1446,7 +1452,7 @@ def test_lock_time(client: Client, lock_time: int, sequence: int): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_lock_time_blockheight(client: Client): +def test_lock_time_blockheight(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1463,12 +1469,12 @@ def test_lock_time_blockheight(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowLockTimeBlockHeight(client, "499999999") client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1481,7 +1487,7 @@ def test_lock_time_blockheight(client: Client): @pytest.mark.parametrize( "lock_time_str", ("1985-11-05 00:53:20", "2048-08-16 22:14:00") ) -def test_lock_time_datetime(client: Client, lock_time_str: str): +def test_lock_time_datetime(session: Session, lock_time_str: str): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1502,12 +1508,12 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_timestamp = int(lock_time_utc.timestamp()) - with client: + with session.client as client: IF = InputFlowLockTimeDatetime(client, lock_time_str) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1517,7 +1523,7 @@ def test_lock_time_datetime(client: Client, lock_time_str: str): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information(client: Client): +def test_information(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1534,12 +1540,12 @@ def test_information(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformation(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1548,7 +1554,7 @@ def test_information(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_mixed(client: Client): +def test_information_mixed(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/0"), # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q amount=31_000_000, @@ -1569,12 +1575,12 @@ def test_information_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationMixed(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -1583,7 +1589,7 @@ def test_information_mixed(client: Client): @pytest.mark.models("core", reason="Cannot test layouts on T1") -def test_information_cancel(client: Client): +def test_information_cancel(session: Session): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -1600,12 +1606,12 @@ def test_information_cancel(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(Cancelled): + with session.client as client, pytest.raises(Cancelled): IF = InputFlowSignTxInformationCancel(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], @@ -1618,7 +1624,7 @@ def test_information_cancel(client: Client): skip="mercury", reason="Cannot test layouts on T1, not implemented in mercury UI", ) -def test_information_replacement(client: Client): +def test_information_replacement(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -1650,12 +1656,12 @@ def test_information_replacement(client: Client): orig_index=0, ) - with client: + with session.client as client: IF = InputFlowSignTxInformationReplacement(client) client.set_input_flow(IF.get()) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_amount_unit.py b/tests/device_tests/bitcoin/test_signtx_amount_unit.py index d3dfa3d00ec..50cc19151b6 100644 --- a/tests/device_tests/bitcoin/test_signtx_amount_unit.py +++ b/tests/device_tests/bitcoin/test_signtx_amount_unit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -42,7 +42,7 @@ @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_testnet(client: Client, amount_unit): +def test_signtx_testnet(session: Session, amount_unit): inp1 = messages.TxInputType( # tb1qajr3a3y5uz27lkxrmn7ck8lp22dgytvagr5nqy address_n=parse_path("m/84h/1h/0h/0/87"), @@ -61,9 +61,9 @@ def test_signtx_testnet(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -79,7 +79,7 @@ def test_signtx_testnet(client: Client, amount_unit): @pytest.mark.parametrize("amount_unit", VECTORS) -def test_signtx_btc(client: Client, amount_unit): +def test_signtx_btc(session: Session, amount_unit): # input tx: 0dac366fd8a67b2a89fbb0d31086e7acded7a5bbf9ef9daa935bc873229ef5b5 inp1 = messages.TxInputType( @@ -95,9 +95,9 @@ def test_signtx_btc(client: Client, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_external.py b/tests/device_tests/bitcoin/test_signtx_external.py index fd8e0cff3e9..4d44e3ec763 100644 --- a/tests/device_tests/bitcoin/test_signtx_external.py +++ b/tests/device_tests/bitcoin/test_signtx_external.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import parse_path @@ -82,7 +82,7 @@ @pytest.mark.models("core") -def test_p2pkh_presigned(client: Client): +def test_p2pkh_presigned(session: Session): inp1 = messages.TxInputType( # mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q address_n=parse_path("m/44h/1h/0h/0/0"), @@ -142,9 +142,9 @@ def test_p2pkh_presigned(client: Client): ) # Test with first input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1ext, inp2], [out1, out2], @@ -155,9 +155,9 @@ def test_p2pkh_presigned(client: Client): assert serialized_tx.hex() == expected_tx # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -170,7 +170,7 @@ def test_p2pkh_presigned(client: Client): inp2ext.script_sig[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2ext], [out1, out2], @@ -179,7 +179,7 @@ def test_p2pkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_presigned(client: Client): +def test_p2wpkh_in_p2sh_presigned(session: Session): inp1 = messages.TxInputType( # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX amount=123_456_789, @@ -216,20 +216,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -252,7 +252,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -267,20 +267,20 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): # Test corrupted script hash in scriptsig. inp1.script_sig[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -293,7 +293,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid public key hash"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -302,7 +302,7 @@ def test_p2wpkh_in_p2sh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wpkh_presigned(client: Client): +def test_p2wpkh_presigned(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -339,9 +339,9 @@ def test_p2wpkh_presigned(client: Client): ) # Test with second input as pre-signed external. - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -358,7 +358,7 @@ def test_p2wpkh_presigned(client: Client): inp2.witness[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -367,7 +367,7 @@ def test_p2wpkh_presigned(client: Client): @pytest.mark.models("core") -def test_p2wsh_external_presigned(client: Client): +def test_p2wsh_external_presigned(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/0"), amount=10_000, @@ -399,14 +399,14 @@ def test_p2wsh_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -429,7 +429,7 @@ def test_p2wsh_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -444,14 +444,14 @@ def test_p2wsh_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_ec16dc), @@ -470,12 +470,12 @@ def test_p2wsh_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) @pytest.mark.models("core") -def test_p2tr_external_presigned(client: Client): +def test_p2tr_external_presigned(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -509,14 +509,14 @@ def test_p2tr_external_presigned(client: Client): amount=4_600, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -530,7 +530,7 @@ def test_p2tr_external_presigned(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) assert_tx_matches( @@ -541,14 +541,14 @@ def test_p2tr_external_presigned(client: Client): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(1), @@ -558,7 +558,7 @@ def test_p2tr_external_presigned(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -567,18 +567,18 @@ def test_p2tr_external_presigned(client: Client): @pytest.mark.models("core") -def test_p2pkh_with_proof(client: Client): +def test_p2pkh_with_proof(session: Session): # TODO pass @pytest.mark.models("core") -def test_p2wpkh_in_p2sh_with_proof(client: Client): +def test_p2wpkh_in_p2sh_with_proof(session: Session): # TODO pass -def test_p2wpkh_with_proof(client: Client): +def test_p2wpkh_with_proof(session: Session): inp1 = messages.TxInputType( # seed "alcohol woman abuse must during monitor noble actual mixed trade anger aisle" # 84'/1'/0'/0/0 @@ -610,18 +610,18 @@ def test_p2wpkh_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e5b7e2), @@ -643,7 +643,7 @@ def test_p2wpkh_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -660,7 +660,7 @@ def test_p2wpkh_with_proof(client: Client): inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -671,7 +671,7 @@ def test_p2wpkh_with_proof(client: Client): @pytest.mark.setup_client( mnemonic="abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" ) -def test_p2tr_with_proof(client: Client): +def test_p2tr_with_proof(session: Session): # Resulting TXID 48ec6dc7bb772ff18cbce0135fedda7c0e85212c7b2f85a5d0cc7a917d77c48a inp1 = messages.TxInputType( @@ -703,15 +703,15 @@ def test_p2tr_with_proof(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -722,7 +722,7 @@ def test_p2tr_with_proof(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -736,10 +736,12 @@ def test_p2tr_with_proof(client: Client): # Test corrupted ownership proof. inp1.ownership_proof[10] ^= 1 with pytest.raises(TrezorFailure, match="Invalid signature|Invalid external input"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET) + btc.sign_tx( + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_CACHE_TESTNET + ) -def test_p2wpkh_with_false_proof(client: Client): +def test_p2wpkh_with_false_proof(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -768,8 +770,8 @@ def test_p2wpkh_with_false_proof(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), @@ -779,7 +781,7 @@ def test_p2wpkh_with_false_proof(client: Client): with pytest.raises(TrezorFailure, match="Invalid external input"): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -787,7 +789,7 @@ def test_p2wpkh_with_false_proof(client: Client): ) -def test_p2tr_external_unverified(client: Client): +def test_p2tr_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -823,13 +825,13 @@ def test_p2tr_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. @@ -840,7 +842,7 @@ def test_p2tr_external_unverified(client: Client): ) -def test_p2wpkh_external_unverified(client: Client): +def test_p2wpkh_external_unverified(session: Session): inp1 = messages.TxInputType( # tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9 address_n=parse_path("m/84h/1h/0h/0/0"), @@ -875,13 +877,13 @@ def test_p2wpkh_external_unverified(client: Client): # Unverified external inputs should be rejected when safety checks are enabled. with pytest.raises(TrezorFailure, match="[Ee]xternal input"): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Signing should succeed after disabling safety checks. - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Second witness is missing from the serialized transaction. diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 5ef4ba0389c..27f0599de9b 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -36,7 +36,7 @@ # Litecoin does not have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should fail. @pytest.mark.altcoin -def test_invalid_path_fail(client: Client): +def test_invalid_path_fail(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -52,7 +52,7 @@ def test_invalid_path_fail(client: Client): ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) assert exc.value.code == messages.FailureType.DataError assert exc.value.message.endswith("Forbidden key path") @@ -61,7 +61,7 @@ def test_invalid_path_fail(client: Client): # Litecoin does not have strong replay protection using SIGHASH_FORKID, but # spending from Bitcoin path should pass with safety checks set to prompt. @pytest.mark.altcoin -def test_invalid_path_prompt(client: Client): +def test_invalid_path_prompt(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -77,21 +77,21 @@ def test_invalid_path_prompt(client: Client): ) device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) # Bcash does have strong replay protection using SIGHASH_FORKID, # spending from Bitcoin path should work. @pytest.mark.altcoin -def test_invalid_path_pass_forkid(client: Client): +def test_invalid_path_pass_forkid(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/0"), amount=390_000, @@ -106,32 +106,32 @@ def test_invalid_path_pass_forkid(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - if is_core(client): + with session.client as client: + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) -def test_attack_path_segwit(client: Client): +def test_attack_path_segwit(session: Session): # Scenario: The attacker falsely claims that the transaction uses Testnet paths to # avoid the path warning dialog, but in step6_sign_segwit_inputs() uses Bitcoin paths # to get a valid signature. device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) # Generate keys address_a = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/0h/0/0"), script_type=messages.InputScriptType.SPENDWITNESS, ) address_b = btc.get_address( - client, + session, "Testnet", parse_path("m/84h/0h/1h/0/1"), script_type=messages.InputScriptType.SPENDWITNESS, @@ -178,15 +178,15 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} ) -def test_invalid_path_fail_asap(client: Client): +def test_invalid_path_fail_asap(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/0"), amount=1_000_000, @@ -202,14 +202,14 @@ def test_invalid_path_fail_asap(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), messages.Failure(code=messages.FailureType.DataError), ] ) try: - btc.sign_tx(client, "Testnet", [inp1], [out1]) + btc.sign_tx(session, "Testnet", [inp1], [out1]) except TrezorFailure: pass diff --git a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py index de0f3807689..d3ab1cf37b0 100644 --- a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py +++ b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py @@ -15,7 +15,7 @@ # If not, see . from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...tx_cache import TxCache @@ -34,7 +34,7 @@ ) -def test_non_segwit_segwit_inputs(client: Client): +def test_non_segwit_segwit_inputs(session: Session): # First is non-segwit, second is segwit. inp1 = messages.TxInputType( @@ -58,9 +58,9 @@ def test_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -71,7 +71,7 @@ def test_non_segwit_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_inputs(client: Client): +def test_segwit_non_segwit_inputs(session: Session): # First is segwit, second is non-segwit. inp1 = messages.TxInputType( @@ -94,9 +94,9 @@ def test_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) assert len(signatures) == 2 @@ -107,7 +107,7 @@ def test_segwit_non_segwit_inputs(client: Client): ) -def test_segwit_non_segwit_segwit_inputs(client: Client): +def test_segwit_non_segwit_segwit_inputs(session: Session): # First is segwit, second is non-segwit and third is segwit again. inp1 = messages.TxInputType( @@ -138,9 +138,9 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 @@ -151,7 +151,7 @@ def test_segwit_non_segwit_segwit_inputs(client: Client): ) -def test_non_segwit_segwit_non_segwit_inputs(client: Client): +def test_non_segwit_segwit_non_segwit_inputs(session: Session): # First is non-segwit, second is segwit and third is non-segwit again. inp1 = messages.TxInputType( @@ -180,9 +180,9 @@ def test_non_segwit_segwit_non_segwit_inputs(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: signatures, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API + session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) assert len(signatures) == 3 diff --git a/tests/device_tests/bitcoin/test_signtx_payreq.py b/tests/device_tests/bitcoin/test_signtx_payreq.py index e02cb2b6c6b..32c90d05e0c 100644 --- a/tests/device_tests/bitcoin/test_signtx_payreq.py +++ b/tests/device_tests/bitcoin/test_signtx_payreq.py @@ -18,8 +18,8 @@ import pytest -from trezorlib import btc, messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import btc, messages, misc, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -138,7 +138,7 @@ def case(id, *args, altcoin: bool = False, models: str | None = None): case("out12", (PaymentRequestParams([1, 2], [], get_nonce=True),)), ), ) -def test_payment_request(client: Client, payment_request_params): +def test_payment_request(session: Session, payment_request_params): for txo in outputs: txo.payment_req_index = None @@ -148,10 +148,10 @@ def test_payment_request(client: Client, payment_request_params): for txo_index in params.txo_indices: outputs[txo_index].payment_req_index = i request_outputs.append(outputs[txo_index]) - nonce = misc.get_nonce(client) if params.get_nonce else None + nonce = misc.get_nonce(session) if params.get_nonce else None payment_reqs.append( make_payment_request( - client, + session, recipient_name="trezor.io", outputs=request_outputs, change_addresses=["tb1qkvwu9g3k2pdxewfqr7syz89r3gj557l3uuf9r9"], @@ -161,7 +161,7 @@ def test_payment_request(client: Client, payment_request_params): ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -174,7 +174,7 @@ def test_payment_request(client: Client, payment_request_params): # Ensure that the nonce has been invalidated. with pytest.raises(TrezorFailure, match="Invalid nonce in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -184,15 +184,18 @@ def test_payment_request(client: Client, payment_request_params): @pytest.mark.models(skip="safe3") -def test_payment_request_details(client: Client): +def test_payment_request_details(session: Session): + if session.model is models.T2B1: + pytest.skip("Details not implemented on T2B1") + # Test that payment request details are shown when requested. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None - nonce = misc.get_nonce(client) + nonce = misc.get_nonce(session) payment_reqs = [ make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[TextMemo("Invoice #87654321.")], @@ -200,12 +203,12 @@ def test_payment_request_details(client: Client): ) ] - with client: + with session.client as client: IF = InputFlowPaymentRequestDetails(client, outputs) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -216,16 +219,16 @@ def test_payment_request_details(client: Client): assert serialized_tx.hex() == SERIALIZED_TX -def test_payment_req_wrong_amount(client: Client): +def test_payment_req_wrong_amount(session: Session): # Test wrong total amount in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Decrease the total amount of the payment request. @@ -233,7 +236,7 @@ def test_payment_req_wrong_amount(client: Client): with pytest.raises(TrezorFailure, match="Invalid amount in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -242,18 +245,18 @@ def test_payment_req_wrong_amount(client: Client): ) -def test_payment_req_wrong_mac_refund(client: Client): +def test_payment_req_wrong_mac_refund(session: Session): # Test wrong MAC in payment request memo. memo = RefundMemo(parse_path("m/44h/1h/0h/1/0")) outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -263,7 +266,7 @@ def test_payment_req_wrong_mac_refund(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -274,7 +277,7 @@ def test_payment_req_wrong_mac_refund(client: Client): @pytest.mark.altcoin @pytest.mark.models("t2t1", reason="Dash not supported on Safe family") -def test_payment_req_wrong_mac_purchase(client: Client): +def test_payment_req_wrong_mac_purchase(session: Session): # Test wrong MAC in payment request memo. memo = CoinPurchaseMemo( amount="22.34904 DASH", @@ -286,11 +289,11 @@ def test_payment_req_wrong_mac_purchase(client: Client): outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], memos=[memo], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Corrupt the MAC value. @@ -300,7 +303,7 @@ def test_payment_req_wrong_mac_purchase(client: Client): with pytest.raises(TrezorFailure, match="Invalid address MAC"): btc.sign_tx( - client, + session, "Testnet", inputs, outputs, @@ -309,16 +312,16 @@ def test_payment_req_wrong_mac_purchase(client: Client): ) -def test_payment_req_wrong_output(client: Client): +def test_payment_req_wrong_output(session: Session): # Test wrong output in payment request. outputs[0].payment_req_index = 0 outputs[1].payment_req_index = 0 outputs[2].payment_req_index = None payment_req = make_payment_request( - client, + session, recipient_name="trezor.io", outputs=outputs[:2], - nonce=misc.get_nonce(client), + nonce=misc.get_nonce(session), ) # Use a different address in the second output. @@ -335,7 +338,7 @@ def test_payment_req_wrong_output(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature in payment request"): btc.sign_tx( - client, + session, "Testnet", inputs, fake_outputs, diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index 307823a9f3f..a2f96c04ed1 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -5,7 +5,7 @@ import pytest from trezorlib import btc, messages, models, tools -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import is_core @@ -78,7 +78,7 @@ def _check_error_message(value: bytes, model: models.TrezorModel, message: str): @with_bad_prevhashes -def test_invalid_prev_hash(client: Client, prev_hash): +def test_invalid_prev_hash(session: Session, prev_hash): inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), amount=123_456_789, @@ -93,12 +93,12 @@ def test_invalid_prev_hash(client: Client, prev_hash): ) with pytest.raises(TrezorFailure) as e: - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes={}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes={}) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_attack(client: Client, prev_hash): +def test_invalid_prev_hash_attack(session: Session, prev_hash): # prepare input with a valid prev-hash inp1 = messages.TxInputType( address_n=tools.parse_path("m/44h/0h/0h/0/0"), @@ -130,20 +130,20 @@ def attack_filter(msg): msg.tx.inputs[0].prev_hash = prev_hash return msg - with client, pytest.raises(TrezorFailure) as e: - client.set_filter(messages.TxAck, attack_filter) - if is_core(client): + with session, session.client as client, pytest.raises(TrezorFailure) as e: + session.set_filter(messages.TxAck, attack_filter) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) # check that injection was performed assert counter == 0 - _check_error_message(prev_hash, client.model, e.value.message) + _check_error_message(prev_hash, session.model, e.value.message) @with_bad_prevhashes -def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): +def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash): prev_tx = copy(PREV_TX) # smoke check: replace prev_hash with all zeros, reserialize and hash, try to sign @@ -161,16 +161,16 @@ def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): amount=99_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) # attack: replace prev_hash with an invalid value prev_tx.inputs[0].prev_hash = prev_hash tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with client, pytest.raises(TrezorFailure) as e: - if client.model is not models.T1B1: + with session, session.client as client, pytest.raises(TrezorFailure) as e: + if session.model is not models.T1B1: IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) - btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) - _check_error_message(prev_hash, client.model, e.value.message) + btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) + _check_error_message(prev_hash, session.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_replacement.py b/tests/device_tests/bitcoin/test_signtx_replacement.py index 97fe7e2d873..fd5db6a5027 100644 --- a/tests/device_tests/bitcoin/test_signtx_replacement.py +++ b/tests/device_tests/bitcoin/test_signtx_replacement.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -90,7 +90,7 @@ ) -def test_p2pkh_fee_bump(client: Client): +def test_p2pkh_fee_bump(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/4"), amount=174_998, @@ -116,8 +116,8 @@ def test_p2pkh_fee_bump(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_50f6f1), @@ -132,7 +132,7 @@ def test_p2pkh_fee_bump(client: Client): request_meta(TXHASH_beafc7), request_input(0, TXHASH_beafc7), request_output(0, TXHASH_beafc7), - (is_core(client), request_orig_input(0, TXHASH_50f6f1)), + (is_core(session), request_orig_input(0, TXHASH_50f6f1)), request_orig_input(0, TXHASH_50f6f1), request_orig_output(0, TXHASH_50f6f1), request_orig_output(1, TXHASH_50f6f1), @@ -145,7 +145,7 @@ def test_p2pkh_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -159,7 +159,7 @@ def test_p2pkh_fee_bump(client: Client): ) -def test_p2wpkh_op_return_fee_bump(client: Client): +def test_p2wpkh_op_return_fee_bump(session: Session): # Original input. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/1h/0/14"), @@ -190,9 +190,9 @@ def test_p2wpkh_op_return_fee_bump(client: Client): orig_index=1, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -207,7 +207,7 @@ def test_p2wpkh_op_return_fee_bump(client: Client): # txid 48bc29fc42a64b43d043b0b7b99b21aa39654234754608f791c60bcbd91a8e92 -def test_p2tr_fee_bump(client: Client): +def test_p2tr_fee_bump(session: Session): inp1 = messages.TxInputType( # tb1p8tvmvsvhsee73rhym86wt435qrqm92psfsyhy6a3n5gw455znnpqm8wald address_n=parse_path("m/86h/1h/0h/0/1"), @@ -243,8 +243,8 @@ def test_p2tr_fee_bump(client: Client): orig_index=1, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_8e4af7), @@ -269,7 +269,7 @@ def test_p2tr_fee_bump(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_CACHE_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -281,7 +281,7 @@ def test_p2tr_fee_bump(client: Client): ) -def test_p2wpkh_finalize(client: Client): +def test_p2wpkh_finalize(session: Session): # Original input with disabled RBF opt-in, i.e. we finalize the transaction. inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/0/2"), @@ -312,8 +312,8 @@ def test_p2wpkh_finalize(client: Client): orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_70f987), @@ -339,7 +339,7 @@ def test_p2wpkh_finalize(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -401,7 +401,7 @@ def test_p2wpkh_finalize(client: Client): ), ) def test_p2wpkh_payjoin( - client, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx + session, out1_amount, out2_amount, copayer_witness, fee_confirm, expected_tx ): # Original input. inp1 = messages.TxInputType( @@ -444,8 +444,8 @@ def test_p2wpkh_payjoin( orig_index=1, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_65b768), @@ -478,7 +478,7 @@ def test_p2wpkh_payjoin( ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -489,7 +489,7 @@ def test_p2wpkh_payjoin( assert serialized_tx.hex() == expected_tx -def test_p2wpkh_in_p2sh_remove_change(client: Client): +def test_p2wpkh_in_p2sh_remove_change(session: Session): # Test fee bump with change-output removal. Originally fee was 3780, now 98060. inp1 = messages.TxInputType( @@ -520,8 +520,8 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -553,7 +553,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -567,7 +567,7 @@ def test_p2wpkh_in_p2sh_remove_change(client: Client): ) -def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): +def test_p2wpkh_in_p2sh_fee_bump_from_external(session: Session): # Use the change output and an external output to bump the fee. # Originally fee was 3780, now 108060 (94280 from change and 10000 from external). @@ -599,8 +599,8 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): orig_index=0, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -634,7 +634,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -649,7 +649,7 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(client: Client): @pytest.mark.models("core") -def test_tx_meld(client: Client): +def test_tx_meld(session: Session): # Meld two original transactions into one, joining the change-outputs into a different one. inp1 = messages.TxInputType( @@ -720,8 +720,8 @@ def test_tx_meld(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -785,7 +785,7 @@ def test_tx_meld(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3], @@ -799,7 +799,7 @@ def test_tx_meld(client: Client): ) -def test_attack_steal_change(client: Client): +def test_attack_steal_change(session: Session): # Attempt to steal amount equivalent to the change in the original transaction by # hiding the fact that an output in the original transaction is a change-output. @@ -860,7 +860,7 @@ def test_attack_steal_change(client: Client): TrezorFailure, match="Original output is missing change-output parameters" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -870,7 +870,7 @@ def test_attack_steal_change(client: Client): @pytest.mark.models("core") -def test_attack_false_internal(client: Client): +def test_attack_false_internal(session: Session): # Falsely claim that an external input is internal in the original transaction. # If this were possible, it would allow an attacker to make it look like the # user was spending more in the original than they actually were, making it @@ -914,7 +914,7 @@ def test_attack_false_internal(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -922,7 +922,7 @@ def test_attack_false_internal(client: Client): ) -def test_attack_fake_int_input_amount(client: Client): +def test_attack_fake_int_input_amount(session: Session): # Give a fake input amount for an original internal input while giving the correct # amount for the replacement input. If an attacker could increase the amount of an # internal input in the original transaction, then they could bump the fee of the @@ -968,7 +968,7 @@ def test_attack_fake_int_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Bitcoin", [inp1], [out1, out2], @@ -977,7 +977,7 @@ def test_attack_fake_int_input_amount(client: Client): @pytest.mark.models("core") -def test_attack_fake_ext_input_amount(client: Client): +def test_attack_fake_ext_input_amount(session: Session): # Give a fake input amount for an original external input while giving the correct # amount for the replacement input. If an attacker could decrease the amount of an # external input in the original transaction, then they could steal the fee from @@ -1044,7 +1044,7 @@ def test_attack_fake_ext_input_amount(client: Client): TrezorFailure, match="Original input does not match current input" ): btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2], @@ -1052,7 +1052,7 @@ def test_attack_fake_ext_input_amount(client: Client): ) -def test_p2wpkh_invalid_signature(client: Client): +def test_p2wpkh_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. # Original input with disabled RBF opt-in, i.e. we finalize the transaction. @@ -1096,7 +1096,7 @@ def test_p2wpkh_invalid_signature(client: Client): with pytest.raises(TrezorFailure, match="Invalid signature"): btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -1105,7 +1105,7 @@ def test_p2wpkh_invalid_signature(client: Client): ) -def test_p2tr_invalid_signature(client: Client): +def test_p2tr_invalid_signature(session: Session): # Ensure that transaction replacement fails when the original signature is invalid. inp1 = messages.TxInputType( @@ -1151,4 +1151,4 @@ def test_p2tr_invalid_signature(client: Client): prev_txes = {TXHASH_8e4af7: prev_tx_invalid} with pytest.raises(TrezorFailure, match="Invalid signature"): - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=prev_txes) diff --git a/tests/device_tests/bitcoin/test_signtx_segwit.py b/tests/device_tests/bitcoin/test_signtx_segwit.py index 763626caef0..ef8c988ff39 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -47,7 +47,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2sh(client: Client, chunkify: bool): +def test_send_p2sh(session: Session, chunkify: bool): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -66,16 +66,16 @@ def test_send_p2sh(client: Client, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -90,7 +90,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1], [out1, out2], @@ -105,7 +105,7 @@ def test_send_p2sh(client: Client, chunkify: bool): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), # 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -124,13 +124,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -146,7 +146,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -156,11 +156,11 @@ def test_send_p2sh_change(client: Client): ) -def test_testnet_segwit_big_amount(client: Client): +def test_testnet_segwit_big_amount(session: Session): # This test is testing transaction with amount bigger than fits to uint32 address_n = parse_path("m/49h/1h/0h/0/0") address = btc.get_address( - client, + session, "Testnet", address_n, script_type=messages.InputScriptType.SPENDP2SHWITNESS, @@ -179,13 +179,13 @@ def test_testnet_segwit_big_amount(client: Client): amount=2**32 + 1, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(prev_hash), @@ -198,7 +198,7 @@ def test_testnet_segwit_big_amount(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} + session, "Testnet", [inp1], [out1], prev_txes={prev_hash: prev_tx} ) # Transaction does not exist on the blockchain, not using assert_tx_matches() assert ( @@ -208,12 +208,12 @@ def test_testnet_segwit_big_amount(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input: 338e2d02e0eaf8848e38925904e51546cf22e58db5b1860c4a0e72b69c56afe5 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet" ).node for i in range(1, 4) ] @@ -241,7 +241,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_338e2d), @@ -254,10 +254,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -265,10 +265,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -278,7 +278,7 @@ def test_send_multisig_1(client: Client): ) -def test_attack_change_input_address(client: Client): +def test_attack_change_input_address(session: Session): # Simulates an attack where the user is coerced into unknowingly # transferring funds from one account to another one of their accounts, # potentially resulting in privacy issues. @@ -303,17 +303,17 @@ def test_attack_change_input_address(client: Client): ) # Test if the transaction can be signed normally. - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), # The user is required to confirm transfer to another account. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -328,7 +328,7 @@ def test_attack_change_input_address(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -349,15 +349,15 @@ def attack_processor(msg): return msg # Now run the attack, must trigger the exception - with client: - client.set_filter(messages.TxAck, attack_processor) + with session: + session.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) -def test_attack_mixed_inputs(client: Client): +def test_attack_mixed_inputs(session: Session): TRUE_AMOUNT = 123_456_789 FAKE_AMOUNT = 120_000_000 @@ -389,11 +389,11 @@ def test_attack_mixed_inputs(client: Client): request_output(0), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ), ( - is_core(client), + is_core(session), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), ), messages.ButtonRequest(code=messages.ButtonRequestType.FeeOverThreshold), @@ -417,16 +417,16 @@ def test_attack_mixed_inputs(client: Client): request_finished(), ] - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 asks for first input for witness again expected_responses.insert(-2, request_input(0)) - with client: + with session: # Sign unmodified transaction. # "Fee over threshold" warning is displayed - fee is the whole TRUE_AMOUNT - client.set_expected_responses(expected_responses) + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], @@ -436,7 +436,7 @@ def test_attack_mixed_inputs(client: Client): # In Phase 1 make the user confirm a lower value of the segwit input. inp2.amount = FAKE_AMOUNT - if client.model is models.T1B1: + if session.model is models.T1B1: # T1 fails as soon as it encounters the fake amount. expected_responses = ( expected_responses[:4] + expected_responses[5:15] + [messages.Failure()] @@ -446,10 +446,10 @@ def test_attack_mixed_inputs(client: Client): expected_responses[:4] + expected_responses[5:16] + [messages.Failure()] ) - with pytest.raises(TrezorFailure) as e, client: - client.set_expected_responses(expected_responses) + with pytest.raises(TrezorFailure) as e, session: + session.set_expected_responses(expected_responses) btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index 0c779c777ef..920b0bf48b7 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import H_, parse_path from ...bip32 import deserialize @@ -61,7 +61,7 @@ ) -def test_send_p2sh(client: Client): +def test_send_p2sh(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -82,16 +82,16 @@ def test_send_p2sh(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_20912f), @@ -106,7 +106,7 @@ def test_send_p2sh(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -116,7 +116,7 @@ def test_send_p2sh(client: Client): ) -def test_send_p2sh_change(client: Client): +def test_send_p2sh_change(session: Session): # input tx: 20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337 inp1 = messages.TxInputType( @@ -137,13 +137,13 @@ def test_send_p2sh_change(client: Client): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -159,7 +159,7 @@ def test_send_p2sh_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) # Transaction does not exist on the blockchain, not using assert_tx_matches() @@ -169,7 +169,7 @@ def test_send_p2sh_change(client: Client): ) -def test_send_native(client: Client): +def test_send_native(session: Session): # input tx: b36780ceb86807ca6e7535a6fd418b1b788cb9b227d2c8a26a0de295e523219e inp1 = messages.TxInputType( @@ -190,16 +190,16 @@ def test_send_native(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b36780), @@ -214,7 +214,7 @@ def test_send_native(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -224,7 +224,7 @@ def test_send_native(client: Client): ) -def test_send_to_taproot(client: Client): +def test_send_to_taproot(session: Session): # input tx: ec16dc5a539c5d60001a7471c37dbb0b5294c289c77df8bd07870b30d73e2231 inp1 = messages.TxInputType( @@ -244,9 +244,9 @@ def test_send_to_taproot(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=10_000 - 7_000 - 200, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -256,7 +256,7 @@ def test_send_to_taproot(client: Client): ) -def test_send_native_change(client: Client): +def test_send_native_change(session: Session): # input tx: fcb3f5436224900afdba50e9e763d98b920dfed056e552040d99ea9bc03a9d83 inp1 = messages.TxInputType( @@ -277,13 +277,13 @@ def test_send_native_change(client: Client): script_type=messages.OutputScriptType.PAYTOWITNESS, amount=100_000 - 40_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -300,7 +300,7 @@ def test_send_native_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -310,7 +310,7 @@ def test_send_native_change(client: Client): ) -def test_send_both(client: Client): +def test_send_both(session: Session): # input 1 tx: 65047a2b107d6301d72d4a1e49e7aea9cf06903fdc4ae74a4a9bba9bc1a414d2 # input 2 tx: d159fd2fcb5854a7c8b275d598765a446f1e2ff510bf077545a404a0c9db65f7 @@ -344,21 +344,21 @@ def test_send_both(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_65047a), @@ -382,7 +382,7 @@ def test_send_both(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2], [out1, out2, out3], @@ -397,12 +397,12 @@ def test_send_both(client: Client): @pytest.mark.multisig -def test_send_multisig_1(client: Client): +def test_send_multisig_1(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -433,7 +433,7 @@ def test_send_multisig_1(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -449,10 +449,10 @@ def test_send_multisig_1(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -460,10 +460,10 @@ def test_send_multisig_1(client: Client): # sign with third key inp1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -474,12 +474,12 @@ def test_send_multisig_1(client: Client): @pytest.mark.multisig -def test_send_multisig_2(client: Client): +def test_send_multisig_2(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -510,7 +510,7 @@ def test_send_multisig_2(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -526,10 +526,10 @@ def test_send_multisig_2(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -537,10 +537,10 @@ def test_send_multisig_2(client: Client): # sign with first key inp1.address_n[2] = H_(1) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -551,12 +551,12 @@ def test_send_multisig_2(client: Client): @pytest.mark.multisig -def test_send_multisig_3_change(client: Client): +def test_send_multisig_3_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/84h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -595,7 +595,7 @@ def test_send_multisig_3_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -611,13 +611,13 @@ def test_send_multisig_3_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -626,13 +626,13 @@ def test_send_multisig_3_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -643,12 +643,12 @@ def test_send_multisig_3_change(client: Client): @pytest.mark.multisig -def test_send_multisig_4_change(client: Client): +def test_send_multisig_4_change(session: Session): # input tx: b9abfa0d4a28f6f25e1f6c0f974bfc3f7c5a44c4d381b1796e3fbeef51b560a6 nodes = [ btc.get_public_node( - client, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" + session, parse_path(f"m/49h/1h/{index}h"), coin_name="Testnet" ) for index in range(1, 4) ] @@ -687,7 +687,7 @@ def test_send_multisig_4_change(client: Client): request_output(0), messages.ButtonRequest(code=B.UnknownDerivationPath), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_b9abfa), @@ -703,13 +703,13 @@ def test_send_multisig_4_change(client: Client): request_finished(), ] - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) # store signature @@ -718,13 +718,13 @@ def test_send_multisig_4_change(client: Client): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with client: - client.set_expected_responses(expected_responses) - if is_core(client): + with session, session.client as client: + session.set_expected_responses(expected_responses) + if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) assert_tx_matches( @@ -734,7 +734,7 @@ def test_send_multisig_4_change(client: Client): ) -def test_multisig_mismatch_inputs_single(client: Client): +def test_multisig_mismatch_inputs_single(session: Session): # Ensure that if there is a non-multisig input, then a multisig output # will not be identified as a change output. @@ -788,18 +788,18 @@ def test_multisig_mismatch_inputs_single(client: Client): amount=100_000 + 100_000 - 50_000 - 10_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), # Ensure that the multisig output is not identified as a change output. messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_1c022d), @@ -824,7 +824,7 @@ def test_multisig_mismatch_inputs_single(client: Client): ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API_TESTNET ) assert_tx_matches( diff --git a/tests/device_tests/bitcoin/test_signtx_taproot.py b/tests/device_tests/bitcoin/test_signtx_taproot.py index f548154ae70..0453474af91 100644 --- a/tests/device_tests/bitcoin/test_signtx_taproot.py +++ b/tests/device_tests/bitcoin/test_signtx_taproot.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path @@ -64,7 +64,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_send_p2tr(client: Client, chunkify: bool): +def test_send_p2tr(session: Session, chunkify: bool): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -79,13 +79,13 @@ def test_send_p2tr(client: Client, chunkify: bool): amount=4_450, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -94,7 +94,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify + session, "Testnet", [inp1], [out1], prev_txes=TX_API, chunkify=chunkify ) assert_tx_matches( @@ -104,7 +104,7 @@ def test_send_p2tr(client: Client, chunkify: bool): ) -def test_send_two_with_change(client: Client): +def test_send_two_with_change(session: Session): inp1 = messages.TxInputType( # tb1pswrqtykue8r89t9u4rprjs0gt4qzkdfuursfnvqaa3f2yql07zmq8s8a5u address_n=parse_path("m/86h/1h/0h/0/0"), @@ -133,14 +133,14 @@ def test_send_two_with_change(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, amount=6_800 + 13_000 - 200 - 15_000, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -153,7 +153,7 @@ def test_send_two_with_change(client: Client): ] ) _, serialized_tx = btc.sign_tx( - client, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API + session, "Testnet", [inp1, inp2], [out1, out2], prev_txes=TX_API ) assert_tx_matches( @@ -163,7 +163,7 @@ def test_send_two_with_change(client: Client): ) -def test_send_mixed(client: Client): +def test_send_mixed(session: Session): inp1 = messages.TxInputType( # 2MutHjgAXkqo3jxX2DZWorLAckAnwTxSM9V address_n=parse_path("m/49h/1h/1h/0/0"), @@ -222,8 +222,8 @@ def test_send_mixed(client: Client): script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ # process inputs request_input(0), @@ -233,19 +233,19 @@ def test_send_mixed(client: Client): # approve outputs request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(2), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(3), messages.ButtonRequest(code=B.ConfirmOutput), request_output(4), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), # verify inputs request_input(0), @@ -293,12 +293,12 @@ def test_send_mixed(client: Client): request_input(0), request_input(1), request_input(2), - (client.model is models.T1B1, request_input(3)), + (session.model is models.T1B1, request_input(3)), request_finished(), ] ) _, serialized_tx = btc.sign_tx( - client, + session, "Testnet", [inp1, inp2, inp3, inp4], [out1, out2, out3, out4, out5], @@ -312,13 +312,12 @@ def test_send_mixed(client: Client): ) -def test_attack_script_type(client: Client): +def test_attack_script_type(session: Session): # Scenario: The attacker falsely claims that the transaction is Taproot-only to # avoid prev tx streaming and gives a lower amount for one of the inputs. The # correct input types and amounts are revelaled only in step6_sign_segwit_inputs() # to get a valid signature. This results in a transaction which pays a fee much # larger than what the user confirmed. - inp1 = messages.TxInputType( address_n=parse_path("m/84h/1h/0h/1/0"), amount=7_289_000, @@ -354,16 +353,16 @@ def attack_processor(msg): return msg - with client: - client.set_filter(messages.TxAck, attack_processor) - client.set_expected_responses( + with session: + session.set_filter(messages.TxAck, attack_processor) + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), - (is_core(client), messages.ButtonRequest(code=B.SignTx)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.SignTx)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_input(1), @@ -374,7 +373,7 @@ def attack_processor(msg): ] ) with pytest.raises(TrezorFailure) as exc: - btc.sign_tx(client, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API) assert exc.value.code == messages.FailureType.ProcessError assert exc.value.message.endswith("Transaction has changed during signing") @@ -392,7 +391,7 @@ def attack_processor(msg): "tb1pllllllllllllllllllllllllllllllllllllllllllllallllscqgl4zhn", ), ) -def test_send_invalid_address(client: Client, address: str): +def test_send_invalid_address(session: Session, address: str): inp1 = messages.TxInputType( # tb1pn2d0yjeedavnkd8z8lhm566p0f2utm3lgvxrsdehnl94y34txmts5s7t4c address_n=parse_path("m/86h/1h/0h/1/0"), @@ -407,12 +406,12 @@ def test_send_invalid_address(client: Client, address: str): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure): - client.set_expected_responses( + with session, pytest.raises(TrezorFailure): + session.set_expected_responses( [ request_input(0), request_output(0), messages.Failure, ] ) - btc.sign_tx(client, "Testnet", [inp1], [out1], prev_txes=TX_API) + btc.sign_tx(session, "Testnet", [inp1], [out1], prev_txes=TX_API) diff --git a/tests/device_tests/bitcoin/test_verifymessage.py b/tests/device_tests/bitcoin/test_verifymessage.py index 86389d8a515..88907e318dc 100644 --- a/tests/device_tests/bitcoin/test_verifymessage.py +++ b/tests/device_tests/bitcoin/test_verifymessage.py @@ -19,12 +19,12 @@ import pytest from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -35,9 +35,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "mirio8q3gtv7fhdnmb3TpZ4EuafdzSs7zL", bytes.fromhex( @@ -49,9 +49,9 @@ def test_message_testnet(client: Client): @pytest.mark.altcoin -def test_message_grs(client: Client): +def test_message_grs(session: Session): ret = btc.verify_message( - client, + session, "Groestlcoin", "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM", base64.b64decode( @@ -62,9 +62,9 @@ def test_message_grs(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -76,7 +76,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -88,7 +88,7 @@ def test_message_verify(client: Client): # uncompressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1JwSSubhmg6iPtRjtyqhUYYH7bZg3Lfy1T", bytes.fromhex( @@ -100,7 +100,7 @@ def test_message_verify(client: Client): # compressed pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -112,7 +112,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -124,7 +124,7 @@ def test_message_verify(client: Client): # compressed pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "1C7zdTfnkzmr13HfA2vNm5SJYRK6nEKyq8", bytes.fromhex( @@ -136,7 +136,7 @@ def test_message_verify(client: Client): # trezor pubkey - OK res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -148,7 +148,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -160,7 +160,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -172,9 +172,9 @@ def test_message_verify(client: Client): @pytest.mark.altcoin -def test_message_verify_bcash(client: Client): +def test_message_verify_bcash(session: Session): res = btc.verify_message( - client, + session, "Bcash", "bitcoincash:qqj22md58nm09vpwsw82fyletkxkq36zxyxh322pru", bytes.fromhex( @@ -185,9 +185,9 @@ def test_message_verify_bcash(client: Client): assert res is True -def test_verify_bitcoind(client: Client): +def test_verify_bitcoind(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "1KzXE97kV7DrpxCViCN3HbGbiKhzzPM7TQ", bytes.fromhex( @@ -199,12 +199,12 @@ def test_verify_bitcoind(client: Client): assert res is True -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -214,7 +214,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit.py b/tests/device_tests/bitcoin/test_verifymessage_segwit.py index 84f04442646..9c3169e0c78 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "2N4VkePSzKH2sv5YBikLHGvzUYvfPxV6zS9", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "3L6TyTisPBmrDAj6RoKmDzNnj4eQi54gD2", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "3CwYaeWxhpXXiHue3ciQez1DLaTEAXcKa1", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py index 5bea51f7dc1..3a4ed68e5da 100644 --- a/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py +++ b/tests/device_tests/bitcoin/test_verifymessage_segwit_native.py @@ -15,12 +15,12 @@ # If not, see . from trezorlib import btc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_message_long(client: Client): +def test_message_long(session: Session): ret = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -31,9 +31,9 @@ def test_message_long(client: Client): assert ret is True -def test_message_testnet(client: Client): +def test_message_testnet(session: Session): ret = btc.verify_message( - client, + session, "Testnet", "tb1qyjjkmdpu7metqt5r36jf872a34syws336p3n3p", bytes.fromhex( @@ -44,9 +44,9 @@ def test_message_testnet(client: Client): assert ret is True -def test_message_verify(client: Client): +def test_message_verify(session: Session): res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -58,7 +58,7 @@ def test_message_verify(client: Client): # no script type res = btc.verify_message( - client, + session, "Bitcoin", "bc1qannfxke2tfd4l7vhepehpvt05y83v3qsf6nfkk", bytes.fromhex( @@ -70,7 +70,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong sig res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -82,7 +82,7 @@ def test_message_verify(client: Client): # trezor pubkey - FAIL - wrong msg res = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -93,12 +93,12 @@ def test_message_verify(client: Client): assert res is False -def test_verify_utf(client: Client): +def test_verify_utf(session: Session): words_nfkd = "Pr\u030ci\u0301s\u030cerne\u030c z\u030clut\u030couc\u030cky\u0301 ku\u030an\u030c u\u0301pe\u030cl d\u030ca\u0301belske\u0301 o\u0301dy za\u0301ker\u030cny\u0301 uc\u030cen\u030c be\u030cz\u030ci\u0301 pode\u0301l zo\u0301ny u\u0301lu\u030a" words_nfc = "P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f" res_nfkd = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( @@ -108,7 +108,7 @@ def test_verify_utf(client: Client): ) res_nfc = btc.verify_message( - client, + session, "Bitcoin", "bc1qyjjkmdpu7metqt5r36jf872a34syws33s82q2j", bytes.fromhex( diff --git a/tests/device_tests/bitcoin/test_zcash.py b/tests/device_tests/bitcoin/test_zcash.py index dc959199a35..adb99589150 100644 --- a/tests/device_tests/bitcoin/test_zcash.py +++ b/tests/device_tests/bitcoin/test_zcash.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -57,7 +57,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_v3_not_supported(client: Client): +def test_v3_not_supported(session: Session): # prevout: aaf51e4606c264e47e5c42c958fe4cf1539c5172684721e38e69f4ef634d75dc:1 # input 1: 3.0 TAZ @@ -75,9 +75,9 @@ def test_v3_not_supported(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client, pytest.raises(TrezorFailure, match="DataError"): + with session, pytest.raises(TrezorFailure, match="DataError"): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -88,7 +88,7 @@ def test_v3_not_supported(client: Client): ) -def test_one_one_fee_sapling(client: Client): +def test_one_one_fee_sapling(session: Session): # prevout: e3820602226974b1dd87b7113cc8aea8c63e5ae29293991e7bfa80c126930368:0 # input 1: 3.0 TAZ @@ -106,13 +106,13 @@ def test_one_one_fee_sapling(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -128,7 +128,7 @@ def test_one_one_fee_sapling(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -145,7 +145,7 @@ def test_one_one_fee_sapling(client: Client): ) -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -161,7 +161,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -170,7 +170,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_old_versions(client: Client): +def test_spend_old_versions(session: Session): # NOTE: fake input tx used input_v1 = messages.TxInputType( @@ -210,9 +210,9 @@ def test_spend_old_versions(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: + with session: _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", inputs, [output], @@ -229,7 +229,7 @@ def test_spend_old_versions(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -259,14 +259,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_e38206), @@ -289,7 +289,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], diff --git a/tests/device_tests/cardano/test_address_public_key.py b/tests/device_tests/cardano/test_address_public_key.py index d7c02e6b6dc..d8ec9288eb7 100644 --- a/tests/device_tests/cardano/test_address_public_key.py +++ b/tests/device_tests/cardano/test_address_public_key.py @@ -22,7 +22,7 @@ get_public_key, parse_optional_bytes, ) -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import CardanoAddressType, CardanoDerivationType from trezorlib.tools import parse_path @@ -48,15 +48,15 @@ "cardano/get_base_address.derivations.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_cardano_get_address(client: Client, chunkify: bool, parameters, result): - client.init_device(new_session=True, derive_cardano=True) +def test_cardano_get_address(session: Session, chunkify: bool, parameters, result): + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] address = get_address( - client, + session, address_parameters=create_address_parameters( address_type=getattr( CardanoAddressType, parameters["address_type"].upper() @@ -94,17 +94,17 @@ def test_cardano_get_address(client: Client, chunkify: bool, parameters, result) "cardano/get_public_key.slip39.json", "cardano/get_public_key.derivations.json", ) -def test_cardano_get_public_key(client: Client, parameters, result): - with client: - IF = InputFlowShowXpubQRCode(client, passphrase=bool(client.ui.passphrase)) +def test_cardano_get_public_key(session: Session, parameters, result): + with session, session.client as client: + IF = InputFlowShowXpubQRCode(client, passphrase=bool(session.passphrase)) client.set_input_flow(IF.get()) - client.init_device(new_session=True, derive_cardano=True) + # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ parameters.get("derivation_type", "ICARUS_TREZOR") ] key = get_public_key( - client, parse_path(parameters["path"]), derivation_type, show_display=True + session, parse_path(parameters["path"]), derivation_type, show_display=True ) assert key.node.public_key.hex() == result["public_key"] diff --git a/tests/device_tests/cardano/test_derivations.py b/tests/device_tests/cardano/test_derivations.py index 656c31a8bde..148a0a85037 100644 --- a/tests/device_tests/cardano/test_derivations.py +++ b/tests/device_tests/cardano/test_derivations.py @@ -17,7 +17,7 @@ import pytest from trezorlib.cardano import get_public_key -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import CardanoDerivationType as D from trezorlib.tools import parse_path @@ -26,35 +26,29 @@ pytestmark = [ pytest.mark.altcoin, - pytest.mark.cardano, pytest.mark.models("core"), ] ADDRESS_N = parse_path("m/1852h/1815h/0h") -def test_bad_session(client: Client): - client.init_device(new_session=True) +def test_bad_session(session: Session): with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) + get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) - client.init_device(new_session=True, derive_cardano=False) - with pytest.raises(TrezorFailure, match="not enabled"): - get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) +def test_ledger_available_without_cardano(session: Session): + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) -def test_ledger_available_always(client: Client): - client.init_device(new_session=True, derive_cardano=False) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) - client.init_device(new_session=True, derive_cardano=True) - get_public_key(client, ADDRESS_N, derivation_type=D.LEDGER) +@pytest.mark.cardano # derive_cardano=True +def test_ledger_available_with_cardano(session: Session): + get_public_key(session, ADDRESS_N, derivation_type=D.LEDGER) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @pytest.mark.parametrize("derivation_type", D) # try ALL derivation types -def test_derivation_irrelevant_on_slip39(client: Client, derivation_type): - client.init_device(new_session=True, derive_cardano=False) - pubkey = get_public_key(client, ADDRESS_N, derivation_type=D.ICARUS) - test_pubkey = get_public_key(client, ADDRESS_N, derivation_type=derivation_type) +def test_derivation_irrelevant_on_slip39(session: Session, derivation_type): + pubkey = get_public_key(session, ADDRESS_N, derivation_type=D.ICARUS) + test_pubkey = get_public_key(session, ADDRESS_N, derivation_type=derivation_type) assert pubkey == test_pubkey diff --git a/tests/device_tests/cardano/test_get_native_script_hash.py b/tests/device_tests/cardano/test_get_native_script_hash.py index 63ee56d16fb..2859d69a41a 100644 --- a/tests/device_tests/cardano/test_get_native_script_hash.py +++ b/tests/device_tests/cardano/test_get_native_script_hash.py @@ -18,7 +18,7 @@ from trezorlib import messages from trezorlib.cardano import get_native_script_hash, parse_native_script -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import parametrize_using_common_fixtures @@ -32,11 +32,9 @@ @parametrize_using_common_fixtures( "cardano/get_native_script_hash.json", ) -def test_cardano_get_native_script_hash(client: Client, parameters, result): - client.init_device(new_session=True, derive_cardano=True) - +def test_cardano_get_native_script_hash(session: Session, parameters, result): native_script_hash = get_native_script_hash( - client, + session, native_script=parse_native_script(parameters["native_script"]), display_format=messages.CardanoNativeScriptHashDisplayFormat.__members__[ parameters["display_format"] diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index 447b2596d10..83b8a075825 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -18,6 +18,7 @@ from trezorlib import cardano, device, messages from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure @@ -58,9 +59,9 @@ def show_details_input_flow(client: Client): "cardano/sign_tx.plutus.json", "cardano/sign_tx.slip39.json", ) -def test_cardano_sign_tx(client: Client, parameters, result): +def test_cardano_sign_tx(session: Session, parameters, result): response = call_sign_tx( - client, + session, parameters, input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), ) @@ -68,8 +69,8 @@ def test_cardano_sign_tx(client: Client, parameters, result): @parametrize_using_common_fixtures("cardano/sign_tx.show_details.json") -def test_cardano_sign_tx_show_details(client: Client, parameters, result): - response = call_sign_tx(client, parameters, show_details_input_flow, chunkify=True) +def test_cardano_sign_tx_show_details(session: Session, parameters, result): + response = call_sign_tx(session, parameters, show_details_input_flow, chunkify=True) assert response == _transform_expected_result(result) @@ -79,13 +80,13 @@ def test_cardano_sign_tx_show_details(client: Client, parameters, result): "cardano/sign_tx.multisig.failed.json", "cardano/sign_tx.plutus.failed.json", ) -def test_cardano_sign_tx_failed(client: Client, parameters, result): +def test_cardano_sign_tx_failed(session: Session, parameters, result): with pytest.raises(TrezorFailure, match=result["error_message"]): - call_sign_tx(client, parameters, None) + call_sign_tx(session, parameters, None) -def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = False): - client.init_device(new_session=True, derive_cardano=True) +def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool = False): + # session.init_device(new_session=True, derive_cardano=True) signing_mode = messages.CardanoTxSigningMode.__members__[parameters["signing_mode"]] inputs = [cardano.parse_input(i) for i in parameters["inputs"]] @@ -116,18 +117,18 @@ def call_sign_tx(client: Client, parameters, input_flow=None, chunkify: bool = F if parameters.get("security_checks") == "prompt": device.apply_settings( - client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily + session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) else: - device.apply_settings(client, safety_checks=messages.SafetyCheckLevel.Strict) + device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) - with client: + with session.client as client: if input_flow is not None: client.watch_layout() client.set_input_flow(input_flow(client)) return cardano.sign_tx( - client=client, + session=session, signing_mode=signing_mode, inputs=inputs, outputs=outputs, diff --git a/tests/device_tests/eos/test_get_public_key.py b/tests/device_tests/eos/test_get_public_key.py index 1b518e95f2e..d99c54cb2b6 100644 --- a/tests/device_tests/eos/test_get_public_key.py +++ b/tests/device_tests/eos/test_get_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.eos import get_public_key from trezorlib.tools import parse_path @@ -28,12 +28,12 @@ @pytest.mark.eos @pytest.mark.models("t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_eos_get_public_key(client: Client): - with client: +def test_eos_get_public_key(session: Session): + with session.client as client: IF = InputFlowShowXpubQRCode(client) client.set_input_flow(IF.get()) public_key = get_public_key( - client, parse_path("m/44h/194h/0h/0/0"), show_display=True + session, parse_path("m/44h/194h/0h/0/0"), show_display=True ) assert ( public_key.wif_public_key @@ -43,7 +43,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02015fabe197c955036bab25f4e7c16558f9f672f9f625314ab1ec8f64f7b1198e" ) - public_key = get_public_key(client, parse_path("m/44h/194h/0h/0/1")) + public_key = get_public_key(session, parse_path("m/44h/194h/0h/0/1")) assert ( public_key.wif_public_key == "EOS5d1VP15RKxT4dSakWu2TFuEgnmaGC2ckfSvQwND7pZC1tXkfLP" @@ -52,7 +52,7 @@ def test_eos_get_public_key(client: Client): public_key.raw_public_key.hex() == "02608bc2c431521dee0b9d5f2fe34053e15fc3b20d2895e0abda857b9ed8e77a78" ) - public_key = get_public_key(client, parse_path("m/44h/194h/1h/0/0")) + public_key = get_public_key(session, parse_path("m/44h/194h/1h/0/0")) assert ( public_key.wif_public_key == "EOS7UuNeTf13nfcG85rDB7AHGugZi4C4wJ4ft12QRotqNfxdV2NvP" diff --git a/tests/device_tests/eos/test_signtx.py b/tests/device_tests/eos/test_signtx.py index 57fd051bb4a..54ebece6a9a 100644 --- a/tests/device_tests/eos/test_signtx.py +++ b/tests/device_tests/eos/test_signtx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import eos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import EosSignedTx from trezorlib.tools import parse_path @@ -35,7 +35,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_eos_signtx_transfer_token(client: Client, chunkify: bool): +def test_eos_signtx_transfer_token(session: Session, chunkify: bool): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -60,8 +60,8 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -69,7 +69,7 @@ def test_eos_signtx_transfer_token(client: Client, chunkify: bool): ) -def test_eos_signtx_buyram(client: Client): +def test_eos_signtx_buyram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -93,8 +93,8 @@ def test_eos_signtx_buyram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -102,7 +102,7 @@ def test_eos_signtx_buyram(client: Client): ) -def test_eos_signtx_buyrambytes(client: Client): +def test_eos_signtx_buyrambytes(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -126,8 +126,8 @@ def test_eos_signtx_buyrambytes(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -135,7 +135,7 @@ def test_eos_signtx_buyrambytes(client: Client): ) -def test_eos_signtx_sellram(client: Client): +def test_eos_signtx_sellram(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -155,8 +155,8 @@ def test_eos_signtx_sellram(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -164,7 +164,7 @@ def test_eos_signtx_sellram(client: Client): ) -def test_eos_signtx_delegate(client: Client): +def test_eos_signtx_delegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -190,8 +190,8 @@ def test_eos_signtx_delegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -199,7 +199,7 @@ def test_eos_signtx_delegate(client: Client): ) -def test_eos_signtx_undelegate(client: Client): +def test_eos_signtx_undelegate(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -224,8 +224,8 @@ def test_eos_signtx_undelegate(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -233,7 +233,7 @@ def test_eos_signtx_undelegate(client: Client): ) -def test_eos_signtx_refund(client: Client): +def test_eos_signtx_refund(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -253,8 +253,8 @@ def test_eos_signtx_refund(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -262,7 +262,7 @@ def test_eos_signtx_refund(client: Client): ) -def test_eos_signtx_linkauth(client: Client): +def test_eos_signtx_linkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -287,8 +287,8 @@ def test_eos_signtx_linkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -296,7 +296,7 @@ def test_eos_signtx_linkauth(client: Client): ) -def test_eos_signtx_unlinkauth(client: Client): +def test_eos_signtx_unlinkauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -320,8 +320,8 @@ def test_eos_signtx_unlinkauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -329,7 +329,7 @@ def test_eos_signtx_unlinkauth(client: Client): ) -def test_eos_signtx_updateauth(client: Client): +def test_eos_signtx_updateauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -376,8 +376,8 @@ def test_eos_signtx_updateauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -385,7 +385,7 @@ def test_eos_signtx_updateauth(client: Client): ) -def test_eos_signtx_deleteauth(client: Client): +def test_eos_signtx_deleteauth(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -405,8 +405,8 @@ def test_eos_signtx_deleteauth(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -414,7 +414,7 @@ def test_eos_signtx_deleteauth(client: Client): ) -def test_eos_signtx_vote(client: Client): +def test_eos_signtx_vote(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -468,8 +468,8 @@ def test_eos_signtx_vote(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -477,7 +477,7 @@ def test_eos_signtx_vote(client: Client): ) -def test_eos_signtx_vote_proxy(client: Client): +def test_eos_signtx_vote_proxy(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -497,8 +497,8 @@ def test_eos_signtx_vote_proxy(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -506,7 +506,7 @@ def test_eos_signtx_vote_proxy(client: Client): ) -def test_eos_signtx_unknown(client: Client): +def test_eos_signtx_unknown(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -526,8 +526,8 @@ def test_eos_signtx_unknown(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -535,7 +535,7 @@ def test_eos_signtx_unknown(client: Client): ) -def test_eos_signtx_newaccount(client: Client): +def test_eos_signtx_newaccount(session: Session): transaction = { "expiration": "2018-07-14T10:43:28", "ref_block_num": 6439, @@ -602,8 +602,8 @@ def test_eos_signtx_newaccount(client: Client): "transaction_extensions": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature @@ -611,7 +611,7 @@ def test_eos_signtx_newaccount(client: Client): ) -def test_eos_signtx_setcontract(client: Client): +def test_eos_signtx_setcontract(session: Session): transaction = { "expiration": "2018-06-19T13:29:53", "ref_block_num": 30587, @@ -638,8 +638,8 @@ def test_eos_signtx_setcontract(client: Client): "context_free_data": [], } - with client: - resp = eos.sign_tx(client, ADDRESS_N, transaction, CHAIN_ID) + with session: + resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( resp.signature diff --git a/tests/device_tests/ethereum/test_definitions.py b/tests/device_tests/ethereum/test_definitions.py index 314189ca597..9cc3fd57043 100644 --- a/tests/device_tests/ethereum/test_definitions.py +++ b/tests/device_tests/ethereum/test_definitions.py @@ -5,7 +5,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -40,60 +40,60 @@ } -def test_builtin(client: Client) -> None: +def test_builtin(session: Session) -> None: # Ethereum (SLIP-44 60, chain_id 1) will sign without any definitions provided - ethereum.sign_tx(client, **DEFAULT_TX_PARAMS) + ethereum.sign_tx(session, **DEFAULT_TX_PARAMS) -def test_chain_id_allowed(client: Client) -> None: +def test_chain_id_allowed(session: Session) -> None: # Any chain id is allowed as long as the SLIP44 stays the same params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=222222) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_disallowed(client: Client) -> None: +def test_slip44_disallowed(session: Session) -> None: # SLIP44 is not allowed without a valid network definition params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0")) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) -def test_slip44_external(client: Client) -> None: +def test_slip44_external(session: Session) -> None: # to use a non-default SLIP44, a valid network definition must be provided network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_slip44_external_disallowed(client: Client) -> None: +def test_slip44_external_disallowed(session: Session) -> None: # network definition does not allow a different SLIP44 network = common.encode_network(chain_id=66666, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/55555h/0h/0/0"), chain_id=66666) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_chain_id_mismatch(client: Client) -> None: +def test_chain_id_mismatch(session: Session) -> None: # network definition for a different chain id will be rejected network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_TX_PARAMS.copy() params.update(chain_id=55555) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) -def test_definition_does_not_override_builtin(client: Client) -> None: +def test_definition_does_not_override_builtin(session: Session) -> None: # The builtin definition for Ethereum (SLIP44 60, chain_id 1) will be used # even if a valid definition with a different SLIP44 is provided network = common.encode_network(chain_id=1, slip44=66666) params = DEFAULT_TX_PARAMS.copy() params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=1) with pytest.raises(TrezorFailure, match="Forbidden key path"): - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO: test that the builtin definition will not show different symbol @@ -102,50 +102,50 @@ def test_definition_does_not_override_builtin(client: Client) -> None: # all tokens are currently accepted, we would need to check the screenshots -def test_builtin_token(client: Client) -> None: +def test_builtin_token(session: Session) -> None: # The builtin definition for USDT (ERC20) will be used even if not provided params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN) - ethereum.sign_tx(client, **params) + ethereum.sign_tx(session, **params) # TODO check that USDT symbol is shown # TODO: test_builtin_token_not_overriden (builtin definition is used even if a custom one is provided) -def test_external_token(client: Client) -> None: +def test_external_token(session: Session) -> None: # A valid token definition must be provided to use a non-builtin token token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=1, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) - ethereum.sign_tx(client, **params, definitions=common.make_defs(None, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(None, token)) # TODO check that FakeTok symbol is shown -def test_external_chain_without_token(client: Client) -> None: - with client: +def test_external_chain_without_token(session: Session) -> None: + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) # when using an external chains, unknown tokens are allowed network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_BUILTIN_TOKEN, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, None)) # TODO check that UNKN token is used, FAKE network -def test_external_chain_token_ok(client: Client) -> None: +def test_external_chain_token_ok(session: Session) -> None: # when providing an external chain and matching token, everything works network = common.encode_network(chain_id=66666, slip44=60) token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=66666, decimals=8) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx(session, **params, definitions=common.make_defs(network, token)) # TODO check that FakeTok is used, FAKE network -def test_external_chain_token_mismatch(client: Client) -> None: - with client: +def test_external_chain_token_mismatch(session: Session) -> None: + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) # when providing external defs, we explicitly allow, but not use, tokens @@ -156,31 +156,33 @@ def test_external_chain_token_mismatch(client: Client) -> None: ) params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666) - ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token)) + ethereum.sign_tx( + session, **params, definitions=common.make_defs(network, token) + ) # TODO check that UNKN is used for token, FAKE for network -def _call_getaddress(client: Client, slip44: int, network: bytes | None) -> None: +def _call_getaddress(session: Session, slip44: int, network: bytes | None) -> None: ethereum.get_address( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), show_display=False, encoded_network=network, ) -def _call_signmessage(client: Client, slip44: int, network: bytes | None) -> None: +def _call_signmessage(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_message( - client, + session, parse_path(f"m/44h/{slip44}h/0h"), b"hello", encoded_network=network, ) -def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> None: +def _call_sign_typed_data(session: Session, slip44: int, network: bytes | None) -> None: ethereum.sign_typed_data( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), TYPED_DATA, metamask_v4_compat=True, @@ -189,10 +191,10 @@ def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> def _call_sign_typed_data_hash( - client: Client, slip44: int, network: bytes | None + session: Session, slip44: int, network: bytes | None ) -> None: ethereum.sign_typed_data_hash( - client, + session, parse_path(f"m/44h/{slip44}h/0h/0/0"), b"\x00" * 32, b"\xff" * 32, @@ -200,7 +202,7 @@ def _call_sign_typed_data_hash( ) -MethodType = Callable[[Client, int, "bytes | None"], None] +MethodType = Callable[[Session, int, "bytes | None"], None] METHODS = ( @@ -212,29 +214,29 @@ def _call_sign_typed_data_hash( @pytest.mark.parametrize("method", METHODS) -def test_method_builtin(client: Client, method: MethodType) -> None: +def test_method_builtin(session: Session, method: MethodType) -> None: # calling a method with a builtin slip44 will work - method(client, 60, None) + method(session, 60, None) @pytest.mark.parametrize("method", METHODS) -def test_method_def_missing(client: Client, method: MethodType) -> None: +def test_method_def_missing(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has no definition will fail with pytest.raises(TrezorFailure, match="Forbidden key path"): - method(client, 66666, None) + method(session, 66666, None) @pytest.mark.parametrize("method", METHODS) -def test_method_external(client: Client, method: MethodType) -> None: +def test_method_external(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition will work network = common.encode_network(slip44=66666) - method(client, 66666, network) + method(session, 66666, network) @pytest.mark.parametrize("method", METHODS) -def test_method_external_mismatch(client: Client, method: MethodType) -> None: +def test_method_external_mismatch(session: Session, method: MethodType) -> None: # calling a method with a slip44 that has an external definition that does not match # the slip44 will fail network = common.encode_network(slip44=77777) with pytest.raises(TrezorFailure, match="Network definition mismatch"): - method(client, 66666, network) + method(session, 66666, network) diff --git a/tests/device_tests/ethereum/test_definitions_bad.py b/tests/device_tests/ethereum/test_definitions_bad.py index 3f21195643f..ae917105ae9 100644 --- a/tests/device_tests/ethereum/test_definitions_bad.py +++ b/tests/device_tests/ethereum/test_definitions_bad.py @@ -5,7 +5,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import EthereumDefinitionType from trezorlib.tools import parse_path @@ -16,99 +16,99 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] -def fails(client: Client, network: bytes, match: str) -> None: +def fails(session: Session, network: bytes, match: str) -> None: with pytest.raises(TrezorFailure, match=match): ethereum.get_address( - client, + session, parse_path("m/44h/666666h/0h"), show_display=False, encoded_network=network, ) -def test_short_message(client: Client) -> None: - fails(client, b"\x00", "Invalid Ethereum definition") +def test_short_message(session: Session) -> None: + fails(session, b"\x00", "Invalid Ethereum definition") -def test_mangled_signature(client: Client) -> None: +def test_mangled_signature(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_signature = signature[:-1] + b"\xff" - fails(client, payload + proof + bad_signature, "Invalid definition signature") + fails(session, payload + proof + bad_signature, "Invalid definition signature") -def test_not_enough_signatures(client: Client) -> None: +def test_not_enough_signatures(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [], threshold=1) - fails(client, payload + proof + signature, "Invalid definition signature") + fails(session, payload + proof + signature, "Invalid definition signature") -def test_missing_signature(client: Client) -> None: +def test_missing_signature(session: Session) -> None: payload = make_payload() proof, _ = sign_payload(payload, []) - fails(client, payload + proof, "Invalid Ethereum definition") + fails(session, payload + proof, "Invalid Ethereum definition") -def test_mangled_payload(client: Client) -> None: +def test_mangled_payload(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_payload = payload[:-1] + b"\xff" - fails(client, bad_payload + proof + signature, "Invalid definition signature") + fails(session, bad_payload + proof + signature, "Invalid definition signature") -def test_proof_length_mismatch(client: Client) -> None: +def test_proof_length_mismatch(session: Session) -> None: payload = make_payload() _, signature = sign_payload(payload, []) bad_proof = b"\x01" - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_proof(client: Client) -> None: +def test_bad_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, [sha256(b"x").digest()]) bad_proof = proof[:-1] + b"\xff" - fails(client, payload + bad_proof + signature, "Invalid definition signature") + fails(session, payload + bad_proof + signature, "Invalid definition signature") -def test_trimmed_proof(client: Client) -> None: +def test_trimmed_proof(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) bad_proof = proof[:-1] - fails(client, payload + bad_proof + signature, "Invalid Ethereum definition") + fails(session, payload + bad_proof + signature, "Invalid Ethereum definition") -def test_bad_prefix(client: Client) -> None: +def test_bad_prefix(session: Session) -> None: payload = make_payload() payload = b"trzd2" + payload[5:] proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_bad_type(client: Client) -> None: +def test_bad_type(session: Session) -> None: # assuming we expect a network definition payload = make_payload(data_type=EthereumDefinitionType.TOKEN, message=make_token()) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition type mismatch") + fails(session, payload + proof + signature, "Definition type mismatch") -def test_outdated(client: Client) -> None: +def test_outdated(session: Session) -> None: payload = make_payload(timestamp=0) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Definition is outdated") + fails(session, payload + proof + signature, "Definition is outdated") -def test_malformed_protobuf(client: Client) -> None: +def test_malformed_protobuf(session: Session) -> None: payload = make_payload(message=b"\x00") proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") -def test_protobuf_mismatch(client: Client) -> None: +def test_protobuf_mismatch(session: Session) -> None: payload = make_payload( data_type=EthereumDefinitionType.NETWORK, message=make_token() ) proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature, "Invalid Ethereum definition") + fails(session, payload + proof + signature, "Invalid Ethereum definition") payload = make_payload( data_type=EthereumDefinitionType.TOKEN, message=make_network() @@ -119,13 +119,13 @@ def test_protobuf_mismatch(client: Client) -> None: params = DEFAULT_ERC20_PARAMS.copy() params.update(to=ERC20_FAKE_ADDRESS) ethereum.sign_tx( - client, + session, **params, definitions=make_defs(None, payload + proof + signature), ) -def test_trailing_garbage(client: Client) -> None: +def test_trailing_garbage(session: Session) -> None: payload = make_payload() proof, signature = sign_payload(payload, []) - fails(client, payload + proof + signature + b"\x00", "Invalid Ethereum definition") + fails(session, payload + proof + signature + b"\x00", "Invalid Ethereum definition") diff --git a/tests/device_tests/ethereum/test_getaddress.py b/tests/device_tests/ethereum/test_getaddress.py index 3add0ad92fb..b57fcd6afd3 100644 --- a/tests/device_tests/ethereum/test_getaddress.py +++ b/tests/device_tests/ethereum/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -27,21 +27,21 @@ @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress(client: Client, parameters, result): +def test_getaddress(session: Session, parameters, result): address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True) == result["address"] + ethereum.get_address(session, address_n, show_display=True) == result["address"] ) @pytest.mark.models("core", reason="No input flow for T1") @parametrize_using_common_fixtures("ethereum/getaddress.json") -def test_getaddress_chunkify_details(client: Client, parameters, result): - with client: +def test_getaddress_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) assert ( - ethereum.get_address(client, address_n, show_display=True, chunkify=True) + ethereum.get_address(session, address_n, show_display=True, chunkify=True) == result["address"] ) diff --git a/tests/device_tests/ethereum/test_getpublickey.py b/tests/device_tests/ethereum/test_getpublickey.py index 103b261f579..586abf736d7 100644 --- a/tests/device_tests/ethereum/test_getpublickey.py +++ b/tests/device_tests/ethereum/test_getpublickey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -27,9 +27,9 @@ @parametrize_using_common_fixtures("ethereum/getpublickey.json") -def test_ethereum_getpublickey(client: Client, parameters, result): +def test_ethereum_getpublickey(session: Session, parameters, result): path = parse_path(parameters["path"]) - res = ethereum.get_public_node(client, path) + res = ethereum.get_public_node(session, path) assert res.node.depth == len(path) assert res.node.fingerprint == result["fingerprint"] assert res.node.child_num == result["child_num"] @@ -38,14 +38,14 @@ def test_ethereum_getpublickey(client: Client, parameters, result): assert res.xpub == result["xpub"] -def test_slip25_disallowed(client: Client): +def test_slip25_disallowed(session: Session): path = parse_path("m/10025'/60'/0'/0/0") with pytest.raises(TrezorFailure): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) @pytest.mark.models("legacy") -def test_legacy_restrictions(client: Client): +def test_legacy_restrictions(session: Session): path = parse_path("m/46'") with pytest.raises(TrezorFailure, match="Invalid path for EthereumGetPublicKey"): - ethereum.get_public_node(client, path) + ethereum.get_public_node(session, path) diff --git a/tests/device_tests/ethereum/test_sign_typed_data.py b/tests/device_tests/ethereum/test_sign_typed_data.py index 14dda4bdbe0..43b872af4b2 100644 --- a/tests/device_tests/ethereum/test_sign_typed_data.py +++ b/tests/device_tests/ethereum/test_sign_typed_data.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum, exceptions -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -28,11 +28,11 @@ @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data( - client, + session, address_n, parameters["data"], metamask_v4_compat=parameters["metamask_v4_compat"], @@ -43,11 +43,11 @@ def test_ethereum_sign_typed_data(client: Client, parameters, result): @pytest.mark.models("legacy") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") -def test_ethereum_sign_typed_data_blind(client: Client, parameters, result): - with client: +def test_ethereum_sign_typed_data_blind(session: Session, parameters, result): + with session: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data_hash( - client, + session, address_n, ethereum.decode_hex(parameters["domain_separator_hash"]), # message hash is empty for domain-only hashes @@ -96,13 +96,13 @@ def test_ethereum_sign_typed_data_blind(client: Client, parameters, result): @pytest.mark.models("core", skip="mercury", reason="Not yet implemented in new UI") -def test_ethereum_sign_typed_data_show_more_button(client: Client): - with client: +def test_ethereum_sign_typed_data_show_more_button(session: Session): + with session.client as client: client.watch_layout() IF = InputFlowEIP712ShowMore(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, @@ -110,13 +110,13 @@ def test_ethereum_sign_typed_data_show_more_button(client: Client): @pytest.mark.models("core") -def test_ethereum_sign_typed_data_cancel(client: Client): - with client, pytest.raises(exceptions.Cancelled): +def test_ethereum_sign_typed_data_cancel(session: Session): + with session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() IF = InputFlowEIP712Cancel(client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( - client, + session, parse_path("m/44h/60h/0h/0/0"), DATA, metamask_v4_compat=True, diff --git a/tests/device_tests/ethereum/test_sign_verify_message.py b/tests/device_tests/ethereum/test_sign_verify_message.py index 8cf2680ad82..7e50bd205ad 100644 --- a/tests/device_tests/ethereum/test_sign_verify_message.py +++ b/tests/device_tests/ethereum/test_sign_verify_message.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ethereum -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -26,18 +26,18 @@ @parametrize_using_common_fixtures("ethereum/signmessage.json") -def test_signmessage(client: Client, parameters, result): +def test_signmessage(session: Session, parameters, result): res = ethereum.sign_message( - client, parse_path(parameters["path"]), parameters["msg"] + session, parse_path(parameters["path"]), parameters["msg"] ) assert res.address == result["address"] assert res.signature.hex() == result["sig"] @parametrize_using_common_fixtures("ethereum/verifymessage.json") -def test_verify(client: Client, parameters, result): +def test_verify(session: Session, parameters, result): res = ethereum.verify_message( - client, + session, parameters["address"], bytes.fromhex(parameters["sig"]), parameters["msg"], @@ -45,7 +45,7 @@ def test_verify(client: Client, parameters, result): assert res is True -def test_verify_invalid(client: Client): +def test_verify_invalid(session: Session): # First vector from the verifymessage JSON fixture msg = "This is an example of a signed message." address = "0xEa53AF85525B1779eE99ece1a5560C0b78537C3b" @@ -54,7 +54,7 @@ def test_verify_invalid(client: Client): ) res = ethereum.verify_message( - client, + session, address, sig, msg, @@ -63,7 +63,7 @@ def test_verify_invalid(client: Client): # Changing the signature, expecting failure res = ethereum.verify_message( - client, + session, address, sig[:-1] + b"\x00", msg, @@ -72,7 +72,7 @@ def test_verify_invalid(client: Client): # Changing the message, expecting failure res = ethereum.verify_message( - client, + session, address, sig, msg + "abc", @@ -81,7 +81,7 @@ def test_verify_invalid(client: Client): # Changing the address, expecting failure res = ethereum.verify_message( - client, + session, address[:-1] + "a", sig, msg, diff --git a/tests/device_tests/ethereum/test_signtx.py b/tests/device_tests/ethereum/test_signtx.py index a550322dbd0..178f63710de 100644 --- a/tests/device_tests/ethereum/test_signtx.py +++ b/tests/device_tests/ethereum/test_signtx.py @@ -17,6 +17,7 @@ import pytest from trezorlib import ethereum, exceptions, messages, models +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters from trezorlib.exceptions import TrezorFailure @@ -56,28 +57,28 @@ def make_defs(parameters: dict) -> messages.EthereumDefinitions: "ethereum/sign_tx_eip155.json", ) @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx(client: Client, chunkify: bool, parameters: dict, result: dict): +def test_signtx(session: Session, chunkify: bool, parameters: dict, result: dict): input_flow = ( - InputFlowConfirmAllWarnings(client).get() - if not client.debug.legacy_debug + InputFlowConfirmAllWarnings(session.client).get() + if not session.client.debug.legacy_debug else None ) - _do_test_signtx(client, parameters, result, input_flow, chunkify=chunkify) + _do_test_signtx(session, parameters, result, input_flow, chunkify=chunkify) def _do_test_signtx( - client: Client, + session: Session, parameters: dict, result: dict, input_flow=None, chunkify: bool = False, ): - with client: + with session.client as client: if input_flow: client.watch_layout() client.set_input_flow(input_flow) sig_v, sig_r, sig_s = ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -120,10 +121,10 @@ def _do_test_signtx( @pytest.mark.models("core", reason="T1 does not support input flows") -def test_signtx_fee_info(client: Client): - input_flow = InputFlowEthereumSignTxShowFeeInfo(client).get() +def test_signtx_fee_info(session: Session): + input_flow = InputFlowEthereumSignTxShowFeeInfo(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -135,10 +136,10 @@ def test_signtx_fee_info(client: Client): skip="mercury", reason="T1 does not support input flows; Mercury can't send Cancel on Summary", ) -def test_signtx_go_back_from_summary(client: Client): - input_flow = InputFlowEthereumSignTxGoBackFromSummary(client).get() +def test_signtx_go_back_from_summary(session: Session): + input_flow = InputFlowEthereumSignTxGoBackFromSummary(session.client).get() _do_test_signtx( - client, + session, example_input_data["parameters"], example_input_data["result"], input_flow, @@ -147,12 +148,14 @@ def test_signtx_go_back_from_summary(client: Client): @parametrize_using_common_fixtures("ethereum/sign_tx_eip1559.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result: dict): - with client: +def test_signtx_eip1559( + session: Session, chunkify: bool, parameters: dict, result: dict +): + with session, session.client as client: if not client.debug.legacy_debug: client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_limit=int(parameters["gas_limit"], 16), @@ -171,14 +174,14 @@ def test_signtx_eip1559(client: Client, chunkify: bool, parameters: dict, result assert sig_v == result["sig_v"] -def test_sanity_checks(client: Client): +def test_sanity_checks(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -191,7 +194,7 @@ def test_sanity_checks(client: Client): # gas overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -204,7 +207,7 @@ def test_sanity_checks(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=123_456, gas_price=20_000, @@ -215,13 +218,13 @@ def test_sanity_checks(client: Client): ) -def test_data_streaming(client: Client): +def test_data_streaming(session: Session): """Only verifying the expected responses, the signatures are checked in vectorized function above. """ - with client: - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + with session: + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), (is_t1, messages.ButtonRequest(code=messages.ButtonRequestType.SignTx)), @@ -259,7 +262,7 @@ def test_data_streaming(client: Client): ) ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0, gas_price=20_000, @@ -271,11 +274,11 @@ def test_data_streaming(client: Client): ) -def test_signtx_eip1559_access_list(client: Client): - with client: +def test_signtx_eip1559_access_list(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -310,11 +313,11 @@ def test_signtx_eip1559_access_list(client: Client): ) -def test_signtx_eip1559_access_list_larger(client: Client): - with client: +def test_signtx_eip1559_access_list_larger(session: Session): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -363,14 +366,14 @@ def test_signtx_eip1559_access_list_larger(client: Client): ) -def test_sanity_checks_eip1559(client: Client): +def test_sanity_checks_eip1559(session: Session): """Is not vectorized because these are internal-only tests that do not need to be exposed to the public. """ # contract creation without data should fail. with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -384,7 +387,7 @@ def test_sanity_checks_eip1559(client: Client): # max fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -398,7 +401,7 @@ def test_sanity_checks_eip1559(client: Client): # priority fee overflow with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, @@ -412,7 +415,7 @@ def test_sanity_checks_eip1559(client: Client): # bad chain ID with pytest.raises(TrezorFailure, match=r"Chain ID out of bounds"): ethereum.sign_tx_eip1559( - client, + session, n=parse_path("m/44h/60h/0h/0/100"), nonce=0, gas_limit=20, @@ -443,10 +446,10 @@ def input_flow_data_go_back(client: Client, cancel: bool = False): "flow", (input_flow_data_skip, input_flow_data_scroll_down, input_flow_data_go_back) ) @pytest.mark.models("core", skip="mercury", reason="Not yet implemented in new UI") -def test_signtx_data_pagination(client: Client, flow): +def test_signtx_data_pagination(session: Session, flow): def _sign_tx_call(): ethereum.sign_tx( - client, + session, n=parse_path("m/44h/60h/0h/0/0"), nonce=0x0, gas_price=0x14, @@ -458,13 +461,13 @@ def _sign_tx_call(): data=bytes.fromhex(HEXDATA), ) - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(flow(client)) _sign_tx_call() if flow is not input_flow_data_scroll_down: - with client, pytest.raises(exceptions.Cancelled): + with session, session.client as client, pytest.raises(exceptions.Cancelled): client.watch_layout() client.set_input_flow(flow(client, cancel=True)) _sign_tx_call() @@ -473,20 +476,22 @@ def _sign_tx_call(): @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking.json") @pytest.mark.parametrize("chunkify", (True, False)) -def test_signtx_staking(client: Client, chunkify: bool, parameters: dict, result: dict): - input_flow = InputFlowEthereumSignTxStaking(client).get() +def test_signtx_staking( + session: Session, chunkify: bool, parameters: dict, result: dict +): + input_flow = InputFlowEthereumSignTxStaking(session.client).get() _do_test_signtx( - client, parameters, result, input_flow=input_flow, chunkify=chunkify + session, parameters, result, input_flow=input_flow, chunkify=chunkify ) @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_data_error.json") -def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dict): +def test_signtx_staking_bad_inputs(session: Session, parameters: dict, result: dict): # result not needed with pytest.raises(TrezorFailure, match=r"DataError"): ethereum.sign_tx( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), gas_price=int(parameters["gas_price"], 16), @@ -503,10 +508,10 @@ def test_signtx_staking_bad_inputs(client: Client, parameters: dict, result: dic @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_eip1559.json") -def test_signtx_staking_eip1559(client: Client, parameters: dict, result: dict): - with client: +def test_signtx_staking_eip1559(session: Session, parameters: dict, result: dict): + with session: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( - client, + session, n=parse_path(parameters["path"]), nonce=int(parameters["nonce"], 16), max_gas_fee=int(parameters["max_gas_fee"], 16), diff --git a/tests/device_tests/misc/test_msg_cipherkeyvalue.py b/tests/device_tests/misc/test_msg_cipherkeyvalue.py index 7a9fe664206..4efec7ab060 100644 --- a/tests/device_tests/misc/test_msg_cipherkeyvalue.py +++ b/tests/device_tests/misc/test_msg_cipherkeyvalue.py @@ -17,15 +17,15 @@ import pytest from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_encrypt(client: Client): +def test_encrypt(session: Session): res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -35,7 +35,7 @@ def test_encrypt(client: Client): assert res.hex() == "676faf8f13272af601776bc31bc14e8f" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -45,7 +45,7 @@ def test_encrypt(client: Client): assert res.hex() == "5aa0fbcb9d7fa669880745479d80c622" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -55,7 +55,7 @@ def test_encrypt(client: Client): assert res.hex() == "958d4f63269b61044aaedc900c8d6208" res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message!", @@ -66,7 +66,7 @@ def test_encrypt(client: Client): # different key res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test2", b"testing message!", @@ -77,7 +77,7 @@ def test_encrypt(client: Client): # different message res = misc.encrypt_keyvalue( - client, + session, [0, 1, 2], "test", b"testing message! it is different", @@ -90,7 +90,7 @@ def test_encrypt(client: Client): # different path res = misc.encrypt_keyvalue( - client, + session, [0, 1, 3], "test", b"testing message!", @@ -101,9 +101,9 @@ def test_encrypt(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_decrypt(client: Client): +def test_decrypt(session: Session): res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("676faf8f13272af601776bc31bc14e8f"), @@ -113,7 +113,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"), @@ -123,7 +123,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("958d4f63269b61044aaedc900c8d6208"), @@ -133,7 +133,7 @@ def test_decrypt(client: Client): assert res == b"testing message!" res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"), @@ -144,7 +144,7 @@ def test_decrypt(client: Client): # different key res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test2", bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"), @@ -155,7 +155,7 @@ def test_decrypt(client: Client): # different message res = misc.decrypt_keyvalue( - client, + session, [0, 1, 2], "test", bytes.fromhex( @@ -168,7 +168,7 @@ def test_decrypt(client: Client): # different path res = misc.decrypt_keyvalue( - client, + session, [0, 1, 3], "test", bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"), @@ -178,11 +178,11 @@ def test_decrypt(client: Client): assert res == b"testing message!" -def test_encrypt_badlen(client: Client): +def test_encrypt_badlen(session: Session): with pytest.raises(Exception): - misc.encrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.encrypt_keyvalue(session, [0, 1, 2], "test", b"testing") -def test_decrypt_badlen(client: Client): +def test_decrypt_badlen(session: Session): with pytest.raises(Exception): - misc.decrypt_keyvalue(client, [0, 1, 2], "test", b"testing") + misc.decrypt_keyvalue(session, [0, 1, 2], "test", b"testing") diff --git a/tests/device_tests/misc/test_msg_enablelabeling.py b/tests/device_tests/misc/test_msg_enablelabeling.py index bac9f23e3aa..2bd0d3d9752 100644 --- a/tests/device_tests/misc/test_msg_enablelabeling.py +++ b/tests/device_tests/misc/test_msg_enablelabeling.py @@ -17,16 +17,16 @@ import pytest from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.models("core") -def test_encrypt(client: Client): +def test_encrypt(session: Session): misc.encrypt_keyvalue( - client, + session, [], "Enable labeling?", b"", diff --git a/tests/device_tests/misc/test_msg_getecdhsessionkey.py b/tests/device_tests/misc/test_msg_getecdhsessionkey.py index 8c38f612b1d..d7c532dc5a0 100644 --- a/tests/device_tests/misc/test_msg_getecdhsessionkey.py +++ b/tests/device_tests/misc/test_msg_getecdhsessionkey.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_ecdh(client: Client): +def test_ecdh(session: Session): identity = messages.IdentityType( proto="gpg", user="", @@ -37,7 +37,7 @@ def test_ecdh(client: Client): "0407f2c6e5becf3213c1d07df0cfbe8e39f70a8c643df7575e5c56859ec52c45ca950499c019719dae0fda04248d851e52cf9d66eeb211d89a77be40de22b6c89d" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="secp256k1", @@ -55,7 +55,7 @@ def test_ecdh(client: Client): "04811a6c2bd2a547d0dd84747297fec47719e7c3f9b0024f027c2b237be99aac39a9230acbd163d0cb1524a0f5ea4bfed6058cec6f18368f72a12aa0c4d083ff64" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="nist256p1", @@ -73,7 +73,7 @@ def test_ecdh(client: Client): "40a8cf4b6a64c4314e80f15a8ea55812bd735fbb365936a48b2d78807b575fa17a" ) result = misc.get_ecdh_session_key( - client, + session, identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name="curve25519", diff --git a/tests/device_tests/misc/test_msg_getentropy.py b/tests/device_tests/misc/test_msg_getentropy.py index 593fb1a76c1..d5d19425f9b 100644 --- a/tests/device_tests/misc/test_msg_getentropy.py +++ b/tests/device_tests/misc/test_msg_getentropy.py @@ -20,7 +20,7 @@ from trezorlib import messages as m from trezorlib import misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session ENTROPY_LENGTHS_POW2 = [2**l for l in range(10)] ENTROPY_LENGTHS_POW2_1 = [2**l + 1 for l in range(10)] @@ -40,11 +40,11 @@ def entropy(data): @pytest.mark.parametrize("entropy_length", ENTROPY_LENGTHS) -def test_entropy(client: Client, entropy_length): - with client: - client.set_expected_responses( +def test_entropy(session: Session, entropy_length): + with session: + session.set_expected_responses( [m.ButtonRequest(code=m.ButtonRequestType.ProtectCall), m.Entropy] ) - ent = misc.get_entropy(client, entropy_length) + ent = misc.get_entropy(session, entropy_length) assert len(ent) == entropy_length print(f"{entropy_length} bytes: entropy = {entropy(ent)}") diff --git a/tests/device_tests/misc/test_msg_signidentity.py b/tests/device_tests/misc/test_msg_signidentity.py index bc9e7f5bd4e..6715387d387 100644 --- a/tests/device_tests/misc/test_msg_signidentity.py +++ b/tests/device_tests/misc/test_msg_signidentity.py @@ -17,13 +17,13 @@ import pytest from trezorlib import messages, misc -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_sign(client: Client): +def test_sign(session: Session): hidden = bytes.fromhex( "cd8552569d6e4509266ef137584d1e62c7579b5b8ed69bbafa4b864c6521e7c2" ) @@ -40,7 +40,7 @@ def test_sign(client: Client): path="/login", index=0, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "17F17smBTX9VTZA9Mj8LM5QGYNZnmziCjL" assert ( sig.public_key.hex() @@ -62,7 +62,7 @@ def test_sign(client: Client): path="/pub", index=3, ) - sig = misc.sign_identity(client, identity, hidden, visual) + sig = misc.sign_identity(session, identity, hidden, visual) assert sig.address == "1KAr6r5qF2kADL8bAaRQBjGKYEGxn9WrbS" assert ( sig.public_key.hex() @@ -80,7 +80,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="nist256p1" + session, identity, hidden, visual, ecdsa_curve_name="nist256p1" ) assert sig.address is None assert ( @@ -99,7 +99,7 @@ def test_sign(client: Client): proto="ssh", user="satoshi", host="bitcoin.org", port="", path="", index=47 ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -116,7 +116,7 @@ def test_sign(client: Client): proto="gpg", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( @@ -133,7 +133,7 @@ def test_sign(client: Client): proto="signify", user="satoshi", host="bitcoin.org", port="", path="" ) sig = misc.sign_identity( - client, identity, hidden, visual, ecdsa_curve_name="ed25519" + session, identity, hidden, visual, ecdsa_curve_name="ed25519" ) assert sig.address is None assert ( diff --git a/tests/device_tests/monero/test_getaddress.py b/tests/device_tests/monero/test_getaddress.py index dfd0ce5ab09..1a6d3ffc01c 100644 --- a/tests/device_tests/monero/test_getaddress.py +++ b/tests/device_tests/monero/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -47,19 +47,19 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_monero_getaddress(client: Client, path: str, expected_address: bytes): - address = monero.get_address(client, parse_path(path), show_display=True) +def test_monero_getaddress(session: Session, path: str, expected_address: bytes): + address = monero.get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_monero_getaddress_chunkify_details( - client: Client, path: str, expected_address: bytes + session: Session, path: str, expected_address: bytes ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = monero.get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/monero/test_getwatchkey.py b/tests/device_tests/monero/test_getwatchkey.py index eee83d0445f..30e3d7b1140 100644 --- a/tests/device_tests/monero/test_getwatchkey.py +++ b/tests/device_tests/monero/test_getwatchkey.py @@ -17,7 +17,7 @@ import pytest from trezorlib import monero -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -27,8 +27,8 @@ @pytest.mark.monero @pytest.mark.models("core") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_monero_getwatchkey(client: Client): - res = monero.get_watch_key(client, parse_path("m/44h/128h/0h")) +def test_monero_getwatchkey(session: Session): + res = monero.get_watch_key(session, parse_path("m/44h/128h/0h")) assert ( res.address == b"4Ahp23WfMrMFK3wYL2hLWQFGt87ZTeRkufS6JoQZu6MEFDokAQeGWmu9MA3GFq1yVLSJQbKJqVAn9F9DLYGpRzRAEXqAXKM" @@ -37,7 +37,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "8722520a581e2a50cc1adab4a1692401effd37b0d63b9d9b60fd7f34ea2b950e" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/1h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/1h")) assert ( res.address == b"44iAazhoAkv5a5RqLNVyh82a1n3ceNggmN4Ho7bUBJ14WkEVR8uFTe9f7v5rNnJ2kEbVXxfXiRzsD5Jtc6NvBi4D6WNHPie" @@ -46,7 +46,7 @@ def test_monero_getwatchkey(client: Client): res.watch_key.hex() == "1f70b7d9e86c11b7a5bee883b75c43d6be189c8f812726ea1ecd94b06bb7db04" ) - res = monero.get_watch_key(client, parse_path("m/44h/128h/2h")) + res = monero.get_watch_key(session, parse_path("m/44h/128h/2h")) assert ( res.address == b"47ejhmbZ4wHUhXaqA4b7PN667oPMkokf4ZkNdWrMSPy9TNaLVr7vLqVUQHh2MnmaAEiyrvLsX8xUf99q3j1iAeMV8YvSFcH" diff --git a/tests/device_tests/nem/test_getaddress.py b/tests/device_tests/nem/test_getaddress.py index b2b20c529ec..920dd974904 100644 --- a/tests/device_tests/nem/test_getaddress.py +++ b/tests/device_tests/nem/test_getaddress.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -28,10 +28,10 @@ @pytest.mark.models("t1b1", "t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_getaddress(client: Client, chunkify: bool): +def test_nem_getaddress(session: Session, chunkify: bool): assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x68, show_display=True, @@ -41,7 +41,7 @@ def test_nem_getaddress(client: Client, chunkify: bool): ) assert ( nem.get_address( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), 0x98, show_display=True, diff --git a/tests/device_tests/nem/test_signtx_mosaics.py b/tests/device_tests/nem/test_signtx_mosaics.py index 51cfd556a77..3e6b835f953 100644 --- a/tests/device_tests/nem/test_signtx_mosaics.py +++ b/tests/device_tests/nem/test_signtx_mosaics.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -32,9 +32,9 @@ ] -def test_nem_signtx_mosaic_supply_change(client: Client): +def test_nem_signtx_mosaic_supply_change(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_mosaic_supply_change(client: Client): ) -def test_nem_signtx_mosaic_creation(client: Client): +def test_nem_signtx_mosaic_creation(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -93,9 +93,9 @@ def test_nem_signtx_mosaic_creation(client: Client): ) -def test_nem_signtx_mosaic_creation_properties(client: Client): +def test_nem_signtx_mosaic_creation_properties(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, @@ -130,9 +130,9 @@ def test_nem_signtx_mosaic_creation_properties(client: Client): ) -def test_nem_signtx_mosaic_creation_levy(client: Client): +def test_nem_signtx_mosaic_creation_levy(session: Session): tx = nem.sign_tx( - client, + session, ADDRESS_N, { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_multisig.py b/tests/device_tests/nem/test_signtx_multisig.py index d153547c424..ef641e52f39 100644 --- a/tests/device_tests/nem/test_signtx_multisig.py +++ b/tests/device_tests/nem/test_signtx_multisig.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,9 +31,9 @@ # assertion data from T1 -def test_nem_signtx_aggregate_modification(client: Client): +def test_nem_signtx_aggregate_modification(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -61,9 +61,9 @@ def test_nem_signtx_aggregate_modification(client: Client): ) -def test_nem_signtx_multisig(client: Client): +def test_nem_signtx_multisig(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 1, @@ -98,7 +98,7 @@ def test_nem_signtx_multisig(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -132,9 +132,9 @@ def test_nem_signtx_multisig(client: Client): ) -def test_nem_signtx_multisig_signer(client: Client): +def test_nem_signtx_multisig_signer(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 333, @@ -169,7 +169,7 @@ def test_nem_signtx_multisig_signer(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 900000, diff --git a/tests/device_tests/nem/test_signtx_others.py b/tests/device_tests/nem/test_signtx_others.py index f775c60cdf6..9760d8c5235 100644 --- a/tests/device_tests/nem/test_signtx_others.py +++ b/tests/device_tests/nem/test_signtx_others.py @@ -17,7 +17,7 @@ import pytest from trezorlib import nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -31,10 +31,10 @@ # assertion data from T1 -def test_nem_signtx_importance_transfer(client: Client): - with client: +def test_nem_signtx_importance_transfer(session: Session): + with session: tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 12349215, @@ -60,9 +60,9 @@ def test_nem_signtx_importance_transfer(client: Client): ) -def test_nem_signtx_provision_namespace(client: Client): +def test_nem_signtx_provision_namespace(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, diff --git a/tests/device_tests/nem/test_signtx_transfers.py b/tests/device_tests/nem/test_signtx_transfers.py index 0388b30ffb4..2df62b55936 100644 --- a/tests/device_tests/nem/test_signtx_transfers.py +++ b/tests/device_tests/nem/test_signtx_transfers.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, nem -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12, is_core @@ -32,16 +32,16 @@ # assertion data from T1 @pytest.mark.parametrize("chunkify", (True, False)) -def test_nem_signtx_simple(client: Client, chunkify: bool): - with client: - client.set_expected_responses( +def test_nem_signtx_simple(session: Session, chunkify: bool): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Unencrypted message messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -53,7 +53,7 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -82,16 +82,16 @@ def test_nem_signtx_simple(client: Client, chunkify: bool): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_encrypted_payload(client: Client): - with client: - client.set_expected_responses( +def test_nem_signtx_encrypted_payload(session: Session): + with session: + session.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), # Ask for encryption messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), ( - is_core(client), + is_core(session), messages.ButtonRequest( code=messages.ButtonRequestType.ConfirmOutput ), @@ -103,7 +103,7 @@ def test_nem_signtx_encrypted_payload(client: Client): ) tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 74649215, @@ -134,9 +134,9 @@ def test_nem_signtx_encrypted_payload(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_xem_as_mosaic(client: Client): +def test_nem_signtx_xem_as_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -168,9 +168,9 @@ def test_nem_signtx_xem_as_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_unknown_mosaic(client: Client): +def test_nem_signtx_unknown_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -202,9 +202,9 @@ def test_nem_signtx_unknown_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic(client: Client): +def test_nem_signtx_known_mosaic(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -236,9 +236,9 @@ def test_nem_signtx_known_mosaic(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_known_mosaic_with_levy(client: Client): +def test_nem_signtx_known_mosaic_with_levy(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, @@ -270,9 +270,9 @@ def test_nem_signtx_known_mosaic_with_levy(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_nem_signtx_multiple_mosaics(client: Client): +def test_nem_signtx_multiple_mosaics(session: Session): tx = nem.sign_tx( - client, + session, parse_path("m/44h/1h/0h/0h/0h"), { "timeStamp": 76809215, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py index 416fef78eac..8841a52426f 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py @@ -19,7 +19,7 @@ import pytest from trezorlib import device, exceptions, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import ( @@ -28,9 +28,9 @@ ) -def do_recover_legacy(client: Client, mnemonic: list[str]): +def do_recover_legacy(session: Session, mnemonic: list[str]): def input_callback(_): - word, pos = client.debug.read_recovery_word() + word, pos = session.client.debug.read_recovery_word() if pos != 0 and pos is not None: word = mnemonic[pos - 1] mnemonic[pos - 1] = None @@ -39,7 +39,7 @@ def input_callback(_): return word ret = device.recover( - client, + session, type=messages.RecoveryType.DryRun, word_count=len(mnemonic), input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, @@ -50,58 +50,59 @@ def input_callback(_): return ret -def do_recover_core(client: Client, mnemonic: list[str], mismatch: bool = False): - with client: +def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): + with session.client as client: client.watch_layout() IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) client.set_input_flow(IF.get()) - return device.recover(client, type=messages.RecoveryType.DryRun) + return device.recover(session, type=messages.RecoveryType.DryRun) -def do_recover(client: Client, mnemonic: list[str], mismatch: bool = False): - if client.model is models.T1B1: - return do_recover_legacy(client, mnemonic) +def do_recover(session: Session, mnemonic: list[str], mismatch: bool = False): + if session.model is models.T1B1: + return do_recover_legacy(session, mnemonic) else: - return do_recover_core(client, mnemonic, mismatch) + return do_recover_core(session, mnemonic, mismatch) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_dry_run(client: Client): - ret = do_recover(client, MNEMONIC12.split(" ")) +def test_dry_run(session: Session): + ret = do_recover(session, MNEMONIC12.split(" ")) assert isinstance(ret, messages.Success) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_seed_mismatch(client: Client): +def test_seed_mismatch(session: Session): with pytest.raises( exceptions.TrezorFailure, match="does not match the one in the device" ): - do_recover(client, ["all"] * 12, mismatch=True) + do_recover(session, ["all"] * 12, mismatch=True) @pytest.mark.models("legacy") -def test_invalid_seed_t1(client: Client): +def test_invalid_seed_t1(session: Session): with pytest.raises(exceptions.TrezorFailure, match="Invalid seed"): - do_recover(client, ["stick"] * 12) + do_recover(session, ["stick"] * 12) @pytest.mark.models("core") -def test_invalid_seed_core(client: Client): - with client: +def test_invalid_seed_core(session: Session): + with session, session.client as client: client.watch_layout() - IF = InputFlowBip39RecoveryDryRunInvalid(client) + IF = InputFlowBip39RecoveryDryRunInvalid(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): return device.recover( - client, + session, type=messages.RecoveryType.DryRun, ) @pytest.mark.setup_client(uninitialized=True) -def test_uninitialized(client: Client): +@pytest.mark.uninitialized_session +def test_uninitialized(session: Session): with pytest.raises(exceptions.TrezorFailure, match="not initialized"): - do_recover(client, ["all"] * 12) + do_recover(session, ["all"] * 12) DRY_RUN_ALLOWED_FIELDS = ( @@ -140,7 +141,7 @@ def _make_bad_params(): @pytest.mark.parametrize("field_name, field_value", _make_bad_params()) -def test_bad_parameters(client: Client, field_name: str, field_value: Any): +def test_bad_parameters(session: Session, field_name: str, field_value: Any): msg = messages.RecoveryDevice( type=messages.RecoveryType.DryRun, word_count=12, @@ -152,4 +153,4 @@ def test_bad_parameters(client: Client, field_name: str, field_value: Any): exceptions.TrezorFailure, match="Forbidden field set in dry-run", ): - client.call(msg) + session.call(msg) diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py index 4f2eab6147b..7ddc634b8d4 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t1.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import MNEMONIC12 @@ -29,9 +29,10 @@ @pytest.mark.setup_client(uninitialized=True) -def test_pin_passphrase(client: Client): +def test_pin_passphrase(session: Session): + debug = session.client.debug mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -43,30 +44,30 @@ def test_pin_passphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -76,23 +77,26 @@ def test_pin_passphrase(client: Client): assert fakes == 12 assert mnemonic == [None] * 12 + raise Exception("TEST IS USING INIT MESSAGE - TODO CHANGE") # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + session.init_device() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_nopin_nopassphrase(client: Client): +def test_nopin_nopassphrase(session: Session): mnemonic = MNEMONIC12.split(" ") - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -104,19 +108,20 @@ def test_nopin_nopassphrase(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug = session.client.debug + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) fakes = 0 for _ in range(int(12 * 2)): assert isinstance(ret, messages.WordRequest) - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word=mnemonic[pos - 1])) + ret = session.call_raw(messages.WordAck(word=mnemonic[pos - 1])) mnemonic[pos - 1] = None else: - ret = client.call_raw(messages.WordAck(word=word)) + ret = session.call_raw(messages.WordAck(word=word)) fakes += 1 # Workflow succesfully ended @@ -126,21 +131,26 @@ def test_nopin_nopassphrase(client: Client): assert fakes == 12 assert mnemonic == [None] * 12 + raise Exception("TEST IS USING INIT MESSAGE - TODO CHANGE") + # Mnemonic is the same - client.init_device() - assert client.debug.state().mnemonic_secret == MNEMONIC12.encode() + # session.init_device() + assert debug.state().mnemonic_secret == MNEMONIC12.encode() - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_word_fail(client: Client): - ret = client.call_raw( +def test_word_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -152,23 +162,24 @@ def test_word_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.WordRequest) for _ in range(int(12 * 2)): - (word, pos) = client.debug.read_recovery_word() + (word, pos) = debug.read_recovery_word() if pos != 0: - ret = client.call_raw(messages.WordAck(word="kwyjibo")) + ret = session.call_raw(messages.WordAck(word="kwyjibo")) assert isinstance(ret, messages.Failure) break else: - client.call_raw(messages.WordAck(word=word)) + session.call_raw(messages.WordAck(word=word)) @pytest.mark.setup_client(uninitialized=True) -def test_pin_fail(client: Client): - ret = client.call_raw( +def test_pin_fail(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.RecoveryDevice( word_count=12, passphrase_protection=True, @@ -180,36 +191,36 @@ def test_pin_fail(client: Client): # click through confirmation assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin(PIN4) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN4) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time, but different one - pin_encoded = client.debug.encode_pin(PIN6) - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin(PIN6) + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Failure should be raised assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): device.recover( - client, + session, word_count=12, pin_protection=False, passphrase_protection=False, label="label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) - ret = client.call_raw( + ret = session.call_raw( messages.RecoveryDevice( word_count=12, input_method=messages.RecoveryDeviceInputMethod.ScrambledWords, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py index 6046e85ca78..abca75bbee6 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC12 from ...input_flows import InputFlowBip39Recovery @@ -26,47 +26,49 @@ @pytest.mark.setup_client(uninitialized=True) -def test_tt_pin_passphrase(client: Client): - with client: +@pytest.mark.uninitialized_session +def test_tt_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" @pytest.mark.setup_client(uninitialized=True) -def test_tt_nopin_nopassphrase(client: Client): - with client: +@pytest.mark.uninitialized_session +def test_tt_nopin_nopassphrase(session: Session): + with session.client as client: IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) client.set_input_flow(IF.get()) device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="hello", ) - assert client.debug.state().mnemonic_secret.decode() == MNEMONIC12 - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Bip39 - assert client.features.label == "hello" + assert session.client.debug.state().mnemonic_secret.decode() == MNEMONIC12 + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Bip39 + assert session.features.label == "hello" -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(RuntimeError): - device.recover(client) + device.recover(session) with pytest.raises(exceptions.TrezorFailure, match="Already initialized"): - client.call(messages.RecoveryDevice()) + session.call(messages.RecoveryDevice()) diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py index fa181117357..ce964ec3fc7 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import MNEMONIC_SLIP39_ADVANCED_20, MNEMONIC_SLIP39_ADVANCED_33 from ...input_flows import ( @@ -28,7 +28,7 @@ InputFlowSlip39AdvancedRecoveryThresholdReached, ) -pytestmark = pytest.mark.models("core") +pytestmark = [pytest.mark.models("core"), pytest.mark.uninitialized_session] EXTRA_GROUP_SHARE = [ "eraser senior decision smug corner ruin rescue cubic angel tackle skin skunk program roster trash rumor slush angel flea amazing" @@ -46,13 +46,13 @@ # To allow reusing functionality for multiple tests def _test_secret( - client: Client, shares: list[str], secret: str, click_info: bool = False + session: Session, shares: list[str], secret: str, click_info: bool = False ): - with client: + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", @@ -60,86 +60,86 @@ def _test_secret( # Workflow succesfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Advanced - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Advanced + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_secret(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret) +def test_secret(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret) @pytest.mark.parametrize("shares, secret", VECTORS) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models(skip="safe3", reason="safe3 does not have info button") -def test_secret_click_info_button(client: Client, shares: list[str], secret: str): - _test_secret(client, shares, secret, click_info=True) +def test_secret_click_info_button(session: Session, shares: list[str], secret: str): + _test_secret(session, shares, secret, click_info=True) @pytest.mark.setup_client(uninitialized=True) -def test_extra_share_entered(client: Client): +def test_extra_share_entered(session: Session): _test_secret( - client, + session, shares=EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20, secret=VECTORS[0][1], ) @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryNoAbort( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): # we choose the second share from the fixture because # the 1st is 1of1 and group threshold condition is reached first first_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ") # second share is first 4 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] - with client: + with session, session.client as client: IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( - client, first_share, second_share + session, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_group_threshold_reached(client: Client): +def test_group_threshold_reached(session: Session): # first share in the fixture is 1of1 so we choose that first_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ") # second share is first 3 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] - with client: + with session, session.client as client: IF = InputFlowSlip39AdvancedRecoveryThresholdReached( - client, first_share, second_share + session, first_share, second_share ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py index 73e18a8686c..136be18bb6c 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import MNEMONIC_SLIP39_ADVANCED_20 @@ -39,14 +39,14 @@ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryDryRun( client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -60,9 +60,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39AdvancedRecoveryDryRun( @@ -70,7 +70,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py index 3f7ed75e730..9c4117dd866 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, exceptions, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -36,7 +36,7 @@ InputFlowSlip39BasicRecoveryWrongNthWord, ) -pytestmark = pytest.mark.models("core") +pytestmark = [pytest.mark.models("core"), pytest.mark.uninitialized_session] MNEMONIC_SLIP39_BASIC_20_1of1 = [ "academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic academic rebuild aquatic spew" @@ -70,32 +70,32 @@ @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("shares, secret, backup_type", VECTORS) def test_secret( - client: Client, shares: list[str], secret: str, backup_type: messages.BackupType + session: Session, shares: list[str], secret: str, backup_type: messages.BackupType ): - with client: + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - ret = device.recover(client, pin_protection=False, label="label") + ret = device.recover(session, pin_protection=False, label="label") # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is backup_type + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is backup_type # Check mnemonic - assert client.debug.state().mnemonic_secret.hex() == secret + assert session.client.debug.state().mnemonic_secret.hex() == secret @pytest.mark.setup_client(uninitialized=True) -def test_recover_with_pin_passphrase(client: Client): - with client: +def test_recover_with_pin_passphrase(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery( client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" ) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, pin_protection=True, passphrase_protection=True, label="label", @@ -103,99 +103,99 @@ def test_recover_with_pin_passphrase(client: Client): # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is True - assert client.features.passphrase_protection is True - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.pin_protection is True + assert session.features.passphrase_protection is True + assert session.features.backup_type is messages.BackupType.Slip39_Basic @pytest.mark.setup_client(uninitialized=True) -def test_abort(client: Client): - with client: +def test_abort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbort(client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_abort_between_shares(client: Client): - with client: +def test_abort_between_shares(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( client, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False - assert client.features.recovery_status is messages.RecoveryStatus.Nothing + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False + assert session.features.recovery_status is messages.RecoveryStatus.Nothing @pytest.mark.setup_client(uninitialized=True) -def test_noabort(client: Client): - with client: +def test_noabort(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) client.set_input_flow(IF.get()) - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is True + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is True @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_first_share(client: Client): - with client: - IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(client) +def test_invalid_mnemonic_first_share(session: Session): + with session, session.client as client: + IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) -def test_invalid_mnemonic_second_share(client: Client): - with client: +def test_invalid_mnemonic_second_share(session: Session): + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( - client, MNEMONIC_SLIP39_BASIC_20_3of6 + session, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") - client.init_device() - assert client.features.initialized is False + device.recover(session, pin_protection=False, label="label") + # TODO remove? session.init_device() + assert session.features.initialized is False @pytest.mark.setup_client(uninitialized=True) @pytest.mark.parametrize("nth_word", range(3)) -def test_wrong_nth_word(client: Client, nth_word: int): +def test_wrong_nth_word(session: Session, nth_word: int): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: - IF = InputFlowSlip39BasicRecoveryWrongNthWord(client, share, nth_word) + with session, session.client as client: + IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_same_share(client: Client): +def test_same_share(session: Session): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with client: - IF = InputFlowSlip39BasicRecoverySameShare(client, share) + with session, session.client as client: + IF = InputFlowSlip39BasicRecoverySameShare(session, share) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): - device.recover(client, pin_protection=False, label="label") + device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) -def test_1of1(client: Client): - with client: +def test_1of1(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, pin_protection=False, passphrase_protection=False, label="label", @@ -203,7 +203,7 @@ def test_1of1(client: Client): # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.initialized is True - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is messages.BackupType.Slip39_Basic + assert session.features.initialized is True + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is messages.BackupType.Slip39_Basic diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py index 4c9ddf8036f..3fcd7b51bd2 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py @@ -17,7 +17,7 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...input_flows import InputFlowSlip39BasicRecoveryDryRun @@ -37,12 +37,12 @@ @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_dryrun(client: Client): - with client: +def test_2of3_dryrun(session: Session): + with session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) client.set_input_flow(IF.get()) ret = device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", @@ -56,9 +56,9 @@ def test_2of3_dryrun(client: Client): @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) -def test_2of3_invalid_seed_dryrun(client: Client): +def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with client, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39BasicRecoveryDryRun( @@ -66,7 +66,7 @@ def test_2of3_invalid_seed_dryrun(client: Client): ) client.set_input_flow(IF.get()) device.recover( - client, + session, passphrase_protection=False, pin_protection=False, label="label", diff --git a/tests/device_tests/reset_recovery/test_reset_backup.py b/tests/device_tests/reset_recovery/test_reset_backup.py index 148087d4f4b..f08600817f8 100644 --- a/tests/device_tests/reset_recovery/test_reset_backup.py +++ b/tests/device_tests/reset_recovery/test_reset_backup.py @@ -19,7 +19,7 @@ from shamir_mnemonic import shamir from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import BackupAvailability, BackupType from ...common import WITH_MOCK_URANDOM @@ -31,32 +31,32 @@ ) -def backup_flow_bip39(client: Client) -> bytes: - with client: +def backup_flow_bip39(session: Session) -> bytes: + with session.client as client: IF = InputFlowBip39Backup(client) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) assert IF.mnemonic is not None return IF.mnemonic.encode() -def backup_flow_slip39_basic(client: Client): - with client: +def backup_flow_slip39_basic(session: Session): + with session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) groups = shamir.decode_mnemonics(IF.mnemonics[:3]) ems = shamir.recover_ems(groups) return ems.ciphertext -def backup_flow_slip39_advanced(client: Client): - with client: +def backup_flow_slip39_advanced(session: Session): + with session.client as client: IF = InputFlowSlip39AdvancedBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13] groups = shamir.decode_mnemonics(mnemonics) @@ -74,32 +74,35 @@ def backup_flow_slip39_advanced(client: Client): @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_msg(client: Client, backup_type, backup_flow): - with WITH_MOCK_URANDOM, client: +@pytest.mark.uninitialized_session +def test_skip_backup_msg(session: Session, backup_type, backup_flow): + assert session.features.initialized is False + + with WITH_MOCK_URANDOM, session: device.reset( - client, + session, skip_backup=True, passphrase_protection=False, pin_protection=False, backup_type=backup_type, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session = session.client.get_session() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.client.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret @@ -107,32 +110,35 @@ def test_skip_backup_msg(client: Client, backup_type, backup_flow): @pytest.mark.models("core") @pytest.mark.parametrize("backup_type, backup_flow", VECTORS) @pytest.mark.setup_client(uninitialized=True) -def test_skip_backup_manual(client: Client, backup_type: BackupType, backup_flow): - with WITH_MOCK_URANDOM, client: +@pytest.mark.uninitialized_session +def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): + assert session.features.initialized is False + + with WITH_MOCK_URANDOM, session, session.client as client: IF = InputFlowResetSkipBackup(client) client.set_input_flow(IF.get()) device.reset( - client, + session, pin_protection=False, passphrase_protection=False, backup_type=backup_type, ) - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.Required - assert client.features.unfinished_backup is False - assert client.features.no_backup is False - assert client.features.backup_type is backup_type + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.Required + assert session.features.unfinished_backup is False + assert session.features.no_backup is False + assert session.features.backup_type is backup_type - secret = backup_flow(client) + secret = backup_flow(session) - client.init_device() - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.unfinished_backup is False - assert client.features.backup_type is backup_type + session = session.client.get_session() + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.unfinished_backup is False + assert session.features.backup_type is backup_type assert secret is not None - state = client.debug.state() + state = session.client.debug.state() assert state.mnemonic_type is backup_type assert state.mnemonic_secret == secret diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py index 803818b3752..b9989ff8520 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_skipbackup.py @@ -18,7 +18,7 @@ from mnemonic import Mnemonic from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -28,8 +28,10 @@ @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup(client: Client): - ret = client.call_raw( +@pytest.mark.uninitialized_session +def test_reset_device_skip_backup(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -40,17 +42,17 @@ def test_reset_device_skip_backup(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False @@ -61,14 +63,14 @@ def test_reset_device_skip_backup(client: Client): expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -78,9 +80,9 @@ def test_reset_device_skip_backup(client: Client): mnemonic = [] for _ in range(STRENGTH // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.Success) @@ -90,13 +92,15 @@ def test_reset_device_skip_backup(client: Client): assert mnemonic == expected_mnemonic # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_skip_backup_break(client: Client): - ret = client.call_raw( +@pytest.mark.uninitialized_session +def test_reset_device_skip_backup_break(session: Session): + debug = session.client.debug + ret = session.call_raw( messages.ResetDevice( strength=STRENGTH, passphrase_protection=False, @@ -107,26 +111,26 @@ def test_reset_device_skip_backup_break(client: Client): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) assert isinstance(ret, messages.Success) # Check if device is properly initialized - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.GetFeatures()) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.Required assert ret.unfinished_backup is False assert ret.no_backup is False # start Backup workflow - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) # send Initialize -> break workflow - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -134,11 +138,11 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False # start backup again - should fail - ret = client.call_raw(messages.BackupDevice()) + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) # read Features again - ret = client.call_raw(messages.Initialize()) + ret = session.call_raw(messages.Initialize()) assert isinstance(ret, messages.Features) assert ret.initialized is True assert ret.backup_availability == messages.BackupAvailability.NotAvailable @@ -146,6 +150,6 @@ def test_reset_device_skip_backup_break(client: Client): assert ret.no_backup is False -def test_initialized_device_backup_fail(client: Client): - ret = client.call_raw(messages.BackupDevice()) +def test_initialized_device_backup_fail(session: Session): + ret = session.call_raw(messages.BackupDevice()) assert isinstance(ret, messages.Failure) diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py index 689b81b0d61..7e7f28bcce7 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t1.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t1.py @@ -18,7 +18,7 @@ from mnemonic import Mnemonic from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import EXTERNAL_ENTROPY, generate_entropy @@ -26,9 +26,10 @@ pytestmark = pytest.mark.models("legacy") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): + debug = session.client.debug # No PIN, no passphrase - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=False, @@ -38,13 +39,13 @@ def reset_device(client: Client, strength: int): ) assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -53,9 +54,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(session.debug.read_reset_word()) + session.debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -65,9 +66,9 @@ def reset_device(client: Client, strength: int): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(session.debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -77,32 +78,38 @@ def reset_device(client: Client, strength: int): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False assert resp.passphrase_protection is False # Do pin & passphrase-protected action, PassphraseRequest should NOT be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.Address) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_128(client: Client): - reset_device(client, 128) +@pytest.mark.uninitialized_session +def test_reset_device_128(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) +@pytest.mark.uninitialized_session +def test_reset_device_192(session: Session): + reset_device(session, 192) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_256_pin(client: Client): +@pytest.mark.uninitialized_session +def test_reset_device_256_pin(session: Session): + debug = session.client.debug strength = 256 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -113,24 +120,24 @@ def test_reset_device_256_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("654") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("654") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) # Provide entropy assert isinstance(ret, messages.EntropyRequest) - internal_entropy = client.debug.state().reset_entropy - ret = client.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) + internal_entropy = debug.state().reset_entropy + ret = session.call_raw(messages.EntropyAck(entropy=EXTERNAL_ENTROPY)) # Generate mnemonic locally entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) @@ -139,9 +146,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + session.call_raw(messages.ButtonAck()) mnemonic = " ".join(mnemonic) @@ -151,9 +158,9 @@ def test_reset_device_256_pin(client: Client): mnemonic = [] for _ in range(strength // 32 * 3): assert isinstance(ret, messages.ButtonRequest) - mnemonic.append(client.debug.read_reset_word()) - client.debug.press_yes() - resp = client.call_raw(messages.ButtonAck()) + mnemonic.append(debug.read_reset_word()) + debug.press_yes() + resp = session.call_raw(messages.ButtonAck()) assert isinstance(resp, messages.Success) @@ -163,23 +170,27 @@ def test_reset_device_256_pin(client: Client): assert mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is True assert resp.passphrase_protection is True # Do passphrase-protected action, PassphraseRequest should be raised - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PassphraseRequest) - client.call_raw(messages.Cancel()) + session.call_raw(messages.Cancel()) @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +@pytest.mark.uninitialized_session +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice( strength=strength, passphrase_protection=True, @@ -190,27 +201,27 @@ def test_failed_pin(client: Client): # Do you want ... ? assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for first time - pin_encoded = client.debug.encode_pin("1234") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("1234") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.PinMatrixRequest) # Enter PIN for second time - pin_encoded = client.debug.encode_pin("6789") - ret = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + pin_encoded = debug.encode_pin("6789") + ret = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(ret, messages.Failure) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.reset( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py index f4e48d81d39..d4732dc96f9 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py @@ -19,7 +19,7 @@ from trezorlib import device, messages from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ...common import EXTERNAL_ENTROPY, MNEMONIC12, WITH_MOCK_URANDOM, generate_entropy @@ -32,14 +32,15 @@ pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): - with WITH_MOCK_URANDOM, client: +def reset_device(session: Session, strength: int): + debug = session.client.debug + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -47,7 +48,7 @@ def reset_device(client: Client, strength: int): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -55,7 +56,7 @@ def reset_device(client: Client, strength: int): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False @@ -64,30 +65,34 @@ def reset_device(client: Client, strength: int): # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device(client: Client): - reset_device(client, 128) # 12 words +@pytest.mark.uninitialized_session +def test_reset_device(session: Session): + reset_device(session, 128) # 12 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_192(client: Client): - reset_device(client, 192) # 18 words +@pytest.mark.uninitialized_session +def test_reset_device_192(session: Session): + reset_device(session, 192) # 18 words @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_pin(client: Client): +@pytest.mark.uninitialized_session +def test_reset_device_pin(session: Session): + debug = session.client.debug strength = 256 # 24 words - with WITH_MOCK_URANDOM, client: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowBip39ResetPIN(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.reset( - client, + session, strength=strength, passphrase_protection=True, pin_protection=True, @@ -95,7 +100,7 @@ def test_reset_device_pin(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -103,7 +108,7 @@ def test_reset_device_pin(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is True @@ -111,16 +116,18 @@ def test_reset_device_pin(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_reset_failed_check(client: Client): +@pytest.mark.uninitialized_session +def test_reset_failed_check(session: Session): + debug = session.client.debug strength = 256 # 24 words - with WITH_MOCK_URANDOM, client: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowBip39ResetFailedCheck(client) client.set_input_flow(IF.get()) # PIN, passphrase, display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -128,7 +135,7 @@ def test_reset_failed_check(client: Client): ) # generate mnemonic locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = debug.state().reset_entropy entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -136,7 +143,7 @@ def test_reset_failed_check(client: Client): assert IF.mnemonic == expected_mnemonic # Check if device is properly initialized - resp = client.call_raw(messages.Initialize()) + resp = session.call_raw(messages.GetFeatures()) assert resp.initialized is True assert resp.backup_availability == messages.BackupAvailability.NotAvailable assert resp.pin_protection is False @@ -145,45 +152,56 @@ def test_reset_failed_check(client: Client): @pytest.mark.setup_client(uninitialized=True) -def test_failed_pin(client: Client): +@pytest.mark.uninitialized_session +def test_failed_pin(session: Session): + debug = session.client.debug strength = 128 - ret = client.call_raw( + ret = session.call_raw( messages.ResetDevice(strength=strength, pin_protection=True, label="test") ) # Confirm Reset assert isinstance(ret, messages.ButtonRequest) - client._raw_write(messages.ButtonAck()) - client.debug.press_yes() + + # client._raw_write(messages.ButtonAck()) + # client.debug.press_yes() + + # # Enter PIN for first time + # client.debug.input("654") + # ret = client.call_raw(messages.ButtonAck()) + + debug.press_yes() # TODO test fails here on T3T1 + ret = session.call_raw(messages.ButtonAck()) # Enter PIN for first time - client.debug.input("654") - ret = client.call_raw(messages.ButtonAck()) + assert isinstance(ret, messages.ButtonRequest) + debug.input("654") + ret = session.call_raw(messages.ButtonAck()) # Re-enter PIN for TR - if client.layout_type is LayoutType.TR: + if session.client.layout_type is LayoutType.TR: assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) # Enter PIN for second time assert isinstance(ret, messages.ButtonRequest) - client.debug.input("456") - ret = client.call_raw(messages.ButtonAck()) + debug.input("456") + ret = session.call_raw(messages.ButtonAck()) # PIN mismatch assert isinstance(ret, messages.ButtonRequest) - client.debug.press_yes() - ret = client.call_raw(messages.ButtonAck()) + debug.press_yes() + ret = session.call_raw(messages.ButtonAck()) assert isinstance(ret, messages.ButtonRequest) @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_already_initialized(client: Client): +def test_already_initialized(session: Session): with pytest.raises(Exception): device.reset( - client, + session, strength=128, passphrase_protection=True, pin_protection=True, diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py index 89c327fb8fe..2d478fba874 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -29,25 +30,30 @@ @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonic = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_management_session() + mnemonic = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) - recover(client, mnemonic) - address_after = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + device.wipe(session) + client = client.get_new_client() + session = Session(client.get_management_session()) + set_language(session, lang[:2]) + recover(session, mnemonic) + session = client.get_session() + address_after = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) assert address_before == address_after -def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str: - with WITH_MOCK_URANDOM, client: +def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowBip39ResetBackup(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -56,26 +62,26 @@ def reset(client: Client, strength: int = 128, skip_backup: bool = False) -> str ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False assert IF.mnemonic is not None return IF.mnemonic -def recover(client: Client, mnemonic: str): +def recover(session: Session, mnemonic: str): words = mnemonic.split(" ") - with client: + with session.client as client: IF = InputFlowBip39Recovery(client, words) client.set_input_flow(IF.get()) client.watch_layout() - ret = device.recover(client, pin_protection=False, label="label") + ret = device.recover(session, pin_protection=False, label="label") # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py index 8b42940d758..f4de1b03c74 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -32,8 +33,10 @@ @pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_management_session() + mnemonics = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) # we're generating 3of5 groups 3of5 shares each test_combinations = [ mnemonics[0:3] # shares 1-3 from groups 1-3 @@ -50,25 +53,28 @@ def test_reset_recovery(client: Client): + mnemonics[22:25], ] for combination in test_combinations: + session = client.get_management_session() lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) - - recover(client, combination) + device.wipe(session) + client = client.get_new_client() + session = Session(client.get_management_session()) + set_language(session, lang[:2]) + recover(session, combination) + session = client.get_session() address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with WITH_MOCK_URANDOM, client: +def reset(session: Session, strength: int = 128) -> list[str]: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -77,25 +83,25 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable return IF.mnemonics -def recover(client: Client, shares: list[str]): - with client: +def recover(session: Session, shares: list[str]): + with session.client as client: IF = InputFlowSlip39AdvancedRecovery(client, shares, False) client.set_input_flow(IF.get()) - ret = device.recover(client, pin_protection=False, label="label") + ret = device.recover(session, pin_protection=False, label="label") # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py index 6b72246a106..f539c2ad449 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py @@ -19,6 +19,7 @@ import pytest from trezorlib import btc, device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import BackupType from trezorlib.tools import parse_path @@ -35,29 +36,35 @@ @pytest.mark.setup_client(uninitialized=True) @WITH_MOCK_URANDOM def test_reset_recovery(client: Client): - mnemonics = reset(client) - address_before = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) + session = client.get_management_session() + mnemonics = reset(session) + session = client.get_session() + address_before = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0")) for share_subset in itertools.combinations(mnemonics, 3): + session = client.get_management_session() lang = client.features.language or "en" - device.wipe(client) - set_language(client, lang[:2]) + device.wipe(session) + client = client.get_new_client() + session = Session(client.get_management_session()) + set_language(session, lang[:2]) selected_mnemonics = share_subset - recover(client, selected_mnemonics) + recover(session, selected_mnemonics) + session = client.get_session() address_after = btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0") + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0") ) assert address_before == address_after -def reset(client: Client, strength: int = 128) -> list[str]: - with client: +def reset(session: Session, strength: int = 128) -> list[str]: + with session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -66,25 +73,25 @@ def reset(client: Client, strength: int = 128) -> list[str]: ) # Check if device is properly initialized - assert client.features.initialized is True + assert session.features.initialized is True assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable return IF.mnemonics -def recover(client: Client, shares: list[str]): - with client: +def recover(session: Session, shares: list[str]): + with session.client as client: IF = InputFlowSlip39BasicRecovery(client, shares) client.set_input_flow(IF.get()) - ret = device.recover(client, pin_protection=False, label="label") + ret = device.recover(session, pin_protection=False, label="label") # Workflow successfully ended assert ret == messages.Success(message="Device recovered") - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py index 6aa9d2bf3d0..04698ae16cb 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py @@ -37,10 +37,10 @@ def test_reset_device_slip39_advanced(client: Client): with WITH_MOCK_URANDOM, client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) - + session = client.get_management_session() # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -54,17 +54,17 @@ def test_reset_device_slip39_advanced(client: Client): # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) - + session = client.get_session() # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Advanced_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Advanced_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) def validate_mnemonics( diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index 8eb5d7830fa..335d34172c4 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -20,7 +20,7 @@ from shamir_mnemonic import MnemonicError, shamir from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import BackupAvailability, BackupType @@ -30,16 +30,16 @@ pytestmark = pytest.mark.models("core") -def reset_device(client: Client, strength: int): +def reset_device(session: Session, strength: int): member_threshold = 3 - with WITH_MOCK_URANDOM, client: + with WITH_MOCK_URANDOM, session.client as client: IF = InputFlowSlip39BasicResetRecovery(client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.reset( - client, + session, strength=strength, passphrase_protection=False, pin_protection=False, @@ -48,32 +48,34 @@ def reset_device(client: Client, strength: int): ) # generate secret locally - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # validate that all combinations will result in the correct master secret validate_mnemonics(IF.mnemonics, member_threshold, secret) - + session = session.client.get_session() # Check if device is properly initialized - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # backup attempt fails because backup was done in reset with pytest.raises(TrezorFailure, match="ProcessError: Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic(client: Client): - reset_device(client, 128) +@pytest.mark.uninitialized_session +def test_reset_device_slip39_basic(session: Session): + reset_device(session, 128) @pytest.mark.setup_client(uninitialized=True) -def test_reset_device_slip39_basic_256(client: Client): - reset_device(client, 256) +@pytest.mark.uninitialized_session +def test_reset_device_slip39_basic_256(session: Session): + reset_device(session, 256) def validate_mnemonics(mnemonics, threshold, expected_ems): diff --git a/tests/device_tests/ripple/test_get_address.py b/tests/device_tests/ripple/test_get_address.py index 0d35b6c5b93..2a066926cd8 100644 --- a/tests/device_tests/ripple/test_get_address.py +++ b/tests/device_tests/ripple/test_get_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.ripple import get_address from trezorlib.tools import parse_path @@ -43,28 +43,28 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_ripple_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_ripple_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_ripple_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address @pytest.mark.setup_client(mnemonic=CUSTOM_MNEMONIC) -def test_ripple_get_address_other(client: Client): +def test_ripple_get_address_other(session: Session): # data from https://github.com/you21979/node-ripple-bip32/blob/master/test/test.js - address = get_address(client, parse_path("m/44h/144h/0h/0/0")) + address = get_address(session, parse_path("m/44h/144h/0h/0/0")) assert address == "r4ocGE47gm4G4LkA9mriVHQqzpMLBTgnTY" - address = get_address(client, parse_path("m/44h/144h/0h/0/1")) + address = get_address(session, parse_path("m/44h/144h/0h/0/1")) assert address == "rUt9ULSrUvfCmke8HTFU1szbmFpWzVbBXW" diff --git a/tests/device_tests/ripple/test_sign_tx.py b/tests/device_tests/ripple/test_sign_tx.py index a03a29d4bec..82911c8abe8 100644 --- a/tests/device_tests/ripple/test_sign_tx.py +++ b/tests/device_tests/ripple/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import ripple -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -29,7 +29,7 @@ @pytest.mark.parametrize("chunkify", (True, False)) -def test_ripple_sign_simple_tx(client: Client, chunkify: bool): +def test_ripple_sign_simple_tx(session: Session, chunkify: bool): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -43,7 +43,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/0"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -66,7 +66,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -92,7 +92,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): } ) resp = ripple.sign_tx( - client, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify + session, parse_path("m/44h/144h/0h/0/2"), msg, chunkify=chunkify ) assert ( resp.signature.hex() @@ -104,7 +104,7 @@ def test_ripple_sign_simple_tx(client: Client, chunkify: bool): ) -def test_ripple_sign_invalid_fee(client: Client): +def test_ripple_sign_invalid_fee(session: Session): msg = ripple.create_sign_tx_msg( { "TransactionType": "Payment", @@ -121,4 +121,4 @@ def test_ripple_sign_invalid_fee(client: Client): TrezorFailure, match="ProcessError: Fee must be in the range of 10 to 10,000 drops", ): - ripple.sign_tx(client, parse_path("m/44h/144h/0h/0/2"), msg) + ripple.sign_tx(session, parse_path("m/44h/144h/0h/0/2"), msg) diff --git a/tests/device_tests/solana/test_address.py b/tests/device_tests/solana/test_address.py index dca1126c056..ce17d7c2a3b 100644 --- a/tests/device_tests/solana/test_address.py +++ b/tests/device_tests/solana/test_address.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_address from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ @parametrize_using_common_fixtures( "solana/get_address.json", ) -def test_solana_get_address(client: Client, parameters, result): +def test_solana_get_address(session: Session, parameters, result): actual_result = get_address( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result.address == result["expected_address"] diff --git a/tests/device_tests/solana/test_public_key.py b/tests/device_tests/solana/test_public_key.py index 864852b116a..abe24dfc8f1 100644 --- a/tests/device_tests/solana/test_public_key.py +++ b/tests/device_tests/solana/test_public_key.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import get_public_key from trezorlib.tools import parse_path @@ -32,9 +32,9 @@ @parametrize_using_common_fixtures( "solana/get_public_key.json", ) -def test_solana_get_public_key(client: Client, parameters, result): +def test_solana_get_public_key(session: Session, parameters, result): actual_result = get_public_key( - client, address_n=parse_path(parameters["path"]), show_display=True + session, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result.public_key.hex() == result["expected_public_key"] diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index 241a3d3b34f..d5685e1ed75 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.solana import sign_tx from trezorlib.tools import parse_path @@ -42,13 +42,11 @@ "solana/sign_tx.unknown_instructions.json", "solana/sign_tx.predefined_transactions.json", ) -def test_solana_sign_tx(client: Client, parameters, result): - client.init_device(new_session=True) - +def test_solana_sign_tx(session: Session, parameters, result): serialized_tx = _serialize_tx(parameters["construct"]) actual_result = sign_tx( - client, + session, address_n=parse_path(parameters["address"]), serialized_tx=serialized_tx, additional_info=( diff --git a/tests/device_tests/stellar/test_stellar.py b/tests/device_tests/stellar/test_stellar.py index 8e214ab1135..1d5c59e1f8e 100644 --- a/tests/device_tests/stellar/test_stellar.py +++ b/tests/device_tests/stellar/test_stellar.py @@ -55,7 +55,7 @@ import pytest from trezorlib import messages, protobuf, stellar -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ...common import parametrize_using_common_fixtures @@ -87,10 +87,10 @@ def make_op(operation_data): @parametrize_using_common_fixtures("stellar/sign_tx.json") -def test_sign_tx(client: Client, parameters, result): +def test_sign_tx(session: Session, parameters, result): tx, operations = parameters_to_proto(parameters) response = stellar.sign_tx( - client, tx, operations, tx.address_n, tx.network_passphrase + session, tx, operations, tx.address_n, tx.network_passphrase ) assert response.public_key.hex() == result["public_key"] assert b64encode(response.signature).decode() == result["signature"] @@ -113,20 +113,20 @@ def test_xdr(parameters, result): @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address(client: Client, parameters, result): +def test_get_address(session: Session, parameters, result): address_n = parse_path(parameters["path"]) - address = stellar.get_address(client, address_n, show_display=True) + address = stellar.get_address(session, address_n, show_display=True) assert address == result["address"] @pytest.mark.models("core") @parametrize_using_common_fixtures("stellar/get_address.json") -def test_get_address_chunkify_details(client: Client, parameters, result): - with client: +def test_get_address_chunkify_details(session: Session, parameters, result): + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) address = stellar.get_address( - client, address_n, show_display=True, chunkify=True + session, address_n, show_display=True, chunkify=True ) assert address == result["address"] diff --git a/tests/device_tests/test_authenticate_device.py b/tests/device_tests/test_authenticate_device.py index f2ffb5d7157..5e697b4f070 100644 --- a/tests/device_tests/test_authenticate_device.py +++ b/tests/device_tests/test_authenticate_device.py @@ -5,7 +5,7 @@ from cryptography.x509 import extensions as ext from trezorlib import device, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import compact_size @@ -35,16 +35,16 @@ ), ), ) -def test_authenticate_device(client: Client, challenge: bytes) -> None: +def test_authenticate_device(session: Session, challenge: bytes) -> None: # NOTE Applications must generate a random challenge for each request. # Issue an AuthenticateDevice challenge to Trezor. - proof = device.authenticate(client, challenge) + proof = device.authenticate(session, challenge) certs = [x509.load_der_x509_certificate(cert) for cert in proof.certificates] # Verify the last certificate in the certificate chain against trust anchor. root_public_key = ec.EllipticCurvePublicKey.from_encoded_point( - ec.SECP256R1(), ROOT_PUBLIC_KEY[client.model] + ec.SECP256R1(), ROOT_PUBLIC_KEY[session.model] ) root_public_key.verify( certs[-1].signature, @@ -78,11 +78,11 @@ def test_authenticate_device(client: Client, challenge: bytes) -> None: # Verify that the common name matches the Trezor model. common_name = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0] - if client.model == models.T3B1: + if session.model == models.T3B1: # XXX TODO replace as soon as we have T3B1 staging internal_model = "T2B1" else: - internal_model = client.model.internal_name + internal_model = session.model.internal_name assert common_name.value.startswith(internal_model) # Verify the signature of the challenge. diff --git a/tests/device_tests/test_autolock.py b/tests/device_tests/test_autolock.py index dc0f69a1df9..a310ff3841e 100644 --- a/tests/device_tests/test_autolock.py +++ b/tests/device_tests/test_autolock.py @@ -19,7 +19,7 @@ import pytest from trezorlib import device, messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from ..common import TEST_ADDRESS_N, get_test_address @@ -29,42 +29,42 @@ pytestmark = pytest.mark.setup_client(pin=PIN4) -def pin_request(client: Client): +def pin_request(session: Session): return ( messages.PinMatrixRequest - if client.model is models.T1B1 + if session.model is models.T1B1 else messages.ButtonRequest ) -def set_autolock_delay(client: Client, delay): - with client: +def set_autolock_delay(session: Session, delay): + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] ) - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) -def test_apply_auto_lock_delay(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_apply_auto_lock_delay(session: Session): + set_autolock_delay(session, 10 * 1000) time.sleep(0.1) # sleep less than auto-lock delay - with client: + with session: # No PIN protection is required. - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) time.sleep(10.5) # sleep more than auto-lock delay - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([pin_request(session), messages.Address]) + get_test_address(session) @pytest.mark.parametrize( @@ -78,44 +78,45 @@ def test_apply_auto_lock_delay(client: Client): 536870, # 149 hours, maximum ], ) -def test_apply_auto_lock_delay_valid(client: Client, seconds): - set_autolock_delay(client, seconds * 1000) - assert client.features.auto_lock_delay_ms == seconds * 1000 +def test_apply_auto_lock_delay_valid(session: Session, seconds): + set_autolock_delay(session, seconds * 1000) + assert session.features.auto_lock_delay_ms == seconds * 1000 -def test_autolock_default_value(client: Client): - assert client.features.auto_lock_delay_ms is None - with client: +def test_autolock_default_value(session: Session): + assert session.features.auto_lock_delay_ms is None + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, label="pls unlock") - client.refresh_features() - assert client.features.auto_lock_delay_ms == 60 * 10 * 1000 + device.apply_settings(session, label="pls unlock") + session.refresh_features() + assert session.features.auto_lock_delay_ms == 60 * 10 * 1000 @pytest.mark.parametrize( "seconds", [0, 1, 9, 536871, 2**22], ) -def test_apply_auto_lock_delay_out_of_range(client: Client, seconds): - with client: +def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - pin_request(client), + pin_request(session), messages.Failure(code=messages.FailureType.ProcessError), ] ) delay = seconds * 1000 with pytest.raises(TrezorFailure): - device.apply_settings(client, auto_lock_delay_ms=delay) + device.apply_settings(session, auto_lock_delay_ms=delay) @pytest.mark.models("core") -def test_autolock_cancels_ui(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_cancels_ui(session: Session): + set_autolock_delay(session, 10 * 1000) - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -126,44 +127,47 @@ def test_autolock_cancels_ui(client: Client): assert isinstance(resp, messages.ButtonRequest) # send an ack, do not read response - client._raw_write(messages.ButtonAck()) + session._write(messages.ButtonAck()) # sleep more than auto-lock delay time.sleep(10.5) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, messages.Failure) assert resp.code == messages.FailureType.ActionCancelled -def test_autolock_ignores_initialize(client: Client): - set_autolock_delay(client, 10 * 1000) +def test_autolock_ignores_initialize(session: Session): + client = session.client + set_autolock_delay(session, 10 * 1000) - assert client.features.unlocked is True + assert session.features.unlocked is True start = time.monotonic() while time.monotonic() - start < 11: # init_device should always work even if locked - client.init_device() + client.resume_session(session) time.sleep(0.1) # after 11 seconds we are definitely locked - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False + +def test_autolock_ignores_getaddress(session: Session): -def test_autolock_ignores_getaddress(client: Client): - set_autolock_delay(client, 10 * 1000) + set_autolock_delay(session, 10 * 1000) - assert client.features.unlocked is True + assert session.features.unlocked is True start = time.monotonic() # let's continue for 8 seconds to give a little leeway to the slow CI while time.monotonic() - start < 8: - get_test_address(client) + get_test_address(session) time.sleep(0.1) # sleep 3 more seconds to wait for autolock time.sleep(3) # after 11 seconds we are definitely locked - client.refresh_features() - assert client.features.unlocked is False + session.refresh_features() + assert session.features.unlocked is False diff --git a/tests/device_tests/test_basic.py b/tests/device_tests/test_basic.py index c2d1202eb52..2955615e11f 100644 --- a/tests/device_tests/test_basic.py +++ b/tests/device_tests/test_basic.py @@ -15,44 +15,64 @@ # If not, see . from trezorlib import device, messages, models +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client def test_features(client: Client): - f0 = client.features - # client erases session_id from its features - f0.session_id = client.session_id - f1 = client.call(messages.Initialize(session_id=f0.session_id)) - assert f0 == f1 + session = client.get_session() + f0 = session.features + if client.protocol_version == ProtocolVersion.PROTOCOL_V1: + # session erases session_id from its features + f0.session_id = session.id + f1 = session.call(messages.Initialize(session_id=session.id)) + assert f0 == f1 + else: + session2 = client.resume_session(session) + f1: messages.Features = session2.call(messages.GetFeatures()) + assert f1.session_id is None + assert f0 == f1 -def test_capabilities(client: Client): - assert (messages.Capability.Translations in client.features.capabilities) == ( - client.model is not models.T1B1 + +def test_capabilities(session: Session): + assert (messages.Capability.Translations in session.features.capabilities) == ( + session.model is not models.T1B1 ) -def test_ping(client: Client): - ping = client.call(messages.Ping(message="ahoj!")) +def test_ping(session: Session): + ping = session.call(messages.Ping(message="ahoj!")) assert ping == messages.Success(message="ahoj!") def test_device_id_same(client: Client): - id1 = client.get_device_id() - client.init_device() - id2 = client.get_device_id() + session1 = client.get_session() + session2 = client.get_session() + id1 = session1.features.device_id + session2.refresh_features() + id2 = session2.features.device_id + client = client.get_new_client() + session3 = client.get_session() + id3 = session3.features.device_id # ID must be at least 12 characters assert len(id1) >= 12 # Every resulf of UUID must be the same assert id1 == id2 + assert id2 == id3 def test_device_id_different(client: Client): - id1 = client.get_device_id() - device.wipe(client) - id2 = client.get_device_id() + session = client.get_management_session() + id1 = client.features.device_id + device.wipe(session) + client = client.get_new_client() + session = client.get_management_session() + + id2 = client.features.device_id # Device ID must be fresh after every reset assert id1 != id2 diff --git a/tests/device_tests/test_bip32_speed.py b/tests/device_tests/test_bip32_speed.py index 1d184c7e4a8..84d8cf9ae59 100644 --- a/tests/device_tests/test_bip32_speed.py +++ b/tests/device_tests/test_bip32_speed.py @@ -19,7 +19,7 @@ import pytest from trezorlib import btc, device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import H_ @@ -29,47 +29,47 @@ ] -def test_public_ckd(client: Client): +def test_public_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() - btc.get_address(client, "Bitcoin", range(depth)) + btc.get_address(session, "Bitcoin", range(depth)) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_private_ckd(client: Client): +def test_private_ckd(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) - btc.get_address(client, "Bitcoin", []) # to compute root node via BIP39 + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) + btc.get_address(session, "Bitcoin", []) # to compute root node via BIP39 for depth in range(8): start = time.time() address_n = [H_(-i) for i in range(-depth, 0)] - btc.get_address(client, "Bitcoin", address_n) + btc.get_address(session, "Bitcoin", address_n) delay = time.time() - start expected = (depth + 1) * 0.26 print("DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay) assert delay <= expected -def test_cache(client: Client): +def test_cache(session: Session): # disable safety checks to access non-standard paths - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) + btc.get_address(session, "Bitcoin", [x, 2, 3, 4, 5, 6, 7, 8]) nocache_time = time.time() - start start = time.time() for x in range(10): - btc.get_address(client, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) + btc.get_address(session, "Bitcoin", [1, 2, 3, 4, 5, 6, 7, x]) cache_time = time.time() - start print("NOCACHE TIME", nocache_time) diff --git a/tests/device_tests/test_busy_state.py b/tests/device_tests/test_busy_state.py index 706745a1981..27fb1b23e6a 100644 --- a/tests/device_tests/test_busy_state.py +++ b/tests/device_tests/test_busy_state.py @@ -20,62 +20,66 @@ from trezorlib import btc, device from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path PIN = "1234" -def _assert_busy(client: Client, should_be_busy: bool, screen: str = "Homescreen"): - assert client.features.busy is should_be_busy - if client.layout_type is not LayoutType.T1: +def _assert_busy(session: Session, should_be_busy: bool, screen: str = "Homescreen"): + assert session.features.busy is should_be_busy + if session.client.layout_type is not LayoutType.T1: if should_be_busy: - assert "CoinJoinProgress" in client.debug.read_layout().all_components() + assert ( + "CoinJoinProgress" + in session.client.debug.read_layout().all_components() + ) else: - assert client.debug.read_layout().main_component() == screen + assert session.client.debug.read_layout().main_component() == screen @pytest.mark.setup_client(pin=PIN) -def test_busy_state(client: Client): - _assert_busy(client, False, "Lockscreen") - assert client.features.unlocked is False +def test_busy_state(session: Session): + _assert_busy(session, False, "Lockscreen") + assert session.features.unlocked is False # Show busy dialog for 1 minute. - device.set_busy(client, expiry_ms=60 * 1000) - _assert_busy(client, True) - assert client.features.unlocked is False + device.set_busy(session, expiry_ms=60 * 1000) + _assert_busy(session, True) + assert session.features.unlocked is False - with client: + with session.client as client: client.use_pin_sequence([PIN]) btc.get_address( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True ) - client.refresh_features() - _assert_busy(client, True) - assert client.features.unlocked is True + session.refresh_features() + _assert_busy(session, True) + assert session.features.unlocked is True # Hide the busy dialog. - device.set_busy(client, None) + device.set_busy(session, None) - _assert_busy(client, False) - assert client.features.unlocked is True + _assert_busy(session, False) + assert session.features.unlocked is True @pytest.mark.models("core") -def test_busy_expiry_core(client: Client): +def test_busy_expiry_core(session: Session): WAIT_TIME_MS = 1500 TOLERANCE = 1000 - _assert_busy(client, False) + _assert_busy(session, False) # Start a timer start = time.monotonic() # Show the busy dialog. - device.set_busy(client, expiry_ms=WAIT_TIME_MS) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=WAIT_TIME_MS) + _assert_busy(session, True) # Wait until the layout changes - client.debug.wait_layout() + time.sleep(0.1) # Improves stability of the test for devices with THP + session.client.debug.wait_layout() end = time.monotonic() # Check that the busy dialog was shown for at least WAIT_TIME_MS. @@ -84,26 +88,26 @@ def test_busy_expiry_core(client: Client): # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) @pytest.mark.flaky(max_runs=5) @pytest.mark.models("legacy") -def test_busy_expiry_legacy(client: Client): - _assert_busy(client, False) +def test_busy_expiry_legacy(session: Session): + _assert_busy(session, False) # Show the busy dialog. - device.set_busy(client, expiry_ms=1500) - _assert_busy(client, True) + device.set_busy(session, expiry_ms=1500) + _assert_busy(session, True) # Hasn't expired yet. time.sleep(0.1) - _assert_busy(client, True) + _assert_busy(session, True) # Wait for it to expire. Add some tolerance to account for CI/hardware slowness. time.sleep(4.0) # Check that the device is no longer busy. # Also needs to come back to Homescreen (for UI tests). - client.refresh_features() - _assert_busy(client, False) + session.refresh_features() + _assert_busy(session, False) diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index b72e95a88e9..9ab7e9165e8 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -17,7 +17,7 @@ import pytest import trezorlib.messages as m -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled from ..common import TEST_ADDRESS_N @@ -35,15 +35,15 @@ ), ], ) -def test_cancel_message_via_cancel(client: Client, message): +def test_cancel_message_via_cancel(session: Session, message): def input_flow(): yield - client.cancel() + session.cancel() - with client, pytest.raises(Cancelled): - client.set_expected_responses([m.ButtonRequest(), m.Failure()]) + with session, session.client as client, pytest.raises(Cancelled): + session.set_expected_responses([m.ButtonRequest(), m.Failure()]) client.set_input_flow(input_flow) - client.call(message) + session.call(message) @pytest.mark.parametrize( @@ -58,43 +58,45 @@ def input_flow(): ), ], ) -def test_cancel_message_via_initialize(client: Client, message): - resp = client.call_raw(message) +@pytest.mark.protocol("protocol_v1") +def test_cancel_message_via_initialize(session: Session, message): + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client._raw_write(m.Initialize()) + session._write(m.ButtonAck()) + session._write(m.Initialize()) - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.Features) @pytest.mark.models("core") -def test_cancel_on_paginated(client: Client): +def test_cancel_on_paginated(session: Session): """Check that device is responsive on paginated screen. See #1708.""" # In #1708, the device would ignore USB (or UDP) events while waiting for the user # to page through the screen. This means that this testcase, instead of failing, # would get stuck waiting for the _raw_read result. # I'm not spending the effort to modify the testcase to cause a _failure_ if that # happens again. Just be advised that this should not get stuck. + message = m.SignMessage( message=b"hello" * 64, address_n=TEST_ADDRESS_N, coin_name="Testnet", ) - resp = client.call_raw(message) + resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client._raw_write(m.ButtonAck()) - client.debug.press_yes() + session._write(m.ButtonAck()) + session.client.debug.press_yes() - resp = client._raw_read() + resp = session._read() assert isinstance(resp, m.ButtonRequest) assert resp.pages is not None - client._raw_write(m.ButtonAck()) + session._write(m.ButtonAck()) - client._raw_write(m.Cancel()) - resp = client._raw_read() + session._write(m.Cancel()) + resp = session._read() assert isinstance(resp, m.Failure) assert resp.code == m.FailureType.ActionCancelled diff --git a/tests/device_tests/test_debuglink.py b/tests/device_tests/test_debuglink.py index 747613db127..d9445fddec8 100644 --- a/tests/device_tests/test_debuglink.py +++ b/tests/device_tests/test_debuglink.py @@ -17,6 +17,8 @@ import pytest from trezorlib import debuglink, device, messages, misc +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path from trezorlib.transport import udp @@ -32,35 +34,41 @@ def test_layout(client: Client): @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_mnemonic(client: Client): - client.ensure_unlocked() - mnemonic = client.debug.state().mnemonic_secret +def test_mnemonic(session: Session): + session.ensure_unlocked() + mnemonic = session.client.debug.state().mnemonic_secret assert mnemonic == MNEMONIC12.encode() @pytest.mark.models("legacy") @pytest.mark.setup_client(mnemonic=MNEMONIC12, pin="1234", passphrase="") -def test_pin(client: Client): - resp = client.call_raw(messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0"))) +def test_pin(session: Session): + resp = session.call_raw( + messages.GetAddress(address_n=parse_path("m/44'/0'/0'/0/0")) + ) assert isinstance(resp, messages.PinMatrixRequest) - state = client.debug.state() - assert state.pin == "1234" - assert state.matrix != "" + with session.client as client: + state = client.debug.state() + assert state.pin == "1234" + assert state.matrix != "" - pin_encoded = client.debug.encode_pin("1234") - resp = client.call_raw(messages.PinMatrixAck(pin=pin_encoded)) - assert isinstance(resp, messages.PassphraseRequest) + pin_encoded = client.debug.encode_pin("1234") + resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) + assert isinstance(resp, messages.PassphraseRequest) - resp = client.call_raw(messages.PassphraseAck(passphrase="")) - assert isinstance(resp, messages.Address) + resp = session.call_raw(messages.PassphraseAck(passphrase="")) + assert isinstance(resp, messages.Address) @pytest.mark.models("core") -def test_softlock_instability(client: Client): +def test_softlock_instability(session: Session): + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: + raise Exception("THIS NEEDS TO BE CHANGED FOR THP") + def load_device(): debuglink.load_device( - client, + session, mnemonic=MNEMONIC12, pin="1234", passphrase_protection=False, @@ -68,27 +76,29 @@ def load_device(): ) # start from a clean slate: - resp = client.debug.reseed(0) + resp = session.client.debug.reseed(0) if isinstance(resp, messages.Failure) and not isinstance( - client.transport, udp.UdpTransport + session.client.transport, udp.UdpTransport ): pytest.xfail("reseed only supported on emulator") - device.wipe(client) - entropy_after_wipe = misc.get_entropy(client, 16) + device.wipe(session) + entropy_after_wipe = misc.get_entropy(session, 16) + session.refresh_features() # configure and wipe the device load_device() - client.debug.reseed(0) - device.wipe(client) - assert misc.get_entropy(client, 16) == entropy_after_wipe + session.client.debug.reseed(0) + device.wipe(session) + assert misc.get_entropy(session, 16) == entropy_after_wipe + session.refresh_features() load_device() # the device has PIN -> lock it - client.call(messages.LockDevice()) - client.debug.reseed(0) + session.call(messages.LockDevice()) + session.client.debug.reseed(0) # wipe_device should succeed with no need to unlock - device.wipe(client) + device.wipe(session) # the device is now trying to run the lockscreen, which attempts to unlock. # If the device actually called config.unlock(), it would use additional randomness. # That is undesirable. Assert that the returned entropy is still the same. - assert misc.get_entropy(client, 16) == entropy_after_wipe + assert misc.get_entropy(session, 16) == entropy_after_wipe diff --git a/tests/device_tests/test_firmware_hash.py b/tests/device_tests/test_firmware_hash.py index 50eb063c2b3..217be1c45d9 100644 --- a/tests/device_tests/test_firmware_hash.py +++ b/tests/device_tests/test_firmware_hash.py @@ -3,7 +3,7 @@ import pytest from trezorlib import firmware, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session # size of FIRMWARE_AREA, see core/embed/models/model_*_layout.c FIRMWARE_LENGTHS = { @@ -15,35 +15,35 @@ } -def test_firmware_hash_emu(client: Client) -> None: - if client.features.fw_vendor != "EMULATOR": +def test_firmware_hash_emu(session: Session) -> None: + if session.features.fw_vendor != "EMULATOR": pytest.skip("Only for emulator") - data = b"\xff" * FIRMWARE_LENGTHS[client.model] + data = b"\xff" * FIRMWARE_LENGTHS[session.model] expected_hash = blake2s(data).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash == expected_hash challenge = b"Hello Trezor" expected_hash = blake2s(data, key=challenge).digest() - hash = firmware.get_hash(client, challenge) + hash = firmware.get_hash(session, challenge) assert hash == expected_hash -def test_firmware_hash_hw(client: Client) -> None: - if client.features.fw_vendor == "EMULATOR": +def test_firmware_hash_hw(session: Session) -> None: + if session.features.fw_vendor == "EMULATOR": pytest.skip("Only for hardware") # TODO get firmware image from outside the environment, check for actual result challenge = b"Hello Trezor" - empty_data = b"\xff" * FIRMWARE_LENGTHS[client.model] + empty_data = b"\xff" * FIRMWARE_LENGTHS[session.model] empty_hash = blake2s(empty_data).digest() empty_hash_challenge = blake2s(empty_data, key=challenge).digest() - hash = firmware.get_hash(client, None) + hash = firmware.get_hash(session, None) assert hash != empty_hash - hash2 = firmware.get_hash(client, challenge) + hash2 = firmware.get_hash(session, challenge) assert hash != hash2 assert hash2 != empty_hash_challenge diff --git a/tests/device_tests/test_language.py b/tests/device_tests/test_language.py index d313608ee20..71f18bd9404 100644 --- a/tests/device_tests/test_language.py +++ b/tests/device_tests/test_language.py @@ -23,6 +23,7 @@ from trezorlib import debuglink, device, exceptions, messages, models from trezorlib._internal import translations +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import message_filters @@ -57,228 +58,235 @@ def get_ping_title(lang: str) -> str: @pytest.fixture -def client(client: Client) -> Iterator[Client]: - lang_before = client.features.language or "" +def session(session: Session) -> Iterator[Session]: + lang_before = session.features.language or "" try: - set_language(client, "en") - yield client + set_language(session, "en") + yield session finally: - set_language(client, lang_before[:2]) + set_language(session, lang_before[:2]) -def _check_ping_screen_texts(client: Client, title: str, right_button: str) -> None: - def ping_input_flow(client: Client, title: str, right_button: str): +def _check_ping_screen_texts(session: Session, title: str, right_button: str) -> None: + def ping_input_flow(session: Session, title: str, right_button: str): yield - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert layout.title().upper() == title.upper() assert layout.button_contents()[-1].upper() == right_button.upper() - client.debug.press_yes() + session.client.debug.press_yes() # TT does not have a right button text (but a green OK tick) - if client.model in (models.T2T1, models.T3T1): + if session.model in (models.T2T1, models.T3T1): right_button = "-" - with client: + with session, session.client as client: client.watch_layout(True) - client.set_input_flow(ping_input_flow(client, title, right_button)) - ping = client.call(messages.Ping(message="ahoj!", button_protection=True)) + client.set_input_flow(ping_input_flow(session, title, right_button)) + ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) assert ping == messages.Success(message="ahoj!") -def test_error_too_long(client: Client): - assert client.features.language == "en-US" +def test_error_too_long(session: Session): + assert session.features.language == "en-US" # Translations too long # Sending more than allowed by the flash capacity - max_length = MAX_DATA_LENGTH[client.model] - with pytest.raises(exceptions.TrezorFailure, match="Translations too long"), client: + max_length = MAX_DATA_LENGTH[session.model] + with pytest.raises( + exceptions.TrezorFailure, match="Translations too long" + ), session: bad_data = (max_length + 1) * b"a" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_length(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_length(session: Session): + assert session.features.language == "en-US" # Invalid data length # Sending more data than advertised in the header - with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), client: - good_data = build_and_sign_blob("cs", client) + with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data + b"abcd" - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_header_magic(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_header_magic(session: Session): + assert session.features.language == "en-US" # Invalid header magic # Does not match the expected magic with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = 4 * b"a" + good_data[4:] - device.change_language(client, language_data=bad_data) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + device.change_language(session, language_data=bad_data) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_data_hash(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_data_hash(session: Session): + assert session.features.language == "en-US" # Invalid data hash # Changing the data after their hash has been calculated with pytest.raises( exceptions.TrezorFailure, match="Translation data verification failed" - ), client: - good_data = build_and_sign_blob("cs", client) + ), session: + good_data = build_and_sign_blob("cs", session) bad_data = good_data[:-8] + 8 * b"a" device.change_language( - client, + session, language_data=bad_data, ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_version_mismatch(client: Client): - assert client.features.language == "en-US" +def test_error_version_mismatch(session: Session): + assert session.features.language == "en-US" # Translations version mismatch # Change the version to one not matching the current device with pytest.raises( exceptions.TrezorFailure, match="Translations version mismatch" - ), client: - blob = prepare_blob("cs", client.model, (3, 5, 4, 0)) + ), session: + blob = prepare_blob("cs", session.model, (3, 5, 4, 0)) device.change_language( - client, + session, language_data=sign_blob(blob), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) -def test_error_invalid_signature(client: Client): - assert client.features.language == "en-US" +def test_error_invalid_signature(session: Session): + assert session.features.language == "en-US" # Invalid signature # Changing the data in the signature section with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), client: - blob = prepare_blob("cs", client.model, client.version) + ), session: + blob = prepare_blob("cs", session.model, session.version) blob.proof = translations.Proof( merkle_proof=[], sigmask=0b011, signature=b"a" * 64, ) device.change_language( - client, + session, language_data=blob.build(), ) - assert client.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + assert session.features.language == "en-US" + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) @pytest.mark.parametrize("lang", LANGUAGES) -def test_full_language_change(client: Client, lang: str): - assert client.features.language == "en-US" - assert client.features.language_version_matches is True +def test_full_language_change(session: Session, lang: str): + assert session.features.language == "en-US" + assert session.features.language_version_matches is True # Setting selected language - set_language(client, lang) - assert client.features.language[:2] == lang - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + set_language(session, lang) + assert session.features.language[:2] == lang + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) # Setting the default language via empty data - set_language(client, "en") - assert client.features.language == "en-US" - assert client.features.language_version_matches is True - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + set_language(session, "en") + assert session.features.language == "en-US" + assert session.features.language_version_matches is True + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def test_language_is_removed_after_wipe(client: Client): - assert client.features.language == "en-US" + session = Session(client.get_session()) + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Setting cs language - set_language(client, "cs") - assert client.features.language == "cs-CZ" + set_language(session, "cs") + assert session.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Wipe device - device.wipe(client) - assert client.features.language == "en-US" + device.wipe(session) + client = client.get_new_client() + session = Session(client.get_management_session()) + assert session.features.language == "en-US" # Load it again debuglink.load_device( - client, + session, mnemonic=" ".join(["all"] * 12), pin=None, passphrase_protection=False, label="test", ) - assert client.features.language == "en-US" + assert session.features.language == "en-US" + + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) +def test_translations_renders_on_screen(session: Session): -def test_translations_renders_on_screen(client: Client): czech_data = get_lang_json("cs") # Setting some values of words__confirm key and checking that in ping screen title - assert client.features.language == "en-US" + assert session.features.language == "en-US" # Normal english - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) - + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) # Normal czech - set_language(client, "cs") - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title("cs"), get_ping_button("cs")) + set_language(session, "cs") + + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title("cs"), get_ping_button("cs")) # Modified czech - changed value czech_data_copy = deepcopy(czech_data) new_czech_confirm = "ABCD" czech_data_copy["translations"]["words__confirm"] = new_czech_confirm device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, new_czech_confirm, get_ping_button("cs")) + _check_ping_screen_texts(session, new_czech_confirm, get_ping_button("cs")) # Modified czech - key deleted completely, english is shown czech_data_copy = deepcopy(czech_data) del czech_data_copy["translations"]["words__confirm"] device.change_language( - client, - language_data=build_and_sign_blob(czech_data_copy, client), + session, + language_data=build_and_sign_blob(czech_data_copy, session), ) - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("cs")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("cs")) + +def test_reject_update(session: Session): -def test_reject_update(client: Client): - assert client.features.language == "en-US" + assert session.features.language == "en-US" lang = "cs" - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) def input_flow_reject(): yield - client.debug.press_no() + session.client.debug.press_no() - with pytest.raises(exceptions.Cancelled), client: + with pytest.raises(exceptions.Cancelled), session, session.client as client: client.set_input_flow(input_flow_reject) - device.change_language(client, language_data) + device.change_language(session, language_data) - assert client.features.language == "en-US" + assert session.features.language == "en-US" - _check_ping_screen_texts(client, get_ping_title("en"), get_ping_button("en")) + _check_ping_screen_texts(session, get_ping_title("en"), get_ping_button("en")) def _maybe_confirm_set_language( - client: Client, lang: str, show_display: bool | None, is_displayed: bool + session: Session, lang: str, show_display: bool | None, is_displayed: bool ) -> None: - language_data = build_and_sign_blob(lang, client) + language_data = build_and_sign_blob(lang, session) CHUNK_SIZE = 1024 @@ -289,34 +297,35 @@ def chunks(data, size): expected_responses_silent: list[Any] = [ messages.TranslationDataRequest(data_offset=off, data_length=len) for off, len in chunks(language_data, CHUNK_SIZE) - ] + [message_filters.Success(), message_filters.Features()] + ] + [message_filters.Success()] + # , message_filters.Features()] expected_responses_confirm = expected_responses_silent[:] # confirmation after first TranslationDataRequest expected_responses_confirm.insert(1, message_filters.ButtonRequest()) # success screen before Success / Features - expected_responses_confirm.insert(-2, message_filters.ButtonRequest()) + expected_responses_confirm.insert(-1, message_filters.ButtonRequest()) if is_displayed: expected_responses = expected_responses_confirm else: expected_responses = expected_responses_silent - with client: - client.set_expected_responses(expected_responses) - device.change_language(client, language_data, show_display=show_display) - assert client.features.language is not None - assert client.features.language[:2] == lang + with session: + session.set_expected_responses(expected_responses) + device.change_language(session, language_data, show_display=show_display) + assert session.features.language is not None + assert session.features.language[:2] == lang # explicitly handle the cases when expected_responses are correct for # change_language but incorrect for selected is_displayed mode (otherwise the # user would get an unhelpful generic expected_responses mismatch) - if is_displayed and client.actual_responses == expected_responses_silent: + if is_displayed and session.actual_responses == expected_responses_silent: raise AssertionError("Change should have been visible but was silent") - if not is_displayed and client.actual_responses == expected_responses_confirm: + if not is_displayed and session.actual_responses == expected_responses_confirm: raise AssertionError("Change should have been silent but was visible") # if the expected_responses do not match either, the generic error message will - # be raised by the client context manager + # be raised by the session context manager @pytest.mark.parametrize( @@ -328,61 +337,64 @@ def chunks(data, size): ], ) @pytest.mark.setup_client(uninitialized=True) -def test_silent_first_install(client: Client, show_display: bool, is_displayed: bool): - assert not client.features.initialized - _maybe_confirm_set_language(client, "cs", show_display, is_displayed) +@pytest.mark.uninitialized_session +def test_silent_first_install(session: Session, show_display: bool, is_displayed: bool): + assert not session.features.initialized + _maybe_confirm_set_language(session, "cs", show_display, is_displayed) @pytest.mark.parametrize("show_display", (True, None)) -def test_switch_from_english(client: Client, show_display: bool | None): - assert client.features.initialized - assert client.features.language == "en-US" - _maybe_confirm_set_language(client, "cs", show_display, True) +def test_switch_from_english(session: Session, show_display: bool | None): + assert session.features.initialized + assert session.features.language == "en-US" + _maybe_confirm_set_language(session, "cs", show_display, True) -def test_switch_from_english_not_silent(client: Client): - assert client.features.initialized - assert client.features.language == "en-US" +def test_switch_from_english_not_silent(session: Session): + assert session.features.initialized + assert session.features.language == "en-US" with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) @pytest.mark.setup_client(uninitialized=True) -def test_switch_language(client: Client): - assert not client.features.initialized - assert client.features.language == "en-US" +@pytest.mark.uninitialized_session +def test_switch_language(session: Session): + assert not session.features.initialized + assert session.features.language == "en-US" # switch to Czech silently - _maybe_confirm_set_language(client, "cs", False, False) + _maybe_confirm_set_language(session, "cs", False, False) # switch to French silently with pytest.raises( exceptions.TrezorFailure, match="Cannot change language without user prompt" ): - _maybe_confirm_set_language(client, "fr", False, False) + _maybe_confirm_set_language(session, "fr", False, False) # switch to French with display, explicitly - _maybe_confirm_set_language(client, "fr", True, True) + _maybe_confirm_set_language(session, "fr", True, True) # switch back to Czech with display, implicitly - _maybe_confirm_set_language(client, "cs", None, True) + _maybe_confirm_set_language(session, "cs", None, True) -def test_header_trailing_data(client: Client): +def test_header_trailing_data(session: Session): """Adding trailing data to _header_ section specifically must be accepted by firmware, as long as the blob is otherwise valid and signed. (this ensures forwards compatibility if we extend the header) """ - assert client.features.language == "en-US" + + assert session.features.language == "en-US" lang = "cs" - blob = prepare_blob(lang, client.model, client.version) + blob = prepare_blob(lang, session.model, session.version) blob.header_bytes += b"trailing dataa" assert len(blob.header_bytes) % 2 == 0, "Trailing data must keep the 2-alignment" language_data = sign_blob(blob) - device.change_language(client, language_data) - assert client.features.language == "cs-CZ" - _check_ping_screen_texts(client, get_ping_title(lang), get_ping_button(lang)) + device.change_language(session, language_data) + assert session.features.language == "cs-CZ" + _check_ping_screen_texts(session, get_ping_title(lang), get_ping_button(lang)) diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 65ea9357481..18fde33506a 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -19,7 +19,8 @@ import pytest from trezorlib import btc, device, exceptions, messages, misc, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ..input_flows import InputFlowConfirmAllWarnings @@ -30,7 +31,7 @@ EXPECTED_RESPONSES_NOPIN = [ messages.ButtonRequest(), messages.Success, - messages.Features, + # messages.Features, ] EXPECTED_RESPONSES_PIN_T1 = [messages.PinMatrixRequest()] + EXPECTED_RESPONSES_NOPIN EXPECTED_RESPONSES_PIN_TT = [messages.ButtonRequest()] + EXPECTED_RESPONSES_NOPIN @@ -38,7 +39,7 @@ EXPECTED_RESPONSES_EXPERIMENTAL_FEATURES = [ messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] PIN4 = "1234" @@ -50,173 +51,178 @@ TR_HOMESCREEN = b"TOIG\x80\x00@\x00\x0c\x04\x00\x00\xa5RY\x96\xdc0\x08\xe4\x06\xdc\xff\x96\xdc\x80\xa8\x16\x90z\xd2y\xf9\x18{\xc0\xf1\xe5\xc9y\x0f\x95\x7f;C\xfe\xd0\xe1K\xefS\x96o\xf9\xb739\x1a\n\xc7\xde\x89\xff\x11\xd8=\xd5\xcf\xb1\x9f\xf7U\xf2\xa3spx\xb0&t\xe4\xaf3x\xcaT\xec\xe50k\xb4\xe8\nl\x16\xbf`'\xf3\xa7Z\x8d-\x98h\x1c\x03\x07\xf0\xcf\xf0\x8aD\x13\xec\x1f@y\x9e\xd8\xa3\xc6\x84F*\x1dx\x02U\x00\x10\xd3\x8cF\xbb\x97y\x18J\xa5T\x18x\x1c\x02\xc6\x90\xfd\xdc\x89\x1a\x94\xb3\xeb\x01\xdc\x9f2\x8c/\xe9/\x8c$\xc6\x9c\x1e\xf8C\x8f@\x17Q\x1d\x11F\x02g\xe4A \xebO\xad\xc6\xe3F\xa7\x8b\xf830R\x82\x0b\x8e\x16\x1dL,\x14\xce\x057tht^\xfe\x00\x9e\x86\xc2\x86\xa3b~^Bl\x18\x1f\xb9+w\x11\x14\xceO\xe9\xb6W\xd8\x85\xbeX\x17\xc2\x13,M`y\xd1~\xa3/\xcd0\xed6\xda\xf5b\x15\xb5\x18\x0f_\xf6\xe2\xdc\x8d\x8ez\xdd\xd5\r^O\x9e\xb6|\xc4e\x0f\x1f\xff0k\xd4\xb8\n\x12{\x8d\x8a>\x0b5\xa2o\xf2jZ\xe5\xee\xdc\x14\xd1\xbd\xd5\xad\x95\xbe\x8c\t\x8f\xb9\xde\xc4\xa551,#`\x94'\x1b\xe7\xd53u\x8fq\xbd4v>3\x8f\xcc\x1d\xbcV>\x90^\xb3L\xc3\xde0]\x05\xec\x83\xd0\x07\xd2(\xbb\xcf+\xd0\xc7ru\xecn\x14k-\xc0|\xd2\x0e\xe8\xe08\xa8<\xdaQ+{\xad\x01\x02#\x16\x12+\xc8\xe0P\x06\xedD7\xae\xd0\xa4\x97\x84\xe32\xca;]\xd04x:\x94`\xbe\xca\x89\xe2\xcb\xc5L\x03\xac|\xe7\xd5\x1f\xe3\x08_\xee!\x04\xd2\xef\x00\xd8\xea\x91p)\xed^#\xb1\xa78eJ\x00F*\xc7\xf1\x0c\x1a\x04\xf5l\xcc\xfc\xa4\x83,c\x1e\xb1>\xc5q\x8b\xe6Y9\xc7\x07\xfa\xcf\xf9\x15\x8a\xdd\x11\x1f\x98\x82\xbe>\xbe+u#g]aC\\\x1bC\xb1\xe8P\xce2\xd6\xb6r\x12\x1c*\xd3\x92\x9d9\xf9cB\x82\xf9S.\xc2B\xe7\x9d\xcf\xdb\xf3\xfd#\xfd\x94x9p\x8d%\x14\xa5\xb3\xe9p5\xa1;~4:\xcd\xe0&\x11\x1d\xe9\xf6\xa1\x1fw\xf54\x95eWx\xda\xd0u\x91\x86\xb8\xbc\xdf\xdc\x008f\x15\xc6\xf6\x7f\xf0T\xb8\xc1\xa3\xc5_A\xc0G\x930\xe7\xdc=\xd5\xa7\xc1\xbcI\x16\xb8s\x9c&\xaa\x06\xc1}\x8b\x19\x9d'c\xc3\xe3^\xc3m\xb6n\xb0(\x16\xf6\xdeg\xb3\x96:i\xe5\x9c\x02\x93\x9fF\x9f-\xa7\"w\xf3X\x9f\x87\x08\x84\"v,\xab!9:. from trezorlib import messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session -def test_ping(client: Client): - with client: - client.set_expected_responses([messages.Success]) - res = client.ping("random data") - assert res == "random data" +def test_ping(session: Session): + with session: + session.set_expected_responses([messages.Success]) + res = session.call(messages.Ping(message="random data")) + assert res.message == "random data" - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.Success, ] ) - res = client.ping("random data", button_protection=True) - assert res == "random data" + res = session.call( + messages.Ping(message="random data 2", button_protection=True) + ) + assert res.message == "random data 2" diff --git a/tests/device_tests/test_msg_sd_protect.py b/tests/device_tests/test_msg_sd_protect.py index fb305613825..60c55c85221 100644 --- a/tests/device_tests/test_msg_sd_protect.py +++ b/tests/device_tests/test_msg_sd_protect.py @@ -17,6 +17,7 @@ import pytest from trezorlib import debuglink, device +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op @@ -26,64 +27,71 @@ pytestmark = [pytest.mark.models("core", skip="safe3"), pytest.mark.sd_card] -def test_enable_disable(client: Client): - assert client.features.sd_protection is False +def test_enable_disable(session: Session): + assert session.features.sd_protection is False # Disabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.DISABLE) + device.sd_protect(session, Op.DISABLE) # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Enabling SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False -def test_refresh(client: Client): - assert client.features.sd_protection is False +def test_refresh(session: Session): + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is True + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is True # Disable SD protection - device.sd_protect(client, Op.DISABLE) - assert client.features.sd_protection is False + device.sd_protect(session, Op.DISABLE) + assert session.features.sd_protection is False # Refreshing SD protection should fail with pytest.raises(TrezorFailure): - device.sd_protect(client, Op.REFRESH) - assert client.features.sd_protection is False + device.sd_protect(session, Op.REFRESH) + assert session.features.sd_protection is False def test_wipe(client: Client): + session = client.get_management_session() # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Wipe device (this wipes internal storage) - device.wipe(client) - assert client.features.sd_protection is False + device.wipe(session) + client = client.get_new_client() + session = client.get_management_session() + assert session.features.sd_protection is False # Restore device to working status debuglink.load_device( - client, mnemonic=MNEMONIC12, pin=None, passphrase_protection=False, label="test" + session, + mnemonic=MNEMONIC12, + pin=None, + passphrase_protection=False, + label="test", ) - assert client.features.sd_protection is False + assert session.features.sd_protection is False # Enable SD protection - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True # Refresh SD protection - device.sd_protect(client, Op.REFRESH) + device.sd_protect(session, Op.REFRESH) diff --git a/tests/device_tests/test_msg_show_device_tutorial.py b/tests/device_tests/test_msg_show_device_tutorial.py index 52904c50c50..f6a083879f3 100644 --- a/tests/device_tests/test_msg_show_device_tutorial.py +++ b/tests/device_tests/test_msg_show_device_tutorial.py @@ -17,11 +17,11 @@ import pytest from trezorlib import device -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("safe") -def test_tutorial(client: Client): - device.show_device_tutorial(client) - assert client.features.initialized is False +def test_tutorial(session: Session): + device.show_device_tutorial(session) + assert session.features.initialized is False diff --git a/tests/device_tests/test_msg_wipedevice.py b/tests/device_tests/test_msg_wipedevice.py index d94f392f1b5..8275bfc715f 100644 --- a/tests/device_tests/test_msg_wipedevice.py +++ b/tests/device_tests/test_msg_wipedevice.py @@ -19,6 +19,7 @@ import pytest from trezorlib import device, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from ..common import get_test_address @@ -31,33 +32,39 @@ def test_wipe_device(client: Client): assert client.features.initialized is True assert client.features.label == "test" assert client.features.passphrase_protection is True - device_id = client.get_device_id() - - device.wipe(client) + device_id = client.features.device_id + device.wipe(client.get_session()) + client = client.get_new_client() assert client.features.initialized is False assert client.features.label is None assert client.features.passphrase_protection is False - assert client.get_device_id() != device_id + assert client.features.device_id != device_id @pytest.mark.setup_client(pin=PIN4) -def test_autolock_not_retained(client: Client): +def test_autolock_not_retained(session: Session): + client = session.client with client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, auto_lock_delay_ms=10_000) + device.apply_settings(session, auto_lock_delay_ms=10_000) + + assert session.features.auto_lock_delay_ms == 10_000 - assert client.features.auto_lock_delay_ms == 10_000 + device.wipe(session) + client = client.get_new_client() + session = client.get_management_session() - device.wipe(client) assert client.features.auto_lock_delay_ms > 10_000 with client: client.use_pin_sequence([PIN4, PIN4]) - device.reset(client, skip_backup=True, pin_protection=True) + device.reset(session, skip_backup=True, pin_protection=True) time.sleep(10.5) - with client: + session = Session(client.get_session()) + + with session, client: # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.set_expected_responses([messages.Address]) + get_test_address(session) diff --git a/tests/device_tests/test_passphrase_slip39_advanced.py b/tests/device_tests/test_passphrase_slip39_advanced.py index 64ef1f5e577..89a68fb1de2 100644 --- a/tests/device_tests/test_passphrase_slip39_advanced.py +++ b/tests/device_tests/test_passphrase_slip39_advanced.py @@ -34,14 +34,14 @@ def test_128bit_passphrase(client: Client): xprv9s21ZrQH143K3dzDLfeY3cMp23u5vDeFYftu5RPYZPucKc99mNEddU4w99GxdgUGcSfMpVDxhnR1XpJzZNXRN1m6xNgnzFS5MwMP6QyBRKV """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mkKDUMRR1CcK8eLAzCZAjKnNbCquPoWPxN" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare + assert address_compare == "n1HeeeojjHgQnG6Bf5VWkM1gcpQkkXqSGw" @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_33, passphrase=True) @@ -53,11 +53,10 @@ def test_256bit_passphrase(client: Client): xprv9s21ZrQH143K2UspC9FRPfQC9NcDB4HPkx1XG9UEtuceYtpcCZ6ypNZWdgfxQ9dAFVeD1F4Zg4roY7nZm2LB7THPD6kaCege3M7EuS8v85c """ assert client.features.passphrase_protection is True - client.use_passphrase("TREZOR") - address = get_test_address(client) + session = client.get_session(passphrase="TREZOR") + address = get_test_address(session) assert address == "mxVtGxUJ898WLzPMmy6PT1FDHD1GUCWGm7" - client.clear_session() - client.use_passphrase("ROZERT") - address_compare = get_test_address(client) + session = client.get_session(passphrase="ROZERT") + address_compare = get_test_address(session) assert address != address_compare diff --git a/tests/device_tests/test_passphrase_slip39_basic.py b/tests/device_tests/test_passphrase_slip39_basic.py index de0e7a734b2..120a6f556ec 100644 --- a/tests/device_tests/test_passphrase_slip39_basic.py +++ b/tests/device_tests/test_passphrase_slip39_basic.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from ..common import ( MNEMONIC_SLIP39_BASIC_20_3of6, @@ -28,14 +28,14 @@ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6, passphrase="TREZOR") -def test_3of6_passphrase(client: Client): +def test_3of6_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2pMWi8jrTawHaj16uKk4CSbvo4Zt61tcrmuUDMx2o1Byzcr3saXNGNvHP8zZgXVdJHsXVdzYFPavxvCyaGyGr1WkAYG83ce """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mi4HXfRJAqCDyEdet5veunBvXLTKSxpuim" @@ -46,25 +46,25 @@ def test_3of6_passphrase(client: Client): ), passphrase="TREZOR", ) -def test_2of5_passphrase(client: Client): +def test_2of5_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: provided by Andrew, address calculated via https://iancoleman.io/bip39/ xprv9s21ZrQH143K2o6EXEHpVy8TCYoMmkBnDCCESLdR2ieKwmcNG48ck2XJQY4waS7RUQcXqR9N7HnQbUVEDMWYyREdF1idQqxFHuCfK7fqFni """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "mjXH4pN7TtbHp3tWLqVKktKuaQeByHMoBZ" @pytest.mark.setup_client( mnemonic=MNEMONIC_SLIP39_BASIC_EXT_20_2of3, passphrase="TREZOR" ) -def test_2of3_ext_passphrase(client: Client): +def test_2of3_ext_passphrase(session: Session): """ BIP32 Root Key for passphrase TREZOR: xprv9s21ZrQH143K4FS1qQdXYAFVAHiSAnjj21YAKGh2CqUPJ2yQhMmYGT4e5a2tyGLiVsRgTEvajXkxhg92zJ8zmWZas9LguQWz7WZShfJg6RS """ - assert client.features.passphrase_protection is True - address = get_test_address(client) + assert session.features.passphrase_protection is True + address = get_test_address(session) assert address == "moELJhDbGK41k6J2ePYh2U8uc5qskC663C" diff --git a/tests/device_tests/test_pin.py b/tests/device_tests/test_pin.py index ee58790c046..c911dfee503 100644 --- a/tests/device_tests/test_pin.py +++ b/tests/device_tests/test_pin.py @@ -19,7 +19,7 @@ import pytest from trezorlib import messages, models -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import PinException from ..common import check_pin_backoff_time, get_test_address @@ -32,18 +32,18 @@ @pytest.mark.setup_client(pin=None) -def test_no_protection(client: Client): - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) +def test_no_protection(session: Session): + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) -def test_correct_pin(client: Client): - with client: +def test_correct_pin(session: Session): + with session, session.client as client: client.use_pin_sequence([PIN4]) # Expected responses differ between T1 and TT - is_t1 = client.model is models.T1B1 - client.set_expected_responses( + is_t1 = session.model is models.T1B1 + session.set_expected_responses( [ (is_t1, messages.PinMatrixRequest), ( @@ -53,45 +53,44 @@ def test_correct_pin(client: Client): messages.Address, ] ) - # client.set_expected_responses([messages.ButtonRequest, messages.Address]) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_incorrect_pin_t1(client: Client): +def test_incorrect_pin_t1(session: Session): with pytest.raises(PinException): - client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + session.client.use_pin_sequence([BAD_PIN]) + get_test_address(session) @pytest.mark.models("core") -def test_incorrect_pin_t2(client: Client): - with client: +def test_incorrect_pin_t2(session: Session): + with session, session.client as client: # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt client.use_pin_sequence([BAD_PIN, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.Address, ] ) - get_test_address(client) + get_test_address(session) @pytest.mark.models("legacy") -def test_exponential_backoff_t1(client: Client): +def test_exponential_backoff_t1(session: Session): for attempt in range(3): start = time.time() - with client, pytest.raises(PinException): + with session, session.client as client, pytest.raises(PinException): client.use_pin_sequence([BAD_PIN]) - get_test_address(client) + get_test_address(session) check_pin_backoff_time(attempt, start) @pytest.mark.models("core") -def test_exponential_backoff_t2(client: Client): - with client: +def test_exponential_backoff_t2(session: Session): + with session.client as client: IF = InputFlowPINBackoff(client, BAD_PIN, PIN4) client.set_input_flow(IF.get()) - get_test_address(client) + get_test_address(session) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index 22ffb13b7f9..7825cbacee3 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -17,8 +17,9 @@ import pytest from trezorlib import btc, device, messages, misc, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -43,196 +44,234 @@ pytestmark = pytest.mark.setup_client(pin=PIN4, passphrase=True) -def _pin_request(client: Client): +def _pin_request(session: Session): """Get appropriate PIN request for each model""" - if client.model is models.T1B1: + if session.model is models.T1B1: return messages.PinMatrixRequest else: return messages.ButtonRequest(code=B.PinEntry) def _assert_protection( - client: Client, pin: bool = True, passphrase: bool = True -) -> None: + session: Session, pin: bool = True, passphrase: bool = True +) -> Session: """Make sure PIN and passphrase protection have expected values""" - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.ensure_unlocked() + session.ensure_unlocked() + client.refresh_features() assert client.features.pin_protection is pin assert client.features.passphrase_protection is passphrase - client.clear_session() + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: + new_session = session.client.get_session() + session.lock() + session.end() + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + new_session = session.client.get_session() + return Session(new_session) -def test_initialize(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses([messages.Features]) - client.init_device() +def test_initialize(session: Session): + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: + # Test is skipped for THP + return + + with session, session.client as client: + client.use_pin_sequence([PIN4]) + session.ensure_unlocked() + session = _assert_protection(session) + with session: + session.set_expected_responses([messages.Features]) + session.call(messages.Initialize(session_id=session.id)) @pytest.mark.models("core") @pytest.mark.setup_client(pin=PIN4) @pytest.mark.parametrize("passphrase", (True, False)) -def test_passphrase_reporting(client: Client, passphrase): +def test_passphrase_reporting(session: Session, passphrase): """On TT, passphrase_protection is a private setting, so a locked device should report passphrase_protection=None. """ - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - device.apply_settings(client, use_passphrase=passphrase) + device.apply_settings(session, use_passphrase=passphrase) - client.lock() + session.lock() # on a locked device, passphrase_protection should be None - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + assert session.features.unlocked is False + assert session.features.passphrase_protection is None # on an unlocked device, protection should be reported accurately - _assert_protection(client, pin=True, passphrase=passphrase) + session = _assert_protection(session, pin=True, passphrase=passphrase) # after re-locking, the setting should be hidden again - client.lock() - assert client.features.unlocked is False - assert client.features.passphrase_protection is None + session.lock() + assert session.features.unlocked is False + assert session.features.passphrase_protection is None -def test_apply_settings(client: Client): - _assert_protection(client) - with client: +def test_apply_settings(session: Session): + session = _assert_protection(session) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] - ) # TrezorClient reinitializes device - device.apply_settings(client, label="nazdar") + ) + device.apply_settings(session, label="nazdar") @pytest.mark.models("legacy") -def test_change_pin_t1(client: Client): - _assert_protection(client) - with client: +def test_change_pin_t1(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ messages.ButtonRequest, - _pin_request(client), - _pin_request(client), - _pin_request(client), + _pin_request(session), + _pin_request(session), + _pin_request(session), messages.Success, messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.models("core") -def test_change_pin_t2(client: Client): - _assert_protection(client) - with client: +def test_change_pin_t2(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, - _pin_request(client), - _pin_request(client), - (client.layout_type is LayoutType.TR, messages.ButtonRequest), - _pin_request(client), + _pin_request(session), + _pin_request(session), + (session.client.layout_type is LayoutType.TR, messages.ButtonRequest), + _pin_request(session), messages.ButtonRequest, messages.Success, - messages.Features, + # messages.Features, ] ) - device.change_pin(client) + device.change_pin(session) @pytest.mark.setup_client(pin=None, passphrase=False) -def test_ping(client: Client): - _assert_protection(client, pin=False, passphrase=False) - with client: - client.set_expected_responses([messages.ButtonRequest, messages.Success]) - client.ping("msg", True) +def test_ping(session: Session): + session = _assert_protection(session, pin=False, passphrase=False) + with session: + session.set_expected_responses([messages.ButtonRequest, messages.Success]) + session.call(messages.Ping(message="msg", button_protection=True)) -def test_get_entropy(client: Client): - _assert_protection(client) - with client: +def test_get_entropy(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest(code=B.ProtectCall), messages.Entropy, ] ) - misc.get_entropy(client, 10) + misc.get_entropy(session, 10) + +def test_get_public_key(session: Session): + session = _assert_protection(session) -def test_get_public_key(client: Client): - _assert_protection(client) - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.PublicKey, - ] - ) - btc.get_public_node(client, []) + expected_responses = [_pin_request(session)] + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.append(messages.PublicKey) -def test_get_address(client: Client): - _assert_protection(client) - with client: - client.use_pin_sequence([PIN4]) - client.set_expected_responses( - [ - _pin_request(client), - messages.PassphraseRequest, - messages.Address, - ] - ) - get_test_address(client) + session.set_expected_responses(expected_responses) + btc.get_public_node(session, []) -def test_wipe_device(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses( - [messages.ButtonRequest, messages.Success, messages.Features] - ) - device.wipe(client) +def test_get_address(session: Session): + session = _assert_protection(session) + + with session, session.client as client: + client.use_pin_sequence([PIN4]) + expected_responses = [_pin_request(session)] + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.append(messages.Address) + + session.set_expected_responses(expected_responses) + + get_test_address(session) + + +def test_wipe_device(session: Session): + # Precise cause of crash is not determined, it happens with some order of + # tests, but not with all. The following leads to crash: + # pytest --random-order-seed=675848 tests/device_tests/test_protection_levels.py + # + # Traceback (most recent call last): + # File "trezor/wire/__init__.py", line 70, in handle_session + # File "trezor/wire/thp_main.py", line 79, in thp_main_loop + # File "trezor/wire/thp_main.py", line 145, in _handle_allocated + # File "trezor/wire/thp/received_message_handler.py", line 123, in handle_received_message + # File "trezor/wire/thp/received_message_handler.py", line 231, in _handle_state_TH1 + # File "trezor/wire/thp/crypto.py", line 93, in handle_th1_crypto + # File "trezor/wire/thp/crypto.py", line 178, in _derive_static_key_pair + # File "storage/device.py", line 364, in get_device_secret + # File "storage/common.py", line 21, in set + # RuntimeError: Could not save value + + session = _assert_protection(session) + with session: + session.set_expected_responses([messages.ButtonRequest, messages.Success]) + device.wipe(session) + client = session.client.get_new_client() + session = Session(client.get_management_session()) + with session, session.client as client: + client.use_pin_sequence([PIN4]) + session.set_expected_responses([messages.Features]) + session.call(messages.GetFeatures()) @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("legacy") -def test_reset_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - with WITH_MOCK_URANDOM, client: - client.set_expected_responses( +def test_reset_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + with WITH_MOCK_URANDOM, session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.EntropyRequest] + [messages.ButtonRequest] * 24 + [messages.Success, messages.Features] ) device.reset( - client, + session, strength=128, passphrase_protection=True, pin_protection=False, label="label", ) + session.call(messages.GetFeatures()) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.reset` has its own check - client.call( + session.call( messages.ResetDevice( strength=128, passphrase_protection=True, @@ -244,30 +283,30 @@ def test_reset_device(client: Client): @pytest.mark.setup_client(uninitialized=True) @pytest.mark.models("legacy") -def test_recovery_device(client: Client): - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - client.use_mnemonic(MNEMONIC12) - with client: - client.set_expected_responses( +def test_recovery_device(session: Session): + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + session.client.use_mnemonic(MNEMONIC12) + with session: + session.set_expected_responses( [messages.ButtonRequest] + [messages.WordRequest] * 24 - + [messages.Success, messages.Features] + + [messages.Success] # , messages.Features] ) device.recover( - client, + session, 12, False, False, "label", - input_callback=client.mnemonic_callback, + input_callback=session.client.mnemonic_callback, ) with pytest.raises(TrezorFailure): # This must fail, because device is already initialized # Using direct call because `device.recover` has its own check - client.call( + session.call( messages.RecoveryDevice( word_count=12, passphrase_protection=False, @@ -277,29 +316,37 @@ def test_recovery_device(client: Client): ) -def test_sign_message(client: Client): - _assert_protection(client) - with client: +def test_sign_message(session: Session): + session = _assert_protection(session) + + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + + expected_responses = [_pin_request(session)] + + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + + expected_responses.extend( [ - _pin_request(client), - messages.PassphraseRequest, messages.ButtonRequest, messages.ButtonRequest, messages.MessageSignature, ] ) + + session.set_expected_responses(expected_responses) + btc.sign_message( - client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" + session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), "testing message" ) @pytest.mark.models("legacy") -def test_verify_message_t1(client: Client): - _assert_protection(client) - with client: - client.set_expected_responses( +def test_verify_message_t1(session: Session): + session = _assert_protection(session) + with session: + session.set_expected_responses( [ messages.ButtonRequest, messages.ButtonRequest, @@ -308,7 +355,7 @@ def test_verify_message_t1(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -319,13 +366,13 @@ def test_verify_message_t1(client: Client): @pytest.mark.models("core") -def test_verify_message_t2(client: Client): - _assert_protection(client) - with client: +def test_verify_message_t2(session: Session): + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + session.set_expected_responses( [ - _pin_request(client), + _pin_request(session), messages.ButtonRequest, messages.ButtonRequest, messages.ButtonRequest, @@ -333,7 +380,7 @@ def test_verify_message_t2(client: Client): ] ) btc.verify_message( - client, + session, "Bitcoin", "14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e", bytes.fromhex( @@ -343,7 +390,7 @@ def test_verify_message_t2(client: Client): ) -def test_signtx(client: Client): +def test_signtx(session: Session): # input tx: 50f6f1209ca92d7359564be803cb2c932cde7d370f7cee50fd1fad6790f6206d inp1 = messages.TxInputType( @@ -359,17 +406,18 @@ def test_signtx(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _assert_protection(client) - with client: + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses( + expected_responses = [_pin_request(session)] + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + expected_responses.append(messages.PassphraseRequest) + expected_responses.extend( [ - _pin_request(client), - messages.PassphraseRequest, request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_meta(TXHASH_50f6f1), @@ -382,7 +430,9 @@ def test_signtx(client: Client): request_finished(), ] ) - btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) + session.set_expected_responses(expected_responses) + + btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=TxCache("Bitcoin")) # def test_firmware_erase(): @@ -393,29 +443,37 @@ def test_signtx(client: Client): @pytest.mark.setup_client(pin=PIN4, passphrase=False) -def test_unlocked(client: Client): - assert client.features.unlocked is False +def test_unlocked(session: Session): + assert session.features.unlocked is False + + session = _assert_protection(session, passphrase=False) - _assert_protection(client, passphrase=False) - with client: + with session, session.client as client: client.use_pin_sequence([PIN4]) - client.set_expected_responses([_pin_request(client), messages.Address]) - get_test_address(client) + session.set_expected_responses([_pin_request(session), messages.Address]) + get_test_address(session) - client.init_device() - assert client.features.unlocked is True - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + session.refresh_features() + assert session.features.unlocked is True + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) @pytest.mark.setup_client(pin=None, passphrase=True) -def test_passphrase_cached(client: Client): - _assert_protection(client, pin=False) - with client: - client.set_expected_responses([messages.PassphraseRequest, messages.Address]) - get_test_address(client) - - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) +def test_passphrase_cached(session: Session): + session = _assert_protection(session, pin=False) + with session: + if session.protocol_version == 1: + session.set_expected_responses( + [messages.PassphraseRequest, messages.Address] + ) + elif session.protocol_version == 2: + session.set_expected_responses([messages.Address]) + else: + raise Exception("Unknown session type") + get_test_address(session) + + with session: + session.set_expected_responses([messages.Address]) + get_test_address(session) diff --git a/tests/device_tests/test_repeated_backup.py b/tests/device_tests/test_repeated_backup.py index 3bf2d42510d..29aa9b538ef 100644 --- a/tests/device_tests/test_repeated_backup.py +++ b/tests/device_tests/test_repeated_backup.py @@ -17,8 +17,8 @@ import pytest -from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib import device, exceptions, messages +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from .. import translations as TR @@ -35,194 +35,198 @@ @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @WITH_MOCK_URANDOM -def test_repeated_backup(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - ret = device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + ret = device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ret == messages.Success(message="Backup unlocked") assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_SINGLE_EXT_20) @WITH_MOCK_URANDOM -def test_repeated_backup_upgrade_single(client: Client): +def test_repeated_backup_upgrade_single(session: Session): assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing - assert client.features.backup_type == messages.BackupType.Slip39_Single_Extendable + assert session.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable # unlock repeated backup by entering the single share - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - ret = device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + ret = device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ret == messages.Success(message="Backup unlocked") assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) # backup type was upgraded: - assert client.features.backup_type == messages.BackupType.Slip39_Basic_Extendable + assert session.features.backup_type == messages.BackupType.Slip39_Basic_Extendable # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @WITH_MOCK_URANDOM -def test_repeated_backup_cancel(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_cancel(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - ret = device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + ret = device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ret == messages.Success(message="Backup unlocked") assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a Cancel message with pytest.raises(Cancelled): - client.call(messages.Cancel()) + session.call(messages.Cancel()) - client.refresh_features() + session.refresh_features() # the backup feature is locked again... assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @WITH_MOCK_URANDOM -def test_repeated_backup_send_disallowed_message(client: Client): - assert client.features.backup_availability == messages.BackupAvailability.Required - assert client.features.recovery_status == messages.RecoveryStatus.Nothing +def test_repeated_backup_send_disallowed_message(session: Session): + assert session.features.backup_availability == messages.BackupAvailability.Required + assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with client: + with session, session.client as client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) - device.backup(client) + device.backup(session) mnemonics = IF.mnemonics assert len(mnemonics) == 5 # cannot backup, since we already just did that! assert ( - client.features.backup_availability == messages.BackupAvailability.NotAvailable + session.features.backup_availability == messages.BackupAvailability.NotAvailable ) - assert client.features.recovery_status == messages.RecoveryStatus.Nothing + assert session.features.recovery_status == messages.RecoveryStatus.Nothing with pytest.raises(TrezorFailure, match=r".*Seed already backed up"): - device.backup(client) + device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with client: + with session, session.client as client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) client.set_input_flow(IF.get()) - ret = device.recover(client, type=messages.RecoveryType.UnlockRepeatedBackup) + ret = device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ret == messages.Success(message="Backup unlocked") assert ( - client.features.backup_availability == messages.BackupAvailability.Available + session.features.backup_availability + == messages.BackupAvailability.Available ) - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = client.debug.read_layout() + layout = session.client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a GetAddress message - resp = client.call_raw( + resp = session.call_raw( messages.GetAddress( coin_name="Testnet", address_n=TEST_ADDRESS_N, @@ -233,10 +237,13 @@ def test_repeated_backup_send_disallowed_message(client: Client): assert isinstance(resp, messages.Failure) assert "not allowed" in resp.message - assert client.features.backup_availability == messages.BackupAvailability.Available - assert client.features.recovery_status == messages.RecoveryStatus.Backup + assert session.features.backup_availability == messages.BackupAvailability.Available + assert session.features.recovery_status == messages.RecoveryStatus.Backup # we are still on the confirmation screen! assert ( - TR.recovery__unlock_repeated_backup in client.debug.read_layout().text_content() + TR.recovery__unlock_repeated_backup + in session.client.debug.read_layout().text_content() ) + with pytest.raises(exceptions.Cancelled): + session.call(messages.Cancel()) diff --git a/tests/device_tests/test_sdcard.py b/tests/device_tests/test_sdcard.py index 69098d81df7..8d5c45b81fc 100644 --- a/tests/device_tests/test_sdcard.py +++ b/tests/device_tests/test_sdcard.py @@ -17,111 +17,117 @@ import pytest from trezorlib import device, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.messages import SdProtectOperationType as Op from .. import translations as TR +PIN = "1234" + pytestmark = pytest.mark.models("core", skip="safe3") @pytest.mark.sd_card(formatted=False) -def test_sd_format(client: Client): - device.sd_protect(client, Op.ENABLE) - assert client.features.sd_protection is True +def test_sd_format(session: Session): + device.sd_protect(session, Op.ENABLE) + assert session.features.sd_protection is True @pytest.mark.sd_card(formatted=False) -def test_sd_no_format(client: Client): +def test_sd_no_format(session: Session): + debug = session.client.debug + def input_flow(): yield # enable SD protection? - client.debug.press_yes() + debug.press_yes() yield # format SD card - client.debug.press_no() + debug.press_no() - with pytest.raises(TrezorFailure) as e, client: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.set_input_flow(input_flow) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) assert e.value.code == messages.FailureType.ProcessError @pytest.mark.sd_card -@pytest.mark.setup_client(pin="1234") -def test_sd_protect_unlock(client: Client): - layout = client.debug.read_layout +@pytest.mark.setup_client(pin=PIN) +def test_sd_protect_unlock(session: Session): + debug = session.client.debug + layout = debug.read_layout def input_flow_enable_sd_protect(): + # debug.press_yes() yield # Enter PIN to unlock device assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # do you really want to enable SD protection assert TR.sd_card__enable in layout().text_content() - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # you have successfully enabled SD protection assert TR.sd_card__enabled in layout().text_content() - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_enable_sd_protect) - device.sd_protect(client, Op.ENABLE) + device.sd_protect(session, Op.ENABLE) def input_flow_change_pin(): yield # do you really want to change PIN? assert layout().title() == TR.pin__title_settings - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # enter new PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # enter new PIN again assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # Pin change successful assert TR.pin__changed in layout().text_content() - client.debug.press_yes() + debug.press_yes() - with client: + with session, session.client as client: client.watch_layout() client.set_input_flow(input_flow_change_pin) - device.change_pin(client) + device.change_pin(session) - client.debug.erase_sd_card(format=False) + debug.erase_sd_card(format=False) def input_flow_change_pin_format(): yield # do you really want to change PIN? assert layout().title() == TR.pin__title_settings - client.debug.press_yes() + debug.press_yes() yield # enter current PIN assert "PinKeyboard" in layout().all_components() - client.debug.input("1234") + debug.input(PIN) yield # SD card problem assert ( TR.sd_card__unplug_and_insert_correct in layout().text_content() or TR.sd_card__insert_correct_card in layout().text_content() ) - client.debug.press_no() # close + debug.press_no() # close - with client, pytest.raises(TrezorFailure) as e: + with session, session.client as client, pytest.raises(TrezorFailure) as e: client.watch_layout() client.set_input_flow(input_flow_change_pin_format) - device.change_pin(client) + device.change_pin(session) assert e.value.code == messages.FailureType.ProcessError diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index a8020d0354d..56b8ace9960 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -18,6 +18,8 @@ from trezorlib import cardano, messages, models from trezorlib.btc import get_public_node +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -30,6 +32,18 @@ PIN4 = "1234" +def test_thp_end_session(client: Client): + session = Session(client.get_session()) + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + # TODO: This test should be skipped on non-THP builds + return + + msg = session.call(messages.EndSession()) + assert isinstance(msg, messages.Success) + with pytest.raises(TrezorFailure, match="ThpUnallocatedSession"): + session.call(messages.GetFeatures()) + + @pytest.mark.setup_client(pin=PIN4, passphrase="") def test_clear_session(client: Client): is_t1 = client.model is models.T1B1 @@ -39,100 +53,105 @@ def test_clear_session(client: Client): ] cached_responses = [messages.PublicKey] - - with client: + session = Session(client.get_session()) + session.lock() + with client, session: client.use_pin_sequence([PIN4]) - client.set_expected_responses(init_responses + cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(init_responses + cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - with client: + client.resume_session(session) + with session: # pin and passphrase are cached - client.set_expected_responses(cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - client.clear_session() + session.lock() + session.end() + session = Session(client.get_session()) # session cache is cleared - with client: + with client, session: client.use_pin_sequence([PIN4]) - client.set_expected_responses(init_responses + cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(init_responses + cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB - with client: + client.resume_session(session) + with session: # pin and passphrase are cached - client.set_expected_responses(cached_responses) - assert get_public_node(client, ADDRESS_N).xpub == XPUB + session.set_expected_responses(cached_responses) + assert get_public_node(session, ADDRESS_N).xpub == XPUB def test_end_session(client: Client): # client instance starts out not initialized # XXX do we want to change this? - assert client.session_id is not None + session = client.get_session() + assert session.id is not None # get_address will succeed - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + with Session(session) as session: + session.set_expected_responses([messages.Address]) + get_test_address(session) - client.end_session() - assert client.session_id is None + session.end() + # assert client.session_id is None with pytest.raises(TrezorFailure) as exc: - get_test_address(client) + get_test_address(session) assert exc.value.code == messages.FailureType.InvalidSession assert exc.value.message.endswith("Invalid session") - client.init_device() - assert client.session_id is not None - with client: - client.set_expected_responses([messages.Address]) - get_test_address(client) + session = client.get_session() + assert session.id is not None + with Session(session) as session: + session.set_expected_responses([messages.Address]) + get_test_address(session) - with client: - # end_session should succeed on empty session too - client.set_expected_responses([messages.Success] * 2) - client.end_session() - client.end_session() + # TODO: is the following valid? I do not think so + # with Session(session) as session: + # # end_session should succeed on empty session too + # session.set_expected_responses([messages.Success] * 2) + # session.end_session() + # session.end_session() def test_cannot_resume_ended_session(client: Client): - session_id = client.session_id - with client: - client.set_expected_responses([messages.Features]) - client.init_device(session_id=session_id) + session = client.get_session() + session_id = session.id + + session_resumed = client.resume_session(session) - assert session_id == client.session_id + assert session_resumed.id == session_id - client.end_session() - with client: - client.set_expected_responses([messages.Features]) - client.init_device(session_id=session_id) + session.end() + session_resumed2 = client.resume_session(session) - assert session_id != client.session_id + assert session_resumed2.id != session_id def test_end_session_only_current(client: Client): """test that EndSession only destroys the current session""" - session_id_a = client.session_id - client.init_device(new_session=True) - session_id_b = client.session_id + session_a = client.get_session() + session_b = client.get_session() + session_b_id = session_b.id - client.end_session() - assert client.session_id is None + session_b.end() + # assert client.session_id is None # resume ended session - client.init_device(session_id=session_id_b) - assert client.session_id != session_id_b + session_b_resumed = client.resume_session(session_b) + assert session_b_resumed.id != session_b_id # resume first session that was not ended - client.init_device(session_id=session_id_a) - assert client.session_id == session_id_a + session_a_resumed = client.resume_session(session_a) + assert session_a_resumed.id == session_a.id @pytest.mark.setup_client(passphrase=True) def test_session_recycling(client: Client): - session_id_orig = client.session_id - with client: - client.set_expected_responses( + session = Session(client.get_session(passphrase="TREZOR")) + with client, session: + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -141,20 +160,22 @@ def test_session_recycling(client: Client): ] ) client.use_passphrase("TREZOR") - address = get_test_address(client) + _ = get_test_address(session) + # address = get_test_address(session) # create and close 100 sessions - more than the session limit for _ in range(100): - client.init_device(new_session=True) - client.end_session() + session_x = client.get_session() + session_x.end() # it should still be possible to resume the original session - with client: - # passphrase should still be cached - client.set_expected_responses([messages.Features, messages.Address]) - client.use_passphrase("TREZOR") - client.init_device(session_id=session_id_orig) - assert address == get_test_address(client) + # TODO imo not True anymore + # with client, session: + # # passphrase should still be cached + # session.set_expected_responses([messages.Features, messages.Address]) + # client.use_passphrase("TREZOR") + # client.resume_session(session) + # assert address == get_test_address(session) @pytest.mark.altcoin @@ -162,18 +183,19 @@ def test_session_recycling(client: Client): @pytest.mark.models("core") def test_derive_cardano_empty_session(client: Client): # start new session - client.init_device(new_session=True) - session_id = client.session_id + session = client.get_session(derive_cardano=True) + # session_id = client.session_id # restarting same session should go well - client.init_device() - assert session_id == client.session_id + session2 = client.resume_session(session) + assert session.id == session2.id # restarting same session should go well with any setting - client.init_device(derive_cardano=False) - assert session_id == client.session_id - client.init_device(derive_cardano=True) - assert session_id == client.session_id + # TODO I do not think that it holds True now + # client.init_device(derive_cardano=False) + # assert session_id == client.session_id + # client.init_device(derive_cardano=True) + # assert session_id == client.session_id @pytest.mark.altcoin @@ -181,43 +203,41 @@ def test_derive_cardano_empty_session(client: Client): @pytest.mark.models("core") def test_derive_cardano_running_session(client: Client): # start new session - client.init_device(new_session=True) - session_id = client.session_id + session = client.get_session(derive_cardano=False) + # force derivation of seed - get_test_address(client) + get_test_address(session) # session should not have Cardano capability with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session, parse_path("m/44h/1815h/0h")) # restarting same session should go well - client.init_device() - assert session_id == client.session_id + session2 = client.resume_session(session) + assert session.id == session2.id - # restarting same session should go well if we _don't_ want to derive cardano - client.init_device(derive_cardano=False) - assert session_id == client.session_id + # TODO restarting same session should go well if we _don't_ want to derive cardano + # # client.init_device(derive_cardano=False) + # # assert session_id == client.session_id # restarting with derive_cardano=True should kill old session and create new one - client.init_device(derive_cardano=True) - assert session_id != client.session_id - - session_id = client.session_id + session3 = client.get_session(derive_cardano=True) + assert session3.id != session.id # new session should have Cardano capability - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session3, parse_path("m/44h/1815h/0h")) # restarting with derive_cardano=True should keep same session - client.init_device(derive_cardano=True) - assert session_id == client.session_id + session4 = client.resume_session(session3) + assert session4.id == session3.id - # restarting with no setting should keep same session - client.init_device() - assert session_id == client.session_id + # # restarting with no setting should keep same session + # client.init_device() + # assert session_id == client.session_id - # restarting with derive_cardano=False should kill old session and create new one - client.init_device(derive_cardano=False) - assert session_id != client.session_id + # # restarting with derive_cardano=False should kill old session and create new one + # client.init_device(derive_cardano=False) + # assert session_id != client.session_id - with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) + # with pytest.raises(TrezorFailure, match="not enabled"): + # cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 1bb9cbd70a8..6aa7dced5b1 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -19,7 +19,9 @@ import pytest from trezorlib import device, exceptions, messages +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import FailureType, SafetyCheckLevel @@ -49,19 +51,13 @@ SESSIONS_STORED = 10 -def _init_session(client: Client, session_id=None, derive_cardano=False): - """Call Initialize, check and return the session ID.""" - response = client.call( - messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) - ) - assert isinstance(response, messages.Features) - assert len(response.session_id) == 32 - return response.session_id - - -def _get_xpub(client: Client, passphrase=None): +def _get_xpub( + session: Session, + expected_passphrase_req: bool = False, + passphrase_v1: str | None = None, +): """Get XPUB and check that the appropriate passphrase flow has happened.""" - if passphrase is not None: + if expected_passphrase_req: expected_responses = [ messages.PassphraseRequest, messages.ButtonRequest, @@ -70,111 +66,122 @@ def _get_xpub(client: Client, passphrase=None): ] else: expected_responses = [messages.PublicKey] - - with client: - client.use_passphrase(passphrase or "") - client.set_expected_responses(expected_responses) - result = client.call(XPUB_REQUEST) + if ( + passphrase_v1 is not None + and session.protocol_version == ProtocolVersion.PROTOCOL_V1 + ): + session.passphrase = passphrase_v1 + + with session: + session.set_expected_responses(expected_responses) + result = session.call(XPUB_REQUEST) return result.xpub @pytest.mark.setup_client(passphrase=True) def test_session_with_passphrase(client: Client): - # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = Session(client.get_session(passphrase="A")) + session_id = session.id # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] # Call Initialize again, this time with the received session id and then call # GetPublicKey. The passphrase should be cached now so Trezor must # not ask for it again, whilst returning the same xpub. - new_session_id = _init_session(client, session_id=session_id) - assert new_session_id == session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + session2 = Session(client.resume_session(session)) + assert session2.id == session_id + assert _get_xpub(session2) == XPUB_PASSPHRASES["A"] # If we set session id in Initialize to None, the cache will be cleared # and Trezor will ask for the passphrase again. - new_session_id = _init_session(client) - assert new_session_id != session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + session3 = Session(client.get_session(passphrase="A")) + assert session3 != session_id + assert _get_xpub(session3, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] - # Unknown session id is the same as setting it to None. - _init_session(client, session_id=b"X" * 32) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + # TODO: The following part is kept only for solving UI-diff in tests + # - it can be removed if fixtures are updated, imo + session4 = Session(client.get_session(passphrase="A")) + assert session4 != session_id + assert _get_xpub(session4, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] @pytest.mark.setup_client(passphrase=True) def test_multiple_sessions(client: Client): # start SESSIONS_STORED sessions session_ids = [] + sessions = [] for _ in range(SESSIONS_STORED): - session_ids.append(_init_session(client)) + session = client.get_session() + sessions.append(session) + session_ids.append(session.id) # Resume each session - for session_id in session_ids: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Creating a new session replaces the least-recently-used session - _init_session(client) + client.get_session() # Resuming session 1 through SESSIONS_STORED will still work - for session_id in session_ids[1:]: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(1, SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Resuming session 0 will not work - new_session_id = _init_session(client, session_ids[0]) - assert new_session_id != session_ids[0] + resumed_session = client.resume_session(sessions[0]) + assert session_ids[0] != resumed_session.id # New session bumped out the least-recently-used anonymous session. # Resuming session 1 through SESSIONS_STORED will still work - for session_id in session_ids[1:]: - new_session_id = _init_session(client, session_id) - assert session_id == new_session_id + for i in range(1, SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] == resumed_session.id # Creating a new session replaces session_ids[0] again - _init_session(client) + client.get_session() # Resuming all sessions one by one will in turn bump out the previous session. - for session_id in session_ids: - new_session_id = _init_session(client, session_id) - assert session_id != new_session_id + for i in range(SESSIONS_STORED): + resumed_session = client.resume_session(sessions[i]) + assert session_ids[i] != resumed_session.id @pytest.mark.setup_client(passphrase=True) def test_multiple_passphrases(client: Client): # start a session - session_a = _init_session(client) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + session_a = Session(client.get_session(passphrase="A")) + session_a_id = session_a.id + assert _get_xpub(session_a, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] # start it again wit the same session id - new_session_id = _init_session(client, session_id=session_a) + session_a_resumed = Session(client.resume_session(session_a)) # session is the same - assert new_session_id == session_a + assert session_a_resumed.id == session_a_id # passphrase is not prompted - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + assert _get_xpub(session_a_resumed) == XPUB_PASSPHRASES["A"] # start a second session - session_b = _init_session(client) + session_b = Session(client.get_session(passphrase="B")) + session_b_id = session_b.id # new session -> new session id and passphrase prompt - assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + assert _get_xpub(session_b, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] # provide the same session id -> must not ask for passphrase again. - new_session_id = _init_session(client, session_id=session_b) - assert new_session_id == session_b - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + session_b_resumed = Session(client.resume_session(session_b)) + assert session_b_resumed.id == session_b_id + assert _get_xpub(session_b_resumed) == XPUB_PASSPHRASES["B"] # provide the first session id -> must not ask for passphrase again and return the same result. - new_session_id = _init_session(client, session_id=session_a) - assert new_session_id == session_a - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + session_a_resumed_again = Session(client.resume_session(session_a)) + assert session_a_resumed_again.id == session_a_id + assert _get_xpub(session_a_resumed_again) == XPUB_PASSPHRASES["A"] # provide the second session id -> must not ask for passphrase again and return the same result. - new_session_id = _init_session(client, session_id=session_b) - assert new_session_id == session_b - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + session_b_resumed_again = Session(client.resume_session(session_b)) + assert session_b_resumed_again.id == session_b_id + assert _get_xpub(session_b_resumed_again) == XPUB_PASSPHRASES["B"] @pytest.mark.slow @@ -185,11 +192,13 @@ def test_max_sessions_with_passphrases(client: Client): # start as many sessions as the limit is session_ids = {} + sessions = {} for passphrase, xpub in XPUB_PASSPHRASES.items(): - session_id = _init_session(client) - assert session_id not in session_ids.values() - session_ids[passphrase] = session_id - assert _get_xpub(client, passphrase=passphrase) == xpub + session = Session(client.get_session(passphrase=passphrase)) + assert session.id not in session_ids.values() + session_ids[passphrase] = session.id + sessions[passphrase] = session + assert _get_xpub(session, expected_passphrase_req=True) == xpub # passphrase is not prompted for the started the sessions, regardless the order # let's try 20 different orderings @@ -198,85 +207,90 @@ def test_max_sessions_with_passphrases(client: Client): for _ in range(20): random.shuffle(shuffling) for passphrase in shuffling: - session_id = _init_session(client, session_id=session_ids[passphrase]) - assert session_id == session_ids[passphrase] - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES[passphrase] + resumed_session = Session(client.resume_session(sessions[passphrase])) + assert resumed_session.id == session_ids[passphrase] + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES[passphrase] # make sure the usage order is the reverse of the creation order for passphrase in reversed(passphrases): - session_id = _init_session(client, session_id=session_ids[passphrase]) - assert session_id == session_ids[passphrase] - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES[passphrase] + resumed_session = Session(client.resume_session(sessions[passphrase])) + assert resumed_session.id == session_ids[passphrase] + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES[passphrase] # creating one more session will exceed the limit - _init_session(client) + new_session = Session(client.get_session(passphrase="XX")) # new session asks for passphrase - _get_xpub(client, passphrase="XX") + _get_xpub(new_session, expected_passphrase_req=True) # restoring the sessions in reverse will evict the next-up session for passphrase in reversed(passphrases): - _init_session(client, session_id=session_ids[passphrase]) - _get_xpub(client, passphrase="whatever") # passphrase is prompted + resumed_session = Session(client.resume_session(sessions[passphrase])) + _get_xpub( + resumed_session, + expected_passphrase_req=True, + passphrase_v1="whatever", + ) # passphrase is prompted def test_session_enable_passphrase(client: Client): # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = Session(client.get_session(passphrase="")) # Trezor will not prompt for passphrase because it is turned off. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_NONE + assert _get_xpub(session, expected_passphrase_req=False) == XPUB_PASSPHRASE_NONE # Turn on passphrase. # Emit the call explicitly to avoid ClearSession done by the library function - response = client.call(messages.ApplySettings(use_passphrase=True)) + response = session.call(messages.ApplySettings(use_passphrase=True)) assert isinstance(response, messages.Success) # The session id is unchanged, therefore we do not prompt for the passphrase. - new_session_id = _init_session(client, session_id=session_id) - assert session_id == new_session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_NONE + session_id = session.id + resumed_session = Session(client.resume_session(session)) + assert session_id == resumed_session.id + assert _get_xpub(resumed_session) == XPUB_PASSPHRASE_NONE # We clear the session id now, so the passphrase should be asked. - new_session_id = _init_session(client) - assert session_id != new_session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + new_session = Session(client.get_session(passphrase="A")) + assert session_id != new_session.id + assert _get_xpub(new_session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) def test_passphrase_on_device(client: Client): - _init_session(client) - + # _init_session(client) + session = client.get_session(passphrase="A") # try to get xpub with passphrase on host: - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) # using `client.call` to auto-skip subsequent ButtonRequests for "show passphrase" - response = client.call(messages.PassphraseAck(passphrase="A", on_device=False)) + response = session.call(messages.PassphraseAck(passphrase="A", on_device=False)) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # make a new session - _init_session(client) + session2 = session.client.get_session(passphrase="A") # try to get xpub with passphrase on device: - response = client.call_raw(XPUB_REQUEST) + response = session2.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(on_device=True)) + response = session2.call_raw(messages.PassphraseAck(on_device=True)) # no "show passphrase" here assert isinstance(response, messages.ButtonRequest) client.debug.input("A") - response = client.call_raw(messages.ButtonAck()) + response = session2.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached - response = client.call_raw(XPUB_REQUEST) + response = session2.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] @@ -285,32 +299,33 @@ def test_passphrase_on_device(client: Client): @pytest.mark.setup_client(passphrase=True) def test_passphrase_always_on_device(client: Client): # Let's start the communication by calling Initialize. - session_id = _init_session(client) + session = client.get_session() + # session_id = _init_session(client) # Force passphrase entry on Trezor. - response = client.call(messages.ApplySettings(passphrase_always_on_device=True)) + response = session.call(messages.ApplySettings(passphrase_always_on_device=True)) assert isinstance(response, messages.Success) # Since we enabled the always_on_device setting, Trezor will send ButtonRequests and ask for it on the device. - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("") # Input empty passphrase. - response = client.call_raw(messages.ButtonAck()) + response = session.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASE_NONE # Passphrase will not be prompted. The session id stays the same and the passphrase is cached. - _init_session(client, session_id=session_id) - response = client.call_raw(XPUB_REQUEST) + resumed_session = client.resume_session(session) + response = resumed_session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASE_NONE # In case we want to add a new passphrase we need to send session_id = None. - _init_session(client) - response = client.call_raw(XPUB_REQUEST) + new_session = client.get_session(passphrase="A") + response = new_session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("A") # Input non-empty passphrase. - response = client.call_raw(messages.ButtonAck()) + response = new_session.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) assert response.xpub == XPUB_PASSPHRASES["A"] @@ -332,25 +347,27 @@ def test_passphrase_on_device_not_possible_on_t1(client: Client): @pytest.mark.setup_client(passphrase=True) -def test_passphrase_ack_mismatch(client: Client): - response = client.call_raw(XPUB_REQUEST) +def test_passphrase_ack_mismatch(session: Session): + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase="A", on_device=True)) + response = session.call_raw(messages.PassphraseAck(passphrase="A", on_device=True)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @pytest.mark.setup_client(passphrase="") -def test_passphrase_missing(client: Client): - response = client.call_raw(XPUB_REQUEST) +def test_passphrase_missing(session: Session): + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase=None)) + response = session.call_raw(messages.PassphraseAck(passphrase=None)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError - response = client.call_raw(XPUB_REQUEST) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) - response = client.call_raw(messages.PassphraseAck(passphrase=None, on_device=False)) + response = session.call_raw( + messages.PassphraseAck(passphrase=None, on_device=False) + ) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError @@ -358,11 +375,11 @@ def test_passphrase_missing(client: Client): @pytest.mark.setup_client(passphrase=True) def test_passphrase_length(client: Client): def call(passphrase: str, expected_result: bool): - _init_session(client) - response = client.call_raw(XPUB_REQUEST) + session = client.get_session(passphrase=passphrase) + response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) try: - response = client.call(messages.PassphraseAck(passphrase=passphrase)) + response = session.call(messages.PassphraseAck(passphrase=passphrase)) assert expected_result is True, "Call should have failed" assert isinstance(response, messages.PublicKey) except exceptions.TrezorFailure as e: @@ -383,17 +400,18 @@ def call(passphrase: str, expected_result: bool): @pytest.mark.setup_client(passphrase=True) def test_hide_passphrase_from_host(client: Client): # Without safety checks, turning it on fails + session = client.get_management_session() with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: - device.apply_settings(client, hide_passphrase_from_host=True) + device.apply_settings(session, hide_passphrase_from_host=True) - device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily) + device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) # Turning it on - device.apply_settings(client, hide_passphrase_from_host=True) + device.apply_settings(session, hide_passphrase_from_host=True) passphrase = "abc" - - with client: + session = Session(client.get_session(passphrase=passphrase)) + with client, session: def input_flow(): yield @@ -410,7 +428,7 @@ def input_flow(): client.watch_layout() client.set_input_flow(input_flow) - client.set_expected_responses( + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -418,17 +436,17 @@ def input_flow(): ] ) client.use_passphrase(passphrase) - result = client.call(XPUB_REQUEST) + result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_hidden_passphrase = result.xpub # Turning it off - device.apply_settings(client, hide_passphrase_from_host=False) + device.apply_settings(session, hide_passphrase_from_host=False) # Starting new session, otherwise the passphrase would be cached - _init_session(client) + session = Session(client.get_session(passphrase=passphrase)) - with client: + with client, session: def input_flow(): yield @@ -445,7 +463,7 @@ def input_flow(): client.watch_layout() client.set_input_flow(input_flow) - client.set_expected_responses( + session.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -454,22 +472,22 @@ def input_flow(): ] ) client.use_passphrase(passphrase) - result = client.call(XPUB_REQUEST) + result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_shown_passphrase = result.xpub assert xpub_hidden_passphrase == xpub_shown_passphrase -def _get_xpub_cardano(client: Client, passphrase): +def _get_xpub_cardano(session: Session, expected_passphrase_req: bool = False): msg = messages.CardanoGetPublicKey( address_n=parse_path("m/44h/1815h/0h/0/0"), derivation_type=messages.CardanoDerivationType.ICARUS, ) - response = client.call_raw(msg) - if passphrase is not None: + response = session.call_raw(msg) + if expected_passphrase_req: assert isinstance(response, messages.PassphraseRequest) - response = client.call(messages.PassphraseAck(passphrase=passphrase)) + response = session.call(messages.PassphraseAck(passphrase=session.passphrase)) assert isinstance(response, messages.CardanoPublicKey) return response.xpub @@ -482,31 +500,37 @@ def test_cardano_passphrase(client: Client): # of the passphrase. # Historically, Cardano calls would ask for passphrase again. Now, they should not. - session_id = _init_session(client, derive_cardano=True) + # session_id = _init_session(client, derive_cardano=True) # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + session = Session(client.get_session(passphrase="B", derive_cardano=True)) + assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] # The passphrase is now cached for non-Cardano coins. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + assert _get_xpub(session) == XPUB_PASSPHRASES["B"] # The passphrase should be cached for Cardano as well - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B + assert _get_xpub_cardano(session) == XPUB_CARDANO_PASSPHRASE_B # Initialize with the session id does not destroy the state - _init_session(client, session_id=session_id, derive_cardano=True) - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B + resumed_session = Session(client.resume_session(session)) + # _init_session(client, session_id=session_id, derive_cardano=True) + assert _get_xpub(resumed_session) == XPUB_PASSPHRASES["B"] + assert _get_xpub_cardano(resumed_session) == XPUB_CARDANO_PASSPHRASE_B # New session will destroy the state - _init_session(client, derive_cardano=True) + new_session = Session(client.get_session(passphrase="A", derive_cardano=True)) + # _init_session(client, derive_cardano=True) # Cardano must ask for passphrase again - assert _get_xpub_cardano(client, passphrase="A") == XPUB_CARDANO_PASSPHRASE_A + assert ( + _get_xpub_cardano(new_session, expected_passphrase_req=True) + == XPUB_CARDANO_PASSPHRASE_A + ) # Passphrase is now cached for Cardano - assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_A + assert _get_xpub_cardano(new_session) == XPUB_CARDANO_PASSPHRASE_A # Passphrase is cached for non-Cardano coins too - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + assert _get_xpub(new_session) == XPUB_PASSPHRASES["A"] diff --git a/tests/device_tests/tezos/test_getaddress.py b/tests/device_tests/tezos/test_getaddress.py index 3e6b5423938..9f35118370d 100644 --- a/tests/device_tests/tezos/test_getaddress.py +++ b/tests/device_tests/tezos/test_getaddress.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_address from trezorlib.tools import parse_path @@ -35,19 +35,19 @@ @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) -def test_tezos_get_address(client: Client, path: str, expected_address: str): - address = get_address(client, parse_path(path), show_display=True) +def test_tezos_get_address(session: Session, path: str, expected_address: str): + address = get_address(session, parse_path(path), show_display=True) assert address == expected_address @pytest.mark.parametrize("path, expected_address", TEST_VECTORS) def test_tezos_get_address_chunkify_details( - client: Client, path: str, expected_address: str + session: Session, path: str, expected_address: str ): - with client: + with session.client as client: IF = InputFlowShowAddressQRCode(client) client.set_input_flow(IF.get()) address = get_address( - client, parse_path(path), show_display=True, chunkify=True + session, parse_path(path), show_display=True, chunkify=True ) assert address == expected_address diff --git a/tests/device_tests/tezos/test_getpublickey.py b/tests/device_tests/tezos/test_getpublickey.py index 9f5bfcd0f74..8b1e72609d7 100644 --- a/tests/device_tests/tezos/test_getpublickey.py +++ b/tests/device_tests/tezos/test_getpublickey.py @@ -16,7 +16,7 @@ import pytest -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tezos import get_public_key from trezorlib.tools import parse_path @@ -24,11 +24,11 @@ @pytest.mark.altcoin @pytest.mark.tezos @pytest.mark.models("core") -def test_tezos_get_public_key(client: Client): +def test_tezos_get_public_key(session: Session): path = parse_path("m/44h/1729h/0h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkttLhEbVfMC3DhyVVFzdwh8ncRnEWiLD1x8TAuPU7vSJak7RtBX" path = parse_path("m/44h/1729h/1h") - pk = get_public_key(client, path) + pk = get_public_key(session, path) assert pk == "edpkuTPqWjcApwyD3VdJhviKM5C13zGk8c4m87crgFarQboF3Mp56f" diff --git a/tests/device_tests/tezos/test_sign_tx.py b/tests/device_tests/tezos/test_sign_tx.py index 06e17304db6..f70a4934d9c 100644 --- a/tests/device_tests/tezos/test_sign_tx.py +++ b/tests/device_tests/tezos/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import messages, tezos -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.protobuf import dict_to_proto from trezorlib.tools import parse_path @@ -32,10 +32,10 @@ ] -def test_tezos_sign_tx_proposal(client: Client): - with client: +def test_tezos_sign_tx_proposal(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -63,10 +63,10 @@ def test_tezos_sign_tx_proposal(client: Client): assert resp.operation_hash == "opLqntFUu984M7LnGsFvfGW6kWe9QjAz4AfPDqQvwJ1wPM4Si4c" -def test_tezos_sign_tx_multiple_proposals(client: Client): - with client: +def test_tezos_sign_tx_multiple_proposals(session: Session): + with session: resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -95,9 +95,9 @@ def test_tezos_sign_tx_multiple_proposals(client: Client): assert resp.operation_hash == "onobSyNgiitGXxSVFJN6949MhUomkkxvH4ZJ2owgWwNeDdntF9Y" -def test_tezos_sing_tx_ballot_yay(client: Client): +def test_tezos_sing_tx_ballot_yay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -119,9 +119,9 @@ def test_tezos_sing_tx_ballot_yay(client: Client): ) -def test_tezos_sing_tx_ballot_nay(client: Client): +def test_tezos_sing_tx_ballot_nay(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -142,9 +142,9 @@ def test_tezos_sing_tx_ballot_nay(client: Client): ) -def test_tezos_sing_tx_ballot_pass(client: Client): +def test_tezos_sing_tx_ballot_pass(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -167,9 +167,9 @@ def test_tezos_sing_tx_ballot_pass(client: Client): @pytest.mark.parametrize("chunkify", (True, False)) -def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): +def test_tezos_sign_tx_tranasaction(session: Session, chunkify: bool): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -202,9 +202,9 @@ def test_tezos_sign_tx_tranasaction(client: Client, chunkify: bool): assert resp.operation_hash == "oon8PNUsPETGKzfESv1Epv4535rviGS7RdCfAEKcPvzojrcuufb" -def test_tezos_sign_tx_delegation(client: Client): +def test_tezos_sign_tx_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_15, dict_to_proto( messages.TezosSignTx, @@ -232,9 +232,9 @@ def test_tezos_sign_tx_delegation(client: Client): assert resp.operation_hash == "op79C1tR7wkUgYNid2zC1WNXmGorS38mTXZwtAjmCQm2kG7XG59" -def test_tezos_sign_tx_origination(client: Client): +def test_tezos_sign_tx_origination(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -263,9 +263,9 @@ def test_tezos_sign_tx_origination(client: Client): assert resp.operation_hash == "onmq9FFZzvG2zghNdr1bgv9jzdbzNycXjSSNmCVhXCGSnV3WA9g" -def test_tezos_sign_tx_reveal(client: Client): +def test_tezos_sign_tx_reveal(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH, dict_to_proto( messages.TezosSignTx, @@ -305,9 +305,9 @@ def test_tezos_sign_tx_reveal(client: Client): assert resp.operation_hash == "oo9JFiWTnTSvUZfajMNwQe1VyFN2pqwiJzZPkpSAGfGD57Z6mZJ" -def test_tezos_smart_contract_delegation(client: Client): +def test_tezos_smart_contract_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -342,9 +342,9 @@ def test_tezos_smart_contract_delegation(client: Client): assert resp.operation_hash == "oo75gfQGGPEPChXZzcPPAGtYqCpsg2BS5q9gmhrU3NQP7CEffpU" -def test_tezos_kt_remove_delegation(client: Client): +def test_tezos_kt_remove_delegation(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -377,9 +377,9 @@ def test_tezos_kt_remove_delegation(client: Client): assert resp.operation_hash == "ootMi1tXbfoVgFyzJa8iXyR4mnHd5TxLm9hmxVzMVRkbyVjKaHt" -def test_tezos_smart_contract_transfer(client: Client): +def test_tezos_smart_contract_transfer(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, @@ -420,9 +420,9 @@ def test_tezos_smart_contract_transfer(client: Client): assert resp.operation_hash == "ooRGGtCmoQDgB36XvQqmM7govc3yb77YDUoa7p2QS7on27wGRns" -def test_tezos_smart_contract_transfer_to_contract(client: Client): +def test_tezos_smart_contract_transfer_to_contract(session: Session): resp = tezos.sign_tx( - client, + session, TEZOS_PATH_10, dict_to_proto( messages.TezosSignTx, diff --git a/tests/device_tests/webauthn/test_msg_webauthn.py b/tests/device_tests/webauthn/test_msg_webauthn.py index 3fd7ca7fd95..7016e2f5f80 100644 --- a/tests/device_tests/webauthn/test_msg_webauthn.py +++ b/tests/device_tests/webauthn/test_msg_webauthn.py @@ -17,7 +17,7 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import Cancelled, TrezorFailure from ...common import MNEMONIC12 @@ -30,23 +30,23 @@ @pytest.mark.models("core") @pytest.mark.altcoin @pytest.mark.setup_client(mnemonic=MNEMONIC12) -def test_add_remove(client: Client): - with client: +def test_add_remove(session: Session): + with session, session.client as client: IF = InputFlowFidoConfirm(client) client.set_input_flow(IF.get()) # Remove index 0 should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, 0) + fido.remove_credential(session, 0) # List should be empty. - assert fido.list_credentials(client) == [] + assert fido.list_credentials(session) == [] # Add valid credential #1. - fido.add_credential(client, CRED1) + fido.add_credential(session, CRED1) # Check that the credential was added and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name == "Example" @@ -59,10 +59,10 @@ def test_add_remove(client: Client): assert creds[0].hmac_secret is True # Add valid credential #2, which has same rpId and userId as credential #1. - fido.add_credential(client, CRED2) + fido.add_credential(session, CRED2) # Check that the credential #2 replaced credential #1 and parameters are correct. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 assert creds[0].rp_id == "example.com" assert creds[0].rp_name is None @@ -76,32 +76,32 @@ def test_add_remove(client: Client): # Adding an invalid credential should appear as if user cancelled. with pytest.raises(Cancelled): - fido.add_credential(client, CRED1[:-2]) + fido.add_credential(session, CRED1[:-2]) # Check that the invalid credential was not added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 1 # Add valid credential, which has same userId as #2, but different rpId. - fido.add_credential(client, CRED3) + fido.add_credential(session, CRED3) # Check that the credential was added. - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) assert len(creds) == 2 # Fill up the credential storage to maximum capacity. for cred in CREDS[: RK_CAPACITY - 2]: - fido.add_credential(client, cred) + fido.add_credential(session, cred) # Adding one more valid credential to full storage should fail. with pytest.raises(TrezorFailure): - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) # Removing the index, which is one past the end, should fail. with pytest.raises(TrezorFailure): - fido.remove_credential(client, RK_CAPACITY) + fido.remove_credential(session, RK_CAPACITY) # Remove index 2. - fido.remove_credential(client, 2) + fido.remove_credential(session, 2) # Adding another valid credential should succeed now. - fido.add_credential(client, CREDS[-1]) + fido.add_credential(session, CREDS[-1]) diff --git a/tests/device_tests/webauthn/test_u2f_counter.py b/tests/device_tests/webauthn/test_u2f_counter.py index d99467f2b9d..c140ba54578 100644 --- a/tests/device_tests/webauthn/test_u2f_counter.py +++ b/tests/device_tests/webauthn/test_u2f_counter.py @@ -17,15 +17,15 @@ import pytest from trezorlib import fido -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session @pytest.mark.altcoin -def test_u2f_counter(client: Client): - assert fido.get_next_counter(client) == 0 - assert fido.get_next_counter(client) == 1 - fido.set_counter(client, 111111) - assert fido.get_next_counter(client) == 111112 - assert fido.get_next_counter(client) == 111113 - fido.set_counter(client, 0) - assert fido.get_next_counter(client) == 1 +def test_u2f_counter(session: Session): + assert fido.get_next_counter(session) == 0 + assert fido.get_next_counter(session) == 1 + fido.set_counter(session, 111111) + assert fido.get_next_counter(session) == 111112 + assert fido.get_next_counter(session) == 111113 + fido.set_counter(session, 0) + assert fido.get_next_counter(session) == 1 diff --git a/tests/device_tests/zcash/test_sign_tx.py b/tests/device_tests/zcash/test_sign_tx.py index d689c8af969..4d7df800903 100644 --- a/tests/device_tests/zcash/test_sign_tx.py +++ b/tests/device_tests/zcash/test_sign_tx.py @@ -17,7 +17,7 @@ import pytest from trezorlib import btc, messages -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -53,7 +53,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.zcash] -def test_version_group_id_missing(client: Client): +def test_version_group_id_missing(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -69,7 +69,7 @@ def test_version_group_id_missing(client: Client): with pytest.raises(TrezorFailure, match="Version group ID must be set."): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -77,7 +77,7 @@ def test_version_group_id_missing(client: Client): ) -def test_spend_v4_input(client: Client): +def test_spend_v4_input(session: Session): # 4b6cecb81c825180786ebe07b65bcc76078afc5be0f1c64e08d764005012380d is a v4 tx inp1 = messages.TxInputType( @@ -95,13 +95,13 @@ def test_spend_v4_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -110,7 +110,7 @@ def test_spend_v4_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -126,7 +126,7 @@ def test_spend_v4_input(client: Client): ) -def test_send_to_multisig(client: Client): +def test_send_to_multisig(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/8"), @@ -143,13 +143,13 @@ def test_send_to_multisig(client: Client): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -158,7 +158,7 @@ def test_send_to_multisig(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -174,7 +174,7 @@ def test_send_to_multisig(client: Client): ) -def test_spend_v5_input(client: Client): +def test_spend_v5_input(session: Session): inp1 = messages.TxInputType( # tmBMyeJebzkP5naji8XUKqLyL1NDwNkgJFt address_n=parse_path("m/44h/1h/0h/0/9"), @@ -190,13 +190,13 @@ def test_spend_v5_input(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), @@ -205,7 +205,7 @@ def test_spend_v5_input(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -221,7 +221,7 @@ def test_spend_v5_input(client: Client): ) -def test_one_two(client: Client): +def test_one_two(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -243,13 +243,13 @@ def test_one_two(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -260,7 +260,7 @@ def test_one_two(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -277,7 +277,7 @@ def test_one_two(client: Client): @pytest.mark.models("core") -def test_unified_address(client: Client): +def test_unified_address(session: Session): # identical to the test_one_two # but receiver address is unified with an orchard address inp1 = messages.TxInputType( @@ -301,13 +301,13 @@ def test_unified_address(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), request_output(1), messages.ButtonRequest(code=B.SignTx), request_input(0), @@ -318,7 +318,7 @@ def test_unified_address(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -335,7 +335,7 @@ def test_unified_address(client: Client): @pytest.mark.models("core") -def test_external_presigned(client: Client): +def test_external_presigned(session: Session): inp1 = messages.TxInputType( # tmQoJ3PTXgQLaRRZZYT6xk8XtjRbr2kCqwu address_n=parse_path("m/44h/1h/0h/0/0"), @@ -365,14 +365,14 @@ def test_external_presigned(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with client: - client.set_expected_responses( + with session: + session.set_expected_responses( [ request_input(0), request_input(1), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(1), request_input(0), @@ -383,7 +383,7 @@ def test_external_presigned(client: Client): ) _, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1, inp2], [out1], @@ -399,7 +399,7 @@ def test_external_presigned(client: Client): ) -def test_refuse_replacement_tx(client: Client): +def test_refuse_replacement_tx(session: Session): inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/4"), amount=174998, @@ -437,7 +437,7 @@ def test_refuse_replacement_tx(client: Client): TrezorFailure, match="Replacement transactions are not supported." ): btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1, out2], @@ -447,12 +447,12 @@ def test_refuse_replacement_tx(client: Client): ) -def test_spend_multisig(client: Client): +def test_spend_multisig(session: Session): # Cloned from tests/device_tests/bitcoin/test_multisig.py::test_2_of_3 nodes = [ btc.get_public_node( - client, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" + session, parse_path(f"m/48h/1h/{index}h/0h"), coin_name="Zcash Testnet" ).node for index in range(1, 4) ] @@ -482,17 +482,17 @@ def test_spend_multisig(client: Client): request_input(0), request_output(0), messages.ButtonRequest(code=B.ConfirmOutput), - (is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)), + (is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), request_input(0), request_output(0), request_finished(), ] - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures1, _ = btc.sign_tx( - client, + session, "Zcash Testnet", [inp1], [out1], @@ -529,10 +529,10 @@ def test_spend_multisig(client: Client): multisig=multisig, ) - with client: - client.set_expected_responses(expected_responses) + with session: + session.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( - client, + session, "Zcash Testnet", [inp3], [out1], diff --git a/tests/input_flows.py b/tests/input_flows.py index 24cf6a5093b..4ca42c2ec68 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -16,6 +16,7 @@ from trezorlib import messages from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import multipage_content @@ -129,13 +130,15 @@ def input_two_different_pins() -> BRGeneratorType: class InputFlowCodeChangeFail(InputFlowBase): + def __init__( - self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str + self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str ): - super().__init__(client) + super().__init__(session.client) self.current_pin = current_pin self.new_pin_1 = new_pin_1 self.new_pin_2 = new_pin_2 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield # do you want to change pin? @@ -150,7 +153,7 @@ def input_flow_common(self) -> BRGeneratorType: # failed retry yield # enter current pin again - self.client.cancel() + self.session.cancel() class InputFlowWrongPIN(InputFlowBase): @@ -1876,9 +1879,11 @@ def input_flow_common(self) -> BRGeneratorType: class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase): - def __init__(self, client: Client): - super().__init__(client) + + def __init__(self, session: Session): + super().__init__(session.client) self.invalid_mnemonic = ["stick"] * 12 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_dry_run() @@ -1887,7 +1892,7 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_invalid_recovery_seed() yield - self.client.cancel() + self.session.cancel() class InputFlowBip39Recovery(InputFlowBase): @@ -1970,15 +1975,17 @@ def input_flow_common(self) -> BRGeneratorType: class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase): + def __init__( self, - client: Client, + session: Session, first_share: list[str], second_share: list[str], ): - super().__init__(client) + super().__init__(session.client) self.first_share = first_share self.second_share = second_share + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -1990,19 +1997,21 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_group_threshold_reached() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase): + def __init__( self, - client: Client, + session: Session, first_share: list[str], second_share: list[str], ): - super().__init__(client) + super().__init__(session.client) self.first_share = first_share self.second_share = second_share + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2014,7 +2023,7 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_share_already_entered() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase): @@ -2113,10 +2122,12 @@ def input_flow_common(self) -> BRGeneratorType: class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase): - def __init__(self, client: Client): - super().__init__(client) + + def __init__(self, session: Session): + super().__init__(session.client) self.first_invalid = ["slush"] * 20 self.second_invalid = ["slush"] * 33 + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2128,16 +2139,18 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_invalid_recovery_share() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase): - def __init__(self, client: Client, shares: list[str]): - super().__init__(client) + + def __init__(self, session: Session, shares: list[str]): + super().__init__(session.client) self.shares = shares self.first_share = shares[0].split(" ") self.invalid_share = self.first_share[:3] + ["slush"] * 17 self.second_share = shares[1].split(" ") + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2150,16 +2163,18 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.success_more_shares_needed(1) yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase): - def __init__(self, client: Client, share: list[str], nth_word: int): - super().__init__(client) + + def __init__(self, session: Session, share: list[str], nth_word: int): + super().__init__(session.client) self.share = share self.nth_word = nth_word # Invalid share - just enough words to trigger the warning self.modified_share = share[:nth_word] + [self.share[-1]] + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2170,15 +2185,17 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_share_from_another_shamir() yield - self.client.cancel() + self.session.cancel() class InputFlowSlip39BasicRecoverySameShare(InputFlowBase): - def __init__(self, client: Client, share: list[str]): - super().__init__(client) + + def __init__(self, session: Session, share: list[str]): + super().__init__(session.client) self.share = share # Second duplicate share - only 4 words are needed to verify it self.duplicate_share = self.share[:4] + self.session = session def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() @@ -2189,7 +2206,7 @@ def input_flow_common(self) -> BRGeneratorType: yield from self.REC.warning_share_already_entered() yield - self.client.cancel() + self.session.cancel() class InputFlowResetSkipBackup(InputFlowBase): diff --git a/tests/persistence_tests/test_safety_checks.py b/tests/persistence_tests/test_safety_checks.py index 1cbf7d75516..c2552f04b3f 100644 --- a/tests/persistence_tests/test_safety_checks.py +++ b/tests/persistence_tests/test_safety_checks.py @@ -20,16 +20,18 @@ def test_safety_checks_level_after_reboot( core_emulator: Emulator, set_level: SafetyCheckLevel, after_level: SafetyCheckLevel ): - device.wipe(core_emulator.client) + device.wipe(core_emulator.client.get_management_session()) + core_emulator.client = core_emulator.client.get_new_client() debuglink.load_device( - core_emulator.client, + core_emulator.client.get_management_session(), mnemonic=MNEMONIC12, pin="", passphrase_protection=False, label="SAFETYLEVEL", ) - device.apply_settings(core_emulator.client, safety_checks=set_level) + device.apply_settings(core_emulator.client.get_session(), safety_checks=set_level) + core_emulator.client.refresh_features() assert core_emulator.client.features.safety_checks == set_level core_emulator.restart() diff --git a/tests/persistence_tests/test_shamir_persistence.py b/tests/persistence_tests/test_shamir_persistence.py index e24f16eeb67..1524bd52034 100644 --- a/tests/persistence_tests/test_shamir_persistence.py +++ b/tests/persistence_tests/test_shamir_persistence.py @@ -16,7 +16,8 @@ import pytest -from trezorlib import device +from trezorlib import device, messages +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import DebugLink, LayoutType from trezorlib.messages import RecoveryStatus @@ -45,7 +46,7 @@ def test_abort(core_emulator: Emulator): assert features.recovery_status == RecoveryStatus.Nothing - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) layout = debug.read_layout() @@ -82,7 +83,7 @@ def test_recovery_single_reset(core_emulator: Emulator): assert features.initialized is False assert features.recovery_status == RecoveryStatus.Nothing - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) @@ -129,7 +130,7 @@ def assert_mnemonic_keyboard(debug: DebugLink) -> None: assert features.recovery_status == RecoveryStatus.Nothing # enter recovery mode - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) @@ -157,7 +158,8 @@ def assert_mnemonic_keyboard(debug: DebugLink) -> None: layout = debug.read_layout() # while keyboard is open, hit the device with Initialize/GetFeatures - device_handler.client.init_device() + if device_handler.client.protocol_version == ProtocolVersion.PROTOCOL_V1: + device_handler.client.get_management_session().call(messages.Initialize()) device_handler.client.refresh_features() # try entering remaining 19 words @@ -207,7 +209,7 @@ def enter_shares_with_restarts(debug: DebugLink) -> None: assert features.recovery_status == RecoveryStatus.Nothing # start device and recovery - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery.confirm_recovery(debug) diff --git a/tests/persistence_tests/test_wipe_code.py b/tests/persistence_tests/test_wipe_code.py index 2497a708f6e..cb06eeb2cd8 100644 --- a/tests/persistence_tests/test_wipe_code.py +++ b/tests/persistence_tests/test_wipe_code.py @@ -11,46 +11,55 @@ def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client) + device.wipe(client.get_management_session()) + client = client.get_new_client() debuglink.load_device( - client, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE" + client.get_management_session(), + MNEMONIC12, + pin, + passphrase_protection=False, + label="WIPECODE", ) with client: client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) - device.change_wipe_code(client) + device.change_wipe_code(client.get_management_session()) def setup_device_core(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client) + device.wipe(client.get_management_session()) + client = client.get_new_client() debuglink.load_device( - client, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE" + client.get_management_session(), + MNEMONIC12, + pin, + passphrase_protection=False, + label="WIPECODE", ) with client: client.use_pin_sequence([pin, wipe_code, wipe_code]) - device.change_wipe_code(client) + device.change_wipe_code(client.get_management_session()) @core_only def test_wipe_code_activate_core(core_emulator: Emulator): # set up device setup_device_core(core_emulator.client, PIN, WIPE_CODE) - - core_emulator.client.init_device() + session = core_emulator.client.get_session() device_id = core_emulator.client.features.device_id # Initiate Change pin process - ret = core_emulator.client.call_raw(messages.ChangePin(remove=False)) + ret = session.call_raw(messages.ChangePin(remove=False)) assert isinstance(ret, messages.ButtonRequest) assert ret.name == "change_pin" core_emulator.client.debug.press_yes() - ret = core_emulator.client.call_raw(messages.ButtonAck()) + ret = session.call_raw(messages.ButtonAck()) # Enter the wipe code instead of the current PIN expected = message_filters.ButtonRequest(code=messages.ButtonRequestType.PinEntry) assert expected.match(ret) - core_emulator.client._raw_write(messages.ButtonAck()) + session._write(messages.ButtonAck()) core_emulator.client.debug.input(WIPE_CODE) # preserving screenshots even after it dies and starts again @@ -75,25 +84,26 @@ def test_wipe_code_activate_legacy(): # set up device setup_device_legacy(emu.client, PIN, WIPE_CODE) - emu.client.init_device() + session = emu.client.get_session() device_id = emu.client.features.device_id # Initiate Change pin process - ret = emu.client.call_raw(messages.ChangePin(remove=False)) + ret = session.call_raw(messages.ChangePin(remove=False)) assert isinstance(ret, messages.ButtonRequest) emu.client.debug.press_yes() - ret = emu.client.call_raw(messages.ButtonAck()) + ret = session.call_raw(messages.ButtonAck()) # Enter the wipe code instead of the current PIN assert isinstance(ret, messages.PinMatrixRequest) wipe_code_encoded = emu.client.debug.encode_pin(WIPE_CODE) - emu.client._raw_write(messages.PinMatrixAck(pin=wipe_code_encoded)) + session._write(messages.PinMatrixAck(pin=wipe_code_encoded)) # wait 30 seconds for emulator to shut down # this will raise a TimeoutError if the emulator doesn't die. emu.wait(30) emu.start() + emu.client.refresh_features() assert emu.client.features.initialized is False assert emu.client.features.pin_protection is False assert emu.client.features.wipe_code_protection is False diff --git a/tests/translations.py b/tests/translations.py index afb12a5fec2..34f79888bae 100644 --- a/tests/translations.py +++ b/tests/translations.py @@ -8,7 +8,7 @@ from trezorlib import cosi, device, models from trezorlib._internal import translations -from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.debuglink import SessionDebugWrapper as Session from . import common @@ -58,20 +58,19 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes: def build_and_sign_blob( lang_or_def: translations.JsonDef | Path | str, - client: Client, + session: Session, ) -> bytes: - blob = prepare_blob(lang_or_def, client.model, client.version) + blob = prepare_blob(lang_or_def, session.model, session.version) return sign_blob(blob) -def set_language(client: Client, lang: str): +def set_language(session: Session, lang: str): if lang.startswith("en"): language_data = b"" else: - language_data = build_and_sign_blob(lang, client) - with client: - device.change_language(client, language_data) # type: ignore - _CURRENT_TRANSLATION.TR = TRANSLATIONS[lang] + language_data = build_and_sign_blob(lang, session) + with session: + device.change_language(session, language_data) # type: ignore def get_lang_json(lang: str) -> translations.JsonDef: diff --git a/tests/ui_tests/__init__.py b/tests/ui_tests/__init__.py index c7ab04a45ee..8f2ad3dc10b 100644 --- a/tests/ui_tests/__init__.py +++ b/tests/ui_tests/__init__.py @@ -68,7 +68,8 @@ def screen_recording( if record_text_layout: client.debug.set_screen_text_file(None) client.debug.watch_layout(False) - client.init_device() + # Instead of client.init_device() we create a new management session + client.get_management_session() client.debug.stop_recording() result = testcase.build_result(request) diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index bafe67f511d..94682a9b191 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -20,7 +20,9 @@ import pytest from shamir_mnemonic import shamir -from trezorlib import btc, debuglink, device, exceptions, fido, models +from trezorlib import btc, debuglink, device, exceptions, fido, messages, models +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import ( ApplySettings, BackupAvailability, @@ -57,15 +59,19 @@ @for_all() def test_upgrade_load(gen: str, tag: str) -> None: def asserts(client: "Client"): + client.refresh_features() assert not client.features.pin_protection assert not client.features.passphrase_protection assert client.features.initialized assert client.features.label == LABEL - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert ( + btc.get_address(client.get_session(passphrase=""), "Bitcoin", PATH) + == ADDRESS + ) with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_management_session(), mnemonic=MNEMONIC, pin="", passphrase_protection=False, @@ -89,12 +95,14 @@ def asserts(client: "Client") -> None: assert not client.features.passphrase_protection assert client.features.initialized assert client.features.label == LABEL - client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + session = Session(client.get_session()) + with client, session: + client.use_pin_sequence([PIN]) + assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + Session(emu.client.get_management_session()), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -130,11 +138,11 @@ def asserts(client: "Client") -> None: assert client.features.initialized assert client.features.label == LABEL client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tags[0]) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_management_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -164,11 +172,11 @@ def asserts(client: "Client"): assert client.features.initialized assert client.features.label == LABEL client.use_pin_sequence([PIN]) - assert btc.get_address(client, "Bitcoin", PATH) == ADDRESS + assert btc.get_address(client.get_session(), "Bitcoin", PATH) == ADDRESS with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_management_session(), mnemonic=MNEMONIC, pin=PIN, passphrase_protection=False, @@ -177,7 +185,9 @@ def asserts(client: "Client"): # Set wipe code. emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) - device.change_wipe_code(emu.client) + session = Session(emu.client.get_management_session()) + session.refresh_features() + device.change_wipe_code(session) device_id = emu.client.features.device_id asserts(emu.client) @@ -189,11 +199,13 @@ def asserts(client: "Client"): # Check that wipe code is set by changing the PIN to it. emu.client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) + session = Session(emu.client.get_management_session()) + session.refresh_features() with pytest.raises( exceptions.TrezorFailure, match="The new PIN must be different from your wipe code", ): - return device.change_pin(emu.client) + return device.change_pin(session) @for_all("legacy") @@ -209,7 +221,7 @@ def asserts(client: "Client"): with EmulatorWrapper(gen, tag) as emu: device.reset( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -217,13 +229,13 @@ def asserts(client: "Client"): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address @for_all() @@ -239,7 +251,7 @@ def asserts(client: "Client"): with EmulatorWrapper(gen, tag) as emu: device.reset( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -248,13 +260,13 @@ def asserts(client: "Client"): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address @for_all(legacy_minimum_version=(1, 7, 2)) @@ -270,7 +282,7 @@ def asserts(client: "Client"): with EmulatorWrapper(gen, tag) as emu: device.reset( - emu.client, + emu.client.get_session(), strength=STRENGTH, passphrase_protection=False, pin_protection=False, @@ -279,13 +291,13 @@ def asserts(client: "Client"): ) device_id = emu.client.features.device_id asserts(emu.client) - address = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(emu.client.get_session(), "Bitcoin", PATH) storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: assert device_id == emu.client.features.device_id asserts(emu.client) - assert btc.get_address(emu.client, "Bitcoin", PATH) == address + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address # Although Shamir was introduced in 2.1.2 already, the debug instrumentation was not present until 2.1.9. @@ -298,7 +310,7 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): emu.client.watch_layout(True) debug = device_handler.debuglink() - device_handler.run(device.recover, pin_protection=False) + device_handler.run_with_session(device.recover, pin_protection=False) recovery_old.confirm_recovery(debug) recovery_old.select_number_of_words(debug) @@ -343,9 +355,10 @@ def test_upgrade_shamir_recovery(gen: str, tag: Optional[str]): @for_all("core", core_minimum_version=(2, 1, 9)) def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): with EmulatorWrapper(gen, tag) as emu: + session = Session(emu.client.get_management_session()) # Generate a new encrypted master secret and record it. device.reset( - emu.client, + session, pin_protection=False, skip_backup=True, backup_type=BackupType.Slip39_Basic, @@ -355,14 +368,16 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): mnemonic_secret = emu.client.debug.state().mnemonic_secret # Set passphrase_source = HOST. - resp = emu.client.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) + session = Session(emu.client.get_session()) + resp = session.call(ApplySettings(_passphrase_source=2, use_passphrase=True)) assert isinstance(resp, Success) # Get a passphrase-less and a passphrased address. - address = btc.get_address(emu.client, "Bitcoin", PATH) - emu.client.init_device(new_session=True) - emu.client.use_passphrase("TREZOR") - address_passphrase = btc.get_address(emu.client, "Bitcoin", PATH) + address = btc.get_address(session, "Bitcoin", PATH) + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + session.call(messages.Initialize(new_session=True)) + new_session = emu.client.get_session(passphrase="TREZOR") + address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) assert emu.client.features.backup_availability == BackupAvailability.Required storage = emu.get_storage() @@ -375,7 +390,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): with emu.client: IF = InputFlowSlip39BasicBackup(emu.client, False) emu.client.set_input_flow(IF.get()) - device.backup(emu.client) + device.backup(emu.client.get_session()) assert ( emu.client.features.backup_availability == BackupAvailability.NotAvailable ) @@ -396,10 +411,13 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): assert ems.ciphertext == mnemonic_secret # Check that addresses are the same after firmware upgrade and backup. - assert btc.get_address(emu.client, "Bitcoin", PATH) == address - emu.client.init_device(new_session=True) - emu.client.use_passphrase("TREZOR") - assert btc.get_address(emu.client, "Bitcoin", PATH) == address_passphrase + assert btc.get_address(emu.client.get_session(), "Bitcoin", PATH) == address + assert ( + btc.get_address( + emu.client.get_session(passphrase="TREZOR"), "Bitcoin", PATH + ) + == address_passphrase + ) @for_all(legacy_minimum_version=(1, 8, 4), core_minimum_version=(2, 1, 9)) @@ -407,22 +425,22 @@ def test_upgrade_u2f(gen: str, tag: str): """Check U2F counter stayed the same after an upgrade.""" with EmulatorWrapper(gen, tag) as emu: debuglink.load_device_by_mnemonic( - emu.client, + emu.client.get_management_session(), mnemonic=MNEMONIC, pin="", passphrase_protection=False, label=LABEL, ) - - success = fido.set_counter(emu.client, 10) + session = emu.client.get_management_session() + success = fido.set_counter(session, 10) assert "U2F counter set" in success - counter = fido.get_next_counter(emu.client) + counter = fido.get_next_counter(session) assert counter == 11 storage = emu.get_storage() with EmulatorWrapper(gen, storage=storage) as emu: - counter = fido.get_next_counter(emu.client) + counter = fido.get_next_counter(session) assert counter == 12 diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index 67a4c406fb9..bdeb74cabf6 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -20,6 +20,8 @@ from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib._internal.emulator import Emulator +from trezorlib.client import ProtocolVersion +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path from ..emulators import EmulatorWrapper @@ -47,11 +49,12 @@ def emulator(gen: str, tag: str) -> Iterator[Emulator]: with EmulatorWrapper(gen, tag) as emu: # set up a passphrase-protected device device.reset( - emu.client, + emu.client.get_management_session(), pin_protection=False, skip_backup=True, ) - resp = emu.client.call( + emu.client = emu.client.get_new_client() + resp = emu.client.get_management_session().call( ApplySettingsCompat(use_passphrase=True, passphrase_source=SOURCE_HOST) ) assert isinstance(resp, messages.Success) @@ -87,11 +90,10 @@ def test_passphrase_works(emulator: Emulator): messages.ButtonRequest, messages.Address, ] - - with emulator.client: - emulator.client.use_passphrase("TREZOR") - emulator.client.set_expected_responses(expected_responses) - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) + emu_session = emulator.client.get_session(passphrase="TREZOR") + with Session(emu_session) as session: + session.set_expected_responses(expected_responses) + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) @for_all( @@ -131,13 +133,18 @@ def test_init_device(emulator: Emulator): messages.Address, ] - with emulator.client: - emulator.client.use_passphrase("TREZOR") - emulator.client.set_expected_responses(expected_responses) + emu_session = emulator.client.get_session(passphrase="TREZOR") + with Session(emu_session) as session: + session.set_expected_responses(expected_responses) - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) + btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) # in TT < 2.3.0 session_id will only be available after PassphraseStateRequest - session_id = emulator.client.session_id - emulator.client.init_device() - btc.get_address(emulator.client, "Testnet", parse_path("44h/1h/0h/0/0")) - assert session_id == emulator.client.session_id + session_id = session.id + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: + session.call(messages.Initialize(session_id=session_id)) + btc.get_address( + session, + "Testnet", + parse_path("44h/1h/0h/0/0"), + ) + assert session_id == session.id diff --git a/vendor/fido2-tests b/vendor/fido2-tests index 93a68b36f6e..c827648dd6d 160000 --- a/vendor/fido2-tests +++ b/vendor/fido2-tests @@ -1 +1 @@ -Subproject commit 93a68b36f6eeaa605b11d7330aa04f6ae874cf61 +Subproject commit c827648dd6d44ce1935f8296d905afb9df1de685