trezor-core/src/lib/protobuf.py

330 lines
10 KiB
Python
Raw Normal View History

2016-09-21 05:14:49 -07:00
'''Streaming protobuf codec.
2016-04-07 14:45:10 -07:00
2016-09-21 05:14:49 -07:00
Handles asynchronous encoding and decoding of protobuf value streams.
2016-04-07 14:45:10 -07:00
2016-09-21 05:14:49 -07:00
Value format: ((field_type, field_flags, field_name), field_value)
field_type: Either one of UVarintType, BoolType, BytesType, UnicodeType,
or an instance of EmbeddedMessage.
field_flags (int): Field bit flags `FLAG_REQUIRED`, `FLAG_REPEATED`.
field_name (str): Field name string.
field_value: Depends on field_type. EmbeddedMessage has `field_value == None`.
'''
2016-09-29 03:29:43 -07:00
from micropython import const
2016-09-21 05:14:49 -07:00
def build_protobuf_message(message_type, callback, *args):
message = message_type()
try:
while True:
field, field_value = yield
field_type, field_flags, field_name = field
if not _is_scalar_type(field_type):
field_value = yield from build_protobuf_message(field_type, callback)
if field_flags & FLAG_REPEATED:
2016-10-06 03:12:38 -07:00
prev_value = getattr(message, field_name, [])
prev_value.append(field_value)
field_value = prev_value
2016-09-21 05:14:49 -07:00
setattr(message, field_name, field_value)
except EOFError:
callback(message, *args)
def print_protobuf_message(message_type):
print('OPEN', message_type)
try:
while True:
field, field_value = yield
field_type, _, field_name = field
if not _is_scalar_type(field_type):
yield from print_protobuf_message(field_type)
else:
print('FIELD', field_name, field_type, field_value)
except EOFError:
print('CLOSE', message_type)
2016-08-05 03:35:45 -07:00
2016-04-07 14:45:10 -07:00
_UVARINT_DUMP_BUFFER = bytearray(1)
2016-04-07 14:45:10 -07:00
class UVarintType:
WIRE_TYPE = 0
@staticmethod
async def dump(target, value):
shifted = True
while shifted:
shifted = value >> 7
_UVARINT_DUMP_BUFFER[0] = (value & 0x7F) | (
0x80 if shifted else 0x00)
await target.write(_UVARINT_DUMP_BUFFER)
value = shifted
2016-04-07 14:45:10 -07:00
@staticmethod
async def load(source):
2016-04-07 14:45:10 -07:00
value, shift, quantum = 0, 0, 0x80
while (quantum & 0x80) == 0x80:
buffer = await source.read(1)
quantum = buffer[0]
value = value + ((quantum & 0x7F) << shift)
shift += 7
2016-04-07 14:45:10 -07:00
return value
2016-04-07 14:45:10 -07:00
class BoolType:
WIRE_TYPE = 0
@staticmethod
async def dump(target, value):
await target.write(b'\x01' if value else b'\x00')
2016-04-07 14:45:10 -07:00
@staticmethod
async def load(source):
varint = await UVarintType.load(source)
2016-09-21 05:14:49 -07:00
return varint != 0
2016-04-07 14:45:10 -07:00
2016-08-05 03:35:45 -07:00
2016-04-07 14:45:10 -07:00
class BytesType:
WIRE_TYPE = 2
@staticmethod
async def dump(target, value):
await UVarintType.dump(target, len(value))
await target.write(value)
2016-04-07 14:45:10 -07:00
@staticmethod
async def load(source):
size = await UVarintType.load(source)
data = await source.read(size)
2016-09-21 05:14:49 -07:00
return data
2016-04-07 14:45:10 -07:00
2016-08-05 03:35:45 -07:00
2016-04-07 14:45:10 -07:00
class UnicodeType:
WIRE_TYPE = 2
@staticmethod
async def dump(target, value):
await BytesType.dump(target, bytes(value, 'utf-8'))
2016-04-07 14:45:10 -07:00
@staticmethod
async def load(source):
data = await BytesType.load(source)
2016-09-25 06:35:04 -07:00
return str(data, 'utf-8', 'strict')
2016-04-07 14:45:10 -07:00
2016-09-21 05:14:49 -07:00
class EmbeddedMessage:
WIRE_TYPE = 2
2016-08-05 03:35:45 -07:00
2016-09-21 05:14:49 -07:00
def __init__(self, message_type):
'''Initializes a new instance. The argument is an underlying message type.'''
self.message_type = message_type
2016-08-05 03:35:45 -07:00
2016-09-21 05:14:49 -07:00
def __call__(self):
'''Creates a message of the underlying message type.'''
return self.message_type()
2016-04-07 14:45:10 -07:00
async def dump(self, target, value):
buf = self.message_type.dumps(value)
await BytesType.dump(target, buf)
2016-08-05 03:35:45 -07:00
async def load(self, target, source):
emb_size = await UVarintType.load(source)
2016-09-21 05:14:49 -07:00
emb_source = source.trim(emb_size)
await self.message_type.load(emb_source, target)
2016-09-21 05:14:49 -07:00
FLAG_SIMPLE = const(0)
FLAG_REQUIRED = const(1)
FLAG_REPEATED = const(2)
2016-04-07 14:45:10 -07:00
2016-08-05 03:35:45 -07:00
2016-04-07 14:45:10 -07:00
# Packs a tag and a wire_type into single int according to the protobuf spec.
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
2016-04-07 14:45:10 -07:00
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
_unpack_key = lambda key: (key >> 3, key & 7)
2016-09-21 05:14:49 -07:00
# Determines if a field type is a scalar or not.
_is_scalar_type = lambda field_type: not isinstance(
field_type, EmbeddedMessage)
class StreamReader:
def __init__(self, buf=None, limit=None):
self.buf = buf if buf is not None else bytearray()
self.limit = limit
def read(self, n):
if self.limit is not None:
if self.limit < n:
raise EOFError()
self.limit -= n
buf = self.buf
while len(buf) < n:
chunk = yield
buf.extend(chunk)
2016-09-25 06:35:04 -07:00
# TODO: is this the most officient way?
2016-09-21 05:14:49 -07:00
result = buf[:n]
buf[:] = buf[n:]
return result
def trim(self, limit):
return StreamReader(self.buf, limit)
class StreamWriter:
def __init__(self):
self.buffer = bytearray()
async def write(self, b):
self.buffer.extend(b)
2016-04-07 14:45:10 -07:00
2016-08-05 03:35:45 -07:00
2016-04-07 14:45:10 -07:00
class MessageType:
2016-09-21 05:14:49 -07:00
'''Represents a message type.'''
2016-04-07 14:45:10 -07:00
def __init__(self, name=None):
self.__name = name
2016-09-21 05:14:49 -07:00
self.__fields = {} # tag -> tuple of field_type, field_flags, field_name
self.__defaults = {} # tag -> default_value
2016-04-07 14:45:10 -07:00
2016-04-26 16:32:57 -07:00
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
2016-09-21 05:14:49 -07:00
'''Adds a field to the message type.'''
if tag in self.__fields:
2016-04-07 14:45:10 -07:00
raise ValueError('The tag %s is already used.' % tag)
2016-09-21 05:14:49 -07:00
if default is not None:
2016-04-26 16:32:57 -07:00
self.__defaults[tag] = default
2016-09-21 05:14:49 -07:00
self.__fields[tag] = (field_type, flags, name)
2016-04-07 14:45:10 -07:00
2016-06-01 05:06:41 -07:00
def __call__(self, **fields):
2016-09-21 05:14:49 -07:00
'''Creates an instance of this message type.'''
2016-06-01 05:06:41 -07:00
return Message(self, **fields)
2016-04-07 14:45:10 -07:00
2016-09-21 05:14:49 -07:00
def dumps(self, value):
target = StreamWriter()
dumper = self.dump(target, value)
try:
while True:
dumper.send(None)
except (StopIteration, EOFError):
return target.buffer
2016-09-21 05:14:49 -07:00
async def dump(self, target, value):
2016-09-21 05:14:49 -07:00
if self is not value.message_type:
raise TypeError('Incompatible type')
for tag, field in self.__fields.items():
field_type, field_flags, field_name = field
field_value = getattr(value, field_name, None)
if field_value is None:
2016-09-21 05:14:49 -07:00
if field_flags & FLAG_REQUIRED:
raise ValueError(
'The field with the tag %s is required but a value is missing.' % tag)
else:
continue
if field_flags & FLAG_REPEATED:
# repeated value
key = _pack_key(tag, field_type.WIRE_TYPE)
# send the values sequentially
for single_value in field_value:
await UVarintType.dump(target, key)
await field_type.dump(target, single_value)
2016-09-21 05:14:49 -07:00
else:
# single value
await UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
await field_type.dump(target, field_value)
def loads(self, value):
result = None
def callback(message):
nonlocal result
result = message
target = build_protobuf_message(self, callback)
target.send(None)
# TODO: avoid the copy!
source = StreamReader(bytearray(value), len(value))
loader = self.load(target, source)
try:
while True:
loader.send(None)
except (StopIteration, EOFError):
if result is None:
raise Exception('Failed to parse protobuf message')
return result
2016-09-21 05:14:49 -07:00
async def load(self, target, source=None):
2016-09-21 05:14:49 -07:00
if source is None:
source = StreamReader()
found_tags = set()
try:
while True:
key = await UVarintType.load(source)
2016-09-21 05:14:49 -07:00
tag, wire_type = _unpack_key(key)
found_tags.add(tag)
if tag in self.__fields:
# retrieve the field descriptor by tag
field = self.__fields[tag]
field_type, _, _ = field
2016-04-26 16:32:57 -07:00
if wire_type != field_type.WIRE_TYPE:
raise TypeError(
2016-08-05 03:35:45 -07:00
'Value of tag %s has incorrect wiretype %s, %s expected.' %
2016-04-26 16:32:57 -07:00
(tag, wire_type, field_type.WIRE_TYPE))
2016-04-07 14:45:10 -07:00
else:
2016-09-21 05:14:49 -07:00
# unknown field, skip it
field_type = {0: UVarintType, 2: BytesType}[wire_type]
await field_type.load(source)
2016-09-21 05:14:49 -07:00
continue
if _is_scalar_type(field_type):
field_value = await field_type.load(source)
2016-09-21 05:14:49 -07:00
target.send((field, field_value))
else:
await field_type.load(target, source)
2016-09-21 05:14:49 -07:00
except EOFError:
for tag, field in self.__fields.items():
# send the default value
if tag not in found_tags and tag in self.__defaults:
target.send((field, self.__defaults[tag]))
found_tags.add(tag)
# check if all required fields are present
_, field_flags, field_name = field
if field_flags & FLAG_REQUIRED and tag not in found_tags:
if field_flags & FLAG_REPEATED:
# no values were in input stream, but required field.
# send empty list
target.send((field, []))
else:
raise ValueError(
'The field %s (\'%s\') is required but missing.' % (tag, field_name))
target.throw(EOFError)
def __repr__(self):
return '<MessageType: %s>' % self.__name
2016-08-05 03:35:45 -07:00
2016-04-07 14:45:10 -07:00
class Message:
2016-09-21 05:14:49 -07:00
'''Represents a message instance.'''
2016-04-07 14:45:10 -07:00
2016-06-01 05:06:41 -07:00
def __init__(self, message_type, **fields):
2016-09-21 05:14:49 -07:00
'''Initializes a new instance of the specified message type.'''
2016-04-07 14:45:10 -07:00
self.message_type = message_type
for key in fields:
setattr(self, key, fields[key])
2016-04-07 14:45:10 -07:00
async def dump(self, target):
return await self.message_type.dump(target, self)
def dumps(self):
return self.message_type.dumps(self)
2016-04-07 14:45:10 -07:00
def __repr__(self):
values = self.__dict__
2016-08-05 03:35:45 -07:00
values = {k: values[k] for k in values if k != 'message_type'}
return '<%s: %s>' % (self.message_type.__name, values)