python-trezor/trezorlib/transport_hid.py

163 lines
5.1 KiB
Python
Raw Normal View History

2016-11-25 13:53:55 -08:00
# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
2013-03-10 08:55:59 -07:00
'''USB HID implementation of Transport.'''
import hid
2013-09-24 16:14:54 -07:00
import time
2016-06-26 12:29:29 -07:00
from .transport import TransportV1, TransportV2, ConnectionError
def enumerate():
"""
Return a list of available TREZOR devices.
"""
devices = {}
for d in hid.enumerate(0, 0):
vendor_id = d['vendor_id']
product_id = d['product_id']
serial_number = d['serial_number']
interface_number = d['interface_number']
2016-10-03 01:39:58 -07:00
usage_page = d['usage_page']
2016-06-26 12:29:29 -07:00
path = d['path']
2016-06-27 06:54:24 -07:00
if (vendor_id, product_id) in DEVICE_IDS:
2016-06-26 12:29:29 -07:00
devices.setdefault(serial_number, [None, None])
2016-10-03 01:39:58 -07:00
# first match by usage_page, then try interface number
if usage_page == 0xFF00 or interface_number == 0: # normal link
2016-06-26 12:29:29 -07:00
devices[serial_number][0] = path
2016-10-03 01:39:58 -07:00
elif usage_page == 0xFF01 or interface_number == 1: # debug link
2016-06-26 12:29:29 -07:00
devices[serial_number][1] = path
# List of two-tuples (path_normal, path_debuglink)
return sorted(devices.values())
2016-06-26 12:29:29 -07:00
def path_to_transport(path):
try:
device = [ d for d in hid.enumerate(0, 0) if d['path'] == path ][0]
except IndexError:
raise ConnectionError("Connection failed")
# VID/PID found, let's find proper transport
try:
2016-06-27 06:54:24 -07:00
transport = DEVICE_TRANSPORTS[(device['vendor_id'], device['product_id'])]
2016-06-26 12:29:29 -07:00
except IndexError:
raise Exception("Unknown transport for VID:PID %04x:%04x" % (vid, pid))
return transport
class _HidTransport(object):
2013-03-10 08:55:59 -07:00
def __init__(self, device, *args, **kwargs):
self.hid = None
2016-06-23 09:38:34 -07:00
self.hid_version = None
2013-09-09 06:37:39 -07:00
2016-06-26 12:29:29 -07:00
device = device[int(bool(kwargs.get('debug_link')))]
super(_HidTransport, self).__init__(device, *args, **kwargs)
def is_connected(self):
2014-08-26 07:06:19 -07:00
"""
Check if the device is still connected.
"""
for d in hid.enumerate(0, 0):
if d['path'] == self.device:
return True
return False
2016-02-10 07:46:58 -08:00
2013-03-10 08:55:59 -07:00
def _open(self):
2013-10-19 05:19:09 -07:00
self.hid = hid.device()
2013-11-14 16:43:05 -08:00
self.hid.open_path(self.device)
2013-09-24 16:14:54 -07:00
self.hid.set_nonblocking(True)
2016-06-26 12:29:29 -07:00
2016-06-23 09:38:34 -07:00
# determine hid_version
2016-09-26 11:52:33 -07:00
if isinstance(self, HidTransportV2):
2016-06-23 09:38:34 -07:00
self.hid_version = 2
2016-09-26 11:52:33 -07:00
else:
r = self.hid.write([0, 63, ] + [0xFF] * 63)
if r == 65:
self.hid_version = 2
return
r = self.hid.write([63, ] + [0xFF] * 63)
if r == 64:
self.hid_version = 1
return
raise ConnectionError("Unknown HID version")
2016-02-10 07:46:58 -08:00
2013-03-10 08:55:59 -07:00
def _close(self):
self.hid.close()
self.hid = None
2016-02-10 07:46:58 -08:00
2016-06-26 12:29:29 -07:00
def _write_chunk(self, chunk):
if len(chunk) != 64:
raise Exception("Unexpected data length")
2016-02-10 07:46:58 -08:00
2016-06-26 12:29:29 -07:00
if self.hid_version == 2:
2016-06-27 13:16:34 -07:00
self.hid.write(b'\0' + chunk)
2016-06-26 12:29:29 -07:00
else:
self.hid.write(chunk)
2016-02-10 07:46:58 -08:00
2016-06-26 12:29:29 -07:00
def _read_chunk(self):
start = time.time()
2016-06-26 12:29:29 -07:00
while True:
2013-03-10 08:55:59 -07:00
data = self.hid.read(64)
2013-09-24 16:14:54 -07:00
if not len(data):
if time.time() - start > 10:
# Over 10 s of no response, let's check if
# device is still alive
if not self.is_connected():
raise ConnectionError("Connection failed")
2016-06-26 12:29:29 -07:00
# Restart timer
start = time.time()
2015-02-25 08:54:27 -08:00
time.sleep(0.001)
2013-09-24 16:14:54 -07:00
continue
2013-09-09 06:37:39 -07:00
2016-06-26 12:29:29 -07:00
break
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
2016-02-10 07:46:58 -08:00
2016-06-26 12:29:29 -07:00
return bytearray(data)
class HidTransportV1(_HidTransport, TransportV1):
pass
class HidTransportV2(_HidTransport, TransportV2):
pass
DEVICE_IDS = [
2016-06-27 06:54:24 -07:00
(0x534c, 0x0001), # TREZOR
(0x1209, 0x53C0), # TREZORv2 Bootloader
(0x1209, 0x53C1), # TREZORv2
2016-06-26 12:29:29 -07:00
]
2016-02-10 07:46:58 -08:00
2016-06-27 06:54:24 -07:00
DEVICE_TRANSPORTS = {
(0x534c, 0x0001): HidTransportV1, # TREZOR
(0x1209, 0x53C0): HidTransportV2, # TREZORv2 Bootloader
(0x1209, 0x53C1): HidTransportV2, # TREZORv2
}
2016-06-26 12:29:29 -07:00
# Backward compatible wrapper, decides for proper transport
# based on VID/PID of given path
def HidTransport(device, *args, **kwargs):
transport = path_to_transport(device[0])
return transport(device, *args, **kwargs)
2013-03-10 08:55:59 -07:00
2016-06-26 12:29:29 -07:00
# Backward compatibility hack; HidTransport is a function, not a class like before
HidTransport.enumerate = enumerate