diff --git a/src/trezor/wire/wire_codec.py b/src/trezor/wire/wire_codec.py index 0be68816..364778af 100644 --- a/src/trezor/wire/wire_codec.py +++ b/src/trezor/wire/wire_codec.py @@ -35,7 +35,6 @@ _MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER) def parse_report(data): marker, session_id = ustruct.unpack(_REP_HEADER, data) - # TODO: handle v1 protocol return marker, session_id, data[_REP_HEADER_LEN:] @@ -77,7 +76,7 @@ Sends (msg_type, data_len) to target, followed by data chunks. Throws EOFError after last data chunk, in case of valid checksum. Throws MessageChecksumError to target if data doesn't match the checksum. ''' - message = yield # read first report + message = memoryview((yield)) # read first report msg_type, data_len, data_tail = parse_message(message) target = genfunc(msg_type, data_len, session_id, *args) @@ -86,11 +85,9 @@ Throws MessageChecksumError to target if data doesn't match the checksum. checksum = 0 # crc32 nreports = 1 - compute_checksum = hasattr(ubinascii, 'crc32') - while data_len > 0: if nreports > 1: - data_tail = yield # read next report + data_tail = memoryview((yield)) # read next report nreports += 1 data_chunk = data_tail[:data_len] # slice off the garbage at the end @@ -98,46 +95,35 @@ Throws MessageChecksumError to target if data doesn't match the checksum. data_len -= len(data_chunk) target.send(data_chunk) - if compute_checksum: - checksum = ubinascii.crc32(data_chunk, checksum) & 0xffffffff + checksum = ubinascii.crc32(data_chunk, checksum) msg_footer = data_tail[:_MSG_FOOTER_LEN] if len(msg_footer) < _MSG_FOOTER_LEN: data_tail = yield # read report with the rest of checksum msg_footer += data_tail[:_MSG_FOOTER_LEN - len(msg_footer)] - if compute_checksum: - data_checksum, = parse_message_footer(msg_footer) - else: - data_checksum = checksum + data_checksum, = parse_message_footer(msg_footer) if data_checksum != checksum: - target.throw(MessageChecksumError( - 'Message checksum mismatch, expected %d, received %d' % (checksum, data_checksum))) + target.throw(MessageChecksumError((checksum, data_checksum))) else: target.throw(EOFError()) def encode_wire_message(msg_type, msg_data, session_id, target): - report = bytearray(_REP_LEN) + report = memoryview(bytearray(_REP_LEN)) serialize_report_header(report, REP_MARKER_HEADER, session_id) serialize_message_header(report, msg_type, len(msg_data)) - msg_data = memoryview(msg_data) - report = memoryview(report) - - source_data = msg_data + source_data = memoryview(msg_data) target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:] - compute_checksum = hasattr(ubinascii, 'crc32') - - if compute_checksum: - checksum = ubinascii.crc32(msg_data) & 0xffffffff - else: - checksum = 0 + checksum = ubinascii.crc32(msg_data) msg_footer = bytearray(_MSG_FOOTER_LEN) serialize_message_footer(msg_footer, checksum) + first = True + while True: # move as much as possible from source to target n = min(len(target_data), len(source_data)) @@ -157,8 +143,10 @@ def encode_wire_message(msg_type, msg_data, session_id, target): break # reset to skip the magic and session ID - serialize_report_header(report, REP_MARKER_DATA, session_id) - target_data = report[_REP_HEADER_LEN:] + if first: + serialize_report_header(report, REP_MARKER_DATA, session_id) + target_data = report[_REP_HEADER_LEN:] + first = False def encode_session_open_message(session_id, target):