From 421f17bfeea3acb82d43169db12e92623b418960 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Mon, 30 May 2016 16:15:34 +0200 Subject: [PATCH] rework events around interfaces, split msg.py to wire.py Touch events are sent on special interface now. --- src/trezor/loop.py | 82 ++++++++++++++++++++---------------------- src/trezor/msg.py | 81 ----------------------------------------- src/trezor/wire.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 124 deletions(-) create mode 100644 src/trezor/wire.py diff --git a/src/trezor/loop.py b/src/trezor/loop.py index daed7244..59a566a3 100644 --- a/src/trezor/loop.py +++ b/src/trezor/loop.py @@ -12,13 +12,14 @@ if __debug__: log_delay_rb_len = const(10) log_delay_rb = array.array('i', [0] * log_delay_rb_len) +# Touch interface +TOUCH = const(256) # 0-255 is reserved for USB interfaces TOUCH_START = const(1) TOUCH_MOVE = const(2) -TOUCH_END = const(4) -HID_READ = const(8) +TOUCH_END = const(3) -blocked_gens = {} # event -> generator -time_queue = [] # [(int, int, generator, any)] +msg_handlers = {} # Interface -> generator +time_queue = [] time_ticket = 0 @@ -37,37 +38,43 @@ def unschedule(gen): heapify(time_queue) -def block(gen, event): - curr_gen = blocked_gens.get(event, None) - if curr_gen is not None: - log.warning(__name__, 'Closing %s blocked on %s', curr_gen, event) - curr_gen.close() - blocked_gens[event] = gen +def block(gen, iface): + curr = msg_handlers.get(iface, None) + if curr: + log.warning(__name__, 'Closing %s blocked on %s', curr, iface) + curr.close() + msg_handlers[iface] = gen def unblock(gen): - for key in blocked_gens: - if blocked_gens[key] is gen: - blocked_gens[key] = None + for iface in msg_handlers: + if msg_handlers[iface] is gen: + msg_handlers[iface] = None -class Sleep(): +class Syscall(): + pass + + +class Sleep(Syscall): def __init__(self, us): self.time = utime.ticks_us() + us - -class Select(): - - def __init__(self, *events): - self.events = events - - def handle(self, gen): - for event in self.events: - block(gen, event) + def register(self, gen): + schedule(gen, self, self.time) -class Wait(): +class Select(Syscall): + + def __init__(self, iface): + self.iface = iface + + def register(self, gen): + block(gen, self.iface) + + +class Wait(Syscall): def __init__(self, gens, wait_for=1, exit_others=True): self.gens = gens @@ -77,8 +84,8 @@ class Wait(): self.finished = [] self.callback = None - def handle(self, gen): - self.scheduled = [schedule(self._wait(gen)) for gen in self.gens] + def register(self, gen): + self.scheduled = [schedule(self._wait(g)) for g in self.gens] self.callback = gen def exit(self): @@ -111,7 +118,6 @@ class Wait(): def run_forever(): if __debug__: global log_delay_pos, log_delay_rb, log_delay_rb_len - global blocked_events, blocked_gen DELAY_MAX = const(1000000) @@ -132,14 +138,12 @@ def run_forever(): message = msg.select(delay) if message: # Run interrupt handler right away, they have priority - event = message[0] + iface = message[0] data = message - gen = blocked_gens.pop(event, None) + gen = msg_handlers.pop(iface, None) if not gen: - log.info(__name__, 'No handler for event: %s', event) + log.info(__name__, 'No handler for message: %s', iface) continue - # Cancel other registrations of this handler - unblock(gen) else: # Run something from the time queue if time_queue: @@ -159,17 +163,9 @@ def run_forever(): log.exception(__name__, e) continue - if isinstance(result, Sleep): - # Sleep until result.time, call us later - schedule(gen, result, result.time) - - elif isinstance(result, Select): - # Wait for one or more types of event - result.handle(gen) - - elif isinstance(result, Wait): - # Register us as a waiting callback - result.handle(gen) + if isinstance(result, Syscall): + # Execute the syscall + result.register(gen) elif result is None: # Just call us asap diff --git a/src/trezor/msg.py b/src/trezor/msg.py index c9dcec3f..405c1d61 100644 --- a/src/trezor/msg.py +++ b/src/trezor/msg.py @@ -16,84 +16,3 @@ def select(timeout_us): def send(iface, msg): return _msg.send(iface, msg) - - -REPORT_LEN = const(64) -REPORT_NUM = const(63) -HEADER_MAGIC = const(35) # '#' - - -def read(): - _, iface, rep = yield loop.Select(loop.HID_READ) - assert rep[0] == REPORT_NUM - return rep - - -def read_wire_msg(): - rep = yield from read() - assert rep[1] == HEADER_MAGIC - assert rep[2] == HEADER_MAGIC - (mtype, mlen) = ustruct.unpack_from('>HL', rep, 3) - - # TODO: validate mlen for sane values - - rep = memoryview(rep) - data = rep[9:] - data = data[:mlen] - - mbuf = bytearray(data) # TODO: allocate mlen bytes - remaining = mlen - len(mbuf) - - while remaining > 0: - rep = yield from read() - rep = memoryview(rep) - data = rep[1:] - data = data[:remaining] - mbuf.extend(data) - remaining -= len(data) - - return (mtype, mbuf) - - -def write_wire_msg(mtype, mbuf): - rep = bytearray(REPORT_LEN) - rep[0] = REPORT_NUM - rep[1] = HEADER_MAGIC - rep[2] = HEADER_MAGIC - ustruct.pack_into('>HL', rep, 3, mtype, len(mbuf)) - - rep = memoryview(rep) - mbuf = memoryview(mbuf) - data = rep[9:] - - while mbuf: - n = min(len(data), len(mbuf)) - data[:n] = mbuf[:n] - i = n - while i < len(data): - data[i] = 0 - i += 1 - send(0, rep) - mbuf = mbuf[n:] - data = rep[1:] - - -def read_msg(*types): - mtype, mbuf = yield from read_wire_msg() - for t in types: - if t.wire_type == mtype: - return t.loads(mbuf) - else: - raise Exception('Unexpected message') - - -def write_msg(msg): - mbuf = msg.dumps() - mtype = msg.message_type.wire_type - write_wire_msg(mtype, mbuf) - - -def call(req, *types): - write_msg(req) - res = yield from read_msg(*types) - return res diff --git a/src/trezor/wire.py b/src/trezor/wire.py new file mode 100644 index 00000000..78d9a9f0 --- /dev/null +++ b/src/trezor/wire.py @@ -0,0 +1,90 @@ +import ustruct +from . import msg +from . import loop + +IFACE = const(0) + +REPORT_LEN = const(64) +REPORT_NUM = const(63) +HEADER_MAGIC = const(35) # + + +def read_report(): + _, rep = yield loop.Select(IFACE) + assert rep[0] == REPORT_NUM, 'Report number malformed' + return rep + + +def write_report(rep): + size = msg.send(IFACE, rep) + assert size == REPORT_LEN, 'HID write failed' + + +def read_wire_msg(): + rep = yield from read_report() + assert rep[1] == HEADER_MAGIC + assert rep[2] == HEADER_MAGIC + (mtype, mlen) = ustruct.unpack_from('>HL', rep, 3) + + # TODO: validate mlen for sane values + + rep = memoryview(rep) + data = rep[9:] + data = data[:mlen] + + mbuf = bytearray(data) # TODO: allocate mlen bytes + remaining = mlen - len(mbuf) + + while remaining > 0: + rep = yield from read_report() + rep = memoryview(rep) + data = rep[1:] + data = data[:remaining] + mbuf.extend(data) + remaining -= len(data) + + return (mtype, mbuf) + + +def write_wire_msg(mtype, mbuf): + rep = bytearray(REPORT_LEN) + rep[0] = REPORT_NUM + rep[1] = HEADER_MAGIC + rep[2] = HEADER_MAGIC + ustruct.pack_into('>HL', rep, 3, mtype, len(mbuf)) + + rep = memoryview(rep) + mbuf = memoryview(mbuf) + data = rep[9:] + + while mbuf: + n = min(len(data), len(mbuf)) + data[:n] = mbuf[:n] + i = n + while i < len(data): + data[i] = 0 + i += 1 + write_report(rep) + mbuf = mbuf[n:] + data = rep[1:] + + +def read_msg(*types): + mtype, mbuf = yield from read_wire_msg() + for t in types: + if t.wire_type == mtype: + return t.loads(mbuf) + else: + raise Exception('Unexpected message') + + +def write_msg(m): + mbuf = m.dumps() + mtype = m.message_type.wire_type + write_wire_msg(mtype, mbuf) + + +def call(req, *types): + write_msg(req) + res = yield from read_msg(*types) + return res