Fixes from straya's feedback

This commit is contained in:
Mariano Sorgente 2020-01-30 11:52:03 +09:00
parent b17ce9e200
commit 6cd4298ae1
No known key found for this signature in database
GPG Key ID: 0F866338C369278C
4 changed files with 58 additions and 11 deletions

View File

@ -138,7 +138,7 @@ class Blockchain:
[(height, index) for index, height in enumerate(heights)], reverse=True
)
curr_block: Optional[SmallHeaderBlock] = self.headers[tip_header_hash]
curr_block: Optional[SmallHeaderBlock] = self.headers.get(tip_header_hash, None)
if curr_block is None:
raise BlockNotInBlockchain(
@ -153,7 +153,9 @@ class Blockchain:
if curr_block is None:
raise ValueError(f"Do not have header {height}")
headers.append((index, curr_block))
return [b.header_hash for index, b in sorted(headers)]
# Return sorted by index (original order)
return [b.header_hash for _, b in sorted(headers, key=lambda pair: pair[0])]
def find_fork_point(self, alternate_chain: List[bytes32]) -> uint32:
"""

View File

@ -155,7 +155,22 @@ class FullNodeStore:
header_blocks: List[HeaderBlock] = []
for row in rows:
header_blocks.append(FullBlock.from_bytes(row[2]).header_block)
return sorted(header_blocks, key=lambda hb: hb.height)
# Sorts the passed in header hashes by hash, with original index
header_hashes_sorted = sorted(
enumerate(header_hashes), key=lambda pair: pair[1]
)
# Sorts the fetched header blocks by hash
header_blocks_sorted = sorted(header_blocks, key=lambda hb: hb.header_hash)
# Combine both and sort by the original indeces
combined = sorted(
zip(header_hashes_sorted, header_blocks_sorted), key=lambda pair: pair[0][0]
)
# Return only the header blocks in the original order
return [pair[1] for pair in combined]
async def get_small_header_blocks(self) -> List[SmallHeaderBlock]:
cursor = await self.db.execute("SELECT * from small_header_blocks")

View File

@ -14,6 +14,7 @@ from src.types.header import Header, HeaderData
from src.types.header_block import HeaderBlock
from src.types.proof_of_space import ProofOfSpace
from src.util.ints import uint8, uint32, uint64
from src.util.errors import BlockNotInBlockchain
from tests.block_tools import BlockTools
bt = BlockTools()
@ -70,6 +71,32 @@ class TestBlockValidation:
) == ReceiveBlockResult.ADDED_TO_HEAD
return (blocks, b)
@pytest.mark.asyncio
async def test_get_header_hashes(self, initial_blockchain):
blocks, b = initial_blockchain
header_hashes_1 = b.get_header_hashes_by_height(
[0, 8, 3], blocks[8].header_hash
)
assert header_hashes_1 == [
blocks[0].header_hash,
blocks[8].header_hash,
blocks[3].header_hash,
]
try:
b.get_header_hashes_by_height([0, 8, 3], blocks[6].header_hash)
thrown = False
except ValueError:
thrown = True
assert thrown
try:
b.get_header_hashes_by_height([0, 8, 3], blocks[9].header_hash)
thrown_2 = False
except BlockNotInBlockchain:
thrown_2 = True
assert thrown_2
@pytest.mark.asyncio
async def test_prev_pointer(self, initial_blockchain):
blocks, b = initial_blockchain
@ -309,6 +336,7 @@ class TestReorgs:
@pytest.mark.asyncio
async def test_reorg_from_genesis(self):
blocks = bt.get_consecutive_blocks(test_constants, 20, [], 9, b"0")
print(len(blocks))
b: Blockchain = await Blockchain.create({}, test_constants)
for i in range(1, len(blocks)):
await b.receive_block(blocks[i], blocks[i - 1].header_block)
@ -333,11 +361,14 @@ class TestReorgs:
# Reorg back to original branch
blocks_reorg_chain_2 = bt.get_consecutive_blocks(
test_constants, 3, blocks, 9, b"3"
test_constants, 3, blocks[:-1], 9, b"3"
)
assert (
await b.receive_block(
blocks_reorg_chain_2[20], blocks_reorg_chain_2[19].header_block
)
== ReceiveBlockResult.ADDED_AS_ORPHAN
)
await b.receive_block(
blocks_reorg_chain_2[20], blocks_reorg_chain_2[19].header_block
) == ReceiveBlockResult.ADDED_AS_ORPHAN
assert (
await b.receive_block(
blocks_reorg_chain_2[21], blocks_reorg_chain_2[20].header_block

View File

@ -69,10 +69,10 @@ class TestStore:
# Get header_blocks
header_blocks = await db.get_header_blocks_by_hash(
[blocks[0].header_hash, blocks[4].header_hash]
[blocks[4].header_hash, blocks[0].header_hash]
)
assert header_blocks[0] == blocks[0].header_block
assert header_blocks[1] == blocks[4].header_block
assert header_blocks[0] == blocks[4].header_block
assert header_blocks[1] == blocks[0].header_block
# Save/get sync
for sync_mode in (False, True):
@ -148,7 +148,6 @@ class TestStore:
await db_2.close()
os.remove(db_filename)
os.remove(db_filename_2)
os.remove(db_filename_3)
raise
# Different database should have different data