Clean up client caching and handling

This commit is contained in:
Neil Booth 2016-01-20 00:28:54 +09:00
parent a1d55fac4e
commit 24037be99c
6 changed files with 197 additions and 212 deletions

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import namedtuple
import traceback
import sys
import os
@ -226,6 +227,7 @@ class BasePlugin(PrintError):
def settings_dialog(self):
pass
Device = namedtuple("Device", "path id_ product_key")
class DeviceMgr(PrintError):
'''Manages hardware clients. A client communicates over a hardware
@ -262,82 +264,115 @@ class DeviceMgr(PrintError):
def __init__(self):
super(DeviceMgr, self).__init__()
# Keyed by wallet. The value is the hid_id if the wallet has
# been paired, and None otherwise.
# Keyed by wallet. The value is the device id if the wallet
# has been paired, and None otherwise.
self.wallets = {}
# A list of clients. We create a client for every device present
# that is of a registered hardware type
self.clients = []
# What we recognise. Keyed by (vendor_id, product_id) pairs,
# the value is a callback to create a client for those devices
self.recognised_hardware = {}
# A list of clients. The key is the client, the value is
# a (path, id_) pair.
self.clients = {}
# What we recognise. Each entry is a (vendor_id, product_id)
# pair.
self.recognised_hardware = set()
# For synchronization
self.lock = threading.RLock()
def register_devices(self, device_pairs, create_client):
def register_devices(self, device_pairs):
for pair in device_pairs:
self.recognised_hardware[pair] = create_client
self.recognised_hardware.add(pair)
def unpair(self, hid_id):
with self.lock:
wallet = self.wallet_by_hid_id(hid_id)
if wallet:
self.wallets[wallet] = None
def close_client(self, client):
with self.lock:
if client in self.clients:
self.clients.remove(client)
def create_client(self, device, handler, plugin):
client = plugin.create_client(device, handler)
if client:
client.close()
self.print_error("Registering", client)
with self.lock:
self.clients[client] = (device.path, device.id_)
return client
def close_wallet(self, wallet):
# Remove the wallet from our list; close any client
with self.lock:
hid_id = self.wallets.pop(wallet, None)
self.close_client(self.client_by_hid_id(hid_id))
def unpaired_clients(self, handler, classinfo):
'''Returns all unpaired clients of the given type.'''
self.scan_devices(handler)
with self.lock:
return [client for client in self.clients
if isinstance(client, classinfo)
and not self.wallet_by_hid_id(client.hid_id())]
def client_by_hid_id(self, hid_id, handler=None):
'''Like get_client() but when we don't care about wallet pairing. If
a device is wiped or in bootloader mode pairing is impossible;
in such cases we communicate by device ID and not wallet.'''
if handler:
self.scan_devices(handler)
with self.lock:
for client in self.clients:
if client.hid_id() == hid_id:
return client
return None
def wallet_hid_id(self, wallet):
def wallet_id(self, wallet):
with self.lock:
return self.wallets.get(wallet)
def wallet_by_hid_id(self, hid_id):
def wallet_by_id(self, id_):
with self.lock:
for wallet, wallet_hid_id in self.wallets.items():
if wallet_hid_id == hid_id:
for wallet, wallet_id in self.wallets.items():
if wallet_id == id_:
return wallet
return None
def paired_wallets(self):
def unpair_wallet(self, wallet):
with self.lock:
return [wallet for (wallet, hid_id) in self.wallets.items()
if hid_id is not None]
wallet_id = self.wallets.pop(wallet)
client = self.client_lookup(wallet_id)
self.clients.pop(client, None)
wallet.unpaired()
if client:
client.close()
def pair_wallet(self, wallet, client):
assert client in self.clients
self.print_error("paired:", wallet, client)
self.wallets[wallet] = client.hid_id()
wallet.connected()
def unpair_id(self, id_):
with self.lock:
wallet = self.wallet_by_id(id_)
if wallet:
self.unpair_wallet(wallet)
def pair_wallet(self, wallet, id_):
with self.lock:
self.wallets[wallet] = id_
wallet.paired()
def paired_wallets(self):
return list(self.wallets.keys())
def client_lookup(self, id_):
with self.lock:
for client, (path, client_id) in self.clients.items():
if client_id == id_:
return client
return None
def client_by_id(self, id_, handler):
'''Returns a client for the device ID if one is registered. If
a device is wiped or in bootloader mode pairing is impossible;
in such cases we communicate by device ID and not wallet.'''
self.scan_devices(handler)
return self.client_lookup(id_)
def client_for_wallet(self, plugin, wallet, force_pair):
assert wallet.handler
devices = self.scan_devices(wallet.handler)
wallet_id = self.wallet_id(wallet)
client = self.client_lookup(wallet_id)
if client:
return client
for device in devices:
if device.id_ == wallet_id:
return self.create_client(device, wallet.handler, plugin)
if force_pair:
first_address, derivation = wallet.first_address()
# Wallets don't have a first address in the install wizard
# until account creation
if not first_address:
self.print_error("no first address for ", wallet)
return None
# The wallet has not been previously paired, so get the
# first address of all unpaired clients and compare.
for device in devices:
# Skip already-paired devices
if self.wallet_by_id(device.id_):
continue
client = self.create_client(device, wallet.handler, plugin)
if client and not client.features.bootloader_mode:
# This will trigger a PIN/passphrase entry request
client_first_address = client.first_address(derivation)
if client_first_address == first_address:
self.pair_wallet(wallet, device.id_)
return client
return None
def scan_devices(self, handler):
# All currently supported hardware libraries use hid, so we
@ -349,76 +384,27 @@ class DeviceMgr(PrintError):
self.print_error("scanning devices...")
# First see what's connected that we know about
devices = {}
devices = []
for d in hid.enumerate(0, 0):
product_key = (d['vendor_id'], d['product_id'])
create_client = self.recognised_hardware.get(product_key)
if create_client:
devices[d['serial_number']] = (create_client, d['path'])
if product_key in self.recognised_hardware:
devices.append(Device(d['path'], d['serial_number'],
product_key))
# Now find out what was disconnected
pairs = [(dev.path, dev.id_) for dev in devices]
disconnected_ids = []
with self.lock:
disconnected = [client for client in self.clients
if not client.hid_id() in devices]
connected = {}
for client, pair in self.clients.items():
if pair in pairs:
connected[client] = pair
else:
disconnected_ids.append(pair[1])
self.clients = connected
# Close disconnected clients after informing their wallets
for client in disconnected:
wallet = self.wallet_by_hid_id(client.hid_id())
if wallet:
wallet.disconnected()
self.close_client(client)
# Unpair disconnected devices
for id_ in disconnected_ids:
self.unpair_id(id_)
# Now see if any new devices are present.
for hid_id, (create_client, path) in devices.items():
try:
client = create_client(path, handler, hid_id)
except BaseException as e:
self.print_error("could not create client", str(e))
client = None
if client:
self.print_error("client created for", path)
with self.lock:
self.clients.append(client)
# Inform re-paired wallet
wallet = self.wallet_by_hid_id(hid_id)
if wallet:
self.pair_wallet(wallet, client)
def get_client(self, wallet, force_pair=True):
'''Returns a client for the wallet, or None if one could not be found.
If force_pair is False then if an already paired client cannot
be found None is returned rather than requiring user
interaction.'''
# We must scan devices to get an up-to-date idea of which
# devices are present. Operating on a client when its device
# has been removed can cause the process to hang.
# Unfortunately there is no plugged / unplugged notification
# system.
self.scan_devices(wallet.handler)
# Previously paired wallets only need look for matching HID IDs
hid_id = self.wallet_hid_id(wallet)
if hid_id:
return self.client_by_hid_id(hid_id)
first_address, derivation = wallet.first_address()
# Wallets don't have a first address in the install wizard
# until account creation
if not first_address:
self.print_error("no first address for ", wallet)
return None
with self.lock:
# The wallet has not been previously paired, so get the
# first address of all unpaired clients and compare.
for client in self.clients:
# If already paired skip it
if self.wallet_by_hid_id(client.hid_id()):
continue
# This will trigger a PIN/passphrase entry request
if client.first_address(derivation) == first_address:
self.pair_wallet(wallet, client)
return client
# Not found
return None
return devices

View File

@ -2,10 +2,10 @@ from keepkeylib.client import proto, BaseClient, ProtocolMixin
from ..trezor.clientbase import TrezorClientBase
class KeepKeyClient(TrezorClientBase, ProtocolMixin, BaseClient):
def __init__(self, transport, handler, plugin, hid_id):
def __init__(self, transport, handler, plugin):
BaseClient.__init__(self, transport)
ProtocolMixin.__init__(self, transport)
TrezorClientBase.__init__(self, handler, plugin, hid_id, proto)
TrezorClientBase.__init__(self, handler, plugin, proto)
def recovery_device(self, *args):
ProtocolMixin.recovery_device(self, True, *args)

View File

@ -2,10 +2,10 @@ from trezorlib.client import proto, BaseClient, ProtocolMixin
from clientbase import TrezorClientBase
class TrezorClient(TrezorClientBase, ProtocolMixin, BaseClient):
def __init__(self, transport, handler, plugin, hid_id):
def __init__(self, transport, handler, plugin):
BaseClient.__init__(self, transport)
ProtocolMixin.__init__(self, transport)
TrezorClientBase.__init__(self, handler, plugin, hid_id, proto)
TrezorClientBase.__init__(self, handler, plugin, proto)
TrezorClientBase.wrap_methods(TrezorClient)

View File

@ -68,27 +68,22 @@ class GuiMixin(object):
class TrezorClientBase(GuiMixin, PrintError):
def __init__(self, handler, plugin, hid_id, proto):
def __init__(self, handler, plugin, proto):
assert hasattr(self, 'tx_api') # ProtocolMixin already constructed?
self.proto = proto
self.device = plugin.device
self.handler = handler
self.hid_id_ = hid_id
self.tx_api = plugin
self.types = plugin.types
self.msg_code_override = None
def __str__(self):
return "%s/%s" % (self.label(), self.hid_id())
return "%s/%s" % (self.label(), self.features.device_id)
def label(self):
'''The name given by the user to the device.'''
return self.features.label
def hid_id(self):
'''The HID ID of the device.'''
return self.hid_id_
def is_initialized(self):
'''True if initialized, False if wiped.'''
return self.features.initialized
@ -163,7 +158,7 @@ class TrezorClientBase(GuiMixin, PrintError):
def close(self):
'''Called when Our wallet was closed or the device removed.'''
self.print_error("disconnected")
self.print_error("closing client")
self.clear_session()
# Release the device
self.transport.close()

View File

@ -51,23 +51,26 @@ class TrezorCompatibleWallet(BIP44_Wallet):
self.session_timeout = seconds
self.storage.put('session_timeout', seconds)
def disconnected(self):
'''A device paired with the wallet was diconnected. Note this is
called in the context of the Plugins thread.'''
self.print_error("disconnected")
def unpaired(self):
'''A device paired with the wallet was diconnected. This can be
called in any thread context.'''
self.print_error("unpaired")
self.force_watching_only = True
self.handler.watching_only_changed()
def connected(self):
'''A device paired with the wallet was (re-)connected. Note this
is called in the context of the Plugins thread.'''
self.print_error("connected")
def paired(self):
'''A device paired with the wallet was (re-)connected. This can be
called in any thread context.'''
self.print_error("paired")
self.force_watching_only = False
self.handler.watching_only_changed()
def timeout(self):
'''Informs the wallet it timed out. Note this is called from
'''Called when the wallet session times out. Note this is called from
the Plugins thread.'''
client = self.get_client(force_pair=False)
if client:
client.clear_session()
self.print_error("timed out")
def get_action(self):
@ -178,8 +181,7 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob):
self.wallet_class.plugin = self
self.prevent_timeout = time.time() + 3600 * 24 * 365
if self.libraries_available:
self.device_manager().register_devices(
self.DEVICE_IDS, self.create_client)
self.device_manager().register_devices(self.DEVICE_IDS)
def is_enabled(self):
return self.libraries_available
@ -199,13 +201,11 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob):
if (isinstance(wallet, self.wallet_class)
and hasattr(wallet, 'last_operation')
and now > wallet.last_operation + wallet.session_timeout):
client = self.get_client(wallet, force_pair=False)
if client:
client.clear_session()
wallet.last_operation = self.prevent_timeout
wallet.timeout()
wallet.timeout()
wallet.last_operation = self.prevent_timeout
def create_client(self, path, handler, hid_id):
def create_client(self, device, handler):
path = device.path
pair = ((None, path) if self.HidTransport._detect_debuglink(path)
else (path, None))
try:
@ -215,50 +215,48 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob):
self.print_error("cannot connect at", path, str(e))
return None
self.print_error("connected to device at", path)
return self.client_class(transport, handler, self, hid_id)
def get_client(self, wallet, force_pair=True, check_firmware=True):
assert self.main_thread != threading.current_thread()
'''check_firmware is ignored unless force_pair is True.'''
client = self.device_manager().get_client(wallet, force_pair)
client = self.client_class(transport, handler, self)
# Try a ping for device sanity
try:
client.ping('t')
except BaseException as e:
self.print_error("ping failed", str(e))
return None
if not client.atleast_version(*self.minimum_firmware):
msg = (_('Outdated %s firmware for device labelled %s. Please '
'download the updated firmware from %s') %
(self.device, client.label(), self.firmware_URL))
handler.show_error(msg)
return None
return client
def get_client(self, wallet, force_pair=True):
# All client interaction should not be in the main GUI thread
assert self.main_thread != threading.current_thread()
devmgr = self.device_manager()
client = devmgr.client_for_wallet(self, wallet, force_pair)
if client:
self.print_error("set last_operation")
wallet.last_operation = time.time()
try:
client.ping('t')
except BaseException as e:
self.print_error("ping failed", str(e))
# Remove it from the manager's cache
self.device_manager().close_client(client)
client = None
if force_pair:
assert wallet.handler
if not client:
msg = (_('Could not connect to your %s. Verify the '
'cable is connected and that no other app is '
'using it.\nContinuing in watching-only mode '
'until the device is re-connected.') % self.device)
wallet.handler.show_error(msg)
raise DeviceDisconnectedError(msg)
if (check_firmware and not
client.atleast_version(*self.minimum_firmware)):
msg = (_('Outdated %s firmware for device labelled %s. Please '
'download the updated firmware from %s') %
(self.device, client.label(), self.firmware_URL))
wallet.handler.show_error(msg)
raise OutdatedFirmwareError(msg)
elif force_pair:
msg = (_('Could not connect to your %s. Verify the '
'cable is connected and that no other app is '
'using it.\nContinuing in watching-only mode '
'until the device is re-connected.') % self.device)
raise DeviceDisconnectedError(msg)
return client
@hook
def close_wallet(self, wallet):
if isinstance(wallet, self.wallet_class):
self.device_manager().close_wallet(wallet)
self.device_manager().unpair_wallet(wallet)
def initialize_device(self, wallet):
# Prevent timeouts during initialization
@ -310,27 +308,32 @@ class TrezorCompatiblePlugin(BasePlugin, ThreadJob):
wallet.thread.add(initialize_device)
def unpaired_clients(self, handler):
def unpaired_devices(self, handler):
'''Returns all connected, unpaired devices as a list of clients and a
list of descriptions.'''
devmgr = self.device_manager()
clients = devmgr.unpaired_clients(handler, self.client_class)
states = [_("wiped"), _("initialized")]
def client_desc(client):
label = client.label() or _("An unnamed device")
devices = devmgr.unpaired_devices(handler)
good_devices, descrs = [], []
for device in devices:
client = self.device_manager().create_client(device, handler, self)
if not client:
continue
state = states[client.is_initialized()]
return ("%s: serial number %s (%s)"
% (label, client.hid_id(), state))
return clients, list(map(client_desc, clients))
label = device.info['label'] or _("An unnamed device")
good_devices.append(device)
descrs.append("%s: device ID %s (%s)" % (label, device.id_, state))
return good_devices, descrs
def select_device(self, wallet):
'''Called when creating a new wallet. Select the device to use. If
the device is uninitialized, go through the intialization
process.'''
msg = _("Please select which %s device to use:") % self.device
clients, labels = self.unpaired_clients(wallet.handler)
client = clients[wallet.handler.query_choice(msg, labels)]
self.device_manager().pair_wallet(wallet, client)
devices, labels = self.unpaired_devices(wallet.handler)
device = devices[wallet.handler.query_choice(msg, labels)]
self.device_manager().pair_wallet(wallet, device.id_)
if not client.is_initialized():
self.initialize_device(wallet)

View File

@ -252,25 +252,25 @@ def qt_plugin_class(base_plugin_class):
menu.addAction(_("Show on %s") % self.device, show_address)
def settings_dialog(self, window):
hid_id = self.choose_device(window)
if hid_id:
SettingsDialog(window, self, hid_id).exec_()
device_id = self.choose_device(window)
if device_id:
SettingsDialog(window, self, device_id).exec_()
def choose_device(self, window):
'''This dialog box should be usable even if the user has
forgotten their PIN or it is in bootloader mode.'''
handler = window.wallet.handler
hid_id = self.device_manager().wallet_hid_id(window.wallet)
if not hid_id:
clients, labels = self.unpaired_clients(handler)
if clients:
device_id = self.device_manager().wallet_id(window.wallet)
if not device_id:
devices, labels = self.unpaired_devices(handler)
if devices:
msg = _("Select a %s device:") % self.device
choice = self.query_choice(window, msg, labels)
if choice is not None:
hid_id = clients[choice].hid_id()
device_id = devices[choice].id_
else:
handler.show_error(_("No devices found"))
return hid_id
return device_id
def query_choice(self, window, msg, choices):
dialog = WindowModalDialog(window)
@ -292,28 +292,29 @@ class SettingsDialog(WindowModalDialog):
We want users to be able to wipe a device even if they've forgotten
their PIN.'''
def __init__(self, window, plugin, hid_id):
def __init__(self, window, plugin, device_id):
title = _("%s Settings") % plugin.device
super(SettingsDialog, self).__init__(window, title)
self.setMaximumWidth(540)
devmgr = plugin.device_manager()
handler = window.wallet.handler
thread = window.wallet.thread
# wallet can be None, needn't be window.wallet
wallet = devmgr.wallet_by_hid_id(hid_id)
wallet = devmgr.wallet_by_id(device_id)
hs_rows, hs_cols = (64, 128)
self.current_label=None
def invoke_client(method, *args, **kw_args):
def task():
client = plugin.get_client(wallet, False)
client = devmgr.client_by_id(device_id, handler)
if not client:
raise RuntimeError("Device not connected")
if method:
getattr(client, method)(*args, **kw_args)
update(client.features)
wallet.thread.add(task)
thread.add(task)
def update(features):
self.current_label = features.label
@ -364,7 +365,7 @@ class SettingsDialog(WindowModalDialog):
if not self.question(msg, title=title):
return
invoke_client('toggle_passphrase')
devmgr.unpair(hid_id)
devmgr.unpair(device_id)
def change_homescreen():
from PIL import Image # FIXME
@ -402,7 +403,7 @@ class SettingsDialog(WindowModalDialog):
icon=QMessageBox.Critical):
return
invoke_client('wipe_device')
devmgr.unpair(hid_id)
devmgr.unpair(device_id)
def slider_moved():
mins = timeout_slider.sliderPosition()