Progress on fast sync

This commit is contained in:
Mariano Sorgente 2020-03-19 14:15:06 +09:00
parent 90e8287926
commit 324cae8dba
No known key found for this signature in database
GPG Key ID: 0F866338C369278C
6 changed files with 324 additions and 106 deletions

View File

@ -504,10 +504,9 @@ class ChiaServer:
yield connection, outbound_message
else:
await result
except Exception as e:
except Exception:
tb = traceback.format_exc()
self.log.error(f"{tb}")
self.log.error(f"Error {type(e)} {e}, closing connection {connection}")
self.log.error(f"Error, closing connection {connection}. {tb}")
self.global_connections.close(connection)
async def expand_outbound_messages(

View File

@ -3,7 +3,9 @@ import asyncio
import time
from typing import Dict, Optional, Tuple, List
import concurrent
import random
import logging
import traceback
from blspy import ExtendedPrivateKey
from src.full_node.full_node import OutboundMessageGenerator
@ -236,107 +238,113 @@ class WalletNode:
raise ValueError("Not enough proof hashes fetched.")
# Creates map from height to difficulty
# heights: List[uint32] = []
# difficulty_weights: List[uint64] = []
# difficulty: uint64
# for i in range(tip_height):
# if self.proof_hashes[i][1] is not None:
# difficulty = self.proof_hashes[i][1][1]
# if i > fork_point_height and i % 2 == 1: # Only add odd heights
# heights.append(uint32(i))
# difficulty_weights.append(difficulty)
heights: List[uint32] = []
difficulty_weights: List[uint64] = []
difficulty: uint64
for i in range(tip_height):
if self.proof_hashes[i][1] is not None:
difficulty = self.proof_hashes[i][1][1]
if i > fork_point_height and i % 2 == 1: # Only add odd heights
heights.append(uint32(i))
difficulty_weights.append(difficulty)
# Randomly sample based on difficulty
# query_heights = random.choices(heights, difficulty_weights, k=50)
query_heights_odd = sorted(
list(
set(
random.choices(
heights, difficulty_weights, k=min(15, len(heights))
)
)
)
)
print("Query heights:", query_heights_odd)
query_heights: List[uint32] = []
for odd_height in query_heights_odd:
query_heights += [uint32(odd_height - 1), odd_height]
# Send requests for these heights
# Verify these proofs
# last_request_time = float(0)
# highest_height_requested = uint32(0)
# request_made = False
last_request_time = float(0)
highest_height_requested = uint32(0)
request_made = False
print("Query heights:", query_heights)
# # TODO: simplify and further pipeline this sync
# for height_index in range(len(query_heights)):
# total_time_slept = 0
# while True:
# if self._shut_down:
# return
# if total_time_slept > timeout:
# raise TimeoutError("Took too long to fetch blocks")
for height_index in range(len(query_heights)):
total_time_slept = 0
while True:
if self._shut_down:
return
if total_time_slept > timeout:
raise TimeoutError("Took too long to fetch blocks")
# # Request batches that we don't have yet
# for batch_start_index in range(
# height_index,
# min(height_index + self.config["num_sync_batches"]),
# len(query_heights),
# ):
# batch_end_index = min(batch_start_index + 1, len(query_heights))
# blocks_missing = any(
# [
# not (self.potential_blocks_received[uint32(h)]).is_set()
# for h in [
# query_heights[i]
# for i in range(batch_start_index, batch_end_index)
# ]
# ]
# )
# if (
# (
# time.time() - last_request_time > sleep_interval
# and blocks_missing
# )
# or (query_heights[batch_end_index] - 1)
# > highest_height_requested
# ):
# self.log.info(
# f"Requesting sync header {query_heights[batch_start_index]}"
# )
# if (
# query_heights[batch_end_index] - 1
# > highest_height_requested
# ):
# highest_height_requested = uint32(
# query_heights[batch_end_index - 1]
# )
# request_made = True
# request_header = wallet_protocol.RequestHeader(
# uint32(query_heights[batch_start_index]),
# self.header_hashes[query_heights[batch_start_index]],
# )
# yield OutboundMessage(
# NodeType.FULL_NODE,
# Message("request_header", request_header),
# Delivery.RANDOM,
# )
# if request_made:
# last_request_time = time.time()
# request_made = False
# Request batches that we don't have yet
for batch_start_index in range(
height_index,
min(
height_index + self.config["num_sync_batches"],
len(query_heights),
),
):
blocks_missing = not self.potential_blocks_received[
uint32(query_heights[batch_start_index])
].is_set()
if (
(
time.time() - last_request_time > sleep_interval
and blocks_missing
)
or (query_heights[batch_start_index])
> highest_height_requested
):
self.log.info(
f"Requesting sync header {query_heights[batch_start_index]}"
)
if (
query_heights[batch_start_index]
> highest_height_requested
):
highest_height_requested = uint32(
query_heights[batch_start_index]
)
request_made = True
request_header = wallet_protocol.RequestHeader(
uint32(query_heights[batch_start_index]),
self.header_hashes[query_heights[batch_start_index]],
)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_header", request_header),
Delivery.RANDOM,
)
if request_made:
last_request_time = time.time()
request_made = False
try:
aw = self.potential_blocks_received[
uint32(query_heights[height_index])
].wait()
await asyncio.wait_for(aw, timeout=sleep_interval)
break
except concurrent.futures.TimeoutError:
try:
await aw
except asyncio.CancelledError:
pass
total_time_slept += sleep_interval
self.log.info("Did not receive desired headers")
# awaitables = [
# self.potential_blocks_received[
# uint32(query_heights[height_index])
# ].wait()
# ]
# future = asyncio.gather(*awaitables, return_exceptions=True)
# try:
# await asyncio.wait_for(future, timeout=sleep_interval)
# break
# except concurrent.futures.TimeoutError:
# try:
# await future
# except asyncio.CancelledError:
# pass
# total_time_slept += sleep_interval
# self.log.info("Did not receive desired headers")
# hh = self.potential_header_hashes[query_heights[height_index]]
# block_record, header_block = self.cached_blocks[hh]
# TODO(mariano): Validate proof and hash of proof
# for query_height in query_heights:
# prev_height = query_height - 1
# pass
self.log.info(
f"Finished downloading sample of headers at heights: {query_heights}, validating."
)
# Validates the downloaded proofs
assert self.wallet_state_manager.validate_select_proofs(
self.proof_hashes,
query_heights_odd,
self.cached_blocks,
self.potential_header_hashes,
)
self.log.info("All proofs validated successfuly.")
# Add blockrecords one at a time, to catch up to starting height
weight = self.wallet_state_manager.block_records[fork_point_hash].weight
@ -371,6 +379,9 @@ class WalletNode:
res == ReceiveBlockResult.ADDED_TO_HEAD
or res == ReceiveBlockResult.ADDED_AS_ORPHAN
)
self.log.info(
f"Fast sync successful up to height {header_validate_start_height - 1}"
)
# Download headers in batches, and verify them as they come in. We download a few batches ahead,
# in case there are delays. TODO(mariano): optimize sync by pipelining
@ -575,10 +586,9 @@ class WalletNode:
self.wallet_state_manager.set_sync_mode(True)
async for ret_msg in self._sync():
yield ret_msg
except asyncio.CancelledError:
self.log.error("Syncing failed, CancelledError")
except BaseException as e:
self.log.error(f"Error {type(e)}{e} with syncing")
except (BaseException, asyncio.CancelledError) as e:
tb = traceback.format_exc()
self.log.error(f"Error with syncing. {type(e)} {tb}")
self.wallet_state_manager.set_sync_mode(False)
else:
header_request = wallet_protocol.RequestHeader(

View File

@ -717,6 +717,172 @@ class WalletStateManager:
assert new_lca == tmp_old # Genesis block is the same, genesis fork
return uint32(0)
def validate_select_proofs(
self,
all_proof_hashes: List[Tuple[bytes32, Optional[Tuple[uint64, uint64]]]],
heights: List[uint32],
cached_blocks: Dict[bytes32, Tuple[BlockRecord, HeaderBlock]],
potential_header_hashes: Dict[uint32, bytes32],
) -> bool:
"""
Given a full list of proof hashes (hash of pospace and time, along with difficulty resets), this function
checks that the proofs at the passed in heights are correct. This is used to validate the weight of a chain,
by probabilisticly sampling a few blocks, and only validating these. Cached blocks and potential header hashes
contains the actual data for the header blocks to validate. This method also requires the previous block for
each height to be present, to ensure an attacker can't grind on the challenge hash.
"""
for height in heights:
breakpoint()
prev_height = uint32(height - 1)
# Get previous header block
prev_hh = potential_header_hashes[prev_height]
_, prev_header_block = cached_blocks[prev_hh]
# Validate proof hash of previous header block
if (
std_hash(
prev_header_block.proof_of_space.get_hash()
+ prev_header_block.proof_of_time.output.get_hash()
)
!= all_proof_hashes[prev_height][0]
):
return False
# Calculate challenge hash (with difficulty)
if (
prev_header_block.challenge.prev_challenge_hash
!= prev_header_block.proof_of_space.challenge_hash
):
return False
if (
prev_header_block.challenge.prev_challenge_hash
!= prev_header_block.proof_of_time.challenge_hash
):
return False
if (
prev_header_block.challenge.proofs_hash
!= all_proof_hashes[prev_height][0]
):
return False
if (
prev_height % self.constants["DIFFICULTY_EPOCH"]
== self.constants["DIFFICULTY_DELAY"]
):
diff_change = all_proof_hashes[prev_height][1]
assert diff_change is not None
if prev_header_block.challenge.new_work_difficulty != diff_change[0]:
return False
else:
if prev_header_block.challenge.new_work_difficulty is not None:
return False
challenge_hash = prev_header_block.challenge.get_hash()
# Get header block
hh = potential_header_hashes[height]
_, header_block = cached_blocks[hh]
# Validate challenge hash is == pospace challenge hash
if challenge_hash != header_block.proof_of_space.challenge_hash:
return False
# Validate challenge hash is == potime challenge hash
if challenge_hash != header_block.proof_of_time.challenge_hash:
return False
# Validate proof hash
if (
std_hash(
header_block.proof_of_space.get_hash()
+ header_block.proof_of_time.output.get_hash()
)
!= all_proof_hashes[height][0]
):
return False
# Get difficulty
if (
height % self.constants["DIFFICULTY_EPOCH"]
< self.constants["DIFFICULTY_DELAY"]
):
diff_height = (
height
- (height % self.constants["DIFFICULTY_EPOCH"])
- (
self.constants["DIFFICULTY_EPOCH"]
- self.constants["DIFFICULTY_DELAY"]
)
)
else:
diff_height = (
height
- (height % self.constants["DIFFICULTY_EPOCH"])
+ self.constants["DIFFICULTY_DELAY"]
)
difficulty_change = all_proof_hashes[diff_height][1]
assert difficulty_change is not None
difficulty = difficulty_change[0]
# Validate pospace to get iters
quality_str = header_block.proof_of_space.verify_and_get_quality_string()
assert quality_str is not None
if (
height
< self.constants["DIFFICULTY_EPOCH"]
+ self.constants["DIFFICULTY_DELAY"]
):
min_iters = self.constants["MIN_ITERS_STARTING"]
else:
if (
height % self.constants["DIFFICULTY_EPOCH"]
< self.constants["DIFFICULTY_DELAY"]
):
height2 = (
height
- (height % self.constants["DIFFICULTY_EPOCH"])
- self.constants["DIFFICULTY_EPOCH"]
- 1
)
else:
height2 = height - (height % self.constants["DIFFICULTY_EPOCH"]) - 1
height1 = height2 - self.constants["DIFFICULTY_EPOCH"]
if height1 == -1:
iters1 = uint64(0)
else:
diff_change_1 = all_proof_hashes[height1][1]
assert diff_change_1 is not None
iters1 = diff_change_1[1]
for ph, i in enumerate(all_proof_hashes):
print(i, ph)
print("Height2:", height2)
diff_change_2 = all_proof_hashes[height2][1]
assert diff_change_2 is not None
iters2 = diff_change_2[1]
min_iters = uint64(
(iters2 - iters1)
// (
self.constants["DIFFICULTY_EPOCH"]
* self.constants["MIN_ITERS_PROPORTION"]
)
)
number_of_iters: uint64 = calculate_iterations_quality(
quality_str, header_block.proof_of_space.size, difficulty, min_iters,
)
# Validate potime
if number_of_iters != header_block.proof_of_time.number_of_iterations:
return False
if not header_block.proof_of_time.is_valid(
self.constants["DISCRIMINANT_SIZE_BITS"]
):
return False
return True
async def get_filter_additions_removals(
self, transactions_fitler: bytes
) -> Tuple[List[bytes32], List[bytes32]]:

View File

@ -21,11 +21,6 @@ class TestFullSync:
async for _ in setup_two_nodes():
yield _
@pytest.fixture(scope="function")
async def wallet_node(self):
async for _ in setup_node_and_wallet():
yield _
@pytest.mark.asyncio
async def test_basic_sync(self, two_nodes):
num_blocks = 100
@ -111,6 +106,18 @@ class TestFullSync:
raise Exception("Took too long to process blocks")
class TestWalletSync:
@pytest.fixture(scope="function")
async def wallet_node(self):
async for _ in setup_node_and_wallet():
yield _
@pytest.fixture(scope="function")
async def wallet_node_starting_height(self):
async for _ in setup_node_and_wallet(dic={"starting_height": 100}):
yield _
@pytest.mark.asyncio
async def test_basic_sync_wallet(self, wallet_node):
num_blocks = 25
@ -172,6 +179,41 @@ class TestFullSync:
f"Took too long to process blocks, stopped at: {time.time() - start}"
)
@pytest.mark.asyncio
async def test_fast_sync_wallet(self, wallet_node_starting_height):
num_blocks = 50
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [])
full_node_1, wallet_node, server_1, server_2 = wallet_node_starting_height
for i in range(1, len(blocks)):
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
await server_2.start_client(
PeerInfo(server_1._host, uint16(server_1._port)), None
)
start = time.time()
found = False
while time.time() - start < 60:
# The second node should eventually catch up to the first one, and have the
# same tip at height num_blocks - 1.
if (
wallet_node.wallet_state_manager.block_records[
wallet_node.wallet_state_manager.lca
].height
== num_blocks - 6
):
found = True
break
await asyncio.sleep(0.1)
if not found:
raise Exception(
f"Took too long to process blocks, stopped at: {time.time() - start}"
)
@pytest.mark.asyncio
async def test_short_sync_wallet(self, wallet_node):
num_blocks = 8

View File

@ -145,7 +145,6 @@ class TestTransactions:
ph = await wallet_0.wallet.get_new_puzzlehash()
# wallet0 <-> sever0 <-> server1
await wallet_server_0.start_client(

View File

@ -144,6 +144,8 @@ async def setup_full_node(db_name, port, introducer_port=None, dic={}):
async def setup_wallet_node(port, introducer_port=None, key_seed=b"", dic={}):
config = load_config("config.yaml", "wallet")
if "starting_height" in dic:
config["starting_height"] = dic["starting_height"]
key_config = {
"wallet_sk": bytes(blspy.ExtendedPrivateKey.from_seed(key_seed)).hex(),
}