trezor-core/src/protobuf.py

238 lines
5.9 KiB
Python
Raw Permalink Normal View History

'''
2017-07-04 09:09:08 -07:00
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
2017-08-21 04:22:35 -07:00
For de-sererializing (loading) protobuf types, object with `AsyncReader`
interface is required:
2017-09-16 06:00:31 -07:00
>>> class AsyncReader:
>>> async def areadinto(self, buffer):
>>> """
>>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`.
>>> """
2017-08-21 04:22:35 -07:00
For serializing (dumping) protobuf types, object with `AsyncWriter` interface is
required:
2017-09-16 06:00:31 -07:00
>>> class AsyncWriter:
>>> async def awrite(self, buffer):
>>> """
>>> Writes all bytes from `buffer`, or raises `EOFError`.
>>> """
2016-09-21 05:14:49 -07:00
'''
2016-09-29 03:29:43 -07:00
from micropython import const
2017-07-04 09:09:08 -07:00
_UVARINT_BUFFER = bytearray(1)
2016-09-21 05:14:49 -07:00
2017-07-04 09:09:08 -07:00
async def load_uvarint(reader):
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
2017-08-15 06:09:09 -07:00
await reader.areadinto(buffer)
2017-07-04 09:09:08 -07:00
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
2016-09-21 05:14:49 -07:00
2017-07-04 09:09:08 -07:00
async def dump_uvarint(writer, n):
buffer = _UVARINT_BUFFER
shifted = True
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
2017-08-15 06:09:09 -07:00
await writer.awrite(buffer)
2017-07-04 09:09:08 -07:00
n = shifted
2017-07-04 09:09:08 -07:00
class UVarintType:
2016-04-07 14:45:10 -07:00
WIRE_TYPE = 0
class Sint32Type:
WIRE_TYPE = 0
class Sint64Type:
WIRE_TYPE = 0
2017-07-04 09:09:08 -07:00
class BoolType:
WIRE_TYPE = 0
2016-08-05 03:35:45 -07:00
2016-04-07 14:45:10 -07:00
2017-07-04 09:09:08 -07:00
class BytesType:
WIRE_TYPE = 2
2016-04-07 14:45:10 -07:00
2017-07-04 09:09:08 -07:00
class UnicodeType:
2016-04-07 14:45:10 -07:00
WIRE_TYPE = 2
2016-09-21 05:14:49 -07:00
2017-07-04 09:09:08 -07:00
class MessageType:
WIRE_TYPE = 2
FIELDS = {}
def __init__(self, **kwargs):
for kw in kwargs:
setattr(self, kw, kwargs[kw])
def __eq__(self, rhs):
return (self.__class__ is rhs.__class__ and
self.__dict__ == rhs.__dict__)
def __repr__(self):
2018-01-31 05:47:52 -08:00
return '<%s>' % self.__class__.__name__
2017-07-04 09:09:08 -07:00
class LimitedReader:
def __init__(self, reader, limit):
self.reader = reader
self.limit = limit
2017-08-15 06:09:09 -07:00
async def areadinto(self, buf):
2017-07-04 09:09:08 -07:00
if self.limit < len(buf):
raise EOFError
else:
2017-08-15 06:09:09 -07:00
nread = await self.reader.areadinto(buf)
2017-07-04 09:09:08 -07:00
self.limit -= nread
return nread
class CountingWriter:
def __init__(self):
self.size = 0
2017-08-15 06:09:09 -07:00
async def awrite(self, buf):
2017-07-04 09:09:08 -07:00
nwritten = len(buf)
self.size += nwritten
return nwritten
FLAG_REPEATED = const(1)
async def load_message(reader, msg_type):
fields = msg_type.FIELDS
msg = msg_type()
while True:
2016-09-21 05:14:49 -07:00
try:
2017-07-04 09:09:08 -07:00
fkey = await load_uvarint(reader)
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
await load_uvarint(reader)
elif wtype == 2:
ivalue = await load_uvarint(reader)
2017-08-15 06:09:09 -07:00
await reader.areadinto(bytearray(ivalue))
2017-07-04 09:09:08 -07:00
else:
raise ValueError
continue
fname, ftype, fflags = field
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
ivalue = await load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
elif ftype is Sint32Type:
fvalue = (ivalue >> 1) ^ ((ivalue << 31) & 0xffffffff)
elif ftype is Sint64Type:
fvalue = (ivalue >> 1) ^ ((ivalue << 63) & 0xffffffffffffffff)
2017-07-04 09:09:08 -07:00
elif ftype is BoolType:
fvalue = bool(ivalue)
elif ftype is BytesType:
fvalue = bytearray(ivalue)
2017-08-15 06:09:09 -07:00
await reader.areadinto(fvalue)
2017-07-04 09:09:08 -07:00
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
2017-08-15 06:09:09 -07:00
await reader.areadinto(fvalue)
2017-07-04 09:09:08 -07:00
fvalue = str(fvalue, 'utf8')
elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
else:
raise TypeError # field type is unknown
if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname, [])
pvalue.append(fvalue)
fvalue = pvalue
setattr(msg, fname, fvalue)
# fill missing fields
for tag in msg.FIELDS:
field = msg.FIELDS[tag]
if not hasattr(msg, field[0]):
setattr(msg, field[0], None)
return msg
async def dump_message(writer, msg):
repvalue = [0]
mtype = msg.__class__
fields = mtype.FIELDS
for ftag in fields:
field = fields[ftag]
fname = field[0]
ftype = field[1]
fflags = field[2]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
for svalue in fvalue:
await dump_uvarint(writer, fkey)
if ftype is UVarintType:
await dump_uvarint(writer, svalue)
elif ftype is Sint32Type:
await dump_uvarint(writer, ((svalue << 1) & 0xffffffff) ^ (svalue >> 31))
elif ftype is Sint64Type:
await dump_uvarint(writer, ((svalue << 1) & 0xffffffffffffffff) ^ (svalue >> 63))
2017-07-04 09:09:08 -07:00
elif ftype is BoolType:
await dump_uvarint(writer, int(svalue))
elif ftype is BytesType:
await dump_uvarint(writer, len(svalue))
2017-08-15 06:09:09 -07:00
await writer.awrite(svalue)
2017-07-04 09:09:08 -07:00
elif ftype is UnicodeType:
bvalue = bytes(svalue, 'utf8')
await dump_uvarint(writer, len(bvalue))
await writer.awrite(bvalue)
2017-07-04 09:09:08 -07:00
elif issubclass(ftype, MessageType):
counter = CountingWriter()
await dump_message(counter, svalue)
await dump_uvarint(writer, counter.size)
await dump_message(writer, svalue)
else:
2017-07-04 09:09:08 -07:00
raise TypeError