diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index 0431cca3..40040f63 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -12,9 +12,7 @@ from .wire_codec import \ decode_wire_stream, encode_wire_message, \ encode_session_open_message, encode_session_close_message from .wire_codec_v1 import \ - SESSION_V1, \ - decode_wire_v1_stream, \ - encode_wire_v1_message + SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message _session_handlers = {} # session id -> generator _workflow_genfuncs = {} # wire type -> (generator function, args) @@ -138,8 +136,8 @@ async def monitor_workflow(workflow, session_id): finally: if session_id in _opened_sessions: if session_id == SESSION_V1: - wire_decoder = decode_wire_v1_stream(_handle_registered_type, - SESSION_V1) + wire_decoder = decode_wire_v1_stream( + _handle_registered_type, session_id) else: wire_decoder = decode_wire_stream( _handle_registered_type, session_id) diff --git a/src/trezor/wire/wire_codec.py b/src/trezor/wire/wire_codec.py index 1ccb8f6b..7edee032 100644 --- a/src/trezor/wire/wire_codec.py +++ b/src/trezor/wire/wire_codec.py @@ -83,12 +83,8 @@ Throws MessageChecksumError to target if data doesn't match the checksum. target.send(None) checksum = 0 # crc32 - nreports = 1 while data_len > 0: - if nreports > 1: - data_tail = memoryview((yield)) # read next report - nreports += 1 data_chunk = data_tail[:data_len] # slice off the garbage at the end data_tail = data_tail[len(data_chunk):] # slice off what we have read @@ -97,6 +93,9 @@ Throws MessageChecksumError to target if data doesn't match the checksum. checksum = ubinascii.crc32(data_chunk, checksum) + if data_len > 0: + data_tail = memoryview((yield)) # read next report + msg_footer = data_tail[:_MSG_FOOTER_LEN] if len(msg_footer) < _MSG_FOOTER_LEN: data_tail = yield # read report with the rest of checksum @@ -137,7 +136,7 @@ def encode_wire_message(msg_type, msg_data, session_id, target): msg_footer = None continue - # FIXME: Optimize speed + # FIXME: optimize speed x = 0 to_fill = len(target_data) while x < to_fill: diff --git a/src/trezor/wire/wire_codec_v1.py b/src/trezor/wire/wire_codec_v1.py index 384a6f58..c0c1b296 100644 --- a/src/trezor/wire/wire_codec_v1.py +++ b/src/trezor/wire/wire_codec_v1.py @@ -1,28 +1,36 @@ import ustruct SESSION_V1 = const(0) -REP_MARKER_V1 = const(63) # ord('?) -REP_MARKER_V1_LEN = const(1) # len('?') +REP_MARKER_V1 = const(63) # ord('?) +REP_MARKER_V1_LEN = const(1) # len('?') -_MSG_HEADER_MAGIC = const(35) # org('#') -_MSG_HEADER_V1 = '>BBHL' # wire type, data length +_REP_LEN = const(64) +_MSG_HEADER_MAGIC = const(35) # org('#') +_MSG_HEADER_V1 = '>BBHL' # magic, magic, wire type, data length _MSG_HEADER_V1_LEN = ustruct.calcsize(_MSG_HEADER_V1) + def detect_v1(data): return (data[0] == REP_MARKER_V1) + def parse_report_v1(data): return None, SESSION_V1, data[1:] + def parse_message(data): magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data) if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC: - raise Exception("Corrupted magic bytes") + raise Exception('Corrupted magic bytes') return msg_type, data_len, data[_MSG_HEADER_V1_LEN:] + def serialize_message_header(data, msg_type, msg_len): - ustruct.pack_into(_MSG_HEADER_V1, data, REP_MARKER_V1_LEN, _MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len) + ustruct.pack_into( + _MSG_HEADER_V1, data, REP_MARKER_V1_LEN, + _MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len) + def decode_wire_v1_stream(genfunc, session_id, *args): '''Decode a v1 wire message from the report data and stream it to target. @@ -32,8 +40,8 @@ Sends (msg_type, data_len) to target, followed by data chunks. Throws EOFError after last data chunk, in case of valid checksum. Throws MessageChecksumError to target if data doesn't match the checksum. ''' - - message = yield # read first report + + message = yield # read first report msg_type, data_len, data = parse_message(message) print(msg_type, data_len, bytes(data)) @@ -48,17 +56,18 @@ Throws MessageChecksumError to target if data doesn't match the checksum. target.send(data_chunk) if data_len > 0: - data = yield # First next record + data = yield # read next report target.throw(EOFError()) + def encode_wire_v1_message(msg_type, msg_data, target): - report = memoryview(bytearray(64)) # Maximum report length - report[0] = REP_MARKER_V1 # Put report marker + report = memoryview(bytearray(_REP_LEN)) + report[0] = REP_MARKER_V1 serialize_message_header(report, msg_type, len(msg_data)) source_data = memoryview(msg_data) - target_data = report[REP_MARKER_V1_LEN+_MSG_HEADER_V1_LEN:] + target_data = report[REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:] while True: # move as much as possible from source to target @@ -67,7 +76,7 @@ def encode_wire_v1_message(msg_type, msg_data, target): source_data = source_data[n:] target_data = target_data[n:] - # FIXME: Optimize speed + # FIXME: optimize speed x = 0 to_fill = len(target_data) while x < to_fill: @@ -79,4 +88,5 @@ def encode_wire_v1_message(msg_type, msg_data, target): if not source_data: break - target_data = report[REP_MARKER_V1_LEN:] \ No newline at end of file + # reset to skip the magic, not the whole header anymore + target_data = report[REP_MARKER_V1_LEN:]