[TREZOR] Added Segwit support.

Following changes were necessary outside the TREZOR plugin.
- transaction.py: update_transaction handles segwit transactions.
- keystore.py: added a segwit parameter to bip44_derivation,
  use m/49' instead of m/44' for segwit.
This commit is contained in:
Jochen Hoenicke 2017-08-16 15:45:38 +02:00
parent fbe27fce04
commit ec0de566a8
4 changed files with 30 additions and 18 deletions

View File

@ -227,7 +227,7 @@ class BaseWizard(object):
self.derivation_dialog(f) self.derivation_dialog(f)
def derivation_dialog(self, f): def derivation_dialog(self, f):
default = bip44_derivation(0) default = bip44_derivation(0, self.config.get('segwit'))
message = '\n'.join([ message = '\n'.join([
_('Enter your wallet derivation here.'), _('Enter your wallet derivation here.'),
_('If you are not sure what this is, leave this field unchanged.') _('If you are not sure what this is, leave this field unchanged.')

View File

@ -684,11 +684,10 @@ is_private_key = lambda x: is_xprv(x) or is_private_key_list(x)
is_bip32_key = lambda x: is_xprv(x) or is_xpub(x) is_bip32_key = lambda x: is_xprv(x) or is_xpub(x)
def bip44_derivation(account_id): def bip44_derivation(account_id, segwit=False):
if bitcoin.TESTNET: bip = 49 if segwit else 44
return "m/44'/1'/%d'"% int(account_id) coin = 1 if bitcoin.TESTNET else 0
else: return "m/%d'/%d'/%d'" % (bip, coin, int(account_id))
return "m/44'/0'/%d'"% int(account_id)
def from_seed(seed, passphrase): def from_seed(seed, passphrase):
t = seed_type(seed) t = seed_type(seed)

View File

@ -421,8 +421,7 @@ def parse_input(vds):
def parse_witness(vds): def parse_witness(vds):
n = vds.read_compact_size() n = vds.read_compact_size()
for i in range(n): return list(vds.read_bytes(vds.read_compact_size()).encode('hex') for i in xrange(n))
x = vds.read_bytes(vds.read_compact_size())
def parse_output(vds, i): def parse_output(vds, i):
d = {} d = {}
@ -548,7 +547,12 @@ class Transaction:
for i, txin in enumerate(self.inputs()): for i, txin in enumerate(self.inputs()):
pubkeys, x_pubkeys = self.get_sorted_pubkeys(txin) pubkeys, x_pubkeys = self.get_sorted_pubkeys(txin)
sigs1 = txin.get('signatures') sigs1 = txin.get('signatures')
sigs2 = d['inputs'][i].get('signatures') if d.get('witness') is None:
sigs2 = d['inputs'][i].get('signatures')
else:
# signatures are in the witnesses. But the last item is
# the pubkey or the multisig script, so skip that.
sigs2 = d['witness'][i][:-1]
for sig in sigs2: for sig in sigs2:
if sig in sigs1: if sig in sigs1:
continue continue

View File

@ -26,6 +26,9 @@ class TrezorCompatibleKeyStore(Hardware_KeyStore):
def get_derivation(self): def get_derivation(self):
return self.derivation return self.derivation
def is_segwit(self):
return self.derivation.startswith("m/49'/")
def get_client(self, force_pair=True): def get_client(self, force_pair=True):
return self.plugin.get_client(self, force_pair) return self.plugin.get_client(self, force_pair)
@ -241,8 +244,8 @@ class TrezorCompatiblePlugin(HW_PluginBase):
self.prev_tx = prev_tx self.prev_tx = prev_tx
self.xpub_path = xpub_path self.xpub_path = xpub_path
client = self.get_client(keystore) client = self.get_client(keystore)
inputs = self.tx_inputs(tx, True) inputs = self.tx_inputs(tx, True, keystore.is_segwit())
outputs = self.tx_outputs(keystore.get_derivation(), tx) outputs = self.tx_outputs(keystore.get_derivation(), tx, keystore.is_segwit())
signed_tx = client.sign_tx(self.get_coin_name(), inputs, outputs, lock_time=tx.locktime)[1] signed_tx = client.sign_tx(self.get_coin_name(), inputs, outputs, lock_time=tx.locktime)[1]
raw = bh2u(signed_tx) raw = bh2u(signed_tx)
tx.update_signatures(raw) tx.update_signatures(raw)
@ -258,7 +261,7 @@ class TrezorCompatiblePlugin(HW_PluginBase):
address_n = client.expand_path(address_path) address_n = client.expand_path(address_path)
client.get_address(self.get_coin_name(), address_n, True) client.get_address(self.get_coin_name(), address_n, True)
def tx_inputs(self, tx, for_sig=False): def tx_inputs(self, tx, for_sig=False, segwit=False):
inputs = [] inputs = []
for txin in tx.inputs(): for txin in tx.inputs():
txinputtype = self.types.TxInputType() txinputtype = self.types.TxInputType()
@ -273,6 +276,7 @@ class TrezorCompatiblePlugin(HW_PluginBase):
xpub, s = parse_xpubkey(x_pubkey) xpub, s = parse_xpubkey(x_pubkey)
xpub_n = self.client_class.expand_path(self.xpub_path[xpub]) xpub_n = self.client_class.expand_path(self.xpub_path[xpub])
txinputtype.address_n.extend(xpub_n + s) txinputtype.address_n.extend(xpub_n + s)
txinputtype.script_type = self.types.SPENDP2SHWITNESS if segwit else self.types.SPENDADDRESS
else: else:
def f(x_pubkey): def f(x_pubkey):
if is_xpubkey(x_pubkey): if is_xpubkey(x_pubkey):
@ -288,8 +292,9 @@ class TrezorCompatiblePlugin(HW_PluginBase):
signatures=map(lambda x: bfh(x)[:-1] if x else b'', txin.get('signatures')), signatures=map(lambda x: bfh(x)[:-1] if x else b'', txin.get('signatures')),
m=txin.get('num_sig'), m=txin.get('num_sig'),
) )
script_type = self.types.SPENDP2SHWITNESS if segwit else self.types.SPENDMULTISIG
txinputtype = self.types.TxInputType( txinputtype = self.types.TxInputType(
script_type=self.types.SPENDMULTISIG, script_type=script_type,
multisig=multisig multisig=multisig
) )
# find which key is mine # find which key is mine
@ -304,6 +309,8 @@ class TrezorCompatiblePlugin(HW_PluginBase):
prev_hash = unhexlify(txin['prevout_hash']) prev_hash = unhexlify(txin['prevout_hash'])
prev_index = txin['prevout_n'] prev_index = txin['prevout_n']
if 'value' in txin:
txinputtype.amount = txin['value']
txinputtype.prev_hash = prev_hash txinputtype.prev_hash = prev_hash
txinputtype.prev_index = prev_index txinputtype.prev_index = prev_index
@ -317,7 +324,7 @@ class TrezorCompatiblePlugin(HW_PluginBase):
return inputs return inputs
def tx_outputs(self, derivation, tx): def tx_outputs(self, derivation, tx, segwit=False):
outputs = [] outputs = []
has_change = False has_change = False
@ -327,14 +334,16 @@ class TrezorCompatiblePlugin(HW_PluginBase):
has_change = True # no more than one change address has_change = True # no more than one change address
addrtype, hash_160 = bc_address_to_hash_160(address) addrtype, hash_160 = bc_address_to_hash_160(address)
index, xpubs, m = info index, xpubs, m = info
if addrtype == ADDRTYPE_P2PKH: if len(xpubs) == 1:
script_type = self.types.PAYTOP2SHWITNESS if segwit else self.types.PAYTOADDRESS
address_n = self.client_class.expand_path(derivation + "/%d/%d"%index) address_n = self.client_class.expand_path(derivation + "/%d/%d"%index)
txoutputtype = self.types.TxOutputType( txoutputtype = self.types.TxOutputType(
amount = amount, amount = amount,
script_type = self.types.PAYTOADDRESS, script_type = script_type,
address_n = address_n, address_n = address_n,
) )
elif addrtype == ADDRTYPE_P2SH: else:
script_type = self.types.PAYTOP2SHWITNESS if segwit else self.types.PAYTOMULTISIG
address_n = self.client_class.expand_path("/%d/%d"%index) address_n = self.client_class.expand_path("/%d/%d"%index)
nodes = map(self.ckd_public.deserialize, xpubs) nodes = map(self.ckd_public.deserialize, xpubs)
pubkeys = [ self.types.HDNodePathType(node=node, address_n=address_n) for node in nodes] pubkeys = [ self.types.HDNodePathType(node=node, address_n=address_n) for node in nodes]
@ -346,7 +355,7 @@ class TrezorCompatiblePlugin(HW_PluginBase):
multisig = multisig, multisig = multisig,
amount = amount, amount = amount,
address_n = self.client_class.expand_path(derivation + "/%d/%d"%index), address_n = self.client_class.expand_path(derivation + "/%d/%d"%index),
script_type = self.types.PAYTOMULTISIG) script_type = script_type)
else: else:
txoutputtype = self.types.TxOutputType() txoutputtype = self.types.TxOutputType()
txoutputtype.amount = amount txoutputtype.amount = amount