diff --git a/.pylintrc b/.pylintrc index dd792729..769ae821 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,5 +1,5 @@ [MASTER] -init-hook='sys.path.insert(0, "mocks")' +init-hook='sys.path.append("mocks"); sys.path.append("src/lib")' [MESSAGES CONTROL] -disable=C0111,C0103,W0603 +disable=C0111,C0103,W0603,W0703 diff --git a/src/lib/protobuf.py b/src/lib/protobuf.py index 1fb47e95..f97bcd7f 100644 --- a/src/lib/protobuf.py +++ b/src/lib/protobuf.py @@ -1,150 +1,53 @@ ''' -Streaming protobuf codec. - -Handles asynchronous encoding and decoding of protobuf value streams. - -Value format: ((field_name, field_type, field_flags), field_value) - field_name (str): Field name string. - field_type (Type): Subclass of Type. - field_flags (int): Field bit flags: `FLAG_REPEATED`. - field_value (Any): Depends on field_type. - MessageTypes have `field_value == None`. - -Type classes are either scalar or message-like. `load()` generators of -scalar types return the value, message types stream it to a target -generator as described above. All types can be loaded and dumped -synchronously with `loads()` and `dumps()`. +Extremely minimal streaming codec for a subset of protobuf. Supports uint32, +bytes, string, embedded message and repeated fields. ''' from micropython import const -from streams import StreamReader, BufferWriter + +_UVARINT_BUFFER = bytearray(1) -def build_message(msg_type, callback=None, *args): - msg = msg_type() - try: - while True: - field, fvalue = yield - fname, ftype, fflags = field - if issubclass(ftype, MessageType): - fvalue = yield from build_message(ftype) - if fflags & FLAG_REPEATED: - prev_value = getattr(msg, fname, []) - prev_value.append(fvalue) - fvalue = prev_value - setattr(msg, fname, fvalue) - except EOFError: - fill_missing_fields(msg) - if callback is not None: - callback(msg, *args) - return msg +async def load_uvarint(reader): + buffer = _UVARINT_BUFFER + result = 0 + shift = 0 + byte = 0x80 + while byte & 0x80: + await reader.readinto(buffer) + byte = buffer[0] + result += (byte & 0x7F) << shift + shift += 7 + return result -def fill_missing_fields(msg): - for tag in msg.FIELDS: - field = msg.FIELDS[tag] - if not hasattr(msg, field[0]): - setattr(msg, field[0], None) +async def dump_uvarint(writer, n): + buffer = _UVARINT_BUFFER + shifted = True + while shifted: + shifted = n >> 7 + buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) + await writer.write(buffer) + n = shifted -class Type: - - @classmethod - def loads(cls, value): - source = StreamReader(value, len(value)) - loader = cls.load(source) - try: - while True: - loader.send(None) - except StopIteration as e: - return e.value - - @classmethod - def dumps(cls, value): - target = BufferWriter() - dumper = cls.dump(value, target) - try: - while True: - dumper.send(None) - except StopIteration: - return target.buffer - - -_uvarint_buffer = bytearray(1) - - -class UVarintType(Type): +class UVarintType: WIRE_TYPE = 0 - @staticmethod - async def load(source): - value, shift, quantum = 0, 0, 0x80 - while quantum & 0x80: - await source.read_into(_uvarint_buffer) - quantum = _uvarint_buffer[0] - value = value + ((quantum & 0x7F) << shift) - shift += 7 - return value - @staticmethod - async def dump(value, target): - shifted = True - while shifted: - shifted = value >> 7 - _uvarint_buffer[0] = (value & 0x7F) | (0x80 if shifted else 0x00) - await target.write(_uvarint_buffer) - value = shifted - - -class BoolType(Type): +class BoolType: WIRE_TYPE = 0 - @staticmethod - async def load(source): - return await UVarintType.load(source) != 0 - @staticmethod - async def dump(value, target): - await target.write(b'\x01' if value else b'\x00') - - -class BytesType(Type): +class BytesType: WIRE_TYPE = 2 - @staticmethod - async def load(source): - size = await UVarintType.load(source) - data = bytearray(size) - await source.read_into(data) - return data - @staticmethod - async def dump(value, target): - await UVarintType.dump(len(value), target) - await target.write(value) - - -class UnicodeType(Type): +class UnicodeType: WIRE_TYPE = 2 - @staticmethod - async def load(source): - size = await UVarintType.load(source) - data = bytearray(size) - await source.read_into(data) - return str(data, 'utf-8') - @staticmethod - async def dump(value, target): - data = bytes(value, 'utf-8') - await UVarintType.dump(len(data), target) - await target.write(data) - - -FLAG_REPEATED = const(1) - - -class MessageType(Type): +class MessageType: WIRE_TYPE = 2 FIELDS = {} @@ -159,61 +62,141 @@ class MessageType(Type): def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.__dict__) - @classmethod - async def load(cls, source=None, target=None): - if target is None: - target = build_message(cls) - if source is None: - source = StreamReader() - try: - while True: - fkey = await UVarintType.load(source) - ftag = fkey >> 3 - wtype = fkey & 7 - if ftag in cls.FIELDS: - field = cls.FIELDS[ftag] - ftype = field[1] - if wtype != ftype.WIRE_TYPE: - raise TypeError( - 'Value of tag %s has incorrect wiretype %s, %s expected.' % - (ftag, wtype, ftype.WIRE_TYPE)) - else: - ftype = {0: UVarintType, 2: BytesType}[wtype] - await ftype.load(source) - continue - if issubclass(ftype, MessageType): - flen = await UVarintType.load(source) - slen = source.set_limit(flen) - target.send((field, None)) - await ftype.load(source, target) - source.set_limit(slen) - else: - fvalue = await ftype.load(source) - target.send((field, fvalue)) - except EOFError as e: - try: - target.throw(e) - except StopIteration as e: - return e.value - @classmethod - async def dump(cls, msg, target): - for ftag in cls.FIELDS: - fname, ftype, fflags = cls.FIELDS[ftag] - fvalue = getattr(msg, fname, None) - if fvalue is None: - continue - key = (ftag << 3) | ftype.WIRE_TYPE - if fflags & FLAG_REPEATED: - for svalue in fvalue: - await UVarintType.dump(key, target) - if issubclass(ftype, MessageType): - await BytesType.dump(ftype.dumps(svalue), target) - else: - await ftype.dump(svalue, target) +class LimitedReader: + + def __init__(self, reader, limit): + self.reader = reader + self.limit = limit + + async def readinto(self, buf): + if self.limit < len(buf): + raise EOFError + else: + nread = await self.reader.readinto(buf) + self.limit -= nread + return nread + + +class CountingWriter: + + def __init__(self): + self.size = 0 + + async def write(self, buf): + nwritten = len(buf) + self.size += nwritten + return nwritten + + +FLAG_REPEATED = const(1) + + +async def load_message(reader, msg_type): + fields = msg_type.FIELDS + msg = msg_type() + + while True: + try: + fkey = await load_uvarint(reader) + except EOFError: + break # no more fields to load + + ftag = fkey >> 3 + wtype = fkey & 7 + + field = fields.get(ftag, None) + + if field is None: # unknown field, skip it + if wtype == 0: + await load_uvarint(reader) + elif wtype == 2: + ivalue = await load_uvarint(reader) + await reader.readinto(bytearray(ivalue)) else: - await UVarintType.dump(key, target) - if issubclass(ftype, MessageType): - await BytesType.dump(ftype.dumps(fvalue), target) - else: - await ftype.dump(fvalue, target) + raise ValueError + continue + + fname, ftype, fflags = field + if wtype != ftype.WIRE_TYPE: + raise TypeError # parsed wire type differs from the schema + + ivalue = await load_uvarint(reader) + + if ftype is UVarintType: + fvalue = ivalue + elif ftype is BoolType: + fvalue = bool(ivalue) + elif ftype is BytesType: + fvalue = bytearray(ivalue) + await reader.readinto(fvalue) + elif ftype is UnicodeType: + fvalue = bytearray(ivalue) + await reader.readinto(fvalue) + fvalue = str(fvalue, 'utf8') + elif issubclass(ftype, MessageType): + fvalue = await load_message(LimitedReader(reader, ivalue), ftype) + else: + raise TypeError # field type is unknown + + if fflags & FLAG_REPEATED: + pvalue = getattr(msg, fname, []) + pvalue.append(fvalue) + fvalue = pvalue + setattr(msg, fname, fvalue) + + # fill missing fields + for tag in msg.FIELDS: + field = msg.FIELDS[tag] + if not hasattr(msg, field[0]): + setattr(msg, field[0], None) + + return msg + + +async def dump_message(writer, msg): + repvalue = [0] + mtype = msg.__class__ + fields = mtype.FIELDS + + for ftag in fields: + field = fields[ftag] + fname = field[0] + ftype = field[1] + fflags = field[2] + + fvalue = getattr(msg, fname, None) + if fvalue is None: + continue + + fkey = (ftag << 3) | ftype.WIRE_TYPE + + if not fflags & FLAG_REPEATED: + repvalue[0] = fvalue + fvalue = repvalue + + for svalue in fvalue: + await dump_uvarint(writer, fkey) + + if ftype is UVarintType: + await dump_uvarint(writer, svalue) + + elif ftype is BoolType: + await dump_uvarint(writer, int(svalue)) + + elif ftype is BytesType: + await dump_uvarint(writer, len(svalue)) + await writer.write(svalue) + + elif ftype is UnicodeType: + await dump_uvarint(writer, len(svalue)) + await writer.write(bytes(svalue, 'utf8')) + + elif issubclass(ftype, MessageType): + counter = CountingWriter() + await dump_message(counter, svalue) + await dump_uvarint(writer, counter.size) + await dump_message(writer, svalue) + + else: + raise TypeError diff --git a/src/lib/streams.py b/src/lib/streams.py deleted file mode 100644 index 09d02979..00000000 --- a/src/lib/streams.py +++ /dev/null @@ -1,65 +0,0 @@ -from trezor.utils import memcpy - - -class StreamReader: - - def __init__(self, buffer=None, limit=None): - if buffer is None: - buffer = bytearray() - self._buffer = buffer - self._limit = limit - self._ofs = 0 - - async def read_into(self, dst): - ''' - Read exactly `len(dst)` bytes into writable buffer-like `dst`. - - Raises `EOFError` if the internal limit was reached or the - backing IO strategy signalled an EOF. - ''' - n = len(dst) - - if self._limit is not None: - if self._limit < n: - raise EOFError() - self._limit -= n - - buf = self._buffer - ofs = self._ofs - i = 0 - while i < n: - if ofs >= len(buf): - buf = yield - ofs = 0 - # memcpy caps on the buffer lengths, no need for exact byte count - nb = memcpy(dst, i, buf, ofs, n) - ofs += nb - i += nb - self._buffer = buf - self._ofs = ofs - - def set_limit(self, n): - ''' - Makes this reader to signal EOF after reading `n` bytes. - - Returns the number of bytes that the reader can read after - raising EOF (intended to be restored with another call to - `set_limit`). - ''' - if self._limit is not None and n is not None: - rem = self._limit - n - else: - rem = None - self._limit = n - return rem - - -class BufferWriter: - - def __init__(self, buffer=None): - if buffer is None: - buffer = bytearray() - self.buffer = buffer - - async def write(self, b): - self.buffer.extend(b) diff --git a/src/main.py b/src/main.py index 1b6077e1..f0c43214 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,8 @@ from trezor import config from trezor import msg from trezor import ui from trezor import wire +from trezor import loop +from trezor.wire import codec_v2 config.init() diff --git a/src/trezor/loop.py b/src/trezor/loop.py index 8ac38921..d634159a 100644 --- a/src/trezor/loop.py +++ b/src/trezor/loop.py @@ -103,11 +103,11 @@ def run_forever(): log_delay_rb[log_delay_pos] = delay log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len - if trezormsg.poll(_paused_tasks, msg_entry, delay): + if io.poll(_paused_tasks, msg_entry, delay): # message received, run tasks paused on the interface msg_tasks = _paused_tasks.pop(msg_entry[0], ()) for task in msg_tasks: - _step_task(task, msg_entry[1]) + _step_task(task, msg_entry[1]) else: # timeout occurred, run the first scheduled task if _scheduled_tasks: @@ -292,6 +292,72 @@ class Wait(Syscall): raise +class Put(Syscall): + + def __init__(self, chan, value=None): + self.chan = chan + self.value = value + + def __call__(self, value): + self.value = value + return self + + def handle(self, task): + self.chan.schedule_put(schedule_task, task, self.value) + + +class Take(Syscall): + + def __init__(self, chan): + self.chan = chan + + def __call__(self): + return self + + def handle(self, task): + if self.chan.schedule_take(schedule_task, task) and self.chan.id is not None: + _pause_task(self.chan, self.chan.id) + + +class Chan: + + def __init__(self, id=None): + self.id = id + self.putters = [] + self.takers = [] + self.put = Put(self) + self.take = Take(self) + + def schedule_publish(self, schedule, value): + if self.takers: + for taker in self.takers: + schedule(taker, value) + self.takers.clear() + return True + else: + return False + + def schedule_put(self, schedule, putter, value): + if self.takers: + taker = self.takers.pop(0) + schedule(taker, value) + schedule(putter, value) + return True + else: + self.putters.append((putter, value)) + return False + + def schedule_take(self, schedule, taker): + if self.putters: + putter, value = self.putters.pop(0) + schedule(taker, value) + schedule(putter, value) + return True + else: + self.takers.append(taker) + return False + + select = Select sleep = Sleep wait = Wait diff --git a/src/trezor/main.py b/src/trezor/main.py index 2dfa6304..dfafbcfd 100644 --- a/src/trezor/main.py +++ b/src/trezor/main.py @@ -1,42 +1,31 @@ -import sys -sys.path.append('lib') - import gc +import micropython +import sys + +sys.path.append('lib') from trezor import loop from trezor import workflow from trezor import log log.level = log.DEBUG -# log.level = log.INFO - - -def perf_info_debug(): - while True: - queue_len = len(loop._scheduled_tasks) - - delay_avg = sum(loop.log_delay_rb) / loop.log_delay_rb_len - delay_last = loop.log_delay_rb[loop.log_delay_pos] - - mem_alloc = gc.mem_alloc() - gc.collect() - log.debug(__name__, "mem_alloc: %s/%s, delay_avg: %d, delay_last: %d, queue_len: %d", - mem_alloc, gc.mem_alloc(), delay_avg, delay_last, queue_len) - - yield loop.Sleep(1000000) def perf_info(): + prev = 0 + peak = 0 + sleep = loop.sleep(100000) while True: gc.collect() - log.info(__name__, "mem_alloc: %d", gc.mem_alloc()) - yield loop.Sleep(1000000) + used = gc.mem_alloc() + if used != prev: + prev = used + peak = max(peak, used) + print('peak %d, used %d' % (peak, used)) + yield sleep def run(default_workflow): - # if __debug__: - # loop.schedule_task(perf_info_debug()) - # else: - # loop.schedule_task(perf_info()) + # loop.schedule_task(perf_info()) workflow.start_default(default_workflow) loop.run_forever() diff --git a/src/trezor/messages/__init__.py b/src/trezor/messages/__init__.py index e626cc27..dac057b0 100644 --- a/src/trezor/messages/__init__.py +++ b/src/trezor/messages/__init__.py @@ -1,13 +1,13 @@ from . import wire_types -def get_protobuf_type_name(wire_type): +def get_type_name(wire_type): for name in dir(wire_types): if getattr(wire_types, name) == wire_type: return name -def get_protobuf_type(wire_type): - name = get_protobuf_type_name(wire_type) +def get_type(wire_type): + name = get_type_name(wire_type) module = __import__('trezor.messages.%s' % name, None, None, (name, ), 0) return getattr(module, name) diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index 2917abe5..93a660f4 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -1,186 +1,134 @@ -import ubinascii import protobuf from trezor import log from trezor import loop from trezor import messages -from trezor import msg from trezor import workflow from . import codec_v1 from . import codec_v2 -from . import sessions -_interface = None - -_workflow_callbacks = {} # wire type -> function returning workflow -_workflow_args = {} # wire type -> args +workflows = {} -def register(wire_type, callback, *args): - if wire_type in _workflow_callbacks: - raise KeyError('Message %d already registered' % wire_type) - _workflow_callbacks[wire_type] = callback - _workflow_args[wire_type] = args +def register(wire_type, handler, *args): + if wire_type in workflows: + raise KeyError + workflows[wire_type] = (handler, args) -def setup(iface): - global _interface - - # setup wire interface for reading and writing - _interface = iface - - # implicitly register v1 codec on its session. v2 sessions are - # opened/closed explicitely through session control messages. - _session_open(codec_v1.SESSION) - - # run session dispatcher - loop.schedule_task(_dispatch_reports()) +def setup(interface): + session_supervisor = codec_v2.SesssionSupervisor(interface, + session_handler) + session_supervisor.open(codec_v1.SESSION_ID) + loop.schedule_task(session_supervisor.listen()) -async def read(session_id, *wire_types): - log.info(__name__, 'session %x: read(%s)', session_id, wire_types) - signal = loop.Signal() - sessions.listen(session_id, _handle_response, wire_types, signal) - return await signal +class Context: + def __init__(self, interface, session_id): + self.interface = interface + self.session_id = session_id + + def get_reader(self): + if self.session_id == codec_v1.SESSION_ID: + return codec_v1.Reader(self.interface) + else: + return codec_v2.Reader(self.interface, self.session_id) + + def get_writer(self, mtype, msize): + if self.session_id == codec_v1.SESSION_ID: + return codec_v1.Writer(self.interface, mtype, msize) + else: + return codec_v2.Writer(self.interface, self.session_id, mtype, msize) + + async def read(self, types): + reader = self.get_reader() + await reader.open() + if reader.type not in types: + raise UnexpectedMessageError(reader) + return await protobuf.load_message(reader, + messages.get_type(reader.type)) + + async def write(self, msg): + counter = protobuf.CountingWriter() + await protobuf.dump_message(counter, msg) + writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size) + await protobuf.dump_message(writer, msg) + await writer.close() + + async def call(self, msg, types): + await self.write(msg) + return await self.read(types) -async def write(session_id, pbuf_msg): - log.info(__name__, 'session %x: write(%s)', session_id, pbuf_msg) - pbuf_type = pbuf_msg.__class__ - msg_data = pbuf_type.dumps(pbuf_msg) - msg_type = pbuf_type.MESSAGE_WIRE_TYPE - sessions.get_codec(session_id).encode( - session_id, msg_type, msg_data, _write_report) - - -async def call(session_id, pbuf_msg, *response_types): - await write(session_id, pbuf_msg) - return await read(session_id, *response_types) +class UnexpectedMessageError(Exception): + def __init__(self, reader): + super().__init__() + self.reader = reader class FailureError(Exception): - - def to_protobuf(self): - from trezor.messages.Failure import Failure - code, message = self.args - return Failure(code=code, message=message) + def __init__(self, code, message): + super().__init__() + self.code = code + self.message = message -class CloseWorkflow(Exception): - pass +class Workflow: + def __init__(self, default): + self.handlers = {} + self.default = default - -def protobuf_workflow(session_id, msg_type, data_len, callback, *args): - return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args) - - -def _start_protobuf_workflow(pbuf_msg, session_id, callback, args): - wf = callback(session_id, pbuf_msg, *args) - wf = _wrap_protobuf_workflow(wf, session_id) - workflow.start(wf) - - -async def _wrap_protobuf_workflow(wf, session_id): - try: - result = await wf - - except CloseWorkflow: - return - - except FailureError as e: - await write(session_id, e.to_protobuf()) - raise - - except Exception as e: - from trezor.messages.Failure import Failure - from trezor.messages.FailureType import FirmwareError - await write(session_id, Failure( - code=FirmwareError, message='Firmware Error')) - raise - - else: - if result is not None: - await write(session_id, result) - return result - - finally: - if session_id in sessions.opened: - sessions.listen(session_id, _handle_workflow) - - -def _build_protobuf(msg_type, callback, *args): - pbuf_type = messages.get_protobuf_type(msg_type) - builder = protobuf.build_message(pbuf_type, callback, *args) - builder.send(None) - return pbuf_type.load(target=builder) - - -def _handle_response(session_id, msg_type, data_len, response_types, signal): - if msg_type in response_types: - return _build_protobuf(msg_type, signal.send) - else: - signal.send(CloseWorkflow()) - return _handle_workflow(session_id, msg_type, data_len) - - -def _handle_workflow(session_id, msg_type, data_len): - if msg_type in _workflow_callbacks: - callback = _workflow_callbacks[msg_type] - args = _workflow_args[msg_type] - return callback(session_id, msg_type, data_len, *args) - else: - return _handle_unexpected(session_id, msg_type, data_len) - - -def _handle_unexpected(session_id, msg_type, data_len): - log.warning( - __name__, 'session %x: skip type %d, len %d', session_id, msg_type, data_len) - - # read the message in full - try: + async def __call__(self, interface, session_id): + ctx = Context(interface, session_id) while True: - yield - except EOFError: - pass + try: + reader = ctx.get_reader() + await reader.open() + try: + handler = self.handlers[reader.type] + except KeyError: + handler = self.default + try: + await handler(ctx, reader) + except UnexpectedMessageError as unexp_msg: + reader = unexp_msg.reader + except Exception as e: + log.exception(__name__, e) + +async def protobuf_workflow(ctx, reader, handler, *args): + msg = await protobuf.load_message(reader, messages.get_type(reader.type)) + try: + res = await handler(reader.sid, msg, *args) + except Exception as exc: + if not isinstance(exc, UnexpectedMessageError): + await ctx.write(make_failure_msg(exc)) + raise + else: + if res: + await ctx.write(res) + + +async def handle_unexp_msg(ctx, reader): + # receive the message and throw it away + while reader.size > 0: + buf = bytearray(reader.size) + await reader.readinto(buf) # respond with an unknown message error from trezor.messages.Failure import Failure from trezor.messages.FailureType import UnexpectedMessage - failure = Failure(code=UnexpectedMessage, message='Unexpected message') - failure = Failure.dumps(failure) - sessions.get_codec(session_id).encode( - session_id, Failure.MESSAGE_WIRE_TYPE, failure, _write_report) + await ctx.write( + Failure(code=UnexpectedMessage, message='Unexpected message')) -def _write_report(report): - # if __debug__: - # log.debug(__name__, 'write report %s', ubinascii.hexlify(report)) - msg.send(_interface, report) - - -def _dispatch_reports(): - read = loop.select(_interface) - while True: - report = yield read - # if __debug__: - # log.debug(__name__, 'read report %s', ubinascii.hexlify(report)) - sessions.dispatch( - memoryview(report), _session_open, _session_close, _session_unknown) - - -def _session_open(session_id=None): - session_id = sessions.open(session_id) - sessions.listen(session_id, _handle_workflow) - sessions.get_codec(session_id).encode_session_open( - session_id, _write_report) - - -def _session_close(session_id): - sessions.close(session_id) - sessions.get_codec(session_id).encode_session_close( - session_id, _write_report) - - -def _session_unknown(session_id, report_data): - log.warning(__name__, 'report on unknown session %x', session_id) +def make_failure_msg(exc): + from trezor.messages.Failure import Failure + from trezor.messages.FailureType import FirmwareError + if isinstance(exc, FailureError): + code = exc.code + message = exc.message + else: + code = FirmwareError + message = 'Firmware Error' + return Failure(code=code, message=message) diff --git a/src/trezor/wire/codec_v1.py b/src/trezor/wire/codec_v1.py index 6657f60a..9a3dfc87 100644 --- a/src/trezor/wire/codec_v1.py +++ b/src/trezor/wire/codec_v1.py @@ -1,114 +1,145 @@ from micropython import const - import ustruct -SESSION = const(0) -REP_MARKER = const(63) # ord('?') -REP_MARKER_LEN = const(1) # len('?') +from trezor import io +from trezor import loop +from trezor import utils _REP_LEN = const(64) -_MSG_HEADER_MAGIC = const(35) # org('#') -_MSG_HEADER = '>BBHL' # magic, magic, wire type, data length -_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER) + +_REP_MARKER = const(63) # ord('?') +_REP_MAGIC = const(35) # org('#') +_REP_INIT = '>BBBHL' # marker, magic, magic, wire type, data length +_REP_INIT_DATA = const(9) # offset of data in the initial report +_REP_CONT_DATA = const(1) # offset of data in the continuation report + +SESSION_ID = const(0) -def detect(data): - return data[0] == REP_MARKER +class Reader: + ''' + Decoder for legacy codec over the HID layer. Provides readable + async-file-like interface. + ''' + + def __init__(self, iface): + self.iface = iface + self.type = None + self.size = None + self.data = None + self.ofs = 0 + + def __repr__(self): + return '' % (self.type, self.size) + + async def open(self): + ''' + Begin the message transmission by waiting for initial V2 message report + on this session. `self.type` and `self.size` are initialized and + available after `open()` returns. + ''' + read = loop.select(self.iface | loop.READ) + while True: + # wait for initial report + report = await read + marker = report[0] + if marker == _REP_MARKER: + _, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report) + if m1 != _REP_MAGIC or m2 != _REP_MAGIC: + raise ValueError + break + + # load received message header + self.type = mtype + self.size = msize + self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize] + self.ofs = 0 + + async def readinto(self, buf): + ''' + Read exactly `len(buf)` bytes into `buf`, waiting for additional + reports, if needed. Raises `EOFError` if end-of-message is encountered + before the full read can be completed. + ''' + if self.size < len(buf): + raise EOFError + + read = loop.select(self.iface | loop.READ) + nread = 0 + while nread < len(buf): + if self.ofs == len(self.data): + # we are at the end of received data + # wait for continuation report + while True: + report = await read + marker = report[0] + if marker == _REP_MARKER: + break + self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size] + self.ofs = 0 + + # copy as much as possible to target buffer + nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf)) + nread += nbytes + self.ofs += nbytes + self.size -= nbytes + + return nread -def parse_report(data): - if len(data) != _REP_LEN: - raise ValueError('Invalid buffer size') - return None, SESSION, data[1:] +class Writer: + ''' + Encoder for legacy codec over the HID layer. Provides writable + async-file-like interface. + ''' + def __init__(self, iface, mtype, msize): + self.iface = iface + self.type = mtype + self.size = msize + self.data = bytearray(_REP_LEN) + self.ofs = _REP_INIT_DATA -def parse_message(data): - magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data) - if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC: - raise ValueError('Corrupted magic bytes') - return msg_type, data_len, data[_MSG_HEADER_LEN:] + # load the report with initial header + ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize) + def __repr__(self): + return '' % (self.type, self.size) -def serialize_message_header(data, msg_type, msg_len): - if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN: - raise ValueError('Invalid buffer size') - if msg_type < 0 or msg_type > 65535: - raise ValueError('Value is out of range') - ustruct.pack_into( - _MSG_HEADER, data, REP_MARKER_LEN, - _MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len) + async def write(self, buf): + ''' + Encode and write every byte from `buf`. Does not need to be called in + case message has zero length. Raises `EOFError` if the length of `buf` + exceeds the remaining message length. + ''' + if self.size < len(buf): + raise EOFError + write = loop.select(self.iface | loop.WRITE) + nwritten = 0 + while nwritten < len(buf): + # copy as much as possible to report buffer + nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf)) + nwritten += nbytes + self.ofs += nbytes + self.size -= nbytes -def decode_stream(session_id, callback, *args): - '''Decode a v1 wire message from the report data and stream it to target. + if self.ofs == _REP_LEN: + # we are at the end of the report, flush it + await write + io.send(self.iface, self.data) + self.ofs = _REP_CONT_DATA -Receives report payloads. After first report, creates target by calling -`callback(session_id, msg_type, data_len, *args)` and sends chunks of message -data. Throws `EOFError` to target after last data chunk. + return nwritten -Pass report payloads as `memoryview` for cheaper slicing. -''' + async def close(self): + '''Flush and close the message transmission.''' + if self.ofs != _REP_CONT_DATA: + # we didn't write anything or last write() wasn't report-aligned, + # pad the final report and flush it + while self.ofs < _REP_LEN: + self.data[self.ofs] = 0x00 + self.ofs += 1 - message = yield # read first report - msg_type, data_len, data = parse_message(message) - - target = callback(session_id, msg_type, data_len, *args) - target.send(None) - - while data_len > 0: - - data_chunk = data[:data_len] # slice off the garbage at the end - data = data[len(data_chunk):] # slice off what we have read - data_len -= len(data_chunk) - target.send(data_chunk) - - if data_len > 0: - data = yield # read next report - - target.throw(EOFError()) - - -def encode(session_id, msg_type, msg_data, callback): - '''Encode a full v1 wire message directly to reports and stream it to callback. - -Callback receives `memoryview`s of HID reports which are valid until the -callback returns. -''' - report = memoryview(bytearray(_REP_LEN)) - report[0] = REP_MARKER - serialize_message_header(report, msg_type, len(msg_data)) - - source_data = memoryview(msg_data) - target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:] - - while True: - # move as much as possible from source to target - n = min(len(target_data), len(source_data)) - target_data[:n] = source_data[:n] - source_data = source_data[n:] - target_data = target_data[n:] - - # fill the rest of the report with 0x00 - x = 0 - to_fill = len(target_data) - while x < to_fill: - target_data[x] = 0 - x += 1 - - callback(report) - - if not source_data: - break - - # reset to skip the magic, not the whole header anymore - target_data = report[REP_MARKER_LEN:] - - -def encode_session_open(session_id, callback): - # v1 codec does not have explicit session support - pass - - -def encode_session_close(session_id, callback): - # v1 codec does not have explicit session support - pass + await loop.select(self.iface | loop.WRITE) + io.send(self.iface, self.data) diff --git a/src/trezor/wire/codec_v2.py b/src/trezor/wire/codec_v2.py index 63f7fd90..808b4fa4 100644 --- a/src/trezor/wire/codec_v2.py +++ b/src/trezor/wire/codec_v2.py @@ -1,190 +1,232 @@ from micropython import const import ustruct -import ubinascii -# trezor wire protocol #2: -# -# # hid report (64B) -# - report marker (1B) -# - session id (4B, BE) -# - payload (59B) -# -# # message -# - streamed as payloads of hid reports -# - message type (4B, BE) -# - data length (4B, BE) -# - data (var-length) -# - data crc32 checksum (4B, BE) -# -# # sessions -# - reports are interleaved, need to be dispatched by session id +from trezor import io +from trezor import loop +from trezor import utils +from trezor.crypto import random -REP_MARKER_HEADER = const(72) # ord('H') -REP_MARKER_DATA = const(68) # ord('D') -REP_MARKER_OPEN = const(79) # ord('O') -REP_MARKER_CLOSE = const(67) # ord('C') - -_REP_HEADER = '>BL' # marker, session id -_MSG_HEADER = '>LL' # msg type, data length -_MSG_FOOTER = '>L' # data checksum +# TREZOR wire protocol #2: +# +# # Initial message report +# uint8_t marker; // REP_MARKER_INIT +# uint32_t session_id; // Big-endian +# uint32_t message_type; // Big-endian +# uint32_t message_size; // Big-endian +# uint8_t data[]; +# +# # Continuation message report +# uint8_t marker; // REP_MARKER_CONT +# uint32_t session_id; // Big-endian +# uint32_t sequence; // Big-endian, 0 for 1st continuation report +# uint8_t data[]; _REP_LEN = const(64) -_REP_HEADER_LEN = ustruct.calcsize(_REP_HEADER) -_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER) -_MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER) + +_REP_MARKER_INIT = const(0x01) +_REP_MARKER_CONT = const(0x02) +_REP_MARKER_OPEN = const(0x03) +_REP_MARKER_CLOSE = const(0x04) + +_REP = '>BL' # marker, session_id +_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size +_REP_CONT = '>BLL' # marker, session_id, sequence +_REP_INIT_DATA = const(13) # offset of data in init report +_REP_CONT_DATA = const(9) # offset of data in cont report -def parse_report(data): - if len(data) != _REP_LEN: - raise ValueError('Invalid buffer size') - marker, session_id = ustruct.unpack(_REP_HEADER, data) - return marker, session_id, data[_REP_HEADER_LEN:] - - -def parse_message(data): - if len(data) != _REP_LEN - _REP_HEADER_LEN: - raise ValueError('Invalid buffer size') - msg_type, data_len = ustruct.unpack(_MSG_HEADER, data) - return msg_type, data_len, data[_MSG_HEADER_LEN:] - - -def parse_message_footer(data): - if len(data) != _MSG_FOOTER_LEN: - raise ValueError('Invalid buffer size') - data_checksum, = ustruct.unpack(_MSG_FOOTER, data) - return data_checksum, - - -def serialize_report_header(data, marker, session_id): - if len(data) < _REP_HEADER_LEN: - raise ValueError('Invalid buffer size') - ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id) - - -def serialize_message_header(data, msg_type, msg_len): - if len(data) < _REP_HEADER_LEN + _MSG_HEADER_LEN: - raise ValueError('Invalid buffer size') - ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len) - - -def serialize_message_footer(data, checksum): - if len(data) < _MSG_FOOTER_LEN: - raise ValueError('Invalid buffer size') - ustruct.pack_into(_MSG_FOOTER, data, 0, checksum) - - -def serialize_opened_session(data, session_id): - serialize_report_header(data, REP_MARKER_OPEN, session_id) - - -class MessageChecksumError(Exception): - pass - - -def decode_stream(session_id, callback, *args): - '''Decode a wire message from the report data and stream it to target. - -Receives report payloads. After first report, creates target by calling -`callback(session_id, msg_type, data_len, *args)` and sends chunks of message -data. -Throws `EOFError` to target after last data chunk, in case of valid checksum. -Throws `MessageChecksumError` to target if data doesn't match the checksum. - -Pass report payloads as `memoryview` for cheaper slicing. -''' - message = yield # read first report - msg_type, data_len, data_tail = parse_message(message) - - target = callback(session_id, msg_type, data_len, *args) - target.send(None) - - checksum = 0 # crc32 - - while data_len > 0: - - data_chunk = data_tail[:data_len] # slice off the garbage at the end - data_tail = data_tail[len(data_chunk):] # slice off what we have read - data_len -= len(data_chunk) - target.send(data_chunk) - - checksum = ubinascii.crc32(data_chunk, checksum) - - if data_len > 0: - data_tail = yield # read next report - - msg_footer = data_tail[:_MSG_FOOTER_LEN] - if len(msg_footer) < _MSG_FOOTER_LEN: - data_tail = yield # read report with the rest of checksum - footer_tail = data_tail[:_MSG_FOOTER_LEN - len(msg_footer)] - msg_footer = bytearray(msg_footer) - msg_footer.extend(footer_tail) - - data_checksum, = parse_message_footer(msg_footer) - if data_checksum != checksum: - target.throw(MessageChecksumError((checksum, data_checksum))) - else: - target.throw(EOFError()) - - -def encode(session_id, msg_type, msg_data, callback): - '''Encode a full wire message directly to reports and stream it to callback. - -Callback receives `memoryview`s of HID reports which are valid until the -callback returns. +class Reader: + ''' + Decoder for v2 codec over the HID layer. Provides readable async-file-like + interface. ''' - report = memoryview(bytearray(_REP_LEN)) - serialize_report_header(report, REP_MARKER_HEADER, session_id) - serialize_message_header(report, msg_type, len(msg_data)) - source_data = memoryview(msg_data) - target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:] + def __init__(self, iface, sid): + self.iface = iface + self.sid = sid + self.type = None + self.size = None + self.data = None + self.ofs = 0 + self.seq = 0 - checksum = ubinascii.crc32(msg_data) + def __repr__(self): + return '' % (self.sid, self.type, self.size) - msg_footer = bytearray(_MSG_FOOTER_LEN) - serialize_message_footer(msg_footer, checksum) + async def open(self): + ''' + Begin the message transmission by waiting for initial V2 message report + on this session. `self.type` and `self.size` are initialized and + available after `open()` returns. + ''' + read = loop.select(self.iface | loop.READ) + while True: + # wait for initial report + report = await read + marker, sid, mtype, msize = ustruct.unpack(_REP_INIT, report) + if sid == self.sid and marker == _REP_MARKER_INIT: + break - first = True + # load received message header + self.type = mtype + self.size = msize + self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize] + self.ofs = 0 + self.seq = 0 - while True: - # move as much as possible from source to target - n = min(len(target_data), len(source_data)) - target_data[:n] = source_data[:n] - source_data = source_data[n:] - target_data = target_data[n:] + async def readinto(self, buf): + ''' + Read exactly `len(buf)` bytes into `buf`, waiting for additional + reports, if needed. Raises `EOFError` if end-of-message is encountered + before the full read can be completed. + ''' + if self.size < len(buf): + raise EOFError - # continue with the footer if source is empty and we have space - if not source_data and target_data and msg_footer: - source_data = msg_footer - msg_footer = None - continue + read = loop.select(self.iface | loop.READ) + nread = 0 + while nread < len(buf): + if self.ofs == len(self.data): + # we are at the end of received data + # wait for continuation report + while True: + report = await read + marker, sid, seq = ustruct.unpack(_REP_CONT, report) + if sid == self.sid and marker == _REP_MARKER_CONT: + if seq != self.seq: + raise ValueError + break + self.data = report[_REP_CONT_DATA:_REP_CONT_DATA + self.size] + self.seq += 1 + self.ofs = 0 - # fill the rest of the report with 0x00 - x = 0 - to_fill = len(target_data) - while x < to_fill: - target_data[x] = 0 - x += 1 + # copy as much as possible to target buffer + nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf)) + nread += nbytes + self.ofs += nbytes + self.size -= nbytes - callback(report) - - if not source_data and not msg_footer: - break - - # reset to skip the magic and session ID - if first: - serialize_report_header(report, REP_MARKER_DATA, session_id) - first = False - target_data = report[_REP_HEADER_LEN:] + return nread -def encode_session_open(session_id, callback): - report = bytearray(_REP_LEN) - serialize_report_header(report, REP_MARKER_OPEN, session_id) - callback(report) +class Writer: + ''' + Encoder for v2 codec over the HID layer. Provides writable async-file-like + interface. + ''' + + def __init__(self, iface, sid, mtype, msize): + self.iface = iface + self.sid = sid + self.type = mtype + self.size = msize + self.data = bytearray(_REP_LEN) + self.ofs = _REP_INIT_DATA + self.seq = 0 + + # load the report with initial header + ustruct.pack_into(_REP_INIT, self.data, 0, + _REP_MARKER_INIT, sid, mtype, msize) + + async def write(self, buf): + ''' + Encode and write every byte from `buf`. Does not need to be called in + case message has zero length. Raises `EOFError` if the length of `buf` + exceeds the remaining message length. + ''' + if self.size < len(buf): + raise EOFError + + write = loop.select(self.iface | loop.WRITE) + nwritten = 0 + while nwritten < len(buf): + # copy as much as possible to report buffer + nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf)) + nwritten += nbytes + self.ofs += nbytes + self.size -= nbytes + + if self.ofs == _REP_LEN: + # we are at the end of the report, flush it, and prepare header + await write + io.send(self.iface, self.data) + ustruct.pack_into(_REP_CONT, self.data, 0, + _REP_MARKER_CONT, self.sid, self.seq) + self.ofs = _REP_CONT_DATA + self.seq += 1 + + return nwritten + + async def close(self): + '''Flush and close the message transmission.''' + if self.ofs != _REP_CONT_DATA: + # we didn't write anything or last write() wasn't report-aligned, + # pad the final report and flush it + while self.ofs < _REP_LEN: + self.data[self.ofs] = 0x00 + self.ofs += 1 + + await loop.select(self.iface | loop.WRITE) + io.send(self.iface, self.data) -def encode_session_close(session_id, callback): - report = bytearray(_REP_LEN) - serialize_report_header(report, REP_MARKER_CLOSE, session_id) - callback(report) +class SesssionSupervisor: + '''Handles session open/close requests on v2 protocol layer.''' + + def __init__(self, iface, handler): + self.iface = iface + self.handler = handler + self.handling_tasks = {} + self.session_report = bytearray(_REP_LEN) + + async def listen(self): + ''' + Listen for open/close requests on configured interface. After open + request, session is started and a new task is scheduled to handle it. + After close request, the handling task is closed and session terminated. + Both requests receive responses confirming the operation. + ''' + read = loop.select(self.iface | loop.READ) + write = loop.select(self.iface | loop.WRITE) + while True: + report = await read + repmarker, repsid = ustruct.unpack(_REP, report) + # because tasks paused on I/O have a priority over time-scheduled + # tasks, we need to `yield` explicitly before sending a response to + # open/close request. Otherwise the handler would have no chance to + # run and schedule communication. + if repmarker == _REP_MARKER_OPEN: + newsid = self.newsid() + self.open(newsid) + yield + await write + self.sendopen(newsid) + elif repmarker == _REP_MARKER_CLOSE: + self.close(repsid) + yield + await write + self.sendclose(repsid) + + def open(self, sid): + if sid not in self.handling_tasks: + task = self.handling_tasks[sid] = self.handler(self.iface, sid) + loop.schedule_task(task) + + def close(self, sid): + if sid in self.handling_tasks: + task = self.handling_tasks.pop(sid) + task.close() + + def newsid(self): + while True: + sid = random.uniform(0xffffffff) + 1 + if sid not in self.handling_tasks: + return sid + + def sendopen(self, sid): + ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid) + io.send(self.iface, self.session_report) + + def sendclose(self, sid): + ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid) + io.send(self.iface, self.session_report) diff --git a/src/trezor/wire/sessions.py b/src/trezor/wire/sessions.py deleted file mode 100644 index 9fc01608..00000000 --- a/src/trezor/wire/sessions.py +++ /dev/null @@ -1,82 +0,0 @@ -from trezor import log -from trezor.crypto import random - -from . import codec_v1 -from . import codec_v2 - -opened = set() # opened session ids -readers = {} # session id -> generator - - -def generate(): - while True: - session_id = random.uniform(0xffffffff) + 1 - if session_id not in opened: - return session_id - - -def open(session_id=None): - if session_id is None: - session_id = generate() - log.info(__name__, 'session %x: open', session_id) - opened.add(session_id) - return session_id - - -def close(session_id): - log.info(__name__, 'session %x: close', session_id) - opened.discard(session_id) - readers.pop(session_id, None) - - -def get_codec(session_id): - if session_id == codec_v1.SESSION: - return codec_v1 - else: - return codec_v2 - - -def listen(session_id, handler, *args): - if session_id not in opened: - raise KeyError('Session %x is unknown' % session_id) - if session_id in readers: - raise KeyError('Session %x is already being listened on' % session_id) - log.info(__name__, 'session %x: listening', session_id) - decoder = get_codec(session_id).decode_stream(session_id, handler, *args) - decoder.send(None) - readers[session_id] = decoder - - -def dispatch(report, open_callback, close_callback, unknown_callback): - ''' - Dispatches payloads of reports adhering to one of the wire codecs. - ''' - - if codec_v1.detect(report): - marker, session_id, report_data = codec_v1.parse_report(report) - else: - marker, session_id, report_data = codec_v2.parse_report(report) - - if marker == codec_v2.REP_MARKER_OPEN: - log.debug(__name__, 'request for new session') - open_callback() - return - elif marker == codec_v2.REP_MARKER_CLOSE: - log.debug(__name__, 'request for closing session %x', session_id) - close_callback(session_id) - return - - if session_id not in readers: - log.warning(__name__, 'report on unknown session %x', session_id) - unknown_callback(session_id, report_data) - return - - log.debug(__name__, 'report on session %x', session_id) - reader = readers[session_id] - - try: - reader.send(report_data) - except StopIteration: - readers.pop(session_id) - except Exception as e: - log.exception(__name__, e) diff --git a/src/trezor/workflow.py b/src/trezor/workflow.py index 19d208d7..6b881dde 100644 --- a/src/trezor/workflow.py +++ b/src/trezor/workflow.py @@ -17,14 +17,14 @@ def start_default(genfunc): def close_default(): global _default - log.info(__name__, 'close default %s', _default) - _default.close() - _default = None + if _default is not None: + log.info(__name__, 'close default %s', _default) + _default.close() + _default = None def start(workflow): - if _default is not None: - close_default() + close_default() _started.append(workflow) log.info(__name__, 'start %s', workflow) loop.schedule_task(_watch(workflow)) diff --git a/tests/test_trezor.wire.codec_v1.py b/tests/test_trezor.wire.codec_v1.py index 4638f144..915eb0f2 100644 --- a/tests/test_trezor.wire.codec_v1.py +++ b/tests/test_trezor.wire.codec_v1.py @@ -1,178 +1,164 @@ -from common import * +import sys -import ustruct +sys.path.append('../src') +sys.path.append('../src/lib') +from utest import * +from ustruct import pack, unpack +from ubinascii import hexlify, unhexlify + +from trezor import msg +from trezor.loop import Select, Syscall, READ, WRITE from trezor.crypto import random from trezor.utils import chunks - from trezor.wire import codec_v1 -class TestWireCodecV1(unittest.TestCase): - # pylint: disable=C0301 - def test_detect(self): - for i in range(0, 256): - if i == ord(b'?'): - self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63)) - else: - self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63)) +def test_reader(): + rep_len = 64 + interface = 0xdeadbeef + message_type = 0x4321 + message_len = 250 + reader = codec_v1.Reader(interface, codec_v1.SESSION_ID) - def test_parse(self): - d = bytes(range(0, 55)) - m = b'##\x00\x00\x00\x00\x00\x37' + d - r = b'?' + m + message = bytearray(range(message_len)) + report_header = bytearray(unhexlify('3f23234321000000fa')) - rm, rs, rd = codec_v1.parse_report(r) - self.assertEqual(rm, None) - self.assertEqual(rs, 0) - self.assertEqual(rd, m) + # open, expected one read + first_report = report_header + message[:rep_len - len(report_header)] + assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) + assert_eq(reader.type, message_type) + assert_eq(reader.size, message_len) - mt, ml, md = codec_v1.parse_message(m) - self.assertEqual(mt, 0) - self.assertEqual(ml, len(d)) - self.assertEqual(md, d) + # empty read + empty_buffer = bytearray() + assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),]) + assert_eq(len(empty_buffer), 0) + assert_eq(reader.size, message_len) - for i in range(0, 1024): - if i != 64: - with self.assertRaises(ValueError): - codec_v1.parse_report(bytes(range(0, i))) + # short read, expected no read + short_buffer = bytearray(32) + assert_async(reader.readinto(short_buffer), [(None, StopIteration()),]) + assert_eq(len(short_buffer), 32) + assert_eq(short_buffer, message[:len(short_buffer)]) + assert_eq(reader.size, message_len - len(short_buffer)) - for hx in range(0, 256): - for hy in range(0, 256): - if hx != ord(b'#') and hy != ord(b'#'): - with self.assertRaises(ValueError): - codec_v1.parse_message(bytes([hx, hy]) + m[2:]) + # aligned read, expected no read + aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) + assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),]) + assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) + assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) - def test_serialize(self): - data = bytearray(range(0, 10)) - codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) - self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09') + # one byte read, expected one read + next_report_header = bytearray(unhexlify('3f')) + next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] + onebyte_buffer = bytearray(1) + assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) + assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) + assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) - data = bytearray(9) - with self.assertRaises(ValueError): - codec_v1.serialize_message_header(data, 65536, 0) + # too long read, raises eof + assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),]) - for i in range(0, 8): - data = bytearray(i) - with self.assertRaises(ValueError): - codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) + # long read, expect multiple reads + start_size = reader.size + long_buffer = bytearray(start_size) + report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):] + report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)] + report_payload_rest = report_payload[len(report_payload_head):] + report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header))) + report_payloads = [report_payload_head] + report_payload_rest + next_reports = [next_report_header + r for r in report_payloads] + expected_syscalls = [] + for i, _ in enumerate(next_reports): + prev_report = next_reports[i - 1] if i > 0 else None + expected_syscalls.append((prev_report, Select(READ | interface))) + expected_syscalls.append((next_reports[-1], StopIteration())) + assert_async(reader.readinto(long_buffer), expected_syscalls) + assert_eq(long_buffer, message[-start_size:]) + assert_eq(reader.size, 0) - def test_decode_empty(self): - message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55 + # one byte read, raises eof + assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),]) - record = [] - genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy') - decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy') - decoder.send(None) - try: - decoder.send(message) - except StopIteration as e: - res = e.value - self.assertEqual(res, None) - self.assertEqual(len(record), 1) - self.assertIsInstance(record[0], EOFError) +def test_writer(): + rep_len = 64 + interface = 0xdeadbeef + message_type = 0x87654321 + message_len = 1024 + writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len) - def test_decode_one_report_aligned(self): - data = bytes(range(0, 55)) - message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data + # init header corresponding to the data above + report_header = bytearray(unhexlify('3f2323432100000400')) - record = [] - genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy') - decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy') - decoder.send(None) + assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) - try: - decoder.send(message) - except StopIteration as e: - res = e.value - self.assertEqual(res, None) - self.assertEqual(len(record), 2) - self.assertEqual(record[0], data) - self.assertIsInstance(record[1], EOFError) + # empty write + start_size = writer.size + assert_async(writer.write(bytearray()), [(None, StopIteration()),]) + assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) + assert_eq(writer.size, start_size) - def test_decode_generated_range(self): - for data_len in range(1, 512): - data = random.bytes(data_len) - data_chunks = [data[:55]] + list(chunks(data[55:], 63)) + # short write, expected no report + start_size = writer.size + short_payload = bytearray(range(4)) + assert_async(writer.write(short_payload), [(None, StopIteration()),]) + assert_eq(writer.size, start_size - len(short_payload)) + assert_eq(writer.data, + report_header + + short_payload + + bytearray(rep_len - len(report_header) - len(short_payload))) - msg_type = 0xabcd - header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len) + # aligned write, expected one report + start_size = writer.size + aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload))) + msg.send = mock_call(msg.send, [ + (interface, report_header + + short_payload + + aligned_payload + + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ]) + assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),]) + assert_eq(writer.size, start_size - len(aligned_payload)) + msg.send.assert_called_n_times(1) + msg.send = msg.send.original - message = header + data - message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))] + # short write, expected no report, but data starts with correct seq and cont marker + report_header = bytearray(unhexlify('3f')) + start_size = writer.size + assert_async(writer.write(short_payload), [(None, StopIteration()),]) + assert_eq(writer.size, start_size - len(short_payload)) + assert_eq(writer.data[:len(report_header) + len(short_payload)], + report_header + short_payload) - record = [] - genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy') - decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy') - decoder.send(None) - - res = 1 - try: - for c in message_chunks: - decoder.send(c) - except StopIteration as e: - res = e.value - self.assertEqual(res, None) - self.assertEqual(len(record), len(data_chunks) + 1) - for i in range(0, len(data_chunks)): - self.assertEqual(record[i], data_chunks[i]) - self.assertIsInstance(record[-1], EOFError) - - def test_encode_empty(self): - record = [] - target = self._record(record)() - target.send(None) - - codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send) - self.assertEqual(len(record), 1) - self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55) - - def test_encode_one_report_aligned(self): - data = bytes(range(0, 55)) - - record = [] - target = self._record(record)() - target.send(None) - - codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send) - self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data]) - - def test_encode_generated_range(self): - for data_len in range(1, 1024): - data = random.bytes(data_len) - - msg_type = 0xabcd - header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len) - - message = header + data - reports = [b'?' + c for c in chunks(message, 63)] - reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1])) - - received = 0 - def genfunc(): - nonlocal received - while True: - self.assertEqual((yield), reports[received]) - received += 1 - target = genfunc() - target.send(None) - - codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send) - self.assertEqual(received, len(reports)) - - def _record(self, record, *_args): - def genfunc(*args): - self.assertEqual(args, _args) - while True: - try: - v = yield - except Exception as e: - record.append(e) - else: - record.append(v) - return genfunc + # long write, expected multiple reports + start_size = writer.size + long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload))) + long_payload_rest = bytearray(range(start_size - len(long_payload_head))) + long_payload = long_payload_head + long_payload_rest + expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header))) + expected_reports = [report_header + r for r in expected_payloads] + expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1]))) + # test write + expected_write_reports = expected_reports[:-1] + msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports]) + assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) + assert_eq(writer.size, start_size - len(long_payload)) + msg.send.assert_called_n_times(len(expected_write_reports)) + msg.send = msg.send.original + # test write raises eof + msg.send = mock_call(msg.send, []) + assert_async(writer.write(bytearray(1)), [(None, EOFError())]) + msg.send.assert_called_n_times(0) + msg.send = msg.send.original + # test close + expected_close_reports = expected_reports[-1:] + msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports]) + assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) + assert_eq(writer.size, 0) + msg.send.assert_called_n_times(len(expected_close_reports)) + msg.send = msg.send.original if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/tests/test_trezor.wire.codec_v2.py b/tests/test_trezor.wire.codec_v2.py index 60226b4e..56fc6ab0 100644 --- a/tests/test_trezor.wire.codec_v2.py +++ b/tests/test_trezor.wire.codec_v2.py @@ -1,219 +1,167 @@ -from common import * +import sys -import ustruct -import ubinascii +sys.path.append('../src') +sys.path.append('../src/lib') -from trezor.crypto import random +from utest import * +from ustruct import pack, unpack +from ubinascii import hexlify, unhexlify + +from trezor import msg +from trezor.loop import Select, Syscall, READ, WRITE from trezor.utils import chunks - from trezor.wire import codec_v2 -class TestWireCodec(unittest.TestCase): - # pylint: disable=C0301 - def test_parse(self): - d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59)) +def test_reader(): + rep_len = 64 + interface = 0xdeadbeef + session_id = 0x12345678 + message_type = 0x87654321 + message_len = 250 + reader = codec_v2.Reader(interface, session_id) - m, s, d = codec_v2.parse_report(d) - self.assertEqual(m, b'O'[0]) - self.assertEqual(s, 0x01234567) - self.assertEqual(d, bytes(range(0, 59))) + message = bytearray(range(message_len)) + report_header = bytearray(unhexlify('011234567887654321000000fa')) - t, l, d = codec_v2.parse_message(d) - self.assertEqual(t, 0x00010203) - self.assertEqual(l, 0x04050607) - self.assertEqual(d, bytes(range(8, 59))) + # open, expected one read + first_report = report_header + message[:rep_len - len(report_header)] + assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) + assert_eq(reader.type, message_type) + assert_eq(reader.size, message_len) - f, = codec_v2.parse_message_footer(d[0:4]) - self.assertEqual(f, 0x08090a0b) + # empty read + empty_buffer = bytearray() + assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),]) + assert_eq(len(empty_buffer), 0) + assert_eq(reader.size, message_len) - for i in range(0, 1024): - if i != 64: - with self.assertRaises(ValueError): - codec_v2.parse_report(bytes(range(0, i))) - if i != 59: - with self.assertRaises(ValueError): - codec_v2.parse_message(bytes(range(0, i))) - if i != 4: - with self.assertRaises(ValueError): - codec_v2.parse_message_footer(bytes(range(0, i))) + # short read, expected no read + short_buffer = bytearray(32) + assert_async(reader.readinto(short_buffer), [(None, StopIteration()),]) + assert_eq(len(short_buffer), 32) + assert_eq(short_buffer, message[:len(short_buffer)]) + assert_eq(reader.size, message_len - len(short_buffer)) - def test_serialize(self): - data = bytearray(range(0, 6)) - codec_v2.serialize_report_header(data, 0x12, 0x3456789a) - self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05') + # aligned read, expected no read + aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) + assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),]) + assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) + assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) - data = bytearray(range(0, 6)) - codec_v2.serialize_opened_session(data, 0x3456789a) - self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') + # one byte read, expected one read + next_report_header = bytearray(unhexlify('021234567800000000')) + next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] + onebyte_buffer = bytearray(1) + assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) + assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) + assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) - data = bytearray(range(0, 14)) - codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef) - self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d') + # too long read, raises eof + assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),]) - data = bytearray(range(0, 5)) - codec_v2.serialize_message_footer(data, 0x89abcdef) - self.assertEqual(data, b'\x89\xab\xcd\xef\x04') + # long read, expect multiple reads + start_size = reader.size + long_buffer = bytearray(start_size) + report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):] + report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)] + report_payload_rest = report_payload[len(report_payload_head):] + report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header))) + report_payloads = [report_payload_head] + report_payload_rest + next_reports = [bytearray(unhexlify('0212345678') + pack('>L', i + 1)) + r for i, r in enumerate(report_payloads)] + expected_syscalls = [] + for i, _ in enumerate(next_reports): + prev_report = next_reports[i - 1] if i > 0 else None + expected_syscalls.append((prev_report, Select(READ | interface))) + expected_syscalls.append((next_reports[-1], StopIteration())) + assert_async(reader.readinto(long_buffer), expected_syscalls) + assert_eq(long_buffer, message[-start_size:]) + assert_eq(reader.size, 0) - for i in range(0, 13): - data = bytearray(i) - if i < 4: - with self.assertRaises(ValueError): - codec_v2.serialize_message_footer(data, 0x00) - if i < 5: - with self.assertRaises(ValueError): - codec_v2.serialize_report_header(data, 0x00, 0x00) - with self.assertRaises(ValueError): - codec_v2.serialize_opened_session(data, 0x00) - with self.assertRaises(ValueError): - codec_v2.serialize_message_header(data, 0x00, 0x00) + # one byte read, raises eof + assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),]) - def test_decode_empty(self): - message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51 - record = [] - genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy') - decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') - decoder.send(None) +def test_writer(): + rep_len = 64 + interface = 0xdeadbeef + session_id = 0x12345678 + message_type = 0x87654321 + message_len = 1024 + writer = codec_v2.Writer(interface, session_id, message_type, message_len) - try: - decoder.send(message) - except StopIteration as e: - res = e.value - self.assertEqual(res, None) - self.assertEqual(len(record), 1) - self.assertIsInstance(record[0], EOFError) + # init header corresponding to the data above + report_header = bytearray(unhexlify('01123456788765432100000400')) - def test_decode_one_report_aligned_correct(self): - data = bytes(range(0, 47)) - footer = b'\x2f\x1c\x12\xce' - message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer + assert_eq(writer.data, report_header + bytearray(64 - len(report_header))) - record = [] - genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy') - decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') - decoder.send(None) + # empty write + start_size = writer.size + assert_async(writer.write(bytearray()), [(None, StopIteration()),]) + assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) + assert_eq(writer.size, start_size) - try: - decoder.send(message) - except StopIteration as e: - res = e.value - self.assertEqual(res, None) - self.assertEqual(len(record), 2) - self.assertEqual(record[0], data) - self.assertIsInstance(record[1], EOFError) + # short write, expected no report + start_size = writer.size + short_payload = bytearray(range(4)) + assert_async(writer.write(short_payload), [(None, StopIteration()),]) + assert_eq(writer.size, start_size - len(short_payload)) + assert_eq(writer.data, + report_header + + short_payload + + bytearray(rep_len - len(report_header) - len(short_payload))) - def test_decode_one_report_aligned_incorrect(self): - data = bytes(range(0, 47)) - footer = bytes(4) # wrong checksum - message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer + # aligned write, expected one report + start_size = writer.size + aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload))) + msg.send = mock_call(msg.send, [ + (interface, report_header + + short_payload + + aligned_payload + + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ]) + assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),]) + assert_eq(writer.size, start_size - len(aligned_payload)) + msg.send.assert_called_n_times(1) + msg.send = msg.send.original - record = [] - genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy') - decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') - decoder.send(None) + # short write, expected no report, but data starts with correct seq and cont marker + report_header = bytearray(unhexlify('021234567800000000')) + start_size = writer.size + assert_async(writer.write(short_payload), [(None, StopIteration()),]) + assert_eq(writer.size, start_size - len(short_payload)) + assert_eq(writer.data[:len(report_header) + len(short_payload)], + report_header + short_payload) - try: - decoder.send(message) - except StopIteration as e: - res = e.value - self.assertEqual(res, None) - self.assertEqual(len(record), 2) - self.assertEqual(record[0], data) - self.assertIsInstance(record[1], codec_v2.MessageChecksumError) - - def test_decode_generated_range(self): - for data_len in range(1, 512): - data = random.bytes(data_len) - data_chunks = [data[:51]] + list(chunks(data[51:], 59)) - - msg_type = 0xabcdef12 - data_csum = ubinascii.crc32(data) - header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len) - footer = ustruct.pack('>L', data_csum) - - message = header + data + footer - message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))] - - record = [] - genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy') - decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') - decoder.send(None) - - res = 1 - try: - for c in message_chunks: - decoder.send(c) - except StopIteration as e: - res = e.value - self.assertEqual(res, None) - self.assertEqual(len(record), len(data_chunks) + 1) - for i in range(0, len(data_chunks)): - self.assertEqual(record[i], data_chunks[i]) - self.assertIsInstance(record[-1], EOFError) - - def test_encode_empty(self): - record = [] - target = self._record(record)() - target.send(None) - - codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send) - self.assertEqual(len(record), 1) - self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51) - - def test_encode_one_report_aligned(self): - data = bytes(range(0, 47)) - footer = b'\x2f\x1c\x12\xce' - - record = [] - target = self._record(record)() - target.send(None) - - codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send) - self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer]) - - def test_encode_generated_range(self): - for data_len in range(1, 1024): - data = random.bytes(data_len) - - msg_type = 0xabcdef12 - session_id = 0xdeadbeef - - data_csum = ubinascii.crc32(data) - header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len) - footer = ustruct.pack('>L', data_csum) - session_header = ustruct.pack('>L', session_id) - - message = header + data + footer - report0 = b'H' + session_header + message[:59] - reports = [b'D' + session_header + c for c in chunks(message[59:], 59)] - reports.insert(0, report0) - reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1])) - - received = 0 - def genfunc(): - nonlocal received - while True: - self.assertEqual((yield), reports[received]) - received += 1 - target = genfunc() - target.send(None) - - codec_v2.encode(session_id, msg_type, data, target.send) - self.assertEqual(received, len(reports)) - - def _record(self, record, *_args): - def genfunc(*args): - self.assertEqual(args, _args) - while True: - try: - v = yield - except Exception as e: - record.append(e) - else: - record.append(v) - return genfunc + # long write, expected multiple reports + start_size = writer.size + long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload))) + long_payload_rest = bytearray(range(start_size - len(long_payload_head))) + long_payload = long_payload_head + long_payload_rest + expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header))) + expected_reports = [ + bytearray(unhexlify('0212345678') + pack('>L', seq)) + rep + for seq, rep in enumerate(expected_payloads)] + expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1]))) + # test write + expected_write_reports = expected_reports[:-1] + msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports]) + assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) + assert_eq(writer.size, start_size - len(long_payload)) + msg.send.assert_called_n_times(len(expected_write_reports)) + msg.send = msg.send.original + # test write raises eof + msg.send = mock_call(msg.send, []) + assert_async(writer.write(bytearray(1)), [(None, EOFError())]) + msg.send.assert_called_n_times(0) + msg.send = msg.send.original + # test close + expected_close_reports = expected_reports[-1:] + msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports]) + assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) + assert_eq(writer.size, 0) + msg.send.assert_called_n_times(len(expected_close_reports)) + msg.send = msg.send.original if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/tests/utest.py b/tests/utest.py new file mode 100644 index 00000000..34cc52e3 --- /dev/null +++ b/tests/utest.py @@ -0,0 +1,142 @@ +import sys +import uio +import ure + +__all__ = [ + 'run_tests', + 'run_test', + 'assert_eq', + 'assert_not_eq', + 'assert_is_instance', + 'mock_call', +] + + +# Running + + +def run_tests(mod_name='__main__'): + ntotal = 0 + nok = 0 + nfailed = 0 + + for name, test in get_tests(mod_name): + result = run_test(test) + report_test(name, test, result) + ntotal += 1 + if result: + nok += 1 + else: + nfailed += 1 + break + report_total(ntotal, nok, nfailed) + + if nfailed > 0: + sys.exit(1) + + +def get_tests(mod_name): + module = __import__(mod_name) + for name in dir(module): + if name.startswith('test_'): + yield name, getattr(module, name) + + +def run_test(test): + try: + test() + except Exception as e: + report_exception(e) + return False + else: + return True + + +# Reporting + + +def report_test(name, test, result): + if result: + print('OK', name) + else: + print('ERR', name) + + +def report_exception(exc): + sio = uio.StringIO() + sys.print_exception(exc, sio) + print(sio.getvalue()) + + +def report_total(total, ok, failed): + print('Total:', total, 'OK:', ok, 'Failed:', failed) + + +# Assertions + + +def assert_eq(a, b, msg=None): + assert a == b, msg or format_eq(a, b) + + +def assert_not_eq(a, b, msg=None): + assert a != b, msg or format_not_eq(a, b) + + +def assert_is_instance(obj, cls, msg=None): + assert isinstance(obj, cls), msg or format_is_instance(obj, cls) + + +def assert_eq_obj(a, b, msg=None): + assert_is_instance(a, b.__class__, msg) + assert_eq(a.__dict__, b.__dict__, msg) + + +def format_eq(a, b): + return '\n%r\nvs (expected)\n%r' % (a, b) + + +def format_not_eq(a, b): + return '%r not expected to be equal %r' % (a, b) + + +def format_is_instance(obj, cls): + return '%r expected to be instance of %r' % (obj, cls) + + +def assert_async(task, syscalls): + for prev_result, expected in syscalls: + if isinstance(expected, Exception): + with assert_raises(expected.__class__): + task.send(prev_result) + else: + syscall = task.send(prev_result) + assert_eq_obj(syscall, expected) + + +class assert_raises: + + def __init__(self, exc_type): + self.exc_type = exc_type + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + assert exc_type is not None, '%r not raised' % self.exc_type + return issubclass(exc_type, self.exc_type) + + +class mock_call: + + def __init__(self, original, expected): + self.original = original + self.expected = expected + self.record = [] + + def __call__(self, *args): + self.record.append(args) + assert_eq(args, self.expected.pop(0)) + + def assert_called_n_times(self, n, msg=None): + assert_eq(len(self.record), n, msg) \ No newline at end of file