Trezor: session timeout improvements

Move session timeout from wallet to config
Prevent timeouts whenever a device operation is in progress
Move timeout job from each plugin to device manager
This commit is contained in:
Neil Booth 2016-02-10 21:51:22 +09:00
parent 7fcc881dd4
commit 5f28834bb2
8 changed files with 73 additions and 82 deletions

View File

@ -45,8 +45,9 @@ class Plugins(DaemonThread):
self.plugins = {} self.plugins = {}
self.gui_name = gui_name self.gui_name = gui_name
self.descriptions = {} self.descriptions = {}
self.device_manager = DeviceMgr() self.device_manager = DeviceMgr(config)
self.load_plugins() self.load_plugins()
self.add_jobs(self.device_manager.thread_jobs())
self.start() self.start()
def load_plugins(self): def load_plugins(self):
@ -247,7 +248,7 @@ class DeviceUnpairableError(Exception):
Device = namedtuple("Device", "path interface_number id_ product_key") Device = namedtuple("Device", "path interface_number id_ product_key")
DeviceInfo = namedtuple("DeviceInfo", "device description initialized") DeviceInfo = namedtuple("DeviceInfo", "device description initialized")
class DeviceMgr(PrintError): class DeviceMgr(ThreadJob, PrintError):
'''Manages hardware clients. A client communicates over a hardware '''Manages hardware clients. A client communicates over a hardware
channel with the device. channel with the device.
@ -276,11 +277,9 @@ class DeviceMgr(PrintError):
the HID IDs. the HID IDs.
This plugin is thread-safe. Currently only devices supported by This plugin is thread-safe. Currently only devices supported by
hidapi are implemented. hidapi are implemented.'''
''' def __init__(self, config):
def __init__(self):
super(DeviceMgr, self).__init__() super(DeviceMgr, self).__init__()
# Keyed by wallet. The value is the device id if the wallet # Keyed by wallet. The value is the device id if the wallet
# has been paired, and None otherwise. # has been paired, and None otherwise.
@ -293,6 +292,20 @@ class DeviceMgr(PrintError):
self.recognised_hardware = set() self.recognised_hardware = set()
# For synchronization # For synchronization
self.lock = threading.RLock() self.lock = threading.RLock()
self.config = config
def thread_jobs(self):
# Thread job to handle device timeouts
return [self]
def run(self):
'''Handle device timeouts. Runs in the context of the Plugins
thread.'''
with self.lock:
clients = list(self.clients.keys())
cutoff = time.time() - self.config.get_session_timeout()
for client in clients:
client.timeout(cutoff)
def register_devices(self, device_pairs): def register_devices(self, device_pairs):
for pair in device_pairs: for pair in device_pairs:
@ -343,9 +356,6 @@ class DeviceMgr(PrintError):
self.wallets[wallet] = id_ self.wallets[wallet] = id_
wallet.paired() wallet.paired()
def paired_wallets(self):
return list(self.wallets.keys())
def client_lookup(self, id_): def client_lookup(self, id_):
with self.lock: with self.lock:
for client, (path, client_id) in self.clients.items(): for client, (path, client_id) in self.clients.items():

View File

@ -4,7 +4,7 @@ import threading
import os import os
from copy import deepcopy from copy import deepcopy
from util import user_dir, print_error, print_msg, print_stderr from util import user_dir, print_error, print_msg, print_stderr, PrintError
SYSTEM_CONFIG_PATH = "/etc/electrum.conf" SYSTEM_CONFIG_PATH = "/etc/electrum.conf"
@ -21,7 +21,7 @@ def set_config(c):
config = c config = c
class SimpleConfig(object): class SimpleConfig(PrintError):
""" """
The SimpleConfig class is responsible for handling operations involving The SimpleConfig class is responsible for handling operations involving
configuration files. configuration files.
@ -168,6 +168,13 @@ class SimpleConfig(object):
recent.remove(filename) recent.remove(filename)
self.set_key('recently_open', recent) self.set_key('recently_open', recent)
def set_session_timeout(self, seconds):
self.print_error("session timeout -> %d seconds" % seconds)
self.set_key('session_timeout', seconds)
def get_session_timeout(self):
return self.get('session_timeout', 300)
def read_system_config(path=SYSTEM_CONFIG_PATH): def read_system_config(path=SYSTEM_CONFIG_PATH):
"""Parse and return the system config settings in /etc/electrum.conf.""" """Parse and return the system config settings in /etc/electrum.conf."""

View File

@ -33,18 +33,11 @@ class BIP44_HW_Wallet(BIP44_Wallet):
def __init__(self, storage): def __init__(self, storage):
BIP44_Wallet.__init__(self, storage) BIP44_Wallet.__init__(self, storage)
# After timeout seconds we clear the device session
self.session_timeout = storage.get('session_timeout', 180)
# Errors and other user interaction is done through the wallet's # Errors and other user interaction is done through the wallet's
# handler. The handler is per-window and preserved across # handler. The handler is per-window and preserved across
# device reconnects # device reconnects
self.handler = None self.handler = None
def set_session_timeout(self, seconds):
self.print_error("setting session timeout to %d seconds" % seconds)
self.session_timeout = seconds
self.storage.put('session_timeout', seconds)
def unpaired(self): def unpaired(self):
'''A device paired with the wallet was diconnected. This can be '''A device paired with the wallet was diconnected. This can be
called in any thread context.''' called in any thread context.'''
@ -55,14 +48,6 @@ class BIP44_HW_Wallet(BIP44_Wallet):
called in any thread context.''' called in any thread context.'''
self.print_error("paired") self.print_error("paired")
def timeout(self):
'''Called when the wallet session times out. Note this is called from
the Plugins thread.'''
client = self.get_client(force_pair=False)
if client:
client.clear_session()
self.print_error("timed out")
def get_action(self): def get_action(self):
pass pass

View File

@ -17,14 +17,11 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
from electrum.util import ThreadJob
from electrum.plugins import BasePlugin, hook from electrum.plugins import BasePlugin, hook
from electrum.i18n import _ from electrum.i18n import _
class HW_PluginBase(BasePlugin, ThreadJob): class HW_PluginBase(BasePlugin):
# Derived classes provide: # Derived classes provide:
# #
# class-static variables: client_class, firmware_URL, handler_class, # class-static variables: client_class, firmware_URL, handler_class,
@ -35,7 +32,6 @@ class HW_PluginBase(BasePlugin, ThreadJob):
BasePlugin.__init__(self, parent, config, name) BasePlugin.__init__(self, parent, config, name)
self.device = self.wallet_class.device self.device = self.wallet_class.device
self.wallet_class.plugin = self self.wallet_class.plugin = self
self.prevent_timeout = time.time() + 3600 * 24 * 365
def is_enabled(self): def is_enabled(self):
return self.libraries_available return self.libraries_available
@ -43,21 +39,6 @@ class HW_PluginBase(BasePlugin, ThreadJob):
def device_manager(self): def device_manager(self):
return self.parent.device_manager return self.parent.device_manager
def thread_jobs(self):
# Thread job to handle device timeouts
return [self] if self.libraries_available else []
def run(self):
'''Handle device timeouts. Runs in the context of the Plugins
thread.'''
now = time.time()
for wallet in self.device_manager().paired_wallets():
if (isinstance(wallet, self.wallet_class)
and hasattr(wallet, 'last_operation')
and now > wallet.last_operation + wallet.session_timeout):
wallet.timeout()
wallet.last_operation = self.prevent_timeout
@hook @hook
def close_wallet(self, wallet): def close_wallet(self, wallet):
if isinstance(wallet, self.wallet_class): if isinstance(wallet, self.wallet_class):

View File

@ -392,8 +392,4 @@ class LedgerPlugin(HW_PluginBase):
wallet.proper_device = False wallet.proper_device = False
self.client = client self.client = client
if client:
self.print_error("set last_operation")
wallet.last_operation = time.time()
return self.client return self.client

View File

@ -1,4 +1,4 @@
from sys import stderr import time
from electrum.i18n import _ from electrum.i18n import _
from electrum.util import PrintError, UserCancelled from electrum.util import PrintError, UserCancelled
@ -82,6 +82,7 @@ class TrezorClientBase(GuiMixin, PrintError):
self.tx_api = plugin self.tx_api = plugin
self.types = plugin.types self.types = plugin.types
self.msg_code_override = None self.msg_code_override = None
self.used()
def __str__(self): def __str__(self):
return "%s/%s" % (self.label(), self.features.device_id) return "%s/%s" % (self.label(), self.features.device_id)
@ -97,6 +98,20 @@ class TrezorClientBase(GuiMixin, PrintError):
def is_pairable(self): def is_pairable(self):
return not self.features.bootloader_mode return not self.features.bootloader_mode
def used(self):
self.print_error("used")
self.last_operation = time.time()
def prevent_timeouts(self):
self.print_error("prevent timeouts")
self.last_operation = float('inf')
def timeout(self, cutoff):
'''Time out the client if the last operation was before cutoff.'''
if self.last_operation < cutoff:
self.print_error("timed out")
self.clear_session()
@staticmethod @staticmethod
def expand_path(n): def expand_path(n):
'''Convert bip32 path to list of uint32 integers with prime flags '''Convert bip32 path to list of uint32 integers with prime flags
@ -158,6 +173,7 @@ class TrezorClientBase(GuiMixin, PrintError):
'''Clear the session to force pin (and passphrase if enabled) '''Clear the session to force pin (and passphrase if enabled)
re-entry. Does not leak exceptions.''' re-entry. Does not leak exceptions.'''
self.print_error("clear session:", self) self.print_error("clear session:", self)
self.prevent_timeouts()
try: try:
super(TrezorClientBase, self).clear_session() super(TrezorClientBase, self).clear_session()
except BaseException as e: except BaseException as e:
@ -185,8 +201,10 @@ class TrezorClientBase(GuiMixin, PrintError):
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
try: try:
self.prevent_timeouts()
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
finally: finally:
self.used()
self.handler.finished() self.handler.finished()
return wrapped return wrapped

View File

@ -1,7 +1,6 @@
import base64 import base64
import re import re
import threading import threading
import time
from binascii import unhexlify from binascii import unhexlify
from functools import partial from functools import partial
@ -136,8 +135,7 @@ class TrezorCompatiblePlugin(HW_PluginBase):
devmgr = self.device_manager() devmgr = self.device_manager()
client = devmgr.client_for_wallet(self, wallet, force_pair) client = devmgr.client_for_wallet(self, wallet, force_pair)
if client: if client:
self.print_error("set last_operation") client.used()
wallet.last_operation = time.time()
return client return client
@ -147,9 +145,6 @@ class TrezorCompatiblePlugin(HW_PluginBase):
self.device_manager().unpair_wallet(wallet) self.device_manager().unpair_wallet(wallet)
def initialize_device(self, wallet): def initialize_device(self, wallet):
# Prevent timeouts during initialization
wallet.last_operation = self.prevent_timeout
# Initialization method # Initialization method
msg = _("Choose how you want to initialize your %s.\n\n" msg = _("Choose how you want to initialize your %s.\n\n"
"The first two methods are secure as no secret information " "The first two methods are secure as no secret information "

View File

@ -345,6 +345,7 @@ class SettingsDialog(WindowModalDialog):
self.setMaximumWidth(540) self.setMaximumWidth(540)
devmgr = plugin.device_manager() devmgr = plugin.device_manager()
config = devmgr.config
handler = window.wallet.handler handler = window.wallet.handler
thread = window.wallet.thread thread = window.wallet.thread
# wallet can be None, needn't be window.wallet # wallet can be None, needn't be window.wallet
@ -459,8 +460,7 @@ class SettingsDialog(WindowModalDialog):
timeout_minutes.setText(_("%2d minutes") % mins) timeout_minutes.setText(_("%2d minutes") % mins)
def slider_released(): def slider_released():
seconds = timeout_slider.sliderPosition() * 60 config.set_session_timeout(timeout_slider.sliderPosition() * 60)
wallet.set_session_timeout(seconds)
# Information tab # Information tab
info_tab = QWidget() info_tab = QWidget()
@ -549,29 +549,28 @@ class SettingsDialog(WindowModalDialog):
settings_glayout.addWidget(homescreen_msg, 5, 1, 1, -1) settings_glayout.addWidget(homescreen_msg, 5, 1, 1, -1)
# Settings tab - Session Timeout # Settings tab - Session Timeout
if wallet: timeout_label = QLabel(_("Session Timeout"))
timeout_label = QLabel(_("Session Timeout")) timeout_minutes = QLabel()
timeout_minutes = QLabel() timeout_slider = QSlider(Qt.Horizontal)
timeout_slider = QSlider(Qt.Horizontal) timeout_slider.setRange(1, 60)
timeout_slider.setRange(1, 60) timeout_slider.setSingleStep(1)
timeout_slider.setSingleStep(1) timeout_slider.setTickInterval(5)
timeout_slider.setTickInterval(5) timeout_slider.setTickPosition(QSlider.TicksBelow)
timeout_slider.setTickPosition(QSlider.TicksBelow) timeout_slider.setTracking(True)
timeout_slider.setTracking(True) timeout_msg = QLabel(
timeout_msg = QLabel( _("Clear the session after the specified period "
_("Clear the session after the specified period " "of inactivity. Once a session has timed out, "
"of inactivity. Once a session has timed out, " "your PIN and passphrase (if enabled) must be "
"your PIN and passphrase (if enabled) must be " "re-entered to use the device."))
"re-entered to use the device.")) timeout_msg.setWordWrap(True)
timeout_msg.setWordWrap(True) timeout_slider.setSliderPosition(config.get_session_timeout() // 60)
timeout_slider.setSliderPosition(wallet.session_timeout // 60) slider_moved()
slider_moved() timeout_slider.valueChanged.connect(slider_moved)
timeout_slider.valueChanged.connect(slider_moved) timeout_slider.sliderReleased.connect(slider_released)
timeout_slider.sliderReleased.connect(slider_released) settings_glayout.addWidget(timeout_label, 6, 0)
settings_glayout.addWidget(timeout_label, 6, 0) settings_glayout.addWidget(timeout_slider, 6, 1, 1, 3)
settings_glayout.addWidget(timeout_slider, 6, 1, 1, 3) settings_glayout.addWidget(timeout_minutes, 6, 4)
settings_glayout.addWidget(timeout_minutes, 6, 4) settings_glayout.addWidget(timeout_msg, 7, 1, 1, -1)
settings_glayout.addWidget(timeout_msg, 7, 1, 1, -1)
settings_layout.addLayout(settings_glayout) settings_layout.addLayout(settings_glayout)
settings_layout.addStretch(1) settings_layout.addStretch(1)