transport: add TransportException

Fixes #134
This commit is contained in:
Jan Pochyla 2017-09-05 17:15:00 +02:00
parent ac0184413d
commit 66ba2c20c0
5 changed files with 33 additions and 25 deletions

View File

@ -20,6 +20,10 @@
from __future__ import absolute_import from __future__ import absolute_import
class TransportException(Exception):
pass
class Transport(object): class Transport(object):
def __init__(self): def __init__(self):

View File

@ -23,7 +23,7 @@ import requests
from google.protobuf import json_format from google.protobuf import json_format
from . import messages_pb2 from . import messages_pb2
from .transport import Transport from .transport import Transport, TransportException
TREZORD_HOST = 'https://localback.net:21324' TREZORD_HOST = 'https://localback.net:21324'
CONFIG_URL = 'https://wallet.trezor.io/data/config_signed.bin' CONFIG_URL = 'https://wallet.trezor.io/data/config_signed.bin'
@ -57,10 +57,12 @@ class BridgeTransport(Transport):
return return
r = requests.get(CONFIG_URL, verify=False) r = requests.get(CONFIG_URL, verify=False)
if r.status_code != 200: if r.status_code != 200:
raise Exception('Could not fetch config from %s' % CONFIG_URL) raise TransportException(
'Could not fetch config from %s' % CONFIG_URL)
r = requests.post(TREZORD_HOST + '/configure', data=r.text) r = requests.post(TREZORD_HOST + '/configure', data=r.text)
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not configure' + get_error(r)) raise TransportException('trezord: Could not configure' +
get_error(r))
BridgeTransport.configured = True BridgeTransport.configured = True
@staticmethod @staticmethod
@ -68,8 +70,8 @@ class BridgeTransport(Transport):
BridgeTransport.configure() BridgeTransport.configure()
r = requests.get(TREZORD_HOST + '/enumerate') r = requests.get(TREZORD_HOST + '/enumerate')
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not enumerate devices' + raise TransportException('trezord: Could not enumerate devices' +
get_error(r)) get_error(r))
return [BridgeTransport(dev) for dev in r.json()] return [BridgeTransport(dev) for dev in r.json()]
@staticmethod @staticmethod
@ -77,13 +79,13 @@ class BridgeTransport(Transport):
for transport in BridgeTransport.enumerate(): for transport in BridgeTransport.enumerate():
if path is None or transport.device['path'] == path: if path is None or transport.device['path'] == path:
return transport return transport
raise Exception('Bridge device not found') raise TransportException('Bridge device not found')
def open(self): def open(self):
r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.device['path']) r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.device['path'])
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not acquire session' + raise TransportException('trezord: Could not acquire session' +
get_error(r)) get_error(r))
self.session = r.json()['session'] self.session = r.json()['session']
def close(self): def close(self):
@ -91,8 +93,8 @@ class BridgeTransport(Transport):
return return
r = self.conn.post(TREZORD_HOST + '/release/%s' % self.session) r = self.conn.post(TREZORD_HOST + '/release/%s' % self.session)
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not release session' + raise TransportException('trezord: Could not release session' +
get_error(r)) get_error(r))
self.session = None self.session = None
def write(self, msg): def write(self, msg):
@ -103,12 +105,13 @@ class BridgeTransport(Transport):
r = self.conn.post( r = self.conn.post(
TREZORD_HOST + '/call/%s' % self.session, data=payload) TREZORD_HOST + '/call/%s' % self.session, data=payload)
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not write message' + get_error(r)) raise TransportException('trezord: Could not write message' +
get_error(r))
self.response = r.json() self.response = r.json()
def read(self): def read(self):
if self.response is None: if self.response is None:
raise Exception('No response stored') raise TransportException('No response stored')
msgtype = getattr(messages_pb2, self.response['type']) msgtype = getattr(messages_pb2, self.response['type'])
msg = msgtype() msg = msgtype()
msg = json_format.ParseDict(self.response['message'], msg) msg = json_format.ParseDict(self.response['message'], msg)

View File

@ -23,7 +23,7 @@ import hid
from .protocol_v1 import ProtocolV1 from .protocol_v1 import ProtocolV1
from .protocol_v2 import ProtocolV2 from .protocol_v2 import ProtocolV2
from .transport import Transport from .transport import Transport, TransportException
DEV_TREZOR1 = (0x534c, 0x0001) DEV_TREZOR1 = (0x534c, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53c1) DEV_TREZOR2 = (0x1209, 0x53c1)
@ -96,7 +96,7 @@ class HidTransport(Transport):
for transport in HidTransport.enumerate(): for transport in HidTransport.enumerate():
if path is None or transport.device['path'] == path: if path is None or transport.device['path'] == path:
return transport return transport
raise Exception('HID device not found') raise TransportException('HID device not found')
def find_debug(self): def find_debug(self):
if isinstance(self.protocol, ProtocolV2): if isinstance(self.protocol, ProtocolV2):
@ -109,7 +109,7 @@ class HidTransport(Transport):
for debug in HidTransport.enumerate(debug=True): for debug in HidTransport.enumerate(debug=True):
if debug.device['serial_number'] == self.device['serial_number']: if debug.device['serial_number'] == self.device['serial_number']:
return debug return debug
raise Exception('Debug HID device not found') raise TransportException('Debug HID device not found')
def open(self): def open(self):
self.hid.open() self.hid.open()
@ -121,7 +121,7 @@ class HidTransport(Transport):
def close(self): def close(self):
self.protocol.session_end(self) self.protocol.session_end(self)
self.hid.close() self.hid.close()
self.hid_version = None self.hid_version = None
def read(self): def read(self):
@ -132,7 +132,7 @@ class HidTransport(Transport):
def write_chunk(self, chunk): def write_chunk(self, chunk):
if len(chunk) != 64: if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk)) raise TransportException('Unexpected chunk size: %d' % len(chunk))
if self.hid_version == 2: if self.hid_version == 2:
self.hid.handle.write(b'\0' + chunk) self.hid.handle.write(b'\0' + chunk)
else: else:
@ -146,7 +146,7 @@ class HidTransport(Transport):
else: else:
time.sleep(0.001) time.sleep(0.001)
if len(chunk) != 64: if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk)) raise TransportException('Unexpected chunk size: %d' % len(chunk))
return bytearray(chunk) return bytearray(chunk)
def probe_hid_version(self): def probe_hid_version(self):
@ -156,7 +156,7 @@ class HidTransport(Transport):
n = self.hid.handle.write([63] + [0xFF] * 63) n = self.hid.handle.write([63] + [0xFF] * 63)
if n == 64: if n == 64:
return 1 return 1
raise Exception('Unknown HID version') raise TransportException('Unknown HID version')
def is_trezor1(dev): def is_trezor1(dev):

View File

@ -22,9 +22,10 @@ import os
import time import time
from .protocol_v1 import ProtocolV1 from .protocol_v1 import ProtocolV1
from .transport import Transport, TransportException
class PipeTransport(object): class PipeTransport(Transport):
''' '''
PipeTransport implements fake wire transport over local named pipe. PipeTransport implements fake wire transport over local named pipe.
Use this transport for talking with trezor-emu. Use this transport for talking with trezor-emu.
@ -64,7 +65,7 @@ class PipeTransport(object):
self.filename_read = self.device + '.from' self.filename_read = self.device + '.from'
self.filename_write = self.device + '.to' self.filename_write = self.device + '.to'
if not os.path.exists(self.filename_write): if not os.path.exists(self.filename_write):
raise Exception("Not connected") raise TransportException('Not connected')
self.read_f = os.open(self.filename_read, 'rb', 0) self.read_f = os.open(self.filename_read, 'rb', 0)
self.write_f = os.open(self.filename_write, 'w+b', 0) self.write_f = os.open(self.filename_write, 'w+b', 0)
@ -93,7 +94,7 @@ class PipeTransport(object):
def write_chunk(self, chunk): def write_chunk(self, chunk):
if len(chunk) != 64: if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk)) raise TransportException('Unexpected chunk size: %d' % len(chunk))
self.write_f.write(chunk) self.write_f.write(chunk)
self.write_f.flush() self.write_f.flush()
@ -105,5 +106,5 @@ class PipeTransport(object):
else: else:
time.sleep(0.001) time.sleep(0.001)
if len(chunk) != 64: if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk)) raise TransportException('Unexpected chunk size: %d' % len(chunk))
return bytearray(chunk) return bytearray(chunk)

View File

@ -76,7 +76,7 @@ class UdpTransport(Transport):
def write_chunk(self, chunk): def write_chunk(self, chunk):
if len(chunk) != 64: if len(chunk) != 64:
raise Exception('Unexpected data length') raise TransportException('Unexpected data length')
self.socket.sendall(chunk) self.socket.sendall(chunk)
def read_chunk(self): def read_chunk(self):
@ -87,5 +87,5 @@ class UdpTransport(Transport):
except socket.timeout: except socket.timeout:
continue continue
if len(chunk) != 64: if len(chunk) != 64:
raise Exception('Unexpected chunk size: %d' % len(chunk)) raise TransportException('Unexpected chunk size: %d' % len(chunk))
return bytearray(chunk) return bytearray(chunk)