From eda280213fde156f45bbfc183dcf464a2502e758 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Thu, 1 Mar 2018 05:13:01 +0100 Subject: [PATCH] src/apps/fido_u2f: fix confirmation, refactor --- src/apps/fido_u2f/__init__.py | 259 ++++++++++++++++++---------------- src/trezor/workflow.py | 25 ++-- 2 files changed, 147 insertions(+), 137 deletions(-) diff --git a/src/apps/fido_u2f/__init__.py b/src/apps/fido_u2f/__init__.py index e691a859..228f3ca4 100644 --- a/src/apps/fido_u2f/__init__.py +++ b/src/apps/fido_u2f/__init__.py @@ -234,7 +234,6 @@ async def read_cmd(iface: io.HID) -> Cmd: read = loop.select(iface.iface_num() | io.POLL_READ) buf = await read - # log.debug(__name__, 'read init %s', buf) ifrm = overlay_struct(buf, desc_init) bcnt = ifrm.bcnt @@ -244,7 +243,8 @@ async def read_cmd(iface: io.HID) -> Cmd: if ifrm.cmd & _TYPE_MASK == _TYPE_CONT: # unexpected cont packet, abort current msg - log.warning(__name__, '_TYPE_CONT') + if __debug__: + log.warning(__name__, '_TYPE_CONT') return None if datalen < bcnt: @@ -256,7 +256,6 @@ async def read_cmd(iface: io.HID) -> Cmd: while datalen < bcnt: buf = await read - # log.debug(__name__, 'read cont %s', buf) cfrm = overlay_struct(buf, desc_cont) @@ -268,14 +267,16 @@ async def read_cmd(iface: io.HID) -> Cmd: if cfrm.cid != ifrm.cid: # cont frame for a different channel, reply with BUSY and skip - log.warning(__name__, '_ERR_CHANNEL_BUSY') + if __debug__: + log.warning(__name__, '_ERR_CHANNEL_BUSY') await send_cmd(cmd_error(cfrm.cid, _ERR_CHANNEL_BUSY), iface) continue if cfrm.seq != seq: # cont frame for this channel, but incorrect seq number, abort # current msg - log.warning(__name__, '_ERR_INVALID_SEQ') + if __debug__: + log.warning(__name__, '_ERR_INVALID_SEQ') await send_cmd(cmd_error(cfrm.cid, _ERR_INVALID_SEQ), iface) return None @@ -299,7 +300,6 @@ async def send_cmd(cmd: Cmd, iface: io.HID) -> None: offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen) iface.write(buf) - # log.debug(__name__, 'send init %s', buf) if offset < datalen: frm = overlay_struct(buf, cont_desc) @@ -312,7 +312,6 @@ async def send_cmd(cmd: Cmd, iface: io.HID) -> None: await write if iface.write(buf) > 0: break - # log.debug(__name__, 'send cont %s', buf) seq += 1 @@ -321,53 +320,65 @@ def boot(iface: io.HID): async def handle_reports(iface: io.HID): + state = ConfirmState() + while True: try: req = await read_cmd(iface) if req is None: continue - resp = dispatch_cmd(req) + resp = dispatch_cmd(req, state) await send_cmd(resp, iface) except Exception as e: log.exception(__name__, e) -def dispatch_cmd(req: Cmd) -> Cmd: +def dispatch_cmd(req: Cmd, state: ConfirmState) -> Cmd: if req.cmd == _CMD_MSG: m = req.to_msg() if m.cla != 0: - log.warning(__name__, '_SW_CLA_NOT_SUPPORTED') + if __debug__: + log.warning(__name__, '_SW_CLA_NOT_SUPPORTED') return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED) if m.lc + _APDU_DATA > len(req.data): - log.warning(__name__, '_SW_WRONG_LENGTH') + if __debug__: + log.warning(__name__, '_SW_WRONG_LENGTH') return msg_error(req.cid, _SW_WRONG_LENGTH) if m.ins == _MSG_REGISTER: - log.debug(__name__, '_MSG_REGISTER') - return msg_register(m) + if __debug__: + log.debug(__name__, '_MSG_REGISTER') + return msg_register(m, state) elif m.ins == _MSG_AUTHENTICATE: - log.debug(__name__, '_MSG_AUTHENTICATE') - return msg_authenticate(m) + if __debug__: + log.debug(__name__, '_MSG_AUTHENTICATE') + return msg_authenticate(m, state) elif m.ins == _MSG_VERSION: - log.debug(__name__, '_MSG_VERSION') + if __debug__: + log.debug(__name__, '_MSG_VERSION') return msg_version(m) else: - log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins) + if __debug__: + log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins) return msg_error(req.cid, _SW_INS_NOT_SUPPORTED) elif req.cmd == _CMD_INIT: - log.debug(__name__, '_CMD_INIT') + if __debug__: + log.debug(__name__, '_CMD_INIT') return cmd_init(req) elif req.cmd == _CMD_PING: - log.debug(__name__, '_CMD_PING') + if __debug__: + log.debug(__name__, '_CMD_PING') return req elif req.cmd == _CMD_WINK: - log.debug(__name__, '_CMD_WINK') + if __debug__: + log.debug(__name__, '_CMD_WINK') return req else: - log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd) + if __debug__: + log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd) return cmd_error(req.cid, _ERR_INVALID_CMD) @@ -394,6 +405,63 @@ def cmd_init(req: Cmd) -> Cmd: _CONFIRM_REGISTER = const(0) _CONFIRM_AUTHENTICATE = const(1) +_CONFIRM_TIMEOUT_MS = const(10 * 1000) + + +class ConfirmState: + + def __init__(self) -> None: + self.reset() + + def reset(self): + self.action = None + self.checksum = None + self.app_id = None + + self.confirmed = None + self.deadline = None + self.workflow = None + + def compare(self, action: int, checksum: bytes) -> bool: + if self.action != action or self.checksum != checksum: + return False + if utime.ticks_ms() >= self.deadline: + return False + return True + + def setup(self, action: int, checksum: bytes, app_id: bytes) -> None: + if self.workflow is not None: + loop.close(self.workflow) + if workflow.workflows: + return False + + self.action = action + self.checksum = checksum + self.app_id = app_id + + self.confirmed = None + self.workflow = self.confirm_workflow() + loop.schedule(self.workflow) + return True + + def keepalive(self): + self.deadline = utime.ticks_ms() + _CONFIRM_TIMEOUT_MS + + async def confirm_workflow(self) -> None: + try: + workflow.onstart(self.workflow) + await self.confirm_layout() + finally: + workflow.onclose(self.workflow) + self.workflow = None + + @ui.layout + async def confirm_layout(self) -> None: + from trezor.ui.confirm import ConfirmDialog, CONFIRMED + + content = ConfirmContent(self.action, self.app_id) + dialog = ConfirmDialog(content, ) + self.confirmed = await dialog == CONFIRMED class ConfirmContent(ui.Widget): @@ -420,7 +488,7 @@ class ConfirmContent(ui.Widget): name = knownapps.knownapps[app_id] try: icon = res.load('apps/fido_u2f/res/u2f_%s.toif' % name.lower().replace(' ', '_')) - except FileNotFoundError: + except Exception: icon = res.load('apps/fido_u2f/res/u2f_generic.toif') else: name = '%s...%s' % (hexlify(app_id[:4]).decode(), hexlify(app_id[-4:]).decode()) @@ -438,75 +506,18 @@ class ConfirmContent(ui.Widget): ui.display.text_center(ui.WIDTH // 2, 168, self.app_name, ui.MONO, ui.FG, ui.BG) -_CONFIRM_STATE_TIMEOUT_MS = const(10 * 1000) - - -class ConfirmState: - - def __init__(self, action: int, app_id: bytes) -> None: - self.action = action - self.app_id = app_id - self.deadline_ms = None - self.confirmed = None - self.task = None - - def fork(self) -> None: - self.deadline_ms = utime.ticks_ms() + _CONFIRM_STATE_TIMEOUT_MS - self.task = self.confirm() - workflow.onstart(self.task) - loop.schedule(self.task) - - def kill(self) -> None: - if self.task is not None: - loop.close(self.task) - self.task = None - - async def confirm(self) -> None: - confirmed = False - try: - confirmed = await self.confirm_layout() - finally: - self.confirmed = confirmed - workflow.onclose(self.task) - - @ui.layout - async def confirm_layout(self) -> None: - from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED - from trezor.ui.text import Text - - if bytes(self.app_id) == _BOGUS_APPID: - text = Text( - 'U2F mismatch', ui.ICON_WRONG, - 'Another U2F device', - 'was used to register', - 'in this application.', - icon_color=ui.RED) - text.render() - await loop.sleep(3 * 1000 * 1000) - return True - - content = ConfirmContent(self.action, self.app_id) - dialog = HoldToConfirmDialog(content) - return await dialog == CONFIRMED - - -_state = None # type: Optional[ConfirmState] # state for msg_register and msg_authenticate -_lastreq = None # type: Optional[Msg] # last received register/authenticate request - - -def msg_register(req: Msg) -> Cmd: - global _state - global _lastreq - +def msg_register(req: Msg, state: ConfirmState) -> Cmd: from apps.common import storage if not storage.is_initialized(): - log.warning(__name__, 'not initialized') + if __debug__: + log.warning(__name__, 'not initialized') return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) # check length of input data if len(req.data) != 64: - log.warning(__name__, '_SW_WRONG_LENGTH req.data') + if __debug__: + log.warning(__name__, '_SW_WRONG_LENGTH req.data') return msg_error(req.cid, _SW_WRONG_LENGTH) # parse challenge and app_id @@ -514,26 +525,24 @@ def msg_register(req: Msg) -> Cmd: app_id = req.data[32:] # check equality with last request - if _lastreq is None or _lastreq.__dict__ != req.__dict__: - if _state is not None: - _state.kill() - _state = None - _lastreq = req + if not state.compare(_CONFIRM_REGISTER, req.data): + if not state.setup(_CONFIRM_REGISTER, req.data, app_id): + return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) + state.keepalive() # wait for a button or continue - if _state is not None and utime.ticks_ms() > _state.deadline_ms: - _state.kill() - _state = None - if _state is None: - _state = ConfirmState(_CONFIRM_REGISTER, app_id) - _state.fork() - if _state.confirmed is None: - log.info(__name__, 'waiting for button') + if not state.confirmed: + if __debug__: + log.info(__name__, 'waiting for button') return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) - _state = None + # sign the registration challenge and return + if __debug__: + log.info(__name__, 'signing register') buf = msg_register_sign(chal, app_id) + state.reset() + return Cmd(req.cid, _CMD_MSG, buf) @@ -586,26 +595,25 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes: return buf -def msg_authenticate(req: Msg) -> Cmd: - - global _state - global _lastreq - +def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd: from apps.common import storage if not storage.is_initialized(): - log.warning(__name__, 'not initialized') + if __debug__: + log.warning(__name__, 'not initialized') return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) # we need at least keyHandleLen if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN: - log.warning(__name__, '_SW_WRONG_LENGTH req.data') + if __debug__: + log.warning(__name__, '_SW_WRONG_LENGTH req.data') return msg_error(req.cid, _SW_WRONG_LENGTH) # check keyHandleLen khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN] if khlen != 64: - log.warning(__name__, '_SW_WRONG_LENGTH khlen') + if __debug__: + log.warning(__name__, '_SW_WRONG_LENGTH khlen') return msg_error(req.cid, _SW_WRONG_LENGTH) auth = overlay_struct(req.data, req_cmd_authenticate(khlen)) @@ -618,40 +626,39 @@ def msg_authenticate(req: Msg) -> Cmd: # if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already if req.p1 == _AUTH_CHECK_ONLY: - log.info(__name__, '_AUTH_CHECK_ONLY') + if __debug__: + log.info(__name__, '_AUTH_CHECK_ONLY') return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) # from now on, only _AUTH_ENFORCE is supported if req.p1 != _AUTH_ENFORCE: - log.info(__name__, '_AUTH_ENFORCE') + if __debug__: + log.info(__name__, '_AUTH_ENFORCE') return msg_error(req.cid, _SW_WRONG_DATA) # check equality with last request - if _lastreq is None or _lastreq.__dict__ != req.__dict__: - if _state is not None: - _state.kill() - _state = None - _lastreq = req + if not state.compare(_CONFIRM_AUTHENTICATE, req.data): + if not state.setup(_CONFIRM_AUTHENTICATE, req.data, auth.appId): + return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) + state.keepalive() # wait for a button or continue - if _state is not None and utime.ticks_ms() > _state.deadline_ms: - _state.kill() - _state = None - if _state is None: - _state = ConfirmState(_CONFIRM_AUTHENTICATE, auth.appId) - _state.fork() - if _state.confirmed is None: - log.info(__name__, 'waiting for button') + if not state.confirmed: + if __debug__: + log.info(__name__, 'waiting for button') return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) - _state = None + # sign the authentication challenge and return + if __debug__: + log.info(__name__, 'signing authentication') buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key()) + state.reset() + return Cmd(req.cid, _CMD_MSG, buf) def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes): - from apps.common import seed # unpack the keypath from the first half of keyhandle @@ -661,7 +668,8 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes): # check high bit for hardened keys for i in keypath: if not i & 0x80000000: - log.warning(__name__, 'invalid key path') + if __debug__: + log.warning(__name__, 'invalid key path') return None # derive the signing key @@ -675,7 +683,8 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes): # verify the hmac if keybase != keyhandle[32:]: - log.warning(__name__, 'invalid key handle') + if __debug__: + log.warning(__name__, 'invalid key handle') return None return node diff --git a/src/trezor/workflow.py b/src/trezor/workflow.py index 7d3f0712..cfd134cf 100644 --- a/src/trezor/workflow.py +++ b/src/trezor/workflow.py @@ -1,19 +1,19 @@ from trezor import loop -started = [] -default = None -default_handler = None +workflows = [] layouts = [] +default = None +default_layout = None def onstart(w): - started.append(w) + workflows.append(w) def onclose(w): - started.remove(w) - if not started and not layouts and default_handler: - startdefault(default_handler) + workflows.remove(w) + if not layouts and default_layout: + startdefault(default_layout) def closedefault(): @@ -24,13 +24,13 @@ def closedefault(): default = None -def startdefault(handler): +def startdefault(layout): global default - global default_handler + global default_layout if not default: - default_handler = handler - default = handler() + default_layout = layout + default = layout() loop.schedule(default) @@ -47,4 +47,5 @@ def onlayoutstart(l): def onlayoutclose(l): - layouts.remove(l) + if l in layouts: + layouts.remove(l)