python-trezor/trezorlib/transport_hid.py

86 lines
2.5 KiB
Python
Raw Normal View History

2013-03-10 08:55:59 -07:00
'''USB HID implementation of Transport.'''
import hid
2013-09-24 16:14:54 -07:00
import time
2013-03-10 08:55:59 -07:00
from transport import Transport, NotImplementedException
DEVICE_IDS = [
(0x10c4, 0xea80), # Trezor Pi
(0x534c, 0x0001), # Trezor
2013-03-10 08:55:59 -07:00
]
2013-03-10 09:52:04 -07:00
class FakeRead(object):
# Let's pretend we have a file-like interface
def __init__(self, func):
self.func = func
def read(self, size):
return self.func(size)
2013-03-10 08:55:59 -07:00
class HidTransport(Transport):
def __init__(self, device, *args, **kwargs):
self.hid = None
self.buffer = ''
super(HidTransport, self).__init__(device, *args, **kwargs)
2013-09-09 06:37:39 -07:00
2013-03-10 08:55:59 -07:00
@classmethod
def enumerate(cls):
devices = []
for d in hid.enumerate(0, 0):
vendor_id = d.get('vendor_id')
product_id = d.get('product_id')
serial_number = d.get('serial_number')
if (vendor_id, product_id) in DEVICE_IDS:
2013-04-01 07:59:16 -07:00
devices.append("0x%04x:0x%04x:%s" % (vendor_id, product_id, serial_number))
2013-03-10 08:55:59 -07:00
return devices
2013-09-09 06:37:39 -07:00
2013-03-10 08:55:59 -07:00
def _open(self):
self.buffer = ''
2013-04-01 07:59:16 -07:00
path = self.device.split(':')
2013-10-19 05:19:09 -07:00
self.hid = hid.device()
self.hid.open(int(path[0], 16), int(path[1], 16))
2013-09-24 16:14:54 -07:00
self.hid.set_nonblocking(True)
2013-04-01 07:59:16 -07:00
self.hid.send_feature_report([0x41, 0x01]) # enable UART
self.hid.send_feature_report([0x43, 0x03]) # purge TX/RX FIFOs
2013-03-10 08:55:59 -07:00
def _close(self):
self.hid.close()
self.buffer = ''
self.hid = None
def ready_to_read(self):
return False
def _write(self, msg):
msg = bytearray(msg)
2013-09-09 06:37:39 -07:00
while len(msg):
2013-09-24 16:14:54 -07:00
# Report ID, data padded to 63 bytes
self.hid.write([63, ] + list(msg[:63]) + [0] * (63 - len(msg[:63])))
2013-09-09 06:37:39 -07:00
msg = msg[63:]
2013-03-10 08:55:59 -07:00
def _read(self):
2013-03-10 09:52:04 -07:00
(msg_type, datalen) = self._read_headers(FakeRead(self._raw_read))
2013-03-10 08:55:59 -07:00
return (msg_type, self._raw_read(datalen))
def _raw_read(self, length):
while len(self.buffer) < length:
data = self.hid.read(64)
2013-09-24 16:14:54 -07:00
if not len(data):
time.sleep(0.05)
continue
2013-09-09 06:37:39 -07:00
2013-03-10 08:55:59 -07:00
report_id = data[0]
if report_id > 63:
# Command report
raise Exception("Not implemented")
2013-09-09 06:37:39 -07:00
# Payload received, skip the report ID
2013-09-24 16:14:54 -07:00
self.buffer += str(bytearray(data[1:]))
2013-03-10 08:55:59 -07:00
ret = self.buffer[:length]
self.buffer = self.buffer[length:]
2013-04-01 07:59:16 -07:00
return ret