diff --git a/src/apps/common/sign.py b/src/apps/common/sign.py index 298e1511..008a4f3d 100644 --- a/src/apps/common/sign.py +++ b/src/apps/common/sign.py @@ -1,9 +1,8 @@ -from trezor.crypto.hashlib import sha256 -from trezor.crypto import HDNode -from trezor.utils import memcpy, memcpy_rev +from trezor.crypto.hashlib import sha256, ripemd160 +from trezor.crypto.curve import secp256k1 +from trezor.crypto import HDNode, base58 from . import coins -from . import seed from trezor.messages.CoinType import CoinType from trezor.messages.SignTx import SignTx @@ -108,17 +107,22 @@ async def sign_tx(tx: SignTx, root: HDNode): h_sign = tx_hash_init() # hash of what we are signing with this input h_second = tx_hash_init() # should be the same as h_first + input_sign = None + key_sign = None + key_sign_pub = None + for i in range(tx.inputs_count): # STAGE_REQUEST_4_INPUT input = await request_tx_input(i) tx_write_input(h_second, input) if i == i_sign: - signing_key = node_derive(root, input.address_n) - signing_key_pub = signing_key.public_key() - input.script_sig = input_derive_scriptsig_for_signing( - input, signing_key_pub) + key_sign = node_derive(root, input.address_n) + key_sign_pub = key_sign.public_key() + script_sig = input_derive_script_pre_sign(input, key_sign_pub) + input_sign = input else: - input.script_sig = bytes() + script_sig = bytes() + input.script_sig = script_sig tx_write_input(h_sign, input) for o in range(tx.outputs_count): @@ -131,17 +135,30 @@ async def sign_tx(tx: SignTx, root: HDNode): if h_first_dig != tx_hash_digest(h_second): raise ValueError('Transaction has changed during signing') - sig = sign(signing_key, tx_hash_digest(h_sign)) - # TODO: serialize scriptsig again - # TODO: serialize input - serialized = xxx + signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign)) + script_sig = input_derive_script_post_sign( + input, key_sign_pub, signature) + input_sign.script_sig = script_sig + + # TODO: serialize the whole input at once, including the script_sig + input_sign_w = BufferWriter(bytearray(), 0) + tx_write_input(input_sign_w, input_sign) + input_sign_b = input_sign_w.getvalue() + + serialized = TxRequestSerializedType( + signature_index=i_sign, signature=signature, serialized_tx=input_sign_b) await send_serialized_tx(serialized) for o in range(tx.outputs_count): # STAGE_REQUEST_5_OUTPUT output = await request_tx_output(o) outputbin = output_compile(output, coin, root) - serialized = xxx + + outputbin_w = BufferWriter(bytearray(), 0) + tx_write_input(outputbin_w, outputbin) + outputbin_b = outputbin_w.getvalue() + + serialized = TxRequestSerializedType(serialized_tx=outputbin_b) await send_serialized_tx(serialized) await request_tx_finish() @@ -152,21 +169,21 @@ async def get_prevtx_output_value(prev_hash: bytes, prev_index: int) -> int: total_in = 0 # STAGE_REQUEST_2_PREV_META - tx = await TxRequest(type=TXMETA, hash=prev_hash) + tx = await request_tx_meta(prev_hash) h = tx_hash_init() tx_write_header(h, tx.version, tx.inputs_count) for i in range(tx.inputs_count): # STAGE_REQUEST_2_PREV_INPUT - input = await TxRequest(type=TXINPUT, hash=prev_hash, index=i) + input = await request_tx_input(i, prev_hash) tx_write_input(h, input) tx_write_middle(h, tx.outputs_count) for o in range(tx.outputs_count): # STAGE_REQUEST_2_PREV_OUTPUT - outputbin = await TxRequest(type=TXOUTPUT, hash=prev_hash, index=o) + outputbin = await request_tx_output(o, prev_hash) tx_write_output(h, outputbin) if o == prev_index: total_in += outputbin.value @@ -198,29 +215,31 @@ def tx_hash_digest(w): def output_compile(output: TxOutputType, coin: CoinType, root: HDNode) -> TxOutputBinType: bin = TxOutputBinType() bin.amount = output.amount - - if output.script_type == OutputScriptType.PAYTOADDRESS: - raw_address = output_paytoaddress_extract_raw_address(output, root) - if raw_address[0] != coin.address_type: - raise ValueError('Invalid address type') - bin.script_pubkey = script_paytoaddress_new(raw_address) - - else: - # TODO: other output script types - raise ValueError('Unknown output script type') - + bin.script_pubkey = output_derive_script(output, coin, root) return bin -def output_paytoaddress_extract_raw_address(output: TxOutputType, root: HDNode) -> bytes: - output_address_n = getattr(output, 'address_n', None) - output_address = getattr(output, 'address_n', None) - if output_address_n: - node = node_derive(root, output_address_n) +def output_derive_script(output: TxOutputType, coin: CoinType, root: HDNode) -> bytes: + if output.script_type == OutputScriptType.PAYTOADDRESS: + raw_address = output_paytoaddress_extract_raw_address(output, root) + if raw_address[0] != coin.address_type: # TODO: do this properly + raise ValueError('Invalid address type') + return script_paytoaddress_new(raw_address) + else: + # TODO: other output script types + raise ValueError('Unknown output script type') + return + + +def output_paytoaddress_extract_raw_address(o: TxOutputType, root: HDNode) -> bytes: + o_address_n = getattr(o, 'address_n', None) + o_address = getattr(o, 'address', None) + if o_address_n: + node = node_derive(root, o_address_n) # TODO: dont encode and decode again - raw_address = address_decode(node.address()) - elif output_address: - raw_address = address_decode(output_address) + raw_address = base58.decode_check(node.address()) + elif o_address: + raw_address = base58.decode_check(o_address) else: raise ValueError('Missing address') return raw_address @@ -231,7 +250,7 @@ def script_paytoaddress_new(raw_address: bytes) -> bytearray: s[0] = 0x76 # OP_DUP s[1] = 0xA9 # OP_HASH_160 s[2] = 0x14 # pushing 20 bytes - s[3:23] = raw_address + s[3:23] = raw_address[1:] # TODO: do this properly s[23] = 0x88 # OP_EQUALVERIFY s[24] = 0xAC # OP_CHECKSIG return s @@ -246,34 +265,31 @@ def output_is_change(output: TxOutputType): # === -def input_derive_scriptsig_for_signing(input: TxInputType, pubkey: bytes) -> bytes: +def input_derive_script_pre_sign(input: TxInputType, pubkey: bytes) -> bytes: if input.script_type == InputScriptType.SPENDADDRESS: - pubkeyhash = xxx - return script_spendaddress_new(pubkeyhash) - + return script_paytoaddress_new(ecdsa_get_pubkeyhash(pubkey)) else: # TODO: other input script types raise ValueError('Unknown input script type') -def script_spendaddress_new(pubkeyhash: bytes) -> bytearray: +def input_derive_script_post_sign(input: TxInputType, pubkey: bytes, signature: bytes) -> bytes: + if input.script_type == InputScriptType.SPENDADDRESS: + return script_spendaddress_new(pubkey, signature) + else: + # TODO: other input script types + raise ValueError('Unknown input script type') + + +def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray: s = bytearray(25) - s[0] = 0x76 # OP_DUP - s[1] = 0xA9 # OP_HASH_160 - s[2] = 0x14 # pushing 20 bytes - s[3:23] = pubkeyhash - s[23] = 0x88 # OP_EQUALVERIFY - s[24] = 0xAC # OP_CHECKSIG - return s - - -async def sign(privkey: bytes, digest: bytes) -> bytes: - # TODO: ecdsa secp256k1 digest sign - return b'' - - -# Addresses, HDNodes -# === + w = BufferWriter(s, 0) + write_op_push(w, len(signature) + 1) + write_bytes(w, signature) + w.writebyte(0x01) + write_op_push(w, len(pubkey)) + write_bytes(w, pubkey) + return def node_derive(root: HDNode, address_n: list) -> HDNode: @@ -283,9 +299,20 @@ def node_derive(root: HDNode, address_n: list) -> HDNode: return node -def address_decode(address: str) -> bytes: - # TODO: decode the address from base58 - return b'' +def ecdsa_get_pubkeyhash(pubkey: bytes) -> bytes: + if pubkey[0] == 0x04: + assert len(pubkey) == 65 # uncompressed format + elif pubkey[0] == 0x00: + assert len(pubkey) == 1 # point at infinity + else: + assert len(pubkey) == 33 # compresssed format + h = sha256(pubkey).digest() + h = ripemd160(h).digest() + return h + + +async def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes: + return secp256k1.sign(privkey, digest) # TX Serialization @@ -321,6 +348,25 @@ def tx_write_footer(w, locktime: int, add_hash_type: bool): write_uint32(w, 1) +def write_op_push(w, n: int): + wb = w.writebyte + if n < 0x4C: + wb(n & 0xFF) + elif n < 0xFF: + wb(0x4C) + wb(n & 0xFF) + elif n < 0xFFFF: + wb(0x4D) + wb(n & 0xFF) + wb((n >> 8) & 0xFF) + else: + wb(0x4E) + wb(n & 0xFF) + wb((n >> 8) & 0xFF) + wb((n >> 16) & 0xFF) + wb((n >> 24) & 0xFF) + + # Buffer IO & Serialization # === @@ -329,7 +375,7 @@ def write_varint(w, n: int): wb = w.writebyte if n < 253: wb(n & 0xFF) - elif n < 0x10000: + elif n < 65536: wb(253) wb(n & 0xFF) wb((n >> 8) & 0xFF) @@ -372,20 +418,25 @@ def write_bytes_rev(w, buf: bytearray): class BufferWriter: def __init__(self, buf: bytearray, ofs: int): + # TODO: re-think the use of bytearrays, buffers, and other byte IO + # i think we should just pass a pre-allocation size here, allocate the + # bytearray and then trim it to zero. in this case, write() would + # correspond to extend(), and writebyte() to append(). of course, the + # the use-case of non-destructively writing to existing bytearray still + # exists. self.buf = buf self.ofs = ofs - def write(self, buf): + def write(self, buf: bytearray): n = len(buf) - w = memcpy(self.buf, self.ofs, buf, 0, n) - self.ofs += w - return w + self.buf[self.ofs:self.ofs + n] = buf + self.ofs += n - def writebyte(self, b): + def writebyte(self, b: int): self.buf[self.ofs] = b self.ofs += 1 - def getvalue(self): + def getvalue(self) -> bytearray: return self.buf @@ -393,12 +444,14 @@ class HashWriter: def __init__(self, hashfunc): self.ctx = hashfunc() + self.buf = bytearray(1) # used in writebyte() - def write(self, buf): + def write(self, buf: bytearray): self.ctx.update(buf) - def writebyte(self, b): - self.ctx.update(bytes(b)) + def writebyte(self, b: int): + self.buf[0] = b + self.ctx.update(self.buf) - def getvalue(self): + def getvalue(self) -> bytes: return self.ctx.digest()