Auto merge of #3148 - str4d:DOS-mitigation-tx-expiry, r=str4d
Don't increase banscore for expired transactions if they only just expired Closes #3141.
This commit is contained in:
commit
f5b1082f9c
|
@ -56,6 +56,7 @@ testScripts=(
|
|||
'bipdersig-p2p.py'
|
||||
'overwinter_peer_management.py'
|
||||
'rewind_index.py'
|
||||
'p2p_txexpiry_dos.py'
|
||||
);
|
||||
testScriptsExt=(
|
||||
'getblocktemplate_longpoll.py'
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
#!/usr/bin/env python2
|
||||
# Copyright (c) 2018 The Zcash developers
|
||||
# Distributed under the MIT software license, see the accompanying
|
||||
# file COPYING or http://www.opensource.org/licenses/mit-license.php.
|
||||
|
||||
from test_framework.mininode import NodeConn, NodeConnCB, NetworkThread, \
|
||||
CTransaction, msg_tx, mininode_lock, OVERWINTER_PROTO_VERSION
|
||||
from test_framework.test_framework import BitcoinTestFramework
|
||||
from test_framework.util import initialize_chain_clean, start_nodes, \
|
||||
p2p_port, assert_equal
|
||||
|
||||
import time, cStringIO
|
||||
from binascii import hexlify, unhexlify
|
||||
|
||||
|
||||
class TestNode(NodeConnCB):
|
||||
def __init__(self):
|
||||
NodeConnCB.__init__(self)
|
||||
self.create_callback_map()
|
||||
self.connection = None
|
||||
|
||||
def add_connection(self, conn):
|
||||
self.connection = conn
|
||||
|
||||
# Spin until verack message is received from the node.
|
||||
# We use this to signal that our test can begin. This
|
||||
# is called from the testing thread, so it needs to acquire
|
||||
# the global lock.
|
||||
def wait_for_verack(self):
|
||||
while True:
|
||||
with mininode_lock:
|
||||
if self.verack_received:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
|
||||
# Wrapper for the NodeConn's send_message function
|
||||
def send_message(self, message):
|
||||
self.connection.send_message(message)
|
||||
|
||||
def on_close(self, conn):
|
||||
pass
|
||||
|
||||
def on_reject(self, conn, message):
|
||||
conn.rejectMessage = message
|
||||
|
||||
|
||||
class TxExpiryDoSTest(BitcoinTestFramework):
|
||||
|
||||
def setup_chain(self):
|
||||
print "Initializing test directory "+self.options.tmpdir
|
||||
initialize_chain_clean(self.options.tmpdir, 1)
|
||||
|
||||
def setup_network(self):
|
||||
self.nodes = start_nodes(1, self.options.tmpdir,
|
||||
extra_args=[['-nuparams=5ba81b19:10']])
|
||||
|
||||
def create_transaction(self, node, coinbase, to_address, amount, txModifier=None):
|
||||
from_txid = node.getblock(coinbase)['tx'][0]
|
||||
inputs = [{ "txid" : from_txid, "vout" : 0}]
|
||||
outputs = { to_address : amount }
|
||||
rawtx = node.createrawtransaction(inputs, outputs)
|
||||
tx = CTransaction()
|
||||
|
||||
if txModifier:
|
||||
f = cStringIO.StringIO(unhexlify(rawtx))
|
||||
tx.deserialize(f)
|
||||
txModifier(tx)
|
||||
rawtx = hexlify(tx.serialize())
|
||||
|
||||
signresult = node.signrawtransaction(rawtx)
|
||||
f = cStringIO.StringIO(unhexlify(signresult['hex']))
|
||||
tx.deserialize(f)
|
||||
return tx
|
||||
|
||||
def run_test(self):
|
||||
test_node = TestNode()
|
||||
|
||||
connections = []
|
||||
connections.append(NodeConn('127.0.0.1', p2p_port(0), self.nodes[0],
|
||||
test_node, "regtest", True))
|
||||
test_node.add_connection(connections[0])
|
||||
|
||||
# Start up network handling in another thread
|
||||
NetworkThread().start()
|
||||
|
||||
test_node.wait_for_verack()
|
||||
|
||||
# Verify mininodes are connected to zcashd nodes
|
||||
peerinfo = self.nodes[0].getpeerinfo()
|
||||
versions = [x["version"] for x in peerinfo]
|
||||
assert_equal(1, versions.count(OVERWINTER_PROTO_VERSION))
|
||||
assert_equal(0, peerinfo[0]["banscore"])
|
||||
|
||||
self.coinbase_blocks = self.nodes[0].generate(1)
|
||||
self.nodes[0].generate(100)
|
||||
self.nodeaddress = self.nodes[0].getnewaddress()
|
||||
|
||||
# Mininodes send transaction to zcashd node.
|
||||
def setExpiryHeight(tx):
|
||||
tx.nExpiryHeight = 101
|
||||
|
||||
spendtx = self.create_transaction(self.nodes[0],
|
||||
self.coinbase_blocks[0],
|
||||
self.nodeaddress, 1.0,
|
||||
txModifier=setExpiryHeight)
|
||||
test_node.send_message(msg_tx(spendtx))
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
# Verify test mininode has not been dropped
|
||||
# and still has a banscore of 0.
|
||||
peerinfo = self.nodes[0].getpeerinfo()
|
||||
versions = [x["version"] for x in peerinfo]
|
||||
assert_equal(1, versions.count(OVERWINTER_PROTO_VERSION))
|
||||
assert_equal(0, peerinfo[0]["banscore"])
|
||||
|
||||
# Mine a block and resend the transaction
|
||||
self.nodes[0].generate(1)
|
||||
test_node.send_message(msg_tx(spendtx))
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
# Verify test mininode has not been dropped
|
||||
# but has a banscore of 10.
|
||||
peerinfo = self.nodes[0].getpeerinfo()
|
||||
versions = [x["version"] for x in peerinfo]
|
||||
assert_equal(1, versions.count(OVERWINTER_PROTO_VERSION))
|
||||
assert_equal(10, peerinfo[0]["banscore"])
|
||||
|
||||
[ c.disconnect_node() for c in connections ]
|
||||
|
||||
if __name__ == '__main__':
|
||||
TxExpiryDoSTest().main()
|
|
@ -44,6 +44,8 @@ BIP0031_VERSION = 60000
|
|||
MY_VERSION = 170002 # past bip-31 for ping/pong
|
||||
MY_SUBVERSION = "/python-mininode-tester:0.0.1/"
|
||||
|
||||
OVERWINTER_VERSION_GROUP_ID = 0x03C48270
|
||||
|
||||
MAX_INV_SZ = 50000
|
||||
|
||||
|
||||
|
@ -565,20 +567,26 @@ class CTxOut(object):
|
|||
class CTransaction(object):
|
||||
def __init__(self, tx=None):
|
||||
if tx is None:
|
||||
self.fOverwintered = False
|
||||
self.nVersion = 1
|
||||
self.nVersionGroupId = 0
|
||||
self.vin = []
|
||||
self.vout = []
|
||||
self.nLockTime = 0
|
||||
self.nExpiryHeight = 0
|
||||
self.vjoinsplit = []
|
||||
self.joinSplitPubKey = None
|
||||
self.joinSplitSig = None
|
||||
self.sha256 = None
|
||||
self.hash = None
|
||||
else:
|
||||
self.fOverwintered = tx.fOverwintered
|
||||
self.nVersion = tx.nVersion
|
||||
self.nVersionGroupId = tx.nVersionGroupId
|
||||
self.vin = copy.deepcopy(tx.vin)
|
||||
self.vout = copy.deepcopy(tx.vout)
|
||||
self.nLockTime = tx.nLockTime
|
||||
self.nExpiryHeight = tx.nExpiryHeight
|
||||
self.vjoinsplit = copy.deepcopy(tx.vjoinsplit)
|
||||
self.joinSplitPubKey = tx.joinSplitPubKey
|
||||
self.joinSplitSig = tx.joinSplitSig
|
||||
|
@ -586,24 +594,46 @@ class CTransaction(object):
|
|||
self.hash = None
|
||||
|
||||
def deserialize(self, f):
|
||||
self.nVersion = struct.unpack("<i", f.read(4))[0]
|
||||
header = struct.unpack("<I", f.read(4))[0]
|
||||
self.fOverwintered = bool(header >> 31)
|
||||
self.nVersion = header & 0x7FFFFFFF
|
||||
self.nVersionGroupId = (struct.unpack("<I", f.read(4))[0]
|
||||
if self.fOverwintered else 0)
|
||||
|
||||
isOverwinterV3 = (self.fOverwintered and
|
||||
self.nVersionGroupId == OVERWINTER_VERSION_GROUP_ID and
|
||||
self.nVersion == 3)
|
||||
|
||||
self.vin = deser_vector(f, CTxIn)
|
||||
self.vout = deser_vector(f, CTxOut)
|
||||
self.nLockTime = struct.unpack("<I", f.read(4))[0]
|
||||
if isOverwinterV3:
|
||||
self.nExpiryHeight = struct.unpack("<I", f.read(4))[0]
|
||||
|
||||
if self.nVersion >= 2:
|
||||
self.vjoinsplit = deser_vector(f, JSDescription)
|
||||
if len(self.vjoinsplit) > 0:
|
||||
self.joinSplitPubKey = deser_uint256(f)
|
||||
self.joinSplitSig = f.read(64)
|
||||
|
||||
self.sha256 = None
|
||||
self.hash = None
|
||||
|
||||
def serialize(self):
|
||||
header = (int(self.fOverwintered)<<31) | self.nVersion
|
||||
isOverwinterV3 = (self.fOverwintered and
|
||||
self.nVersionGroupId == OVERWINTER_VERSION_GROUP_ID and
|
||||
self.nVersion == 3)
|
||||
|
||||
r = ""
|
||||
r += struct.pack("<i", self.nVersion)
|
||||
r += struct.pack("<I", header)
|
||||
if self.fOverwintered:
|
||||
r += struct.pack("<I", self.nVersionGroupId)
|
||||
r += ser_vector(self.vin)
|
||||
r += ser_vector(self.vout)
|
||||
r += struct.pack("<I", self.nLockTime)
|
||||
if isOverwinterV3:
|
||||
r += struct.pack("<I", self.nExpiryHeight)
|
||||
if self.nVersion >= 2:
|
||||
r += ser_vector(self.vjoinsplit)
|
||||
if len(self.vjoinsplit) > 0:
|
||||
|
@ -628,8 +658,10 @@ class CTransaction(object):
|
|||
return True
|
||||
|
||||
def __repr__(self):
|
||||
r = "CTransaction(nVersion=%i vin=%s vout=%s nLockTime=%i" \
|
||||
% (self.nVersion, repr(self.vin), repr(self.vout), self.nLockTime)
|
||||
r = ("CTransaction(fOverwintered=%r nVersion=%i nVersionGroupId=0x%08x "
|
||||
"vin=%s vout=%s nLockTime=%i nExpiryHeight=%i"
|
||||
% (self.fOverwintered, self.nVersion, self.nVersionGroupId,
|
||||
repr(self.vin), repr(self.vout), self.nLockTime, self.nExpiryHeight))
|
||||
if self.nVersion >= 2:
|
||||
r += " vjoinsplit=%s" % repr(self.vjoinsplit)
|
||||
if len(self.vjoinsplit) > 0:
|
||||
|
|
|
@ -897,7 +897,9 @@ bool ContextualCheckTransaction(const CTransaction& tx, CValidationState &state,
|
|||
|
||||
// Check that all transactions are unexpired
|
||||
if (IsExpiredTx(tx, nHeight)) {
|
||||
return state.DoS(dosLevel, error("ContextualCheckTransaction(): transaction is expired"), REJECT_INVALID, "tx-overwinter-expired");
|
||||
// Don't increase banscore if the transaction only just expired
|
||||
int expiredDosLevel = IsExpiredTx(tx, nHeight - 1) ? dosLevel : 0;
|
||||
return state.DoS(expiredDosLevel, error("ContextualCheckTransaction(): transaction is expired"), REJECT_INVALID, "tx-overwinter-expired");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue