diff --git a/src/apps/common/request_passphrase.py b/src/apps/common/request_passphrase.py index 8e60f459..88cb418c 100644 --- a/src/apps/common/request_passphrase.py +++ b/src/apps/common/request_passphrase.py @@ -2,8 +2,9 @@ from trezor import ui, wire async def request_passphrase(session_id): + from trezor.messages.FailureType import ActionCancelled from trezor.messages.PassphraseRequest import PassphraseRequest - from trezor.messages.wire_types import PassphraseAck + from trezor.messages.wire_types import PassphraseAck, Cancel from trezor.ui.text import Text ui.display.clear() @@ -11,5 +12,17 @@ async def request_passphrase(session_id): 'Please enter passphrase', 'on your computer.') text.render() - ack = await wire.call(session_id, PassphraseRequest(), PassphraseAck) + ack = await wire.call(session_id, PassphraseRequest(), PassphraseAck, Cancel) + if ack.MESSAGE_WIRE_TYPE == Cancel: + raise wire.FailureError(ActionCancelled, 'Passphrase cancelled') + return ack.passphrase + + +async def protect_by_passphrase(session_id): + from apps.common import storage + + if storage.is_protected_by_passphrase(): + return await request_passphrase(session_id) + else: + return '' diff --git a/src/apps/common/seed.py b/src/apps/common/seed.py index 8e9e950f..8c3fa223 100644 --- a/src/apps/common/seed.py +++ b/src/apps/common/seed.py @@ -22,7 +22,7 @@ async def get_seed(session_id: int) -> bytes: async def compute_seed(session_id: int) -> bytes: from trezor.messages.FailureType import Other - from .request_passphrase import request_passphrase + from .request_passphrase import protect_by_passphrase from .request_pin import protect_by_pin from . import storage @@ -31,8 +31,5 @@ async def compute_seed(session_id: int) -> bytes: await protect_by_pin(session_id) - if storage.is_protected_by_passphrase(): - passphrase = await request_passphrase(session_id) - else: - passphrase = '' + passphrase = await protect_by_passphrase(session_id) return bip39.seed(storage.get_mnemonic(), passphrase)