signtx: add first test, make it all work

This commit is contained in:
Jan Pochyla 2016-11-08 18:49:58 +01:00
parent b20e62ffa8
commit d2c71b3a58
4 changed files with 173 additions and 57 deletions

View File

@ -106,20 +106,23 @@ _coins = [
},
]
def by_shortcut(shortcut):
for c in _couns:
for c in _coins:
if c['coin_shortcut'] == shortcut:
return c
raise Exception('Unknown coin shortcut "%s"' % shortcut)
def by_name(name):
for c in _couns:
for c in _coins:
if c['coin_name'] == name:
return c
raise Exception('Unknown coin name "%s"' % name)
def by_address_type(version):
for c in _couns:
for c in _coins:
if c['address_type'] == version:
return c
raise Exception('Unknown coin address type %d' % version)

View File

@ -1,6 +1,6 @@
from trezor.crypto.hashlib import sha256, ripemd160
from trezor.crypto.curve import secp256k1
from trezor.crypto import base58
from trezor.crypto import base58, der
from . import coins
@ -21,7 +21,7 @@ from trezor.messages import OutputScriptType, InputScriptType
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
tx_req.type = TXMETA
tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash
tx_req.details.request_index = None
ack = yield tx_req
@ -30,7 +30,7 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.type = TXINPUT
tx_req.request_type = TXINPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
@ -39,19 +39,19 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.type = TXOUTPUT
tx_req.request_type = TXOUTPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
if tx_hash is None:
return ack.outputs[0]
return ack.tx.outputs[0]
else:
return ack.bin_outputs[0]
return ack.tx.bin_outputs[0]
def request_tx_finish(tx_req: TxRequest):
tx_req.type = TXFINISHED
tx_req.request_type = TXFINISHED
tx_req.details = None
yield tx_req
tx_req.serialized = None
@ -62,8 +62,8 @@ def request_tx_finish(tx_req: TxRequest):
async def sign_tx(tx: SignTx, root):
tx_version = getattr(tx, 'version', 0)
tx_lock_time = getattr(tx, 'lock_time', 1)
tx_version = getattr(tx, 'version', 1)
tx_lock_time = getattr(tx, 'lock_time', 0)
tx_inputs_count = getattr(tx, 'inputs_count', 0)
tx_outputs_count = getattr(tx, 'outputs_count', 0)
coin_name = getattr(tx, 'coin_name', 'Bitcoin')
@ -91,7 +91,7 @@ async def sign_tx(tx: SignTx, root):
for i in range(tx_inputs_count):
# STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i)
write_tx_input(h_first, txi)
write_tx_input_check(h_first, txi)
total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index)
@ -118,9 +118,10 @@ async def sign_tx(tx: SignTx, root):
tx_ser = TxRequestSerializedType()
for i_sign in range(tx.inputs_count):
for i_sign in range(tx_inputs_count):
# hash of what we are signing with this input
h_sign = HashWriter(sha256)
# h_sign = BufferWriter()
# same as h_first, checked at the end of this iteration
h_second = HashWriter(sha256)
@ -130,10 +131,10 @@ async def sign_tx(tx: SignTx, root):
write_tx_header(h_sign, tx_version, tx_inputs_count)
for i in range(tx.inputs_count):
for i in range(tx_inputs_count):
# STAGE_REQUEST_4_INPUT
txi = await request_tx_input(tx_req, i)
write_tx_input(h_second, txi)
write_tx_input_check(h_second, txi)
if i == i_sign:
txi_sign = txi
key_sign = node_derive(root, txi.address_n)
@ -146,7 +147,7 @@ async def sign_tx(tx: SignTx, root):
write_tx_middle(h_sign, tx_outputs_count)
for o in range(tx.outputs_count):
for o in range(tx_outputs_count):
# STAGE_REQUEST_4_OUTPUT
txo = await request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
@ -156,30 +157,29 @@ async def sign_tx(tx: SignTx, root):
write_tx_footer(h_sign, tx_lock_time, True)
import ubinascii
# check the control digests
h_first_dig = tx_hash_digest(h_first, False)
h_second_dig = tx_hash_digest(h_second, False)
if h_first_dig != h_second_dig:
if tx_hash_digest(h_first, False) != tx_hash_digest(h_second, False):
raise ValueError('Transaction has changed during signing')
# compute the signature from the tx digest
h_sign_dig = tx_hash_digest(h_sign, True)
signature = ecdsa_sign(key_sign, h_sign_dig)
signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign, True))
tx_ser.signature_index = i_sign
tx_ser.signature = signature
# serialize input with correct signature
txi_sign.script_sig = input_derive_script_post_sign(
txi_sign, key_sign_pub, signature)
txi_sign_w = BufferWriter()
w_txi_sign = BufferWriter()
if i_sign == 0:
write_tx_header(txi_sign_w, tx_version, tx_inputs_count)
write_tx_input(txi_sign_w, txi_sign)
tx_ser.serialized_tx = txi_sign_w.getvalue()
write_tx_header(w_txi_sign, tx_version, tx_inputs_count)
write_tx_input(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign.getvalue()
tx_req.serialized = tx_ser
for o in range(tx.outputs_count):
for o in range(tx_outputs_count):
# STAGE_REQUEST_5_OUTPUT
txo = await request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
@ -190,7 +190,7 @@ async def sign_tx(tx: SignTx, root):
if o == 0:
write_tx_middle(w_txo_bin, tx_outputs_count)
write_tx_output(w_txo_bin, txo_bin)
if o == tx_outputs_count:
if o == tx_outputs_count - 1:
write_tx_footer(w_txo_bin, tx_lock_time, False)
tx_ser.signature_index = None
tx_ser.signature = None
@ -205,12 +205,12 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
total_out = 0 # sum of output amounts
# STAGE_REQUEST_2_PREV_META
tx = await request_tx_meta(prev_hash)
tx = await request_tx_meta(tx_req, prev_hash)
tx_version = getattr(tx, 'version', 0)
tx_lock_time = getattr(tx, 'lock_time', 1)
tx_inputs_count = getattr(tx, 'inputs_count', 0)
tx_outputs_count = getattr(tx, 'outputs_count', 0)
tx_inputs_count = getattr(tx, 'inputs_cnt', 0)
tx_outputs_count = getattr(tx, 'outputs_cnt', 0)
txh = HashWriter(sha256)
@ -228,16 +228,17 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
txo_bin = await request_tx_output(tx_req, o, prev_hash)
write_tx_output(txh, txo_bin)
if o == prev_index:
total_out += txo_bin.value
total_out += txo_bin.amount
write_tx_footer(txh, tx_lock_time, False)
if tx_hash_digest(txh, True) != prev_hash:
prev_hash_rev = bytes(reversed(prev_hash)) # TODO: improve performance
if tx_hash_digest(txh, True) != prev_hash_rev:
raise ValueError('Encountered invalid prev_hash')
return total_out
def tx_hash_digest(w, double: bool):
def tx_hash_digest(w, double: bool) -> bytes:
d = w.getvalue()
if double:
d = sha256(d).digest()
@ -250,8 +251,8 @@ def tx_hash_digest(w, double: bool):
def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
if o.script_type == OutputScriptType.PAYTOADDRESS:
return script_paytoaddress_new(
output_paytoaddress_extract_raw_address(o, coin, root))
ra = output_paytoaddress_extract_raw_address(o, coin, root)
return script_paytoaddress_new(ra[1:])
else:
raise ValueError('Invalid output script type')
return
@ -269,7 +270,7 @@ def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, roo
raw_address = base58.decode_check(o_address)
else:
raise ValueError('Missing address')
if raw_address[0] != coin.address_type:
if raw_address[0] != coin['address_type']:
raise ValueError('Invalid address type')
return raw_address
@ -284,14 +285,16 @@ def output_is_change(output: TxOutputType):
def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
i_script_type = getattr(i, 'script_type', InputScriptType.SPENDADDRESS)
if i_script_type == InputScriptType.SPENDADDRESS:
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
else:
raise ValueError('Unknown input script type')
def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
i_script_type = getattr(i, 'script_type', InputScriptType.SPENDADDRESS)
if i_script_type == InputScriptType.SPENDADDRESS:
return script_spendaddress_new(pubkey, signature)
else:
raise ValueError('Unknown input script type')
@ -315,20 +318,23 @@ def ecdsa_hash_pubkey(pubkey: bytes) -> bytes:
return h
def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes:
return secp256k1.sign(privkey, digest)
def ecdsa_sign(node, digest: bytes) -> bytes:
sig = secp256k1.sign(node.private_key(), digest)
print(len(sig))
sigder = der.convert_seq((sig[:32], sig[32:]))
return sigder
# TX Scripts
# ===
def script_paytoaddress_new(raw_address: bytes) -> bytearray:
def script_paytoaddress_new(pubkeyhash: bytes) -> bytearray:
s = bytearray(25)
s[0] = 0x76 # OP_DUP
s[1] = 0xA9 # OP_HASH_160
s[2] = 0x14 # pushing 20 bytes
s[3:23] = raw_address[1:] # TODO: do this properly
s[3:23] = pubkeyhash
s[23] = 0x88 # OP_EQUALVERIFY
s[24] = 0xAC # OP_CHECKSIG
return s
@ -354,11 +360,22 @@ def write_tx_header(w, version: int, inputs_count: int):
def write_tx_input(w, i: TxInputType):
i_sequence = getattr(i, 'sequence', 4294967295)
write_bytes_rev(w, i.prev_hash)
write_uint32(w, i.prev_index)
write_varint(w, len(i.script_sig))
write_bytes(w, i.script_sig)
write_uint32(w, i.sequence)
write_uint32(w, i_sequence)
def write_tx_input_check(w, i: TxInputType):
i_sequence = getattr(i, 'sequence', 4294967295)
write_bytes(w, i.prev_hash)
write_uint32(w, i.prev_index)
write_uint32(w, len(i.address_n))
for n in i.address_n:
write_uint32(w, n)
write_uint32(w, i_sequence)
def write_tx_middle(w, outputs_count: int):
@ -446,26 +463,16 @@ def write_bytes_rev(w, buf: bytearray):
class BufferWriter:
def __init__(self, buf: bytearray=None, ofs: int=0):
# TODO: re-think the use of bytearrays, buffers, and other byte IO
# i think we should just pass a pre-allocation size here, allocate the
# bytearray and then trim it to zero. in this case, write() would
# correspond to extend(), and writebyte() to append(). of course, the
# the use-case of non-destructively writing to existing bytearray still
# exists.
def __init__(self, buf: bytearray=None):
if buf is None:
buf = bytearray()
self.buf = buf
self.ofs = ofs
def write(self, buf: bytearray):
n = len(buf)
self.buf[self.ofs:self.ofs + n] = buf
self.ofs += n
self.buf.extend(buf)
def writebyte(self, b: int):
self.buf[self.ofs] = b
self.ofs += 1
self.buf.append(b)
def getvalue(self) -> bytearray:
return self.buf

View File

@ -140,6 +140,17 @@ class MessageType(Type):
WIRE_TYPE = 2
FIELDS = {}
def __init__(self, **kwargs):
for kw in kwargs:
setattr(self, kw, kwargs[kw])
def __eq__(self, rhs):
return (self.__class__ is rhs.__class__ and
self.__dict__ == rhs.__dict__)
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.__dict__)
@classmethod
async def load(cls, source=None, target=None):
if target is None:

View File

@ -0,0 +1,95 @@
from common import *
from trezor.crypto import bip32, bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TxAck import TxAck
from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType, InputScriptType
from apps.common import signtx
class TestSignTx(unittest.TestCase):
# pylint: disable=C0301
def test_one_one_fee(self):
# tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882
# input 0: 0.0039 BTC
ptx1 = TransactionType(version=1, lock_time=0, inputs_cnt=2, outputs_cnt=1)
pinp1 = TxInputType(script_sig=unhexlify(b'483045022072ba61305fe7cb542d142b8f3299a7b10f9ea61f6ffaab5dca8142601869d53c0221009a8027ed79eb3b9bc13577ac2853269323434558528c6b6a7e542be46e7e9a820141047a2d177c0f3626fc68c53610b0270fa6156181f46586c679ba6a88b34c6f4874686390b4d92e5769fbb89c8050b984f4ec0b257a0e5c4ff8bd3b035a51709503'),
prev_hash=unhexlify(b'c16a03f1cf8f99f6b5297ab614586cacec784c2d259af245909dedb0e39eddcf'),
prev_index=1)
pinp2 = TxInputType(script_sig=unhexlify(b'48304502200fd63adc8f6cb34359dc6cca9e5458d7ea50376cbd0a74514880735e6d1b8a4c0221008b6ead7fe5fbdab7319d6dfede3a0bc8e2a7c5b5a9301636d1de4aa31a3ee9b101410486ad608470d796236b003635718dfc07c0cac0cfc3bfc3079e4f491b0426f0676e6643a39198e8e7bdaffb94f4b49ea21baa107ec2e237368872836073668214'),
prev_hash=unhexlify(b'1ae39a2f8d59670c8fc61179148a8e61e039d0d9e8ab08610cb69b4a19453eaf'),
prev_index=1)
pout1 = TxOutputBinType(script_pubkey=unhexlify(b'76a91424a56db43cf6f2b02e838ea493f95d8d6047423188ac'),
amount=390000)
inp1 = TxInputType(address_n=[0], # 14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e
# amount=390000,
prev_hash=unhexlify(b'd5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882'),
prev_index=0)
out1 = TxOutputType(address='1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1',
amount=390000 - 10000,
script_type=OutputScriptType.PAYTOADDRESS)
tx = SignTx(inputs_count=1, outputs_count=1)
messages = [
None,
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify(b"d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882")), serialized=None),
TxAck(tx=ptx1),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify(b"d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882")), serialized=None),
TxAck(tx=TransactionType(inputs=[pinp1])),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=1, tx_hash=unhexlify(b"d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882")), serialized=None),
TxAck(tx=TransactionType(inputs=[pinp2])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify(b"d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882")), serialized=None),
TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
# ButtonRequest(code=ButtonRequest_ConfirmOutput),
# ButtonRequest(code=ButtonRequest_SignTx),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType(
signature_index=0,
signature=unhexlify(b'30450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede781'),
serialized_tx=unhexlify(b'010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff'))),
TxAck(tx=TransactionType(outputs=[out1])),
TxRequest(request_type=TXFINISHED, details=None, serialized=TxRequestSerializedType(
signature_index=None,
signature=None,
serialized_tx=unhexlify(b'0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'),
)),
]
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signtx.sign_tx(tx, root)
i = 0
try:
for i in range(0, len(messages) - 1, 2):
res = signer.send(messages[i])
self.assertEqual(res, messages[i + 1])
except StopIteration:
pass
self.assertEqual(i, len(messages) - 2)
# Accepted by network: tx fd79435246dee76b2f159d2db08032d666c95adc544de64c8c49f474df4a7fee
# self.assertEqual(hexlify(serialized_tx), b'010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000')
if __name__ == '__main__':
unittest.main()