diff --git a/cmd.py b/cmd.py index b9dc23c..3e1bf2c 100755 --- a/cmd.py +++ b/cmd.py @@ -14,8 +14,8 @@ def parse_args(commands): parser = argparse.ArgumentParser(description='Commandline tool for Trezor devices.') parser.add_argument('-t', '--transport', dest='transport', choices=['usb', 'serial', 'pipe', 'socket'], default='usb', help="Transport used for talking with the device") parser.add_argument('-p', '--path', dest='path', default='', help="Path used by the transport (usually serial port)") - parser.add_argument('-dt', '--debuglink-transport', dest='debuglink_transport', choices=['usb', 'serial', 'pipe', 'socket'], default='socket', help="Debuglink transport") - parser.add_argument('-dp', '--debuglink-path', dest='debuglink_path', default='127.0.0.1:2000', help="Path used by the transport (usually serial port)") + parser.add_argument('-dt', '--debuglink-transport', dest='debuglink_transport', choices=['usb', 'serial', 'pipe', 'socket'], default='usb', help="Debuglink transport") + parser.add_argument('-dp', '--debuglink-path', dest='debuglink_path', default='', help="Path used by the transport (usually serial port)") parser.add_argument('-j', '--json', dest='json', action='store_true', help="Prints result as json object") parser.add_argument('-d', '--debug', dest='debug', action='store_true', help='Enable low-level debugging') @@ -42,7 +42,7 @@ def parse_args(commands): return parser.parse_args() -def get_transport(transport_string, path): +def get_transport(transport_string, path, **kwargs): if transport_string == 'usb': from trezorlib.transport_hid import HidTransport @@ -52,23 +52,23 @@ def get_transport(transport_string, path): except IndexError: raise Exception("No Trezor found on USB") - return HidTransport(path) + return HidTransport(path, **kwargs) if transport_string == 'serial': from trezorlib.transport_serial import SerialTransport - return SerialTransport(path) + return SerialTransport(path, **kwargs) if transport_string == 'pipe': from trezorlib.transport_pipe import PipeTransport - return PipeTransport(path, is_device=False) + return PipeTransport(path, is_device=False, **kwargs) if transport_string == 'socket': from trezorlib.transport_socket import SocketTransportClient - return SocketTransportClient(path) + return SocketTransportClient(path, **kwargs) if transport_string == 'fake': from trezorlib.transport_fake import FakeTransport - return FakeTransport(path) + return FakeTransport(path, **kwargs) raise NotImplemented("Unknown transport") @@ -96,8 +96,8 @@ class Commands(object): def ping(self, args): return self.client.ping(args.msg) - def get_master_public_key(self, args): - return self.client.get_master_public_key() + def get_public_node(self, args): + return self.client.get_public_node(args.n) def get_serial_number(self, args): return binascii.hexlify(self.client.get_serial_number()) @@ -113,6 +113,9 @@ class Commands(object): return self.client.load_device(seed, args.pin) + def sign_message(self, args): + return self.client.sign_message(args.n, args.message) + def firmware_update(self, args): if not args.file: raise Exception("Must provide firmware filename") @@ -129,10 +132,11 @@ class Commands(object): get_entropy.help = 'Get example entropy' get_features.help = 'Retrieve device features and settings' get_serial_number.help = 'Get device\'s unique identifier' - get_master_public_key.help = 'Get master public key' + get_public_node.help = 'Get public node of given path' set_label.help = 'Set new wallet label' set_coin.help = 'Switch device to another crypto currency' load_device.help = 'Load custom configuration to the device' + sign_message.help = 'Sign message using address of given path' firmware_update.help = 'Upload new firmware to device (must be in bootloader mode)' get_address.arguments = ( @@ -162,6 +166,15 @@ class Commands(object): (('-n', '--pin'), {'type': str, 'default': ''}), ) + sign_message.arguments = ( + (('n',), {'metavar': 'N', 'type': int, 'nargs': '+'}), + (('message',), {'type': str}), + ) + + get_public_node.arguments = ( + (('n',), {'metavar': 'N', 'type': int, 'nargs': '+'}), + ) + firmware_update.arguments = ( (('-f', '--file'), {'type': str}), ) @@ -239,7 +252,10 @@ def main(): transport = get_transport(args.transport, args.path) if args.debug: - debuglink_transport = get_transport(args.debuglink_transport, args.debuglink_path) + if args.debuglink_transport == 'usb' and args.debuglink_path == '': + debuglink_transport = get_transport('usb', args.path, debug_link=True) + else: + debuglink_transport = get_transport(args.debuglink_transport, args.debuglink_path) debuglink = DebugLink(debuglink_transport) else: debuglink = None diff --git a/trezorlib/transport_hid.py b/trezorlib/transport_hid.py index 8b407bb..f58b01e 100644 --- a/trezorlib/transport_hid.py +++ b/trezorlib/transport_hid.py @@ -5,8 +5,8 @@ import time from transport import Transport, NotImplementedException DEVICE_IDS = [ - (0x10c4, 0xea80), # Trezor Pi - (0x534c, 0x0001), # Trezor + (0x10c4, 0xea80), # Shield + (0x534c, 0x0001), # Trezor ] class FakeRead(object): @@ -21,6 +21,8 @@ class HidTransport(Transport): def __init__(self, device, *args, **kwargs): self.hid = None self.buffer = '' + if bool(kwargs.get('debug_link')): + device = device[:-2] + '01' super(HidTransport, self).__init__(device, *args, **kwargs) @classmethod @@ -29,18 +31,18 @@ class HidTransport(Transport): for d in hid.enumerate(0, 0): vendor_id = d.get('vendor_id') product_id = d.get('product_id') - serial_number = d.get('serial_number') - - if (vendor_id, product_id) in DEVICE_IDS: - devices.append("0x%04x:0x%04x:%s" % (vendor_id, product_id, serial_number)) + path = d.get('path') + + if (vendor_id, product_id) in DEVICE_IDS and path.endswith(':00'): + devices.append(path) return devices def _open(self): self.buffer = '' - path = self.device.split(':') + print self.device self.hid = hid.device() - self.hid.open(int(path[0], 16), int(path[1], 16)) + self.hid.open_path(self.device) self.hid.set_nonblocking(True) self.hid.send_feature_report([0x41, 0x01]) # enable UART self.hid.send_feature_report([0x43, 0x03]) # purge TX/RX FIFOs