From bc42eb68d6d874d02f959a48886c9518afbb0b8b Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Thu, 24 Aug 2017 14:29:27 +0200 Subject: [PATCH] transports: refactor, split protocol code --- tests/device_tests/common.py | 8 +- trezorctl | 13 +- trezorlib/client.py | 23 +-- trezorlib/debuglink.py | 3 +- trezorlib/protocol_v1.py | 80 +++++++++++ trezorlib/protocol_v2.py | 127 +++++++++++++++++ trezorlib/transport.py | 261 ++-------------------------------- trezorlib/transport_bridge.py | 120 +++++++--------- trezorlib/transport_hid.py | 228 +++++++++++++---------------- trezorlib/transport_pipe.py | 111 ++++++++------- trezorlib/transport_udp.py | 88 +++++++----- 11 files changed, 503 insertions(+), 559 deletions(-) create mode 100644 trezorlib/protocol_v1.py create mode 100644 trezorlib/protocol_v2.py diff --git a/tests/device_tests/common.py b/tests/device_tests/common.py index 29d2dd0..a55b082 100644 --- a/tests/device_tests/common.py +++ b/tests/device_tests/common.py @@ -59,16 +59,16 @@ def pipe_exists(path): return False -if HID_ENABLED and len(HidTransport.enumerate()) > 0: +if HID_ENABLED and HidTransport.enumerate(): devices = HidTransport.enumerate() print('Using TREZOR') TRANSPORT = HidTransport TRANSPORT_ARGS = (devices[0],) - TRANSPORT_KWARGS = {'debug_link': False} + TRANSPORT_KWARGS = {} DEBUG_TRANSPORT = HidTransport - DEBUG_TRANSPORT_ARGS = (devices[0],) - DEBUG_TRANSPORT_KWARGS = {'debug_link': True} + DEBUG_TRANSPORT_ARGS = (devices[0].find_debug(),) + DEBUG_TRANSPORT_KWARGS = {} elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'): diff --git a/trezorctl b/trezorctl index 821ea7e..136c49a 100755 --- a/trezorctl +++ b/trezorctl @@ -65,11 +65,12 @@ def cli(ctx, transport, path, verbose, is_json): if ctx.invoked_subcommand == 'list': ctx.obj = transport else: - t = get_transport(transport, path) + def connect(): + return get_transport(transport, path) if verbose: - ctx.obj = TrezorClientVerbose(t) + ctx.obj = TrezorClientVerbose(connect) else: - ctx.obj = TrezorClient(t) + ctx.obj = TrezorClient(connect) @cli.resultcallback() @@ -108,11 +109,7 @@ def print_result(res, transport, path, verbose, is_json): def ls(transport_name): transport_class = get_transport_class_by_name(transport_name) devices = transport_class.enumerate() - if transport_name == 'usb': - return [dev[0] for dev in devices] - if transport_name == 'bridge': - return devices - return [] + return devices # diff --git a/trezorlib/client.py b/trezorlib/client.py index a42005b..5962c1d 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -153,11 +153,11 @@ def session(f): # with session activation / deactivation def wrapped_f(*args, **kwargs): client = args[0] + client.get_transport().session_begin() try: - client.transport.session_begin() return f(*args, **kwargs) finally: - client.transport.session_end() + client.get_transport().session_end() return wrapped_f @@ -179,17 +179,23 @@ def normalize_nfc(txt): class BaseClient(object): # Implements very basic layer of sending raw protobuf # messages to device and getting its response back. - def __init__(self, transport, **kwargs): - self.transport = transport + def __init__(self, connect, **kwargs): + self.connect = connect + self.transport = None super(BaseClient, self).__init__() # *args, **kwargs) + def get_transport(self): + if self.transport is None: + self.transport = self.connect() + return self.transport + def cancel(self): - self.transport.write(proto.Cancel()) + self.get_transport().write(proto.Cancel()) @session def call_raw(self, msg): - self.transport.write(msg) - return self.transport.read_blocking() + self.get_transport().write(msg) + return self.get_transport().read() @session def call(self, msg): @@ -212,9 +218,6 @@ class BaseClient(object): raise CallException(msg.code, msg.message) - def close(self): - self.transport.close() - class VerboseWireMixin(object): def call_raw(self, msg): diff --git a/trezorlib/debuglink.py b/trezorlib/debuglink.py index 2e91a34..1e06985 100644 --- a/trezorlib/debuglink.py +++ b/trezorlib/debuglink.py @@ -43,14 +43,13 @@ class DebugLink(object): def close(self): self.transport.session_end() - self.transport.close() def _call(self, msg, nowait=False): print("DEBUGLINK SEND", pprint(msg)) self.transport.write(msg) if nowait: return - ret = self.transport.read_blocking() + ret = self.transport.read() print("DEBUGLINK RECV", pprint(ret)) return ret diff --git a/trezorlib/protocol_v1.py b/trezorlib/protocol_v1.py new file mode 100644 index 0000000..9bf6b59 --- /dev/null +++ b/trezorlib/protocol_v1.py @@ -0,0 +1,80 @@ +# This file is part of the TREZOR project. +# +# Copyright (C) 2012-2016 Marek Palatinus +# Copyright (C) 2012-2016 Pavol Rusnak +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + +from __future__ import absolute_import + +import struct +from . import mapping + +REPLEN = 64 + + +class ProtocolV1(object): + + def session_begin(self, transport): + pass + + def session_end(self, transport): + pass + + def write(self, transport, msg): + ser = msg.SerializeToString() + header = struct.pack(">HL", mapping.get_type(msg), len(ser)) + data = bytearray(b"##" + header + ser) + + while data: + # Report ID, data padded to 63 bytes + chunk = b'?' + data[:REPLEN-1] + chunk = chunk.ljust(REPLEN, bytes([0x00])) + transport.write_chunk(chunk) + data = data[63:] + + def read(self, transport): + # Read header with first part of message data + chunk = transport.read_chunk() + (msg_type, datalen, data) = self.parse_first(chunk) + + # Read the rest of the message + while len(data) < datalen: + chunk = transport.read_chunk() + data.extend(self.parse_next(chunk)) + + # Strip padding zeros + data = data[:datalen] + + # Parse to protobuf + msg = mapping.get_class(msg_type)() + msg.ParseFromString(bytes(data)) + return msg + + def parse_first(self, chunk): + if chunk[:3] != b'?##': + raise Exception('Unexpected magic characters') + try: + headerlen = struct.calcsize('>HL') + (msg_type, datalen) = struct.unpack('>HL', bytes(chunk[3:3 + headerlen])) + except: + raise Exception('Cannot parse header') + + data = chunk[3 + headerlen:] + return (msg_type, datalen, data) + + def parse_next(self, chunk): + if chunk[:1] != b'?': + raise Exception('Unexpected magic characters') + return chunk[1:] diff --git a/trezorlib/protocol_v2.py b/trezorlib/protocol_v2.py new file mode 100644 index 0000000..cea18ea --- /dev/null +++ b/trezorlib/protocol_v2.py @@ -0,0 +1,127 @@ +# This file is part of the TREZOR project. +# +# Copyright (C) 2012-2016 Marek Palatinus +# Copyright (C) 2012-2016 Pavol Rusnak +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + +from __future__ import absolute_import + +import struct +from . import mapping + +REPLEN = 64 + + +class ProtocolV2(object): + + def __init__(self): + self.session = None + + def session_begin(self, transport): + chunk = struct.pack('>B', 0x03) + chunk = chunk.ljust(REPLEN, bytes([0x00])) + transport.write_chunk(chunk) + resp = transport.read_chunk() + self.session = self.parse_session_open(resp) + + def session_end(self, transport): + if not self.session: + return + chunk = struct.pack('>BL', 0x04, self.session) + chunk = chunk.ljust(REPLEN, bytes([0x00])) + transport.write_chunk(chunk) + resp = transport.read_chunk() + if resp[0] != 0x04: + raise Exception('Expected session close') + self.session = None + + def write(self, transport, msg): + if not self.session: + raise Exception('Missing session for v2 protocol') + + # Serialize whole message + data = bytearray(msg.SerializeToString()) + dataheader = struct.pack('>LL', mapping.get_type(msg), len(data)) + data = dataheader + data + seq = -1 + + # Write it out + while data: + if seq < 0: + repheader = struct.pack('>BL', 0x01, self.session) + else: + repheader = struct.pack('>BLL', 0x02, self.session, seq) + datalen = REPLEN - len(repheader) + chunk = repheader + data[:datalen] + chunk = chunk.ljust(REPLEN, bytes([0x00])) + transport.write_chunk(chunk) + data = data[datalen:] + seq += 1 + + def read(self, transport): + if not self.session: + raise Exception('Missing session for v2 protocol') + + # Read header with first part of message data + chunk = transport.read_chunk() + msg_type, datalen, data = self.parse_first(chunk) + + # Read the rest of the message + while len(data) < datalen: + chunk = transport.read_chunk() + next_data = self.parse_next(chunk) + data.extend(next_data) + + # Strip padding + data = data[:datalen] + + # Parse to protobuf + msg = mapping.get_class(msg_type)() + msg.ParseFromString(bytes(data)) + return msg + + def parse_first(self, chunk): + try: + headerlen = struct.calcsize('>BLLL') + (magic, session, msg_type, datalen) = struct.unpack('>BLLL', bytes(chunk[:headerlen])) + except: + raise Exception('Cannot parse header') + if magic != 0x01: + raise Exception('Unexpected magic character') + if session != self.session: + raise Exception('Session id mismatch') + return msg_type, datalen, chunk[headerlen:] + + def parse_next(self, chunk): + try: + headerlen = struct.calcsize('>BLL') + (magic, session, sequence) = struct.unpack('>BLL', bytes(chunk[:headerlen])) + except: + raise Exception('Cannot parse header') + if magic != 0x02: + raise Exception('Unexpected magic characters') + if session != self.session: + raise Exception('Session id mismatch') + return chunk[headerlen:] + + def parse_session_open(self, chunk): + try: + headerlen = struct.calcsize('>BL') + (magic, session) = struct.unpack('>BL', bytes(chunk[:headerlen])) + except: + raise Exception('Cannot parse header') + if magic != 0x03: + raise Exception('Unexpected magic character') + return session \ No newline at end of file diff --git a/trezorlib/transport.py b/trezorlib/transport.py index c2c9702..3bf6dbd 100644 --- a/trezorlib/transport.py +++ b/trezorlib/transport.py @@ -2,6 +2,7 @@ # # Copyright (C) 2012-2016 Marek Palatinus # Copyright (C) 2012-2016 Pavol Rusnak +# Copyright (C) 2016 Jochen Hoenicke # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -18,258 +19,24 @@ from __future__ import absolute_import -import struct -import binascii -from . import mapping - - -class NotImplementedException(Exception): - pass - - -class ConnectionError(Exception): - pass - class Transport(object): - def __init__(self, device, *args, **kwargs): - self.device = device - self.session_id = 0 - self.session_depth = 0 - self._open() + + def __init__(self): + self.session_counter = 0 def session_begin(self): - """ - Apply a lock to the device in order to preform synchronous multistep "conversations" with the device. For example, before entering the transaction signing workflow, one begins a session. After the transaction is complete, the session may be ended. - """ - if self.session_depth == 0: - self._session_begin() - self.session_depth += 1 + if self.session_counter == 0: + self.open() + self.session_counter += 1 def session_end(self): - """ - End a session. Se session_begin for an in depth description of TREZOR sessions. - """ - self.session_depth -= 1 - self.session_depth = max(0, self.session_depth) - if self.session_depth == 0: - self._session_end() + self.session_counter = max(self.session_counter - 1, 0) + if self.session_counter == 0: + self.close() + + def open(self): + raise NotImplementedError def close(self): - """ - Close the connection to the physical device or file descriptor represented by the Transport. - """ - self._close() - - def write(self, msg): - """ - Write mesage to tansport. msg should be a member of a valid `protobuf class `_ with a SerializeToString() method. - """ - raise NotImplementedException("Not implemented") - - def read(self): - """ - If there is data available to be read from the transport, reads the data and tries to parse it as a protobuf message. If the parsing succeeds, return a protobuf object. - Otherwise, returns None. - """ - if not self._ready_to_read(): - return None - - data = self._read() - if data is None: - return None - - return self._parse_message(data) - - def read_blocking(self): - """ - Same as read, except blocks until data is available to be read. - """ - while True: - data = self._read() - if data is not None: - break - - return self._parse_message(data) - - def _parse_message(self, data): - (session_id, msg_type, data) = data - - # Raise exception if we get the response with unexpected session ID - if session_id != self.session_id: - raise Exception("Session ID mismatch. Have %d, got %d" % - (self.session_id, session_id)) - - if msg_type == 'protobuf': - return data - else: - inst = mapping.get_class(msg_type)() - inst.ParseFromString(bytes(data)) - return inst - - # Functions to be implemented in specific transports: - def _open(self): - raise NotImplementedException("Not implemented") - - def _close(self): - raise NotImplementedException("Not implemented") - - def _write_chunk(self, chunk): - raise NotImplementedException("Not implemented") - - def _read_chunk(self): - raise NotImplementedException("Not implemented") - - def _ready_to_read(self): - """ - Returns True if there is data to be read from the transport. Otherwise, False. - """ - raise NotImplementedException("Not implemented") - - def _session_begin(self): - pass - - def _session_end(self): - pass - - -class TransportV1(Transport): - def write(self, msg): - ser = msg.SerializeToString() - header = struct.pack(">HL", mapping.get_type(msg), len(ser)) - data = bytearray(b"##" + header + ser) - - while len(data): - # Report ID, data padded to 63 bytes - chunk = b'?' + data[:63] + b'\0' * (63 - len(data[:63])) - self._write_chunk(chunk) - data = data[63:] - - def _read(self): - chunk = self._read_chunk() - (msg_type, datalen, data) = self.parse_first(chunk) - - while len(data) < datalen: - chunk = self._read_chunk() - data.extend(self.parse_next(chunk)) - - # Strip padding zeros - data = data[:datalen] - return (0, msg_type, data) - - def parse_first(self, chunk): - if chunk[:3] != b"?##": - raise Exception("Unexpected magic characters") - - try: - headerlen = struct.calcsize(">HL") - (msg_type, datalen) = struct.unpack(">HL", bytes(chunk[3:3 + headerlen])) - except: - raise Exception("Cannot parse header") - - data = chunk[3 + headerlen:] - return (msg_type, datalen, data) - - def parse_next(self, chunk): - if chunk[0:1] != b"?": - raise Exception("Unexpected magic characters") - - return chunk[1:] - - -class TransportV2(Transport): - def write(self, msg): - if not self.session_id: - raise Exception('Missing session_id for v2 transport') - - data = bytearray(msg.SerializeToString()) - - dataheader = struct.pack(">LL", mapping.get_type(msg), len(data)) - data = dataheader + data - seq = -1 - - while len(data): - if seq < 0: - repheader = struct.pack(">BL", 0x01, self.session_id) - else: - repheader = struct.pack(">BLL", 0x02, self.session_id, seq) - datalen = 64 - len(repheader) - chunk = repheader + data[:datalen] + b'\0' * (datalen - len(data[:datalen])) - self._write_chunk(chunk) - data = data[datalen:] - seq += 1 - - def _read(self): - if not self.session_id: - raise Exception('Missing session_id for v2 transport') - - chunk = self._read_chunk() - (session_id, msg_type, datalen, data) = self.parse_first(chunk) - - while len(data) < datalen: - chunk = self._read_chunk() - (next_session_id, next_data) = self.parse_next(chunk) - - if next_session_id != session_id: - raise Exception("Session id mismatch") - - data.extend(next_data) - - data = data[:datalen] # Strip padding - return (session_id, msg_type, data) - - def parse_first(self, chunk): - try: - headerlen = struct.calcsize(">BLLL") - (magic, session_id, msg_type, datalen) = struct.unpack(">BLLL", bytes(chunk[:headerlen])) - except: - raise Exception("Cannot parse header") - if magic != 0x01: - raise Exception("Unexpected magic character") - return (session_id, msg_type, datalen, chunk[headerlen:]) - - def parse_next(self, chunk): - try: - headerlen = struct.calcsize(">BLL") - (magic, session_id, sequence) = struct.unpack(">BLL", bytes(chunk[:headerlen])) - except: - raise Exception("Cannot parse header") - if magic != 0x02: - raise Exception("Unexpected magic characters") - return (session_id, chunk[headerlen:]) - - def parse_session_open(self, chunk): - try: - headerlen = struct.calcsize(">BL") - (magic, session_id) = struct.unpack(">BL", bytes(chunk[:headerlen])) - except: - raise Exception("Cannot parse header") - if magic != 0x03: - raise Exception("Unexpected magic character") - return session_id - - def _session_begin(self): - self._write_chunk(bytearray(b'\x03' + b'\0' * 63)) - self.session_id = self.parse_session_open(self._read_chunk()) - - def _session_end(self): - header = struct.pack(">L", self.session_id) - self._write_chunk(bytearray(b'\x04' + header + b'\0' * (63 - len(header)))) - if self._read_chunk()[0] != 0x04: - raise Exception("Expected session close") - self.session_id = None - - ''' - def read_headers(self, read_f): - c = read_f.read(2) - if c != b"?!": - raise Exception("Unexpected magic characters") - - try: - headerlen = struct.calcsize(">HL") - (session_id, msg_type, datalen) = struct.unpack(">LLL", read_f.read(headerlen)) - except: - raise Exception("Cannot parse header length") - - return (0, msg_type, datalen) - ''' + raise NotImplementedError \ No newline at end of file diff --git a/trezorlib/transport_bridge.py b/trezorlib/transport_bridge.py index 0413281..bd924f8 100644 --- a/trezorlib/transport_bridge.py +++ b/trezorlib/transport_bridge.py @@ -17,14 +17,13 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -'''BridgeTransport implements transport TREZOR Bridge (aka trezord).''' +from __future__ import absolute_import -import binascii -import json import requests from google.protobuf import json_format -from . import messages_pb2 as proto -from .transport import TransportV1 + +from . import messages_pb2 +from .transport import Transport TREZORD_HOST = 'https://localback.net:21324' CONFIG_URL = 'https://wallet.trezor.io/data/config_signed.bin' @@ -34,93 +33,84 @@ def get_error(resp): return ' (error=%d str=%s)' % (resp.status_code, resp.json()['error']) -class BridgeTransport(TransportV1): +class BridgeTransport(Transport): + ''' + BridgeTransport implements transport through TREZOR Bridge (aka trezord). + ''' - CONFIGURED = False + configured = False - def __init__(self, device, *args, **kwargs): - self.configure() - - self.path = device['path'] + def __init__(self, device): + super(BridgeTransport, self).__init__() + self.device = device + self.conn = requests.Session() self.session = None self.response = None - self.conn = requests.Session() - super(BridgeTransport, self).__init__(device, *args, **kwargs) + def __str__(self): + return self.device['path'] @staticmethod def configure(): - if BridgeTransport.CONFIGURED: + if BridgeTransport.configured: return r = requests.get(CONFIG_URL, verify=False) if r.status_code != 200: raise Exception('Could not fetch config from %s' % CONFIG_URL) - - config = r.text - - r = requests.post(TREZORD_HOST + '/configure', data=config) + r = requests.post(TREZORD_HOST + '/configure', data=r.text) if r.status_code != 200: raise Exception('trezord: Could not configure' + get_error(r)) - BridgeTransport.CONFIGURED = True + BridgeTransport.configured = True - @classmethod - def enumerate(cls): - """ - Return a list of available TREZOR devices. - """ - cls.configure() + @staticmethod + def enumerate(): + BridgeTransport.configure() r = requests.get(TREZORD_HOST + '/enumerate') if r.status_code != 200: - raise Exception('trezord: Could not enumerate devices' + get_error(r)) - enum = r.json() - return enum + raise Exception('trezord: Could not enumerate devices' + + get_error(r)) + return [BridgeTransport(dev) for dev in r.json()] - @classmethod - def find_by_path(cls, path=None): - """ - Finds a device by transport-specific path. - If path is not set, return first device. - """ - devices = cls.enumerate() - for dev in devices: - if not path or dev['path'] == binascii.hexlify(path): - return cls(dev) - raise Exception('Device not found') + @staticmethod + def find_by_path(path): + for transport in BridgeTransport.enumerate(): + if path is None or transport.device['path'] == path: + return transport + raise Exception('Bridge device not found') - def _open(self): - r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.path) + def open(self): + r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.device['path']) if r.status_code != 200: - raise Exception('trezord: Could not acquire session' + get_error(r)) - resp = r.json() - self.session = resp['session'] + raise Exception('trezord: Could not acquire session' + + get_error(r)) + self.session = r.json()['session'] - def _close(self): + def close(self): + if not self.session: + return r = self.conn.post(TREZORD_HOST + '/release/%s' % self.session) if r.status_code != 200: - raise Exception('trezord: Could not release session' + get_error(r)) - else: - self.session = None + raise Exception('trezord: Could not release session' + + get_error(r)) + self.session = None - def _ready_to_read(self): - return self.response is not None - - def write(self, protobuf_msg): - # Override main 'write' method, HTTP transport cannot be - # splitted to chunks - cls = protobuf_msg.__class__.__name__ - msg = json_format.MessageToJson(protobuf_msg, preserving_proto_field_name=True) - payload = '{"type": "%s", "message": %s}' % (cls, msg) - r = self.conn.post(TREZORD_HOST + '/call/%s' % self.session, data=payload) + def write(self, msg): + msgname = msg.__class__.__name__ + msgjson = json_format.MessageToJson( + msg, preserving_proto_field_name=True) + payload = '{"type": "%s", "message": %s}' % (msgname, msgjson) + r = self.conn.post( + TREZORD_HOST + '/call/%s' % self.session, data=payload) if r.status_code != 200: raise Exception('trezord: Could not write message' + get_error(r)) - else: - self.response = r.json() + self.response = r.json() - def _read(self): + def read(self): if self.response is None: raise Exception('No response stored') - cls = getattr(proto, self.response['type']) - inst = cls() - pb = json_format.ParseDict(self.response['message'], inst) - return (0, 'protobuf', pb) + msgtype = getattr(messages_pb2, self.response['type']) + msg = msgtype() + msg = json_format.ParseDict(self.response['message'], msg) + self.response = None + return msg diff --git a/trezorlib/transport_hid.py b/trezorlib/transport_hid.py index c365bd6..17b745b 100644 --- a/trezorlib/transport_hid.py +++ b/trezorlib/transport_hid.py @@ -16,168 +16,136 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -'''USB HID implementation of Transport.''' +from __future__ import absolute_import import time import hid -from .transport import TransportV1, TransportV2, ConnectionError + +from .protocol_v1 import ProtocolV1 +from .protocol_v2 import ProtocolV2 +from .transport import Transport + +DEV_TREZOR1 = (0x534c, 0x0001) +DEV_TREZOR2 = (0x1209, 0x53c0) +DEV_TREZOR2_BL = (0x1209, 0x1201) -def enumerate(): - """ - Return a list of available TREZOR devices. - """ - devices = {} - for d in hid.enumerate(0, 0): - vendor_id = d['vendor_id'] - product_id = d['product_id'] - serial_number = d['serial_number'] - interface_number = d['interface_number'] - usage_page = d['usage_page'] - path = d['path'] +class HidTransport(Transport): + ''' + HidTransport implements transport over USB HID interface. + ''' - if (vendor_id, product_id) in DEVICE_IDS: - devices.setdefault(serial_number, [None, None]) - # first match by usage_page, then try interface number - if usage_page == 0xFF00 or interface_number == 0: # normal link - devices[serial_number][0] = path - elif usage_page == 0xFF01 or interface_number == 1: # debug link - devices[serial_number][1] = path + def __init__(self, device, protocol=None): + super(HidTransport, self).__init__() - # List of two-tuples (path_normal, path_debuglink) - return sorted(devices.values()) - - -def find_by_path(path=None): - """ - Finds a device by transport-specific path. - If path is not set, return first device. - """ - devices = enumerate() - for dev in devices: - if not path or path in dev: - return HidTransport(dev) - raise Exception('Device not found') - - -def path_to_transport(path): - try: - device = [d for d in hid.enumerate(0, 0) if d['path'] == path][0] - except IndexError: - raise ConnectionError("Connection failed") - - # VID/PID found, let's find proper transport - try: - transport = DEVICE_TRANSPORTS[(device['vendor_id'], device['product_id'])] - except IndexError: - raise Exception("Unknown transport for VID:PID %04x:%04x" % (device['vendor_id'], device['product_id'])) - - return transport - - -class _HidTransport(object): - def __init__(self, device, *args, **kwargs): + if protocol is None: + if is_trezor2(device): + protocol = ProtocolV2() + else: + protocol = ProtocolV1() + self.device = device + self.protocol = protocol self.hid = None self.hid_version = None - device = device[int(bool(kwargs.get('debug_link')))] - super(_HidTransport, self).__init__(device, *args, **kwargs) + def __str__(self): + return self.device['path'] - def is_connected(self): - """ - Check if the device is still connected. - """ - for d in hid.enumerate(0, 0): - if d['path'] == self.device: - return True - return False + @staticmethod + def enumerate(debug=False): + return [ + HidTransport(dev) for dev in hid.enumerate(0, 0) + if ((is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)) and + (is_debug(dev) == debug)) + ] - def _open(self): + @staticmethod + def find_by_path(path=None): + for transport in HidTransport.enumerate(): + if path is None or transport.device['path'] == path: + return transport + raise Exception('HID device not found') + + def find_debug(self): + if isinstance(self.protocol, ProtocolV2): + # For v2 protocol, lets use the same HID interface, but with a different session + debug = HidTransport(self.device, ProtocolV2()) + debug.hid = self.hid + debug.hid_version = self.hid_version + return debug + if isinstance(self.protocol, ProtocolV1): + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device['serial_number'] == self.device['serial_number']: + return debug + + def open(self): + if self.hid: + return self.hid = hid.device() - self.hid.open_path(self.device) + self.hid.open_path(self.device['path']) self.hid.set_nonblocking(True) - - # determine hid_version - if isinstance(self, HidTransportV2): - self.hid_version = 2 + if is_trezor1(self.device): + self.hid_version = self.probe_hid_version() else: - r = self.hid.write([0, 63, ] + [0xFF] * 63) - if r == 65: - self.hid_version = 2 - return - r = self.hid.write([63, ] + [0xFF] * 63) - if r == 64: - self.hid_version = 1 - return - raise ConnectionError("Unknown HID version") + self.hid_version = 2 + self.protocol.session_begin(self) - def _close(self): - self.hid.close() + def close(self): + self.protocol.session_end(self) + try: + self.hid.close() + except OSError: + pass # Failing to close the handle is not a problem self.hid = None + self.hid_version = None - def _write_chunk(self, chunk): + def read(self): + return self.protocol.read(self) + + def write(self, msg): + return self.protocol.write(self, msg) + + def write_chunk(self, chunk): if len(chunk) != 64: - raise Exception("Unexpected data length") - + raise Exception('Unexpected chunk size: %d' % len(chunk)) if self.hid_version == 2: self.hid.write(b'\0' + chunk) else: self.hid.write(chunk) - def _read_chunk(self): - start = time.time() - + def read_chunk(self): while True: - data = self.hid.read(64) - if not len(data): - if time.time() - start > 10: - # Over 10 s of no response, let's check if - # device is still alive - if not self.is_connected(): - raise ConnectionError("Connection failed") - - # Restart timer - start = time.time() - + chunk = self.hid.read(64) + if chunk: + break + else: time.sleep(0.001) - continue + if len(chunk) != 64: + raise Exception('Unexpected chunk size: %d' % len(chunk)) + return bytearray(chunk) - break - - if len(data) != 64: - raise Exception("Unexpected chunk size: %d" % len(data)) - - return bytearray(data) + def probe_hid_version(self): + n = self.hid.write([0, 63] + [0xFF] * 63) + if n == 65: + return 2 + n = self.hid.write([63] + [0xFF] * 63) + if n == 64: + return 1 + raise Exception('Unknown HID version') -class HidTransportV1(_HidTransport, TransportV1): - pass +def is_trezor1(dev): + return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR1 -class HidTransportV2(_HidTransport, TransportV2): - pass +def is_trezor2(dev): + return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR2 -DEVICE_IDS = [ - (0x534c, 0x0001), # TREZOR - (0x1209, 0x53c0), # TREZORv2 Bootloader - (0x1209, 0x53c1), # TREZORv2 -] - -DEVICE_TRANSPORTS = { - (0x534c, 0x0001): HidTransportV1, # TREZOR - (0x1209, 0x53c0): HidTransportV1, # TREZORv2 Bootloader - (0x1209, 0x53c1): HidTransportV2, # TREZORv2 -} +def is_trezor2_bl(dev): + return (dev['vendor_id'], dev['product_id']) == DEV_TREZOR2_BL -# Backward compatible wrapper, decides for proper transport -# based on VID/PID of given path -def HidTransport(device, *args, **kwargs): - transport = path_to_transport(device[0]) - return transport(device, *args, **kwargs) - - -# Backward compatibility hack; HidTransport is a function, not a class like before -HidTransport.enumerate = enumerate -HidTransport.find_by_path = find_by_path +def is_debug(dev): + return (dev['usage_page'] == 0xFF01 or dev['interface_number'] == 1) diff --git a/trezorlib/transport_pipe.py b/trezorlib/transport_pipe.py index 00e33eb..1d53b54 100644 --- a/trezorlib/transport_pipe.py +++ b/trezorlib/transport_pipe.py @@ -16,91 +16,94 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -from __future__ import print_function +from __future__ import absolute_import + import os import time -from select import select -from .transport import TransportV1 - -"""PipeTransport implements fake wire transport over local named pipe. -Use this transport for talking with trezor simulator.""" +from .protocol_v1 import ProtocolV1 -class PipeTransport(TransportV1): +class PipeTransport(object): + ''' + PipeTransport implements fake wire transport over local named pipe. + Use this transport for talking with trezor-emu. + ''' + + def __init__(self, device=None, is_device=False): + super(PipeTransport, self).__init__() - def __init__(self, device='/tmp/pipe.trezor', is_device=False, *args, **kwargs): if not device: device = '/tmp/pipe.trezor' - self.is_device = is_device # set True if act as device + self.device = device + self.is_device = is_device + self.filename_read = None + self.filename_write = None + self.read_f = None + self.write_f = None + self.protocol = ProtocolV1() - super(PipeTransport, self).__init__(device, *args, **kwargs) + def __str__(self): + return self.device - @classmethod - def enumerate(cls): - raise Exception('This transport cannot enumerate devices') + @staticmethod + def enumerate(): + raise NotImplementedError('This transport cannot enumerate devices') - @classmethod - def find_by_path(cls, path=None): - return cls(path) + @staticmethod + def find_by_path(path=None): + return PipeTransport(path) - def _open(self): + def open(self): if self.is_device: self.filename_read = self.device + '.to' self.filename_write = self.device + '.from' - os.mkfifo(self.filename_read, 0o600) os.mkfifo(self.filename_write, 0o600) else: self.filename_read = self.device + '.from' self.filename_write = self.device + '.to' - if not os.path.exists(self.filename_write): raise Exception("Not connected") - self.write_fd = os.open(self.filename_write, os.O_RDWR) # |os.O_NONBLOCK) - self.write_f = os.fdopen(self.write_fd, 'w+b', 0) + self.read_f = os.open(self.filename_read, 'rb', 0) + self.write_f = os.open(self.filename_write, 'w+b', 0) - self.read_fd = os.open(self.filename_read, os.O_RDWR) # |os.O_NONBLOCK) - self.read_f = os.fdopen(self.read_fd, 'rb', 0) + self.protocol.session_begin(self) - def _close(self): - self.read_f.close() - self.write_f.close() + def close(self): + self.protocol.session_end(self) + if self.read_f: + self.read_f.close() + self.read_f = None + if self.write_f: + self.write_f.close() + self.write_f = None if self.is_device: os.unlink(self.filename_read) os.unlink(self.filename_write) + self.filename_read = None + self.filename_write = None - def _ready_to_read(self): - rlist, _, _ = select([self.read_f], [], [], 0) - return len(rlist) > 0 + def read(self): + return self.protocol.read(self) - def _write_chunk(self, chunk): + def write(self, msg): + return self.protocol.write(self, msg) + + def write_chunk(self, chunk): if len(chunk) != 64: - raise Exception("Unexpected data length") + raise Exception('Unexpected chunk size: %d' % len(chunk)) + self.write_f.write(chunk) + self.write_f.flush() - try: - self.write_f.write(chunk) - self.write_f.flush() - except OSError: - print("Error while writing to socket") - raise - - def _read_chunk(self): + def read_chunk(self): while True: - try: - data = self.read_f.read(64) - except IOError: - print("Failed to read from device") - raise - - if not len(data): + chunk = self.read_f.read(64) + if chunk: + break + else: time.sleep(0.001) - continue - - break - - if len(data) != 64: - raise Exception("Unexpected chunk size: %d" % len(data)) - - return bytearray(data) + if len(chunk) != 64: + raise Exception('Unexpected chunk size: %d' % len(chunk)) + return bytearray(chunk) diff --git a/trezorlib/transport_udp.py b/trezorlib/transport_udp.py index 1bf5ad8..6432874 100644 --- a/trezorlib/transport_udp.py +++ b/trezorlib/transport_udp.py @@ -16,66 +16,76 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -'''UDP Socket implementation of Transport.''' +from __future__ import absolute_import import socket -from select import select -from .transport import TransportV2 + +from .protocol_v2 import ProtocolV2 +from .transport import Transport -class UdpTransport(TransportV2): +class UdpTransport(Transport): - def __init__(self, device, *args, **kwargs): - if device is None: - device = '' - device = device.split(':') - if len(device) < 2: - if not device[0]: - # Default port used by trezor v2 - device = ('127.0.0.1', 21324) - else: - device = ('127.0.0.1', int(device[0])) + DEFAULT_HOST = '127.0.0.1' + DEFAULT_PORT = 21324 + + def __init__(self, device=None, protocol=None): + super(UdpTransport, self).__init__() + + if not device: + host = UdpTransport.DEFAULT_HOST + port = UdpTransport.DEFAULT_PORT else: - device = (device[0], int(device[1])) - + host = device.split(':').get(0) + port = device.split(':').get(1, UdpTransport.DEFAULT_PORT) + port = int(port) + if not protocol: + protocol = ProtocolV2() + self.device = (host, port) + self.protocol = protocol self.socket = None - super(UdpTransport, self).__init__(device, *args, **kwargs) - @classmethod - def enumerate(cls): - raise Exception('This transport cannot enumerate devices') + def __str__(self): + return self.device - @classmethod - def find_by_path(cls, path=None): - return cls(path) + @staticmethod + def enumerate(): + raise NotImplementedError('This transport cannot enumerate devices') - def _open(self): + @staticmethod + def find_by_path(path=None): + return UdpTransport(path) + + def open(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket.connect(self.device) self.socket.settimeout(10) + self.protocol.session_begin(self) - def _close(self): - self.socket.close() - self.socket = None + def close(self): + if self.socket: + self.protocol.session_end(self) + self.socket.close() + self.socket = None - def _ready_to_read(self): - rlist, _, _ = select([self.socket], [], [], 0) - return len(rlist) > 0 + def read(self): + return self.protocol.read(self) - def _write_chunk(self, chunk): + def write(self, msg): + return self.protocol.write(self, msg) + + def write_chunk(self, chunk): if len(chunk) != 64: - raise Exception("Unexpected data length") - + raise Exception('Unexpected data length') self.socket.sendall(chunk) - def _read_chunk(self): + def read_chunk(self): while True: try: - data = self.socket.recv(64) + chunk = self.socket.recv(64) break except socket.timeout: continue - if len(data) != 64: - raise Exception("Unexpected chunk size: %d" % len(data)) - - return bytearray(data) + if len(chunk) != 64: + raise Exception('Unexpected chunk size: %d' % len(chunk)) + return bytearray(chunk)