diff --git a/tests/device_tests/common.py b/tests/device_tests/common.py index e58a415..331829b 100644 --- a/tests/device_tests/common.py +++ b/tests/device_tests/common.py @@ -59,51 +59,37 @@ def pipe_exists(path): return False -if HID_ENABLED and HidTransport.enumerate(): - - devices = HidTransport.enumerate() - print('Using TREZOR') - TRANSPORT = HidTransport - TRANSPORT_ARGS = (devices[0],) - TRANSPORT_KWARGS = {} - DEBUG_TRANSPORT = HidTransport - DEBUG_TRANSPORT_ARGS = (devices[0].find_debug(),) - DEBUG_TRANSPORT_KWARGS = {} - -elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'): - - print('Using Emulator (v1=pipe)') - TRANSPORT = PipeTransport - TRANSPORT_ARGS = ('/tmp/pipe.trezor', False) - TRANSPORT_KWARGS = {} - DEBUG_TRANSPORT = PipeTransport - DEBUG_TRANSPORT_ARGS = ('/tmp/pipe.trezor_debug', False) - DEBUG_TRANSPORT_KWARGS = {} - -elif UDP_ENABLED: - - print('Using Emulator (v2=udp)') - TRANSPORT = UdpTransport - TRANSPORT_ARGS = ('', ) - TRANSPORT_KWARGS = {} - DEBUG_TRANSPORT = UdpTransport - DEBUG_TRANSPORT_ARGS = ('', ) - DEBUG_TRANSPORT_KWARGS = {} - - def get_transport(): - return TRANSPORT(*TRANSPORT_ARGS, **TRANSPORT_KWARGS) + if HID_ENABLED and HidTransport.enumerate(): + devices = HidTransport.enumerate() + wirelink = devices[0] + debuglink = devices[0].find_debug() + + elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'): + wirelink = PipeTransport('/tmp/pipe.trezor', False) + debuglink = PipeTransport('/tmp/pipe.trezor_debug', False) + + elif UDP_ENABLED: + wirelink = UdpTransport() + debuglink = UdpTransport() + + return wirelink, debuglink -def get_debug_transport(): - return DEBUG_TRANSPORT(*DEBUG_TRANSPORT_ARGS, **DEBUG_TRANSPORT_KWARGS) +if HID_ENABLED and HidTransport.enumerate(): + print('Using TREZOR') +elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'): + print('Using Emulator (v1=pipe)') +elif UDP_ENABLED: + print('Using Emulator (v2=udp)') class TrezorTest(unittest.TestCase): def setUp(self): - self.client = TrezorClientDebugLink(get_transport) - self.client.set_debuglink(get_debug_transport()) + wirelink, debuglink = get_transport() + self.client = TrezorClientDebugLink(wirelink) + self.client.set_debuglink(debuglink) self.client.set_tx_api(tx_api.TxApiBitcoin) # self.client.set_buttonwait(3) diff --git a/trezorctl b/trezorctl index f1c8443..e4e6ed8 100755 --- a/trezorctl +++ b/trezorctl @@ -19,10 +19,11 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . -import binascii -import json import base64 +import binascii import click +import functools +import json from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException import trezorlib.types_pb2 as types @@ -65,12 +66,10 @@ def cli(ctx, transport, path, verbose, is_json): if ctx.invoked_subcommand == 'list': ctx.obj = transport else: - def connect(): - return get_transport(transport, path) if verbose: - ctx.obj = TrezorClientVerbose(connect) + ctx.obj = lambda: TrezorClientVerbose(get_transport(transport, path)) else: - ctx.obj = TrezorClient(connect) + ctx.obj = lambda: TrezorClient(get_transport(transport, path)) @cli.resultcallback() @@ -123,33 +122,33 @@ def ls(transport_name): @click.option('-p', '--pin-protection', is_flag=True) @click.option('-r', '--passphrase-protection', is_flag=True) @click.pass_obj -def ping(client, message, button_protection, pin_protection, passphrase_protection): - return client.ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection) +def ping(connect, message, button_protection, pin_protection, passphrase_protection): + return connect().ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection) @cli.command(help='Clear session (remove cached PIN, passphrase, etc.).') @click.pass_obj -def clear_session(client): - return client.clear_session() +def clear_session(connect): + return connect().clear_session() @cli.command(help='Get example entropy.') @click.argument('size', type=int) @click.pass_obj -def get_entropy(client, size): - return binascii.hexlify(client.get_entropy(size)) +def get_entropy(connect, size): + return binascii.hexlify(connect().get_entropy(size)) @cli.command(help='Retrieve device features and settings.') @click.pass_obj -def get_features(client): - return client.features +def get_features(connect): + return connect().features @cli.command(help='List all supported coin types by the device.') @click.pass_obj -def list_coins(client): - return [coin.coin_name for coin in client.features.coins] +def list_coins(connect): + return [coin.coin_name for coin in connect().features.coins] # @@ -160,33 +159,33 @@ def list_coins(client): @cli.command(help='Change new PIN or remove existing.') @click.option('-r', '--remove', is_flag=True) @click.pass_obj -def change_pin(client, remove): - return client.change_pin(remove) +def change_pin(connect, remove): + return connect().change_pin(remove) @cli.command(help='Enable passphrase.') @click.pass_obj -def enable_passphrase(client): - return client.apply_settings(use_passphrase=True) +def enable_passphrase(connect): + return connect().apply_settings(use_passphrase=True) @cli.command(help='Disable passphrase.') @click.pass_obj -def disable_passphrase(client): - return client.apply_settings(use_passphrase=False) +def disable_passphrase(connect): + return connect().apply_settings(use_passphrase=False) @cli.command(help='Set new device label.') @click.option('-l', '--label') @click.pass_obj -def set_label(client, label): - return client.apply_settings(label=label) +def set_label(connect, label): + return connect().apply_settings(label=label) @cli.command(help='Set device flags.') @click.argument('flags') @click.pass_obj -def set_flags(client, flags): +def set_flags(connect, flags): flags = flags.lower() if flags.startswith('0b'): flags = int(flags, 2) @@ -194,13 +193,13 @@ def set_flags(client, flags): flags = int(flags, 16) else: flags = int(flags) - return client.apply_flags(flags=flags) + return connect().apply_flags(flags=flags) @cli.command(help='Set new homescreen.') @click.option('-f', '--filename', default=None) @click.pass_obj -def set_homescreen(client, filename): +def set_homescreen(connect, filename): if filename is not None: from PIL import Image im = Image.open(filename) @@ -217,20 +216,20 @@ def set_homescreen(client, filename): img = bytes(img) else: img = b'\x00' - return client.apply_settings(homescreen=img) + return connect().apply_settings(homescreen=img) @cli.command(help='Set U2F counter.') @click.argument('counter', type=int) @click.pass_obj -def set_u2f_counter(client, counter): - return client.set_u2f_counter(counter) +def set_u2f_counter(connect, counter): + return connect().set_u2f_counter(counter) @cli.command(help='Reset device to factory defaults and remove all private data.') @click.pass_obj -def wipe_device(client): - return client.wipe_device() +def wipe_device(connect): + return connect().wipe_device() @cli.command(help='Load custom configuration to the device.') @@ -243,10 +242,11 @@ def wipe_device(client): @click.option('-i', '--ignore-checksum', is_flag=True) @click.option('-s', '--slip0014', is_flag=True) @click.pass_obj -def load_device(client, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014): +def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014): if not mnemonic and not xprv and not slip0014: raise CallException(types.Failure_DataError, 'Please provide mnemonic or xprv') + client = connect() if mnemonic: return client.load_device_by_mnemonic( mnemonic, @@ -283,12 +283,12 @@ def load_device(client, mnemonic, expand, xprv, pin, passphrase_protection, labe @click.option('-t', '--type', 'rec_type', type=click.Choice(['scrambled', 'matrix']), default='scrambled') @click.option('-d', '--dry-run', is_flag=True) @click.pass_obj -def recovery_device(client, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run): +def recovery_device(connect, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run): typemap = { 'scrambled': types.RecoveryDeviceType_ScrambledWords, 'matrix': types.RecoveryDeviceType_Matrix } - return client.recovery_device( + return connect().recovery_device( int(words), passphrase_protection, pin_protection, @@ -308,8 +308,8 @@ def recovery_device(client, words, expand, pin_protection, passphrase_protection @click.option('-u', '--u2f-counter', default=0) @click.option('-s', '--skip-backup', is_flag=True) @click.pass_obj -def reset_device(client, strength, pin_protection, passphrase_protection, label, u2f_counter, skip_backup): - return client.reset_device( +def reset_device(connect, strength, pin_protection, passphrase_protection, label, u2f_counter, skip_backup): + return connect().reset_device( True, int(strength), passphrase_protection, @@ -323,8 +323,8 @@ def reset_device(client, strength, pin_protection, passphrase_protection, label, @cli.command(help='Perform device seed backup.') @click.pass_obj -def backup_device(client): - return client.backup_device() +def backup_device(connect): + return connect().backup_device() # @@ -338,7 +338,7 @@ def backup_device(client): @click.option('-v', '--version') @click.option('-s', '--skip-check', is_flag=True) @click.pass_obj -def firmware_update(client, filename, url, version, skip_check): +def firmware_update(connect, filename, url, version, skip_check): if filename: fp = open(filename, 'rb').read() elif url: @@ -377,13 +377,13 @@ def firmware_update(client, filename, url, version, skip_check): click.echo('Please confirm action on device...') from io import BytesIO - return client.firmware_update(fp=BytesIO(fp)) + return connect().firmware_update(fp=BytesIO(fp)) @cli.command(help='Perform a self-test.') @click.pass_obj -def self_test(client): - return client.self_test() +def self_test(connect): + return connect().self_test() # @@ -397,7 +397,8 @@ def self_test(client): @click.option('-t', '--script-type', type=click.Choice(['address', 'segwit', 'p2shsegwit']), default='address') @click.option('-d', '--show-display', is_flag=True) @click.pass_obj -def get_address(client, coin, address, script_type, show_display): +def get_address(connect, coin, address, script_type, show_display): + client = connect() address_n = client.expand_path(address) typemap = { 'address': types.SPENDADDRESS, @@ -414,7 +415,8 @@ def get_address(client, coin, address, script_type, show_display): @click.option('-e', '--curve') @click.option('-d', '--show-display', is_flag=True) @click.pass_obj -def get_public_node(client, coin, address, curve, show_display): +def get_public_node(connect, coin, address, curve, show_display): + client = connect() address_n = client.expand_path(address) result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin) return { @@ -440,7 +442,8 @@ def get_public_node(client, coin, address, curve, show_display): @click.option('-t', '--script-type', type=click.Choice(['address', 'segwit', 'p2shsegwit']), default='address') @click.argument('message') @click.pass_obj -def sign_message(client, coin, address, message, script_type): +def sign_message(connect, coin, address, message, script_type): + client = connect() address_n = client.expand_path(address) typemap = { 'address': types.SPENDADDRESS, @@ -462,16 +465,17 @@ def sign_message(client, coin, address, message, script_type): @click.argument('signature') @click.argument('message') @click.pass_obj -def verify_message(client, coin, address, signature, message): +def verify_message(connect, coin, address, signature, message): signature = base64.b64decode(signature) - return client.verify_message(coin, address, signature, message) + return connect().verify_message(coin, address, signature, message) @cli.command(help='Sign message with Ethereum address.') @click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/60'/0'/0/0") @click.argument('message') @click.pass_obj -def ethereum_sign_message(client, address, message): +def ethereum_sign_message(connect, address, message): + client = connect() address_n = client.expand_path(address) ret = client.ethereum_sign_message(address_n, message) output = { @@ -494,10 +498,10 @@ def ethereum_decode_hex(value): @click.argument('signature') @click.argument('message') @click.pass_obj -def ethereum_verify_message(client, address, signature, message): +def ethereum_verify_message(connect, address, signature, message): address = ethereum_decode_hex(address) signature = ethereum_decode_hex(signature) - return client.ethereum_verify_message(address, signature, message) + return connect().ethereum_verify_message(address, signature, message) @cli.command(help='Encrypt value by given key and path.') @@ -505,7 +509,8 @@ def ethereum_verify_message(client, address, signature, message): @click.argument('key') @click.argument('value') @click.pass_obj -def encrypt_keyvalue(client, address, key, value): +def encrypt_keyvalue(connect, address, key, value): + client = connect() address_n = client.expand_path(address) res = client.encrypt_keyvalue(address_n, key, value) return binascii.hexlify(res) @@ -516,7 +521,8 @@ def encrypt_keyvalue(client, address, key, value): @click.argument('key') @click.argument('value') @click.pass_obj -def decrypt_keyvalue(client, address, key, value): +def decrypt_keyvalue(connect, address, key, value): + client = connect() address_n = client.expand_path(address) return client.decrypt_keyvalue(address_n, key, value.decode('hex')) @@ -528,7 +534,8 @@ def decrypt_keyvalue(client, address, key, value): @click.argument('pubkey') @click.argument('message') @click.pass_obj -def encrypt_message(client, coin, display_only, address, pubkey, message): +def encrypt_message(connect, coin, display_only, address, pubkey, message): + client = connect() pubkey = binascii.unhexlify(pubkey) address_n = client.expand_path(address) res = client.encrypt_message(pubkey, message, display_only, coin, address_n) @@ -544,7 +551,8 @@ def encrypt_message(client, coin, display_only, address, pubkey, message): @click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0") @click.argument('payload') @click.pass_obj -def decrypt_message(client, address, payload): +def decrypt_message(connect, address, payload): + client = connect() address_n = client.expand_path(address) payload = base64.b64decode(payload) nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:] @@ -560,7 +568,8 @@ def decrypt_message(client, address, payload): @click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/60'/0'/0/0") @click.option('-d', '--show-display', is_flag=True) @click.pass_obj -def ethereum_get_address(client, address, show_display): +def ethereum_get_address(connect, address, show_display): + client = connect() address_n = client.expand_path(address) address = client.ethereum_get_address(address_n, show_display) return '0x%s' % binascii.hexlify(address).decode() @@ -578,7 +587,7 @@ def ethereum_get_address(client, address, show_display): @click.option('-p', '--publish', is_flag=True, help='Publish transaction via RPC') @click.argument('to') @click.pass_obj -def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_price, nonce, data, publish, to): +def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_price, nonce, data, publish, to): from ethjsonrpc import EthJsonRpc import rlp @@ -626,6 +635,7 @@ def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_pric to_address = ethereum_decode_hex(to) + client = connect() address_n = client.expand_path(address) address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)),) @@ -676,7 +686,8 @@ def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_pric @click.option('-N', '--network', type=int, default=0x68) @click.option('-d', '--show-display', is_flag=True) @click.pass_obj -def nem_get_address(client, address, network, show_display): +def nem_get_address(connect, address, network, show_display): + client = connect() address_n = client.expand_path(address) return client.nem_get_address(address_n, network, show_display) @@ -686,7 +697,8 @@ def nem_get_address(client, address, network, show_display): @click.option('-f', '--file', type=click.File('r'), default='-', help='Transaction in NIS (RequestPrepareAnnounce) format') @click.option('-b', '--broadcast', help='NIS to announce transaction to') @click.pass_obj -def nem_sign_tx(client, address, file, broadcast): +def nem_sign_tx(connect, address, file, broadcast): + client = connect() address_n = client.expand_path(address) transaction = client.nem_sign_tx(address_n, json.load(file)) diff --git a/trezorlib/client.py b/trezorlib/client.py index c2fa55e..2bfe829 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -153,11 +153,11 @@ def session(f): # with session activation / deactivation def wrapped_f(*args, **kwargs): client = args[0] - client.get_transport().session_begin() + client.transport.session_begin() try: return f(*args, **kwargs) finally: - client.get_transport().session_end() + client.transport.session_end() return wrapped_f @@ -179,26 +179,20 @@ def normalize_nfc(txt): class BaseClient(object): # Implements very basic layer of sending raw protobuf # messages to device and getting its response back. - def __init__(self, connect, **kwargs): - self.connect = connect - self.transport = None + def __init__(self, transport, **kwargs): + self.transport = transport super(BaseClient, self).__init__() # *args, **kwargs) - def get_transport(self): - if self.transport is None: - self.transport = self.connect() - return self.transport - def close(self): pass def cancel(self): - self.get_transport().write(proto.Cancel()) + self.transport.write(proto.Cancel()) @session def call_raw(self, msg): - self.get_transport().write(msg) - return self.get_transport().read() + self.transport.write(msg) + return self.transport.read() @session def call(self, msg):