Progress on fast sync
This commit is contained in:
parent
90e8287926
commit
324cae8dba
|
@ -504,10 +504,9 @@ class ChiaServer:
|
|||
yield connection, outbound_message
|
||||
else:
|
||||
await result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tb = traceback.format_exc()
|
||||
self.log.error(f"{tb}")
|
||||
self.log.error(f"Error {type(e)} {e}, closing connection {connection}")
|
||||
self.log.error(f"Error, closing connection {connection}. {tb}")
|
||||
self.global_connections.close(connection)
|
||||
|
||||
async def expand_outbound_messages(
|
||||
|
|
|
@ -3,7 +3,9 @@ import asyncio
|
|||
import time
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
import concurrent
|
||||
import random
|
||||
import logging
|
||||
import traceback
|
||||
from blspy import ExtendedPrivateKey
|
||||
|
||||
from src.full_node.full_node import OutboundMessageGenerator
|
||||
|
@ -236,107 +238,113 @@ class WalletNode:
|
|||
raise ValueError("Not enough proof hashes fetched.")
|
||||
|
||||
# Creates map from height to difficulty
|
||||
# heights: List[uint32] = []
|
||||
# difficulty_weights: List[uint64] = []
|
||||
# difficulty: uint64
|
||||
# for i in range(tip_height):
|
||||
# if self.proof_hashes[i][1] is not None:
|
||||
# difficulty = self.proof_hashes[i][1][1]
|
||||
# if i > fork_point_height and i % 2 == 1: # Only add odd heights
|
||||
# heights.append(uint32(i))
|
||||
# difficulty_weights.append(difficulty)
|
||||
heights: List[uint32] = []
|
||||
difficulty_weights: List[uint64] = []
|
||||
difficulty: uint64
|
||||
for i in range(tip_height):
|
||||
if self.proof_hashes[i][1] is not None:
|
||||
difficulty = self.proof_hashes[i][1][1]
|
||||
if i > fork_point_height and i % 2 == 1: # Only add odd heights
|
||||
heights.append(uint32(i))
|
||||
difficulty_weights.append(difficulty)
|
||||
|
||||
# Randomly sample based on difficulty
|
||||
# query_heights = random.choices(heights, difficulty_weights, k=50)
|
||||
query_heights_odd = sorted(
|
||||
list(
|
||||
set(
|
||||
random.choices(
|
||||
heights, difficulty_weights, k=min(15, len(heights))
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
print("Query heights:", query_heights_odd)
|
||||
query_heights: List[uint32] = []
|
||||
for odd_height in query_heights_odd:
|
||||
query_heights += [uint32(odd_height - 1), odd_height]
|
||||
|
||||
# Send requests for these heights
|
||||
# Verify these proofs
|
||||
# last_request_time = float(0)
|
||||
# highest_height_requested = uint32(0)
|
||||
# request_made = False
|
||||
last_request_time = float(0)
|
||||
highest_height_requested = uint32(0)
|
||||
request_made = False
|
||||
print("Query heights:", query_heights)
|
||||
|
||||
# # TODO: simplify and further pipeline this sync
|
||||
# for height_index in range(len(query_heights)):
|
||||
# total_time_slept = 0
|
||||
# while True:
|
||||
# if self._shut_down:
|
||||
# return
|
||||
# if total_time_slept > timeout:
|
||||
# raise TimeoutError("Took too long to fetch blocks")
|
||||
for height_index in range(len(query_heights)):
|
||||
total_time_slept = 0
|
||||
while True:
|
||||
if self._shut_down:
|
||||
return
|
||||
if total_time_slept > timeout:
|
||||
raise TimeoutError("Took too long to fetch blocks")
|
||||
|
||||
# # Request batches that we don't have yet
|
||||
# for batch_start_index in range(
|
||||
# height_index,
|
||||
# min(height_index + self.config["num_sync_batches"]),
|
||||
# len(query_heights),
|
||||
# ):
|
||||
# batch_end_index = min(batch_start_index + 1, len(query_heights))
|
||||
# blocks_missing = any(
|
||||
# [
|
||||
# not (self.potential_blocks_received[uint32(h)]).is_set()
|
||||
# for h in [
|
||||
# query_heights[i]
|
||||
# for i in range(batch_start_index, batch_end_index)
|
||||
# ]
|
||||
# ]
|
||||
# )
|
||||
# if (
|
||||
# (
|
||||
# time.time() - last_request_time > sleep_interval
|
||||
# and blocks_missing
|
||||
# )
|
||||
# or (query_heights[batch_end_index] - 1)
|
||||
# > highest_height_requested
|
||||
# ):
|
||||
# self.log.info(
|
||||
# f"Requesting sync header {query_heights[batch_start_index]}"
|
||||
# )
|
||||
# if (
|
||||
# query_heights[batch_end_index] - 1
|
||||
# > highest_height_requested
|
||||
# ):
|
||||
# highest_height_requested = uint32(
|
||||
# query_heights[batch_end_index - 1]
|
||||
# )
|
||||
# request_made = True
|
||||
# request_header = wallet_protocol.RequestHeader(
|
||||
# uint32(query_heights[batch_start_index]),
|
||||
# self.header_hashes[query_heights[batch_start_index]],
|
||||
# )
|
||||
# yield OutboundMessage(
|
||||
# NodeType.FULL_NODE,
|
||||
# Message("request_header", request_header),
|
||||
# Delivery.RANDOM,
|
||||
# )
|
||||
# if request_made:
|
||||
# last_request_time = time.time()
|
||||
# request_made = False
|
||||
# Request batches that we don't have yet
|
||||
for batch_start_index in range(
|
||||
height_index,
|
||||
min(
|
||||
height_index + self.config["num_sync_batches"],
|
||||
len(query_heights),
|
||||
),
|
||||
):
|
||||
blocks_missing = not self.potential_blocks_received[
|
||||
uint32(query_heights[batch_start_index])
|
||||
].is_set()
|
||||
if (
|
||||
(
|
||||
time.time() - last_request_time > sleep_interval
|
||||
and blocks_missing
|
||||
)
|
||||
or (query_heights[batch_start_index])
|
||||
> highest_height_requested
|
||||
):
|
||||
self.log.info(
|
||||
f"Requesting sync header {query_heights[batch_start_index]}"
|
||||
)
|
||||
if (
|
||||
query_heights[batch_start_index]
|
||||
> highest_height_requested
|
||||
):
|
||||
highest_height_requested = uint32(
|
||||
query_heights[batch_start_index]
|
||||
)
|
||||
request_made = True
|
||||
request_header = wallet_protocol.RequestHeader(
|
||||
uint32(query_heights[batch_start_index]),
|
||||
self.header_hashes[query_heights[batch_start_index]],
|
||||
)
|
||||
yield OutboundMessage(
|
||||
NodeType.FULL_NODE,
|
||||
Message("request_header", request_header),
|
||||
Delivery.RANDOM,
|
||||
)
|
||||
if request_made:
|
||||
last_request_time = time.time()
|
||||
request_made = False
|
||||
try:
|
||||
aw = self.potential_blocks_received[
|
||||
uint32(query_heights[height_index])
|
||||
].wait()
|
||||
await asyncio.wait_for(aw, timeout=sleep_interval)
|
||||
break
|
||||
except concurrent.futures.TimeoutError:
|
||||
try:
|
||||
await aw
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
total_time_slept += sleep_interval
|
||||
self.log.info("Did not receive desired headers")
|
||||
|
||||
# awaitables = [
|
||||
# self.potential_blocks_received[
|
||||
# uint32(query_heights[height_index])
|
||||
# ].wait()
|
||||
# ]
|
||||
# future = asyncio.gather(*awaitables, return_exceptions=True)
|
||||
# try:
|
||||
# await asyncio.wait_for(future, timeout=sleep_interval)
|
||||
# break
|
||||
# except concurrent.futures.TimeoutError:
|
||||
# try:
|
||||
# await future
|
||||
# except asyncio.CancelledError:
|
||||
# pass
|
||||
# total_time_slept += sleep_interval
|
||||
# self.log.info("Did not receive desired headers")
|
||||
|
||||
# hh = self.potential_header_hashes[query_heights[height_index]]
|
||||
# block_record, header_block = self.cached_blocks[hh]
|
||||
|
||||
# TODO(mariano): Validate proof and hash of proof
|
||||
# for query_height in query_heights:
|
||||
# prev_height = query_height - 1
|
||||
|
||||
# pass
|
||||
self.log.info(
|
||||
f"Finished downloading sample of headers at heights: {query_heights}, validating."
|
||||
)
|
||||
# Validates the downloaded proofs
|
||||
assert self.wallet_state_manager.validate_select_proofs(
|
||||
self.proof_hashes,
|
||||
query_heights_odd,
|
||||
self.cached_blocks,
|
||||
self.potential_header_hashes,
|
||||
)
|
||||
self.log.info("All proofs validated successfuly.")
|
||||
|
||||
# Add blockrecords one at a time, to catch up to starting height
|
||||
weight = self.wallet_state_manager.block_records[fork_point_hash].weight
|
||||
|
@ -371,6 +379,9 @@ class WalletNode:
|
|||
res == ReceiveBlockResult.ADDED_TO_HEAD
|
||||
or res == ReceiveBlockResult.ADDED_AS_ORPHAN
|
||||
)
|
||||
self.log.info(
|
||||
f"Fast sync successful up to height {header_validate_start_height - 1}"
|
||||
)
|
||||
|
||||
# Download headers in batches, and verify them as they come in. We download a few batches ahead,
|
||||
# in case there are delays. TODO(mariano): optimize sync by pipelining
|
||||
|
@ -575,10 +586,9 @@ class WalletNode:
|
|||
self.wallet_state_manager.set_sync_mode(True)
|
||||
async for ret_msg in self._sync():
|
||||
yield ret_msg
|
||||
except asyncio.CancelledError:
|
||||
self.log.error("Syncing failed, CancelledError")
|
||||
except BaseException as e:
|
||||
self.log.error(f"Error {type(e)}{e} with syncing")
|
||||
except (BaseException, asyncio.CancelledError) as e:
|
||||
tb = traceback.format_exc()
|
||||
self.log.error(f"Error with syncing. {type(e)} {tb}")
|
||||
self.wallet_state_manager.set_sync_mode(False)
|
||||
else:
|
||||
header_request = wallet_protocol.RequestHeader(
|
||||
|
|
|
@ -717,6 +717,172 @@ class WalletStateManager:
|
|||
assert new_lca == tmp_old # Genesis block is the same, genesis fork
|
||||
return uint32(0)
|
||||
|
||||
def validate_select_proofs(
|
||||
self,
|
||||
all_proof_hashes: List[Tuple[bytes32, Optional[Tuple[uint64, uint64]]]],
|
||||
heights: List[uint32],
|
||||
cached_blocks: Dict[bytes32, Tuple[BlockRecord, HeaderBlock]],
|
||||
potential_header_hashes: Dict[uint32, bytes32],
|
||||
) -> bool:
|
||||
"""
|
||||
Given a full list of proof hashes (hash of pospace and time, along with difficulty resets), this function
|
||||
checks that the proofs at the passed in heights are correct. This is used to validate the weight of a chain,
|
||||
by probabilisticly sampling a few blocks, and only validating these. Cached blocks and potential header hashes
|
||||
contains the actual data for the header blocks to validate. This method also requires the previous block for
|
||||
each height to be present, to ensure an attacker can't grind on the challenge hash.
|
||||
"""
|
||||
|
||||
for height in heights:
|
||||
breakpoint()
|
||||
prev_height = uint32(height - 1)
|
||||
# Get previous header block
|
||||
prev_hh = potential_header_hashes[prev_height]
|
||||
_, prev_header_block = cached_blocks[prev_hh]
|
||||
|
||||
# Validate proof hash of previous header block
|
||||
if (
|
||||
std_hash(
|
||||
prev_header_block.proof_of_space.get_hash()
|
||||
+ prev_header_block.proof_of_time.output.get_hash()
|
||||
)
|
||||
!= all_proof_hashes[prev_height][0]
|
||||
):
|
||||
return False
|
||||
|
||||
# Calculate challenge hash (with difficulty)
|
||||
if (
|
||||
prev_header_block.challenge.prev_challenge_hash
|
||||
!= prev_header_block.proof_of_space.challenge_hash
|
||||
):
|
||||
return False
|
||||
if (
|
||||
prev_header_block.challenge.prev_challenge_hash
|
||||
!= prev_header_block.proof_of_time.challenge_hash
|
||||
):
|
||||
return False
|
||||
if (
|
||||
prev_header_block.challenge.proofs_hash
|
||||
!= all_proof_hashes[prev_height][0]
|
||||
):
|
||||
return False
|
||||
if (
|
||||
prev_height % self.constants["DIFFICULTY_EPOCH"]
|
||||
== self.constants["DIFFICULTY_DELAY"]
|
||||
):
|
||||
diff_change = all_proof_hashes[prev_height][1]
|
||||
assert diff_change is not None
|
||||
if prev_header_block.challenge.new_work_difficulty != diff_change[0]:
|
||||
return False
|
||||
else:
|
||||
if prev_header_block.challenge.new_work_difficulty is not None:
|
||||
return False
|
||||
challenge_hash = prev_header_block.challenge.get_hash()
|
||||
|
||||
# Get header block
|
||||
hh = potential_header_hashes[height]
|
||||
_, header_block = cached_blocks[hh]
|
||||
|
||||
# Validate challenge hash is == pospace challenge hash
|
||||
if challenge_hash != header_block.proof_of_space.challenge_hash:
|
||||
return False
|
||||
# Validate challenge hash is == potime challenge hash
|
||||
if challenge_hash != header_block.proof_of_time.challenge_hash:
|
||||
return False
|
||||
# Validate proof hash
|
||||
if (
|
||||
std_hash(
|
||||
header_block.proof_of_space.get_hash()
|
||||
+ header_block.proof_of_time.output.get_hash()
|
||||
)
|
||||
!= all_proof_hashes[height][0]
|
||||
):
|
||||
return False
|
||||
|
||||
# Get difficulty
|
||||
if (
|
||||
height % self.constants["DIFFICULTY_EPOCH"]
|
||||
< self.constants["DIFFICULTY_DELAY"]
|
||||
):
|
||||
diff_height = (
|
||||
height
|
||||
- (height % self.constants["DIFFICULTY_EPOCH"])
|
||||
- (
|
||||
self.constants["DIFFICULTY_EPOCH"]
|
||||
- self.constants["DIFFICULTY_DELAY"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
diff_height = (
|
||||
height
|
||||
- (height % self.constants["DIFFICULTY_EPOCH"])
|
||||
+ self.constants["DIFFICULTY_DELAY"]
|
||||
)
|
||||
|
||||
difficulty_change = all_proof_hashes[diff_height][1]
|
||||
assert difficulty_change is not None
|
||||
difficulty = difficulty_change[0]
|
||||
|
||||
# Validate pospace to get iters
|
||||
quality_str = header_block.proof_of_space.verify_and_get_quality_string()
|
||||
assert quality_str is not None
|
||||
|
||||
if (
|
||||
height
|
||||
< self.constants["DIFFICULTY_EPOCH"]
|
||||
+ self.constants["DIFFICULTY_DELAY"]
|
||||
):
|
||||
min_iters = self.constants["MIN_ITERS_STARTING"]
|
||||
else:
|
||||
if (
|
||||
height % self.constants["DIFFICULTY_EPOCH"]
|
||||
< self.constants["DIFFICULTY_DELAY"]
|
||||
):
|
||||
height2 = (
|
||||
height
|
||||
- (height % self.constants["DIFFICULTY_EPOCH"])
|
||||
- self.constants["DIFFICULTY_EPOCH"]
|
||||
- 1
|
||||
)
|
||||
else:
|
||||
height2 = height - (height % self.constants["DIFFICULTY_EPOCH"]) - 1
|
||||
|
||||
height1 = height2 - self.constants["DIFFICULTY_EPOCH"]
|
||||
if height1 == -1:
|
||||
iters1 = uint64(0)
|
||||
else:
|
||||
diff_change_1 = all_proof_hashes[height1][1]
|
||||
assert diff_change_1 is not None
|
||||
iters1 = diff_change_1[1]
|
||||
for ph, i in enumerate(all_proof_hashes):
|
||||
print(i, ph)
|
||||
print("Height2:", height2)
|
||||
diff_change_2 = all_proof_hashes[height2][1]
|
||||
assert diff_change_2 is not None
|
||||
iters2 = diff_change_2[1]
|
||||
|
||||
min_iters = uint64(
|
||||
(iters2 - iters1)
|
||||
// (
|
||||
self.constants["DIFFICULTY_EPOCH"]
|
||||
* self.constants["MIN_ITERS_PROPORTION"]
|
||||
)
|
||||
)
|
||||
|
||||
number_of_iters: uint64 = calculate_iterations_quality(
|
||||
quality_str, header_block.proof_of_space.size, difficulty, min_iters,
|
||||
)
|
||||
|
||||
# Validate potime
|
||||
if number_of_iters != header_block.proof_of_time.number_of_iterations:
|
||||
return False
|
||||
|
||||
if not header_block.proof_of_time.is_valid(
|
||||
self.constants["DISCRIMINANT_SIZE_BITS"]
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def get_filter_additions_removals(
|
||||
self, transactions_fitler: bytes
|
||||
) -> Tuple[List[bytes32], List[bytes32]]:
|
||||
|
|
|
@ -21,11 +21,6 @@ class TestFullSync:
|
|||
async for _ in setup_two_nodes():
|
||||
yield _
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def wallet_node(self):
|
||||
async for _ in setup_node_and_wallet():
|
||||
yield _
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_sync(self, two_nodes):
|
||||
num_blocks = 100
|
||||
|
@ -111,6 +106,18 @@ class TestFullSync:
|
|||
|
||||
raise Exception("Took too long to process blocks")
|
||||
|
||||
|
||||
class TestWalletSync:
|
||||
@pytest.fixture(scope="function")
|
||||
async def wallet_node(self):
|
||||
async for _ in setup_node_and_wallet():
|
||||
yield _
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def wallet_node_starting_height(self):
|
||||
async for _ in setup_node_and_wallet(dic={"starting_height": 100}):
|
||||
yield _
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_sync_wallet(self, wallet_node):
|
||||
num_blocks = 25
|
||||
|
@ -172,6 +179,41 @@ class TestFullSync:
|
|||
f"Took too long to process blocks, stopped at: {time.time() - start}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_sync_wallet(self, wallet_node_starting_height):
|
||||
num_blocks = 50
|
||||
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [])
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node_starting_height
|
||||
|
||||
for i in range(1, len(blocks)):
|
||||
async for _ in full_node_1.respond_block(
|
||||
full_node_protocol.RespondBlock(blocks[i])
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
found = False
|
||||
while time.time() - start < 60:
|
||||
# The second node should eventually catch up to the first one, and have the
|
||||
# same tip at height num_blocks - 1.
|
||||
if (
|
||||
wallet_node.wallet_state_manager.block_records[
|
||||
wallet_node.wallet_state_manager.lca
|
||||
].height
|
||||
== num_blocks - 6
|
||||
):
|
||||
found = True
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
if not found:
|
||||
raise Exception(
|
||||
f"Took too long to process blocks, stopped at: {time.time() - start}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_sync_wallet(self, wallet_node):
|
||||
num_blocks = 8
|
||||
|
|
|
@ -145,7 +145,6 @@ class TestTransactions:
|
|||
|
||||
ph = await wallet_0.wallet.get_new_puzzlehash()
|
||||
|
||||
|
||||
# wallet0 <-> sever0 <-> server1
|
||||
|
||||
await wallet_server_0.start_client(
|
||||
|
|
|
@ -144,6 +144,8 @@ async def setup_full_node(db_name, port, introducer_port=None, dic={}):
|
|||
|
||||
async def setup_wallet_node(port, introducer_port=None, key_seed=b"", dic={}):
|
||||
config = load_config("config.yaml", "wallet")
|
||||
if "starting_height" in dic:
|
||||
config["starting_height"] = dic["starting_height"]
|
||||
key_config = {
|
||||
"wallet_sk": bytes(blspy.ExtendedPrivateKey.from_seed(key_seed)).hex(),
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue