transports: refactor, split protocol code

This commit is contained in:
Jan Pochyla 2017-08-24 14:29:27 +02:00
parent 7019438a49
commit bc42eb68d6
11 changed files with 503 additions and 559 deletions

View File

@ -59,16 +59,16 @@ def pipe_exists(path):
return False return False
if HID_ENABLED and len(HidTransport.enumerate()) > 0: if HID_ENABLED and HidTransport.enumerate():
devices = HidTransport.enumerate() devices = HidTransport.enumerate()
print('Using TREZOR') print('Using TREZOR')
TRANSPORT = HidTransport TRANSPORT = HidTransport
TRANSPORT_ARGS = (devices[0],) TRANSPORT_ARGS = (devices[0],)
TRANSPORT_KWARGS = {'debug_link': False} TRANSPORT_KWARGS = {}
DEBUG_TRANSPORT = HidTransport DEBUG_TRANSPORT = HidTransport
DEBUG_TRANSPORT_ARGS = (devices[0],) DEBUG_TRANSPORT_ARGS = (devices[0].find_debug(),)
DEBUG_TRANSPORT_KWARGS = {'debug_link': True} DEBUG_TRANSPORT_KWARGS = {}
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'): elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):

View File

@ -65,11 +65,12 @@ def cli(ctx, transport, path, verbose, is_json):
if ctx.invoked_subcommand == 'list': if ctx.invoked_subcommand == 'list':
ctx.obj = transport ctx.obj = transport
else: else:
t = get_transport(transport, path) def connect():
return get_transport(transport, path)
if verbose: if verbose:
ctx.obj = TrezorClientVerbose(t) ctx.obj = TrezorClientVerbose(connect)
else: else:
ctx.obj = TrezorClient(t) ctx.obj = TrezorClient(connect)
@cli.resultcallback() @cli.resultcallback()
@ -108,11 +109,7 @@ def print_result(res, transport, path, verbose, is_json):
def ls(transport_name): def ls(transport_name):
transport_class = get_transport_class_by_name(transport_name) transport_class = get_transport_class_by_name(transport_name)
devices = transport_class.enumerate() devices = transport_class.enumerate()
if transport_name == 'usb': return devices
return [dev[0] for dev in devices]
if transport_name == 'bridge':
return devices
return []
# #

View File

@ -153,11 +153,11 @@ def session(f):
# with session activation / deactivation # with session activation / deactivation
def wrapped_f(*args, **kwargs): def wrapped_f(*args, **kwargs):
client = args[0] client = args[0]
client.get_transport().session_begin()
try: try:
client.transport.session_begin()
return f(*args, **kwargs) return f(*args, **kwargs)
finally: finally:
client.transport.session_end() client.get_transport().session_end()
return wrapped_f return wrapped_f
@ -179,17 +179,23 @@ def normalize_nfc(txt):
class BaseClient(object): class BaseClient(object):
# Implements very basic layer of sending raw protobuf # Implements very basic layer of sending raw protobuf
# messages to device and getting its response back. # messages to device and getting its response back.
def __init__(self, transport, **kwargs): def __init__(self, connect, **kwargs):
self.transport = transport self.connect = connect
self.transport = None
super(BaseClient, self).__init__() # *args, **kwargs) super(BaseClient, self).__init__() # *args, **kwargs)
def get_transport(self):
if self.transport is None:
self.transport = self.connect()
return self.transport
def cancel(self): def cancel(self):
self.transport.write(proto.Cancel()) self.get_transport().write(proto.Cancel())
@session @session
def call_raw(self, msg): def call_raw(self, msg):
self.transport.write(msg) self.get_transport().write(msg)
return self.transport.read_blocking() return self.get_transport().read()
@session @session
def call(self, msg): def call(self, msg):
@ -212,9 +218,6 @@ class BaseClient(object):
raise CallException(msg.code, msg.message) raise CallException(msg.code, msg.message)
def close(self):
self.transport.close()
class VerboseWireMixin(object): class VerboseWireMixin(object):
def call_raw(self, msg): def call_raw(self, msg):

View File

@ -43,14 +43,13 @@ class DebugLink(object):
def close(self): def close(self):
self.transport.session_end() self.transport.session_end()
self.transport.close()
def _call(self, msg, nowait=False): def _call(self, msg, nowait=False):
print("DEBUGLINK SEND", pprint(msg)) print("DEBUGLINK SEND", pprint(msg))
self.transport.write(msg) self.transport.write(msg)
if nowait: if nowait:
return return
ret = self.transport.read_blocking() ret = self.transport.read()
print("DEBUGLINK RECV", pprint(ret)) print("DEBUGLINK RECV", pprint(ret))
return ret return ret

80
trezorlib/protocol_v1.py Normal file
View File

@ -0,0 +1,80 @@
# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import struct
from . import mapping
REPLEN = 64
class ProtocolV1(object):
def session_begin(self, transport):
pass
def session_end(self, transport):
pass
def write(self, transport, msg):
ser = msg.SerializeToString()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
data = bytearray(b"##" + header + ser)
while data:
# Report ID, data padded to 63 bytes
chunk = b'?' + data[:REPLEN-1]
chunk = chunk.ljust(REPLEN, bytes([0x00]))
transport.write_chunk(chunk)
data = data[63:]
def read(self, transport):
# Read header with first part of message data
chunk = transport.read_chunk()
(msg_type, datalen, data) = self.parse_first(chunk)
# Read the rest of the message
while len(data) < datalen:
chunk = transport.read_chunk()
data.extend(self.parse_next(chunk))
# Strip padding zeros
data = data[:datalen]
# Parse to protobuf
msg = mapping.get_class(msg_type)()
msg.ParseFromString(bytes(data))
return msg
def parse_first(self, chunk):
if chunk[:3] != b'?##':
raise Exception('Unexpected magic characters')
try:
headerlen = struct.calcsize('>HL')
(msg_type, datalen) = struct.unpack('>HL', bytes(chunk[3:3 + headerlen]))
except:
raise Exception('Cannot parse header')
data = chunk[3 + headerlen:]
return (msg_type, datalen, data)
def parse_next(self, chunk):
if chunk[:1] != b'?':
raise Exception('Unexpected magic characters')
return chunk[1:]

127
trezorlib/protocol_v2.py Normal file
View File

@ -0,0 +1,127 @@
# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import struct
from . import mapping
REPLEN = 64
class ProtocolV2(object):
def __init__(self):
self.session = None
def session_begin(self, transport):
chunk = struct.pack('>B', 0x03)
chunk = chunk.ljust(REPLEN, bytes([0x00]))
transport.write_chunk(chunk)
resp = transport.read_chunk()
self.session = self.parse_session_open(resp)
def session_end(self, transport):
if not self.session:
return
chunk = struct.pack('>BL', 0x04, self.session)
chunk = chunk.ljust(REPLEN, bytes([0x00]))
transport.write_chunk(chunk)
resp = transport.read_chunk()
if resp[0] != 0x04:
raise Exception('Expected session close')
self.session = None
def write(self, transport, msg):
if not self.session:
raise Exception('Missing session for v2 protocol')
# Serialize whole message
data = bytearray(msg.SerializeToString())
dataheader = struct.pack('>LL', mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
# Write it out
while data:
if seq < 0:
repheader = struct.pack('>BL', 0x01, self.session)
else:
repheader = struct.pack('>BLL', 0x02, self.session, seq)
datalen = REPLEN - len(repheader)
chunk = repheader + data[:datalen]
chunk = chunk.ljust(REPLEN, bytes([0x00]))
transport.write_chunk(chunk)
data = data[datalen:]
seq += 1
def read(self, transport):
if not self.session:
raise Exception('Missing session for v2 protocol')
# Read header with first part of message data
chunk = transport.read_chunk()
msg_type, datalen, data = self.parse_first(chunk)
# Read the rest of the message
while len(data) < datalen:
chunk = transport.read_chunk()
next_data = self.parse_next(chunk)
data.extend(next_data)
# Strip padding
data = data[:datalen]
# Parse to protobuf
msg = mapping.get_class(msg_type)()
msg.ParseFromString(bytes(data))
return msg
def parse_first(self, chunk):
try:
headerlen = struct.calcsize('>BLLL')
(magic, session, msg_type, datalen) = struct.unpack('>BLLL', bytes(chunk[:headerlen]))
except:
raise Exception('Cannot parse header')
if magic != 0x01:
raise Exception('Unexpected magic character')
if session != self.session:
raise Exception('Session id mismatch')
return msg_type, datalen, chunk[headerlen:]
def parse_next(self, chunk):
try:
headerlen = struct.calcsize('>BLL')
(magic, session, sequence) = struct.unpack('>BLL', bytes(chunk[:headerlen]))
except:
raise Exception('Cannot parse header')
if magic != 0x02:
raise Exception('Unexpected magic characters')
if session != self.session:
raise Exception('Session id mismatch')
return chunk[headerlen:]
def parse_session_open(self, chunk):
try:
headerlen = struct.calcsize('>BL')
(magic, session) = struct.unpack('>BL', bytes(chunk[:headerlen]))
except:
raise Exception('Cannot parse header')
if magic != 0x03:
raise Exception('Unexpected magic character')
return session

View File

@ -2,6 +2,7 @@
# #
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com> # Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com> # Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
# Copyright (C) 2016 Jochen Hoenicke <hoenicke@gmail.com>
# #
# This library is free software: you can redistribute it and/or modify # This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by # it under the terms of the GNU Lesser General Public License as published by
@ -18,258 +19,24 @@
from __future__ import absolute_import from __future__ import absolute_import
import struct
import binascii
from . import mapping
class NotImplementedException(Exception):
pass
class ConnectionError(Exception):
pass
class Transport(object): class Transport(object):
def __init__(self, device, *args, **kwargs):
self.device = device def __init__(self):
self.session_id = 0 self.session_counter = 0
self.session_depth = 0
self._open()
def session_begin(self): def session_begin(self):
""" if self.session_counter == 0:
Apply a lock to the device in order to preform synchronous multistep "conversations" with the device. For example, before entering the transaction signing workflow, one begins a session. After the transaction is complete, the session may be ended. self.open()
""" self.session_counter += 1
if self.session_depth == 0:
self._session_begin()
self.session_depth += 1
def session_end(self): def session_end(self):
""" self.session_counter = max(self.session_counter - 1, 0)
End a session. Se session_begin for an in depth description of TREZOR sessions. if self.session_counter == 0:
""" self.close()
self.session_depth -= 1
self.session_depth = max(0, self.session_depth) def open(self):
if self.session_depth == 0: raise NotImplementedError
self._session_end()
def close(self): def close(self):
""" raise NotImplementedError
Close the connection to the physical device or file descriptor represented by the Transport.
"""
self._close()
def write(self, msg):
"""
Write mesage to tansport. msg should be a member of a valid `protobuf class <https://developers.google.com/protocol-buffers/docs/pythontutorial>`_ with a SerializeToString() method.
"""
raise NotImplementedException("Not implemented")
def read(self):
"""
If there is data available to be read from the transport, reads the data and tries to parse it as a protobuf message. If the parsing succeeds, return a protobuf object.
Otherwise, returns None.
"""
if not self._ready_to_read():
return None
data = self._read()
if data is None:
return None
return self._parse_message(data)
def read_blocking(self):
"""
Same as read, except blocks until data is available to be read.
"""
while True:
data = self._read()
if data is not None:
break
return self._parse_message(data)
def _parse_message(self, data):
(session_id, msg_type, data) = data
# Raise exception if we get the response with unexpected session ID
if session_id != self.session_id:
raise Exception("Session ID mismatch. Have %d, got %d" %
(self.session_id, session_id))
if msg_type == 'protobuf':
return data
else:
inst = mapping.get_class(msg_type)()
inst.ParseFromString(bytes(data))
return inst
# Functions to be implemented in specific transports:
def _open(self):
raise NotImplementedException("Not implemented")
def _close(self):
raise NotImplementedException("Not implemented")
def _write_chunk(self, chunk):
raise NotImplementedException("Not implemented")
def _read_chunk(self):
raise NotImplementedException("Not implemented")
def _ready_to_read(self):
"""
Returns True if there is data to be read from the transport. Otherwise, False.
"""
raise NotImplementedException("Not implemented")
def _session_begin(self):
pass
def _session_end(self):
pass
class TransportV1(Transport):
def write(self, msg):
ser = msg.SerializeToString()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
data = bytearray(b"##" + header + ser)
while len(data):
# Report ID, data padded to 63 bytes
chunk = b'?' + data[:63] + b'\0' * (63 - len(data[:63]))
self._write_chunk(chunk)
data = data[63:]
def _read(self):
chunk = self._read_chunk()
(msg_type, datalen, data) = self.parse_first(chunk)
while len(data) < datalen:
chunk = self._read_chunk()
data.extend(self.parse_next(chunk))
# Strip padding zeros
data = data[:datalen]
return (0, msg_type, data)
def parse_first(self, chunk):
if chunk[:3] != b"?##":
raise Exception("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
(msg_type, datalen) = struct.unpack(">HL", bytes(chunk[3:3 + headerlen]))
except:
raise Exception("Cannot parse header")
data = chunk[3 + headerlen:]
return (msg_type, datalen, data)
def parse_next(self, chunk):
if chunk[0:1] != b"?":
raise Exception("Unexpected magic characters")
return chunk[1:]
class TransportV2(Transport):
def write(self, msg):
if not self.session_id:
raise Exception('Missing session_id for v2 transport')
data = bytearray(msg.SerializeToString())
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
while len(data):
if seq < 0:
repheader = struct.pack(">BL", 0x01, self.session_id)
else:
repheader = struct.pack(">BLL", 0x02, self.session_id, seq)
datalen = 64 - len(repheader)
chunk = repheader + data[:datalen] + b'\0' * (datalen - len(data[:datalen]))
self._write_chunk(chunk)
data = data[datalen:]
seq += 1
def _read(self):
if not self.session_id:
raise Exception('Missing session_id for v2 transport')
chunk = self._read_chunk()
(session_id, msg_type, datalen, data) = self.parse_first(chunk)
while len(data) < datalen:
chunk = self._read_chunk()
(next_session_id, next_data) = self.parse_next(chunk)
if next_session_id != session_id:
raise Exception("Session id mismatch")
data.extend(next_data)
data = data[:datalen] # Strip padding
return (session_id, msg_type, data)
def parse_first(self, chunk):
try:
headerlen = struct.calcsize(">BLLL")
(magic, session_id, msg_type, datalen) = struct.unpack(">BLLL", bytes(chunk[:headerlen]))
except:
raise Exception("Cannot parse header")
if magic != 0x01:
raise Exception("Unexpected magic character")
return (session_id, msg_type, datalen, chunk[headerlen:])
def parse_next(self, chunk):
try:
headerlen = struct.calcsize(">BLL")
(magic, session_id, sequence) = struct.unpack(">BLL", bytes(chunk[:headerlen]))
except:
raise Exception("Cannot parse header")
if magic != 0x02:
raise Exception("Unexpected magic characters")
return (session_id, chunk[headerlen:])
def parse_session_open(self, chunk):
try:
headerlen = struct.calcsize(">BL")
(magic, session_id) = struct.unpack(">BL", bytes(chunk[:headerlen]))
except:
raise Exception("Cannot parse header")
if magic != 0x03:
raise Exception("Unexpected magic character")
return session_id
def _session_begin(self):
self._write_chunk(bytearray(b'\x03' + b'\0' * 63))
self.session_id = self.parse_session_open(self._read_chunk())
def _session_end(self):
header = struct.pack(">L", self.session_id)
self._write_chunk(bytearray(b'\x04' + header + b'\0' * (63 - len(header))))
if self._read_chunk()[0] != 0x04:
raise Exception("Expected session close")
self.session_id = None
'''
def read_headers(self, read_f):
c = read_f.read(2)
if c != b"?!":
raise Exception("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
(session_id, msg_type, datalen) = struct.unpack(">LLL", read_f.read(headerlen))
except:
raise Exception("Cannot parse header length")
return (0, msg_type, datalen)
'''

View File

@ -17,14 +17,13 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
'''BridgeTransport implements transport TREZOR Bridge (aka trezord).''' from __future__ import absolute_import
import binascii
import json
import requests import requests
from google.protobuf import json_format from google.protobuf import json_format
from . import messages_pb2 as proto
from .transport import TransportV1 from . import messages_pb2
from .transport import Transport
TREZORD_HOST = 'https://localback.net:21324' TREZORD_HOST = 'https://localback.net:21324'
CONFIG_URL = 'https://wallet.trezor.io/data/config_signed.bin' CONFIG_URL = 'https://wallet.trezor.io/data/config_signed.bin'
@ -34,93 +33,84 @@ def get_error(resp):
return ' (error=%d str=%s)' % (resp.status_code, resp.json()['error']) return ' (error=%d str=%s)' % (resp.status_code, resp.json()['error'])
class BridgeTransport(TransportV1): class BridgeTransport(Transport):
'''
BridgeTransport implements transport through TREZOR Bridge (aka trezord).
'''
CONFIGURED = False configured = False
def __init__(self, device, *args, **kwargs): def __init__(self, device):
self.configure() super(BridgeTransport, self).__init__()
self.path = device['path']
self.device = device
self.conn = requests.Session()
self.session = None self.session = None
self.response = None self.response = None
self.conn = requests.Session()
super(BridgeTransport, self).__init__(device, *args, **kwargs) def __str__(self):
return self.device['path']
@staticmethod @staticmethod
def configure(): def configure():
if BridgeTransport.CONFIGURED: if BridgeTransport.configured:
return return
r = requests.get(CONFIG_URL, verify=False) r = requests.get(CONFIG_URL, verify=False)
if r.status_code != 200: if r.status_code != 200:
raise Exception('Could not fetch config from %s' % CONFIG_URL) raise Exception('Could not fetch config from %s' % CONFIG_URL)
r = requests.post(TREZORD_HOST + '/configure', data=r.text)
config = r.text
r = requests.post(TREZORD_HOST + '/configure', data=config)
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not configure' + get_error(r)) raise Exception('trezord: Could not configure' + get_error(r))
BridgeTransport.CONFIGURED = True BridgeTransport.configured = True
@classmethod @staticmethod
def enumerate(cls): def enumerate():
""" BridgeTransport.configure()
Return a list of available TREZOR devices.
"""
cls.configure()
r = requests.get(TREZORD_HOST + '/enumerate') r = requests.get(TREZORD_HOST + '/enumerate')
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not enumerate devices' + get_error(r)) raise Exception('trezord: Could not enumerate devices' +
enum = r.json() get_error(r))
return enum return [BridgeTransport(dev) for dev in r.json()]
@classmethod @staticmethod
def find_by_path(cls, path=None): def find_by_path(path):
""" for transport in BridgeTransport.enumerate():
Finds a device by transport-specific path. if path is None or transport.device['path'] == path:
If path is not set, return first device. return transport
""" raise Exception('Bridge device not found')
devices = cls.enumerate()
for dev in devices:
if not path or dev['path'] == binascii.hexlify(path):
return cls(dev)
raise Exception('Device not found')
def _open(self): def open(self):
r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.path) r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.device['path'])
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not acquire session' + get_error(r)) raise Exception('trezord: Could not acquire session' +
resp = r.json() get_error(r))
self.session = resp['session'] self.session = r.json()['session']
def _close(self): def close(self):
if not self.session:
return
r = self.conn.post(TREZORD_HOST + '/release/%s' % self.session) r = self.conn.post(TREZORD_HOST + '/release/%s' % self.session)
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not release session' + get_error(r)) raise Exception('trezord: Could not release session' +
else: get_error(r))
self.session = None self.session = None
def _ready_to_read(self): def write(self, msg):
return self.response is not None msgname = msg.__class__.__name__
msgjson = json_format.MessageToJson(
def write(self, protobuf_msg): msg, preserving_proto_field_name=True)
# Override main 'write' method, HTTP transport cannot be payload = '{"type": "%s", "message": %s}' % (msgname, msgjson)
# splitted to chunks r = self.conn.post(
cls = protobuf_msg.__class__.__name__ TREZORD_HOST + '/call/%s' % self.session, data=payload)
msg = json_format.MessageToJson(protobuf_msg, preserving_proto_field_name=True)
payload = '{"type": "%s", "message": %s}' % (cls, msg)
r = self.conn.post(TREZORD_HOST + '/call/%s' % self.session, data=payload)
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not write message' + get_error(r)) raise Exception('trezord: Could not write message' + get_error(r))
else: self.response = r.json()
self.response = r.json()
def _read(self): def read(self):
if self.response is None: if self.response is None:
raise Exception('No response stored') raise Exception('No response stored')
cls = getattr(proto, self.response['type']) msgtype = getattr(messages_pb2, self.response['type'])
inst = cls() msg = msgtype()
pb = json_format.ParseDict(self.response['message'], inst) msg = json_format.ParseDict(self.response['message'], msg)
return (0, 'protobuf', pb) self.response = None
return msg

View File

@ -16,168 +16,136 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
'''USB HID implementation of Transport.''' from __future__ import absolute_import
import time import time
import hid import hid
from .transport import TransportV1, TransportV2, ConnectionError
from .protocol_v1 import ProtocolV1
from .protocol_v2 import ProtocolV2
from .transport import Transport
DEV_TREZOR1 = (0x534c, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53c0)
DEV_TREZOR2_BL = (0x1209, 0x1201)
def enumerate(): class HidTransport(Transport):
""" '''
Return a list of available TREZOR devices. HidTransport implements transport over USB HID interface.
""" '''
devices = {}
for d in hid.enumerate(0, 0):
vendor_id = d['vendor_id']
product_id = d['product_id']
serial_number = d['serial_number']
interface_number = d['interface_number']
usage_page = d['usage_page']
path = d['path']
if (vendor_id, product_id) in DEVICE_IDS: def __init__(self, device, protocol=None):
devices.setdefault(serial_number, [None, None]) super(HidTransport, self).__init__()
# first match by usage_page, then try interface number
if usage_page == 0xFF00 or interface_number == 0: # normal link
devices[serial_number][0] = path
elif usage_page == 0xFF01 or interface_number == 1: # debug link
devices[serial_number][1] = path
# List of two-tuples (path_normal, path_debuglink) if protocol is None:
return sorted(devices.values()) if is_trezor2(device):
protocol = ProtocolV2()
else:
def find_by_path(path=None): protocol = ProtocolV1()
""" self.device = device
Finds a device by transport-specific path. self.protocol = protocol
If path is not set, return first device.
"""
devices = enumerate()
for dev in devices:
if not path or path in dev:
return HidTransport(dev)
raise Exception('Device not found')
def path_to_transport(path):
try:
device = [d for d in hid.enumerate(0, 0) if d['path'] == path][0]
except IndexError:
raise ConnectionError("Connection failed")
# VID/PID found, let's find proper transport
try:
transport = DEVICE_TRANSPORTS[(device['vendor_id'], device['product_id'])]
except IndexError:
raise Exception("Unknown transport for VID:PID %04x:%04x" % (device['vendor_id'], device['product_id']))
return transport
class _HidTransport(object):
def __init__(self, device, *args, **kwargs):
self.hid = None self.hid = None
self.hid_version = None self.hid_version = None
device = device[int(bool(kwargs.get('debug_link')))] def __str__(self):
super(_HidTransport, self).__init__(device, *args, **kwargs) return self.device['path']
def is_connected(self): @staticmethod
""" def enumerate(debug=False):
Check if the device is still connected. return [
""" HidTransport(dev) for dev in hid.enumerate(0, 0)
for d in hid.enumerate(0, 0): if ((is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)) and
if d['path'] == self.device: (is_debug(dev) == debug))
return True ]
return False
def _open(self): @staticmethod
def find_by_path(path=None):
for transport in HidTransport.enumerate():
if path is None or transport.device['path'] == path:
return transport
raise Exception('HID device not found')
def find_debug(self):
if isinstance(self.protocol, ProtocolV2):
# For v2 protocol, lets use the same HID interface, but with a different session
debug = HidTransport(self.device, ProtocolV2())
debug.hid = self.hid
debug.hid_version = self.hid_version
return debug
if isinstance(self.protocol, ProtocolV1):
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device['serial_number'] == self.device['serial_number']:
return debug
def open(self):
if self.hid:
return
self.hid = hid.device() self.hid = hid.device()
self.hid.open_path(self.device) self.hid.open_path(self.device['path'])
self.hid.set_nonblocking(True) self.hid.set_nonblocking(True)
if is_trezor1(self.device):
# determine hid_version self.hid_version = self.probe_hid_version()
if isinstance(self, HidTransportV2):
self.hid_version = 2
else: else:
r = self.hid.write([0, 63, ] + [0xFF] * 63) self.hid_version = 2
if r == 65: self.protocol.session_begin(self)
self.hid_version = 2
return
r = self.hid.write([63, ] + [0xFF] * 63)
if r == 64:
self.hid_version = 1
return
raise ConnectionError("Unknown HID version")
def _close(self): def close(self):
self.hid.close() self.protocol.session_end(self)
try:
self.hid.close()
except OSError:
pass # Failing to close the handle is not a problem
self.hid = None self.hid = None
self.hid_version = None
def _write_chunk(self, chunk): def read(self):
return self.protocol.read(self)
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64: if len(chunk) != 64:
raise Exception("Unexpected data length") raise Exception('Unexpected chunk size: %d' % len(chunk))
if self.hid_version == 2: if self.hid_version == 2:
self.hid.write(b'\0' + chunk) self.hid.write(b'\0' + chunk)
else: else:
self.hid.write(chunk) self.hid.write(chunk)
def _read_chunk(self): def read_chunk(self):
start = time.time()
while True: while True:
data = self.hid.read(64) chunk = self.hid.read(64)
if not len(data): if chunk:
if time.time() - start > 10: break
# Over 10 s of no response, let's check if else:
# device is still alive
if not self.is_connected():
raise ConnectionError("Connection failed")
# Restart timer
start = time.time()
time.sleep(0.001) time.sleep(0.001)
continue if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk))
return bytearray(chunk)
break def probe_hid_version(self):
n = self.hid.write([0, 63] + [0xFF] * 63)
if len(data) != 64: if n == 65:
raise Exception("Unexpected chunk size: %d" % len(data)) return 2
n = self.hid.write([63] + [0xFF] * 63)
return bytearray(data) if n == 64:
return 1
raise Exception('Unknown HID version')
class HidTransportV1(_HidTransport, TransportV1): def is_trezor1(dev):
pass return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR1
class HidTransportV2(_HidTransport, TransportV2): def is_trezor2(dev):
pass return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR2
DEVICE_IDS = [ def is_trezor2_bl(dev):
(0x534c, 0x0001), # TREZOR return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR2_BL
(0x1209, 0x53c0), # TREZORv2 Bootloader
(0x1209, 0x53c1), # TREZORv2
]
DEVICE_TRANSPORTS = {
(0x534c, 0x0001): HidTransportV1, # TREZOR
(0x1209, 0x53c0): HidTransportV1, # TREZORv2 Bootloader
(0x1209, 0x53c1): HidTransportV2, # TREZORv2
}
# Backward compatible wrapper, decides for proper transport def is_debug(dev):
# based on VID/PID of given path return (dev['usage_page'] == 0xFF01 or dev['interface_number'] == 1)
def HidTransport(device, *args, **kwargs):
transport = path_to_transport(device[0])
return transport(device, *args, **kwargs)
# Backward compatibility hack; HidTransport is a function, not a class like before
HidTransport.enumerate = enumerate
HidTransport.find_by_path = find_by_path

View File

@ -16,91 +16,94 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import print_function from __future__ import absolute_import
import os import os
import time import time
from select import select
from .transport import TransportV1 from .protocol_v1 import ProtocolV1
"""PipeTransport implements fake wire transport over local named pipe.
Use this transport for talking with trezor simulator."""
class PipeTransport(TransportV1): class PipeTransport(object):
'''
PipeTransport implements fake wire transport over local named pipe.
Use this transport for talking with trezor-emu.
'''
def __init__(self, device=None, is_device=False):
super(PipeTransport, self).__init__()
def __init__(self, device='/tmp/pipe.trezor', is_device=False, *args, **kwargs):
if not device: if not device:
device = '/tmp/pipe.trezor' device = '/tmp/pipe.trezor'
self.is_device = is_device # set True if act as device self.device = device
self.is_device = is_device
self.filename_read = None
self.filename_write = None
self.read_f = None
self.write_f = None
self.protocol = ProtocolV1()
super(PipeTransport, self).__init__(device, *args, **kwargs) def __str__(self):
return self.device
@classmethod @staticmethod
def enumerate(cls): def enumerate():
raise Exception('This transport cannot enumerate devices') raise NotImplementedError('This transport cannot enumerate devices')
@classmethod @staticmethod
def find_by_path(cls, path=None): def find_by_path(path=None):
return cls(path) return PipeTransport(path)
def _open(self): def open(self):
if self.is_device: if self.is_device:
self.filename_read = self.device + '.to' self.filename_read = self.device + '.to'
self.filename_write = self.device + '.from' self.filename_write = self.device + '.from'
os.mkfifo(self.filename_read, 0o600) os.mkfifo(self.filename_read, 0o600)
os.mkfifo(self.filename_write, 0o600) os.mkfifo(self.filename_write, 0o600)
else: else:
self.filename_read = self.device + '.from' self.filename_read = self.device + '.from'
self.filename_write = self.device + '.to' self.filename_write = self.device + '.to'
if not os.path.exists(self.filename_write): if not os.path.exists(self.filename_write):
raise Exception("Not connected") raise Exception("Not connected")
self.write_fd = os.open(self.filename_write, os.O_RDWR) # |os.O_NONBLOCK) self.read_f = os.open(self.filename_read, 'rb', 0)
self.write_f = os.fdopen(self.write_fd, 'w+b', 0) self.write_f = os.open(self.filename_write, 'w+b', 0)
self.read_fd = os.open(self.filename_read, os.O_RDWR) # |os.O_NONBLOCK) self.protocol.session_begin(self)
self.read_f = os.fdopen(self.read_fd, 'rb', 0)
def _close(self): def close(self):
self.read_f.close() self.protocol.session_end(self)
self.write_f.close() if self.read_f:
self.read_f.close()
self.read_f = None
if self.write_f:
self.write_f.close()
self.write_f = None
if self.is_device: if self.is_device:
os.unlink(self.filename_read) os.unlink(self.filename_read)
os.unlink(self.filename_write) os.unlink(self.filename_write)
self.filename_read = None
self.filename_write = None
def _ready_to_read(self): def read(self):
rlist, _, _ = select([self.read_f], [], [], 0) return self.protocol.read(self)
return len(rlist) > 0
def _write_chunk(self, chunk): def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64: if len(chunk) != 64:
raise Exception("Unexpected data length") raise Exception('Unexpected chunk size: %d' % len(chunk))
self.write_f.write(chunk)
self.write_f.flush()
try: def read_chunk(self):
self.write_f.write(chunk)
self.write_f.flush()
except OSError:
print("Error while writing to socket")
raise
def _read_chunk(self):
while True: while True:
try: chunk = self.read_f.read(64)
data = self.read_f.read(64) if chunk:
except IOError: break
print("Failed to read from device") else:
raise
if not len(data):
time.sleep(0.001) time.sleep(0.001)
continue if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk))
break return bytearray(chunk)
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
return bytearray(data)

View File

@ -16,66 +16,76 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
'''UDP Socket implementation of Transport.''' from __future__ import absolute_import
import socket import socket
from select import select
from .transport import TransportV2 from .protocol_v2 import ProtocolV2
from .transport import Transport
class UdpTransport(TransportV2): class UdpTransport(Transport):
def __init__(self, device, *args, **kwargs): DEFAULT_HOST = '127.0.0.1'
if device is None: DEFAULT_PORT = 21324
device = ''
device = device.split(':') def __init__(self, device=None, protocol=None):
if len(device) < 2: super(UdpTransport, self).__init__()
if not device[0]:
# Default port used by trezor v2 if not device:
device = ('127.0.0.1', 21324) host = UdpTransport.DEFAULT_HOST
else: port = UdpTransport.DEFAULT_PORT
device = ('127.0.0.1', int(device[0]))
else: else:
device = (device[0], int(device[1])) host = device.split(':').get(0)
port = device.split(':').get(1, UdpTransport.DEFAULT_PORT)
port = int(port)
if not protocol:
protocol = ProtocolV2()
self.device = (host, port)
self.protocol = protocol
self.socket = None self.socket = None
super(UdpTransport, self).__init__(device, *args, **kwargs)
@classmethod def __str__(self):
def enumerate(cls): return self.device
raise Exception('This transport cannot enumerate devices')
@classmethod @staticmethod
def find_by_path(cls, path=None): def enumerate():
return cls(path) raise NotImplementedError('This transport cannot enumerate devices')
def _open(self): @staticmethod
def find_by_path(path=None):
return UdpTransport(path)
def open(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device) self.socket.connect(self.device)
self.socket.settimeout(10) self.socket.settimeout(10)
self.protocol.session_begin(self)
def _close(self): def close(self):
self.socket.close() if self.socket:
self.socket = None self.protocol.session_end(self)
self.socket.close()
self.socket = None
def _ready_to_read(self): def read(self):
rlist, _, _ = select([self.socket], [], [], 0) return self.protocol.read(self)
return len(rlist) > 0
def _write_chunk(self, chunk): def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64: if len(chunk) != 64:
raise Exception("Unexpected data length") raise Exception('Unexpected data length')
self.socket.sendall(chunk) self.socket.sendall(chunk)
def _read_chunk(self): def read_chunk(self):
while True: while True:
try: try:
data = self.socket.recv(64) chunk = self.socket.recv(64)
break break
except socket.timeout: except socket.timeout:
continue continue
if len(data) != 64: if len(chunk) != 64:
raise Exception("Unexpected chunk size: %d" % len(data)) raise Exception('Unexpected chunk size: %d' % len(chunk))
return bytearray(chunk)
return bytearray(data)