diff --git a/plugins/trezor/client.py b/plugins/trezor/client.py index 6b355161..422bb9c3 100644 --- a/plugins/trezor/client.py +++ b/plugins/trezor/client.py @@ -1,6 +1,7 @@ from sys import stderr from electrum.i18n import _ +from electrum.util import PrintError class GuiMixin(object): @@ -58,18 +59,31 @@ class GuiMixin(object): def trezor_client_class(protocol_mixin, base_client, proto): '''Returns a class dynamically.''' - class TrezorClient(protocol_mixin, GuiMixin, base_client): + class TrezorClient(protocol_mixin, GuiMixin, base_client, PrintError): - def __init__(self, transport, device): + def __init__(self, transport, plugin): base_client.__init__(self, transport) protocol_mixin.__init__(self, transport) self.proto = proto - self.device = device + self.device = plugin.device + self.handler = plugin.handler + self.tx_api = plugin + self.bad = False + + def firmware_version(self): + f = self.features + v = (f.major_version, f.minor_version, f.patch_version) + self.print_error('firmware version', v) + return v + + def atleast_version(self, major, minor=0, patch=0): + return cmp(self.firmware_version(), (major, minor, patch)) def call_raw(self, msg): try: return base_client.call_raw(self, msg) except: + self.print_error("Marking %s client bad" % self.device) self.bad = True raise diff --git a/plugins/trezor/plugin.py b/plugins/trezor/plugin.py index 704d1c73..8c3d05da 100644 --- a/plugins/trezor/plugin.py +++ b/plugins/trezor/plugin.py @@ -224,37 +224,33 @@ class TrezorCompatiblePlugin(BasePlugin): return False return True - def get_client(self): + def create_client(self): if not self.libraries_available: self.give_error(_('please install the %s libraries from %s') % (self.device, self.libraries_URL)) + devices = self.HidTransport.enumerate() + if not devices: + self.give_error(_('Could not connect to your %s. Please ' + 'verify the cable is connected and that no ' + 'other app is using it.' % self.device)) + + transport = self.HidTransport(devices[0]) + client = self.client_class(transport, self) + if not client.atleast_version(*self.minimum_firmware): + self.give_error(_('Outdated %s firmware. Please update the ' + 'firmware from %s') + % (self.device, self.firmware_URL)) + return client + + def get_client(self): if not self.client or self.client.bad: - d = self.HidTransport.enumerate() - if not d: - self.give_error(_('Could not connect to your %s. Please ' - 'verify the cable is connected and that no ' - 'other app is using it.' % self.device)) - transport = self.HidTransport(d[0]) - self.client = self.client_class(transport, self.device) - self.client.handler = self.handler - self.client.set_tx_api(self) - self.client.bad = False - if not self.atleast_version(*self.minimum_firmware): - self.client = None - self.give_error(_('Outdated %s firmware. Please update the ' - 'firmware from %s') % (self.device, - self.firmware_URL)) + self.client = self.create_client() + return self.client - def compare_version(self, major, minor=0, patch=0): - f = self.get_client().features - v = [f.major_version, f.minor_version, f.patch_version] - self.print_error('firmware version', v) - return cmp(v, [major, minor, patch]) - def atleast_version(self, major, minor=0, patch=0): - return self.compare_version(major, minor, patch) >= 0 + return self.get_client().atleast_version(major, minor, patch) @hook def close_wallet(self): @@ -395,6 +391,7 @@ class TrezorCompatiblePlugin(BasePlugin): o.script_pubkey = vout['scriptPubKey'].decode('hex') return t + # This function is called from the trezor libraries (via tx_api) def get_tx(self, tx_hash): tx = self.prev_tx[tx_hash] tx.deserialize()