apps.common: implement finish device state handling

This commit is contained in:
Pavol Rusnak 2018-02-24 18:58:02 +01:00
parent 35e1135c95
commit 502ecd7bcc
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
4 changed files with 41 additions and 19 deletions

View File

@ -1,11 +1,32 @@
from trezor.crypto import random, hashlib, hmac
from apps.common.storage import get_device_id
memory = {}
_seed = None
_state = None
_passphrase = None
_state_salt = None
def get_state():
global _state
return _state
def get_state(salt: bytes=None, passphrase: str=None):
global _passphrase, _state_salt
if salt is None:
# generate a random salt if not provided and not already cached
if _state_salt is None:
_state_salt = random.bytes(32)
else:
# otherwise copy provided salt to cached salt
_state_salt = salt
# state = HMAC(passphrase, salt || device_id)
if passphrase is None:
key = _passphrase if _passphrase is not None else ''
else:
key = passphrase
msg = _state_salt + get_device_id().encode()
state = hmac.new(key.encode(), msg, hashlib.sha256).digest()
return _state_salt + state
def get_seed():
@ -13,15 +34,13 @@ def get_seed():
return _seed
def set_seed(seed):
from trezor.crypto import bip32
from trezor.crypto.hashlib import blake2s
node = bip32.from_seed(seed, 'secp256k1')
state = blake2s(node.public_key()).digest()
global _seed, _state
_seed, _state = seed, state
def set_seed(seed, passphrase):
global _seed, _passphrase
_seed, _passphrase = seed, _passphrase
def clear():
global _seed, _state
_seed, _state = None, None
global _seed, _passphrase
global _state_salt
_seed, _passphrase = None, None
_state_salt = None

View File

@ -1,4 +1,5 @@
from trezor import res, ui, wire
from apps.common.cache import get_state
async def request_passphrase(ctx):
@ -47,7 +48,9 @@ async def request_passphrase(ctx):
raise wire.FailureError(ProcessError, 'Passphrase not provided')
passphrase = ack.passphrase
# TODO: process ack.state and check against the current device state, throw error if different
if ack.state is not None:
if ack.state != get_state(salt=ack.state[:32], passphrase=passphrase):
raise wire.FailureError(ProcessError, 'Passphrase mismatch')
return passphrase

View File

@ -16,12 +16,12 @@ async def derive_node(ctx: wire.Context, path=[], curve_name=_DEFAULT_CURVE):
async def _get_seed(ctx: wire.Context) -> bytes:
from . import cache
if cache.get_seed() is None:
seed = await _compute_seed(ctx)
cache.set_seed(seed)
seed, passphrase = await _compute_seed(ctx)
cache.set_seed(seed, passphrase)
return cache.get_seed()
async def _compute_seed(ctx: wire.Context) -> bytes:
async def _compute_seed(ctx: wire.Context) -> (bytes, str):
from trezor.messages.FailureType import ProcessError
from .request_passphrase import protect_by_passphrase
from . import storage
@ -30,7 +30,7 @@ async def _compute_seed(ctx: wire.Context) -> bytes:
raise wire.FailureError(ProcessError, 'Device is not initialized')
passphrase = await protect_by_passphrase(ctx)
return bip39.seed(storage.get_mnemonic(), passphrase)
return bip39.seed(storage.get_mnemonic(), passphrase), passphrase
def derive_node_without_passphrase(path, curve_name=_DEFAULT_CURVE):

View File

@ -10,7 +10,7 @@ async def respond_Features(ctx, msg):
from trezor.messages.Features import Features
if msg.__qualname__ == 'Initialize':
if msg.state is None or msg.state != cache.get_state():
if msg.state is None or msg.state != cache.get_state(salt=msg.state[:32]):
cache.clear()
f = Features()