transports: refactor, split protocol code

This commit is contained in:
Jan Pochyla 2017-08-24 14:29:27 +02:00
parent 7019438a49
commit bc42eb68d6
11 changed files with 503 additions and 559 deletions

View File

@ -59,16 +59,16 @@ def pipe_exists(path):
return False
if HID_ENABLED and len(HidTransport.enumerate()) > 0:
if HID_ENABLED and HidTransport.enumerate():
devices = HidTransport.enumerate()
print('Using TREZOR')
TRANSPORT = HidTransport
TRANSPORT_ARGS = (devices[0],)
TRANSPORT_KWARGS = {'debug_link': False}
TRANSPORT_KWARGS = {}
DEBUG_TRANSPORT = HidTransport
DEBUG_TRANSPORT_ARGS = (devices[0],)
DEBUG_TRANSPORT_KWARGS = {'debug_link': True}
DEBUG_TRANSPORT_ARGS = (devices[0].find_debug(),)
DEBUG_TRANSPORT_KWARGS = {}
elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'):

View File

@ -65,11 +65,12 @@ def cli(ctx, transport, path, verbose, is_json):
if ctx.invoked_subcommand == 'list':
ctx.obj = transport
else:
t = get_transport(transport, path)
def connect():
return get_transport(transport, path)
if verbose:
ctx.obj = TrezorClientVerbose(t)
ctx.obj = TrezorClientVerbose(connect)
else:
ctx.obj = TrezorClient(t)
ctx.obj = TrezorClient(connect)
@cli.resultcallback()
@ -108,11 +109,7 @@ def print_result(res, transport, path, verbose, is_json):
def ls(transport_name):
transport_class = get_transport_class_by_name(transport_name)
devices = transport_class.enumerate()
if transport_name == 'usb':
return [dev[0] for dev in devices]
if transport_name == 'bridge':
return devices
return []
return devices
#

View File

@ -153,11 +153,11 @@ def session(f):
# with session activation / deactivation
def wrapped_f(*args, **kwargs):
client = args[0]
client.get_transport().session_begin()
try:
client.transport.session_begin()
return f(*args, **kwargs)
finally:
client.transport.session_end()
client.get_transport().session_end()
return wrapped_f
@ -179,17 +179,23 @@ def normalize_nfc(txt):
class BaseClient(object):
# Implements very basic layer of sending raw protobuf
# messages to device and getting its response back.
def __init__(self, transport, **kwargs):
self.transport = transport
def __init__(self, connect, **kwargs):
self.connect = connect
self.transport = None
super(BaseClient, self).__init__() # *args, **kwargs)
def get_transport(self):
if self.transport is None:
self.transport = self.connect()
return self.transport
def cancel(self):
self.transport.write(proto.Cancel())
self.get_transport().write(proto.Cancel())
@session
def call_raw(self, msg):
self.transport.write(msg)
return self.transport.read_blocking()
self.get_transport().write(msg)
return self.get_transport().read()
@session
def call(self, msg):
@ -212,9 +218,6 @@ class BaseClient(object):
raise CallException(msg.code, msg.message)
def close(self):
self.transport.close()
class VerboseWireMixin(object):
def call_raw(self, msg):

View File

@ -43,14 +43,13 @@ class DebugLink(object):
def close(self):
self.transport.session_end()
self.transport.close()
def _call(self, msg, nowait=False):
print("DEBUGLINK SEND", pprint(msg))
self.transport.write(msg)
if nowait:
return
ret = self.transport.read_blocking()
ret = self.transport.read()
print("DEBUGLINK RECV", pprint(ret))
return ret

80
trezorlib/protocol_v1.py Normal file
View File

@ -0,0 +1,80 @@
# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import struct
from . import mapping
REPLEN = 64
class ProtocolV1(object):
def session_begin(self, transport):
pass
def session_end(self, transport):
pass
def write(self, transport, msg):
ser = msg.SerializeToString()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
data = bytearray(b"##" + header + ser)
while data:
# Report ID, data padded to 63 bytes
chunk = b'?' + data[:REPLEN-1]
chunk = chunk.ljust(REPLEN, bytes([0x00]))
transport.write_chunk(chunk)
data = data[63:]
def read(self, transport):
# Read header with first part of message data
chunk = transport.read_chunk()
(msg_type, datalen, data) = self.parse_first(chunk)
# Read the rest of the message
while len(data) < datalen:
chunk = transport.read_chunk()
data.extend(self.parse_next(chunk))
# Strip padding zeros
data = data[:datalen]
# Parse to protobuf
msg = mapping.get_class(msg_type)()
msg.ParseFromString(bytes(data))
return msg
def parse_first(self, chunk):
if chunk[:3] != b'?##':
raise Exception('Unexpected magic characters')
try:
headerlen = struct.calcsize('>HL')
(msg_type, datalen) = struct.unpack('>HL', bytes(chunk[3:3 + headerlen]))
except:
raise Exception('Cannot parse header')
data = chunk[3 + headerlen:]
return (msg_type, datalen, data)
def parse_next(self, chunk):
if chunk[:1] != b'?':
raise Exception('Unexpected magic characters')
return chunk[1:]

127
trezorlib/protocol_v2.py Normal file
View File

@ -0,0 +1,127 @@
# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import struct
from . import mapping
REPLEN = 64
class ProtocolV2(object):
def __init__(self):
self.session = None
def session_begin(self, transport):
chunk = struct.pack('>B', 0x03)
chunk = chunk.ljust(REPLEN, bytes([0x00]))
transport.write_chunk(chunk)
resp = transport.read_chunk()
self.session = self.parse_session_open(resp)
def session_end(self, transport):
if not self.session:
return
chunk = struct.pack('>BL', 0x04, self.session)
chunk = chunk.ljust(REPLEN, bytes([0x00]))
transport.write_chunk(chunk)
resp = transport.read_chunk()
if resp[0] != 0x04:
raise Exception('Expected session close')
self.session = None
def write(self, transport, msg):
if not self.session:
raise Exception('Missing session for v2 protocol')
# Serialize whole message
data = bytearray(msg.SerializeToString())
dataheader = struct.pack('>LL', mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
# Write it out
while data:
if seq < 0:
repheader = struct.pack('>BL', 0x01, self.session)
else:
repheader = struct.pack('>BLL', 0x02, self.session, seq)
datalen = REPLEN - len(repheader)
chunk = repheader + data[:datalen]
chunk = chunk.ljust(REPLEN, bytes([0x00]))
transport.write_chunk(chunk)
data = data[datalen:]
seq += 1
def read(self, transport):
if not self.session:
raise Exception('Missing session for v2 protocol')
# Read header with first part of message data
chunk = transport.read_chunk()
msg_type, datalen, data = self.parse_first(chunk)
# Read the rest of the message
while len(data) < datalen:
chunk = transport.read_chunk()
next_data = self.parse_next(chunk)
data.extend(next_data)
# Strip padding
data = data[:datalen]
# Parse to protobuf
msg = mapping.get_class(msg_type)()
msg.ParseFromString(bytes(data))
return msg
def parse_first(self, chunk):
try:
headerlen = struct.calcsize('>BLLL')
(magic, session, msg_type, datalen) = struct.unpack('>BLLL', bytes(chunk[:headerlen]))
except:
raise Exception('Cannot parse header')
if magic != 0x01:
raise Exception('Unexpected magic character')
if session != self.session:
raise Exception('Session id mismatch')
return msg_type, datalen, chunk[headerlen:]
def parse_next(self, chunk):
try:
headerlen = struct.calcsize('>BLL')
(magic, session, sequence) = struct.unpack('>BLL', bytes(chunk[:headerlen]))
except:
raise Exception('Cannot parse header')
if magic != 0x02:
raise Exception('Unexpected magic characters')
if session != self.session:
raise Exception('Session id mismatch')
return chunk[headerlen:]
def parse_session_open(self, chunk):
try:
headerlen = struct.calcsize('>BL')
(magic, session) = struct.unpack('>BL', bytes(chunk[:headerlen]))
except:
raise Exception('Cannot parse header')
if magic != 0x03:
raise Exception('Unexpected magic character')
return session

View File

@ -2,6 +2,7 @@
#
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
# Copyright (C) 2016 Jochen Hoenicke <hoenicke@gmail.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
@ -18,258 +19,24 @@
from __future__ import absolute_import
import struct
import binascii
from . import mapping
class NotImplementedException(Exception):
pass
class ConnectionError(Exception):
pass
class Transport(object):
def __init__(self, device, *args, **kwargs):
self.device = device
self.session_id = 0
self.session_depth = 0
self._open()
def __init__(self):
self.session_counter = 0
def session_begin(self):
"""
Apply a lock to the device in order to preform synchronous multistep "conversations" with the device. For example, before entering the transaction signing workflow, one begins a session. After the transaction is complete, the session may be ended.
"""
if self.session_depth == 0:
self._session_begin()
self.session_depth += 1
if self.session_counter == 0:
self.open()
self.session_counter += 1
def session_end(self):
"""
End a session. Se session_begin for an in depth description of TREZOR sessions.
"""
self.session_depth -= 1
self.session_depth = max(0, self.session_depth)
if self.session_depth == 0:
self._session_end()
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.close()
def open(self):
raise NotImplementedError
def close(self):
"""
Close the connection to the physical device or file descriptor represented by the Transport.
"""
self._close()
def write(self, msg):
"""
Write mesage to tansport. msg should be a member of a valid `protobuf class <https://developers.google.com/protocol-buffers/docs/pythontutorial>`_ with a SerializeToString() method.
"""
raise NotImplementedException("Not implemented")
def read(self):
"""
If there is data available to be read from the transport, reads the data and tries to parse it as a protobuf message. If the parsing succeeds, return a protobuf object.
Otherwise, returns None.
"""
if not self._ready_to_read():
return None
data = self._read()
if data is None:
return None
return self._parse_message(data)
def read_blocking(self):
"""
Same as read, except blocks until data is available to be read.
"""
while True:
data = self._read()
if data is not None:
break
return self._parse_message(data)
def _parse_message(self, data):
(session_id, msg_type, data) = data
# Raise exception if we get the response with unexpected session ID
if session_id != self.session_id:
raise Exception("Session ID mismatch. Have %d, got %d" %
(self.session_id, session_id))
if msg_type == 'protobuf':
return data
else:
inst = mapping.get_class(msg_type)()
inst.ParseFromString(bytes(data))
return inst
# Functions to be implemented in specific transports:
def _open(self):
raise NotImplementedException("Not implemented")
def _close(self):
raise NotImplementedException("Not implemented")
def _write_chunk(self, chunk):
raise NotImplementedException("Not implemented")
def _read_chunk(self):
raise NotImplementedException("Not implemented")
def _ready_to_read(self):
"""
Returns True if there is data to be read from the transport. Otherwise, False.
"""
raise NotImplementedException("Not implemented")
def _session_begin(self):
pass
def _session_end(self):
pass
class TransportV1(Transport):
def write(self, msg):
ser = msg.SerializeToString()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
data = bytearray(b"##" + header + ser)
while len(data):
# Report ID, data padded to 63 bytes
chunk = b'?' + data[:63] + b'\0' * (63 - len(data[:63]))
self._write_chunk(chunk)
data = data[63:]
def _read(self):
chunk = self._read_chunk()
(msg_type, datalen, data) = self.parse_first(chunk)
while len(data) < datalen:
chunk = self._read_chunk()
data.extend(self.parse_next(chunk))
# Strip padding zeros
data = data[:datalen]
return (0, msg_type, data)
def parse_first(self, chunk):
if chunk[:3] != b"?##":
raise Exception("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
(msg_type, datalen) = struct.unpack(">HL", bytes(chunk[3:3 + headerlen]))
except:
raise Exception("Cannot parse header")
data = chunk[3 + headerlen:]
return (msg_type, datalen, data)
def parse_next(self, chunk):
if chunk[0:1] != b"?":
raise Exception("Unexpected magic characters")
return chunk[1:]
class TransportV2(Transport):
def write(self, msg):
if not self.session_id:
raise Exception('Missing session_id for v2 transport')
data = bytearray(msg.SerializeToString())
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
while len(data):
if seq < 0:
repheader = struct.pack(">BL", 0x01, self.session_id)
else:
repheader = struct.pack(">BLL", 0x02, self.session_id, seq)
datalen = 64 - len(repheader)
chunk = repheader + data[:datalen] + b'\0' * (datalen - len(data[:datalen]))
self._write_chunk(chunk)
data = data[datalen:]
seq += 1
def _read(self):
if not self.session_id:
raise Exception('Missing session_id for v2 transport')
chunk = self._read_chunk()
(session_id, msg_type, datalen, data) = self.parse_first(chunk)
while len(data) < datalen:
chunk = self._read_chunk()
(next_session_id, next_data) = self.parse_next(chunk)
if next_session_id != session_id:
raise Exception("Session id mismatch")
data.extend(next_data)
data = data[:datalen] # Strip padding
return (session_id, msg_type, data)
def parse_first(self, chunk):
try:
headerlen = struct.calcsize(">BLLL")
(magic, session_id, msg_type, datalen) = struct.unpack(">BLLL", bytes(chunk[:headerlen]))
except:
raise Exception("Cannot parse header")
if magic != 0x01:
raise Exception("Unexpected magic character")
return (session_id, msg_type, datalen, chunk[headerlen:])
def parse_next(self, chunk):
try:
headerlen = struct.calcsize(">BLL")
(magic, session_id, sequence) = struct.unpack(">BLL", bytes(chunk[:headerlen]))
except:
raise Exception("Cannot parse header")
if magic != 0x02:
raise Exception("Unexpected magic characters")
return (session_id, chunk[headerlen:])
def parse_session_open(self, chunk):
try:
headerlen = struct.calcsize(">BL")
(magic, session_id) = struct.unpack(">BL", bytes(chunk[:headerlen]))
except:
raise Exception("Cannot parse header")
if magic != 0x03:
raise Exception("Unexpected magic character")
return session_id
def _session_begin(self):
self._write_chunk(bytearray(b'\x03' + b'\0' * 63))
self.session_id = self.parse_session_open(self._read_chunk())
def _session_end(self):
header = struct.pack(">L", self.session_id)
self._write_chunk(bytearray(b'\x04' + header + b'\0' * (63 - len(header))))
if self._read_chunk()[0] != 0x04:
raise Exception("Expected session close")
self.session_id = None
'''
def read_headers(self, read_f):
c = read_f.read(2)
if c != b"?!":
raise Exception("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
(session_id, msg_type, datalen) = struct.unpack(">LLL", read_f.read(headerlen))
except:
raise Exception("Cannot parse header length")
return (0, msg_type, datalen)
'''
raise NotImplementedError

View File

@ -17,14 +17,13 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
'''BridgeTransport implements transport TREZOR Bridge (aka trezord).'''
from __future__ import absolute_import
import binascii
import json
import requests
from google.protobuf import json_format
from . import messages_pb2 as proto
from .transport import TransportV1
from . import messages_pb2
from .transport import Transport
TREZORD_HOST = 'https://localback.net:21324'
CONFIG_URL = 'https://wallet.trezor.io/data/config_signed.bin'
@ -34,93 +33,84 @@ def get_error(resp):
return ' (error=%d str=%s)' % (resp.status_code, resp.json()['error'])
class BridgeTransport(TransportV1):
class BridgeTransport(Transport):
'''
BridgeTransport implements transport through TREZOR Bridge (aka trezord).
'''
CONFIGURED = False
configured = False
def __init__(self, device, *args, **kwargs):
self.configure()
self.path = device['path']
def __init__(self, device):
super(BridgeTransport, self).__init__()
self.device = device
self.conn = requests.Session()
self.session = None
self.response = None
self.conn = requests.Session()
super(BridgeTransport, self).__init__(device, *args, **kwargs)
def __str__(self):
return self.device['path']
@staticmethod
def configure():
if BridgeTransport.CONFIGURED:
if BridgeTransport.configured:
return
r = requests.get(CONFIG_URL, verify=False)
if r.status_code != 200:
raise Exception('Could not fetch config from %s' % CONFIG_URL)
config = r.text
r = requests.post(TREZORD_HOST + '/configure', data=config)
r = requests.post(TREZORD_HOST + '/configure', data=r.text)
if r.status_code != 200:
raise Exception('trezord: Could not configure' + get_error(r))
BridgeTransport.CONFIGURED = True
BridgeTransport.configured = True
@classmethod
def enumerate(cls):
"""
Return a list of available TREZOR devices.
"""
cls.configure()
@staticmethod
def enumerate():
BridgeTransport.configure()
r = requests.get(TREZORD_HOST + '/enumerate')
if r.status_code != 200:
raise Exception('trezord: Could not enumerate devices' + get_error(r))
enum = r.json()
return enum
raise Exception('trezord: Could not enumerate devices' +
get_error(r))
return [BridgeTransport(dev) for dev in r.json()]
@classmethod
def find_by_path(cls, path=None):
"""
Finds a device by transport-specific path.
If path is not set, return first device.
"""
devices = cls.enumerate()
for dev in devices:
if not path or dev['path'] == binascii.hexlify(path):
return cls(dev)
raise Exception('Device not found')
@staticmethod
def find_by_path(path):
for transport in BridgeTransport.enumerate():
if path is None or transport.device['path'] == path:
return transport
raise Exception('Bridge device not found')
def _open(self):
r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.path)
def open(self):
r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.device['path'])
if r.status_code != 200:
raise Exception('trezord: Could not acquire session' + get_error(r))
resp = r.json()
self.session = resp['session']
raise Exception('trezord: Could not acquire session' +
get_error(r))
self.session = r.json()['session']
def _close(self):
def close(self):
if not self.session:
return
r = self.conn.post(TREZORD_HOST + '/release/%s' % self.session)
if r.status_code != 200:
raise Exception('trezord: Could not release session' + get_error(r))
else:
self.session = None
raise Exception('trezord: Could not release session' +
get_error(r))
self.session = None
def _ready_to_read(self):
return self.response is not None
def write(self, protobuf_msg):
# Override main 'write' method, HTTP transport cannot be
# splitted to chunks
cls = protobuf_msg.__class__.__name__
msg = json_format.MessageToJson(protobuf_msg, preserving_proto_field_name=True)
payload = '{"type": "%s", "message": %s}' % (cls, msg)
r = self.conn.post(TREZORD_HOST + '/call/%s' % self.session, data=payload)
def write(self, msg):
msgname = msg.__class__.__name__
msgjson = json_format.MessageToJson(
msg, preserving_proto_field_name=True)
payload = '{"type": "%s", "message": %s}' % (msgname, msgjson)
r = self.conn.post(
TREZORD_HOST + '/call/%s' % self.session, data=payload)
if r.status_code != 200:
raise Exception('trezord: Could not write message' + get_error(r))
else:
self.response = r.json()
self.response = r.json()
def _read(self):
def read(self):
if self.response is None:
raise Exception('No response stored')
cls = getattr(proto, self.response['type'])
inst = cls()
pb = json_format.ParseDict(self.response['message'], inst)
return (0, 'protobuf', pb)
msgtype = getattr(messages_pb2, self.response['type'])
msg = msgtype()
msg = json_format.ParseDict(self.response['message'], msg)
self.response = None
return msg

View File

@ -16,168 +16,136 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
'''USB HID implementation of Transport.'''
from __future__ import absolute_import
import time
import hid
from .transport import TransportV1, TransportV2, ConnectionError
from .protocol_v1 import ProtocolV1
from .protocol_v2 import ProtocolV2
from .transport import Transport
DEV_TREZOR1 = (0x534c, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53c0)
DEV_TREZOR2_BL = (0x1209, 0x1201)
def enumerate():
"""
Return a list of available TREZOR devices.
"""
devices = {}
for d in hid.enumerate(0, 0):
vendor_id = d['vendor_id']
product_id = d['product_id']
serial_number = d['serial_number']
interface_number = d['interface_number']
usage_page = d['usage_page']
path = d['path']
class HidTransport(Transport):
'''
HidTransport implements transport over USB HID interface.
'''
if (vendor_id, product_id) in DEVICE_IDS:
devices.setdefault(serial_number, [None, None])
# first match by usage_page, then try interface number
if usage_page == 0xFF00 or interface_number == 0: # normal link
devices[serial_number][0] = path
elif usage_page == 0xFF01 or interface_number == 1: # debug link
devices[serial_number][1] = path
def __init__(self, device, protocol=None):
super(HidTransport, self).__init__()
# List of two-tuples (path_normal, path_debuglink)
return sorted(devices.values())
def find_by_path(path=None):
"""
Finds a device by transport-specific path.
If path is not set, return first device.
"""
devices = enumerate()
for dev in devices:
if not path or path in dev:
return HidTransport(dev)
raise Exception('Device not found')
def path_to_transport(path):
try:
device = [d for d in hid.enumerate(0, 0) if d['path'] == path][0]
except IndexError:
raise ConnectionError("Connection failed")
# VID/PID found, let's find proper transport
try:
transport = DEVICE_TRANSPORTS[(device['vendor_id'], device['product_id'])]
except IndexError:
raise Exception("Unknown transport for VID:PID %04x:%04x" % (device['vendor_id'], device['product_id']))
return transport
class _HidTransport(object):
def __init__(self, device, *args, **kwargs):
if protocol is None:
if is_trezor2(device):
protocol = ProtocolV2()
else:
protocol = ProtocolV1()
self.device = device
self.protocol = protocol
self.hid = None
self.hid_version = None
device = device[int(bool(kwargs.get('debug_link')))]
super(_HidTransport, self).__init__(device, *args, **kwargs)
def __str__(self):
return self.device['path']
def is_connected(self):
"""
Check if the device is still connected.
"""
for d in hid.enumerate(0, 0):
if d['path'] == self.device:
return True
return False
@staticmethod
def enumerate(debug=False):
return [
HidTransport(dev) for dev in hid.enumerate(0, 0)
if ((is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)) and
(is_debug(dev) == debug))
]
def _open(self):
@staticmethod
def find_by_path(path=None):
for transport in HidTransport.enumerate():
if path is None or transport.device['path'] == path:
return transport
raise Exception('HID device not found')
def find_debug(self):
if isinstance(self.protocol, ProtocolV2):
# For v2 protocol, lets use the same HID interface, but with a different session
debug = HidTransport(self.device, ProtocolV2())
debug.hid = self.hid
debug.hid_version = self.hid_version
return debug
if isinstance(self.protocol, ProtocolV1):
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device['serial_number'] == self.device['serial_number']:
return debug
def open(self):
if self.hid:
return
self.hid = hid.device()
self.hid.open_path(self.device)
self.hid.open_path(self.device['path'])
self.hid.set_nonblocking(True)
# determine hid_version
if isinstance(self, HidTransportV2):
self.hid_version = 2
if is_trezor1(self.device):
self.hid_version = self.probe_hid_version()
else:
r = self.hid.write([0, 63, ] + [0xFF] * 63)
if r == 65:
self.hid_version = 2
return
r = self.hid.write([63, ] + [0xFF] * 63)
if r == 64:
self.hid_version = 1
return
raise ConnectionError("Unknown HID version")
self.hid_version = 2
self.protocol.session_begin(self)
def _close(self):
self.hid.close()
def close(self):
self.protocol.session_end(self)
try:
self.hid.close()
except OSError:
pass # Failing to close the handle is not a problem
self.hid = None
self.hid_version = None
def _write_chunk(self, chunk):
def read(self):
return self.protocol.read(self)
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64:
raise Exception("Unexpected data length")
raise Exception('Unexpected chunk size: %d' % len(chunk))
if self.hid_version == 2:
self.hid.write(b'\0' + chunk)
else:
self.hid.write(chunk)
def _read_chunk(self):
start = time.time()
def read_chunk(self):
while True:
data = self.hid.read(64)
if not len(data):
if time.time() - start > 10:
# Over 10 s of no response, let's check if
# device is still alive
if not self.is_connected():
raise ConnectionError("Connection failed")
# Restart timer
start = time.time()
chunk = self.hid.read(64)
if chunk:
break
else:
time.sleep(0.001)
continue
if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk))
return bytearray(chunk)
break
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
return bytearray(data)
def probe_hid_version(self):
n = self.hid.write([0, 63] + [0xFF] * 63)
if n == 65:
return 2
n = self.hid.write([63] + [0xFF] * 63)
if n == 64:
return 1
raise Exception('Unknown HID version')
class HidTransportV1(_HidTransport, TransportV1):
pass
def is_trezor1(dev):
return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR1
class HidTransportV2(_HidTransport, TransportV2):
pass
def is_trezor2(dev):
return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR2
DEVICE_IDS = [
(0x534c, 0x0001), # TREZOR
(0x1209, 0x53c0), # TREZORv2 Bootloader
(0x1209, 0x53c1), # TREZORv2
]
DEVICE_TRANSPORTS = {
(0x534c, 0x0001): HidTransportV1, # TREZOR
(0x1209, 0x53c0): HidTransportV1, # TREZORv2 Bootloader
(0x1209, 0x53c1): HidTransportV2, # TREZORv2
}
def is_trezor2_bl(dev):
return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR2_BL
# Backward compatible wrapper, decides for proper transport
# based on VID/PID of given path
def HidTransport(device, *args, **kwargs):
transport = path_to_transport(device[0])
return transport(device, *args, **kwargs)
# Backward compatibility hack; HidTransport is a function, not a class like before
HidTransport.enumerate = enumerate
HidTransport.find_by_path = find_by_path
def is_debug(dev):
return (dev['usage_page'] == 0xFF01 or dev['interface_number'] == 1)

View File

@ -16,91 +16,94 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from __future__ import print_function
from __future__ import absolute_import
import os
import time
from select import select
from .transport import TransportV1
"""PipeTransport implements fake wire transport over local named pipe.
Use this transport for talking with trezor simulator."""
from .protocol_v1 import ProtocolV1
class PipeTransport(TransportV1):
class PipeTransport(object):
'''
PipeTransport implements fake wire transport over local named pipe.
Use this transport for talking with trezor-emu.
'''
def __init__(self, device=None, is_device=False):
super(PipeTransport, self).__init__()
def __init__(self, device='/tmp/pipe.trezor', is_device=False, *args, **kwargs):
if not device:
device = '/tmp/pipe.trezor'
self.is_device = is_device # set True if act as device
self.device = device
self.is_device = is_device
self.filename_read = None
self.filename_write = None
self.read_f = None
self.write_f = None
self.protocol = ProtocolV1()
super(PipeTransport, self).__init__(device, *args, **kwargs)
def __str__(self):
return self.device
@classmethod
def enumerate(cls):
raise Exception('This transport cannot enumerate devices')
@staticmethod
def enumerate():
raise NotImplementedError('This transport cannot enumerate devices')
@classmethod
def find_by_path(cls, path=None):
return cls(path)
@staticmethod
def find_by_path(path=None):
return PipeTransport(path)
def _open(self):
def open(self):
if self.is_device:
self.filename_read = self.device + '.to'
self.filename_write = self.device + '.from'
os.mkfifo(self.filename_read, 0o600)
os.mkfifo(self.filename_write, 0o600)
else:
self.filename_read = self.device + '.from'
self.filename_write = self.device + '.to'
if not os.path.exists(self.filename_write):
raise Exception("Not connected")
self.write_fd = os.open(self.filename_write, os.O_RDWR) # |os.O_NONBLOCK)
self.write_f = os.fdopen(self.write_fd, 'w+b', 0)
self.read_f = os.open(self.filename_read, 'rb', 0)
self.write_f = os.open(self.filename_write, 'w+b', 0)
self.read_fd = os.open(self.filename_read, os.O_RDWR) # |os.O_NONBLOCK)
self.read_f = os.fdopen(self.read_fd, 'rb', 0)
self.protocol.session_begin(self)
def _close(self):
self.read_f.close()
self.write_f.close()
def close(self):
self.protocol.session_end(self)
if self.read_f:
self.read_f.close()
self.read_f = None
if self.write_f:
self.write_f.close()
self.write_f = None
if self.is_device:
os.unlink(self.filename_read)
os.unlink(self.filename_write)
self.filename_read = None
self.filename_write = None
def _ready_to_read(self):
rlist, _, _ = select([self.read_f], [], [], 0)
return len(rlist) > 0
def read(self):
return self.protocol.read(self)
def _write_chunk(self, chunk):
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64:
raise Exception("Unexpected data length")
raise Exception('Unexpected chunk size: %d' % len(chunk))
self.write_f.write(chunk)
self.write_f.flush()
try:
self.write_f.write(chunk)
self.write_f.flush()
except OSError:
print("Error while writing to socket")
raise
def _read_chunk(self):
def read_chunk(self):
while True:
try:
data = self.read_f.read(64)
except IOError:
print("Failed to read from device")
raise
if not len(data):
chunk = self.read_f.read(64)
if chunk:
break
else:
time.sleep(0.001)
continue
break
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
return bytearray(data)
if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk))
return bytearray(chunk)

View File

@ -16,66 +16,76 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
'''UDP Socket implementation of Transport.'''
from __future__ import absolute_import
import socket
from select import select
from .transport import TransportV2
from .protocol_v2 import ProtocolV2
from .transport import Transport
class UdpTransport(TransportV2):
class UdpTransport(Transport):
def __init__(self, device, *args, **kwargs):
if device is None:
device = ''
device = device.split(':')
if len(device) < 2:
if not device[0]:
# Default port used by trezor v2
device = ('127.0.0.1', 21324)
else:
device = ('127.0.0.1', int(device[0]))
DEFAULT_HOST = '127.0.0.1'
DEFAULT_PORT = 21324
def __init__(self, device=None, protocol=None):
super(UdpTransport, self).__init__()
if not device:
host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT
else:
device = (device[0], int(device[1]))
host = device.split(':').get(0)
port = device.split(':').get(1, UdpTransport.DEFAULT_PORT)
port = int(port)
if not protocol:
protocol = ProtocolV2()
self.device = (host, port)
self.protocol = protocol
self.socket = None
super(UdpTransport, self).__init__(device, *args, **kwargs)
@classmethod
def enumerate(cls):
raise Exception('This transport cannot enumerate devices')
def __str__(self):
return self.device
@classmethod
def find_by_path(cls, path=None):
return cls(path)
@staticmethod
def enumerate():
raise NotImplementedError('This transport cannot enumerate devices')
def _open(self):
@staticmethod
def find_by_path(path=None):
return UdpTransport(path)
def open(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device)
self.socket.settimeout(10)
self.protocol.session_begin(self)
def _close(self):
self.socket.close()
self.socket = None
def close(self):
if self.socket:
self.protocol.session_end(self)
self.socket.close()
self.socket = None
def _ready_to_read(self):
rlist, _, _ = select([self.socket], [], [], 0)
return len(rlist) > 0
def read(self):
return self.protocol.read(self)
def _write_chunk(self, chunk):
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64:
raise Exception("Unexpected data length")
raise Exception('Unexpected data length')
self.socket.sendall(chunk)
def _read_chunk(self):
def read_chunk(self):
while True:
try:
data = self.socket.recv(64)
chunk = self.socket.recv(64)
break
except socket.timeout:
continue
if len(data) != 64:
raise Exception("Unexpected chunk size: %d" % len(data))
return bytearray(data)
if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk))
return bytearray(chunk)