diff --git a/trezorctl b/trezorctl index 80da178..b81505b 100755 --- a/trezorctl +++ b/trezorctl @@ -24,6 +24,7 @@ import json import base64 import click +from trezorlib import protobuf_json from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException import trezorlib.types_pb2 as types @@ -55,13 +56,6 @@ def get_transport(transport_name, path): return dev -def output(res): - if output.json: - click.echo(json.dumps(res, sort_keys=True, indent=4)) - else: - click.echo(res) - - @click.group() @click.option('-t', '--transport', type=click.Choice(['usb', 'udp', 'pipe', 'bridge']), default='usb', help='Select transport used for communication.') @click.option('-p', '--path', help='Select device by transport-specific path.') @@ -69,7 +63,6 @@ def output(res): @click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object') @click.pass_context def cli(ctx, transport, path, verbose, is_json): - output.json = is_json if ctx.invoked_subcommand == 'list': ctx.obj = transport else: @@ -80,6 +73,24 @@ def cli(ctx, transport, path, verbose, is_json): ctx.obj = TrezorClient(t) +@cli.resultcallback() +def print_result(res, transport, path, verbose, is_json): + if is_json: + if hasattr(res, '__module__') and res.__module__ == 'messages_pb2': + click.echo(protobuf_json.pb2json(res)) + else: + click.echo(json.dumps(res, sort_keys=True, indent=4)) + else: + if isinstance(res, list): + for line in res: + click.echo(line) + elif isinstance(res, dict): + for k, v in res.items(): + click.echo('%s: %s' % (k, v)) + else: + click.echo(res) + + # # Common functions # @@ -90,18 +101,7 @@ def cli(ctx, transport, path, verbose, is_json): def ls(transport_name): transport_class = get_transport_class_by_name(transport_name) devices = transport_class.enumerate() - if output.json: - click.echo(json.dumps(devices, sort_keys=True, indent=4)) - else: - if transport_name == 'usb': - for dev in devices: - if dev[1] is not None: - click.echo('%s - debuglink enabled' % dev[0]) - else: - click.echo(dev[0]) - else: - for dev in devices: - click.echo(dev) + return [d[0] for d in devices] # @@ -116,23 +116,20 @@ def ls(transport_name): @click.option('-r', '--passphrase-protection', is_flag=True) @click.pass_obj def ping(client, message, button_protection, pin_protection, passphrase_protection): - ret = client.ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection) - output(ret) + return client.ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection) @cli.command(help='Clear session (remove cached PIN, passphrase, etc.).') @click.pass_obj def clear_session(client): - ret = client.clear_session() - output(ret) + return client.clear_session() @cli.command(help='Get example entropy.') @click.argument('size', type=int) @click.pass_obj def get_entropy(client, size): - ret = binascii.hexlify(client.get_entropy(size)) - output(ret) + return binascii.hexlify(client.get_entropy(size)) @cli.command(help='Retrieve device features and settings.') @@ -144,8 +141,7 @@ def get_features(client): @cli.command(help='List all supported coin types by the device.') @click.pass_obj def list_coins(client): - ret = [coin.coin_name for coin in client.features.coins] - output(ret) + return [coin.coin_name for coin in client.features.coins] # @@ -157,30 +153,26 @@ def list_coins(client): @click.option('-r', '--remove', is_flag=True) @click.pass_obj def change_pin(client, remove): - ret = client.change_pin(remove) - output(ret) + return client.change_pin(remove) @cli.command(help='Enable passphrase.') @click.pass_obj def enable_passphrase(client): - ret = client.apply_settings(use_passphrase=True) - output(ret) + return client.apply_settings(use_passphrase=True) @cli.command(help='Disable passphrase.') @click.pass_obj def disable_passphrase(client): - ret = client.apply_settings(use_passphrase=False) - output(ret) + return client.apply_settings(use_passphrase=False) @cli.command(help='Set new device label.') @click.option('-l', '--label') @click.pass_obj def set_label(client, label): - ret = client.apply_settings(label=label) - output(ret) + return client.apply_settings(label=label) @cli.command(help='Set new homescreen.') @@ -201,23 +193,20 @@ def set_homescreen(client, filename): img = ''.join(chr(int(img[i:i + 8], 2)) for i in range(0, len(img), 8)) else: img = '\x00' - ret = client.apply_settings(homescreen=img) - output(ret) + return client.apply_settings(homescreen=img) @cli.command(help='Set U2F counter.') @click.argument('counter', type=int) @click.pass_obj def set_u2f_counter(client, counter): - ret = client.set_u2f_counter(counter) - output(ret) + return client.set_u2f_counter(counter) @cli.command(help='Reset device to factory defaults and remove all private data.') @click.pass_obj def wipe_device(client): - ret = client.wipe_device() - output(ret) + return client.wipe_device() @cli.command(help='Load custom configuration to the device.') @@ -234,7 +223,7 @@ def load_device(client, mnemonic, expand, xprv, pin, passphrase_protection, labe raise CallException(types.Failure_DataError, 'Please provide mnemonic or xprv') if mnemonic: - ret = client.load_device_by_mnemonic( + return client.load_device_by_mnemonic( mnemonic, pin, passphrase_protection, @@ -244,14 +233,13 @@ def load_device(client, mnemonic, expand, xprv, pin, passphrase_protection, labe expand ) if xprv: - ret = client.load_device_by_xprv( + return client.load_device_by_xprv( xprv, pin, passphrase_protection, label, 'english' ) - output(ret) @cli.command(help='Start safe recovery workflow.') @@ -268,7 +256,7 @@ def recovery_device(client, words, expand, pin_protection, passphrase_protection 'scrambled': types.RecoveryDeviceType_ScrambledWords, 'matrix': types.RecoveryDeviceType_Matrix } - ret = client.recovery_device( + return client.recovery_device( words, passphrase_protection, pin_protection, @@ -278,7 +266,6 @@ def recovery_device(client, words, expand, pin_protection, passphrase_protection expand, dry_run ) - output(ret) @cli.command(help='Perform device setup and generate new seed.') @@ -290,7 +277,7 @@ def recovery_device(client, words, expand, pin_protection, passphrase_protection @click.option('-s', '--skip-backup', is_flag=True) @click.pass_obj def reset_device(client, strength, pin_protection, passphrase_protection, label, u2f_counter, skip_backup): - ret = client.reset_device( + return client.reset_device( True, strength, passphrase_protection, @@ -300,14 +287,12 @@ def reset_device(client, strength, pin_protection, passphrase_protection, label, u2f_counter, skip_backup ) - output(ret) @cli.command(help='Perform device seed backup.') @click.pass_obj def backup_device(client): - ret = client.backup_device() - output(ret) + return client.backup_device() # @@ -360,15 +345,13 @@ def firmware_update(client, filename, url, version, skip_check): click.echo('Please confirm action on device...') from io import BytesIO - ret = client.firmware_update(fp=BytesIO(fp)) - output(ret) + return client.firmware_update(fp=BytesIO(fp)) @cli.command(help='Perform a self-test.') @click.pass_obj def self_test(client): - ret = client.self_test() - output(ret) + return client.self_test() # @@ -390,8 +373,7 @@ def get_address(client, coin, address, script_type, show_display): 'p2shsegwit': types.SPENDP2SHWITNESS, } script_type = typemap[script_type] - ret = client.get_address(coin, address_n, show_display, script_type=script_type) - output(ret) + return client.get_address(coin, address_n, show_display, script_type=script_type) @cli.command(help='Get public node of given path.') @@ -402,8 +384,7 @@ def get_address(client, coin, address, script_type, show_display): @click.pass_obj def get_public_node(client, coin, address, curve, show_display): address_n = client.expand_path(address) - ret = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin) - output(ret) + return client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin) # @@ -419,12 +400,11 @@ def get_public_node(client, coin, address, curve, show_display): def sign_message(client, coin, address, message): address_n = client.expand_path(address) res = client.sign_message(coin, address_n, message) - ret = { + return { 'message': message, 'address': res.address, 'signature': base64.b64encode(res.signature) } - output(ret) @cli.command(help='Verify message.') @@ -435,8 +415,7 @@ def sign_message(client, coin, address, message): @click.pass_obj def verify_message(client, coin, address, signature, message): signature = base64.b64decode(signature) - ret = client.verify_message(coin, address, signature, message) - output(ret) + return client.verify_message(coin, address, signature, message) @cli.command(help='Encrypt value by given key and path.') @@ -447,8 +426,7 @@ def verify_message(client, coin, address, signature, message): def encrypt_keyvalue(client, address, key, value): address_n = client.expand_path(address) res = client.encrypt_keyvalue(address_n, key, value) - ret = binascii.hexlify(res) - output(ret) + return binascii.hexlify(res) @cli.command(help='Decrypt value by given key and path.') @@ -458,8 +436,7 @@ def encrypt_keyvalue(client, address, key, value): @click.pass_obj def decrypt_keyvalue(client, address, key, value): address_n = client.expand_path(address) - ret = client.decrypt_keyvalue(address_n, key, value.decode('hex')) - output(ret) + return client.decrypt_keyvalue(address_n, key, value.decode('hex')) @cli.command(help='Encrypt message.') @@ -473,13 +450,12 @@ def encrypt_message(client, coin, display_only, address, pubkey, message): pubkey = binascii.unhexlify(pubkey) address_n = client.expand_path(address) res = client.encrypt_message(pubkey, message, display_only, coin, address_n) - ret = { + return { 'nonce': binascii.hexlify(res.nonce), 'message': binascii.hexlify(res.message), 'hmac': binascii.hexlify(res.hmac), 'payload': base64.b64encode(res.nonce + res.message + res.hmac), } - output(ret) @cli.command(help='Decrypt message.') @@ -490,8 +466,7 @@ def decrypt_message(client, address, payload): address_n = client.expand_path(address) payload = base64.b64decode(payload) nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:] - ret = client.decrypt_message(address_n, nonce, message, msg_hmac) - output(ret) + return client.decrypt_message(address_n, nonce, message, msg_hmac) # @@ -506,8 +481,7 @@ def decrypt_message(client, address, payload): def ethereum_get_address(client, address, show_display): address_n = client.expand_path(address) address = client.ethereum_get_address(address_n, show_display) - ret = '0x%s' % binascii.hexlify(address) - output(ret) + return '0x%s' % binascii.hexlify(address) @cli.command(help='Sign (and optionally publish) Ethereum transaction. Use TO as destination address or set TO to "" for contract creation.') @@ -615,10 +589,9 @@ def ethereum_sign_tx(client, host, chain_id, address, value, gas_limit, gas_pric if publish: tx_hash = eth.eth_sendRawTransaction(tx_hex) - ret = 'Transaction published with ID: %s' % tx_hash + return 'Transaction published with ID: %s' % tx_hash else: - ret = 'Signed raw transaction: %s' % tx_hex - output(ret) + return 'Signed raw transaction: %s' % tx_hex #