src/apps/fido_u2f: fix confirmation, refactor

This commit is contained in:
Jan Pochyla 2018-03-01 05:13:01 +01:00
parent f74cbead5e
commit eda280213f
2 changed files with 147 additions and 137 deletions

View File

@ -234,7 +234,6 @@ async def read_cmd(iface: io.HID) -> Cmd:
read = loop.select(iface.iface_num() | io.POLL_READ)
buf = await read
# log.debug(__name__, 'read init %s', buf)
ifrm = overlay_struct(buf, desc_init)
bcnt = ifrm.bcnt
@ -244,7 +243,8 @@ async def read_cmd(iface: io.HID) -> Cmd:
if ifrm.cmd & _TYPE_MASK == _TYPE_CONT:
# unexpected cont packet, abort current msg
log.warning(__name__, '_TYPE_CONT')
if __debug__:
log.warning(__name__, '_TYPE_CONT')
return None
if datalen < bcnt:
@ -256,7 +256,6 @@ async def read_cmd(iface: io.HID) -> Cmd:
while datalen < bcnt:
buf = await read
# log.debug(__name__, 'read cont %s', buf)
cfrm = overlay_struct(buf, desc_cont)
@ -268,14 +267,16 @@ async def read_cmd(iface: io.HID) -> Cmd:
if cfrm.cid != ifrm.cid:
# cont frame for a different channel, reply with BUSY and skip
log.warning(__name__, '_ERR_CHANNEL_BUSY')
if __debug__:
log.warning(__name__, '_ERR_CHANNEL_BUSY')
await send_cmd(cmd_error(cfrm.cid, _ERR_CHANNEL_BUSY), iface)
continue
if cfrm.seq != seq:
# cont frame for this channel, but incorrect seq number, abort
# current msg
log.warning(__name__, '_ERR_INVALID_SEQ')
if __debug__:
log.warning(__name__, '_ERR_INVALID_SEQ')
await send_cmd(cmd_error(cfrm.cid, _ERR_INVALID_SEQ), iface)
return None
@ -299,7 +300,6 @@ async def send_cmd(cmd: Cmd, iface: io.HID) -> None:
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
iface.write(buf)
# log.debug(__name__, 'send init %s', buf)
if offset < datalen:
frm = overlay_struct(buf, cont_desc)
@ -312,7 +312,6 @@ async def send_cmd(cmd: Cmd, iface: io.HID) -> None:
await write
if iface.write(buf) > 0:
break
# log.debug(__name__, 'send cont %s', buf)
seq += 1
@ -321,53 +320,65 @@ def boot(iface: io.HID):
async def handle_reports(iface: io.HID):
state = ConfirmState()
while True:
try:
req = await read_cmd(iface)
if req is None:
continue
resp = dispatch_cmd(req)
resp = dispatch_cmd(req, state)
await send_cmd(resp, iface)
except Exception as e:
log.exception(__name__, e)
def dispatch_cmd(req: Cmd) -> Cmd:
def dispatch_cmd(req: Cmd, state: ConfirmState) -> Cmd:
if req.cmd == _CMD_MSG:
m = req.to_msg()
if m.cla != 0:
log.warning(__name__, '_SW_CLA_NOT_SUPPORTED')
if __debug__:
log.warning(__name__, '_SW_CLA_NOT_SUPPORTED')
return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED)
if m.lc + _APDU_DATA > len(req.data):
log.warning(__name__, '_SW_WRONG_LENGTH')
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH')
return msg_error(req.cid, _SW_WRONG_LENGTH)
if m.ins == _MSG_REGISTER:
log.debug(__name__, '_MSG_REGISTER')
return msg_register(m)
if __debug__:
log.debug(__name__, '_MSG_REGISTER')
return msg_register(m, state)
elif m.ins == _MSG_AUTHENTICATE:
log.debug(__name__, '_MSG_AUTHENTICATE')
return msg_authenticate(m)
if __debug__:
log.debug(__name__, '_MSG_AUTHENTICATE')
return msg_authenticate(m, state)
elif m.ins == _MSG_VERSION:
log.debug(__name__, '_MSG_VERSION')
if __debug__:
log.debug(__name__, '_MSG_VERSION')
return msg_version(m)
else:
log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins)
if __debug__:
log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins)
return msg_error(req.cid, _SW_INS_NOT_SUPPORTED)
elif req.cmd == _CMD_INIT:
log.debug(__name__, '_CMD_INIT')
if __debug__:
log.debug(__name__, '_CMD_INIT')
return cmd_init(req)
elif req.cmd == _CMD_PING:
log.debug(__name__, '_CMD_PING')
if __debug__:
log.debug(__name__, '_CMD_PING')
return req
elif req.cmd == _CMD_WINK:
log.debug(__name__, '_CMD_WINK')
if __debug__:
log.debug(__name__, '_CMD_WINK')
return req
else:
log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd)
if __debug__:
log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd)
return cmd_error(req.cid, _ERR_INVALID_CMD)
@ -394,6 +405,63 @@ def cmd_init(req: Cmd) -> Cmd:
_CONFIRM_REGISTER = const(0)
_CONFIRM_AUTHENTICATE = const(1)
_CONFIRM_TIMEOUT_MS = const(10 * 1000)
class ConfirmState:
def __init__(self) -> None:
self.reset()
def reset(self):
self.action = None
self.checksum = None
self.app_id = None
self.confirmed = None
self.deadline = None
self.workflow = None
def compare(self, action: int, checksum: bytes) -> bool:
if self.action != action or self.checksum != checksum:
return False
if utime.ticks_ms() >= self.deadline:
return False
return True
def setup(self, action: int, checksum: bytes, app_id: bytes) -> None:
if self.workflow is not None:
loop.close(self.workflow)
if workflow.workflows:
return False
self.action = action
self.checksum = checksum
self.app_id = app_id
self.confirmed = None
self.workflow = self.confirm_workflow()
loop.schedule(self.workflow)
return True
def keepalive(self):
self.deadline = utime.ticks_ms() + _CONFIRM_TIMEOUT_MS
async def confirm_workflow(self) -> None:
try:
workflow.onstart(self.workflow)
await self.confirm_layout()
finally:
workflow.onclose(self.workflow)
self.workflow = None
@ui.layout
async def confirm_layout(self) -> None:
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
content = ConfirmContent(self.action, self.app_id)
dialog = ConfirmDialog(content, )
self.confirmed = await dialog == CONFIRMED
class ConfirmContent(ui.Widget):
@ -420,7 +488,7 @@ class ConfirmContent(ui.Widget):
name = knownapps.knownapps[app_id]
try:
icon = res.load('apps/fido_u2f/res/u2f_%s.toif' % name.lower().replace(' ', '_'))
except FileNotFoundError:
except Exception:
icon = res.load('apps/fido_u2f/res/u2f_generic.toif')
else:
name = '%s...%s' % (hexlify(app_id[:4]).decode(), hexlify(app_id[-4:]).decode())
@ -438,75 +506,18 @@ class ConfirmContent(ui.Widget):
ui.display.text_center(ui.WIDTH // 2, 168, self.app_name, ui.MONO, ui.FG, ui.BG)
_CONFIRM_STATE_TIMEOUT_MS = const(10 * 1000)
class ConfirmState:
def __init__(self, action: int, app_id: bytes) -> None:
self.action = action
self.app_id = app_id
self.deadline_ms = None
self.confirmed = None
self.task = None
def fork(self) -> None:
self.deadline_ms = utime.ticks_ms() + _CONFIRM_STATE_TIMEOUT_MS
self.task = self.confirm()
workflow.onstart(self.task)
loop.schedule(self.task)
def kill(self) -> None:
if self.task is not None:
loop.close(self.task)
self.task = None
async def confirm(self) -> None:
confirmed = False
try:
confirmed = await self.confirm_layout()
finally:
self.confirmed = confirmed
workflow.onclose(self.task)
@ui.layout
async def confirm_layout(self) -> None:
from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED
from trezor.ui.text import Text
if bytes(self.app_id) == _BOGUS_APPID:
text = Text(
'U2F mismatch', ui.ICON_WRONG,
'Another U2F device',
'was used to register',
'in this application.',
icon_color=ui.RED)
text.render()
await loop.sleep(3 * 1000 * 1000)
return True
content = ConfirmContent(self.action, self.app_id)
dialog = HoldToConfirmDialog(content)
return await dialog == CONFIRMED
_state = None # type: Optional[ConfirmState] # state for msg_register and msg_authenticate
_lastreq = None # type: Optional[Msg] # last received register/authenticate request
def msg_register(req: Msg) -> Cmd:
global _state
global _lastreq
def msg_register(req: Msg, state: ConfirmState) -> Cmd:
from apps.common import storage
if not storage.is_initialized():
log.warning(__name__, 'not initialized')
if __debug__:
log.warning(__name__, 'not initialized')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# check length of input data
if len(req.data) != 64:
log.warning(__name__, '_SW_WRONG_LENGTH req.data')
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH req.data')
return msg_error(req.cid, _SW_WRONG_LENGTH)
# parse challenge and app_id
@ -514,26 +525,24 @@ def msg_register(req: Msg) -> Cmd:
app_id = req.data[32:]
# check equality with last request
if _lastreq is None or _lastreq.__dict__ != req.__dict__:
if _state is not None:
_state.kill()
_state = None
_lastreq = req
if not state.compare(_CONFIRM_REGISTER, req.data):
if not state.setup(_CONFIRM_REGISTER, req.data, app_id):
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
state.keepalive()
# wait for a button or continue
if _state is not None and utime.ticks_ms() > _state.deadline_ms:
_state.kill()
_state = None
if _state is None:
_state = ConfirmState(_CONFIRM_REGISTER, app_id)
_state.fork()
if _state.confirmed is None:
log.info(__name__, 'waiting for button')
if not state.confirmed:
if __debug__:
log.info(__name__, 'waiting for button')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
_state = None
# sign the registration challenge and return
if __debug__:
log.info(__name__, 'signing register')
buf = msg_register_sign(chal, app_id)
state.reset()
return Cmd(req.cid, _CMD_MSG, buf)
@ -586,26 +595,25 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
return buf
def msg_authenticate(req: Msg) -> Cmd:
global _state
global _lastreq
def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
from apps.common import storage
if not storage.is_initialized():
log.warning(__name__, 'not initialized')
if __debug__:
log.warning(__name__, 'not initialized')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# we need at least keyHandleLen
if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN:
log.warning(__name__, '_SW_WRONG_LENGTH req.data')
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH req.data')
return msg_error(req.cid, _SW_WRONG_LENGTH)
# check keyHandleLen
khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN]
if khlen != 64:
log.warning(__name__, '_SW_WRONG_LENGTH khlen')
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH khlen')
return msg_error(req.cid, _SW_WRONG_LENGTH)
auth = overlay_struct(req.data, req_cmd_authenticate(khlen))
@ -618,40 +626,39 @@ def msg_authenticate(req: Msg) -> Cmd:
# if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already
if req.p1 == _AUTH_CHECK_ONLY:
log.info(__name__, '_AUTH_CHECK_ONLY')
if __debug__:
log.info(__name__, '_AUTH_CHECK_ONLY')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# from now on, only _AUTH_ENFORCE is supported
if req.p1 != _AUTH_ENFORCE:
log.info(__name__, '_AUTH_ENFORCE')
if __debug__:
log.info(__name__, '_AUTH_ENFORCE')
return msg_error(req.cid, _SW_WRONG_DATA)
# check equality with last request
if _lastreq is None or _lastreq.__dict__ != req.__dict__:
if _state is not None:
_state.kill()
_state = None
_lastreq = req
if not state.compare(_CONFIRM_AUTHENTICATE, req.data):
if not state.setup(_CONFIRM_AUTHENTICATE, req.data, auth.appId):
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
state.keepalive()
# wait for a button or continue
if _state is not None and utime.ticks_ms() > _state.deadline_ms:
_state.kill()
_state = None
if _state is None:
_state = ConfirmState(_CONFIRM_AUTHENTICATE, auth.appId)
_state.fork()
if _state.confirmed is None:
log.info(__name__, 'waiting for button')
if not state.confirmed:
if __debug__:
log.info(__name__, 'waiting for button')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
_state = None
# sign the authentication challenge and return
if __debug__:
log.info(__name__, 'signing authentication')
buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key())
state.reset()
return Cmd(req.cid, _CMD_MSG, buf)
def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
from apps.common import seed
# unpack the keypath from the first half of keyhandle
@ -661,7 +668,8 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
# check high bit for hardened keys
for i in keypath:
if not i & 0x80000000:
log.warning(__name__, 'invalid key path')
if __debug__:
log.warning(__name__, 'invalid key path')
return None
# derive the signing key
@ -675,7 +683,8 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
# verify the hmac
if keybase != keyhandle[32:]:
log.warning(__name__, 'invalid key handle')
if __debug__:
log.warning(__name__, 'invalid key handle')
return None
return node

View File

@ -1,19 +1,19 @@
from trezor import loop
started = []
default = None
default_handler = None
workflows = []
layouts = []
default = None
default_layout = None
def onstart(w):
started.append(w)
workflows.append(w)
def onclose(w):
started.remove(w)
if not started and not layouts and default_handler:
startdefault(default_handler)
workflows.remove(w)
if not layouts and default_layout:
startdefault(default_layout)
def closedefault():
@ -24,13 +24,13 @@ def closedefault():
default = None
def startdefault(handler):
def startdefault(layout):
global default
global default_handler
global default_layout
if not default:
default_handler = handler
default = handler()
default_layout = layout
default = layout()
loop.schedule(default)
@ -47,4 +47,5 @@ def onlayoutstart(l):
def onlayoutclose(l):
layouts.remove(l)
if l in layouts:
layouts.remove(l)